Giter VIP home page Giter VIP logo

Comments (7)

Yukun-Huang avatar Yukun-Huang commented on July 18, 2024

I have tried both with-fix-weight and without-fix-weight versions, and the final results are as follows:

version knn score
w fixed weights 0.745
w/o fixed weights 0.829

from predict-cluster.

sukun1045 avatar sukun1045 commented on July 18, 2024

Hello, sorry for the late reply. Yes, I agree and thank you for pointing out the bug. It looks like something went wrong when we reorganized the code for this release. The score for w/o fixed weights looks correct but the score for fixed weight should be higher than the w/o fixed weights version. Let me redo the experiment in TensorFlow and I will confirm with you. Moreover, the Pytorch implementation result also shows better performance when using fixed weights. We will release the Pytorch Version very soon. Thanks again for your interest.

from predict-cluster.

Yukun-Huang avatar Yukun-Huang commented on July 18, 2024

Thanks a lot for your response, and look forward to the follow-up update and release of the pytorch version. Actually, I have implemented a pytorch version and it achieved similar performance to the tensorflow version. Maybe I got some details wrong.

By the way, I noticed that the skeleton data in Predict-Cluster/ucla_data is different from the original data in the NW-UCLA dataset. This is because they have been pre-processed by view-invariant transformation, right? Will this part of the pre-processing script be released? Thank you.

from predict-cluster.

sukun1045 avatar sukun1045 commented on July 18, 2024

Hi, the PyTorch implementation is now in the ucla_github_pytorch folder. Yes I guess there is something wrong in the current TensorFlow implementation because the FW result is not the same as what we got before. But the PyTorch version looks fine. You can check it out and try whether it has any bug. FW should be able to reach at least 80% acc. For data pre-processing, we are applying view-invariant transform to the original data. We may add the pre-processing script later.

from predict-cluster.

Yutasq avatar Yutasq commented on July 18, 2024

Hi @sukun1045 ,

Thanks a lot for your code. I found that in your released pytorch version, it seems the weights of decoder are trained for the FW settings. As when I try to print out the trainable parameters, the results give that all the 30 tensor variables defined in your framework are "requires_grad=True". My code to check the trainable parameters is:
for p in model.parameters():
      print(p.requires_grad)
and it returns 30 "True". Also when I try to check the #trainable parameters, FW and FS returns the same number "57532476", which is exactly both encoder and decoder are trainable. My code to check the #params is :
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])

I think one of the plausible way to disable the decoder is to replace the lines 104-108 in PCNet.py by the following lines:
if self.fix_weight:
      with torch.no_grad():
            # decoder fix weight
            print(len(self.decoder.gru.all_weights))
            for p in self.decoder.gru.all_weights[0]: # self.decoder.gru.all_weights is a list with length 1
                  p.requires_grad = False
When changed to this way, the # trainable params in FW reduces to "44568636" which seems work. And if it is the case, it seems the performance of FW is also worse than w/o FW in pytorch. So probably training the decoder will give a better performance than disabling it? Please correct me if my understanding is wrong, and looking forward to your response. Thank you.

from predict-cluster.

sukun1045 avatar sukun1045 commented on July 18, 2024

Hi @Yutasq ,
Thank you so much for your analyses and pointing out the problem! I have verified and yes you are correct. I guess we made a mistake when setting the gru.requires_grad be false instead of gru.all_weights.requires_grad. It was also misleading when we set the 'fix_weight' boolean variable be False, the experiment results looked always worse than the case of being 'True' and this might be due to the random initialization at every different training. But now it is clear that if the decoder gru weights are changing, definitely this is a mistake.

Sorry to make a error like this. I think the idea of this paper is trying to learn a better representation of hidden state for some spatial-temporal series data in a unsupervised way. The intuition behind fixing the weight of decoder is we find that even using a random initialization encoder (without training), the final encoder hidden state has already reached a quite high accuracy. And since only last encoder hidden state has been used to evaluate the accuracy, focusing on improving the encoder only may be the direction. While there is issue in Fixed Weight strategy implementation, I still want to mention that the decoder has zero inputs (non-autoregressive) even in non-FW case which can already be seen as a weaker way to train the decoder.

Thanks again and feel free to leave any question or comment!

from predict-cluster.

Yutasq avatar Yutasq commented on July 18, 2024

Hi @sukun1045 ,

Thanks a lot for your reply and clarification. I also agree that the idea of weakening the decoder by feeding zeros indeed helps the encoder learn better, and also many other good designs that make it work. I learn a lot from your paper and your code. Many thanks for your great work!

from predict-cluster.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.