Giter VIP home page Giter VIP logo

frepo's People

Contributors

yongchaoz avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

frepo's Issues

ImageNet-1K

Could you please provide the hyperparameters used for training the model on ImageNet synthetic data? (eg learning rate, number of epochs, flip, ZCA and etc)?

Many thanks for considering my request.

Can I apply FRePo to a regression task?

Dear YongChao,

Can I apply FRePo to Multivariable time series forecasting tasks?

I tried and failed, I got huge meta-loss like 10^10 (I keep dimension > number of samples, so that's not the reason).

The shape of my feature map here is (batch_size, num_nodes, seq_len, dim), so the shape of feat_tar is (1024x170x1, 512), and the shape of feat_syn is (10x170x1, 512), which is much larger than that in image classification. Do you think this is the reason?

Looking forward to your reply!

Memory usage (PyTorch)

Hello Yongchao,

When running the script for PyTorch (using the commands from the readme, only with distill_torch) the process gets killed often because it runs out of memory - both on my local machine(approx after ~50K iterations) and on the cluster(after 200K iterations). Could it be that there is some kind of memory leakage? When restarting the process it continues from the checkpoint without any issues.

Thank you in advance for your help!
Best, Jovan

Reproducibility Issues

Not sure what I am doing wrong, but I am not able to get near the results reported in the paper when using the google drive checkpoint images in pytorch. Heres what I have done:

  • Trained LeNet on mnist_ipc1_llTrue and False
  • Adam optimizer with multiple learning rates in the range of .003-.00003.
  • Tried no weight decay or decay of .0001
  • Trained for anywhere from 50 to 10,000 epochs

With all these variations, I have only achieved anywhere from 8-20% accuracy on the real mnist testing set. The paper does not mention the need to preprocess the test set. The training images look good to me, match that in the paper (although they have inverted color w.r.t the paper).

The labels seem off though. Learned Labels look okay, but when Learned label is false, the labels are the identity matrix minus .1. So it looks like [.9, -.1, -.1, ...] for class 0. I replaced these labels with the identity matrix and was able to get around 70% accuracy on MNIST (very large variance depending on weight initializations). Was hoping to see somewhere around 85-90% accuracy on this but maybe it is just becuase LeNet is a very different architecture than the conv it was distilled on.

I tried to use the exact conv-bn as described in the paper; however, with the checkpoint labels (both true & false) these get 10-20% accuracy. With the identity trick I did for LeNet, I only get around 40% accuracy, not 90+%.

The hyperparameters for these experiments aren't well documented in the paper, so I could be using completely wrong things. Are the any other ideas for what I could be doing wrong for reproducing on MNIST?

Suggestions for measuring FLOPs

I want to measure the FLOPs required for the Distillation process. Do you have any suggestions on how to do it in your JAX implementation?

Would you like to give scripts or hyper-parameter table for other settings?

By directly using your script for 1 IPC of CIFAR100, I got 27.7 acc (+0.5 compared to the paper). However, I only got 39.9 (-1.4) acc on 10 IPC of cifaCIFAR100100 when I directly use the hyperparameters of 1 IPC of CIFAR100. For 50 IPC of CIFAR100, I got "nan", below is the log:

INFO:absl:Saved checkpoint at train_log/cifar100/step500K_num5000/conv_w128_d3_batch_llTrue/state10_reset100/saved_ckpt/checkpoint_1
INFO:absl:[500] monitor/learning_rate=0.0002999993448611349, monitor/steps_per_second=1.837596, proto/x_proto_norm=nan, proto/y_proto_margin_max=nan, proto/y_proto_margin_mean=nan, rad_norm_x=nan, train/grad_norm_y=nan, train/kernel_loss=nan, train/label_loss=nan, train/top5accuracy=0.049998436123132706, train/total_loss=nan

Achieving random test accuracy when evaluating pre-distilled data with pytorch models

Could you provide an evaluation script for pytorch? I am trying to evaluate the distilled data in pytorch and cannot figure out how to get good accuracy. Your pytorch branch is mainly in regards to distilling on pytorch; however, I am just trying to apply the pre-distilled images. I am not sure if I am using the correct methods as there are no comments. I am trying evaluate_synset, using a conv-bn model and scraping together any information on "args" that I can find throughout the codebase.

Here is my attempt at trying to get evaluations close to the paper:

Define conv model here:

class conv(nn.Module):
    def __init__(self):
        super(conv, self).__init__()
        self.conv1 = nn.Conv2d(3, 128, 3)
        self.bn1 = nn.BatchNorm2d(128)
        self.conv2 = nn.Conv2d(128, 256, 3)
        self.bn2 = nn.BatchNorm2d(256)
        self.conv3 = nn.Conv2d(256, 1024, 3)
        self.fc = nn.Linear(4096, 10)
        

    def forward(self, x):
        x = F.avg_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = self.bn1(x)
        x = F.avg_pool2d(F.relu(self.conv2(x)), (2, 2))
        x = self.bn2(x)
        x = F.avg_pool2d(F.relu(self.conv3(x)), (2, 2))
        x = x.view(-1, int(x.nelement() / x.shape[0]))
        x = self.fc(x)
        return x

Define distilled dataset, but I actually only use the labels and images here for the method call

class DistilledDataset(Dataset):
    def __init__(self, data_path, transform=None, target_transform=None):
        state = checkpoints.restore_checkpoint(data_path, None)
        x_proto = state['params']['x_proto'] #ema_average
        y_proto = state['params']['y_proto']
        self.labels = y_proto #np.eye(10)
        self.images = np.transpose(x_proto, (0,3,1,2))  #transpose since flax checkpoint puts channels last.
        self.transform = transform
        self.target_transform = target_transform
        print(np.shape(self.images), np.shape(self.labels))
    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]
        return image, label

For test set, I have done multiple things and none worked.

  1. Define test data and normalize it (standard off the shelf way
test_dataset = torchvision.datasets.CIFAR10(root = './data',
                                                    train = False,
                                                    transform = transforms.Compose([
                                                            transforms.ToTensor(),
                                                            transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2470, 0.2435, 0.2616]),]),

                                                    download=True)

test_loader = torch.utils.data.DataLoader(dataset = test_dataset,
                                                    batch_size = 10,
                                                    shuffle = True)
  1. Using your method with one_hot - 1 / num_classes
channel, im_size, num_classes, class_names, mean, std, dst_train, dst_test, testloader = get_dataset('CIFAR10', './data')
x_test = np.transpose(torch.tensor(dst_test.data), (0,3,1,2))
y_test = torch.tensor(dst_test.targets)
y_test = F.one_hot(y_test, num_classes=num_classes) - 1 / num_classes
dst_test = TensorDataset(x_test, y_test)
testloader = DataLoader(dst_test, batch_size=256, shuffle=False, num_workers=0)
  1. Using your method, but not doing the one_hot so just with get_dataset.

I was wondering if I needed to do any zca method for my test set; however, after looking at your code in the torch branch, I could not find you using it on the test set (and if you did, it must have been in the "get_dataset" method that I used).

All the args that I could scrape together from the code or by experimentation / other distillation repos. The hyperparameters listed are only distilling hyperparameters; however, evaluation parameters seems more important otherwise you cannot apply the distilled data to other applications.

args = argparse.Namespace(dataset_name = 'cifar10', data_path=None, zca_path=None, ckpt_dir=None, ckpt_name='', res_dir=None, random_seed=0,
         eval_batch_size=1000, arch='conv', width=128, depth=3, normalization='identity', pooling='avg',
         use_chunk=False, chunk_size=2000, optimizer='adam', learning_rate=0.0003, weight_decay=0.0003,
         loss='mse', temperature=1.0, num_eval=10, device = 'cuda', lr_net=.01, epoch_eval_train='1000', 
         batch_train=256,dataset='cifar10',dsa=True,dsa_strategy='color_crop_cutout_flip_scale_rotate',
         dsa_param = ParamDiffAug(), dc_aug_param=None, zca_trans=kornia.enhance.ZCAWhitening(eps=0.1, compute_inv=True))

Final call that should train my model on distilled data and hopefully get the test accuracy near that reported in the paper

evaluate_synset(1, model, torch.tensor(distilled_dataset.images), torch.tensor(distilled_dataset.labels), test_loader, args)

Even varying these hyperparamters, I have not been able to achieve more than 10% or random test accuracy on cifar 10.

Results of a few different runs with the different test sets

Train epoch 0, acc = 0.0, loss = 0.3656125068664551!
/usr/local/lib/python3.10/dist-packages/torch/nn/functional.py:4298: UserWarning: Default grid_sample and affine_grid behavior has changed to align_corners=False since 1.3.0. Please specify align_corners=True if the old behavior is desired. See the documentation of grid_sample for details.
  warnings.warn(
/usr/local/lib/python3.10/dist-packages/torch/nn/functional.py:4236: UserWarning: Default grid_sample and affine_grid behavior has changed to align_corners=False since 1.3.0. Please specify align_corners=True if the old behavior is desired. See the documentation of grid_sample for details.
  warnings.warn(
Train epoch 100, acc = 0.3, loss = 92.42935180664062!
Train epoch 200, acc = 0.1, loss = 0.3106668293476105!
Train epoch 300, acc = 0.2, loss = 0.2642415165901184!
Train epoch 400, acc = 0.4, loss = 0.23028026521205902!
Train epoch 500, acc = 0.4, loss = 0.17225635051727295!
Train epoch 600, acc = 0.2, loss = 0.2151159942150116!
Train epoch 700, acc = 0.4, loss = 0.11859626322984695!
Train epoch 800, acc = 0.4, loss = 0.1117740198969841!
Train epoch 900, acc = 0.4, loss = 0.14984123408794403!
Train epoch 1000, acc = 0.4, loss = 0.11243434995412827!
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/loss.py:536: UserWarning: Using a target size (torch.Size([10])) that is different to the input size (torch.Size([10, 10])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.
  return F.mse_loss(input, target, reduction=self.reduction)
[2023-04-29 02:16:55] Evaluate_01: epoch = 1000 train time = 4 s train loss = 0.112434 train acc = 0.4000, test acc = 0.1289
(conv(
   (conv1): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1))
   (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
   (conv2): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1))
   (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
   (conv3): Conv2d(256, 1024, kernel_size=(3, 3), stride=(1, 1))
   (fc): Linear(in_features=4096, out_features=10, bias=True)
 ),
 0.4,
 0.1289)

Running for 5,000 epochs

Train epoch 0, acc = 0.1, loss = 0.2970341145992279!
Train epoch 500, acc = 1.0, loss = 0.009703082963824272!
Train epoch 1000, acc = 0.9, loss = 0.08426003903150558!
Train epoch 1500, acc = 1.0, loss = 0.004442617297172546!
Train epoch 2000, acc = 1.0, loss = 0.0022026468068361282!
Train epoch 2500, acc = 1.0, loss = 0.004193977452814579!
Train epoch 3000, acc = 1.0, loss = 0.005132946185767651!
Train epoch 3500, acc = 1.0, loss = 0.0019280301639810205!
Train epoch 4000, acc = 1.0, loss = 0.0062325275503098965!
Train epoch 4500, acc = 1.0, loss = 0.006385618820786476!
Train epoch 5000, acc = 1.0, loss = 0.006081026047468185!
[2023-04-29 02:19:11] Evaluate_01: epoch = 5000 train time = 23 s train loss = 0.006081 train acc = 1.0000, test acc = 0.1041
(conv(
   (conv1): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1))
   (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
   (conv2): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1))
   (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
   (conv3): Conv2d(256, 1024, kernel_size=(3, 3), stride=(1, 1))
   (fc): Linear(in_features=4096, out_features=10, bias=True)
 ),
 1.0,
 0.1041)

And running this for 10,000 epochs

Train epoch 0, acc = 0.1, loss = 0.34674468636512756!
Train epoch 1000, acc = 0.9, loss = 0.04066995531320572!
Train epoch 2000, acc = 1.0, loss = 0.022518476471304893!
Train epoch 3000, acc = 1.0, loss = 0.004260644316673279!
Train epoch 4000, acc = 1.0, loss = 0.009586067870259285!
Train epoch 5000, acc = 1.0, loss = 0.003190208226442337!
Train epoch 6000, acc = 1.0, loss = 0.0058468966744840145!
Train epoch 7000, acc = 1.0, loss = 0.0008022473775781691!
Train epoch 8000, acc = 1.0, loss = 0.000369836954632774!
Train epoch 9000, acc = 1.0, loss = 0.00015043983876239508!
Train epoch 10000, acc = 1.0, loss = 0.0006692138849757612!
[2023-04-29 02:18:30] Evaluate_01: epoch = 10000 train time = 46 s train loss = 0.000669 train acc = 1.0000, test acc = 0.1035
(conv(
   (conv1): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1))
   (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
   (conv2): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1))
   (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
   (conv3): Conv2d(256, 1024, kernel_size=(3, 3), stride=(1, 1))
   (fc): Linear(in_features=4096, out_features=10, bias=True)
 ),
 1.0,
 0.1035)

Notice the variablility in the training accuracy from the first to the second/third runs.

A problem of environment

I followed how to set up the environment on your README.md, but it seems to have some problems, could you help me to fix it?
I want to evaluate the checkpoints you provide, but I'm unfamiliar with JAX
image

I successfully create the environment from environment.yaml using conda, and the environment variables are set as follows:

# Configure Environment Variable (Change to your own path)

export LD_LIBRARY_PATH=/home/name/anaconda3/envs/frepo/lib:$LD_LIBRARY_PATH
export XLA_FLAGS=--xla_gpu_cuda_data_dir=/home/name/anaconda3/envs/frepo
export PATH=/home/name/anaconda3/envs/frepo:$PATH

Attribute error

Dear author,
I follow your instructions in readme to build conda env and run script.
But I encounter an attribute error when distill images of cifar10.
My script is:
`export LD_LIBRARY_PATH=/home/name/anaconda3/envs/frepo/lib:$LD_LIBRARY_PATH
export XLA_FLAGS=--xla_gpu_cuda_data_dir=/home/name/anaconda3/envs/frepo
export PATH=/home/name/anaconda3/envs/frepo:$PATH

path="--dataset_name=cifar10 --train_log=train_log --train_img=train_img --zca_path=data/zca --data_path=~/tensorflow_datasets --save_image=True"
exp="--learn_label=True --random_seed=0"
arch="--arch=conv --width=128 --depth=3 --normalization=batch"
hyper="--max_online_updates=100 --num_nn_state=10 --num_train_steps=500000"
ckpt="--ckpt_dir=train_log/cifar100/step500K_num100/conv_w128_d3_batch_llTrue/state10_reset100 --ckpt_name=best_ckpt --res_dir=dd/cifar100 --num_eval=5"
python -m script.distill $path $exp $arch $hyper --num_prototypes_per_class=1`

But I encountered the following error:
image
Could you please help me on how to fix it?
Many thanks!

Question about lb_margin_th

Dear authors,

Thanks for your excellent work!

I have a question related to function lb_margin_th. As shown in the following line, if I understand correctly, val[0] and val[1] should have contained the top 2 values of logits. If this is the case, I think this line should be:

margin = jnp.minimum(val[..., 0] - val[..., 1], 1 / dim)

because the top_k function applies on the last axis by default.

Or is there anything I got wrong? Looking forward to your reply and thanks in advance!

https://github.com/yongchao97/FRePo/blob/43e028a5839a5de367701b9e9544a08ffccb3166/lib/datadistillation/frepo.py#L156

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.