Giter VIP home page Giter VIP logo

stable-continual-learning's People

Contributors

imirzadeh avatar michalzajac-ml avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

stable-continual-learning's Issues

Change the hyper-parameters in split-CIFAR

In the scripts replicate_experiment_2.sh here, to run the split-CIFAR experiments, the command is

python -m stable_sgd.main --dataset cifar100 --tasks 20 --epochs-per-task 1 --lr 0.1 --gamma 0.8 --hiddens 256 --batch-size 10 --dropout 0.5 --seed 1234

which only provides results around 2x% average accuracy

I found in a closed issue here that the setting should be

python -m stable_sgd.main --dataset cifar100 --tasks 20 --epochs-per-task 1 --lr 0.15 --gamma 0.85 --batch-size 10 --dropout 0.1 --seed 1234

Maybe the scripts need some updates @imirzadeh?

Issue with Resnet18 Implementation

I think that something is missed in the Resnet18 implementation:

  class BasicBlock (nn.Module):
	expansion = 1

	def __init__(self, in_planes, planes, stride=1, config={}):
		super(BasicBlock, self).__init__()
		self.conv1 = conv3x3(in_planes, planes, stride)
		self.conv2 = conv3x3(planes, planes)

		self.shortcut = nn.Sequential()
		if stride != 1 or in_planes != self.expansion * planes:
			self.shortcut = nn.Sequential(
				nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1,
						  stride=stride, bias=False),
			)
		self.IC1 = nn.Sequential(
			nn.BatchNorm2d(planes),
			nn.Dropout(p=config['dropout'])
			)

		self.IC2 = nn.Sequential(
			nn.BatchNorm2d(planes),
			nn.Dropout(p=config['dropout'])
			)

	def forward(self, x):
		out = self.conv1(x)
		out = relu(out)
		out = self.IC1(out)

		out += self.shortcut(x)
		out = relu(out)
		out = self.IC2(out)
		return out

Two attributes are defined in the class BasicBlock. They represent the classical convolution operations that are used for Resnet-X architectures:

self.conv1 = conv3x3(in_planes, planes, stride)
self.conv2 = conv3x3(planes, planes)

The problem is that in the forward pass only the first one (i.e. self.conv1) and the two BatchNormalization Layers are used to compute the output. Furthermore, when the model is load in the gpu both conv1 and conv2 are moved into it and the second one is unused in the forward. So i think that the code of the forward pass should be:

def forward(self, x):
          out = self.conv1(x)
          out = relu(out)
          out = self.IC1(out)
          out = self.conv2(out)
          out = self.IC2(out)
          
          out += self.shortcut(x)
          out = relu(out)
          return out

`
Fixing this problem, i am not able to reproduce the results of the paper. On cifar 100 using the hyperparameters provided in this closed issue :

--dataset cifar100 --tasks 20 --epochs-per-task 1 --lr 0.15 --gamma 0.85 --batch-size 10 --dropout 0.1 --seed 1234

average accuracy = 51.339999999999996, forget = 0.11200000000000002

If use instead the hyperparameters provided in replicate_experiment_2.sh

--dataset cifar100 --tasks 20 --epochs-per-task 1 --lr 0.1 --gamma 0.8 --hiddens 256 --batch-size 10 --dropout 0.5 --seed 1234

average accuracy = 44.550000000000004, forget = 0.05421052631578946

These results differ much from the results reported in the paper:

average_accuracy=59.9 and forgetting=0.08

Could you provide the correct hyperparameters?

Unable to reproduce results in the experiment 2

Hi,

Thanks for sharing the code. However I'm not able to reproduce the paper results in experiment 2 particularly. I downloaded the code and ran bash replicate_experiment_2.sh as per the instructions. It seems like Stable SGD always returns Avg accuracy less than or approximately equal to the Naive SGD for all three datasets. I also tried varying hyper-parameters such as gamma, lr and dropout etc. but the Stable SGD results did not come near to ones given in the Paper.

Is there really a bug in replicate_experiment_2.sh or it is my misunderstanding? I was wondering if you could provide any hints for reproducing the results.

Thanks in advance.

Cannot reproduce results for the experiment 1 of the paper

Hello,

Thanks for sharing the code.
However, I've had trouble reproducing the results, starting with experiment 1.
I've set up a fresh Python 3.7 environment and installed dependencies with bash setup_and_install.sh. Then I've run bash replicate_experiment_1.sh and here are the results I've got:

  • rotated mnist, naive: 87.0, stable: 86.8 (average accuracy, averaged over 3 runs). According to the paper, stable should be > 90 (Figure 3c)
  • permuted mnist, naive: 77.7, stable: 61.7. According to the paper, stable should be > 90.

Any hints for reproducing the results would be very helpful.

Issues with Table 2 in the manuscript

Hello,

Thanks for sharing the code.
However, I've had trouble reproducing the results of the naive method on permuted mnist, starting with experiment 1.
Here are my results:
ACC: 82.26, Forgetting: 0.18. Actually, as shown in Table 2 in the original manuscript, the naive version should achieve 44.4% in ACC and 0.53% in forgetting.
Could you share more details about the naive method in Table 2?
Thanks very much!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.