Giter VIP home page Giter VIP logo

fed_cvae's People

Contributors

ceh-2000 avatar emiliolr avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

Forkers

realrui

fed_cvae's Issues

Try changing local optimizers to SGD

Currently, we use Adam as the local optimizer but this is divergent with the standard in the literature. (This is largely because Adam introduces additional hyperparameters, complicating the tuning process.) Try switching the local optimizer to SGD with no momentum and re-tuning learning rates.

Hyperparameter tuning for FedVAE

Follow the protocol detailed in blue in this doc.

  • Look up how other FL papers tune - should we re-tune for every level of alpha (heterogeneity)? Should we split out a validation dataset to tune?

Describe the findings and paste in Tensorboard plots into this document. Consider writing a shell script to automate experimentation.

Clean up classifier model

We should use a more standard classifier architecture, like the following from McMahan et al. (2017): "a CNN with two 5x5 convolution layers (the first with 32 channels, the second with 64, each followed with 2x2 max pooling), a fully connected layer with 512 units and ReLu activation, and a final softmax output layer (1,663,370 total parameters)."

`FedVAE`: possibly re-initialize classifier each global round

Currently, the classifier is consistently trained over all global epochs. However, our pipeline schematic indicates that it should only be trained after all communication is done. Add the ability to re-initialize the server classifier's weights each round (essentially training from scratch each round) to see if this makes a difference at all.

  • Motivation: initial samples from the aggregated decoder may not be very high quality, which could direct classifier's weights towards a bad part of loss space.

Re-tune `OneFedVAE`

Wait for #41, #43, #44.

Re-tune OneFedVAE hyperparameters since the previous best hyperparameters are from previous hyperparameter tuning runs.

  • Rename this algorithm (OneFedVAE is too confusing - maybe FedVAE-Ens) and refactor the code

`FedVAE`: implement augmented classifier training scheme

Currently, in server_fed_vae.py we only sample latent variables from a tight uniform distribution to obtain high-quality samples for classifier training. It may help to either:

  1. Sample a small portion of zs from a multivariate normal, or
  2. Sample from a wider uniform distribution (either for all samples or for a small portion),

to obtain a wider variety of intra-class variation for classifier training. It's likely that increasing the number of samples used for classifier training will also be necessary.

A similar approach may help for the knowledge-distillation fine-tuning for the server VAE as well.

Implement the application dataset

Wait until after meeting with Jay on 8/16.

Implement the chest x-ray dataset used in this paper.

  • Check that our VAE architecture is powerful enough to capture this data - train a single centralized CVAE and check samples! If it isn't good enough, settle on alternative architecture.

Clean up printing and logging for all algorithms

After all algorithms are implemented: clean up the hyperparameters that are printed/logged to tensorboard in main.py. As an example, for the unachievable ideal (centralized model), alpha and number of local epochs should not be printed/logged.

Implement "Unachievable Ideal" (Centralized Model)

Train the global classifier with the training data selected according to the sampling ratio.

  • Add an if-statement in the Data class to not separate the data according to number of users (just hand back a single dataset of all available training data according to the sampling ratio).
  • Add an if-statement in main to only train the global classifier. Another file should probably exist to train this model.
  • Log this to tensorboard as central_model_sampling_ratio=x.x_number_of_epochs=xxx.

Add the ability to do a weighted average of user weights

In the original FL paper (McMahan et al. (2017)), they average weights proportionally to number of samples each user has in its local dataset. Currently, we do an unweighted average of user weights.

  • Make a weighted average based on number of data samples possible in the average_weights function of utils.py (see Algorithm 1 of McMahan et al. (2017) for details)

Implement FashionMNIST

We need more datasets than just MNIST.

  • FashionMNIST is pulled from Pytorch.
  • FashionMNIST is selectable as a command line argument.

Random seeds check

For final experiments, we'd like to show the stability of FedVAE. To do this, we should run the model several times with different weight initializations but with the same dataset split--the random seed shouldn't affect how the dataset is distributed.

  • Separate the random seeds used for dataset and model.
  • Check that changing the model seed changes performance, but leaves the dataset split the same - can check by inspecting non-IID dataset visualization... should be identical!!

Local computation experiments

Make it possible to easily test local computation amounts without re-starting the run. After each local training, communicate upwards to the server and log test results, then don't communicate downwards, but run another local epoch and repeat.

Tweak knowledge distillation procedure for FedVAE

Currently, FedVAE generates an even number of samples from all users (teacher decoders).

  • It may, however, be beneficial to sample according to the number of training data points seen by each user to ensure that users who saw substantially less data (more likely to happen with lower alpha) have less of an effect during fine-tuning.
  • Also, it may help to perform weighted parameters averaging for the server decoder's initialization.

Make learning rate a passable parameter

We want learning rate to be passed in with command line arguments.

  • All hard-coded instances of learning rate are passed in via command line. This includes:
    • Local (user) learning rate, which is used by all algorithms.
    • Global decoder aggregation learning rate for FedVAE.

Implement Sampling a Fraction of Users For Extended Communication

Standard FL algorithms classically only involve a fraction of users during each communication round.

  • Add fraction of users to sample as a command line arg in main.py and integrate this into server.py by adding a base method that (uniformly) randomly samples users according to this fraction.
  • Every model that extends server.py should be able to use this fraction of selected users, although one-shot methods should set this to 1.0 by default since they involve all users.

Re-tune `FedVAE` hyperparameters

Wait for issues #41, #43, #44.

Re-tune FedVAE hyperparameters for just one epoch, since the previous best hyperparameters are from the few-shot setting.

  • Maybe rename this algorithm and refactor (call this OneFedVAE?) - it's one-shot and has one decoder

Implement `FedVAE`

Implement our algorithm as shown in the pipeline below:

Screen Shot 2022-07-19 at 2 06 02 PM

  • Test the kaiming weight initialization script
  • Create file structure for VAE (decoder, encoder, view, linear_predict, etc.)

Implement the Basic One-Shot Algorithm

Implement the one-shot FL algorithm described in Guha et al. (2019)--just the ensembling version that doesn't require auxiliary data.

  • Add a command argument to specify the sampling method for users. Add methods in the extended server class that allow for sampling via validation, amount of data, random sampling, and all users.
  • Implement an extended server class that overwrites the create_users method in the base class to split data subsets into additional training/validation subsets (if necessary for the sampling scheme).
  • The train method should allow all users to train and then should have selected users upload their models.
  • The evaluation step for this algorithm (the "global model") should just be based on the ensembled predictions (majority vote over classes or average logits) of selected/uploaded user models.

there is no condition in encoder?

Hi, I have a small question. As far as I understand about CVAE, the encoder requires labels as inputs. Why does the conditional encoder in the code only take images as input?

One-shot FedVAE

We want to show that few-shot federated learning is a better setting for our model than one-shot. Thus, create a new algorithm that implements FedVAE as a one-shot algorithm.

See this paper for reference.

  • Add a new command line argument to select for onefedvae
  • Update the run name that gets saved to Tensorboard with onefedvae parameters.
  • Extend ServerFedVAE to ServerOneFedVAE and overwrite the server classifier training to sample a new dataset from collected decoders (same as ServerFedVAE for decoder knowledge distillation) and then train the classifier on this dataset.
  • Because this is a one-shot model, make sure we ignore the added global epochs parameter and only run for one epoch (se oneshot algorithm as reference).
  • Update the README.md with instructions to run.
  • For verification, local epochs should be very high so that user models can converge.

`FedVAE`: PMF calculation doesn't always sum to one

When using python3 main.py --algorithm fedvae --num_users 2 --alpha 0.1 --sample_ratio 0.25 --glob_epochs 2 --local_epochs 3 --should_log 1 --z_dim 50 --beta 1.0 --classifier_num_train_samples 1000 --classifier_epochs 5 --decoder_num_train_samples 1000 --decoder_epochs 5, the label distributions begin to not sum to one.

We need to ensure that label distributions always sum to one. This may be a python precision issue.

Make the unachievable ideal appear as a line in tensorboard

For every experiment, we want the hard-coded converged ideal value to appear as a line across the top of our plot.

  • Add the hard-coded value to main.py.
  • Log a new run with the writer that just logs the same hard-coded as many times as args.glob_epochs.
  • End this run and start a new one.

Note: This is a purely for easier visualization, not a true experiment.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.