Giter VIP home page Giter VIP logo

[Bug] (suggested fix) `mmrazor.models.algorithms.quantization.mm_architecture.MMArchitectureQuant.sync_qparams()` fails if there are modules present in other modes but not in forward `mode='tensor'` about mmrazor HOT 4 OPEN

elisa-aleman avatar elisa-aleman commented on September 28, 2024
[Bug] (suggested fix) `mmrazor.models.algorithms.quantization.mm_architecture.MMArchitectureQuant.sync_qparams()` fails if there are modules present in other modes but not in forward `mode='tensor'`

from mmrazor.

Comments (4)

elisa-aleman avatar elisa-aleman commented on September 28, 2024

Added more context and suggested a fix

from mmrazor.

elisa-aleman avatar elisa-aleman commented on September 28, 2024

After trying to deploy the quantized model, I realized the suggested fix might be unnecessary and cause further issues since the mmdeploy/tools/deploy.py will force model.architecture.test_cfg.flip_test=False for pose estimators, which means that there would be extra weights in the quantized state_dict and cause the model deploy to fail.

I then tried:

python /tools/train.py \
    ${qat_topdown_cgf} \
    --cgf-options \
         model.architecture.test_cfg.flip_test=False \
    --work-dir /path/here/

But the model still fails to sync without my patch.

from mmrazor.

elisa-aleman avatar elisa-aleman commented on September 28, 2024

I realized that the sync_qparams() is also called from the loss mode as a source mode during the training loop, so my previous fix actually removes any progress during training. I suggest this new fix that doesn't reset fake weight values if not found, although I've yet to finish deploying this model and so it's subject to changes.

@@ -121,7 +121,7 @@ class MMArchitectureQuant(BaseAlgorithm):
                 in some subtle ways, so we need to sync them here.
         """
 
-        def traverse(module, prefix):
+        def traverse(module, prefix, mode, src_mode):
             for name, child in module._modules.items():
                 if module is None:
                     continue
@@ -129,7 +129,13 @@ class MMArchitectureQuant(BaseAlgorithm):
                 if isinstance(child, FakeQuantizeBase):
                     for name, param in child.named_parameters():
                         param_name = f'{child_name}.{name}'
-                        src_param = src_state_dict[param_name]
+                        src_param = src_state_dict.get(param_name)
+                        if '_dup' in param_name and src_param is None:
+                            param_name = '.'.join([section.split('_dup')[0] for section in param_name.split('.')])
+                            src_param = src_state_dict.get(param_name)
+                        if src_param is None:
+                            print(f"{param_name} in mode: '{mode}' but not found in source mode: '{src_mode}', skipping sync.")
+                            continue
                         if src_param.shape == param.shape:
                             param.data.copy_(src_param)
                         else:
@@ -140,20 +146,26 @@ class MMArchitectureQuant(BaseAlgorithm):
                             param.data.copy_(src_param)
                     for name, buffer in child.named_buffers():
                         buffer_name = f'{child_name}.{name}'
-                        src_buffer = src_state_dict[buffer_name]
+                        src_buffer = src_state_dict.get(buffer_name)
+                        if '_dup' in buffer_name and src_buffer is None:
+                            buffer_name = '.'.join([section.split('_dup')[0] for section in buffer_name.split('.')])
+                            src_buffer = src_state_dict.get(buffer_name)
+                        if src_buffer is None:
+                            print(f"{buffer_name} in mode: '{mode}' but not found in source mode: '{src_mode}', skipping sync.")
+                            continue
                         if src_buffer.shape == buffer.shape:
                             buffer.data.copy_(src_buffer)
                         else:
                             buffer.resize_(src_buffer.shape)
                             buffer.data.copy_(src_buffer)
                 else:
-                    traverse(child, f'{child_name}.')
+                    traverse(child, f'{child_name}.', mode, src_mode)
         src_state_dict = self.qmodels[src_mode].state_dict()
         for mode in self.forward_modes:
             if mode == src_mode:
                 continue
-            traverse(self.qmodels[mode], '')
+            traverse(self.qmodels[mode], '', mode, src_mode)

     def _get_rewriter_context_in_mmdeploy(self, deploy_cfg):
         """Get rewriter context in mmdeploy according to the deploy related

from mmrazor.

elisa-aleman avatar elisa-aleman commented on September 28, 2024

After some fixing, the solution to this issue is to refactor the model so that all FX tracing is possible on all modes up until wrapped methods that differ in each mode. as long as the only difference in tracing is after the .forward() method, the syncing won't fail.

from mmrazor.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.