Comments (5)
I think i've found the problem. type=mmfewshot.TFA
in config file cannot be recognized as well as other repo algorithms, for example type=mmdet.retinanet
.
Maybe It's mainly due to different file structure between mmfewshot
and the other openmmlab repo such as mmdetection
. So builder.py
in mmfewshot
cannot be correctly found for Registry(model)
class.
Also mmfewshot
is mainly built on mmcls
and mmdet
, so components in model TFA
in repo mmfewshot
, for example Faster RCNN
are originally buildt and registered in mmdet
, which makes Registry
function fail to find them under mmfewshot
Switch to type=TFA
and applying mmfewshot builder function to build models directly in MMFewShotArchitecture
works fine for me currently.
from mmrazor.
Can you make a PR so we can develop it together? @Chan-Sun
from mmrazor.
Can you make a PR so we can develop it together? @Chan-Sun
that's cool, but i only use only detection part from mmfewshot and kd part from mmrazor. i'm not sure if it works well for other parts.
i'll try to make a PR when finishing this work.
from mmrazor.
Welcome!
from mmrazor.
This problem seems to be common on other openmmlab repo, for example mmocr.
Here is my solution.Take TFA model in mmfewshot as an example.
- for architecture code. Add class
MMFewShotArchitecture
and slightly modify the originalBaseArchitecture
code
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.cnn import MODELS
from mmcv.runner import BaseModule
class BaseArchitecture(BaseModule):
"""Base class for architecture.
Args:
model (:obj:`torch.nn.Module`): Model to be slimmed, such as
``DETECTOR`` in MMDetection.
"""
def __init__(self, model, **kwargs):
super(BaseArchitecture, self).__init__(**kwargs)
#### modification code
if isinstance(model,dict):
self.model = MODELS.build(model)
else:
self.model = model
# Copyright (c) OpenMMLab. All rights reserved.
from mmrazor.models.builder import ARCHITECTURES
from mmfewshot.detection.models.builder import build_detector
@ARCHITECTURES.register_module()
class MMFewShotArchitecture(BaseArchitecture):
"""Architecture based on MMFewShot."""
def __init__(self,**kwargs):
#### This is for student model
model = build_detector(kwargs["model"])
super(MMFewShotArchitecture, self).__init__(model)
- for config code
model = dict(
type='TFA',
………………
- For distillation framework in mmrazor, take
SingleTeacherDistiller
as example
do not forget to modifybuild_teacher
function
from mmfewshot.detection.models import build_detector
@DISTILLERS.register_module()
class SingleTeacherDistiller(BaseDistiller):
def __init__(self,
teacher,
teacher_trainable=False,
teacher_norm_eval=True,
components=tuple(),
**kwargs):
super().__init__(**kwargs)
self.teacher = self.build_teacher(teacher)
def build_teacher(self, cfg):
###This is for teacher model
"""Build a model from the `cfg`."""
teacher = build_detector(cfg)
return teacher
from mmrazor.
Related Issues (20)
- [Feature] merge DFND algorithm to mmrazor
- [Bug] class `CustomTracer` in mmrazor/models/task_modules/tracer/fx/custom_tracer.py: __init__() takes 1 positional argument but 3 were given HOT 1
- why the accuracy of distilled student is lower than model of mmpretrain HOT 2
- TypeError: __init__() takes 1 positional argument but 2 were given HOT 2
- TypeError: 'Proxy' object cannot be interpreted as an integer HOT 1
- Quantizing rtmdet to int8 fails HOT 2
- How to load the training weights of the student model before distillation? HOT 1
- Does mmrazor now support mmroate?
- ImportError: `torch>=1.13` is not installed properly, plz check. HOT 1
- Are there any tutorials that are friendly to newcomers?
- compile mmrazor gets a wrong results with pip install -v -e .
- The student model maintains an mAP of 0 throughout the distiller training process.
- KD connector
- torch.fx.proxy.TraceError: class `MMArchitectureQuant` HOT 11
- generate a yolov5_prune_config file
- [Bug] use mmrazor/main/tools/visualizations/feature_diff_visualization.py error: inference_mot() missing 2 required positional arguments: 'frame_id' and 'video_len'
- How to use the mmrazor tool to lightweight the end2end.onnx model generated by mmdeploy
- Error: assert len(module.recorded_input) > 0
- Was support for GN ever added to mmrazor?
- I can't reproduce dfad results
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from mmrazor.