Comments (6)
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 PORT=29813 \
./tools/dist_train.sh configs/h2rbox/h2rbox_r50_adamw_fpn_1x_dota_le90.py 8
from h2rbox-mmrotate.
Thanks for your quick reply. I did use tools/dist_train.sh
for multiple GPU training. This is the script I used for training with multiple GPUS:
bash tools/dist_train.sh configs/h2rbox/h2rbox-le90_r50_fpn_adamw-1x_logo.py 8
I used my custom dataset, in COCO format, configs/_base_/datasets/dota_coco_logo.py
:
# dataset settings
dataset_type = 'mmdet.CocoDataset'
data_root = 'data/split_ms_dota/'
file_client_args = dict(backend='disk')
train_pipeline = [
dict(type='mmdet.LoadImageFromFile', file_client_args=file_client_args),
dict(
type='mmdet.LoadAnnotations',
with_bbox=True,
with_mask=False,
poly2mask=False),
# Horizontal GTBox, (x,y,w,h,theta)
dict(type='ConvertBoxType', box_type_mapping=dict(gt_bboxes='rbox')),
dict(type='mmdet.Resize', scale=(800, 800), keep_ratio=True),
dict(
type='mmdet.RandomFlip',
prob=0.75,
direction=['horizontal', 'vertical', 'diagonal']),
dict(type='mmdet.PackDetInputs')
]
val_pipeline = [
dict(type='mmdet.LoadImageFromFile', file_client_args=file_client_args),
dict(type='mmdet.Resize', scale=(800, 800), keep_ratio=True),
# avoid bboxes being resized
dict(
type='mmdet.LoadAnnotations',
with_bbox=True,
with_mask=False,
poly2mask=False),
# Horizontal GTBox, (x,y,w,h,theta)
dict(type='ConvertBoxType', box_type_mapping=dict(gt_bboxes='rbox')),
dict(
type='mmdet.PackDetInputs',
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
'scale_factor', 'instances'))
]
test_pipeline = [
dict(type='mmdet.LoadImageFromFile', file_client_args=file_client_args),
dict(type='mmdet.Resize', scale=(800, 800), keep_ratio=True),
dict(
type='mmdet.PackDetInputs',
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
'scale_factor'))
]
metainfo = dict(
classes=('ukn'))
train_dataloader = dict(
batch_size=2,
num_workers=2,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=True),
batch_sampler=None,
dataset=dict(
type=dataset_type,
metainfo=metainfo,
# data_root=data_root,
ann_file='train/train.json',
data_prefix=dict(img=''),
filter_cfg=dict(filter_empty_gt=True),
pipeline=train_pipeline))
val_dataloader = dict(
batch_size=1,
num_workers=2,
persistent_workers=True,
drop_last=False,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
metainfo=metainfo,
# data_root=data_root,
ann_file='val/val.json',
data_prefix=dict(img=''),
test_mode=True,
pipeline=val_pipeline))
test_dataloader = val_dataloader
val_evaluator = dict(type='RotatedCocoMetric', metric='bbox', classwise=True)
test_evaluator = val_evaluator
# inference on test dataset and format the output results
# for submission. Note: the test set has no annotation.
# test_dataloader = dict(
# batch_size=1,
# num_workers=2,
# persistent_workers=True,
# drop_last=False,
# sampler=dict(type='DefaultSampler', shuffle=False),
# dataset=dict(
# type=dataset_type,
# ann_file='test/test.json',
# data_prefix=dict(img='test/images/'),
# test_mode=True,
# pipeline=test_pipeline))
# test_evaluator = dict(
# type='DOTAMetric',
# format_only=True,
# merge_patches=True,
# outfile_prefix='./work_dirs/dota/Task1')
The config file is configs/h2rbox/h2rbox-le90_r50_fpn_adamw-1x_logo.py
:
_base_ = [
'../_base_/datasets/dota_coco_logo.py', '../_base_/schedules/schedule_1x.py',
'../_base_/default_runtime.py'
]
angle_version = 'le90'
# model settings
model = dict(
type='H2RBoxDetector',
crop_size=(800, 800),
data_preprocessor=dict(
type='mmdet.DetDataPreprocessor',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
bgr_to_rgb=True,
pad_size_divisor=32,
boxtype2tensor=False),
backbone=dict(
type='mmdet.ResNet',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=True,
style='pytorch',
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),
neck=dict(
type='mmdet.FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
start_level=1,
add_extra_convs='on_output',
num_outs=5,
relu_before_extra_convs=True),
bbox_head=dict(
type='H2RBoxHead',
num_classes=1,
in_channels=256,
angle_version='le90',
stacked_convs=4,
feat_channels=256,
strides=[8, 16, 32, 64, 128],
center_sampling=True,
center_sample_radius=1.5,
norm_on_bbox=True,
centerness_on_reg=True,
use_hbbox_loss=False,
scale_angle=True,
bbox_coder=dict(
type='DistanceAnglePointCoder', angle_version=angle_version),
loss_cls=dict(
type='mmdet.FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
loss_bbox=dict(type='mmdet.IoULoss', loss_weight=1.0),
loss_centerness=dict(
type='mmdet.CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
# square_classes=[9, 11],
crop_size=(800, 800),
loss_bbox_ss=dict(
type='H2RBoxConsistencyLoss',
loss_weight=0.4,
center_loss_cfg=dict(type='mmdet.L1Loss', loss_weight=0.0),
shape_loss_cfg=dict(type='mmdet.IoULoss', loss_weight=1.0),
angle_loss_cfg=dict(type='mmdet.L1Loss', loss_weight=1.0))),
# training and testing settings
train_cfg=None,
test_cfg=dict(
nms_pre=2000,
min_bbox_size=0,
score_thr=0.05,
nms=dict(type='nms_rotated', iou_threshold=0.1),
max_per_img=2000))
# load hbox annotations
# train_pipeline = [
# dict(
# type='mmdet.LoadImageFromFile',
# file_client_args={{_base_.file_client_args}}),
# dict(type='mmdet.LoadAnnotations', with_bbox=True, box_type='qbox'),
# # Horizontal GTBox, (x1,y1,x2,y2)
# dict(type='ConvertBoxType', box_type_mapping=dict(gt_bboxes='hbox')),
# # Horizontal GTBox, (x,y,w,h,theta)
# dict(type='ConvertBoxType', box_type_mapping=dict(gt_bboxes='rbox')),
# dict(type='mmdet.Resize', scale=(800, 800), keep_ratio=True),
# dict(
# type='mmdet.RandomFlip',
# prob=0.75,
# direction=['horizontal', 'vertical', 'diagonal']),
# dict(type='mmdet.PackDetInputs')
# ]
metainfo = dict(classes=('ukn'))
train_ann_file='/mmu_cd/fuliangcheng/datasets/logo/OpenBrand/openbrand_and_fake_goods_cdp4w_relabel_ecomm_s1_s3_relabel_topgmv_p0.20230201.json.coco.json'
train_dataloader = dict(dataset=dict(ann_file=train_ann_file, metainfo=metainfo))
val_ann_file='/mmu_cd/fuliangcheng/datasets/logo/OpenBrand/retrieval/brand_logo_gmv1w商品_detect_input.json.coco.json'
val_dataloader = dict(dataset=dict(ann_file=val_ann_file, metainfo=metainfo))
# optimizer
optim_wrapper = dict(
optimizer=dict(
_delete_=True,
type='AdamW',
lr=0.0001,
betas=(0.9, 0.999),
weight_decay=0.05))
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=12, val_interval=6)
from h2rbox-mmrotate.
Try add find_unused_parameters=True
in config。
from h2rbox-mmrotate.
Indeed, after find_unused_parameters=True
is added in the config file, the training starts successfully. However, there is a warning stating that did not find any unused parameters
. If that is the case, not sure why the training is terminated when this parameter is not set. Any idea?
from h2rbox-mmrotate.
Have you successfully trained the default configuration file for multiple GPUs (not your own dataset)?
from h2rbox-mmrotate.
I did not try reproduce with DOTA dataset. The dataset is too big to download. After find_unused_parameters=True
is added, the training seems to work as expected as the loss is dropping.
from h2rbox-mmrotate.
Related Issues (19)
- Some small questions for r2h function. HOT 2
- 数据增强 HOT 5
- HRSC dataset test result HOT 2
- visualization of test resuts
- KFIOU的3D实现
- can you share your trained model HOT 1
- Training on HBBs is not giving any performance increase on OBBs test set.
- What is the dotav1 annotation box file used for h2rbox? HOT 5
- 模型训练 HOT 1
- code HOT 1
- About training on HRSID HOT 2
- ImportError HOT 1
- no module name h2rbox HOT 11
- sar图像旋转检测 HOT 16
- .txt format annotation of horizontal boxes for ssdd and hrsid datasets on h2rbox? HOT 14
- Want each category mAP HOT 1
- H2RBoxHead' object has no attribute 'seprate_angle' HOT 1
- Welcome update to OpenMMLab 2.0
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from h2rbox-mmrotate.