Giter VIP home page Giter VIP logo

Comments (6)

BestJuly avatar BestJuly commented on July 19, 2024

Hi @Mrbishuai , thank you for your interest.

It sounds not reasonable because settings are different.
Could you provide more detailed information, such as

  1. Have you make sure that in your code version, with different settings of --neg, the input data is the same?
  2. Have you check the retrieval accuracies from different (i.e. repeat and shuffle) self-supervised learning models.

Because you mentioned the same accuracy and ft_classify.py, I guess one reason might be the model loading part. Because I fixed random seed, and if you finetune models when not successfully loading the pretrained weights, the results will be the same for each run.
For the loading part, you can modify the code model.load_state_dict(xxx, strict=False) to model.load_state_dict(xxx, strict=True) to check which layers have been successfully loaded. It should raise error because for classification, the fully-connected layer is newly added. You can also uncomment this to check.

from iic.

Mrbishuai avatar Mrbishuai commented on July 19, 2024

Hi @BestJuly , thank you for your reply!
I think the pre-train part is OK. When I run train_ssl.py with different settings of --neg(repeat or shuffle), the parameter display is different.
I also guess there is a problem loading the model.

When download your pre-train model r3d_res_repeat_cls.pth to direct test, the code is
pretrained_weights = torch.load(args.ckpt, map_location='cpu')
if args.mode == 'train':
model.load_state_dict(pretrained_weights, strict=False)
else:
model.load_state_dict(pretrained_weights, strict=True)

When run ft_classify.py with you provided to generate best_model, next to test. This code had to be modefied to
pretrained_weights = torch.load(args.ckpt, map_location='cpu')
if args.mode == 'train':
model.load_state_dict(pretrained_weights, strict=False)
else:
model.load_state_dict(pretrained_weights['model'], strict=True)

After careful analysis of their parameters.
I found that your r3d_res_repeat_cls.pth pre-train model, the class is 'collections.OrderedDict' and length is 74.
My generated mest_ model.pt, the class is dict and length is 1.
I don't know if there's a problem here.

from iic.

BestJuly avatar BestJuly commented on July 19, 2024

@Mrbishuai This might caused by different version of saving codes.

I am not sure whether I used the same codes for my provided checkpoint.
(I think it is highly possible in current situation that the provided checkpoint is generated from different codes.)

There are different options for saving model parameters:

# option 1
state = {'model': model.state_dict(),}
torch.save(state, PATH)

# option 2
torch.save(model.state_dict(), PATH)

You should use model.load_state_dict(pretrained_weights['model'], strict=True) for option 1 and model.load_state_dict(pretrained_weights, strict=True) for option 2.
This will also result in different type as you mentioned, dict and collections.OrderedDict.

The strict=True is used for loading fine-tuned models and strict=False is used for loading SSL models because of the differences of model definition.

from iic.

Mrbishuai avatar Mrbishuai commented on July 19, 2024

Hi @BestJuly , thank you for your reply!
I think the pre-train part is OK. When I run train_ssl.py with different settings of --neg(repeat or shuffle), the parameter display is different.
I also guess there is a problem loading the model.

When download your pre-train model r3d_res_repeat_cls.pth to direct test, the code is
pretrained_weights = torch.load(args.ckpt, map_location='cpu')
if args.mode == 'train':
model.load_state_dict(pretrained_weights, strict=False)
else:
model.load_state_dict(pretrained_weights, strict=True)

When run ft_classify.py with you provided to generate best_model, next to test. This code had to be modefied to
pretrained_weights = torch.load(args.ckpt, map_location='cpu')
if args.mode == 'train':
model.load_state_dict(pretrained_weights, strict=False)
else:
model.load_state_dict(pretrained_weights['model'], strict=True)

After careful analysis of their parameters.
I found that your r3d_res_repeat_cls.pth pre-train model, the class is 'collections.OrderedDict' and length is 74.
My generated mest_ model.pt, the class is dict and length is 1.
I don't know if there's a problem here.

from iic.

Mrbishuai avatar Mrbishuai commented on July 19, 2024

Hi @BestJuly .
when I modify the code model.load_state_dict(xxx, strict=False) to model.load_state_dict(xxx, strict=True).
There is a mistake
Traceback (most recent call last):
File "ft_classify.py", line 238, in
model.load_state_dict(pretrained_weights, strict=True)
File "/home/bishuai/anaconda3/envs/IIC/lib/python3.7/site-packages/torch/nn/modules/module.py", line 839, in load_state_dict
self.class.name, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for R3DNet:
Missing key(s) in state_dict: "conv1.temporal_spatial_conv.weight", "bn1.weight", "bn1.bias", "bn1.running_mean", "bn1.running_var", "conv2.block1.conv1.temporal_spatial_conv.weight", "conv2.block1.bn1.weight", "conv2.block1.bn1.bias", "conv2.block1.bn1.running_mean", "conv2.block1.bn1.running_var", "conv2.block1.conv2.temporal_spatial_conv.weight", "conv2.block1.bn2.weight", "conv2.block1.bn2.bias", "conv2.block1.bn2.running_mean", "conv2.block1.bn2.running_var", "conv3.block1.downsampleconv.temporal_spatial_conv.weight", "conv3.block1.downsamplebn.weight", "conv3.block1.downsamplebn.bias", "conv3.block1.downsamplebn.running_mean", "conv3.block1.downsamplebn.running_var", "conv3.block1.conv1.temporal_spatial_conv.weight", "conv3.block1.bn1.weight", "conv3.block1.bn1.bias", "conv3.block1.bn1.running_mean", "conv3.block1.bn1.running_var", "conv3.block1.conv2.temporal_spatial_conv.weight", "conv3.block1.bn2.weight", "conv3.block1.bn2.bias", "conv3.block1.bn2.running_mean", "conv3.block1.bn2.running_var", "conv4.block1.downsampleconv.temporal_spatial_conv.weight", "conv4.block1.downsamplebn.weight", "conv4.block1.downsamplebn.bias", "conv4.block1.downsamplebn.running_mean", "conv4.block1.downsamplebn.running_var", "conv4.block1.conv1.temporal_spatial_conv.weight", "conv4.block1.bn1.weight", "conv4.block1.bn1.bias", "conv4.block1.bn1.running_mean", "conv4.block1.bn1.running_var", "conv4.block1.conv2.temporal_spatial_conv.weight", "conv4.block1.bn2.weight", "conv4.block1.bn2.bias", "conv4.block1.bn2.running_mean", "conv4.block1.bn2.running_var", "conv5.block1.downsampleconv.temporal_spatial_conv.weight", "conv5.block1.downsamplebn.weight", "conv5.block1.downsamplebn.bias", "conv5.block1.downsamplebn.running_mean", "conv5.block1.downsamplebn.running_var", "conv5.block1.conv1.temporal_spatial_conv.weight", "conv5.block1.bn1.weight", "conv5.block1.bn1.bias", "conv5.block1.bn1.running_mean", "conv5.block1.bn1.running_var", "conv5.block1.conv2.temporal_spatial_conv.weight", "conv5.block1.bn2.weight", "conv5.block1.bn2.bias", "conv5.block1.bn2.running_mean", "conv5.block1.bn2.running_var", "linear.weight", "linear.bias".

from iic.

BestJuly avatar BestJuly commented on July 19, 2024

Hi @Mrbishuai
Because model architecutes are the same, you need to check the names of weights and the saving structures I mentioned.

For example, you may meet the problem when the name of each layer contain base_network,
then the error should contain both "missing keys" and "keys which are not found" because the names are different.
I am not sure whether the error report also contained other information telling this.

Also, I want to mention again that there are two options

# option 1
## save checkpoint
state = {'model': model.state_dict(),}
torch.save(state, PATH)
## load checkpoint
model.load_state_dict(pretrained_weights['model'], strict=True)

# option 2
## save checkpoint
torch.save(model.state_dict(), PATH)
## load checkpoint
model.load_state_dict(pretrained_weights, strict=True)

If you only use loading part of option 2 to load the models saving with option 1, errors will raise.

The correct msg you should get is that

  1. If you load SSL pre-trained models when setting strict=True, errors will raise but only saying mismatched layers of FC layers, then you can change to set strict=False and start your fine-tuning process;
  2. If you load my provided classification model, no errors should occur. (Because I do not remember which option I used when I train the provided model, please try both options).

from iic.

Related Issues (13)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.