├── .gitignore
├── LICENSE
├── README.md
├── configs
├── _base_
│ ├── datasets
│ │ └── shift.py
│ ├── default_runtime.py
│ └── models
│ │ ├── faster-rcnn_r50_fpn.py
│ │ └── yolox_x_8x8.py
├── continuous
│ ├── mean_teacher_adapter_yolox
│ │ └── yolox_x_8xb4-24e_shift_from_clear_daytime.py
│ └── no_adap_yolox
│ │ ├── yolox_x_8xb4-24e_shift_from_all.py
│ │ ├── yolox_x_8xb4-24e_shift_from_clear_daytime.py
│ │ └── yolox_x_8xb4-24e_shift_from_clear_night.py
└── source
│ └── yolox
│ ├── README.md
│ ├── amp_yolox_x_8xb4-24e_shift_all.py
│ ├── amp_yolox_x_8xb4-24e_shift_clear_daytime.py
│ ├── amp_yolox_x_8xb4-24e_shift_clear_night.py
│ ├── amp_yolox_x_8xb4-24e_shift_daytime.py
│ ├── amp_yolox_x_8xb4-24e_shift_night.py
│ ├── yolox_x_8xb4-24e_shift_all.py
│ ├── yolox_x_8xb4-24e_shift_clear_daytime.py
│ ├── yolox_x_8xb4-24e_shift_clear_night.py
│ ├── yolox_x_8xb4-24e_shift_daytime.py
│ └── yolox_x_8xb4-24e_shift_night.py
├── docs
├── challenge.md
├── dataset_prepare.md
├── get_started.md
├── model_zoo.md
└── train_test.md
├── requirements.txt
├── requirements
├── build.txt
├── mminstall.txt
└── runtime.txt
├── resources
└── shift-logo.png
├── scripts
├── continuous
│ ├── mean_teacher_adapter_yolox
│ │ ├── slurm_test_yolox_shift_from_clear_daytime.sh
│ │ ├── test_yolox_shift_from_clear_daytime.sh
│ │ └── val_yolox_shift_from_clear_daytime.sh
│ └── no_adap_yolox
│ │ ├── test_yolox_shift_from_clear_daytime.sh
│ │ ├── test_yolox_shift_from_clear_night.sh
│ │ ├── val_yolox_shift_from_clear_daytime.sh
│ │ └── val_yolox_shift_from_clear_night.sh
├── source
│ ├── slurm_train_yolox_24e_shift_clear_daytime.sh
│ ├── slurm_train_yolox_24e_shift_daytime.sh
│ ├── test_yolox_shift_clear_daytime.sh
│ ├── test_yolox_shift_clear_night.sh
│ └── train_yolox_shift_clear_daytime.sh
└── tmp
│ ├── edit_json.py
│ └── edit_json_ids.py
├── setup.cfg
├── setup.py
├── shift_tta
├── .mim
│ ├── configs
│ └── tools
├── __init__.py
├── datasets
│ ├── __init__.py
│ ├── shift_dataset.py
│ └── utils
│ │ ├── __init__.py
│ │ └── filters.py
├── evaluation
│ ├── __init__.py
│ └── metrics
│ │ ├── __init__.py
│ │ └── shift_video_metrics.py
├── fileio
│ ├── __init__.py
│ └── backends
│ │ ├── __init__.py
│ │ ├── tar_backend.py
│ │ └── zip_backend.py
├── models
│ ├── __init__.py
│ ├── adapters
│ │ ├── __init__.py
│ │ ├── base_adapter.py
│ │ └── mean_teacher_yolox_adapter.py
│ ├── detectors
│ │ ├── __init__.py
│ │ ├── adaptive_detector.py
│ │ └── base.py
│ └── losses
│ │ ├── __init__.py
│ │ └── yolox_consistency_loss.py
├── registry.py
├── utils
│ ├── __init__.py
│ └── setup_env.py
└── version.py
└── tools
├── dist_test.sh
├── dist_train.sh
├── install
└── setup_venv.sh
├── shift
└── download.py
├── test.py
├── test.sh
├── train.py
└── train.sh
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 | db.sqlite3
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/en/_build/
68 | docs/zh_cn/_build/
69 |
70 | # PyBuilder
71 | target/
72 |
73 | # Jupyter Notebook
74 | .ipynb_checkpoints
75 |
76 | # pyenv
77 | .python-version
78 |
79 | # celery beat schedule file
80 | celerybeat-schedule
81 |
82 | # SageMath parsed files
83 | *.sage.py
84 |
85 | # Environments
86 | .env
87 | .venv
88 | env/
89 | venv/
90 | ENV/
91 | env.bak/
92 | venv.bak/
93 |
94 | # Spyder project settings
95 | .spyderproject
96 | .spyproject
97 |
98 | # Rope project settings
99 | .ropeproject
100 |
101 | # mkdocs documentation
102 | /site
103 |
104 | # mypy
105 | .mypy_cache/
106 |
107 | data
108 | ckpts
109 | .vscode
110 | .idea
111 | .DS_Store
112 |
113 | # custom
114 | *.pkl
115 | *.pkl.json
116 | *.log.json
117 | work_dirs/
118 | shift_tta/.mim
119 | shift_tta.egg_info
120 | errors/
121 | outputs/
122 | checkpoints/
123 | scripts/tmp/
124 |
125 | # Pytorch
126 | *.pth
127 | *.py~
128 | *.sh~
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 ETH VIS Group
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |

3 |
4 |
17 |
18 |
19 |
20 | ## Introduction
21 |
22 | [SHIFT](https://www.vis.xyz/shift/) is a driving dataset for continuous multi-task domain adaptation. It is maintained by the [VIS Group](https://www.vis.xyz/) at ETH Zurich.
23 |
24 | The main branch works with **PyTorch1.6+**.
25 |
26 |
27 | https://github.com/SysCV/shift-detection-tta/assets/44324619/9ddc4b31-7ca9-46b1-a1c5-3b9107e04f9e
28 |
29 |
30 |
31 |
32 | ## Tutorial
33 | ### Get started
34 |
35 | Please refer to [get_started.md](docs/get_started.md) for install instructions.
36 |
37 | ### Prepare the SHIFT dataset
38 |
39 | Please refer to [dataset_prepare.md](docs/dataset_prepare.md) for instructions on how to download and prepare the SHIFT dataset.
40 |
41 | ### Usage
42 |
43 | Please refer to [train_test.md](docs/train_test.md) for instructions on how to train and test your own model.
44 |
45 | ### Participate in the Challenge on Continuous Test-time Adaptation for Object Detection
46 |
47 | Please refer to [challenge.md](docs/challenge.md) for instructions on how to participate in the challenge and for training, test, and adaptation instructions.
48 |
49 | The challenge is organized for the Workshop on [Visual Continual Learning @ ICCV2023](https://wvcl.vis.xyz). Checkout [wvcl.vis.xyz/challenges](https://wvcl.vis.xyz/challenges) for additional details on this and other challenges.
50 |
51 | We will award the top three teams of each challenge with a certificate and a prize of 1000, 500, and 300 USD, respectively. The winners of each challenge will be invited to give a presentation at the workshop. Teams will be selected based on the performance of their methods on the test set.
52 |
53 | We will also award one team from each challenge with an innovation award. The innovation award is given to the team that proposes the most innovative method and/or insightful analysis. The winner will receive a certificate and an additional prize of 300 USD.
54 |
55 | **Please notice** that this challenge is part of the track **Challenge B - Continual Test-time Adaptation**, together with the challenge on "Continuous Test-time Adaptation for Semantic Segmentation". Since the challenge on "Continuous Test-time Adaptation for Object Detection" constitutes half of the track B, the prize should be considered half of what mentioned above.
56 |
57 | # Continuous Test-time Adaptation for Object Detection
58 |
59 | ## Model zoo
60 |
61 | Results and models are available in the [model zoo](docs/model_zoo.md).
62 |
63 | ### Object Detection
64 |
65 | Supported Adaptation Methods
66 | - [x] [no_adap](configs/continuous/no_adap_yolox)
67 | - [x] [mean_teacher_adapter_yolox](configs/continuous/mean_teacher_adapter_yolox)
68 |
69 | Supported Datasets
70 |
71 | - [x] [SHIFT](https://www.vis.xyz/shift/)
72 |
73 |
74 | ## Citation
75 |
76 | If you find this project useful in your research, please consider citing:
77 |
78 | - SHIFT, the dataset powering this challenge and the continuous adaptation tasks:
79 |
80 | ```latex
81 | @inproceedings{sun2022shift,
82 | title={SHIFT: a synthetic driving dataset for continuous multi-task domain adaptation},
83 | author={Sun, Tao and Segu, Mattia and Postels, Janis and Wang, Yuxuan and Van Gool, Luc and Schiele, Bernt and Tombari, Federico and Yu, Fisher},
84 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
85 | pages={21371--21382},
86 | year={2022}
87 | }
88 | ```
89 |
90 | - DARTH, the test-time adaptation method introducing the detection consistency loss for detection adaptation based on mean-teacher:
91 | ```latex
92 | @inproceedings{segu2023darth,
93 | title={Darth: holistic test-time adaptation for multiple object tracking},
94 | author={Segu, Mattia and Schiele, Bernt and Yu, Fisher},
95 | booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
96 | pages={9717--9727},
97 | year={2023}
98 | }
99 | ```
100 |
101 | ## License
102 |
103 | This project is released under the [MIT License](LICENSE).
104 |
--------------------------------------------------------------------------------
/configs/_base_/datasets/shift.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SysCV/shift-detection-tta/a5c4fb0e906caa57c0b572d651f24f01ff6bdcdf/configs/_base_/datasets/shift.py
--------------------------------------------------------------------------------
/configs/_base_/default_runtime.py:
--------------------------------------------------------------------------------
1 | default_scope = 'shift_tta'
2 |
3 | default_hooks = dict(
4 | timer=dict(type='IterTimerHook'),
5 | logger=dict(type='LoggerHook', interval=50),
6 | param_scheduler=dict(type='ParamSchedulerHook'),
7 | checkpoint=dict(type='CheckpointHook', interval=1),
8 | sampler_seed=dict(type='DistSamplerSeedHook'),
9 | visualization=dict(
10 | type='mmtrack.TrackVisualizationHook',
11 | draw=False),
12 | )
13 |
14 | env_cfg = dict(
15 | cudnn_benchmark=False,
16 | mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
17 | dist_cfg=dict(backend='nccl'),
18 | )
19 |
20 | vis_backends = [dict(type='LocalVisBackend')]
21 | visualizer = dict(
22 | type='mmtrack.TrackLocalVisualizer', vis_backends=vis_backends, name='visualizer')
23 |
24 | log_level = 'INFO'
25 | load_from = None
26 | resume = False
--------------------------------------------------------------------------------
/configs/_base_/models/faster-rcnn_r50_fpn.py:
--------------------------------------------------------------------------------
1 | model = dict(
2 | data_preprocessor=dict(
3 | type='mmtrack.TrackDataPreprocessor',
4 | mean=[123.675, 116.28, 103.53],
5 | std=[58.395, 57.12, 57.375],
6 | bgr_to_rgb=True,
7 | rgb_to_bgr=False,
8 | pad_size_divisor=32),
9 | detector=dict(
10 | type='FasterRCNN',
11 | _scope_='mmdet',
12 | backbone=dict(
13 | type='ResNet',
14 | depth=50,
15 | num_stages=4,
16 | out_indices=(0, 1, 2, 3),
17 | frozen_stages=1,
18 | norm_cfg=dict(type='BN', requires_grad=True),
19 | norm_eval=True,
20 | style='pytorch',
21 | init_cfg=dict(
22 | type='Pretrained', checkpoint='torchvision://resnet50')),
23 | neck=dict(
24 | type='FPN',
25 | in_channels=[256, 512, 1024, 2048],
26 | out_channels=256,
27 | num_outs=5),
28 | rpn_head=dict(
29 | type='RPNHead',
30 | in_channels=256,
31 | feat_channels=256,
32 | anchor_generator=dict(
33 | type='AnchorGenerator',
34 | scales=[8],
35 | ratios=[0.5, 1.0, 2.0],
36 | strides=[4, 8, 16, 32, 64]),
37 | bbox_coder=dict(
38 | type='DeltaXYWHBBoxCoder',
39 | target_means=[.0, .0, .0, .0],
40 | target_stds=[1.0, 1.0, 1.0, 1.0]),
41 | loss_cls=dict(
42 | type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
43 | loss_bbox=dict(
44 | type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)),
45 | roi_head=dict(
46 | type='StandardRoIHead',
47 | bbox_roi_extractor=dict(
48 | type='SingleRoIExtractor',
49 | roi_layer=dict(
50 | type='RoIAlign', output_size=7, sampling_ratio=0),
51 | out_channels=256,
52 | featmap_strides=[4, 8, 16, 32]),
53 | bbox_head=dict(
54 | type='Shared2FCBBoxHead',
55 | in_channels=256,
56 | fc_out_channels=1024,
57 | roi_feat_size=7,
58 | num_classes=80,
59 | bbox_coder=dict(
60 | type='DeltaXYWHBBoxCoder',
61 | target_means=[0., 0., 0., 0.],
62 | target_stds=[0.1, 0.1, 0.2, 0.2]),
63 | reg_class_agnostic=False,
64 | loss_cls=dict(
65 | type='CrossEntropyLoss',
66 | use_sigmoid=False,
67 | loss_weight=1.0),
68 | loss_bbox=dict(type='SmoothL1Loss', loss_weight=1.0))),
69 | train_cfg=dict(
70 | rpn=dict(
71 | assigner=dict(
72 | type='MaxIoUAssigner',
73 | pos_iou_thr=0.7,
74 | neg_iou_thr=0.3,
75 | min_pos_iou=0.3,
76 | match_low_quality=True,
77 | ignore_iof_thr=-1),
78 | sampler=dict(
79 | type='RandomSampler',
80 | num=256,
81 | pos_fraction=0.5,
82 | neg_pos_ub=-1,
83 | add_gt_as_proposals=False),
84 | allowed_border=-1,
85 | pos_weight=-1,
86 | debug=False),
87 | rpn_proposal=dict(
88 | nms_pre=2000,
89 | max_per_img=1000,
90 | nms=dict(type='nms', iou_threshold=0.7),
91 | min_bbox_size=0),
92 | rcnn=dict(
93 | assigner=dict(
94 | type='MaxIoUAssigner',
95 | pos_iou_thr=0.5,
96 | neg_iou_thr=0.5,
97 | min_pos_iou=0.5,
98 | match_low_quality=False,
99 | ignore_iof_thr=-1),
100 | sampler=dict(
101 | type='RandomSampler',
102 | num=512,
103 | pos_fraction=0.25,
104 | neg_pos_ub=-1,
105 | add_gt_as_proposals=True),
106 | pos_weight=-1,
107 | debug=False)),
108 | test_cfg=dict(
109 | rpn=dict(
110 | nms_pre=1000,
111 | max_per_img=1000,
112 | nms=dict(type='nms', iou_threshold=0.7),
113 | min_bbox_size=0),
114 | rcnn=dict(
115 | score_thr=0.05,
116 | nms=dict(type='nms', iou_threshold=0.5),
117 | max_per_img=100))
118 | # soft-nms is also supported for rcnn testing
119 | # e.g., nms=dict(type='soft_nms', iou_threshold=0.5, min_score=0.05)
120 | ))
--------------------------------------------------------------------------------
/configs/_base_/models/yolox_x_8x8.py:
--------------------------------------------------------------------------------
1 | # model settings
2 | img_scale = (640, 640)
3 |
4 | model = dict(
5 | data_preprocessor=dict(
6 | type='mmtrack.TrackDataPreprocessor',
7 | pad_size_divisor=32,
8 | batch_augments=[
9 | dict(
10 | type='mmdet.BatchSyncRandomResize',
11 | random_size_range=(480, 800),
12 | size_divisor=32,
13 | interval=10)
14 | ]),
15 | detector=dict(
16 | _scope_='mmdet',
17 | type='YOLOX',
18 | backbone=dict(
19 | type='CSPDarknet', deepen_factor=1.33, widen_factor=1.25),
20 | neck=dict(
21 | type='YOLOXPAFPN',
22 | in_channels=[320, 640, 1280],
23 | out_channels=320,
24 | num_csp_blocks=4),
25 | bbox_head=dict(
26 | type='YOLOXHead',
27 | num_classes=80,
28 | in_channels=320,
29 | feat_channels=320),
30 | train_cfg=dict(
31 | assigner=dict(type='SimOTAAssigner', center_radius=2.5)),
32 | test_cfg=dict(
33 | score_thr=0.01, nms=dict(type='nms', iou_threshold=0.65))))
--------------------------------------------------------------------------------
/configs/continuous/mean_teacher_adapter_yolox/yolox_x_8xb4-24e_shift_from_clear_daytime.py:
--------------------------------------------------------------------------------
1 | _base_ = [
2 | '../../_base_/models/yolox_x_8x8.py',
3 | '../../_base_/default_runtime.py'
4 | ]
5 |
6 | dataset_type = 'SHIFTDataset'
7 | data_root = 'data/shift/'
8 | attributes = dict(weather_coarse='clear', timeofday_coarse='daytime')
9 |
10 | img_scale = (800, 1440)
11 | batch_size = 2
12 |
13 | model = dict(
14 | type='AdaptiveDetector',
15 | data_preprocessor=dict(
16 | type='mmtrack.TrackDataPreprocessor',
17 | pad_size_divisor=32,
18 | batch_augments=[
19 | dict(
20 | type='mmdet.BatchSyncRandomResize',
21 | random_size_range=(576, 1024),
22 | size_divisor=32,
23 | interval=10)
24 | ]),
25 | detector=dict(
26 | _scope_='mmdet',
27 | bbox_head=dict(num_classes=6),
28 | test_cfg=dict(score_thr=0.01, nms=dict(type='nms', iou_threshold=0.7)),
29 | init_cfg=dict(
30 | type='Pretrained',
31 | checkpoint= # noqa: E251
32 | 'https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_x_8x8_300e_coco/yolox_x_8x8_300e_coco_20211126_140254-1ef88d67.pth' # noqa: E501
33 | )),
34 | adapter=dict(
35 | type='MeanTeacherYOLOXAdapter',
36 | episodic=True, # do NOT change this. episodic must be set to True for the WVCL ICCV 2023 SHIFT Challenges
37 | optim_wrapper=dict(
38 | type='OptimWrapper',
39 | optimizer=dict(
40 | type='SGD', lr=0.00025, momentum=0.9, weight_decay=5e-4, nesterov=True),
41 | paramwise_cfg=dict(norm_decay_mult=0., bias_decay_mult=0.)),
42 | optim_steps=5,
43 | teacher=dict(
44 | type='ExponentialMovingAverage',
45 | momentum=0.0002,
46 | update_buffers=True),
47 | loss=dict(
48 | type='YOLOXConsistencyLoss',
49 | weight=0.01,
50 | ),
51 | pipeline = [
52 | dict(type='LoadImageFromFile',
53 | backend_args=dict(
54 | backend='tar',
55 | tar_path=data_root + 'continuous/videos/1x/val/front/img_decompressed.tar',
56 | )
57 | ),
58 | dict(type='mmtrack.LoadTrackAnnotations'),
59 | ],
60 | teacher_pipeline = [
61 | dict(type='mmdet.Resize', scale=img_scale, keep_ratio=True),
62 | dict(
63 | type='mmdet.Pad',
64 | size_divisor=32,
65 | pad_val=dict(img=(114.0, 114.0, 114.0))),
66 | dict(type='mmtrack.PackTrackInputs', pack_single_img=True),
67 | ],
68 | student_pipeline = [
69 | dict(type='mmdet.YOLOXHSVRandomAug'),
70 | dict(type='mmdet.Resize', scale=img_scale, keep_ratio=True),
71 | dict(
72 | type='mmdet.Pad',
73 | size_divisor=32,
74 | pad_val=dict(img=(114.0, 114.0, 114.0))),
75 | dict(type='mmtrack.PackTrackInputs', pack_single_img=True),
76 | ],
77 | views=2,
78 | ))
79 |
80 | train_pipeline = [
81 | dict(
82 | type='mmdet.Mosaic',
83 | img_scale=img_scale,
84 | pad_val=114.0,
85 | bbox_clip_border=False),
86 | dict(
87 | type='mmdet.RandomAffine',
88 | scaling_ratio_range=(0.1, 2),
89 | border=(-img_scale[0] // 2, -img_scale[1] // 2),
90 | bbox_clip_border=False),
91 | dict(
92 | type='mmdet.MixUp',
93 | img_scale=img_scale,
94 | ratio_range=(0.8, 1.6),
95 | pad_val=114.0,
96 | bbox_clip_border=False),
97 | dict(type='mmdet.YOLOXHSVRandomAug'),
98 | dict(type='mmdet.RandomFlip', prob=0.5),
99 | dict(
100 | type='mmdet.Resize',
101 | scale=img_scale,
102 | keep_ratio=True,
103 | clip_object_border=False),
104 | dict(
105 | type='mmdet.Pad',
106 | size_divisor=32,
107 | pad_val=dict(img=(114.0, 114.0, 114.0))),
108 | dict(
109 | type='mmdet.FilterAnnotations',
110 | min_gt_bbox_wh=(1, 1),
111 | keep_empty=False),
112 | dict(type='mmtrack.PackTrackInputs', pack_single_img=True)
113 | ]
114 | test_pipeline = [
115 | dict(type='LoadImageFromFile',
116 | backend_args=dict(
117 | backend='tar',
118 | tar_path=data_root + 'continuous/videos/1x/val/front/img_decompressed.tar',
119 | )
120 | ),
121 | dict(type='mmtrack.LoadTrackAnnotations'),
122 | dict(type='mmdet.Resize', scale=img_scale, keep_ratio=True),
123 | dict(
124 | type='mmdet.Pad',
125 | size_divisor=32,
126 | pad_val=dict(img=(114.0, 114.0, 114.0))),
127 | dict(type='mmtrack.PackTrackInputs', pack_single_img=True)
128 | ]
129 |
130 | train_dataset = dict(
131 | # use MultiImageMixDataset wrapper to support mosaic and mixup
132 | type='mmdet.MultiImageMixDataset',
133 | dataset=dict(
134 | type='SHIFTDataset',
135 | load_as_video=False,
136 | ann_file=data_root + 'discrete/images/train/front/det_2d_cocoformat.json',
137 | data_prefix=dict(img=''),
138 | ref_img_sampler=None,
139 | metainfo=dict(classes=('pedestrian', 'car', 'truck', 'bus', 'motorcycle', 'bicycle')),
140 | pipeline=[
141 | dict(type='LoadImageFromFile',
142 | backend_args=dict(
143 | backend='zip',
144 | zip_path=data_root + 'discrete/images/train/front/img.zip',
145 | )
146 | ),
147 | dict(type='mmtrack.LoadTrackAnnotations'),
148 | ],
149 | filter_cfg=dict(
150 | attributes=attributes,
151 | filter_empty_gt=False,
152 | min_size=32
153 | )),
154 | pipeline=train_pipeline)
155 | train_dataloader = dict(
156 | batch_size=batch_size,
157 | num_workers=4,
158 | persistent_workers=True,
159 | sampler=dict(type='DefaultSampler', shuffle=True),
160 | dataset=train_dataset)
161 |
162 | val_dataset=dict(
163 | type='SHIFTDataset',
164 | load_as_video=True,
165 | ann_file=data_root + 'continuous/videos/1x/val/front/det_2d_cocoformat.json',
166 | data_prefix=dict(img=''),
167 | ref_img_sampler=None,
168 | test_mode=True,
169 | filter_cfg=dict(attributes=attributes),
170 | pipeline=test_pipeline,
171 | metainfo=dict(classes=('pedestrian', 'car', 'truck', 'bus', 'motorcycle', 'bicycle')))
172 | val_dataloader = dict(
173 | batch_size=1,
174 | num_workers=4,
175 | persistent_workers=True,
176 | drop_last=False,
177 | sampler=dict(type='mmtrack.VideoSampler'),
178 | dataset=val_dataset)
179 | test_dataloader = val_dataloader
180 | # optimizer
181 | # default 8 gpu
182 | lr = 0.0005 / 8 * batch_size
183 | optim_wrapper = dict(
184 | type='OptimWrapper',
185 | optimizer=dict(
186 | type='SGD', lr=lr, momentum=0.9, weight_decay=5e-4, nesterov=True),
187 | paramwise_cfg=dict(norm_decay_mult=0., bias_decay_mult=0.))
188 |
189 | # some hyper parameters
190 | # training settings
191 | total_epochs = 12
192 | num_last_epochs = 2
193 | resume_from = None
194 | interval = 5
195 |
196 | train_cfg = dict(
197 | type='EpochBasedTrainLoop', max_epochs=total_epochs, val_interval=1)
198 | val_cfg = dict(type='ValLoop')
199 | test_cfg = dict(type='TestLoop')
200 | # learning policy
201 | param_scheduler = [
202 | dict(
203 | # use quadratic formula to warm up 1 epochs
204 | # and lr is updated by iteration
205 | type='mmdet.QuadraticWarmupLR',
206 | by_epoch=True,
207 | begin=0,
208 | end=1,
209 | convert_to_iter_based=True),
210 | dict(
211 | # use cosine lr from 1 to epoch #(total_epochs - num_last_epochs)
212 | type='mmdet.CosineAnnealingLR',
213 | eta_min=lr * 0.05,
214 | begin=1,
215 | T_max=total_epochs - num_last_epochs,
216 | end=total_epochs - num_last_epochs,
217 | by_epoch=True,
218 | convert_to_iter_based=True),
219 | dict(
220 | # use fixed lr during last 10 epochs
221 | type='mmdet.ConstantLR',
222 | by_epoch=True,
223 | factor=1,
224 | begin=total_epochs - num_last_epochs,
225 | end=total_epochs,
226 | )
227 | ]
228 |
229 | custom_hooks = [
230 | dict(
231 | type='mmtrack.YOLOXModeSwitchHook',
232 | num_last_epochs=num_last_epochs,
233 | priority=48),
234 | dict(type='mmdet.SyncNormHook', priority=48),
235 | dict(
236 | type='mmdet.EMAHook',
237 | ema_type='mmdet.ExpMomentumEMA',
238 | momentum=0.0001,
239 | update_buffers=True,
240 | priority=49)
241 | ]
242 | default_hooks = dict(checkpoint=dict(interval=1))
243 |
244 | # evaluator
245 | val_evaluator = [
246 | dict(type='SHIFTVideoMetric', metric=['bbox'], classwise=True),
247 | ]
248 | test_evaluator = val_evaluator
--------------------------------------------------------------------------------
/configs/continuous/no_adap_yolox/yolox_x_8xb4-24e_shift_from_all.py:
--------------------------------------------------------------------------------
1 | _base_ = [
2 | '../../_base_/models/yolox_x_8x8.py',
3 | '../../_base_/default_runtime.py'
4 | ]
5 |
6 | dataset_type = 'SHIFTDataset'
7 | data_root = 'data/shift/'
8 | attributes = None
9 |
10 | img_scale = (800, 1440)
11 | batch_size = 2
12 |
13 | model = dict(
14 | type='AdaptiveDetector',
15 | data_preprocessor=dict(
16 | type='mmtrack.TrackDataPreprocessor',
17 | pad_size_divisor=32,
18 | batch_augments=[
19 | dict(
20 | type='mmdet.BatchSyncRandomResize',
21 | random_size_range=(576, 1024),
22 | size_divisor=32,
23 | interval=10)
24 | ]),
25 | detector=dict(
26 | _scope_='mmdet',
27 | bbox_head=dict(num_classes=6),
28 | test_cfg=dict(score_thr=0.01, nms=dict(type='nms', iou_threshold=0.7)),
29 | init_cfg=dict(
30 | type='Pretrained',
31 | checkpoint= # noqa: E251
32 | 'https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_x_8x8_300e_coco/yolox_x_8x8_300e_coco_20211126_140254-1ef88d67.pth' # noqa: E501
33 | )),
34 | adapter=None)
35 |
36 | train_pipeline = [
37 | dict(
38 | type='mmdet.Mosaic',
39 | img_scale=img_scale,
40 | pad_val=114.0,
41 | bbox_clip_border=False),
42 | dict(
43 | type='mmdet.RandomAffine',
44 | scaling_ratio_range=(0.1, 2),
45 | border=(-img_scale[0] // 2, -img_scale[1] // 2),
46 | bbox_clip_border=False),
47 | dict(
48 | type='mmdet.MixUp',
49 | img_scale=img_scale,
50 | ratio_range=(0.8, 1.6),
51 | pad_val=114.0,
52 | bbox_clip_border=False),
53 | dict(type='mmdet.YOLOXHSVRandomAug'),
54 | dict(type='mmdet.RandomFlip', prob=0.5),
55 | dict(
56 | type='mmdet.Resize',
57 | scale=img_scale,
58 | keep_ratio=True,
59 | clip_object_border=False),
60 | dict(
61 | type='mmdet.Pad',
62 | size_divisor=32,
63 | pad_val=dict(img=(114.0, 114.0, 114.0))),
64 | dict(
65 | type='mmdet.FilterAnnotations',
66 | min_gt_bbox_wh=(1, 1),
67 | keep_empty=False),
68 | dict(type='mmtrack.PackTrackInputs', pack_single_img=True)
69 | ]
70 | test_pipeline = [
71 | dict(type='LoadImageFromFile',
72 | backend_args=dict(
73 | backend='tar',
74 | tar_path=data_root + 'continuous/videos/1x/val/front/img_decompressed.tar',
75 | )
76 | ),
77 | dict(type='mmtrack.LoadTrackAnnotations'),
78 | dict(type='mmdet.Resize', scale=img_scale, keep_ratio=True),
79 | dict(
80 | type='mmdet.Pad',
81 | size_divisor=32,
82 | pad_val=dict(img=(114.0, 114.0, 114.0))),
83 | dict(type='mmtrack.PackTrackInputs', pack_single_img=True)
84 | ]
85 |
86 | train_dataset = dict(
87 | # use MultiImageMixDataset wrapper to support mosaic and mixup
88 | type='mmdet.MultiImageMixDataset',
89 | dataset=dict(
90 | type='SHIFTDataset',
91 | load_as_video=False,
92 | ann_file=data_root + 'discrete/images/train/front/det_2d_cocoformat.json',
93 | data_prefix=dict(img=''),
94 | ref_img_sampler=None,
95 | metainfo=dict(classes=('pedestrian', 'car', 'truck', 'bus', 'motorcycle', 'bicycle')),
96 | pipeline=[
97 | dict(type='LoadImageFromFile',
98 | backend_args=dict(
99 | backend='zip',
100 | zip_path=data_root + 'discrete/images/train/front/img.zip',
101 | )
102 | ),
103 | dict(type='mmtrack.LoadTrackAnnotations'),
104 | ],
105 | filter_cfg=dict(
106 | attributes=attributes,
107 | filter_empty_gt=False,
108 | min_size=32
109 | )),
110 | pipeline=train_pipeline)
111 | train_dataloader = dict(
112 | batch_size=batch_size,
113 | num_workers=4,
114 | persistent_workers=True,
115 | sampler=dict(type='DefaultSampler', shuffle=True),
116 | dataset=train_dataset)
117 |
118 | val_dataset=dict(
119 | type='SHIFTDataset',
120 | load_as_video=True,
121 | ann_file=data_root + 'continuous/videos/1x/val/front/det_2d_cocoformat.json',
122 | data_prefix=dict(img=''),
123 | ref_img_sampler=None,
124 | test_mode=True,
125 | filter_cfg=dict(attributes=attributes),
126 | pipeline=test_pipeline,
127 | metainfo=dict(classes=('pedestrian', 'car', 'truck', 'bus', 'motorcycle', 'bicycle')))
128 | val_dataloader = dict(
129 | batch_size=1,
130 | num_workers=4,
131 | persistent_workers=True,
132 | drop_last=False,
133 | sampler=dict(type='mmtrack.VideoSampler'),
134 | dataset=val_dataset)
135 | test_dataloader = val_dataloader
136 | # optimizer
137 | # default 8 gpu
138 | lr = 0.0005 / 8 * batch_size
139 | optim_wrapper = dict(
140 | type='OptimWrapper',
141 | optimizer=dict(
142 | type='SGD', lr=lr, momentum=0.9, weight_decay=5e-4, nesterov=True),
143 | paramwise_cfg=dict(norm_decay_mult=0., bias_decay_mult=0.))
144 |
145 | # some hyper parameters
146 | # training settings
147 | total_epochs = 12
148 | num_last_epochs = 2
149 | resume_from = None
150 | interval = 5
151 |
152 | train_cfg = dict(
153 | type='EpochBasedTrainLoop', max_epochs=total_epochs, val_interval=1)
154 | val_cfg = dict(type='ValLoop')
155 | test_cfg = dict(type='TestLoop')
156 | # learning policy
157 | param_scheduler = [
158 | dict(
159 | # use quadratic formula to warm up 1 epochs
160 | # and lr is updated by iteration
161 | type='mmdet.QuadraticWarmupLR',
162 | by_epoch=True,
163 | begin=0,
164 | end=1,
165 | convert_to_iter_based=True),
166 | dict(
167 | # use cosine lr from 1 to epoch #(total_epochs - num_last_epochs)
168 | type='mmdet.CosineAnnealingLR',
169 | eta_min=lr * 0.05,
170 | begin=1,
171 | T_max=total_epochs - num_last_epochs,
172 | end=total_epochs - num_last_epochs,
173 | by_epoch=True,
174 | convert_to_iter_based=True),
175 | dict(
176 | # use fixed lr during last 10 epochs
177 | type='mmdet.ConstantLR',
178 | by_epoch=True,
179 | factor=1,
180 | begin=total_epochs - num_last_epochs,
181 | end=total_epochs,
182 | )
183 | ]
184 |
185 | custom_hooks = [
186 | dict(
187 | type='mmtrack.YOLOXModeSwitchHook',
188 | num_last_epochs=num_last_epochs,
189 | priority=48),
190 | dict(type='mmdet.SyncNormHook', priority=48),
191 | dict(
192 | type='mmdet.EMAHook',
193 | ema_type='mmdet.ExpMomentumEMA',
194 | momentum=0.0001,
195 | update_buffers=True,
196 | priority=49)
197 | ]
198 | default_hooks = dict(checkpoint=dict(interval=1))
199 |
200 | # evaluator
201 | val_evaluator = [
202 | dict(type='SHIFTVideoMetric', metric=['bbox'], classwise=True),
203 | ]
204 | test_evaluator = val_evaluator
--------------------------------------------------------------------------------
/configs/continuous/no_adap_yolox/yolox_x_8xb4-24e_shift_from_clear_daytime.py:
--------------------------------------------------------------------------------
1 | _base_ = [
2 | '../../_base_/models/yolox_x_8x8.py',
3 | '../../_base_/default_runtime.py'
4 | ]
5 |
6 | dataset_type = 'SHIFTDataset'
7 | data_root = 'data/shift/'
8 | attributes = dict(weather_coarse='clear', timeofday_coarse='daytime')
9 |
10 | img_scale = (800, 1440)
11 | batch_size = 2
12 |
13 | model = dict(
14 | type='AdaptiveDetector',
15 | data_preprocessor=dict(
16 | type='mmtrack.TrackDataPreprocessor',
17 | pad_size_divisor=32,
18 | batch_augments=[
19 | dict(
20 | type='mmdet.BatchSyncRandomResize',
21 | random_size_range=(576, 1024),
22 | size_divisor=32,
23 | interval=10)
24 | ]),
25 | detector=dict(
26 | _scope_='mmdet',
27 | bbox_head=dict(num_classes=6),
28 | test_cfg=dict(score_thr=0.01, nms=dict(type='nms', iou_threshold=0.7)),
29 | init_cfg=dict(
30 | type='Pretrained',
31 | checkpoint= # noqa: E251
32 | 'https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_x_8x8_300e_coco/yolox_x_8x8_300e_coco_20211126_140254-1ef88d67.pth' # noqa: E501
33 | )),
34 | adapter=None)
35 |
36 | train_pipeline = [
37 | dict(
38 | type='mmdet.Mosaic',
39 | img_scale=img_scale,
40 | pad_val=114.0,
41 | bbox_clip_border=False),
42 | dict(
43 | type='mmdet.RandomAffine',
44 | scaling_ratio_range=(0.1, 2),
45 | border=(-img_scale[0] // 2, -img_scale[1] // 2),
46 | bbox_clip_border=False),
47 | dict(
48 | type='mmdet.MixUp',
49 | img_scale=img_scale,
50 | ratio_range=(0.8, 1.6),
51 | pad_val=114.0,
52 | bbox_clip_border=False),
53 | dict(type='mmdet.YOLOXHSVRandomAug'),
54 | dict(type='mmdet.RandomFlip', prob=0.5),
55 | dict(
56 | type='mmdet.Resize',
57 | scale=img_scale,
58 | keep_ratio=True,
59 | clip_object_border=False),
60 | dict(
61 | type='mmdet.Pad',
62 | size_divisor=32,
63 | pad_val=dict(img=(114.0, 114.0, 114.0))),
64 | dict(
65 | type='mmdet.FilterAnnotations',
66 | min_gt_bbox_wh=(1, 1),
67 | keep_empty=False),
68 | dict(type='mmtrack.PackTrackInputs', pack_single_img=True)
69 | ]
70 | test_pipeline = [
71 | dict(type='LoadImageFromFile',
72 | backend_args=dict(
73 | backend='tar',
74 | tar_path=data_root + 'continuous/videos/1x/val/front/img_decompressed.tar',
75 | )
76 | ),
77 | dict(type='mmtrack.LoadTrackAnnotations'),
78 | dict(type='mmdet.Resize', scale=img_scale, keep_ratio=True),
79 | dict(
80 | type='mmdet.Pad',
81 | size_divisor=32,
82 | pad_val=dict(img=(114.0, 114.0, 114.0))),
83 | dict(type='mmtrack.PackTrackInputs', pack_single_img=True)
84 | ]
85 |
86 | train_dataset = dict(
87 | # use MultiImageMixDataset wrapper to support mosaic and mixup
88 | type='mmdet.MultiImageMixDataset',
89 | dataset=dict(
90 | type='SHIFTDataset',
91 | load_as_video=False,
92 | ann_file=data_root + 'discrete/images/train/front/det_2d_cocoformat.json',
93 | data_prefix=dict(img=''),
94 | ref_img_sampler=None,
95 | metainfo=dict(classes=('pedestrian', 'car', 'truck', 'bus', 'motorcycle', 'bicycle')),
96 | pipeline=[
97 | dict(type='LoadImageFromFile',
98 | backend_args=dict(
99 | backend='zip',
100 | zip_path=data_root + 'discrete/images/train/front/img.zip',
101 | )
102 | ),
103 | dict(type='mmtrack.LoadTrackAnnotations'),
104 | ],
105 | filter_cfg=dict(
106 | attributes=attributes,
107 | filter_empty_gt=False,
108 | min_size=32
109 | )),
110 | pipeline=train_pipeline)
111 | train_dataloader = dict(
112 | batch_size=batch_size,
113 | num_workers=4,
114 | persistent_workers=True,
115 | sampler=dict(type='DefaultSampler', shuffle=True),
116 | dataset=train_dataset)
117 |
118 | val_dataset=dict(
119 | type='SHIFTDataset',
120 | load_as_video=True,
121 | ann_file=data_root + 'continuous/videos/1x/val/front/det_2d_cocoformat.json',
122 | data_prefix=dict(img=''),
123 | ref_img_sampler=None,
124 | test_mode=True,
125 | filter_cfg=dict(attributes=attributes),
126 | pipeline=test_pipeline,
127 | metainfo=dict(classes=('pedestrian', 'car', 'truck', 'bus', 'motorcycle', 'bicycle')))
128 | val_dataloader = dict(
129 | batch_size=1,
130 | num_workers=4,
131 | persistent_workers=True,
132 | drop_last=False,
133 | sampler=dict(type='mmtrack.VideoSampler'),
134 | dataset=val_dataset)
135 | test_dataloader = val_dataloader
136 | # optimizer
137 | # default 8 gpu
138 | lr = 0.0005 / 8 * batch_size
139 | optim_wrapper = dict(
140 | type='OptimWrapper',
141 | optimizer=dict(
142 | type='SGD', lr=lr, momentum=0.9, weight_decay=5e-4, nesterov=True),
143 | paramwise_cfg=dict(norm_decay_mult=0., bias_decay_mult=0.))
144 |
145 | # some hyper parameters
146 | # training settings
147 | total_epochs = 12
148 | num_last_epochs = 2
149 | resume_from = None
150 | interval = 5
151 |
152 | train_cfg = dict(
153 | type='EpochBasedTrainLoop', max_epochs=total_epochs, val_interval=1)
154 | val_cfg = dict(type='ValLoop')
155 | test_cfg = dict(type='TestLoop')
156 | # learning policy
157 | param_scheduler = [
158 | dict(
159 | # use quadratic formula to warm up 1 epochs
160 | # and lr is updated by iteration
161 | type='mmdet.QuadraticWarmupLR',
162 | by_epoch=True,
163 | begin=0,
164 | end=1,
165 | convert_to_iter_based=True),
166 | dict(
167 | # use cosine lr from 1 to epoch #(total_epochs - num_last_epochs)
168 | type='mmdet.CosineAnnealingLR',
169 | eta_min=lr * 0.05,
170 | begin=1,
171 | T_max=total_epochs - num_last_epochs,
172 | end=total_epochs - num_last_epochs,
173 | by_epoch=True,
174 | convert_to_iter_based=True),
175 | dict(
176 | # use fixed lr during last 10 epochs
177 | type='mmdet.ConstantLR',
178 | by_epoch=True,
179 | factor=1,
180 | begin=total_epochs - num_last_epochs,
181 | end=total_epochs,
182 | )
183 | ]
184 |
185 | custom_hooks = [
186 | dict(
187 | type='mmtrack.YOLOXModeSwitchHook',
188 | num_last_epochs=num_last_epochs,
189 | priority=48),
190 | dict(type='mmdet.SyncNormHook', priority=48),
191 | dict(
192 | type='mmdet.EMAHook',
193 | ema_type='mmdet.ExpMomentumEMA',
194 | momentum=0.0001,
195 | update_buffers=True,
196 | priority=49)
197 | ]
198 | default_hooks = dict(checkpoint=dict(interval=1))
199 |
200 | # evaluator
201 | val_evaluator = [
202 | dict(type='SHIFTVideoMetric', metric=['bbox'], classwise=True),
203 | ]
204 | test_evaluator = val_evaluator
--------------------------------------------------------------------------------
/configs/continuous/no_adap_yolox/yolox_x_8xb4-24e_shift_from_clear_night.py:
--------------------------------------------------------------------------------
1 | _base_ = [
2 | '../../_base_/models/yolox_x_8x8.py',
3 | '../../_base_/default_runtime.py'
4 | ]
5 |
6 | dataset_type = 'SHIFTDataset'
7 | data_root = 'data/shift/'
8 | attributes = dict(weather_coarse='clear', timeofday_coarse='night')
9 |
10 | img_scale = (800, 1440)
11 | batch_size = 2
12 |
13 | model = dict(
14 | type='AdaptiveDetector',
15 | data_preprocessor=dict(
16 | type='mmtrack.TrackDataPreprocessor',
17 | pad_size_divisor=32,
18 | batch_augments=[
19 | dict(
20 | type='mmdet.BatchSyncRandomResize',
21 | random_size_range=(576, 1024),
22 | size_divisor=32,
23 | interval=10)
24 | ]),
25 | detector=dict(
26 | _scope_='mmdet',
27 | bbox_head=dict(num_classes=6),
28 | test_cfg=dict(score_thr=0.01, nms=dict(type='nms', iou_threshold=0.7)),
29 | init_cfg=dict(
30 | type='Pretrained',
31 | checkpoint= # noqa: E251
32 | 'https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_x_8x8_300e_coco/yolox_x_8x8_300e_coco_20211126_140254-1ef88d67.pth' # noqa: E501
33 | )),
34 | adapter=None)
35 |
36 | train_pipeline = [
37 | dict(
38 | type='mmdet.Mosaic',
39 | img_scale=img_scale,
40 | pad_val=114.0,
41 | bbox_clip_border=False),
42 | dict(
43 | type='mmdet.RandomAffine',
44 | scaling_ratio_range=(0.1, 2),
45 | border=(-img_scale[0] // 2, -img_scale[1] // 2),
46 | bbox_clip_border=False),
47 | dict(
48 | type='mmdet.MixUp',
49 | img_scale=img_scale,
50 | ratio_range=(0.8, 1.6),
51 | pad_val=114.0,
52 | bbox_clip_border=False),
53 | dict(type='mmdet.YOLOXHSVRandomAug'),
54 | dict(type='mmdet.RandomFlip', prob=0.5),
55 | dict(
56 | type='mmdet.Resize',
57 | scale=img_scale,
58 | keep_ratio=True,
59 | clip_object_border=False),
60 | dict(
61 | type='mmdet.Pad',
62 | size_divisor=32,
63 | pad_val=dict(img=(114.0, 114.0, 114.0))),
64 | dict(
65 | type='mmdet.FilterAnnotations',
66 | min_gt_bbox_wh=(1, 1),
67 | keep_empty=False),
68 | dict(type='mmtrack.PackTrackInputs', pack_single_img=True)
69 | ]
70 | test_pipeline = [
71 | dict(type='LoadImageFromFile',
72 | backend_args=dict(
73 | backend='tar',
74 | tar_path=data_root + 'continuous/videos/1x/val/front/img_decompressed.tar',
75 | )
76 | ),
77 | dict(type='mmtrack.LoadTrackAnnotations'),
78 | dict(type='mmdet.Resize', scale=img_scale, keep_ratio=True),
79 | dict(
80 | type='mmdet.Pad',
81 | size_divisor=32,
82 | pad_val=dict(img=(114.0, 114.0, 114.0))),
83 | dict(type='mmtrack.PackTrackInputs', pack_single_img=True)
84 | ]
85 |
86 | train_dataset = dict(
87 | # use MultiImageMixDataset wrapper to support mosaic and mixup
88 | type='mmdet.MultiImageMixDataset',
89 | dataset=dict(
90 | type='SHIFTDataset',
91 | load_as_video=False,
92 | ann_file=data_root + 'discrete/images/train/front/det_2d_cocoformat.json',
93 | data_prefix=dict(img=''),
94 | ref_img_sampler=None,
95 | metainfo=dict(classes=('pedestrian', 'car', 'truck', 'bus', 'motorcycle', 'bicycle')),
96 | pipeline=[
97 | dict(type='LoadImageFromFile',
98 | backend_args=dict(
99 | backend='zip',
100 | zip_path=data_root + 'discrete/images/train/front/img.zip',
101 | )
102 | ),
103 | dict(type='mmtrack.LoadTrackAnnotations'),
104 | ],
105 | filter_cfg=dict(
106 | attributes=attributes,
107 | filter_empty_gt=False,
108 | min_size=32
109 | )),
110 | pipeline=train_pipeline)
111 | train_dataloader = dict(
112 | batch_size=batch_size,
113 | num_workers=4,
114 | persistent_workers=True,
115 | sampler=dict(type='DefaultSampler', shuffle=True),
116 | dataset=train_dataset)
117 |
118 | val_dataset=dict(
119 | type='SHIFTDataset',
120 | load_as_video=True,
121 | ann_file=data_root + 'continuous/videos/1x/val/front/det_2d_cocoformat.json',
122 | data_prefix=dict(img=''),
123 | ref_img_sampler=None,
124 | test_mode=True,
125 | filter_cfg=dict(attributes=attributes),
126 | pipeline=test_pipeline,
127 | metainfo=dict(classes=('pedestrian', 'car', 'truck', 'bus', 'motorcycle', 'bicycle')))
128 | val_dataloader = dict(
129 | batch_size=1,
130 | num_workers=4,
131 | persistent_workers=True,
132 | drop_last=False,
133 | sampler=dict(type='mmtrack.VideoSampler'),
134 | dataset=val_dataset)
135 | test_dataloader = val_dataloader
136 | # optimizer
137 | # default 8 gpu
138 | lr = 0.0005 / 8 * batch_size
139 | optim_wrapper = dict(
140 | type='OptimWrapper',
141 | optimizer=dict(
142 | type='SGD', lr=lr, momentum=0.9, weight_decay=5e-4, nesterov=True),
143 | paramwise_cfg=dict(norm_decay_mult=0., bias_decay_mult=0.))
144 |
145 | # some hyper parameters
146 | # training settings
147 | total_epochs = 12
148 | num_last_epochs = 2
149 | resume_from = None
150 | interval = 5
151 |
152 | train_cfg = dict(
153 | type='EpochBasedTrainLoop', max_epochs=total_epochs, val_interval=1)
154 | val_cfg = dict(type='ValLoop')
155 | test_cfg = dict(type='TestLoop')
156 | # learning policy
157 | param_scheduler = [
158 | dict(
159 | # use quadratic formula to warm up 1 epochs
160 | # and lr is updated by iteration
161 | type='mmdet.QuadraticWarmupLR',
162 | by_epoch=True,
163 | begin=0,
164 | end=1,
165 | convert_to_iter_based=True),
166 | dict(
167 | # use cosine lr from 1 to epoch #(total_epochs - num_last_epochs)
168 | type='mmdet.CosineAnnealingLR',
169 | eta_min=lr * 0.05,
170 | begin=1,
171 | T_max=total_epochs - num_last_epochs,
172 | end=total_epochs - num_last_epochs,
173 | by_epoch=True,
174 | convert_to_iter_based=True),
175 | dict(
176 | # use fixed lr during last 10 epochs
177 | type='mmdet.ConstantLR',
178 | by_epoch=True,
179 | factor=1,
180 | begin=total_epochs - num_last_epochs,
181 | end=total_epochs,
182 | )
183 | ]
184 |
185 | custom_hooks = [
186 | dict(
187 | type='mmtrack.YOLOXModeSwitchHook',
188 | num_last_epochs=num_last_epochs,
189 | priority=48),
190 | dict(type='mmdet.SyncNormHook', priority=48),
191 | dict(
192 | type='mmdet.EMAHook',
193 | ema_type='mmdet.ExpMomentumEMA',
194 | momentum=0.0001,
195 | update_buffers=True,
196 | priority=49)
197 | ]
198 | default_hooks = dict(checkpoint=dict(interval=1))
199 |
200 | # evaluator
201 | val_evaluator = [
202 | dict(type='SHIFTVideoMetric', metric=['bbox'], classwise=True),
203 | ]
204 | test_evaluator = val_evaluator
--------------------------------------------------------------------------------
/configs/source/yolox/README.md:
--------------------------------------------------------------------------------
1 | # TODO
--------------------------------------------------------------------------------
/configs/source/yolox/amp_yolox_x_8xb4-24e_shift_all.py:
--------------------------------------------------------------------------------
1 | _base_ = ['./yolox_x_8xb4-24e_shift_all.py']
2 |
3 | # fp16 settings
4 | optim_wrapper = dict(type='AmpOptimWrapper', loss_scale='dynamic')
5 | test_cfg = dict(type='TestLoop', fp16=True)
--------------------------------------------------------------------------------
/configs/source/yolox/amp_yolox_x_8xb4-24e_shift_clear_daytime.py:
--------------------------------------------------------------------------------
1 | _base_ = ['./yolox_x_8xb4-24e_shift_clear_daytime.py']
2 |
3 | # fp16 settings
4 | optim_wrapper = dict(type='AmpOptimWrapper', loss_scale='dynamic')
5 | test_cfg = dict(type='TestLoop', fp16=True)
--------------------------------------------------------------------------------
/configs/source/yolox/amp_yolox_x_8xb4-24e_shift_clear_night.py:
--------------------------------------------------------------------------------
1 | _base_ = ['./yolox_x_8xb4-24e_shift_clear_night.py']
2 |
3 | # fp16 settings
4 | optim_wrapper = dict(type='AmpOptimWrapper', loss_scale='dynamic')
5 | test_cfg = dict(type='TestLoop', fp16=True)
--------------------------------------------------------------------------------
/configs/source/yolox/amp_yolox_x_8xb4-24e_shift_daytime.py:
--------------------------------------------------------------------------------
1 | _base_ = ['./yolox_x_8xb4-24e_shift_daytime.py']
2 |
3 | # fp16 settings
4 | optim_wrapper = dict(type='AmpOptimWrapper', loss_scale='dynamic')
5 | test_cfg = dict(type='TestLoop', fp16=True)
--------------------------------------------------------------------------------
/configs/source/yolox/amp_yolox_x_8xb4-24e_shift_night.py:
--------------------------------------------------------------------------------
1 | _base_ = ['./yolox_x_8xb4-24e_shift_night.py']
2 |
3 | # fp16 settings
4 | optim_wrapper = dict(type='AmpOptimWrapper', loss_scale='dynamic')
5 | test_cfg = dict(type='TestLoop', fp16=True)
--------------------------------------------------------------------------------
/configs/source/yolox/yolox_x_8xb4-24e_shift_all.py:
--------------------------------------------------------------------------------
1 | _base_ = [
2 | '../../_base_/models/yolox_x_8x8.py',
3 | '../../_base_/default_runtime.py'
4 | ]
5 |
6 | dataset_type = 'SHIFTDataset'
7 | data_root = 'data/shift/'
8 | attributes = None
9 |
10 | img_scale = (800, 1440)
11 | batch_size = 2
12 |
13 | model = dict(
14 | type='AdaptiveDetector',
15 | data_preprocessor=dict(
16 | type='mmtrack.TrackDataPreprocessor',
17 | pad_size_divisor=32,
18 | batch_augments=[
19 | dict(
20 | type='mmdet.BatchSyncRandomResize',
21 | random_size_range=(576, 1024),
22 | size_divisor=32,
23 | interval=10)
24 | ]),
25 | detector=dict(
26 | _scope_='mmdet',
27 | bbox_head=dict(num_classes=6),
28 | test_cfg=dict(score_thr=0.01, nms=dict(type='nms', iou_threshold=0.7)),
29 | init_cfg=dict(
30 | type='Pretrained',
31 | checkpoint= # noqa: E251
32 | 'https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_x_8x8_300e_coco/yolox_x_8x8_300e_coco_20211126_140254-1ef88d67.pth' # noqa: E501
33 | )),
34 | adapter=None)
35 |
36 | train_pipeline = [
37 | dict(
38 | type='mmdet.Mosaic',
39 | img_scale=img_scale,
40 | pad_val=114.0,
41 | bbox_clip_border=False),
42 | dict(
43 | type='mmdet.RandomAffine',
44 | scaling_ratio_range=(0.1, 2),
45 | border=(-img_scale[0] // 2, -img_scale[1] // 2),
46 | bbox_clip_border=False),
47 | dict(
48 | type='mmdet.MixUp',
49 | img_scale=img_scale,
50 | ratio_range=(0.8, 1.6),
51 | pad_val=114.0,
52 | bbox_clip_border=False),
53 | dict(type='mmdet.YOLOXHSVRandomAug'),
54 | dict(type='mmdet.RandomFlip', prob=0.5),
55 | dict(
56 | type='mmdet.Resize',
57 | scale=img_scale,
58 | keep_ratio=True,
59 | clip_object_border=False),
60 | dict(
61 | type='mmdet.Pad',
62 | size_divisor=32,
63 | pad_val=dict(img=(114.0, 114.0, 114.0))),
64 | dict(
65 | type='mmdet.FilterAnnotations',
66 | min_gt_bbox_wh=(1, 1),
67 | keep_empty=False),
68 | dict(type='mmtrack.PackTrackInputs', pack_single_img=True)
69 | ]
70 | test_pipeline = [
71 | dict(type='LoadImageFromFile',
72 | backend_args=dict(
73 | backend='zip',
74 | zip_path=data_root + 'discrete/images/val/front/img.zip',
75 | )
76 | ),
77 | dict(type='mmtrack.LoadTrackAnnotations'),
78 | dict(type='mmdet.Resize', scale=img_scale, keep_ratio=True),
79 | dict(
80 | type='mmdet.Pad',
81 | size_divisor=32,
82 | pad_val=dict(img=(114.0, 114.0, 114.0))),
83 | dict(type='mmtrack.PackTrackInputs', pack_single_img=True)
84 | ]
85 |
86 | train_dataset = dict(
87 | # use MultiImageMixDataset wrapper to support mosaic and mixup
88 | type='mmdet.MultiImageMixDataset',
89 | dataset=dict(
90 | type='SHIFTDataset',
91 | load_as_video=False,
92 | ann_file=data_root + 'discrete/images/train/front/det_2d_cocoformat.json',
93 | data_prefix=dict(img=''),
94 | ref_img_sampler=None,
95 | metainfo=dict(classes=('pedestrian', 'car', 'truck', 'bus', 'motorcycle', 'bicycle')),
96 | pipeline=[
97 | dict(type='LoadImageFromFile',
98 | backend_args=dict(
99 | backend='zip',
100 | zip_path=data_root + 'discrete/images/train/front/img.zip',
101 | )
102 | ),
103 | dict(type='mmtrack.LoadTrackAnnotations'),
104 | ],
105 | filter_cfg=dict(
106 | attributes=attributes,
107 | filter_empty_gt=False,
108 | min_size=32
109 | )),
110 | pipeline=train_pipeline)
111 | train_dataloader = dict(
112 | batch_size=batch_size,
113 | num_workers=4,
114 | persistent_workers=True,
115 | sampler=dict(type='DefaultSampler', shuffle=True),
116 | dataset=train_dataset)
117 |
118 | val_dataset=dict(
119 | type='SHIFTDataset',
120 | load_as_video=False,
121 | ann_file=data_root + 'discrete/images/val/front/det_2d_cocoformat.json',
122 | data_prefix=dict(img=''),
123 | ref_img_sampler=None,
124 | test_mode=True,
125 | filter_cfg=dict(attributes=attributes),
126 | pipeline=test_pipeline,
127 | metainfo=dict(classes=('pedestrian', 'car', 'truck', 'bus', 'motorcycle', 'bicycle')))
128 | val_dataloader = dict(
129 | batch_size=1,
130 | num_workers=4,
131 | persistent_workers=True,
132 | drop_last=False,
133 | sampler=dict(type='DefaultSampler', shuffle=False),
134 | dataset=val_dataset)
135 | test_dataloader = val_dataloader
136 | # optimizer
137 | # default 8 gpu
138 | lr = 0.0005 / 8 * batch_size
139 | optim_wrapper = dict(
140 | type='OptimWrapper',
141 | optimizer=dict(
142 | type='SGD', lr=lr, momentum=0.9, weight_decay=5e-4, nesterov=True),
143 | paramwise_cfg=dict(norm_decay_mult=0., bias_decay_mult=0.))
144 |
145 | # some hyper parameters
146 | # training settings
147 | total_epochs = 24
148 | num_last_epochs = 2
149 | resume_from = None
150 | interval = 2
151 |
152 | train_cfg = dict(
153 | type='EpochBasedTrainLoop', max_epochs=total_epochs, val_interval=1)
154 | val_cfg = dict(type='ValLoop')
155 | test_cfg = dict(type='TestLoop')
156 | # learning policy
157 | param_scheduler = [
158 | dict(
159 | # use quadratic formula to warm up 1 epochs
160 | # and lr is updated by iteration
161 | type='mmdet.QuadraticWarmupLR',
162 | by_epoch=True,
163 | begin=0,
164 | end=1,
165 | convert_to_iter_based=True),
166 | dict(
167 | # use cosine lr from 1 to epoch #(total_epochs - num_last_epochs)
168 | type='mmdet.CosineAnnealingLR',
169 | eta_min=lr * 0.05,
170 | begin=1,
171 | T_max=total_epochs - num_last_epochs,
172 | end=total_epochs - num_last_epochs,
173 | by_epoch=True,
174 | convert_to_iter_based=True),
175 | dict(
176 | # use fixed lr during last 10 epochs
177 | type='mmdet.ConstantLR',
178 | by_epoch=True,
179 | factor=1,
180 | begin=total_epochs - num_last_epochs,
181 | end=total_epochs,
182 | )
183 | ]
184 |
185 | custom_hooks = [
186 | dict(
187 | type='mmtrack.YOLOXModeSwitchHook',
188 | num_last_epochs=num_last_epochs,
189 | priority=48),
190 | dict(type='mmdet.SyncNormHook', priority=48),
191 | dict(
192 | type='mmdet.EMAHook',
193 | ema_type='mmdet.ExpMomentumEMA',
194 | momentum=0.0001,
195 | update_buffers=True,
196 | priority=49)
197 | ]
198 | default_hooks = dict(checkpoint=dict(interval=1))
199 |
200 | # evaluator
201 | val_evaluator = [
202 | dict(type='SHIFTVideoMetric', metric=['bbox'], classwise=True),
203 | ]
204 | test_evaluator = val_evaluator
--------------------------------------------------------------------------------
/configs/source/yolox/yolox_x_8xb4-24e_shift_clear_daytime.py:
--------------------------------------------------------------------------------
1 | _base_ = [
2 | '../../_base_/models/yolox_x_8x8.py',
3 | '../../_base_/default_runtime.py'
4 | ]
5 |
6 | dataset_type = 'SHIFTDataset'
7 | data_root = 'data/shift/'
8 | attributes = dict(weather_coarse='clear', timeofday_coarse='daytime')
9 |
10 | img_scale = (800, 1440)
11 | batch_size = 2
12 |
13 | model = dict(
14 | type='AdaptiveDetector',
15 | data_preprocessor=dict(
16 | type='mmtrack.TrackDataPreprocessor',
17 | pad_size_divisor=32,
18 | batch_augments=[
19 | dict(
20 | type='mmdet.BatchSyncRandomResize',
21 | random_size_range=(576, 1024),
22 | size_divisor=32,
23 | interval=10)
24 | ]),
25 | detector=dict(
26 | _scope_='mmdet',
27 | bbox_head=dict(num_classes=6),
28 | test_cfg=dict(score_thr=0.01, nms=dict(type='nms', iou_threshold=0.7)),
29 | init_cfg=dict(
30 | type='Pretrained',
31 | checkpoint= # noqa: E251
32 | 'https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_x_8x8_300e_coco/yolox_x_8x8_300e_coco_20211126_140254-1ef88d67.pth' # noqa: E501
33 | )),
34 | adapter=None)
35 |
36 | train_pipeline = [
37 | dict(
38 | type='mmdet.Mosaic',
39 | img_scale=img_scale,
40 | pad_val=114.0,
41 | bbox_clip_border=False),
42 | dict(
43 | type='mmdet.RandomAffine',
44 | scaling_ratio_range=(0.1, 2),
45 | border=(-img_scale[0] // 2, -img_scale[1] // 2),
46 | bbox_clip_border=False),
47 | dict(
48 | type='mmdet.MixUp',
49 | img_scale=img_scale,
50 | ratio_range=(0.8, 1.6),
51 | pad_val=114.0,
52 | bbox_clip_border=False),
53 | dict(type='mmdet.YOLOXHSVRandomAug'),
54 | dict(type='mmdet.RandomFlip', prob=0.5),
55 | dict(
56 | type='mmdet.Resize',
57 | scale=img_scale,
58 | keep_ratio=True,
59 | clip_object_border=False),
60 | dict(
61 | type='mmdet.Pad',
62 | size_divisor=32,
63 | pad_val=dict(img=(114.0, 114.0, 114.0))),
64 | dict(
65 | type='mmdet.FilterAnnotations',
66 | min_gt_bbox_wh=(1, 1),
67 | keep_empty=False),
68 | dict(type='mmtrack.PackTrackInputs', pack_single_img=True)
69 | ]
70 | test_pipeline = [
71 | dict(type='LoadImageFromFile',
72 | backend_args=dict(
73 | backend='zip',
74 | zip_path=data_root + 'discrete/images/val/front/img.zip',
75 | )
76 | ),
77 | dict(type='mmtrack.LoadTrackAnnotations'),
78 | dict(type='mmdet.Resize', scale=img_scale, keep_ratio=True),
79 | dict(
80 | type='mmdet.Pad',
81 | size_divisor=32,
82 | pad_val=dict(img=(114.0, 114.0, 114.0))),
83 | dict(type='mmtrack.PackTrackInputs', pack_single_img=True)
84 | ]
85 |
86 | train_dataset = dict(
87 | # use MultiImageMixDataset wrapper to support mosaic and mixup
88 | type='mmdet.MultiImageMixDataset',
89 | dataset=dict(
90 | type='SHIFTDataset',
91 | load_as_video=False,
92 | ann_file=data_root + 'discrete/images/train/front/det_2d_cocoformat.json',
93 | data_prefix=dict(img=''),
94 | ref_img_sampler=None,
95 | metainfo=dict(classes=('pedestrian', 'car', 'truck', 'bus', 'motorcycle', 'bicycle')),
96 | pipeline=[
97 | dict(type='LoadImageFromFile',
98 | backend_args=dict(
99 | backend='zip',
100 | zip_path=data_root + 'discrete/images/train/front/img.zip',
101 | )
102 | ),
103 | dict(type='mmtrack.LoadTrackAnnotations'),
104 | ],
105 | filter_cfg=dict(
106 | attributes=attributes,
107 | filter_empty_gt=False,
108 | min_size=32
109 | )),
110 | pipeline=train_pipeline)
111 | train_dataloader = dict(
112 | batch_size=batch_size,
113 | num_workers=4,
114 | persistent_workers=True,
115 | sampler=dict(type='DefaultSampler', shuffle=True),
116 | dataset=train_dataset)
117 |
118 | val_dataset=dict(
119 | type='SHIFTDataset',
120 | load_as_video=False,
121 | ann_file=data_root + 'discrete/images/val/front/det_2d_cocoformat.json',
122 | data_prefix=dict(img=''),
123 | ref_img_sampler=None,
124 | test_mode=True,
125 | filter_cfg=dict(attributes=attributes),
126 | pipeline=test_pipeline,
127 | metainfo=dict(classes=('pedestrian', 'car', 'truck', 'bus', 'motorcycle', 'bicycle')))
128 | val_dataloader = dict(
129 | batch_size=1,
130 | num_workers=4,
131 | persistent_workers=True,
132 | drop_last=False,
133 | sampler=dict(type='DefaultSampler', shuffle=False),
134 | dataset=val_dataset)
135 | test_dataloader = val_dataloader
136 | # optimizer
137 | # default 8 gpu
138 | lr = 0.0005 / 8 * batch_size
139 | optim_wrapper = dict(
140 | type='OptimWrapper',
141 | optimizer=dict(
142 | type='SGD', lr=lr, momentum=0.9, weight_decay=5e-4, nesterov=True),
143 | paramwise_cfg=dict(norm_decay_mult=0., bias_decay_mult=0.))
144 |
145 | # some hyper parameters
146 | # training settings
147 | total_epochs = 24
148 | num_last_epochs = 2
149 | resume_from = None
150 | interval = 2
151 |
152 | train_cfg = dict(
153 | type='EpochBasedTrainLoop', max_epochs=total_epochs, val_interval=1)
154 | val_cfg = dict(type='ValLoop')
155 | test_cfg = dict(type='TestLoop')
156 | # learning policy
157 | param_scheduler = [
158 | dict(
159 | # use quadratic formula to warm up 1 epochs
160 | # and lr is updated by iteration
161 | type='mmdet.QuadraticWarmupLR',
162 | by_epoch=True,
163 | begin=0,
164 | end=1,
165 | convert_to_iter_based=True),
166 | dict(
167 | # use cosine lr from 1 to epoch #(total_epochs - num_last_epochs)
168 | type='mmdet.CosineAnnealingLR',
169 | eta_min=lr * 0.05,
170 | begin=1,
171 | T_max=total_epochs - num_last_epochs,
172 | end=total_epochs - num_last_epochs,
173 | by_epoch=True,
174 | convert_to_iter_based=True),
175 | dict(
176 | # use fixed lr during last 10 epochs
177 | type='mmdet.ConstantLR',
178 | by_epoch=True,
179 | factor=1,
180 | begin=total_epochs - num_last_epochs,
181 | end=total_epochs,
182 | )
183 | ]
184 |
185 | custom_hooks = [
186 | dict(
187 | type='mmtrack.YOLOXModeSwitchHook',
188 | num_last_epochs=num_last_epochs,
189 | priority=48),
190 | dict(type='mmdet.SyncNormHook', priority=48),
191 | dict(
192 | type='mmdet.EMAHook',
193 | ema_type='mmdet.ExpMomentumEMA',
194 | momentum=0.0001,
195 | update_buffers=True,
196 | priority=49)
197 | ]
198 | default_hooks = dict(checkpoint=dict(interval=1))
199 |
200 | # evaluator
201 | val_evaluator = [
202 | dict(type='SHIFTVideoMetric', metric=['bbox'], classwise=True),
203 | ]
204 | test_evaluator = val_evaluator
--------------------------------------------------------------------------------
/configs/source/yolox/yolox_x_8xb4-24e_shift_clear_night.py:
--------------------------------------------------------------------------------
1 | _base_ = [
2 | '../../_base_/models/yolox_x_8x8.py',
3 | '../../_base_/default_runtime.py'
4 | ]
5 |
6 | dataset_type = 'SHIFTDataset'
7 | data_root = 'data/shift/'
8 | attributes = dict(weather_coarse='clear', timeofday_coarse='night')
9 |
10 | img_scale = (800, 1440)
11 | batch_size = 2
12 |
13 | model = dict(
14 | type='AdaptiveDetector',
15 | data_preprocessor=dict(
16 | type='mmtrack.TrackDataPreprocessor',
17 | pad_size_divisor=32,
18 | batch_augments=[
19 | dict(
20 | type='mmdet.BatchSyncRandomResize',
21 | random_size_range=(576, 1024),
22 | size_divisor=32,
23 | interval=10)
24 | ]),
25 | detector=dict(
26 | _scope_='mmdet',
27 | bbox_head=dict(num_classes=6),
28 | test_cfg=dict(score_thr=0.01, nms=dict(type='nms', iou_threshold=0.7)),
29 | init_cfg=dict(
30 | type='Pretrained',
31 | checkpoint= # noqa: E251
32 | 'https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_x_8x8_300e_coco/yolox_x_8x8_300e_coco_20211126_140254-1ef88d67.pth' # noqa: E501
33 | )),
34 | adapter=None)
35 |
36 | train_pipeline = [
37 | dict(
38 | type='mmdet.Mosaic',
39 | img_scale=img_scale,
40 | pad_val=114.0,
41 | bbox_clip_border=False),
42 | dict(
43 | type='mmdet.RandomAffine',
44 | scaling_ratio_range=(0.1, 2),
45 | border=(-img_scale[0] // 2, -img_scale[1] // 2),
46 | bbox_clip_border=False),
47 | dict(
48 | type='mmdet.MixUp',
49 | img_scale=img_scale,
50 | ratio_range=(0.8, 1.6),
51 | pad_val=114.0,
52 | bbox_clip_border=False),
53 | dict(type='mmdet.YOLOXHSVRandomAug'),
54 | dict(type='mmdet.RandomFlip', prob=0.5),
55 | dict(
56 | type='mmdet.Resize',
57 | scale=img_scale,
58 | keep_ratio=True,
59 | clip_object_border=False),
60 | dict(
61 | type='mmdet.Pad',
62 | size_divisor=32,
63 | pad_val=dict(img=(114.0, 114.0, 114.0))),
64 | dict(
65 | type='mmdet.FilterAnnotations',
66 | min_gt_bbox_wh=(1, 1),
67 | keep_empty=False),
68 | dict(type='mmtrack.PackTrackInputs', pack_single_img=True)
69 | ]
70 | test_pipeline = [
71 | dict(type='LoadImageFromFile',
72 | backend_args=dict(
73 | backend='zip',
74 | zip_path=data_root + 'discrete/images/val/front/img.zip',
75 | )
76 | ),
77 | dict(type='mmtrack.LoadTrackAnnotations'),
78 | dict(type='mmdet.Resize', scale=img_scale, keep_ratio=True),
79 | dict(
80 | type='mmdet.Pad',
81 | size_divisor=32,
82 | pad_val=dict(img=(114.0, 114.0, 114.0))),
83 | dict(type='mmtrack.PackTrackInputs', pack_single_img=True)
84 | ]
85 |
86 | train_dataset = dict(
87 | # use MultiImageMixDataset wrapper to support mosaic and mixup
88 | type='mmdet.MultiImageMixDataset',
89 | dataset=dict(
90 | type='SHIFTDataset',
91 | load_as_video=False,
92 | ann_file=data_root + 'discrete/images/train/front/det_2d_cocoformat.json',
93 | data_prefix=dict(img=''),
94 | ref_img_sampler=None,
95 | metainfo=dict(classes=('pedestrian', 'car', 'truck', 'bus', 'motorcycle', 'bicycle')),
96 | pipeline=[
97 | dict(type='LoadImageFromFile',
98 | backend_args=dict(
99 | backend='zip',
100 | zip_path=data_root + 'discrete/images/train/front/img.zip',
101 | )
102 | ),
103 | dict(type='mmtrack.LoadTrackAnnotations'),
104 | ],
105 | filter_cfg=dict(
106 | attributes=attributes,
107 | filter_empty_gt=False,
108 | min_size=32
109 | )),
110 | pipeline=train_pipeline)
111 | train_dataloader = dict(
112 | batch_size=batch_size,
113 | num_workers=4,
114 | persistent_workers=True,
115 | sampler=dict(type='DefaultSampler', shuffle=True),
116 | dataset=train_dataset)
117 |
118 | val_dataset=dict(
119 | type='SHIFTDataset',
120 | load_as_video=False,
121 | ann_file=data_root + 'discrete/images/val/front/det_2d_cocoformat.json',
122 | data_prefix=dict(img=''),
123 | ref_img_sampler=None,
124 | test_mode=True,
125 | filter_cfg=dict(attributes=attributes),
126 | pipeline=test_pipeline,
127 | metainfo=dict(classes=('pedestrian', 'car', 'truck', 'bus', 'motorcycle', 'bicycle')))
128 | val_dataloader = dict(
129 | batch_size=1,
130 | num_workers=4,
131 | persistent_workers=True,
132 | drop_last=False,
133 | sampler=dict(type='DefaultSampler', shuffle=False),
134 | dataset=val_dataset)
135 | test_dataloader = val_dataloader
136 | # optimizer
137 | # default 8 gpu
138 | lr = 0.0005 / 8 * batch_size
139 | optim_wrapper = dict(
140 | type='OptimWrapper',
141 | optimizer=dict(
142 | type='SGD', lr=lr, momentum=0.9, weight_decay=5e-4, nesterov=True),
143 | paramwise_cfg=dict(norm_decay_mult=0., bias_decay_mult=0.))
144 |
145 | # some hyper parameters
146 | # training settings
147 | total_epochs = 24
148 | num_last_epochs = 2
149 | resume_from = None
150 | interval = 2
151 |
152 | train_cfg = dict(
153 | type='EpochBasedTrainLoop', max_epochs=total_epochs, val_interval=1)
154 | val_cfg = dict(type='ValLoop')
155 | test_cfg = dict(type='TestLoop')
156 | # learning policy
157 | param_scheduler = [
158 | dict(
159 | # use quadratic formula to warm up 1 epochs
160 | # and lr is updated by iteration
161 | type='mmdet.QuadraticWarmupLR',
162 | by_epoch=True,
163 | begin=0,
164 | end=1,
165 | convert_to_iter_based=True),
166 | dict(
167 | # use cosine lr from 1 to epoch #(total_epochs - num_last_epochs)
168 | type='mmdet.CosineAnnealingLR',
169 | eta_min=lr * 0.05,
170 | begin=1,
171 | T_max=total_epochs - num_last_epochs,
172 | end=total_epochs - num_last_epochs,
173 | by_epoch=True,
174 | convert_to_iter_based=True),
175 | dict(
176 | # use fixed lr during last 10 epochs
177 | type='mmdet.ConstantLR',
178 | by_epoch=True,
179 | factor=1,
180 | begin=total_epochs - num_last_epochs,
181 | end=total_epochs,
182 | )
183 | ]
184 |
185 | custom_hooks = [
186 | dict(
187 | type='mmtrack.YOLOXModeSwitchHook',
188 | num_last_epochs=num_last_epochs,
189 | priority=48),
190 | dict(type='mmdet.SyncNormHook', priority=48),
191 | dict(
192 | type='mmdet.EMAHook',
193 | ema_type='mmdet.ExpMomentumEMA',
194 | momentum=0.0001,
195 | update_buffers=True,
196 | priority=49)
197 | ]
198 | default_hooks = dict(checkpoint=dict(interval=1))
199 |
200 | # evaluator
201 | val_evaluator = [
202 | dict(type='SHIFTVideoMetric', metric=['bbox'], classwise=True),
203 | ]
204 | test_evaluator = val_evaluator
--------------------------------------------------------------------------------
/configs/source/yolox/yolox_x_8xb4-24e_shift_daytime.py:
--------------------------------------------------------------------------------
1 | _base_ = [
2 | '../../_base_/models/yolox_x_8x8.py',
3 | '../../_base_/default_runtime.py'
4 | ]
5 |
6 | dataset_type = 'SHIFTDataset'
7 | data_root = 'data/shift/'
8 | attributes = dict(timeofday_coarse='daytime')
9 |
10 | img_scale = (800, 1440)
11 | batch_size = 2
12 |
13 | model = dict(
14 | type='AdaptiveDetector',
15 | data_preprocessor=dict(
16 | type='mmtrack.TrackDataPreprocessor',
17 | pad_size_divisor=32,
18 | batch_augments=[
19 | dict(
20 | type='mmdet.BatchSyncRandomResize',
21 | random_size_range=(576, 1024),
22 | size_divisor=32,
23 | interval=10)
24 | ]),
25 | detector=dict(
26 | _scope_='mmdet',
27 | bbox_head=dict(num_classes=6),
28 | test_cfg=dict(score_thr=0.01, nms=dict(type='nms', iou_threshold=0.7)),
29 | init_cfg=dict(
30 | type='Pretrained',
31 | checkpoint= # noqa: E251
32 | 'https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_x_8x8_300e_coco/yolox_x_8x8_300e_coco_20211126_140254-1ef88d67.pth' # noqa: E501
33 | )),
34 | adapter=None)
35 |
36 | train_pipeline = [
37 | dict(
38 | type='mmdet.Mosaic',
39 | img_scale=img_scale,
40 | pad_val=114.0,
41 | bbox_clip_border=False),
42 | dict(
43 | type='mmdet.RandomAffine',
44 | scaling_ratio_range=(0.1, 2),
45 | border=(-img_scale[0] // 2, -img_scale[1] // 2),
46 | bbox_clip_border=False),
47 | dict(
48 | type='mmdet.MixUp',
49 | img_scale=img_scale,
50 | ratio_range=(0.8, 1.6),
51 | pad_val=114.0,
52 | bbox_clip_border=False),
53 | dict(type='mmdet.YOLOXHSVRandomAug'),
54 | dict(type='mmdet.RandomFlip', prob=0.5),
55 | dict(
56 | type='mmdet.Resize',
57 | scale=img_scale,
58 | keep_ratio=True,
59 | clip_object_border=False),
60 | dict(
61 | type='mmdet.Pad',
62 | size_divisor=32,
63 | pad_val=dict(img=(114.0, 114.0, 114.0))),
64 | dict(
65 | type='mmdet.FilterAnnotations',
66 | min_gt_bbox_wh=(1, 1),
67 | keep_empty=False),
68 | dict(type='mmtrack.PackTrackInputs', pack_single_img=True)
69 | ]
70 | test_pipeline = [
71 | dict(type='LoadImageFromFile',
72 | backend_args=dict(
73 | backend='zip',
74 | zip_path=data_root + 'discrete/images/val/front/img.zip',
75 | )
76 | ),
77 | dict(type='mmtrack.LoadTrackAnnotations'),
78 | dict(type='mmdet.Resize', scale=img_scale, keep_ratio=True),
79 | dict(
80 | type='mmdet.Pad',
81 | size_divisor=32,
82 | pad_val=dict(img=(114.0, 114.0, 114.0))),
83 | dict(type='mmtrack.PackTrackInputs', pack_single_img=True)
84 | ]
85 |
86 | train_dataset = dict(
87 | # use MultiImageMixDataset wrapper to support mosaic and mixup
88 | type='mmdet.MultiImageMixDataset',
89 | dataset=dict(
90 | type='SHIFTDataset',
91 | load_as_video=False,
92 | ann_file=data_root + 'discrete/images/train/front/det_2d_cocoformat.json',
93 | data_prefix=dict(img=''),
94 | ref_img_sampler=None,
95 | metainfo=dict(classes=('pedestrian', 'car', 'truck', 'bus', 'motorcycle', 'bicycle')),
96 | pipeline=[
97 | dict(type='LoadImageFromFile',
98 | backend_args=dict(
99 | backend='zip',
100 | zip_path=data_root + 'discrete/images/train/front/img.zip',
101 | )
102 | ),
103 | dict(type='mmtrack.LoadTrackAnnotations'),
104 | ],
105 | filter_cfg=dict(
106 | attributes=attributes,
107 | filter_empty_gt=False,
108 | min_size=32
109 | )),
110 | pipeline=train_pipeline)
111 | train_dataloader = dict(
112 | batch_size=batch_size,
113 | num_workers=4,
114 | persistent_workers=True,
115 | sampler=dict(type='DefaultSampler', shuffle=True),
116 | dataset=train_dataset)
117 |
118 | val_dataset=dict(
119 | type='SHIFTDataset',
120 | load_as_video=False,
121 | ann_file=data_root + 'discrete/images/val/front/det_2d_cocoformat.json',
122 | data_prefix=dict(img=''),
123 | ref_img_sampler=None,
124 | test_mode=True,
125 | filter_cfg=dict(attributes=attributes),
126 | pipeline=test_pipeline,
127 | metainfo=dict(classes=('pedestrian', 'car', 'truck', 'bus', 'motorcycle', 'bicycle')))
128 | val_dataloader = dict(
129 | batch_size=1,
130 | num_workers=4,
131 | persistent_workers=True,
132 | drop_last=False,
133 | sampler=dict(type='DefaultSampler', shuffle=False),
134 | dataset=val_dataset)
135 | test_dataloader = val_dataloader
136 | # optimizer
137 | # default 8 gpu
138 | lr = 0.0005 / 8 * batch_size
139 | optim_wrapper = dict(
140 | type='OptimWrapper',
141 | optimizer=dict(
142 | type='SGD', lr=lr, momentum=0.9, weight_decay=5e-4, nesterov=True),
143 | paramwise_cfg=dict(norm_decay_mult=0., bias_decay_mult=0.))
144 |
145 | # some hyper parameters
146 | # training settings
147 | total_epochs = 24
148 | num_last_epochs = 2
149 | resume_from = None
150 | interval = 2
151 |
152 | train_cfg = dict(
153 | type='EpochBasedTrainLoop', max_epochs=total_epochs, val_interval=1)
154 | val_cfg = dict(type='ValLoop')
155 | test_cfg = dict(type='TestLoop')
156 | # learning policy
157 | param_scheduler = [
158 | dict(
159 | # use quadratic formula to warm up 1 epochs
160 | # and lr is updated by iteration
161 | type='mmdet.QuadraticWarmupLR',
162 | by_epoch=True,
163 | begin=0,
164 | end=1,
165 | convert_to_iter_based=True),
166 | dict(
167 | # use cosine lr from 1 to epoch #(total_epochs - num_last_epochs)
168 | type='mmdet.CosineAnnealingLR',
169 | eta_min=lr * 0.05,
170 | begin=1,
171 | T_max=total_epochs - num_last_epochs,
172 | end=total_epochs - num_last_epochs,
173 | by_epoch=True,
174 | convert_to_iter_based=True),
175 | dict(
176 | # use fixed lr during last 10 epochs
177 | type='mmdet.ConstantLR',
178 | by_epoch=True,
179 | factor=1,
180 | begin=total_epochs - num_last_epochs,
181 | end=total_epochs,
182 | )
183 | ]
184 |
185 | custom_hooks = [
186 | dict(
187 | type='mmtrack.YOLOXModeSwitchHook',
188 | num_last_epochs=num_last_epochs,
189 | priority=48),
190 | dict(type='mmdet.SyncNormHook', priority=48),
191 | dict(
192 | type='mmdet.EMAHook',
193 | ema_type='mmdet.ExpMomentumEMA',
194 | momentum=0.0001,
195 | update_buffers=True,
196 | priority=49)
197 | ]
198 | default_hooks = dict(checkpoint=dict(interval=1))
199 |
200 | # evaluator
201 | val_evaluator = [
202 | dict(type='SHIFTVideoMetric', metric=['bbox'], classwise=True),
203 | ]
204 | test_evaluator = val_evaluator
--------------------------------------------------------------------------------
/configs/source/yolox/yolox_x_8xb4-24e_shift_night.py:
--------------------------------------------------------------------------------
1 | _base_ = [
2 | '../../_base_/models/yolox_x_8x8.py',
3 | '../../_base_/default_runtime.py'
4 | ]
5 |
6 | dataset_type = 'SHIFTDataset'
7 | data_root = 'data/shift/'
8 | attributes = dict(timeofday_coarse='night')
9 |
10 | img_scale = (800, 1440)
11 | batch_size = 2
12 |
13 | model = dict(
14 | type='AdaptiveDetector',
15 | data_preprocessor=dict(
16 | type='mmtrack.TrackDataPreprocessor',
17 | pad_size_divisor=32,
18 | batch_augments=[
19 | dict(
20 | type='mmdet.BatchSyncRandomResize',
21 | random_size_range=(576, 1024),
22 | size_divisor=32,
23 | interval=10)
24 | ]),
25 | detector=dict(
26 | _scope_='mmdet',
27 | bbox_head=dict(num_classes=6),
28 | test_cfg=dict(score_thr=0.01, nms=dict(type='nms', iou_threshold=0.7)),
29 | init_cfg=dict(
30 | type='Pretrained',
31 | checkpoint= # noqa: E251
32 | 'https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_x_8x8_300e_coco/yolox_x_8x8_300e_coco_20211126_140254-1ef88d67.pth' # noqa: E501
33 | )),
34 | adapter=None)
35 |
36 | train_pipeline = [
37 | dict(
38 | type='mmdet.Mosaic',
39 | img_scale=img_scale,
40 | pad_val=114.0,
41 | bbox_clip_border=False),
42 | dict(
43 | type='mmdet.RandomAffine',
44 | scaling_ratio_range=(0.1, 2),
45 | border=(-img_scale[0] // 2, -img_scale[1] // 2),
46 | bbox_clip_border=False),
47 | dict(
48 | type='mmdet.MixUp',
49 | img_scale=img_scale,
50 | ratio_range=(0.8, 1.6),
51 | pad_val=114.0,
52 | bbox_clip_border=False),
53 | dict(type='mmdet.YOLOXHSVRandomAug'),
54 | dict(type='mmdet.RandomFlip', prob=0.5),
55 | dict(
56 | type='mmdet.Resize',
57 | scale=img_scale,
58 | keep_ratio=True,
59 | clip_object_border=False),
60 | dict(
61 | type='mmdet.Pad',
62 | size_divisor=32,
63 | pad_val=dict(img=(114.0, 114.0, 114.0))),
64 | dict(
65 | type='mmdet.FilterAnnotations',
66 | min_gt_bbox_wh=(1, 1),
67 | keep_empty=False),
68 | dict(type='mmtrack.PackTrackInputs', pack_single_img=True)
69 | ]
70 | test_pipeline = [
71 | dict(type='LoadImageFromFile',
72 | backend_args=dict(
73 | backend='zip',
74 | zip_path=data_root + 'discrete/images/val/front/img.zip',
75 | )
76 | ),
77 | dict(type='mmtrack.LoadTrackAnnotations'),
78 | dict(type='mmdet.Resize', scale=img_scale, keep_ratio=True),
79 | dict(
80 | type='mmdet.Pad',
81 | size_divisor=32,
82 | pad_val=dict(img=(114.0, 114.0, 114.0))),
83 | dict(type='mmtrack.PackTrackInputs', pack_single_img=True)
84 | ]
85 |
86 | train_dataset = dict(
87 | # use MultiImageMixDataset wrapper to support mosaic and mixup
88 | type='mmdet.MultiImageMixDataset',
89 | dataset=dict(
90 | type='SHIFTDataset',
91 | load_as_video=False,
92 | ann_file=data_root + 'discrete/images/train/front/det_2d_cocoformat.json',
93 | data_prefix=dict(img=''),
94 | ref_img_sampler=None,
95 | metainfo=dict(classes=('pedestrian', 'car', 'truck', 'bus', 'motorcycle', 'bicycle')),
96 | pipeline=[
97 | dict(type='LoadImageFromFile',
98 | backend_args=dict(
99 | backend='zip',
100 | zip_path=data_root + 'discrete/images/train/front/img.zip',
101 | )
102 | ),
103 | dict(type='mmtrack.LoadTrackAnnotations'),
104 | ],
105 | filter_cfg=dict(
106 | attributes=attributes,
107 | filter_empty_gt=False,
108 | min_size=32
109 | )),
110 | pipeline=train_pipeline)
111 | train_dataloader = dict(
112 | batch_size=batch_size,
113 | num_workers=4,
114 | persistent_workers=True,
115 | sampler=dict(type='DefaultSampler', shuffle=True),
116 | dataset=train_dataset)
117 |
118 | val_dataset=dict(
119 | type='SHIFTDataset',
120 | load_as_video=False,
121 | ann_file=data_root + 'discrete/images/val/front/det_2d_cocoformat.json',
122 | data_prefix=dict(img=''),
123 | ref_img_sampler=None,
124 | test_mode=True,
125 | filter_cfg=dict(attributes=attributes),
126 | pipeline=test_pipeline,
127 | metainfo=dict(classes=('pedestrian', 'car', 'truck', 'bus', 'motorcycle', 'bicycle')))
128 | val_dataloader = dict(
129 | batch_size=1,
130 | num_workers=4,
131 | persistent_workers=True,
132 | drop_last=False,
133 | sampler=dict(type='DefaultSampler', shuffle=False),
134 | dataset=val_dataset)
135 | test_dataloader = val_dataloader
136 | # optimizer
137 | # default 8 gpu
138 | lr = 0.0005 / 8 * batch_size
139 | optim_wrapper = dict(
140 | type='OptimWrapper',
141 | optimizer=dict(
142 | type='SGD', lr=lr, momentum=0.9, weight_decay=5e-4, nesterov=True),
143 | paramwise_cfg=dict(norm_decay_mult=0., bias_decay_mult=0.))
144 |
145 | # some hyper parameters
146 | # training settings
147 | total_epochs = 24
148 | num_last_epochs = 2
149 | resume_from = None
150 | interval = 2
151 |
152 | train_cfg = dict(
153 | type='EpochBasedTrainLoop', max_epochs=total_epochs, val_interval=1)
154 | val_cfg = dict(type='ValLoop')
155 | test_cfg = dict(type='TestLoop')
156 | # learning policy
157 | param_scheduler = [
158 | dict(
159 | # use quadratic formula to warm up 1 epochs
160 | # and lr is updated by iteration
161 | type='mmdet.QuadraticWarmupLR',
162 | by_epoch=True,
163 | begin=0,
164 | end=1,
165 | convert_to_iter_based=True),
166 | dict(
167 | # use cosine lr from 1 to epoch #(total_epochs - num_last_epochs)
168 | type='mmdet.CosineAnnealingLR',
169 | eta_min=lr * 0.05,
170 | begin=1,
171 | T_max=total_epochs - num_last_epochs,
172 | end=total_epochs - num_last_epochs,
173 | by_epoch=True,
174 | convert_to_iter_based=True),
175 | dict(
176 | # use fixed lr during last 10 epochs
177 | type='mmdet.ConstantLR',
178 | by_epoch=True,
179 | factor=1,
180 | begin=total_epochs - num_last_epochs,
181 | end=total_epochs,
182 | )
183 | ]
184 |
185 | custom_hooks = [
186 | dict(
187 | type='mmtrack.YOLOXModeSwitchHook',
188 | num_last_epochs=num_last_epochs,
189 | priority=48),
190 | dict(type='mmdet.SyncNormHook', priority=48),
191 | dict(
192 | type='mmdet.EMAHook',
193 | ema_type='mmdet.ExpMomentumEMA',
194 | momentum=0.0001,
195 | update_buffers=True,
196 | priority=49)
197 | ]
198 | default_hooks = dict(checkpoint=dict(interval=1))
199 |
200 | # evaluator
201 | val_evaluator = [
202 | dict(type='SHIFTVideoMetric', metric=['bbox'], classwise=True),
203 | ]
204 | test_evaluator = val_evaluator
--------------------------------------------------------------------------------
/docs/challenge.md:
--------------------------------------------------------------------------------
1 | # Workshop on Visual Continual Learning @ ICCV2023
2 |
3 | # Challenge on Continual Test-time Adaptation for Object Detection
4 | ## Goal
5 | We introduce the [1st Challenge on Continual Test-time Adaptation for Object Detection](https://wvcl.vis.xyz/challenges).
6 |
7 | The goal of this challenge is training an object detector on the SHIFT clear-daytime subset (source domain) and adapting it to the set of SHIFT sequences with continuous domain shift starting from clear-daytime conditions.
8 |
9 | ## Rules
10 | - Using additional data is **not** allowed;
11 | - Any detector architecture can be used;
12 | - The model should be adapted on the fly to each target sequence, and reset to its original state at the end of every sequence.
13 |
14 | You can find a reference implementation for an [AdaptiveDetector](shift_tta/models/detectors/adaptive_detector.py) class wrapping any object detector and an adapter, a [BaseAdapter](shift_tta/models/adapters/base_adapter.py) class and a reference implementation of a [mean-teacher adapter](shift_tta/models/adapters/mean_teacher_adapter_yolox.py) based on YOLOX.
15 |
16 | ## Prize
17 | We will award the top three teams of each challenge with a certificate and a prize of 1000, 500, and 300 USD, respectively. The winners of each challenge will be invited to give a presentation at the workshop. Teams will be selected based on the performance of their methods on the test set.
18 |
19 | We will also award one team from each challenge with an Innovation Award. The Innovation Award is given to the team that proposes the most innovative method and/or insightful analysis. The winner will receive a certificate and an additional prize of 300 USD.
20 |
21 | **Please notice** that this challenge is part of the track **Challenge B - Continual Test-time Adaptation**, together with the challenge on "Continuous Test-time Adaptation for Semantic Segmentation". Since the challenge on "Continuous Test-time Adaptation for Object Detection" constitutes half of the track B, the prize should be considered half of what mentioned above.
22 |
23 | ## Instructions
24 |
25 | ### Train a model on the source domain
26 | First, train an object detection model on the source domain. You may choose any object detector architecture.
27 |
28 | You can find a reference training script at [scripts/source/train_yolox_shift_clear_daytime.sh](scripts/source/train_yolox_shift_clear_daytime.sh) to train a YOLOX model on the SHIFT clear-daytime discrete set.
29 |
30 | We use the discrete set of SHIFT to train the object detector.
31 |
32 | You can also download a YOLOX checkpoint pre-trained using the above-mentioned script at [link](https://dl.cv.ethz.ch/shift/challenge2023/test_time_adaptation/checkpoints/yolox_x_8xb4-24e_shift_clear_daytime.pth).
33 |
34 | ### Test the source model on the target domain
35 | Then, validate the source model on the validation set of the continuous target domain. In particular, we validate on the videos presenting continuous domain shift starting from the clear-daytime conditions. The validation set should be used for validating your method under continuous domain shift and for hyperparameter search.
36 |
37 | You can find a reference validation script at [scripts/continuous/no_adap_yolox/val_yolox_shift_from_clear_daytime.sh](scripts/continuous/no_adap_yolox/val_yolox_shift_from_clear_daytime.sh).
38 |
39 |
40 | ### Continuously adapt a model to the validation target domain
41 | You can now validate your test-time adaptation baseline on the validation videos presenting continuous domain shift starting from the clear-daytime conditions. The validation set should be used for validating your method under continuous domain shift and for hyperparameter search.
42 |
43 | We implemented a baseline adapter based on a detection consistency loss and a mean-teacher formulation. You can find an implementation of the adapter at [mean_teacher_yolox_adapter](shift_tta/models/adapters/mean_teacher_yolox_adapter.py), and the corresponding config file at [configs/continuous/mean_teacher_adapter_yolox/yolox_x_8xb4-24e_shift_from_clear_daytime.py](configs/continuous/mean_teacher_adapter_yolox/yolox_x_8xb4-24e_shift_from_clear_daytime.py).
44 |
45 | You can run the adaptation script on the validation set using [scripts/continuous/mean_teacher_adapter_yolox/val_yolox_shift_from_clear_daytime.sh](scripts/continuous/mean_teacher_adapter_yolox/val_yolox_shift_from_clear_daytime.sh)
46 |
47 | ### Continuously adapt a model to the test target domain
48 | Finally, collect your results on the test set and submit to our evaluation [benchmark](https://evalai.vis.xyz/web/challenges/challenge-page/6/overview).
49 |
50 | You can now test your test-time adaptation baseline on the test videos presenting continuous domain shift starting from the clear-daytime conditions.
51 |
52 | We implemented a baseline adapter based on a detection consistency loss and a mean-teacher formulation. You can find an implementation of the adapter at [mean_teacher_yolox_adapter](shift_tta/models/adapters/mean_teacher_yolox_adapter.py), and the corresponding config file at [configs/continuous/mean_teacher_adapter_yolox/yolox_x_8xb4-24e_shift_from_clear_daytime.py](configs/continuous/mean_teacher_adapter_yolox/yolox_x_8xb4-24e_shift_from_clear_daytime.py).
53 |
54 | You can run the adaptation script on the validation set using [scripts/continuous/mean_teacher_adapter_yolox/test_yolox_shift_from_clear_daytime.sh](scripts/continuous/mean_teacher_adapter_yolox/test_yolox_shift_from_clear_daytime.sh)
55 |
56 |
57 | ## Submission
58 |
59 | ### Submit your results
60 | Running the above-mentioned scripts with the following `CFG_OPTIONS` stores results in the [Scalabel](https://www.scalabel.ai/) format in `${WORK_DIR}/results`:
61 |
62 | ```bash
63 | declare -a CFG_OPTIONS=(
64 | "test_evaluator.0.outfile_prefix=${WORK_DIR}/results"
65 | )
66 | ```
67 |
68 | Identify the file ending with `.scalabel.json` and submit it to our [evaluation benchmark](https://evalai.vis.xyz/web/challenges/challenge-page/6/overview) to participate in the challenge.
69 |
70 | ### Submit a technical report
71 |
72 | We require participants to submit a short report providing details on their solution to [vcl.iccvworkshop.2023@gmail.com](mailto:vcl.iccvworkshop.2023@gmail.com).
73 |
74 | Remember that we will also award one team from each challenge with an Innovation Award. The Innovation Award is given to the team that proposes the most innovative method and/or insightful analysis. The winner will receive a certificate and an additional prize of 300 USD.
75 |
76 | Optionally, participant may submit their code or open a pull request after the challenge deadline if they want their adapter included in this repository.
77 |
--------------------------------------------------------------------------------
/docs/dataset_prepare.md:
--------------------------------------------------------------------------------
1 | ## SHIFT Dataset Preparation
2 |
3 | This page provides the instructions for the [SHIFT](https://www.vis.xyz/shift/) dataset preparation.
4 |
5 | ### 1. Downloading the Dataset
6 |
7 | Please download the SHIFT dataset from the [official website](https://www.vis.xyz/shift/download/) to your $DATADIR. It is recommended to symlink the root of the datasets to `$SHIFT_DETECTION_TTA/data`. This will avoid storing large files in your project directory, a requirement of several high-performance computing systems.
8 |
9 | Examples of other directories that we recommend to symlink are `checkpoints/`, `data/`, `work_dir/`.
10 |
11 | Symlink your data directory to the `$SHIFT_DETECTION_TTA` base directory using:
12 |
13 | ```shell
14 | ln -s $DATADIR/ $SHIFT_DETECTION_TTA/
15 | ```
16 |
17 | Then, use the official [download.py](https://github.com/SysCV/shift-dev/blob/main/download.py) script provided with the SHIFT devkit to download the dataset.
18 |
19 | ```shell
20 | mkdir -p $DATADIR/shift
21 |
22 | # Download the discrete shift set for training source models
23 | python tools/shift/download.py \
24 | --view "[front]" --group "[img, det_2d]" \
25 | --split "[train, val]" --framerate "[images]" \
26 | --shift "discrete" \
27 | $DATADIR/shift
28 |
29 | # Download the continuous shift set for test-time adaptation
30 | python tools/shift/download.py \
31 | --view "[front]" --group "[img, det_2d]" \
32 | --split "[val, test]" --framerate "[videos]" \
33 | --shift "continuous/1x" \
34 | $DATADIR/shift
35 | ```
36 |
37 | #### 1.1 Data Structure
38 |
39 | We here report the recommended data structure. If your folder structure is different from the following, you may need to change the corresponding paths in the config files.
40 |
41 | ```
42 | shift-detection-tta
43 | ├── shift_tta
44 | ├── tools
45 | ├── configs
46 | ├── data
47 | │ ├── shift
48 | │ │ ├── discrete
49 | │ │ │ ├── images
50 | │ │ │ │ ├── train
51 | │ │ │ │ │ ├── front
52 | │ │ │ │ │ │ ├── img.zip
53 | │ │ │ │ │ │ ├── det_2d.json (the official annotation files)
54 | │ │ │ │ ├── val
55 | │ │ │ │ │ │ ├── img.zip
56 | │ │ │ │ │ │ ├── det_2d.json (the official annotation files)
57 | │ │ ├── continuous
58 | │ │ │ ├── videos
59 | │ │ │ │ ├── 1x
60 | │ │ │ │ │ ├── val
61 | │ │ │ │ │ │ ├── front
62 | │ │ │ │ │ │ │ ├── img.tar
63 | │ │ │ │ │ │ │ ├── det_2d.json (the official annotation files)
64 | │ │ │ │ │ ├── test
65 | │ │ │ │ │ │ │ ├── img.tar
66 | │ │ │ │ │ │ │ ├── det_2d.json (the official annotation files)
67 | ```
68 |
69 |
70 | ### 2. Process the Dataset
71 |
72 | ### 2.1 Decompress the Dataset
73 | To ensure reproducible decompression of videos, we recommend using the [Docker image](https://github.com/SysCV/shift-dev/blob/main/Dockerfile) from the [official SHIFT devkit](https://github.com/SysCV/shift-dev). You could refer to the Docker engine's installation doc.
74 |
75 | ```shell
76 | # clone the SHIFT devkit
77 | git clone git@github.com:SysCV/shift-dev.git
78 | cd shift-dev
79 |
80 | # build and install our Docker image
81 | docker build -t shift_dataset_decompress .
82 |
83 | # run the container (the mode is set to "tar")
84 | docker run -v :/data -e MODE=tar shift_dataset_decompress
85 | # Here, denotes the root path under which all tar files will be processed recursively. The mode and number of jobs can be configured through environment variables MODE and JOBS.
86 | ```
87 |
88 | The folder structure will be as following after your run these scripts:
89 |
90 | ```
91 | shift-detection-tta
92 | ├── shift_tta
93 | ├── tools
94 | ├── configs
95 | ├── data
96 | │ ├── shift
97 | │ │ ├── discrete
98 | │ │ │ ├── images
99 | │ │ │ │ ├── train
100 | │ │ │ │ │ ├── front
101 | │ │ │ │ │ │ ├── img.zip
102 | │ │ │ │ │ │ ├── det_2d.json (the official annotation files)
103 | │ │ │ │ ├── val
104 | │ │ │ │ │ ├── front
105 | │ │ │ │ │ │ ├── img.zip
106 | │ │ │ │ │ │ ├── det_2d.json (the official annotation files)
107 | │ │ ├── continuous
108 | │ │ │ ├── videos
109 | │ │ │ │ ├── 1x
110 | │ │ │ │ │ ├── val
111 | │ │ │ │ │ │ ├── front
112 | │ │ │ │ │ │ │ ├── img.tar
113 | │ │ │ │ │ │ │ ├── img_decompressed.tar
114 | │ │ │ │ │ │ │ ├── det_2d.json (the official annotation files)
115 | │ │ │ │ │ ├── test
116 | │ │ │ │ │ │ ├── front
117 | │ │ │ │ │ │ │ ├── img.tar
118 | │ │ │ │ │ │ │ ├── img_decompressed.tar
119 | │ │ │ │ │ │ │ ├── det_2d.json (the official annotation files)
120 | ```
121 |
122 | ### 2.2 Convert Annotations
123 |
124 | We use [CocoVID](https://github.com/open-mmlab/mmtracking/blob/master/mmtrack/datasets/parsers/coco_video_parser.py) to maintain all datasets in this codebase.
125 |
126 | In this case, you need to convert the official annotations to this style. We provide scripts and the usages are as following:
127 |
128 | ```shell
129 | # SHIFT discrete (images, detection-like)
130 | python -m scalabel.label.to_coco -m box_track -i $DATADIR/shift/discrete/images/$SET_NAME/front/det_2d.json -o $DATADIR/shift/discrete/images/$SET_NAME/front/det_2d_cocoformat.json
131 |
132 | # SHIFT continuous (videos, tracking-like)
133 | python -m scalabel.label.to_coco -m box_track -i $DATADIR/shift/continuous/videos/1x/$SET_NAME/front/det_2d.json -o $DATADIR/shift/continuous/videos/1x/$SET_NAME/front/det_2d_cocoformat.json
134 | ```
135 |
136 | where `$SET_NAME` is one of `[train, val, test]`.
137 |
138 |
139 | The folder structure will be as following after your run these scripts:
140 |
141 | ```
142 | shift-detection-tta
143 | ├── shift_tta
144 | ├── tools
145 | ├── configs
146 | ├── data
147 | │ ├── shift
148 | │ │ ├── discrete
149 | │ │ │ ├── images
150 | │ │ │ │ ├── train
151 | │ │ │ │ │ ├── front
152 | │ │ │ │ │ │ ├── img.zip
153 | │ │ │ │ │ │ ├── det_2d.json (the official annotation files)
154 | │ │ │ │ │ │ ├── det_2d_cocoformat.json (the converted annotation file)
155 | │ │ │ │ ├── val
156 | │ │ │ │ │ ├── front
157 | │ │ │ │ │ │ ├── img.zip
158 | │ │ │ │ │ │ ├── det_2d.json (the official annotation files)
159 | │ │ │ │ │ │ ├── det_2d_cocoformat.json (the converted annotation file)
160 | │ │ ├── continuous
161 | │ │ │ ├── videos
162 | │ │ │ │ ├── 1x
163 | │ │ │ │ │ ├── val
164 | │ │ │ │ │ │ ├── front
165 | │ │ │ │ │ │ │ ├── img.tar
166 | │ │ │ │ │ │ │ ├── img_decompressed.tar
167 | │ │ │ │ │ │ │ ├── det_2d.json (the official annotation files)
168 | │ │ │ │ │ │ │ ├── det_2d_cocoformat.json (the converted annotation file)
169 | │ │ │ │ │ ├── test
170 | │ │ │ │ │ │ ├── front
171 | │ │ │ │ │ │ │ ├── img.tar
172 | │ │ │ │ │ │ │ ├── img_decompressed.tar
173 | │ │ │ │ │ │ │ ├── det_2d.json (the official annotation files)
174 | │ │ │ │ │ │ │ ├── det_2d_cocoformat.json (the converted annotation file)
175 | ```
176 |
177 | ### 2.3 Dataset Loading
178 | Some high-performance clusters do not support folders with a large number of files. For this reason, we implemented a [ZipBackend](shift_tta/fileio/backends/zip_backend.py) and a [TarBackend](shift_tta/fileio/backends/tar_backend.py) for loading data directly from `.zip` and `.tar` files.
179 |
180 | For usage, refer to the [`shift.py`](configs/_base_/datasets/shift.py) config file.
181 |
182 | # TODO: we might have to split the dataset file into two, depending on how we want to handle testing on the discrete val set and adapting to the continuous val set.
183 |
--------------------------------------------------------------------------------
/docs/get_started.md:
--------------------------------------------------------------------------------
1 | ## Prerequisites
2 |
3 | - Linux | macOS | Windows
4 | - Python 3.6+
5 | - PyTorch 1.6+
6 | - CUDA 9.2+ (If you build PyTorch from source, CUDA 9.0 is also compatible)
7 | - GCC 5+
8 | - [MMCV](https://mmcv.readthedocs.io/en/latest/get_started/installation.html)
9 | - [MMEngine](https://mmengine.readthedocs.io/en/latest/get_started/installation.html)
10 | - [MMDetection](https://mmdetection.readthedocs.io/en/latest/get_started.html#installation)
11 | - [MMTracking](https://mmtracking.readthedocs.io/en/latest/install.html#installation)
12 |
13 | The compatible MMTracking, MMEngine, MMCV, and MMDetection versions are as below. Please install the correct version to avoid installation issues.
14 |
15 | | MMTracking version | MMEngine version | MMCV version | MMDetection version |
16 | | :----------------: | :--------------: | :--------------------: | :---------------------: |
17 | | 1.x | mmengine>=0.1.0 | mmcv>=2.0.0rc1,\<2.0.0 | mmdet>=3.0.0rc0,\<3.0.0 |
18 | | 1.0.0rc1 | mmengine>=0.1.0 | mmcv>=2.0.0rc1,\<2.0.0 | mmdet>=3.0.0rc0,\<3.0.0 |
19 |
20 | ## Installation
21 |
22 | ### Detailed Instructions
23 |
24 | 1. Create a conda virtual environment and activate it.
25 |
26 | ```shell
27 | conda create -n shift-tta python=3.9 -y
28 | conda activate shift-tta
29 | ```
30 |
31 | 2. Install PyTorch and torchvision following the [official instructions](https://pytorch.org/). Here we use PyTorch 1.10.0 and CUDA 11.1.
32 | You may also switch to other version by specifying the version number.
33 |
34 | **Install with conda**
35 |
36 | ```shell
37 | conda install pytorch=1.11.0 torchvision cudatoolkit=11.3 -c pytorch
38 | ```
39 |
40 | **Install with pip**
41 |
42 | ```shell
43 | pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 --extra-index-url https://download.pytorch.org/whl/cu113
44 | ```
45 |
46 | 3. Install MMEngine
47 |
48 | ```shell
49 | pip install mmengine
50 | ```
51 |
52 | 4. Install mmcv, we recommend you to install the pre-build package as below.
53 |
54 | ```shell
55 | pip install 'mmcv>=2.0.0rc1' -f https://download.openmmlab.com/mmcv/dist/{cu_version}/{torch_version}/index.html
56 | ```
57 |
58 | mmcv is only compiled on PyTorch 1.x.0 because the compatibility usually holds between 1.x.0 and 1.x.1. If your PyTorch version is 1.x.1, you can install mmcv compiled with PyTorch 1.x.0 and it usually works well.
59 |
60 | ```shell
61 | # We can ignore the micro version of PyTorch
62 | pip install 'mmcv>=2.0.0rc1' -f https://download.openmmlab.com/mmcv/dist/cu113/torch1.11.0/index.html
63 | ```
64 |
65 | See [here](https://mmcv.readthedocs.io/en/latest/get_started/installation.html) for different versions of MMCV compatible to different PyTorch and CUDA versions.
66 | Optionally you can choose to compile mmcv from source by the following command
67 |
68 | ```shell
69 | git clone -b 2.x https://github.com/open-mmlab/mmcv.git
70 | cd mmcv
71 | MMCV_WITH_OPS=1 pip install -e . # package mmcv, which contains cuda ops, will be installed after this step
72 | # pip install -e . # package mmcv, which contains no cuda ops, will be installed after this step
73 | cd ..
74 | ```
75 |
76 | **Important**: You need to run pip uninstall mmcv-lite first if you have mmcv installed. Because if mmcv-lite and mmcv are both installed, there will be ModuleNotFoundError.
77 |
78 | 5. Install MMDetection
79 |
80 | ```shell
81 | pip install 'mmdet>=3.0.0rc0'
82 | ```
83 |
84 | Optionally, you can also build MMDetection from source in case you want to modify the code:
85 |
86 | ```shell
87 | git clone -b 3.x https://github.com/open-mmlab/mmdetection.git
88 | cd mmdetection
89 | pip install -r requirements/build.txt
90 | pip install -v -e . # or "python setup.py develop"
91 | ```
92 |
93 | 6. Install MMTracking
94 |
95 | ```shell
96 | pip install 'mmtrack>=1.0.0rc1'
97 | ```
98 |
99 | Optionally, you can also build MMTracking from source in case you want to modify the code:
100 |
101 | ```shell
102 | git clone -b 1.x https://github.com/open-mmlab/mmtracking.git
103 | cd mmtracking
104 | pip install -r requirements/build.txt
105 | pip install -v -e . # or "python setup.py develop"
106 | ```
107 |
108 | 7. Clone the shift-detection-tta repository.
109 |
110 | ```shell
111 | git clone git@github.com:SysCV/shift-detection-tta.git
112 | cd shift-detection-tta
113 | ```
114 |
115 | 8. Install build requirements and then install shift-detection-tta.
116 |
117 | ```shell
118 | pip install -r requirements/build.txt
119 | pip install -v -e . # or "python setup.py develop"
120 | ```
121 |
122 | Note:
123 |
124 | a. Following the above instructions, shift-detection-tta is installed on `dev` mode
125 | , any local modifications made to the code will take effect without the need to reinstall it.
126 |
127 | b. If you would like to use `opencv-python-headless` instead of `opencv-python`,
128 | you can install it before installing MMCV.
129 |
130 | ### A from-scratch setup script
131 |
132 | Assuming that you already have CUDA 10.1 installed, here is a full script for setting up shift-detection-tta with conda.
133 |
134 | ```shell
135 | conda create -n shift-tta python=3.9 -y
136 | conda activate shift-tta
137 |
138 | conda install pytorch=1.11.0 torchvision cudatoolkit=11.3 -c pytorch -y
139 |
140 | pip install -U openmim
141 | # install mmengine from main branch
142 | python -m pip install git+https://github.com/open-mmlab/mmengine.git@62f9504d701251db763f56658436fd23a586fe25
143 | # install mmcv
144 | mim install 'mmcv == 2.0.0rc4'
145 | # install mmdetection
146 | mim install 'mmdet == 3.0.0rc5'
147 | # install mmclassification from dev-1.x branch at specific commit
148 | python -m pip install git+https://github.com/open-mmlab/mmclassification.git@3ff80f5047fe3f3780a05d387f913dd02999611d
149 | # install mmtracking from dev-1.x branch at specific commit
150 | python -m pip install git+https://github.com/open-mmlab/mmtracking.git@9e4cb98a3cdac749242cd8decb3a172058d4fd6e
151 |
152 | # install trackeval for compatibility with mmtrack
153 | python -m pip install git+https://github.com/JonathonLuiten/TrackEval.git
154 | # install scalabel
155 | python -m pip install git+https://github.com/scalabel/scalabel.git
156 |
157 | # install shift-detection-tta
158 | git clone git@github.com:SysCV/shift-detection-tta.git
159 | cd shift-detection-tta
160 | python -m pip install --no-input -r requirements.txt
161 | pip install --no-input -v -e .
162 | ```
163 |
164 |
165 | Alternatively (and recommended), clone the repository and directly run the install script [setup_env.sh](tools/install/setup_env.sh):
166 |
167 | ```shell
168 | git clone git@github.com:SysCV/shift-detection-tta.git
169 | cd shift-detection-tta
170 | ./tools/install/setup_env.sh
171 | ```
172 |
173 |
--------------------------------------------------------------------------------
/docs/model_zoo.md:
--------------------------------------------------------------------------------
1 | # Model Zoo
2 |
3 | ## Continual test-time adaptation for object detection
4 | ### No Adaptation
5 |
6 | Please refer to [no_adap_yolox](configs/continuous/no_adap_yolox) for details.
7 |
8 | ### Mean-teacher
9 |
10 | Please refer to [mean_teacher_adapter_yolox](configs/continuous/mean_teacher_adapter_yolox) for details.
11 |
--------------------------------------------------------------------------------
/docs/train_test.md:
--------------------------------------------------------------------------------
1 | # Learn to train and test
2 |
3 | ## Train
4 |
5 | This section will show how to train existing models on supported datasets.
6 | The following training environments are supported:
7 |
8 | - CPU
9 | - single GPU
10 | - single node multiple GPUs
11 | - multiple nodes
12 |
13 | You can also manage jobs with Slurm.
14 |
15 | Important:
16 |
17 | - You can change the evaluation interval during training by modifying the `train_cfg` as
18 | `train_cfg = dict(val_interval=10)`. That means evaluating the model every 10 epochs.
19 | - The default learning rate in all config files is for 8 GPUs.
20 | According to the [Linear Scaling Rule](https://arxiv.org/abs/1706.02677),
21 | you need to set the learning rate proportional to the batch size if you use different GPUs or images per GPU,
22 | e.g., `lr=0.01` for 8 GPUs * 1 img/gpu and lr=0.04 for 16 GPUs * 2 imgs/gpu.
23 | - During training, log files and checkpoints will be saved to the working directory,
24 | which is specified by CLI argument `--work-dir`. It uses `./work_dirs/CONFIG_NAME` as default.
25 | - If you want the mixed precision training, simply specify CLI argument `--amp`.
26 |
27 | #### 1. Train on CPU
28 |
29 | The model is default put on cuda device.
30 | Only if there are no cuda devices, the model will be put on cpu.
31 | So if you want to train the model on CPU, you need to `export CUDA_VISIBLE_DEVICES=-1` to disable GPU visibility first.
32 | More details in [MMEngine](https://github.com/open-mmlab/mmengine/blob/ca282aee9e402104b644494ca491f73d93a9544f/mmengine/runner/runner.py#L849-L850).
33 |
34 | ```shell script
35 | CUDA_VISIBLE_DEVICES=-1 python tools/train.py ${CONFIG_FILE} [optional arguments]
36 | ```
37 |
38 | An example of training the VID model DFF on CPU:
39 |
40 | ```shell script
41 | CUDA_VISIBLE_DEVICES=-1 python tools/train.py configs/source/yolox/amp_yolox_x_8xb4-24e_shift_clear_daytime.py
42 | ```
43 |
44 | #### 2. Train on single GPU
45 |
46 | If you want to train the model on single GPU, you can directly use the `tools/train.py` as follows.
47 |
48 | ```shell script
49 | python tools/train.py ${CONFIG_FILE} [optional arguments]
50 | ```
51 |
52 | You can use `export CUDA_VISIBLE_DEVICES=$GPU_ID` to select the GPU.
53 |
54 | An example of training the MOT model ByteTrack on single GPU:
55 |
56 | ```shell script
57 | CUDA_VISIBLE_DEVICES=2 python tools/train.py configs/source/yolox/amp_yolox_x_8xb4-24e_shift_clear_daytime.py
58 | ```
59 |
60 | #### 3. Train on single node multiple GPUs
61 |
62 | We provide `tools/dist_train.sh` to launch training on multiple GPUs.
63 | The basic usage is as follows.
64 |
65 | ```shell script
66 | bash ./tools/dist_train.sh ${CONFIG_FILE} ${GPU_NUM} [optional arguments]
67 | ```
68 |
69 | If you would like to launch multiple jobs on a single machine,
70 | e.g., 2 jobs of 4-GPU training on a machine with 8 GPUs,
71 | you need to specify different ports (29500 by default) for each job to avoid communication conflict.
72 |
73 | For example, you can set the port in commands as follows.
74 |
75 | ```shell script
76 | CUDA_VISIBLE_DEVICES=0,1,2,3 PORT=29500 ./tools/dist_train.sh ${CONFIG_FILE} 4
77 | CUDA_VISIBLE_DEVICES=4,5,6,7 PORT=29501 ./tools/dist_train.sh ${CONFIG_FILE} 4
78 | ```
79 |
80 | An example of training the SOT model SiameseRPN++ on single node multiple GPUs:
81 |
82 | ```shell script
83 | bash ./tools/dist_train.sh configs/source/yolox/amp_yolox_x_8xb4-24e_shift_clear_daytime.py 8
84 | ```
85 |
86 | #### 4. Train on multiple nodes
87 |
88 | If you launch with multiple machines simply connected with ethernet, you can simply run following commands:
89 |
90 | On the first machine:
91 |
92 | ```shell script
93 | NNODES=2 NODE_RANK=0 PORT=$MASTER_PORT MASTER_ADDR=$MASTER_ADDR bash tools/dist_train.sh $CONFIG $GPUS
94 | ```
95 |
96 | On the second machine:
97 |
98 | ```shell script
99 | NNODES=2 NODE_RANK=1 PORT=$MASTER_PORT MASTER_ADDR=$MASTER_ADDR bash tools/dist_train.sh $CONFIG $GPUS
100 | ```
101 |
102 | Usually it is slow if you do not have high speed networking like InfiniBand.
103 |
104 | #### 5. Train with Slurm
105 |
106 | [Slurm](https://slurm.schedmd.com/) is a good job scheduling system for computing clusters.
107 | On a cluster managed by Slurm, you can use `slurm_train.sh` to spawn training jobs.
108 | It supports both single-node and multi-node training.
109 |
110 | The basic usage is as follows.
111 |
112 | ```shell script
113 | bash ./tools/slurm_train.sh ${PARTITION} ${JOB_NAME} ${CONFIG_FILE} ${WORK_DIR} ${GPUS}
114 | ```
115 |
116 | An example of training the YOLOX detector on SHIFT clear-daytime with Slurm:
117 |
118 | ```shell script
119 | PORT=29501 \
120 | GPUS_PER_NODE=8 \
121 | SRUN_ARGS="--quotatype=reserved" \
122 | bash ./tools/slurm_train.sh \
123 | mypartition \
124 | YOLOX_shift_clear_daytime \
125 | configs/source/yolox/amp_yolox_x_8xb4-24e_shift_clear_daytime.py \
126 | ./work_dirs/YOLOX_shift_clear_daytime \
127 | 8
128 | ```
129 |
130 | ## Test
131 |
132 | This section will show how to test existing models on supported datasets.
133 | The following testing environments are supported:
134 |
135 | - CPU
136 | - single GPU
137 | - single node multiple GPUs
138 | - multiple nodes
139 |
140 | You can also manage jobs with Slurm.
141 |
142 | Important:
143 |
144 | - You can set the results saving path by modifying the key `outfile_prefix` in evaluator.
145 | For example, `val_evaluator = dict(outfile_prefix='results/YOLOX_shift_from_clear_daytime')`.
146 | Otherwise, a temporal file will be created and will be removed after evaluation.
147 | - If you just want the formatted results without evaluation, you can set `format_only=True`.
148 | For example, `test_evaluator = dict(type='SHIFTVideoMetric', metric='bbox', outfile_prefix='results/YOLOX_shift_from_clear_daytime', format_only=True)`
149 |
150 | #### 1. Test on CPU
151 |
152 | The model is default put on cuda device.
153 | Only if there are no cuda devices, the model will be put on cpu.
154 | So if you want to test the model on CPU, you need to `export CUDA_VISIBLE_DEVICES=-1` to disable GPU visibility first.
155 | More details in [MMEngine](https://github.com/open-mmlab/mmengine/blob/ca282aee9e402104b644494ca491f73d93a9544f/mmengine/runner/runner.py#L849-L850).
156 |
157 | ```shell script
158 | CUDA_VISIBLE_DEVICES=-1 python tools/test.py ${CONFIG_FILE} [optional arguments]
159 | ```
160 |
161 | An example of testing the YOLOX clear-daytime model on CPU:
162 |
163 | ```shell script
164 | CUDA_VISIBLE_DEVICES=-1 python tools/test.py configs/source/yolox/amp_yolox_x_8xb4-24e_shift_clear_daytime.py --checkpoint checkpoints/yolox_x_8xb4-24e_shift_clear_daytime.pth
165 | ```
166 |
167 | #### 2. Test on single GPU
168 |
169 | If you want to test the model on single GPU, you can directly use the `tools/test.py` as follows.
170 |
171 | ```shell script
172 | python tools/test.py ${CONFIG_FILE} [optional arguments]
173 | ```
174 |
175 | You can use `export CUDA_VISIBLE_DEVICES=$GPU_ID` to select the GPU.
176 |
177 | An example of testing the YOLOX clear-daytime model on single GPU:
178 |
179 | ```shell script
180 | CUDA_VISIBLE_DEVICES=2 python tools/test.py configs/source/yolox/amp_yolox_x_8xb4-24e_shift_clear_daytime.py --checkpoint checkpoints/yolox_x_8xb4-24e_shift_clear_daytime.pth
181 | ```
182 |
183 | #### 3. Test on single node multiple GPUs
184 |
185 | We provide `tools/dist_test.sh` to launch testing on multiple GPUs.
186 | The basic usage is as follows.
187 |
188 | ```shell script
189 | bash ./tools/dist_test.sh ${CONFIG_FILE} ${GPU_NUM} [optional arguments]
190 | ```
191 |
192 | An example of testing the YOLOX clear-daytime model on single node multiple GPUs:
193 |
194 | ```shell script
195 | bash ./tools/dist_test.sh configs/source/yolox/amp_yolox_x_8xb4-24e_shift_clear_daytime.py 8 --checkpoint checkpoints/yolox_x_8xb4-24e_shift_clear_daytime.pth
196 | ```
197 |
198 | #### 4. Test on multiple nodes
199 |
200 | You can test on multiple nodes, which is similar with "Train on multiple nodes".
201 |
202 | #### 5. Test with Slurm
203 |
204 | On a cluster managed by Slurm, you can use `slurm_test.sh` to spawn testing jobs.
205 | It supports both single-node and multi-node testing.
206 |
207 | The basic usage is as follows.
208 |
209 | ```shell script
210 | bash ./tools/slurm_test.sh ${PARTITION} ${JOB_NAME} ${CONFIG_FILE} ${GPUS}
211 | ```
212 |
213 | An example of testing the YOLOX clear-daytime model with Slurm:
214 |
215 | ```shell script
216 | PORT=29501 \
217 | GPUS_PER_NODE=8 \
218 | SRUN_ARGS="--quotatype=reserved" \
219 | bash ./tools/slurm_test.sh \
220 | mypartition \
221 | YOLOX_clear_daytime \
222 | configs/source/yolox/amp_yolox_x_8xb4-24e_shift_clear_daytime.py \
223 | 8 \
224 | --checkpoint checkpoints/yolox_x_8xb4-24e_shift_clear_daytime.pth
225 | ```
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | -r requirements/build.txt
2 | -r requirements/runtime.txt
--------------------------------------------------------------------------------
/requirements/build.txt:
--------------------------------------------------------------------------------
1 | cython
2 | numpy
--------------------------------------------------------------------------------
/requirements/mminstall.txt:
--------------------------------------------------------------------------------
1 | mmcls>=1.0.0rc0
2 | mmcv>=2.0.0rc1
3 | mmdet>=3.0.0rc0
4 | mmengine>=0.1.0
5 | mmtrack>=1.0.0rc1
--------------------------------------------------------------------------------
/requirements/runtime.txt:
--------------------------------------------------------------------------------
1 | attributee
2 | lap
3 | matplotlib
4 | mmcls>=1.0.0rc0
5 | motmetrics
6 | packaging
7 | pandas<=1.3.5
8 | pycocotools
9 | pydantic==1.9.1
10 | pytest
11 | scikit-learn
12 | scikit-image<=0.19.3
13 | scipy<=1.7.3
14 | seaborn
15 | tabulate
16 | terminaltables
17 | tqdm
18 |
--------------------------------------------------------------------------------
/resources/shift-logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SysCV/shift-detection-tta/a5c4fb0e906caa57c0b572d651f24f01ff6bdcdf/resources/shift-logo.png
--------------------------------------------------------------------------------
/scripts/continuous/mean_teacher_adapter_yolox/slurm_test_yolox_shift_from_clear_daytime.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | TIME=12:00:00 # TIME=(24:00:00)
3 | PARTITION=gpu22 # PARTITION=(gpu16 | gpu20 | gpu22)
4 | GPUS_TYPE=a40 # GPUS_TYPE=(Quadro_RTX_8000 | a40 | a100)
5 | GPUS=4
6 | CPUS=16
7 | MEM_PER_CPU=22000
8 | GPUS_PER_NODE=${GPUS_PER_NODE:-8}
9 | CPUS_PER_TASK=${CPUS_PER_TASK:-16}
10 | SBATCH_ARGS=${SBATCH_ARGS:-""}
11 | GPUS_PER_NODE=${GPUS}
12 | CPUS_PER_TASK=${CPUS}
13 |
14 | ###############
15 | ##### Your args
16 | CONFIG=configs/continuous/mean_teacher_adapter_yolox/yolox_x_8xb4-24e_shift_from_clear_daytime.py
17 | CKPT=https://dl.cv.ethz.ch/shift/challenge2023/test_time_adaptation/checkpoints/yolox_x_8xb4-24e_shift_clear_daytime.pth
18 | WORK_DIR=work_dirs/continuous/mean_teacher_yolox/test/yolox_x_8xb4-24e_shift_from_clear_ndaytime
19 |
20 | declare -a CFG_OPTIONS=(
21 | "test_evaluator.0.outfile_prefix=${WORK_DIR}/results"
22 | "test_dataloader.dataset.ann_file=data/shift/continuous/videos/1x/test/front/det_2d_cocoformat.json"
23 | "test_dataloader.dataset.pipeline.0.backend_args.tar_path=data/shift/continuous/videos/1x/test/front/img_decompressed.tar"
24 | )
25 | #####
26 | ###############
27 |
28 | if [ $GPUS -gt 1 ]
29 | then
30 | CMD=tools/dist_test.sh
31 | else
32 | CMD=tools/test.sh
33 | fi
34 |
35 | echo "Launching ${CMD} on ${GPUS} gpus."
36 | echo "Starting job ${JOB_NAME} from ${CONFIG} using --cfg-options ${CFG_OPTIONS[*]}"
37 |
38 | mkdir -p errors/
39 | mkdir -p outputs/
40 |
41 | ID=$(sbatch \
42 | --parsable \
43 | -t ${TIME} \
44 | --job-name=${JOB_NAME} \
45 | -p ${PARTITION} \
46 | --gres=gpu:${GPUS_TYPE}:${GPUS_PER_NODE} \
47 | -e errors/%j.log \
48 | -o outputs/%j.log \
49 | --mail-type=BEGIN,END,FAIL \
50 | ${SBATCH_ARGS} \
51 | ${CMD} \
52 | ${CONFIG} \
53 | ${GPUS} \
54 | --checkpoint ${CKPT} \
55 | --cfg-options ${CFG_OPTIONS[@]})
56 |
57 |
--------------------------------------------------------------------------------
/scripts/continuous/mean_teacher_adapter_yolox/test_yolox_shift_from_clear_daytime.sh:
--------------------------------------------------------------------------------
1 | CONFIG=configs/continuous/mean_teacher_adapter_yolox/yolox_x_8xb4-24e_shift_from_clear_daytime.py
2 | CKPT=https://dl.cv.ethz.ch/shift/challenge2023/test_time_adaptation/checkpoints/yolox_x_8xb4-24e_shift_clear_daytime.pth
3 | WORK_DIR=work_dirs/continuous/mean_teacher_yolox/test/yolox_x_8xb4-24e_shift_from_clear_ndaytime
4 |
5 | declare -a CFG_OPTIONS=(
6 | "test_evaluator.0.outfile_prefix=${WORK_DIR}/results"
7 | "test_dataloader.dataset.ann_file=data/shift/continuous/videos/1x/test/front/det_2d_cocoformat.json"
8 | "test_dataloader.dataset.pipeline.0.backend_args.tar_path=data/shift/continuous/videos/1x/test/front/img_decompressed.tar"
9 | )
10 |
11 | # python -m debugpy --listen $HOSTNAME:5678 --wait-for-client tools/test.py \
12 | python tools/test.py \
13 | ${CONFIG} \
14 | --checkpoint ${CKPT} \
15 | --cfg-options ${CFG_OPTIONS[@]}
16 |
--------------------------------------------------------------------------------
/scripts/continuous/mean_teacher_adapter_yolox/val_yolox_shift_from_clear_daytime.sh:
--------------------------------------------------------------------------------
1 | CONFIG=configs/continuous/mean_teacher_adapter_yolox/yolox_x_8xb4-24e_shift_from_clear_daytime.py
2 | CKPT=https://dl.cv.ethz.ch/shift/challenge2023/test_time_adaptation/checkpoints/yolox_x_8xb4-24e_shift_clear_daytime.pth
3 | WORK_DIR=work_dirs/continuous/mean_teacher_yolox/val/yolox_x_8xb4-24e_shift_from_clear_ndaytime
4 |
5 | declare -a CFG_OPTIONS=(
6 | "test_evaluator.0.outfile_prefix=${WORK_DIR}/results"
7 | )
8 |
9 | # python -m debugpy --listen $HOSTNAME:5678 --wait-for-client tools/test.py \
10 | python tools/test.py \
11 | ${CONFIG} \
12 | --checkpoint ${CKPT} \
13 | --cfg-options ${CFG_OPTIONS[@]}
14 |
--------------------------------------------------------------------------------
/scripts/continuous/no_adap_yolox/test_yolox_shift_from_clear_daytime.sh:
--------------------------------------------------------------------------------
1 | CONFIG=configs/continuous/no_adap_yolox/yolox_x_8xb4-24e_shift_from_clear_daytime.py
2 | CKPT=https://dl.cv.ethz.ch/shift/challenge2023/test_time_adaptation/checkpoints/yolox_x_8xb4-24e_shift_clear_daytime.pth
3 | WORK_DIR=work_dirs/continuous/no_adap_yolox/test/yolox_x_8xb4-24e_shift_from_clear_daytime
4 |
5 | declare -a CFG_OPTIONS=(
6 | "test_evaluator.0.outfile_prefix=${WORK_DIR}/results"
7 | "test_dataloader.dataset.ann_file=data/shift/continuous/videos/1x/test/front/det_2d_cocoformat.json"
8 | "test_dataloader.dataset.pipeline.0.backend_args.tar_path=data/shift/continuous/videos/1x/test/front/img_decompressed.tar"
9 | )
10 |
11 | # python -m debugpy --listen $HOSTNAME:5678 --wait-for-client tools/test.py \
12 | python tools/test.py \
13 | ${CONFIG} \
14 | --checkpoint ${CKPT} \
15 | --work-dir ${WORK_DIR} \
16 | --cfg-options ${CFG_OPTIONS[@]}
--------------------------------------------------------------------------------
/scripts/continuous/no_adap_yolox/test_yolox_shift_from_clear_night.sh:
--------------------------------------------------------------------------------
1 | CONFIG=configs/continuous/no_adap_yolox/yolox_x_8xb4-24e_shift_from_clear_night.py
2 | CKPT=https://dl.cv.ethz.ch/shift/challenge2023/test_time_adaptation/checkpoints/yolox_x_8xb4-24e_shift_clear_daytime.pth
3 | WORK_DIR=work_dirs/continuous/no_adap_yolox/test/yolox_x_8xb4-24e_shift_from_clear_night
4 |
5 | declare -a CFG_OPTIONS=(
6 | "test_evaluator.0.outfile_prefix=${WORK_DIR}/results"
7 | "test_dataloader.dataset.ann_file=data/shift/continuous/videos/1x/test/front/det_2d_cocoformat.json"
8 | "test_dataloader.dataset.pipeline.0.backend_args.tar_path=data/shift/continuous/videos/1x/test/front/img_decompressed.tar"
9 |
10 | )
11 |
12 | # python tools/test.py \
13 | python -m debugpy --listen $HOSTNAME:5678 --wait-for-client tools/test.py \
14 | ${CONFIG} \
15 | --checkpoint ${CKPT} \
16 | --work-dir ${WORK_DIR} \
17 | --cfg-options ${CFG_OPTIONS[@]}
--------------------------------------------------------------------------------
/scripts/continuous/no_adap_yolox/val_yolox_shift_from_clear_daytime.sh:
--------------------------------------------------------------------------------
1 | CONFIG=configs/continuous/no_adap_yolox/yolox_x_8xb4-24e_shift_from_clear_daytime.py
2 | CKPT=https://dl.cv.ethz.ch/shift/challenge2023/test_time_adaptation/checkpoints/yolox_x_8xb4-24e_shift_clear_daytime.pth
3 | WORK_DIR=work_dirs/continuous/no_adap_yolox/val/yolox_x_8xb4-24e_shift_from_clear_daytime
4 |
5 | declare -a CFG_OPTIONS=(
6 | "test_evaluator.0.outfile_prefix=${WORK_DIR}/results"
7 | )
8 |
9 | # python -m debugpy --listen $HOSTNAME:5678 --wait-for-client tools/test.py \
10 | python tools/test.py \
11 | ${CONFIG} \
12 | --checkpoint ${CKPT} \
13 | --work-dir ${WORK_DIR} \
14 | --cfg-options ${CFG_OPTIONS[@]}
--------------------------------------------------------------------------------
/scripts/continuous/no_adap_yolox/val_yolox_shift_from_clear_night.sh:
--------------------------------------------------------------------------------
1 | CONFIG=configs/continuous/no_adap_yolox/yolox_x_8xb4-24e_shift_from_clear_night.py
2 | CKPT=https://dl.cv.ethz.ch/shift/challenge2023/test_time_adaptation/checkpoints/yolox_x_8xb4-24e_shift_clear_daytime.pth
3 | WORK_DIR=work_dirs/continuous/no_adap_yolox/val/yolox_x_8xb4-24e_shift_from_clear_night
4 |
5 | declare -a CFG_OPTIONS=(
6 | "test_evaluator.0.outfile_prefix=${WORK_DIR}/results"
7 | )
8 |
9 | # python tools/test.py \
10 | python -m debugpy --listen $HOSTNAME:5678 --wait-for-client tools/test.py \
11 | ${CONFIG} \
12 | --checkpoint ${CKPT} \
13 | --work-dir ${WORK_DIR} \
14 | --cfg-options ${CFG_OPTIONS[@]}
--------------------------------------------------------------------------------
/scripts/source/slurm_train_yolox_24e_shift_clear_daytime.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | TIME=12:00:00 # TIME=(24:00:00)
3 | PARTITION=gpu22 # PARTITION=(gpu16 | gpu20 | gpu22)
4 | GPUS_TYPE=a40 # GPUS_TYPE=(Quadro_RTX_8000 | a40 | a100)
5 | GPUS=4
6 | CPUS=16
7 | MEM_PER_CPU=22000
8 | GPUS_PER_NODE=${GPUS_PER_NODE:-8}
9 | CPUS_PER_TASK=${CPUS_PER_TASK:-16}
10 | SBATCH_ARGS=${SBATCH_ARGS:-""}
11 | GPUS_PER_NODE=${GPUS}
12 | CPUS_PER_TASK=${CPUS}
13 |
14 | ###############
15 | ##### Your args
16 | CONFIG=configs/source/yolox/yolox_x_8xb4-24e_shift_clear_daytime.py
17 | declare -a CFG_OPTIONS=(
18 | "data.workers_per_gpu=4"
19 | "data.samples_per_gpu=4"
20 | )
21 | #####
22 | ###############
23 |
24 | if [ $GPUS -gt 1 ]
25 | then
26 | CMD=tools/dist_train.sh
27 | else
28 | CMD=tools/train.sh
29 | fi
30 |
31 | echo "Launching ${CMD} on ${GPUS} gpus."
32 | echo "Starting job ${JOB_NAME} from ${CONFIG} using --cfg-options ${CFG_OPTIONS[*]}"
33 |
34 | mkdir -p errors/
35 | mkdir -p outputs/
36 |
37 | ID=$(sbatch \
38 | --parsable \
39 | -t ${TIME} \
40 | --job-name=${JOB_NAME} \
41 | -p ${PARTITION} \
42 | --gres=gpu:${GPUS_TYPE}:${GPUS_PER_NODE} \
43 | -e errors/%j.log \
44 | -o outputs/%j.log \
45 | --mail-type=BEGIN,END,FAIL \
46 | ${SBATCH_ARGS} \
47 | ${CMD} \
48 | ${CONFIG} \
49 | ${GPUS} \
50 | --cfg-options ${CFG_OPTIONS[@]})
--------------------------------------------------------------------------------
/scripts/source/slurm_train_yolox_24e_shift_daytime.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | TIME=12:00:00 # TIME=(24:00:00)
3 | PARTITION=gpu22 # PARTITION=(gpu16 | gpu20 | gpu22)
4 | GPUS_TYPE=a40 # GPUS_TYPE=(Quadro_RTX_8000 | a40 | a100)
5 | GPUS=4
6 | CPUS=16
7 | MEM_PER_CPU=22000
8 | GPUS_PER_NODE=${GPUS_PER_NODE:-8}
9 | CPUS_PER_TASK=${CPUS_PER_TASK:-16}
10 | SBATCH_ARGS=${SBATCH_ARGS:-""}
11 | GPUS_PER_NODE=${GPUS}
12 | CPUS_PER_TASK=${CPUS}
13 |
14 | ###############
15 | ##### Your args
16 | CONFIG=configs/source/yolox/yolox_x_8xb4-24e_shift_daytime.py
17 | declare -a CFG_OPTIONS=(
18 | "data.workers_per_gpu=4"
19 | "data.samples_per_gpu=4"
20 | )
21 | #####
22 | ###############
23 |
24 | if [ $GPUS -gt 1 ]
25 | then
26 | CMD=tools/dist_train.sh
27 | else
28 | CMD=tools/train.sh
29 | fi
30 |
31 | echo "Launching ${CMD} on ${GPUS} gpus."
32 | echo "Starting job ${JOB_NAME} from ${CONFIG} using --cfg-options ${CFG_OPTIONS[*]}"
33 |
34 | mkdir -p errors/
35 | mkdir -p outputs/
36 |
37 | ID=$(sbatch \
38 | --parsable \
39 | -t ${TIME} \
40 | --job-name=${JOB_NAME} \
41 | -p ${PARTITION} \
42 | --gres=gpu:${GPUS_TYPE}:${GPUS_PER_NODE} \
43 | -e errors/%j.log \
44 | -o outputs/%j.log \
45 | --mail-type=BEGIN,END,FAIL \
46 | ${SBATCH_ARGS} \
47 | ${CMD} \
48 | ${CONFIG} \
49 | ${GPUS} \
50 | --cfg-options ${CFG_OPTIONS[@]})
--------------------------------------------------------------------------------
/scripts/source/test_yolox_shift_clear_daytime.sh:
--------------------------------------------------------------------------------
1 | CONFIG=configs/source/yolox/yolox_x_8xb4-24e_shift_clear_daytime.py
2 | CKPT=https://dl.cv.ethz.ch/shift/challenge2023/test_time_adaptation/checkpoints/yolox_x_8xb4-24e_shift_clear_daytime.pth
3 |
4 | # python -m debugpy --listen $HOSTNAME:5678 --wait-for-client tools/test.py \
5 | python tools/test.py \
6 | ${CONFIG} \
7 | --checkpoint ${CKPT}
--------------------------------------------------------------------------------
/scripts/source/test_yolox_shift_clear_night.sh:
--------------------------------------------------------------------------------
1 | CONFIG=configs/source/yolox/yolox_x_8xb4-24e_shift_clear_night.py
2 | CKPT=https://dl.cv.ethz.ch/shift/challenge2023/test_time_adaptation/checkpoints/yolox_x_8xb4-24e_shift_clear_daytime.pth
3 |
4 | # python -m debugpy --listen $HOSTNAME:5678 --wait-for-client tools/test.py \
5 | python tools/test.py \
6 | ${CONFIG} \
7 | --checkpoint ${CKPT}
--------------------------------------------------------------------------------
/scripts/source/train_yolox_shift_clear_daytime.sh:
--------------------------------------------------------------------------------
1 | # CONFIG=configs/source/yolox/yolox_x_8xb4-24e_shift_clear_daytime.py
2 | CONFIG=configs/source/yolox/amp_yolox_x_8xb4-24e_shift_clear_daytime.py
3 |
4 | # python -m debugpy --listen $HOSTNAME:5678 --wait-for-client tools/train.py \
5 | python tools/train.py \
6 | ${CONFIG}
--------------------------------------------------------------------------------
/scripts/tmp/edit_json.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | ann_file = 'data/shift/continuous/videos/1x/val/front/det_2d_cocoformat.json'
4 | dest_file = 'data/shift/continuous/videos/1x/val/front/det_2d_cocoformat_tmp.json'
5 |
6 | with open(ann_file, 'r') as fp:
7 | anns = json.load(fp)
8 |
9 | for im in anns['images']:
10 | im['file_name'] = im['file_name'].replace('_img_front.jpg', '.jpg')
11 |
12 | with open(dest_file, 'w') as fp:
13 | json.dump(anns, fp)
14 |
15 |
--------------------------------------------------------------------------------
/scripts/tmp/edit_json_ids.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | ann_file = 'data/shift/continuous/videos/1x/test/front/det_2d_cocoformat.json'
4 | dest_file = 'data/shift/continuous/videos/1x/test/front/det_2d_cocoformat_tmp.json'
5 |
6 | with open(ann_file, 'r') as fp:
7 | anns = json.load(fp)
8 |
9 | print('Starting conversion')
10 | for im in anns['images']:
11 | im['file_name'] = im['file_name'].replace(str(im['frame_id']).zfill(8), str(im['frame_id']//10).zfill(8))
12 | im['frame_id'] = im['frame_id'] // 10
13 | print('Conversion done')
14 |
15 | with open(dest_file, 'w') as fp:
16 | json.dump(anns, fp)
17 |
18 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [isort]
2 | line_length = 79
3 | multi_line_output = 0
4 | extra_standard_library = setuptools
5 | known_first_party = shift-tta
6 | known_third_party = PIL,addict,cv2,lap,matplotlib,mmcls,mmcv,mmdet,mmtrack,motmetrics,numpy,packaging,pandas,pycocotools,pytest,pytorch_sphinx_theme,requests,scipy,script_utils,seaborn,tao,terminaltables,torch,tqdm
7 | no_lines_before = STDLIB,LOCALFOLDER
8 | default_section = THIRDPARTY
9 |
10 | [yapf]
11 | BASED_ON_STYLE = pep8
12 | BLANK_LINE_BEFORE_NESTED_CLASS_OR_DEF = true
13 | SPLIT_BEFORE_EXPRESSION_AFTER_OPENING_PAREN = true
14 |
15 | [codespell]
16 | ignore-words-list = mot,gool
17 | skip = *.json
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 | import shutil
4 | import sys
5 | import warnings
6 | from setuptools import find_packages, setup
7 |
8 | import torch
9 | from torch.utils.cpp_extension import (BuildExtension, CppExtension,
10 | CUDAExtension)
11 |
12 |
13 | def readme():
14 | with open('README.md', encoding='utf-8') as f:
15 | content = f.read()
16 | return content
17 |
18 |
19 | version_file = 'shift_tta/version.py'
20 |
21 |
22 | def get_version():
23 | with open(version_file, 'r') as f:
24 | exec(compile(f.read(), version_file, 'exec'))
25 | return locals()['__version__']
26 |
27 |
28 | def make_cuda_ext(name, module, sources, sources_cuda=[]):
29 |
30 | define_macros = []
31 | extra_compile_args = {'cxx': []}
32 |
33 | if torch.cuda.is_available() or os.getenv('FORCE_CUDA', '0') == '1':
34 | define_macros += [('WITH_CUDA', None)]
35 | extension = CUDAExtension
36 | extra_compile_args['nvcc'] = [
37 | '-D__CUDA_NO_HALF_OPERATORS__',
38 | '-D__CUDA_NO_HALF_CONVERSIONS__',
39 | '-D__CUDA_NO_HALF2_OPERATORS__',
40 | ]
41 | sources += sources_cuda
42 | else:
43 | print(f'Compiling {name} without CUDA')
44 | extension = CppExtension
45 |
46 | return extension(
47 | name=f'{module}.{name}',
48 | sources=[os.path.join(*module.split('.'), p) for p in sources],
49 | define_macros=define_macros,
50 | extra_compile_args=extra_compile_args)
51 |
52 |
53 | def parse_requirements(fname='requirements.txt', with_version=True):
54 | """Parse the package dependencies listed in a requirements file but strips
55 | specific versioning information.
56 |
57 | Args:
58 | fname (str): path to requirements file
59 | with_version (bool, default=False): if True include version specs
60 |
61 | Returns:
62 | List[str]: list of requirements items
63 |
64 | CommandLine:
65 | python -c "import setup; print(setup.parse_requirements())"
66 | """
67 | import re
68 | import sys
69 | from os.path import exists
70 | require_fpath = fname
71 |
72 | def parse_line(line):
73 | """Parse information from a line in a requirements text file."""
74 | if line.startswith('-r '):
75 | # Allow specifying requirements in other files
76 | target = line.split(' ')[1]
77 | for info in parse_require_file(target):
78 | yield info
79 | else:
80 | info = {'line': line}
81 | if line.startswith('-e '):
82 | info['package'] = line.split('#egg=')[1]
83 | elif '@git+' in line:
84 | info['package'] = line
85 | else:
86 | # Remove versioning from the package
87 | pat = '(' + '|'.join(['>=', '==', '>']) + ')'
88 | parts = re.split(pat, line, maxsplit=1)
89 | parts = [p.strip() for p in parts]
90 |
91 | info['package'] = parts[0]
92 | if len(parts) > 1:
93 | op, rest = parts[1:]
94 | if ';' in rest:
95 | # Handle platform specific dependencies
96 | # http://setuptools.readthedocs.io/en/latest/setuptools.html#declaring-platform-specific-dependencies
97 | version, platform_deps = map(str.strip,
98 | rest.split(';'))
99 | info['platform_deps'] = platform_deps
100 | else:
101 | version = rest # NOQA
102 | info['version'] = (op, version)
103 | yield info
104 |
105 | def parse_require_file(fpath):
106 | with open(fpath, 'r') as f:
107 | for line in f.readlines():
108 | line = line.strip()
109 | if line and not line.startswith('#'):
110 | for info in parse_line(line):
111 | yield info
112 |
113 | def gen_packages_items():
114 | if exists(require_fpath):
115 | for info in parse_require_file(require_fpath):
116 | parts = [info['package']]
117 | if with_version and 'version' in info:
118 | parts.extend(info['version'])
119 | if not sys.version.startswith('3.4'):
120 | # apparently package_deps are broken in 3.4
121 | platform_deps = info.get('platform_deps')
122 | if platform_deps is not None:
123 | parts.append(';' + platform_deps)
124 | item = ''.join(parts)
125 | yield item
126 |
127 | packages = list(gen_packages_items())
128 | return packages
129 |
130 |
131 | def add_mim_extension():
132 | """Add extra files that are required to support MIM into the package.
133 |
134 | These files will be added by creating a symlink to the originals if the
135 | package is installed in `editable` mode (e.g. pip install -e .), or by
136 | copying from the originals otherwise.
137 | """
138 |
139 | # parse installment mode
140 | if 'develop' in sys.argv:
141 | # installed by `pip install -e .`
142 | mode = 'symlink'
143 | elif 'sdist' in sys.argv or 'bdist_wheel' in sys.argv:
144 | # installed by `pip install .`
145 | # or create source distribution by `python setup.py sdist`
146 | mode = 'copy'
147 | else:
148 | return
149 |
150 | filenames = ['tools', 'configs', 'model-index.yml']
151 | repo_path = osp.dirname(__file__)
152 | mim_path = osp.join(repo_path, 'shift_tta', '.mim')
153 | os.makedirs(mim_path, exist_ok=True)
154 |
155 | for filename in filenames:
156 | if osp.exists(filename):
157 | src_path = osp.join(repo_path, filename)
158 | tar_path = osp.join(mim_path, filename)
159 |
160 | if osp.isfile(tar_path) or osp.islink(tar_path):
161 | os.remove(tar_path)
162 | elif osp.isdir(tar_path):
163 | shutil.rmtree(tar_path)
164 |
165 | if mode == 'symlink':
166 | src_relpath = osp.relpath(src_path, osp.dirname(tar_path))
167 | try:
168 | os.symlink(src_relpath, tar_path)
169 | except OSError:
170 | # Creating a symbolic link on windows may raise an
171 | # `OSError: [WinError 1314]` due to privilege. If
172 | # the error happens, the src file will be copied
173 | mode = 'copy'
174 | warnings.warn(
175 | f'Failed to create a symbolic link for {src_relpath}, '
176 | f'and it will be copied to {tar_path}')
177 | else:
178 | continue
179 |
180 | if mode == 'copy':
181 | if osp.isfile(src_path):
182 | shutil.copyfile(src_path, tar_path)
183 | elif osp.isdir(src_path):
184 | shutil.copytree(src_path, tar_path)
185 | else:
186 | warnings.warn(f'Cannot copy file {src_path}.')
187 | else:
188 | raise ValueError(f'Invalid mode {mode}')
189 |
190 |
191 | if __name__ == '__main__':
192 | add_mim_extension()
193 | setup(
194 | name='shift-tta',
195 | version=get_version(),
196 | description='Test-time adaptation platform for object detection from the VIS group at ETH Zurich',
197 | long_description=readme(),
198 | long_description_content_type='text/markdown',
199 | author='Mattia Segu',
200 | author_email='mattia.segu@vision.ee.ethz.ch',
201 | keywords='computer vision, object detection, test-time adaptation',
202 | url='https://github.com/SysCV/shift-detection-tta',
203 | packages=find_packages(exclude=('configs', 'tools', 'demo')),
204 | include_package_data=True,
205 | classifiers=[
206 | 'Development Status :: 4 - Beta',
207 | 'License :: OSI Approved :: MIT License',
208 | 'Operating System :: OS Independent',
209 | 'Programming Language :: Python :: 3',
210 | 'Programming Language :: Python :: 3.6',
211 | 'Programming Language :: Python :: 3.7',
212 | 'Programming Language :: Python :: 3.8',
213 | 'Programming Language :: Python :: 3.9',
214 | ],
215 | license='MIT License',
216 | install_requires=parse_requirements('requirements/runtime.txt'),
217 | extras_require={
218 | 'all': parse_requirements('requirements.txt'),
219 | 'build': parse_requirements('requirements/build.txt'),
220 | 'mim': parse_requirements('requirements/mminstall.txt')
221 | },
222 | ext_modules=[],
223 | cmdclass={'build_ext': BuildExtension},
224 | zip_safe=False)
--------------------------------------------------------------------------------
/shift_tta/.mim/configs:
--------------------------------------------------------------------------------
1 | ../../configs
--------------------------------------------------------------------------------
/shift_tta/.mim/tools:
--------------------------------------------------------------------------------
1 | ../../tools
--------------------------------------------------------------------------------
/shift_tta/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import warnings
3 |
4 | import mmcv
5 | import mmdet
6 | import mmtrack
7 | from packaging.version import parse
8 |
9 | from .version import __version__, version_info
10 |
11 | MMCV_MIN = '2.0.0rc1'
12 | MMCV_MAX = '2.0.0'
13 |
14 | MMDET_MIN = '3.0.0rc0'
15 | MMDET_MAX = '3.0.0'
16 |
17 | MMTRACK_MIN = '1.0.0rc1'
18 | MMTRACK_MAX = '1.0.0'
19 |
20 |
21 | def digit_version(version_str: str, length: int = 4):
22 | """Convert a version string into a tuple of integers.
23 |
24 | This method is usually used for comparing two versions. For pre-release
25 | versions: alpha < beta < rc.
26 |
27 | Args:
28 | version_str (str): The version string.
29 | length (int): The maximum number of version levels. Default: 4.
30 |
31 | Returns:
32 | tuple[int]: The version info in digits (integers).
33 | """
34 | version = parse(version_str)
35 | assert version.release, f'failed to parse version {version_str}'
36 | release = list(version.release)
37 | release = release[:length]
38 | if len(release) < length:
39 | release = release + [0] * (length - len(release))
40 | if version.is_prerelease:
41 | mapping = {'a': -3, 'b': -2, 'rc': -1}
42 | val = -4
43 | # version.pre can be None
44 | if version.pre:
45 | if version.pre[0] not in mapping:
46 | warnings.warn(f'unknown prerelease version {version.pre[0]}, '
47 | 'version checking may go wrong')
48 | else:
49 | val = mapping[version.pre[0]]
50 | release.extend([val, version.pre[-1]])
51 | else:
52 | release.extend([val, 0])
53 |
54 | elif version.is_postrelease:
55 | release.extend([1, version.post])
56 | else:
57 | release.extend([0, 0])
58 | return tuple(release)
59 |
60 |
61 | mmcv_min_version = digit_version(MMCV_MIN)
62 | mmcv_max_version = digit_version(MMCV_MAX)
63 | mmcv_version = digit_version(mmcv.__version__)
64 |
65 |
66 | assert (mmcv_min_version <= mmcv_version < mmcv_max_version), \
67 | f'MMCV=={mmcv.__version__} is used but incompatible. ' \
68 | f'Please install mmcv>={MMCV_MIN}, <{MMCV_MAX}.'
69 |
70 | mmdet_min_version = digit_version(MMDET_MIN)
71 | mmdet_max_version = digit_version(MMDET_MAX)
72 | mmdet_version = digit_version(mmdet.__version__)
73 |
74 |
75 | assert (mmdet_min_version <= mmdet_version < mmdet_max_version), \
76 | f'MMDet=={mmdet.__version__} is used but incompatible. ' \
77 | f'Please install mmdet>={MMDET_MIN}, <{MMDET_MAX}.'
78 |
79 | mmtrack_min_version = digit_version(MMTRACK_MIN)
80 | mmtrack_max_version = digit_version(MMTRACK_MAX)
81 | mmtrack_version = digit_version(mmtrack.__version__)
82 |
83 |
84 | assert (mmtrack_min_version <= mmtrack_version < mmtrack_max_version), \
85 | f'MMTrack=={mmtrack.__version__} is used but incompatible. ' \
86 | f'Please install mmtrack>={MMTRACK_MIN}, <{MMTRACK_MAX}.'
87 |
88 | __all__ = ['__version__', 'version_info', 'digit_version']
--------------------------------------------------------------------------------
/shift_tta/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from .shift_dataset import SHIFTDataset
2 | from .utils import check_attributes
3 |
4 | __all__ = [
5 | 'SHIFTDataset',
6 | 'check_attributes',
7 | ]
--------------------------------------------------------------------------------
/shift_tta/datasets/shift_dataset.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | from mmengine.fileio import FileClient
4 | from mmtrack.datasets import BaseVideoDataset
5 | from mmtrack.datasets.api_wrappers import CocoVID
6 |
7 | from shift_tta.registry import DATASETS
8 |
9 | from .utils import check_attributes
10 |
11 |
12 | @DATASETS.register_module()
13 | class SHIFTDataset(BaseVideoDataset):
14 | """Dataset class for SHIFT.
15 | Args:
16 | attributes (Optional[Dict[str, ...]]): a dictionary containing the
17 | allowed attributes. Dataset samples will be filtered based on
18 | the allowed attributes. If None, load all samples. Default: None.
19 | """
20 |
21 | METAINFO = dict(
22 | classes = ('pedestrian', 'car', 'truck', 'bus', 'motorcycle', 'bicycle')
23 | )
24 |
25 | def __init__(self,
26 | attributes=None,
27 | *args,
28 | **kwargs):
29 | self.attributes = attributes
30 | super().__init__(*args, **kwargs)
31 |
32 | def filter_by_attributes(self):
33 | """Filter annotations according to filter_cfg.attributes.
34 |
35 | Returns:
36 | list[int]: Filtered results.
37 | """
38 | if self.load_as_video:
39 | valid_data_indices = self._filter_video_by_attributes()
40 | else:
41 | valid_data_indices = self._filter_image_by_attributes()
42 |
43 | return valid_data_indices
44 |
45 | def _filter_video_by_attributes(self):
46 | """Filter video annotations according to filter_cfg.attributes.
47 |
48 | Annotations are filtered based on the attributes of the first
49 | frame in the video.
50 |
51 | Returns:
52 | list[int]: Filtered results.
53 | """
54 | file_client = FileClient.infer_client(uri=self.ann_file)
55 | with file_client.get_local_path(self.ann_file) as local_path:
56 | coco = CocoVID(local_path)
57 |
58 | valid_data_indices = []
59 | data_id = 0
60 | vid_ids = coco.get_vid_ids()
61 | for vid_id in vid_ids:
62 | img_ids = coco.get_img_ids_from_vid(vid_id)
63 | if not len(img_ids) > 0:
64 | continue
65 | raw_img_info = coco.load_imgs([img_ids[0]])[0]
66 | if check_attributes(
67 | raw_img_info['attributes'], self.filter_cfg['attributes']):
68 | valid_data_indices.extend(
69 | list(range(data_id, data_id + len(img_ids))))
70 | data_id += len(img_ids)
71 |
72 | set_valid_data_indices = set(self.valid_data_indices)
73 | valid_data_indices = [
74 | id for id in valid_data_indices if id in set_valid_data_indices
75 | ]
76 | return valid_data_indices
77 |
78 | def _filter_image_by_attributes(self):
79 | """Filter image annotations according to filter_cfg.attributes.
80 |
81 | Returns:
82 | list[int]: Filtered results.
83 | """
84 | valid_data_indices = []
85 | for i, data_info in enumerate(self.data_list):
86 | img_id = data_info['img_id']
87 | if check_attributes(
88 | data_info['attributes'], self.filter_cfg['attributes']
89 | ):
90 | valid_data_indices.append(i)
91 |
92 | set_valid_data_indices = set(self.valid_data_indices)
93 | valid_data_indices = [
94 | id for id in valid_data_indices if id in set_valid_data_indices
95 | ]
96 | return valid_data_indices
97 |
98 | def filter_data(self) -> List[int]:
99 | """Filter annotations according to filter_cfg.
100 |
101 | Returns:
102 | list[int]: Filtered results.
103 | """
104 | # filter data by attributes (useful for domain filtering)
105 | if self.filter_cfg is not None:
106 | self.valid_data_indices = self.filter_by_attributes()
107 |
108 | return super().filter_data()
--------------------------------------------------------------------------------
/shift_tta/datasets/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .filters import check_attributes
2 |
3 | __all__ = ['check_attributes']
--------------------------------------------------------------------------------
/shift_tta/datasets/utils/filters.py:
--------------------------------------------------------------------------------
1 | """Dataset filtering utils."""
2 | from typing import Any, Dict, List, Optional, Union
3 |
4 |
5 | def _check_attributes(
6 | attributes: Union[bool, float, str],
7 | allowed_attributes: Union[bool, float, str, List[float], List[str]],
8 | ) -> bool:
9 | """Check if attributes are allowed.
10 | Args:
11 | attributes: Attributes of current frame.
12 | allowed_attributes: Attributes allowed.
13 | Return:
14 | boolean, whether frame attributes are allowed.
15 | """
16 | if isinstance(allowed_attributes, list):
17 | # assert frame_attributes not in allowed_attributes
18 | return attributes in allowed_attributes
19 | return attributes == allowed_attributes
20 |
21 |
22 | def check_attributes(attributes, allowed_attributes=None):
23 | """Check if a dictionary of attributes is allowed.
24 | Args:
25 | attributes (Dict[str, str]): attributes to check
26 | allowed_attributes (Dict[str, List[str]]): allowed attributes
27 | Return:
28 | boolean, whether frame attributes are allowed.
29 | """
30 | check = True
31 | if allowed_attributes:
32 | for key in allowed_attributes:
33 | allowed_attribute = allowed_attributes[key]
34 | check = check and _check_attributes(
35 | attributes[key], allowed_attribute
36 | )
37 | if not check:
38 | return check
39 | return check
--------------------------------------------------------------------------------
/shift_tta/evaluation/__init__.py:
--------------------------------------------------------------------------------
1 | from .metrics import * # noqa: F401,F403
--------------------------------------------------------------------------------
/shift_tta/evaluation/metrics/__init__.py:
--------------------------------------------------------------------------------
1 | from .shift_video_metrics import SHIFTVideoMetric
2 |
3 |
4 | __all__ = ['SHIFTVideoMetric']
--------------------------------------------------------------------------------
/shift_tta/evaluation/metrics/shift_video_metrics.py:
--------------------------------------------------------------------------------
1 |
2 | import warnings
3 | from typing import Optional, Sequence
4 |
5 | from mmdet.datasets.api_wrappers import COCO
6 | from mmdet.evaluation import CocoMetric
7 | from mmdet.structures.mask import encode_mask_results
8 | from mmengine.dist import broadcast_object_list, is_main_process
9 | from mmengine.fileio import FileClient, dump, load
10 | from mmtrack.evaluation.metrics import CocoVideoMetric
11 |
12 | from scalabel.label.typing import Dataset, Frame, Label, Box3D, RLE
13 | from scalabel.label.transforms import bbox_to_box2d, coco_rle_to_rle, polygon_to_poly2ds
14 |
15 | from shift_tta.registry import METRICS
16 |
17 |
18 | @METRICS.register_module()
19 | class SHIFTVideoMetric(CocoVideoMetric):
20 | """SHIFT evaluation metric.
21 |
22 | Wraps CocoVideoMetric to implement dumping to coco json format so that
23 | it is compatible for conversion to the scalabel format.
24 |
25 | Args:
26 | to_scalabel (bool): Whether to dump results to a Scalabel-style json
27 | file. Defaults to True.
28 | """
29 | def __init__(self, to_scalabel: bool = True, **kwargs) -> None:
30 | super().__init__(**kwargs)
31 | self.to_scalabel = to_scalabel
32 |
33 | def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None:
34 | """Process one batch of data samples and predictions. The processed
35 | results should be stored in ``self.results``, which will be used to
36 | compute the metrics when all batches have been processed.
37 |
38 | Note that we only modify ``pred['pred_instances']`` in ``CocoMetric``
39 | to ``pred['pred_det_instances']`` here.
40 |
41 | Args:
42 | data_batch (dict): A batch of data from the dataloader.
43 | data_samples (Sequence[dict]): A batch of data samples that
44 | contain annotations and predictions.
45 | """
46 | for data_sample in data_samples:
47 | result = dict()
48 | pred = data_sample['pred_det_instances']
49 | result['img_id'] = data_sample['img_id']
50 | result['img_name'] = data_sample['img_path'].split('/')[-1]
51 | result['video_name'] = data_sample['img_path'].split('/')[-2]
52 | result['bboxes'] = pred['bboxes'].cpu().numpy()
53 | result['scores'] = pred['scores'].cpu().numpy()
54 | result['labels'] = pred['labels'].cpu().numpy()
55 | # encode mask to RLE
56 | if 'masks' in pred:
57 | result['masks'] = encode_mask_results(
58 | pred['masks'].detach().cpu().numpy())
59 | # some detectors use different scores for bbox and mask
60 | if 'mask_scores' in pred:
61 | result['mask_scores'] = pred['mask_scores'].cpu().numpy()
62 |
63 | # parse gt
64 | gt = dict()
65 | gt['width'] = data_sample['ori_shape'][1]
66 | gt['height'] = data_sample['ori_shape'][0]
67 | gt['img_id'] = data_sample['img_id']
68 | if self._coco_api is None:
69 | assert 'instances' in data_sample, \
70 | 'ground truth is required for evaluation when ' \
71 | '`ann_file` is not provided'
72 | gt['anns'] = data_sample['instances']
73 | # add converted result to the results list
74 | self.results.append((gt, result))
75 |
76 | def results2scalabel(self, results: Sequence[dict],
77 | outfile_prefix: str) -> dict:
78 | """Dump the detection results to a Scalabel style json file.
79 |
80 | There are 3 types of results: proposals, bbox predictions, mask
81 | predictions, and they have different data types. This method will
82 | automatically recognize the type, and dump them to json files.
83 |
84 | Args:
85 | results (Sequence[dict]): Testing results of the
86 | dataset.
87 | outfile_prefix (str): The filename prefix of the json files. If the
88 | prefix is "somepath/xxx", the json files will be named
89 | "somepath/xxx.bbox.json", "somepath/xxx.segm.json",
90 | "somepath/xxx.proposal.json".
91 |
92 | Returns:
93 | dict: Possible keys are "bbox", "segm", "proposal", and
94 | values are corresponding filenames.
95 | """
96 | bbox_frames = []
97 | segm_frames = [] if 'masks' in results[0] else None
98 | for idx, result in enumerate(results):
99 | image_id = result.get('img_id', idx)
100 | bboxes = result['bboxes']
101 | scores = result['scores']
102 | # bbox results
103 | labels = []
104 | for i, label in enumerate(result['labels']):
105 | label = Label(
106 | id=i,
107 | # category=self.cat_ids[label], # check if this is text cat
108 | category=self.dataset_meta['classes'][label],
109 | box2d=bbox_to_box2d(self.xyxy2xywh(bboxes[i])),
110 | score=float(scores[i]),
111 | )
112 | labels.append(label)
113 |
114 | bbox_frame = Frame(
115 | name=result["img_name"],
116 | videoName=result["video_name"],
117 | frameIndex=result["img_name"].split('.')[0].split('_')[0],
118 | labels=labels,
119 | )
120 | bbox_frames.append(bbox_frame)
121 |
122 | if segm_frames is None:
123 | continue
124 |
125 | # segm results
126 | masks = result['masks']
127 | mask_scores = result.get('mask_scores', scores)
128 | labels = []
129 | for i, label in enumerate(result['labels']):
130 | if isinstance(masks[i]['counts'], bytes):
131 | masks[i]['counts'] = masks[i]['counts'].decode()
132 |
133 | label = Label(
134 | id=label["id"],
135 | category=self.dataset_meta['classes'][label],
136 | box2d=bbox_to_box2d(self.xyxy2xywh(bboxes[i])),
137 | score=float(mask_scores[i])
138 | )
139 | if isinstance(masks[i], list):
140 | label.poly2d = polygon_to_poly2ds(masks[i])
141 | else:
142 | label.rle = coco_rle_to_rle(masks[i])
143 | labels.append(label)
144 | segm_frame = Frame(
145 | name=result["img_name"],
146 | videoName=result["video_name"],
147 | frameIndex=result["img_name"].split('.')[0].split('_')[0],
148 | labels=labels,
149 | )
150 | segm_frames.append(segm_frame)
151 |
152 | bbox_ds = Dataset(frames=bbox_frames, groups=None, config=None)
153 |
154 | result_files = dict()
155 | result_files['bbox'] = f'{outfile_prefix}.bbox.scalabel.json'
156 | result_files['proposal'] = f'{outfile_prefix}.bbox.scalabel.json'
157 | with open(result_files['bbox'], "w") as f:
158 | f.write(bbox_ds.json(exclude_unset=True))
159 |
160 | if segm_frames is not None:
161 | segm_ds = Dataset(frames=segm_frames, groups=None, config=None)
162 | result_files['segm'] = f'{outfile_prefix}.segm.scalabel.json'
163 | with open(result_files['segm'], "w") as f:
164 | f.write(segm_ds.json(exclude_unset=True))
165 |
166 | return result_files
167 |
168 | def results2json(self, results: Sequence[dict],
169 | outfile_prefix: str) -> dict:
170 | """Dump the detection results to a COCO / Scalabel style json file.
171 |
172 | There are 3 types of results: proposals, bbox predictions, mask
173 | predictions, and they have different data types. This method will
174 | automatically recognize the type, and dump them to json files.
175 |
176 | Args:
177 | results (Sequence[dict]): Testing results of the
178 | dataset.
179 | outfile_prefix (str): The filename prefix of the json files. If the
180 | prefix is "somepath/xxx", the json files will be named
181 | "somepath/xxx.bbox.json", "somepath/xxx.segm.json",
182 | "somepath/xxx.proposal.json".
183 |
184 | Returns:
185 | dict: Possible keys are "bbox", "segm", "proposal", and
186 | values are corresponding filenames.
187 | """
188 |
189 | if self.to_scalabel:
190 | _ = self.results2scalabel(results, outfile_prefix)
191 |
192 | return super().results2json(results, outfile_prefix)
193 |
--------------------------------------------------------------------------------
/shift_tta/fileio/__init__.py:
--------------------------------------------------------------------------------
1 | from .backends import TarBackend, ZipBackend
2 |
3 |
4 | __all__ = [
5 | 'TarBackend',
6 | 'ZipBackend'
7 | ]
--------------------------------------------------------------------------------
/shift_tta/fileio/backends/__init__.py:
--------------------------------------------------------------------------------
1 | from .tar_backend import TarBackend
2 | from .zip_backend import ZipBackend
3 |
4 | __all__ = [
5 | 'TarBackend',
6 | 'ZipBackend'
7 | ]
--------------------------------------------------------------------------------
/shift_tta/fileio/backends/tar_backend.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from typing import Union
3 | from tarfile import TarFile
4 |
5 | import os
6 |
7 | from mmengine.fileio.backends import BaseStorageBackend, register_backend
8 |
9 |
10 | @register_backend('tar')
11 | class TarBackend(BaseStorageBackend):
12 | """Backend for loading data from .tar files.
13 |
14 | This backend works with filepaths pointing to valid .tar files. We assume
15 | that the given .tar file contains the whole dataset associated to this
16 | backend.
17 | """
18 |
19 | def __init__(self, tar_path='', **kwargs):
20 | self.tar_path = str(tar_path)
21 | self._client = None
22 |
23 | def get(self, filepath: Union[str, Path]) -> bytes:
24 | """Get values according to the filepath.
25 |
26 | Args:
27 | filepath (str or Path): Here, filepath is the tar key.
28 |
29 | Returns:
30 | bytes: Expected bytes object.
31 |
32 | Examples:
33 | >>> backend = TarBackend('path/to/tar')
34 | >>> backend.get('key')
35 | b'hello world'
36 | """
37 | if self._client is None:
38 | self._client = self._get_client()
39 |
40 | filepath = str(filepath)
41 | try:
42 | with self._client.extractfile(filepath) as data:
43 | data = data.read()
44 | except KeyError as e:
45 | raise ValueError(
46 | f"Value '{filepath}' not found in {self._client}!") from e
47 | return data
48 |
49 | def get_text(self, filepath, encoding=None):
50 | raise NotImplementedError
51 |
52 | def _get_client(self) -> TarFile:
53 | """Get Tar client.
54 |
55 | Returns:
56 | TarFile: the tar file.
57 | """
58 |
59 | if not os.path.exists(self.tar_path):
60 | raise FileNotFoundError(
61 | f"Corresponding tar file not found:" f" {self.tar_path}")
62 |
63 | return TarFile(self.tar_path)
64 |
65 | def __del__(self):
66 | if self._client is not None:
67 | self._client.close()
--------------------------------------------------------------------------------
/shift_tta/fileio/backends/zip_backend.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | from abc import abstractmethod
4 | from zipfile import ZipFile
5 |
6 | from pathlib import Path
7 | from typing import Literal, Union
8 |
9 | from mmengine.fileio.backends import BaseStorageBackend, register_backend
10 |
11 |
12 | @register_backend('zip')
13 | class ZipBackend(BaseStorageBackend):
14 | """Backend for loading data from .zip files.
15 |
16 | This backend works with filepaths pointing to valid .zip files. We assume
17 | that the given .zip file contains the whole dataset associated to this
18 | backend.
19 | """
20 |
21 | def __init__(self, zip_path: Union[str, Path] = '') -> None:
22 | self.zip_path = str(zip_path)
23 | self._client = None
24 |
25 | def get(self, filepath: str) -> bytes:
26 | """Get values according to the filepath.
27 |
28 | Args:
29 | filepath (str or Path): Here, filepath is the zip key.
30 |
31 | Returns:
32 | bytes: Expected bytes object.
33 |
34 | Examples:
35 | >>> backend = TarBackend('path/to/tar')
36 | >>> backend.get('key')
37 | b'hello world'
38 | """
39 |
40 | if self._client is None:
41 | self._client = self._get_client('r')
42 |
43 | filepath = str(filepath)
44 | try:
45 | with self._client.open(filepath) as data:
46 | data = data.read()
47 | except KeyError as e:
48 | raise ValueError(
49 | f"Value '{filepath}' not found in {self._client}!") from e
50 | return bytes(data)
51 |
52 | def get_text(self, filepath, encoding=None):
53 | raise NotImplementedError
54 |
55 | def _get_client(self, mode: Literal["r", "w", "a", "x"]) -> ZipFile:
56 | """Get Zip client.
57 |
58 | Args:
59 | mode (str): Mode to open the file in.
60 |
61 | Returns:
62 | ZipFile: the zip file.
63 | """
64 | assert len(mode) == 1, "Mode must be a single character for zip file."
65 |
66 | if not os.path.exists(self.zip_path):
67 | raise FileNotFoundError(
68 | f"Corresponding zip file not found:" f" {self.zip_path}")
69 |
70 | return ZipFile(self.zip_path, mode)
71 |
72 | def __del__(self):
73 | if self._client is not None:
74 | self._client.close()
--------------------------------------------------------------------------------
/shift_tta/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .adapters import * # noqa: F401,F403
2 | from .detectors import * # noqa: F401,F403
3 | from .losses import * # noqa: F401,F403
4 |
--------------------------------------------------------------------------------
/shift_tta/models/adapters/__init__.py:
--------------------------------------------------------------------------------
1 | from .base_adapter import BaseAdapter
2 | from .mean_teacher_yolox_adapter import MeanTeacherYOLOXAdapter
3 |
4 | __all__ = [
5 | 'BaseAdapter',
6 | 'MeanTeacherYOLOXAdapter'
7 | ]
--------------------------------------------------------------------------------
/shift_tta/models/adapters/base_adapter.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from abc import ABCMeta, abstractmethod
3 | from typing import List
4 |
5 | import torch
6 | from copy import deepcopy
7 |
8 | from mmengine.structures import InstanceData
9 | from mmtrack.structures import TrackDataSample
10 |
11 |
12 | class BaseAdapter(metaclass=ABCMeta):
13 | """Base adapter model.
14 |
15 | Args:
16 | episodic (bool, optional). If episodic is True, the model will be reset
17 | to its initial state at the end of every evaluated sequence.
18 | Defaults to True.
19 | """
20 |
21 | def __init__(self,
22 | episodic: bool = True) -> None:
23 | super().__init__()
24 | self.episodic = episodic
25 | self.fp16_enabled = False
26 |
27 | self.source_model_state = None
28 |
29 | def _init_source_model_state(self, model) -> None:
30 | """Init self.source_model_state.
31 |
32 | Args:
33 | model (nn.Module): detection model.
34 | """
35 | self.source_model_state = deepcopy(model.state_dict())
36 |
37 | def _restore_source_model_state(self, model) -> None:
38 | """Init self.source_model_state.
39 |
40 | Args:
41 | model (nn.Module): detection model.
42 | """
43 |
44 | if self.source_model_state is None:
45 | raise Exception("cannot reset without saved model state")
46 | model.load_state_dict(self.source_model_state, strict=True)
47 |
48 | def reset(self, model) -> None:
49 | """Reset the model state to self.source_model_state.
50 |
51 | Args:
52 | model (nn.Module): detection model.
53 | """
54 | if self.source_model_state is None:
55 | self._init_source_model_state(model)
56 | else:
57 | self._restore_source_model_state(model)
58 |
59 | @property
60 | def with_episodic(self) -> bool:
61 | """Whether the model has to be reset at the end of every sequence."""
62 | return True if self.episodic else False
63 |
64 | @abstractmethod
65 | def _adapt(self, *args, **kwargs):
66 | """Adapt the model."""
67 | pass
68 |
69 | def adapt(self, model: torch.nn.Module, img: torch.Tensor,
70 | feats: List[torch.Tensor], data_sample: TrackDataSample,
71 | **kwargs) -> InstanceData:
72 | """Adapt the model.
73 |
74 |
75 | Args:
76 | model (nn.Module): detection model.
77 | img (Tensor): of shape (T, C, H, W) encoding input image.
78 | Typically these should be mean centered and std scaled.
79 | The T denotes the number of key images and usually is 1 in
80 | ByteTrack method.
81 | feats (list[Tensor]): Multi level feature maps of `img`.
82 | data_sample (:obj:`TrackDataSample`): The data sample.
83 | It includes information such as `pred_det_instances`.
84 |
85 | Returns:
86 | :obj:`InstanceData`: Detection results of the input images.
87 | Each InstanceData usually contains ``bboxes``, ``labels``,
88 | ``scores`` and ``instances_id``.
89 | """
90 | metainfo = data_sample.metainfo
91 | bboxes = data_sample.pred_det_instances.bboxes
92 | labels = data_sample.pred_det_instances.labels
93 | scores = data_sample.pred_det_instances.scores
94 |
95 | frame_id = metainfo.get('frame_id', -1)
96 | if self.with_episodic and frame_id == 0:
97 | self.reset(model)
98 |
99 | # adapt model
100 | self._adapt() # TODO: implement your own adapt method here
101 |
102 | # update pred_det_instances
103 | pred_det_instances = InstanceData()
104 | pred_det_instances.bboxes = bboxes
105 | pred_det_instances.labels = labels
106 | pred_det_instances.scores = scores
107 |
108 | return pred_det_instances
--------------------------------------------------------------------------------
/shift_tta/models/adapters/mean_teacher_yolox_adapter.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from typing import List, Optional, Tuple
3 |
4 | import torch
5 |
6 | from copy import deepcopy
7 |
8 | from mmengine.dataset import Compose
9 | from mmengine.optim import build_optim_wrapper
10 | from mmengine.structures import InstanceData
11 | from mmtrack.structures import TrackDataSample
12 |
13 | from shift_tta.registry import MODELS
14 | from .base_adapter import BaseAdapter
15 |
16 |
17 | @MODELS.register_module()
18 | class MeanTeacherYOLOXAdapter(BaseAdapter):
19 | """Mean-teacher YOLOX adapter model.
20 |
21 | Args:
22 | teacher (dict): Configuration of teacher. Defaults to None.
23 | optim_wrapper (dict): Configuration of optimizer wrapper.
24 | Defaults to None.
25 | loss (dict): Configuration of loss. Defaults to None.
26 | pipeline (list(dict)): Configuration of image transforms.
27 | Defaults to None.
28 | """
29 |
30 | def __init__(self,
31 | teacher: Optional[dict] = None,
32 | optim_wrapper: Optional[dict] = None,
33 | optim_steps: int = 0,
34 | loss: Optional[dict] = dict(
35 | type='ROIConsistencyLoss',
36 | weight=0.01,
37 | ),
38 | pipeline: Optional[list[dict]] = None,
39 | teacher_pipeline: Optional[list[dict]] = None,
40 | student_pipeline: Optional[list[dict]] = None,
41 | views: int = 1,
42 | **kwargs) -> None:
43 | super().__init__(**kwargs)
44 |
45 | self.teacher = None
46 | if teacher is not None:
47 | self.teacher_cfg = teacher
48 |
49 | # build optimizer
50 | self.optim_wrapper = None
51 | if optim_wrapper is not None:
52 | self.optim_wrapper_cfg = optim_wrapper
53 | self.optim_steps = optim_steps
54 |
55 | # build loss
56 | self.loss = MODELS.build(loss)
57 |
58 | # build image transforms
59 | self.pipeline = Compose(pipeline)
60 | self.teacher_pipeline = Compose(teacher_pipeline)
61 | self.student_pipeline = Compose(student_pipeline)
62 | self.views = views
63 |
64 | # TODO: implement param_scheduler for optimizer (e.g. lr decay)
65 |
66 | def _init_source_model_state(self, model) -> None:
67 | """Init self.source_model_state.
68 |
69 | Args:
70 | model (nn.Module): detection model.
71 | """
72 | super()._init_source_model_state(model)
73 |
74 | self.optim_wrapper = build_optim_wrapper(model, self.optim_wrapper_cfg)
75 | self.optim_wrapper_state = deepcopy(self.optim_wrapper.state_dict())
76 |
77 | if self.teacher_cfg is not None:
78 | self.teacher_cfg['model'] = model
79 | self.teacher = MODELS.build(self.teacher_cfg)
80 | self.teacher_model_state = deepcopy(self.teacher.state_dict())
81 |
82 | def _reset_optimizer(self) -> None:
83 | """Reset optimizer state.
84 |
85 | Args:
86 | model (nn.Module): detection model."""
87 | if self.optim_wrapper is not None:
88 | self.optim_wrapper.load_state_dict(self.optim_wrapper_state)
89 |
90 | def _restore_source_model_state(self, model) -> None:
91 | """Init self.source_model_state.
92 |
93 | Args:
94 | model (nn.Module): detection model.
95 | """
96 | super()._restore_source_model_state(model)
97 |
98 | if self.teacher is not None:
99 | self.teacher.load_state_dict(
100 | self.teacher_model_state, strict=True)
101 |
102 | def _detect_forward(self, detector: torch.nn.Module, img: torch.Tensor,
103 | batch_data_samples: TrackDataSample, rescale: bool = True):
104 | """Detector forward pass."""
105 | feats = detector.extract_feat(img)
106 | outs = detector.bbox_head.forward(feats)
107 |
108 | batch_img_metas = [
109 | data_samples.metainfo for data_samples in batch_data_samples
110 | ]
111 | predictions = detector.bbox_head.predict_by_feat(
112 | *outs, batch_img_metas=batch_img_metas, rescale=rescale)
113 | det_results = detector.add_pred_to_datasample(
114 | batch_data_samples, predictions)
115 |
116 | return det_results, outs
117 |
118 | def _expand_view(self, outs: Tuple[torch.Tensor], views: int = 1):
119 | """Expand batch size of each element in outs to views."""
120 | outs = tuple(list(o.repeat_interleave(views, dim=0) for o in out)
121 | for out in outs)
122 | return outs
123 |
124 | def _adapt(self, model: torch.nn.Module,
125 | teacher_img: torch.Tensor,
126 | student_imgs: torch.Tensor,
127 | teacher_data_samples: List[TrackDataSample],
128 | student_data_samples: List[TrackDataSample],
129 | *args, **kwargs) -> InstanceData:
130 | """Adapt the model."""
131 |
132 | # teacher forward
133 | teacher_det_results, teacher_outs = self._detect_forward(
134 | self.teacher.module.detector, teacher_img, teacher_data_samples)
135 | teacher_outs = self._expand_view(teacher_outs, views=self.views)
136 | teacher_outs = dict(
137 | cls_score=teacher_outs[0],
138 | bbox_pred=teacher_outs[1],
139 | objectness=teacher_outs[2])
140 |
141 | # student forward
142 | _, outs = self._detect_forward(
143 | model.detector, student_imgs, student_data_samples)
144 | outs = dict(
145 | cls_score=outs[0],
146 | bbox_pred=outs[1],
147 | objectness=outs[2])
148 |
149 | # adapt
150 | loss = self.loss(outs, teacher_outs)
151 | loss.backward()
152 | self.optim_wrapper.step()
153 | self.optim_wrapper.zero_grad()
154 |
155 | return teacher_det_results
156 |
157 | def adapt(self, model: torch.nn.Module, img: torch.Tensor,
158 | feats: List[torch.Tensor], data_sample: TrackDataSample,
159 | **kwargs) -> InstanceData:
160 | """Adapt the model.
161 |
162 | Args:
163 | model (nn.Module): detection model.
164 | img (Tensor): of shape (T, C, H, W) encoding input image.
165 | Typically these should be mean centered and std scaled.
166 | The T denotes the number of key images and usually is 1 in
167 | ByteTrack method.
168 | feats (list[Tensor]): Multi level feature maps of `img`.
169 | data_sample (:obj:`TrackDataSample`): The data sample.
170 | It includes information such as `pred_det_instances`.
171 |
172 | Returns:
173 | :obj:`InstanceData`: Detection results of the input images.
174 | Each InstanceData usually contains ``bboxes``, ``labels``,
175 | ``scores`` and ``instances_id``.
176 | """
177 | metainfo = data_sample.metainfo
178 | frame_id = metainfo.get('frame_id', -1)
179 | if self.with_episodic and frame_id == 0:
180 | self.reset(model)
181 |
182 | # adapt model
183 | # TODO: apply multiple image transforms
184 | # data_sample = self.transforms(deepcopy(data_sample))
185 | # TODO: create a batch
186 | # TODO: compute teacher prediction on clean target image
187 | # TODO: compute distill loss into augmented batch
188 | # (concat targets to batch size)
189 |
190 | # make teacher and student views
191 | results = dict(img_path=data_sample.img_path,
192 | instances=data_sample.instances,
193 | )
194 | results = self.pipeline(results)
195 | teacher_results = self.teacher_pipeline(results)
196 | teacher_img = teacher_results['inputs']['img'].to(img)
197 | teacher_data_samples = [teacher_results['data_samples']]
198 |
199 | student_imgs = []
200 | student_data_samples = []
201 | for _ in range(self.views):
202 | student_results = self.student_pipeline(results)
203 | student_imgs.append(student_results['inputs']['img'])
204 | student_data_samples.append(student_results['data_samples'])
205 | student_imgs = torch.cat(student_imgs).to(img)
206 |
207 | with torch.enable_grad():
208 | model.requires_grad_(True)
209 | model.train(True)
210 | for _ in range(self.optim_steps):
211 | outs = self._adapt(
212 | model, teacher_img, student_imgs,
213 | teacher_data_samples, student_data_samples)
214 | self.teacher.update_parameters(model)
215 |
216 | self._reset_optimizer()
217 |
218 | # update pred_det_instances
219 | pred_det_instances = outs[0].pred_instances.clone()
220 |
221 | return pred_det_instances
--------------------------------------------------------------------------------
/shift_tta/models/detectors/__init__.py:
--------------------------------------------------------------------------------
1 | from .base import BaseAdaptiveDetector
2 | from .adaptive_detector import AdaptiveDetector
3 |
4 | __all__ = [
5 | 'BaseAdaptiveDetector',
6 | 'AdaptiveDetector'
7 | ]
--------------------------------------------------------------------------------
/shift_tta/models/detectors/adaptive_detector.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from typing import Dict, Optional
3 |
4 | import torch
5 | from torch import Tensor
6 |
7 | from mmtrack.utils import OptConfigType, OptMultiConfig, SampleList
8 |
9 | from shift_tta.registry import MODELS
10 | from .base import BaseAdaptiveDetector
11 |
12 |
13 | @MODELS.register_module()
14 | class AdaptiveDetector(BaseAdaptiveDetector):
15 | """AdaptiveDetector: baseline test-time adaptation method for object detection.
16 |
17 | Args:
18 | detector (dict): Configuration of detector. Defaults to None.
19 | adapter (dict): Configuration of adapter. Defaults to None.
20 | data_preprocessor (dict or ConfigDict, optional): The pre-process
21 | config of :class:`TrackDataPreprocessor`. it usually includes,
22 | ``pad_size_divisor``, ``pad_value``, ``mean`` and ``std``.
23 | init_cfg (dict or list[dict]): Configuration of initialization.
24 | Defaults to None.
25 | """
26 |
27 | def __init__(self,
28 | detector: Optional[dict] = None,
29 | adapter: Optional[dict] = None,
30 | data_preprocessor: OptConfigType = None,
31 | init_cfg: OptMultiConfig = None):
32 | super().__init__(data_preprocessor, init_cfg)
33 |
34 | if detector is not None:
35 | self.detector = MODELS.build(detector)
36 |
37 | if adapter is not None:
38 | self.adapter = MODELS.build(adapter)
39 |
40 | def loss(self, inputs: Dict[str, Tensor], data_samples: SampleList,
41 | **kwargs) -> dict:
42 | """Calculate losses from a batch of inputs and data samples.
43 |
44 | Args:
45 | inputs (Dict[str, Tensor]): of shape (N, T, C, H, W) encoding
46 | input images. Typically these should be mean centered and std
47 | scaled. The N denotes batch size.The T denotes the number of
48 | key/reference frames.
49 | - img (Tensor) : The key images.
50 | - ref_img (Tensor): The reference images.
51 | data_samples (list[:obj:`TrackDataSample`]): The batch
52 | data samples. It usually includes information such
53 | as `gt_instance`.
54 |
55 | Returns:
56 | dict: A dictionary of loss components.
57 | """
58 | # modify the inputs shape to fit mmdet
59 | img = inputs['img']
60 | assert img.size(1) == 1
61 | # convert 'inputs' shape to (N, C, H, W)
62 | img = torch.squeeze(img, dim=1)
63 | return self.detector.loss(img, data_samples, **kwargs)
64 |
65 | def predict(self, inputs: Dict[str, Tensor], data_samples: SampleList,
66 | **kwargs) -> SampleList:
67 | """Predict results from a batch of inputs and data samples with post-
68 | processing.
69 |
70 | Args:
71 | inputs (Dict[str, Tensor]): of shape (N, T, C, H, W) encoding
72 | input images. Typically these should be mean centered and std
73 | scaled. The N denotes batch size.The T denotes the number of
74 | key/reference frames.
75 | - img (Tensor) : The key images.
76 | - ref_img (Tensor): The reference images.
77 | data_samples (list[:obj:`TrackDataSample`]): The batch
78 | data samples. It usually includes information such
79 | as `gt_instance`.
80 |
81 | Returns:
82 | SampleList: Tracking results of the input images.
83 | Each TrackDataSample usually contains ``pred_det_instances``
84 | or ``pred_track_instances``.
85 | """
86 | img = inputs['img']
87 | assert img.dim() == 5, 'The img must be 5D Tensor (N, T, C, H, W).'
88 | assert img.size(0) == 1, \
89 | 'AdaptiveDetector inference only support 1 batch size per gpu for now.'
90 | img = img[0]
91 |
92 | assert len(data_samples) == 1, \
93 | 'AdaptiveDetector inference only support 1 batch size per gpu for now.'
94 |
95 | data_sample = data_samples[0]
96 |
97 | if self.with_adapter:
98 | adapted_det_instances = self.adapter.adapt(
99 | model=self,
100 | img=img,
101 | feats=None,
102 | data_sample=data_sample,
103 | **kwargs)
104 | data_sample.pred_det_instances = adapted_det_instances
105 | else:
106 | det_results = self.detector.predict(img, data_samples)
107 | assert len(det_results) == 1, 'Batch inference is not supported.'
108 | data_sample.pred_det_instances = \
109 | det_results[0].pred_instances.clone()
110 |
111 | return [data_sample]
--------------------------------------------------------------------------------
/shift_tta/models/detectors/base.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from abc import ABCMeta, abstractmethod
3 | from typing import Dict, List, Tuple, Union
4 |
5 | from mmengine.model import BaseModel
6 | from torch import Tensor
7 |
8 | from mmtrack.utils import (ForwardResults, OptConfigType, OptMultiConfig,
9 | OptSampleList, SampleList)
10 |
11 |
12 | class BaseAdaptiveDetector(BaseModel, metaclass=ABCMeta):
13 | """Base class for test-time adaptation of object detection.
14 |
15 | Args:
16 | data_preprocessor (dict or ConfigDict, optional): The pre-process
17 | config of :class:`TrackDataPreprocessor`. it usually includes,
18 | ``pad_size_divisor``, ``pad_value``, ``mean`` and ``std``.
19 | init_cfg (dict or list[dict]): Initialization config dict.
20 | """
21 |
22 | def __init__(self,
23 | data_preprocessor: OptConfigType = None,
24 | init_cfg: OptMultiConfig = None) -> None:
25 | super().__init__(
26 | data_preprocessor=data_preprocessor, init_cfg=init_cfg)
27 |
28 | def freeze_module(self, module: Union[List[str], Tuple[str], str]) -> None:
29 | """Freeze module during training."""
30 | if isinstance(module, str):
31 | modules = [module]
32 | else:
33 | if not (isinstance(module, list) or isinstance(module, tuple)):
34 | raise TypeError('module must be a str or a list.')
35 | else:
36 | modules = module
37 | for module in modules:
38 | m = getattr(self, module)
39 | m.eval()
40 | for param in m.parameters():
41 | param.requires_grad = False
42 |
43 | @property
44 | def with_detector(self) -> bool:
45 | """bool: whether the framework has a detector."""
46 | return hasattr(self, 'detector') and self.detector is not None
47 |
48 | @property
49 | def with_adapter(self) -> bool:
50 | """bool: whether the framework has an adapter model."""
51 | return hasattr(self, 'adapter') and self.adapter is not None
52 |
53 | def forward(self,
54 | inputs: Dict[str, Tensor],
55 | data_samples: OptSampleList = None,
56 | mode: str = 'predict',
57 | **kwargs) -> ForwardResults:
58 | """The unified entry for a forward process in both training and test.
59 |
60 | The method should accept three modes: "tensor", "predict" and "loss":
61 |
62 | - "tensor": Forward the whole network and return tensor or tuple of
63 | tensor without any post-processing, same as a common nn.Module.
64 | - "predict": Forward and return the predictions, which are fully
65 | processed to a list of :obj:`TrackDataSample`.
66 | - "loss": Forward and return a dict of losses according to the given
67 | inputs and data samples.
68 |
69 | Note that this method doesn't handle neither back propagation nor
70 | optimizer updating, which are done in the :meth:`train_step`.
71 |
72 | Args:
73 | inputs (Dict[str, Tensor]): of shape (N, T, C, H, W)
74 | encoding input images. Typically these should be mean centered
75 | and std scaled. The N denotes batch size. The T denotes the
76 | number of key/reference frames.
77 | - img (Tensor) : The key images.
78 | - ref_img (Tensor): The reference images.
79 | data_samples (list[:obj:`TrackDataSample`], optional): The
80 | annotation data of every samples. Defaults to None.
81 | mode (str): Return what kind of value. Defaults to 'predict'.
82 |
83 | Returns:
84 | The return type depends on ``mode``.
85 |
86 | - If ``mode="tensor"``, return a tensor or a tuple of tensor.
87 | - If ``mode="predict"``, return a list of :obj:`TrackDataSample`.
88 | - If ``mode="loss"``, return a dict of tensor.
89 | """
90 | if mode == 'loss':
91 | return self.loss(inputs, data_samples, **kwargs)
92 | elif mode == 'predict':
93 | return self.predict(inputs, data_samples, **kwargs)
94 | elif mode == 'tensor':
95 | return self._forward(inputs, data_samples, **kwargs)
96 | else:
97 | raise RuntimeError(f'Invalid mode "{mode}". '
98 | 'Only supports loss, predict and tensor mode')
99 |
100 | @abstractmethod
101 | def loss(self, inputs: Dict[str, Tensor], data_samples: SampleList,
102 | **kwargs) -> Union[dict, tuple]:
103 | """Calculate losses from a batch of inputs and data samples."""
104 | pass
105 |
106 | @abstractmethod
107 | def predict(self, inputs: Dict[str, Tensor], data_samples: SampleList,
108 | **kwargs) -> SampleList:
109 | """Predict results from a batch of inputs and data samples with post-
110 | processing."""
111 | pass
112 |
113 | def _forward(self,
114 | inputs: Dict[str, Tensor],
115 | data_samples: OptSampleList = None,
116 | **kwargs):
117 | """Network forward process. Usually includes backbone, neck and head
118 | forward without any post-processing.
119 |
120 | Args:
121 | inputs (Dict[str, Tensor]): of shape (N, T, C, H, W).
122 | data_samples (List[:obj:`TrackDataSample`], optional): The
123 | Data Samples. It usually includes information such as
124 | `gt_instance`.
125 |
126 | Returns:
127 | tuple[list]: A tuple of features from ``head`` forward.
128 | """
129 | raise NotImplementedError(
130 | "_forward function (namely 'tensor' mode) is not supported now")
--------------------------------------------------------------------------------
/shift_tta/models/losses/__init__.py:
--------------------------------------------------------------------------------
1 | from .yolox_consistency_loss import YOLOXConsistencyLoss
2 |
3 | __all__ = ['YOLOXConsistencyLoss']
--------------------------------------------------------------------------------
/shift_tta/models/losses/yolox_consistency_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn.functional import mse_loss
4 |
5 | from shift_tta.registry import MODELS
6 |
7 |
8 | @MODELS.register_module()
9 | class YOLOXConsistencyLoss(nn.Module):
10 | """YOLOXConsistencyLoss
11 | Args:
12 | weight (float, optional): Weight of the loss. Default to 1.0.
13 | obj_weight (float, optional): Weight of the objectness consistency loss.
14 | Default to 1.0.
15 | reg_weight (float, optional): Weight of the regression consistency loss.
16 | Default to 1.0.
17 | cls_weight (float, optional): Weight of the classification consistency loss.
18 | Default to 1.0.
19 | """
20 |
21 | def __init__(self,
22 | weight=1.0,
23 | obj_weight=1.0,
24 | reg_weight=1.0,
25 | cls_weight=1.0,
26 | ):
27 | super(YOLOXConsistencyLoss, self).__init__()
28 | self.weight = weight
29 | self.obj_weight = obj_weight
30 | self.reg_weight = reg_weight
31 | self.cls_weight = cls_weight
32 |
33 | def forward(self, inputs, targets, **kwargs):
34 |
35 | """Forward pass.
36 | Args:
37 | inputs: Dictionary of classification scores and bounding box
38 | refinements for the sampled proposals. For cls scores, the shape is
39 | (b*n) * (cats + 1), where n is sampled proposal in each image, cats
40 | is the total number of categories without the background. For bbox
41 | preds, the shape is (b*n) * (4*cats)
42 | targets: Same output by bbox_head from the teacher output.
43 | Returns:
44 | The YOLOX consistency loss.
45 | """
46 | teacher_obj = targets["objectness"]
47 | teacher_reg = targets["bbox_pred"]
48 | teacher_cls = targets["cls_score"]
49 |
50 | student_obj = inputs["objectness"]
51 | student_reg = inputs["bbox_pred"]
52 | student_cls = inputs["cls_score"]
53 |
54 | obj_elements = 0
55 | reg_elements = 0
56 | cls_elements = 0
57 | obj_loss = 0.
58 | reg_loss = 0.
59 | cls_loss = 0.
60 | for (t_obj, t_reg, t_cls, s_obj, s_reg, s_cls) in zip(
61 | teacher_obj, teacher_reg, teacher_cls,
62 | student_obj, student_reg, student_cls,
63 | ):
64 | assert s_obj.shape == t_obj.shape
65 | assert s_reg.shape == t_reg.shape
66 | assert s_cls.shape == t_cls.shape
67 |
68 | _obj_loss = mse_loss(t_obj, s_obj, reduction="none")
69 | _cls_loss = mse_loss(t_cls, s_cls, reduction="none")
70 | _reg_loss = mse_loss(t_reg, s_reg, reduction="none")
71 |
72 | obj_loss += torch.sum(_obj_loss)
73 | reg_loss += torch.sum(_reg_loss)
74 | cls_loss += torch.sum(_cls_loss)
75 |
76 | obj_elements += _obj_loss.numel()
77 | reg_elements += _reg_loss.numel()
78 | cls_elements += _cls_loss.numel()
79 |
80 | obj_loss = obj_loss / obj_elements
81 | reg_loss = reg_loss / reg_elements
82 | cls_loss = cls_loss / cls_elements
83 |
84 | loss = self.obj_weight * obj_loss
85 | loss += self.reg_weight * reg_loss
86 | loss += self.cls_weight * cls_loss
87 |
88 | return self.weight * loss
--------------------------------------------------------------------------------
/shift_tta/registry.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | """MMDetection provides 17 registry nodes to support using modules across
3 | projects. Each node is a child of the root registry in MMEngine.
4 |
5 | More details can be found at
6 | https://mmengine.readthedocs.io/en/latest/tutorials/registry.html.
7 | """
8 |
9 | from mmengine.registry import DATA_SAMPLERS as MMENGINE_DATA_SAMPLERS
10 | from mmengine.registry import DATASETS as MMENGINE_DATASETS
11 | from mmengine.registry import EVALUATOR as MMENGINE_EVALUATOR
12 | from mmengine.registry import HOOKS as MMENGINE_HOOKS
13 | from mmengine.registry import LOG_PROCESSORS as MMENGINE_LOG_PROCESSORS
14 | from mmengine.registry import LOOPS as MMENGINE_LOOPS
15 | from mmengine.registry import METRICS as MMENGINE_METRICS
16 | from mmengine.registry import MODEL_WRAPPERS as MMENGINE_MODEL_WRAPPERS
17 | from mmengine.registry import MODELS as MMENGINE_MODELS
18 | from mmengine.registry import \
19 | OPTIM_WRAPPER_CONSTRUCTORS as MMENGINE_OPTIM_WRAPPER_CONSTRUCTORS
20 | from mmengine.registry import OPTIM_WRAPPERS as MMENGINE_OPTIM_WRAPPERS
21 | from mmengine.registry import OPTIMIZERS as MMENGINE_OPTIMIZERS
22 | from mmengine.registry import PARAM_SCHEDULERS as MMENGINE_PARAM_SCHEDULERS
23 | from mmengine.registry import \
24 | RUNNER_CONSTRUCTORS as MMENGINE_RUNNER_CONSTRUCTORS
25 | from mmengine.registry import RUNNERS as MMENGINE_RUNNERS
26 | from mmengine.registry import TASK_UTILS as MMENGINE_TASK_UTILS
27 | from mmengine.registry import TRANSFORMS as MMENGINE_TRANSFORMS
28 | from mmengine.registry import VISBACKENDS as MMENGINE_VISBACKENDS
29 | from mmengine.registry import VISUALIZERS as MMENGINE_VISUALIZERS
30 | from mmengine.registry import \
31 | WEIGHT_INITIALIZERS as MMENGINE_WEIGHT_INITIALIZERS
32 | from mmengine.registry import Registry
33 |
34 | # manage all kinds of runners like `EpochBasedRunner` and `IterBasedRunner`
35 | RUNNERS = Registry(
36 | 'runner', parent=MMENGINE_RUNNERS)
37 | # manage runner constructors that define how to initialize runners
38 | RUNNER_CONSTRUCTORS = Registry(
39 | 'runner constructor',
40 | parent=MMENGINE_RUNNER_CONSTRUCTORS)
41 | # manage all kinds of loops like `EpochBasedTrainLoop`
42 | LOOPS = Registry(
43 | 'loop', parent=MMENGINE_LOOPS)
44 | # manage all kinds of hooks like `CheckpointHook`
45 | HOOKS = Registry(
46 | 'hook', parent=MMENGINE_HOOKS)
47 |
48 | # manage data-related modules
49 | DATASETS = Registry(
50 | 'dataset', parent=MMENGINE_DATASETS)
51 | DATA_SAMPLERS = Registry(
52 | 'data sampler',
53 | parent=MMENGINE_DATA_SAMPLERS)
54 | TRANSFORMS = Registry(
55 | 'transform',
56 | parent=MMENGINE_TRANSFORMS)
57 |
58 | # manage all kinds of modules inheriting `nn.Module`
59 | MODELS = Registry('model', parent=MMENGINE_MODELS)
60 | # manage all kinds of model wrappers like 'MMDistributedDataParallel'
61 | MODEL_WRAPPERS = Registry(
62 | 'model_wrapper',
63 | parent=MMENGINE_MODEL_WRAPPERS)
64 | # manage all kinds of weight initialization modules like `Uniform`
65 | WEIGHT_INITIALIZERS = Registry(
66 | 'weight initializer',
67 | parent=MMENGINE_WEIGHT_INITIALIZERS)
68 |
69 | # manage all kinds of optimizers like `SGD` and `Adam`
70 | OPTIMIZERS = Registry(
71 | 'optimizer',
72 | parent=MMENGINE_OPTIMIZERS)
73 | # manage optimizer wrapper
74 | OPTIM_WRAPPERS = Registry(
75 | 'optim_wrapper',
76 | parent=MMENGINE_OPTIM_WRAPPERS)
77 | # manage constructors that customize the optimization hyperparameters.
78 | OPTIM_WRAPPER_CONSTRUCTORS = Registry(
79 | 'optimizer constructor',
80 | parent=MMENGINE_OPTIM_WRAPPER_CONSTRUCTORS)
81 | # manage all kinds of parameter schedulers like `MultiStepLR`
82 | PARAM_SCHEDULERS = Registry(
83 | 'parameter scheduler',
84 | parent=MMENGINE_PARAM_SCHEDULERS)
85 | # manage all kinds of metrics
86 | METRICS = Registry(
87 | 'metric', parent=MMENGINE_METRICS)
88 | # manage evaluator
89 | EVALUATOR = Registry(
90 | 'evaluator', parent=MMENGINE_EVALUATOR)
91 |
92 | # manage task-specific modules like anchor generators and box coders
93 | TASK_UTILS = Registry(
94 | 'task util', parent=MMENGINE_TASK_UTILS)
95 |
96 | # manage visualizer
97 | VISUALIZERS = Registry(
98 | 'visualizer',
99 | parent=MMENGINE_VISUALIZERS)
100 | # manage visualizer backend
101 | VISBACKENDS = Registry(
102 | 'vis_backend',
103 | parent=MMENGINE_VISBACKENDS)
104 |
105 | # manage logprocessor
106 | LOG_PROCESSORS = Registry(
107 | 'log_processor',
108 | parent=MMENGINE_LOG_PROCESSORS)
--------------------------------------------------------------------------------
/shift_tta/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .setup_env import register_all_modules
2 |
3 | __all__ = [
4 | 'register_all_modules',
5 | ]
--------------------------------------------------------------------------------
/shift_tta/utils/setup_env.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import warnings
3 |
4 | from mmtrack.utils import register_all_modules as register_all_mmtrack_modules
5 | from mmengine import DefaultScope
6 |
7 |
8 | def register_all_modules(init_default_scope: bool = True) -> None:
9 | """Register all modules in mmtrack into the registries.
10 |
11 | Args:
12 | init_default_scope (bool): Whether initialize the mmtrack default scope.
13 | When `init_default_scope=True`, the global default scope will be
14 | set to `mmtrack`, and all registries will build modules from mmtrack's
15 | registry node. To understand more about the registry, please refer
16 | to https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/registry.md
17 | Defaults to True.
18 | """ # noqa
19 | import shift_tta.fileio # noqa: F401,F403
20 | import shift_tta.datasets # noqa: F401,F403
21 | import shift_tta.evaluation # noqa: F401,F403
22 | import shift_tta.models # noqa: F401,F403
23 |
24 | # register parent modules
25 | register_all_mmtrack_modules(init_default_scope=False)
26 |
27 | if init_default_scope:
28 | never_created = DefaultScope.get_current_instance() is None \
29 | or not DefaultScope.check_instance_created('shift_tta')
30 | if never_created:
31 | DefaultScope.get_instance('shift_tta', scope_name='shift_tta')
32 | return
33 | current_scope = DefaultScope.get_current_instance()
34 | if current_scope.scope_name != 'shift_tta':
35 | warnings.warn('The current default scope '
36 | f'"{current_scope.scope_name}" is not "shift_tta", '
37 | '`register_all_modules` will force the current'
38 | 'default scope to be "shift_tta". If this is not '
39 | 'expected, please set `init_default_scope=False`.')
40 | # avoid name conflict
41 | new_instance_name = f'shift_tta-{datetime.datetime.now()}'
42 | DefaultScope.get_instance(new_instance_name, scope_name='shift_tta')
--------------------------------------------------------------------------------
/shift_tta/version.py:
--------------------------------------------------------------------------------
1 | __version__ = '0.0.0'
2 |
3 |
4 | def parse_version_info(version_str):
5 | version_info = []
6 | for x in version_str.split('.'):
7 | if x.isdigit():
8 | version_info.append(int(x))
9 | elif x.find('rc') != -1:
10 | patch_version = x.split('rc')
11 | version_info.append(int(patch_version[0]))
12 | version_info.append(f'rc{patch_version[1]}')
13 | return tuple(version_info)
14 |
15 |
16 | version_info = parse_version_info(__version__)
--------------------------------------------------------------------------------
/tools/dist_test.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | CONFIG=$1
4 | GPUS=$2
5 | NNODES=${NNODES:-1}
6 | NODE_RANK=${NODE_RANK:-0}
7 | PORT=${PORT:-29500}
8 | MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"}
9 |
10 | # Make conda available
11 | eval "$(conda shell.bash hook)"
12 | # Activate a conda environment
13 | conda activate shift-tta
14 |
15 | export MPLBACKEND=Agg
16 |
17 | python -m torch.distributed.launch \
18 | --nnodes=$NNODES \
19 | --node_rank=$NODE_RANK \
20 | --master_addr=$MASTER_ADDR \
21 | --nproc_per_node=$GPUS \
22 | --master_port=$PORT \
23 | tools/test.py \
24 | $CONFIG \
25 | --launcher pytorch \
26 | ${@:3}
--------------------------------------------------------------------------------
/tools/dist_train.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | CONFIG=$1
4 | GPUS=$2
5 | NNODES=${NNODES:-1}
6 | NODE_RANK=${NODE_RANK:-0}
7 | PORT=${PORT:-29500}
8 | MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"}
9 |
10 | PORT=$(shuf -i 24000-29500 -n 1)
11 |
12 | # Make conda available
13 | eval "$(conda shell.bash hook)"
14 | # Activate a conda environment
15 | conda activate shift-tta
16 |
17 | export MPLBACKEND=Agg
18 |
19 | python -m torch.distributed.launch \
20 | --nnodes=$NNODES \
21 | --node_rank=$NODE_RANK \
22 | --master_addr=$MASTER_ADDR \
23 | --nproc_per_node=$GPUS \
24 | --master_port=$PORT \
25 | tools/train.py \
26 | $CONFIG \
27 | --launcher pytorch \
28 | ${@:3}
--------------------------------------------------------------------------------
/tools/install/setup_venv.sh:
--------------------------------------------------------------------------------
1 |
2 | conda create -n shift-tta python=3.9 -y
3 | conda activate shift-tta
4 |
5 | conda install pytorch=1.11.0 torchvision cudatoolkit=11.3 -c pytorch -y
6 |
7 | pip install -U openmim
8 | # install mmengine from main branch
9 | python -m pip install git+https://github.com/open-mmlab/mmengine.git@62f9504d701251db763f56658436fd23a586fe25
10 | mim install 'mmcv == 2.0.0rc4'
11 | mim install 'mmdet == 3.0.0rc5'
12 | # install mmclassification from dev-1.x branch at specific commit
13 | python -m pip install git+https://github.com/open-mmlab/mmclassification.git@3ff80f5047fe3f3780a05d387f913dd02999611d
14 | # install mmtracking from dev-1.x branch at specific commit
15 | python -m pip install git+https://github.com/open-mmlab/mmtracking.git@9e4cb98a3cdac749242cd8decb3a172058d4fd6e
16 | python -m pip install git+https://github.com/JonathonLuiten/TrackEval.git
17 | python -m pip install git+https://github.com/scalabel/scalabel.git
18 | python -m pip install --no-input -r requirements.txt
19 |
20 | # install shift-detection-tta
21 | pip install --no-input -v -e .
22 |
--------------------------------------------------------------------------------
/tools/shift/download.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | """
4 | Download script for SHIFT Dataset.
5 |
6 | The data is released under a Creative Commons Attribution-NonCommercial-ShareAlike 4.0 License.
7 | Homepage: www.vis.xyz/shift/.
8 | (C)2022, VIS Group, ETH Zurich.
9 |
10 |
11 | Script usage example:
12 | python download.py --view "[front, left_stereo]" \ # list of view abbreviation to download
13 | --group "[img, semseg]" \ # list of data group abbreviation to download
14 | --split "[train, val, test]" \ # list of split to download
15 | --framerate "[images, videos]" \ # chooses the desired frame rate (images=1fps, videos=10fps)
16 | --shift "discrete" \ # type of domain shifts. Options: discrete, continuous/1x, continuous/10x, continuous/100x
17 | dataset_root # path where to store the downloaded data
18 |
19 | You can set the option to "all" to download the entire data from this option. For example,
20 | python download.py --view "all" --group "[img]" --split "all" --framerate "[images]" .
21 | downloads the entire RGB images from the dataset.
22 | """
23 |
24 | import argparse
25 | import logging
26 | import os
27 | import sys
28 | import tempfile
29 |
30 | import tqdm
31 |
32 | if sys.version_info.major >= 3 and sys.version_info.minor >= 6:
33 | import urllib.request as urllib
34 | else:
35 | import urllib
36 |
37 |
38 | BASE_URL = "https://dl.cv.ethz.ch/shift/"
39 |
40 | FRAME_RATES = [("images", "images (1 fps)"), ("videos", "videos (10 fps)")]
41 |
42 | SPLITS = [
43 | ("train", "training set"),
44 | ("val", "validation set"),
45 | ("test", "testing set"),
46 | ]
47 |
48 | VIEWS = [
49 | ("front", "Front"),
50 | ("left_45", "Left 45°"),
51 | ("left_90", "Left 90°"),
52 | ("right_45", "Right 45°"),
53 | ("right_90", "Right 90°"),
54 | ("left_stereo", "Front (Stereo)"),
55 | ("center", "Center (for LiDAR)"),
56 | ]
57 |
58 | DATA_GROUPS = [
59 | ("img", "zip", "RGB Image"),
60 | ("det_2d", "json", "2D Detection and Tracking"),
61 | ("det_3d", "json", "3D Detection and Tracking"),
62 | ("semseg", "zip", "Semantic Segmentation"),
63 | ("det_insseg_2d", "json", "Instance Segmentation"),
64 | ("flow", "zip", "Optical Flow"),
65 | ("depth", "zip", "Depth Maps"),
66 | ("seq", "csv", "Sequence Info"),
67 | ("lidar", "zip", "LiDAR Point Cloud"),
68 | ]
69 |
70 |
71 | class ProgressBar(tqdm.tqdm):
72 | def update_to(self, batch=1, batch_size=1, total=None):
73 | if total is not None:
74 | self.total = total
75 | self.update(batch * batch_size - self.n)
76 |
77 |
78 | def setup_logger():
79 | log_formatter = logging.Formatter(
80 | "[%(asctime)s] SHIFT Downloader - %(levelname)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S",
81 | )
82 | logger = logging.getLogger("logger")
83 | logger.setLevel(logging.DEBUG)
84 | ch = logging.StreamHandler()
85 | ch.setLevel(logging.DEBUG)
86 | ch.setFormatter(log_formatter)
87 | logger.addHandler(ch)
88 | return logger
89 |
90 |
91 | def get_url_discrete(rate, split, view, group, ext):
92 | url = BASE_URL + "discrete/{rate}/{split}/{view}/{group}.{ext}".format(
93 | rate=rate, split=split, view=view, group=group, ext=ext
94 | )
95 | return url
96 |
97 |
98 | def get_url_continuous(rate, shift_length, split, view, group, ext):
99 | url = BASE_URL + "continuous/{rate}/{shift_length}/{split}/{view}/{group}.{ext}".format(
100 | rate=rate, shift_length=shift_length, split=split, view=view, group=group, ext=ext
101 | )
102 | return url
103 |
104 |
105 | def string_to_list(option_str):
106 | option_str = option_str.replace(" ", "").lstrip("[").rstrip("]")
107 | return option_str.split(",")
108 |
109 |
110 | def parse_options(option_str, bounds, name):
111 | if option_str == "all":
112 | return bounds
113 | candidates = {}
114 | for item in bounds:
115 | candidates[item[0]] = item
116 | used = []
117 | try:
118 | option_list = string_to_list(option_str)
119 | except Exception as e:
120 | logger.error("Error in parsing options." + e)
121 | for option in option_list:
122 | if option not in candidates:
123 | logger.info(
124 | "Invalid option '{option}' for '{name}'. ".format(option=option, name=name)
125 | + "Please check the download document (https://www.vis.xyz/shift/download/)."
126 | )
127 | else:
128 | used.append(candidates[option])
129 | if len(used) == 0:
130 | logger.error(
131 | "No '{name}' is specified to download. ".format(name=name)
132 | + "If you want to download all {name}s, please use '--{name} all'.".format(name=name)
133 | )
134 | sys.exit(1)
135 | return used
136 |
137 |
138 | def download_file(url, out_file):
139 | out_dir = os.path.dirname(out_file)
140 | if not os.path.isdir(out_dir):
141 | os.makedirs(out_dir)
142 | if not os.path.isfile(out_file):
143 | logging.info("downloading " + url)
144 | fh, out_file_tmp = tempfile.mkstemp(dir=out_dir)
145 | f = os.fdopen(fh, "w")
146 | f.close()
147 | filename = url.split("/")[-1]
148 | with ProgressBar(unit="B", unit_scale=True, miniters=1, desc=filename) as t:
149 | urllib.urlretrieve(url, out_file_tmp, reporthook=t.update_to)
150 | os.rename(out_file_tmp, out_file)
151 | else:
152 | logger.warning("Skipping download of existing file " + out_file)
153 |
154 |
155 | def main():
156 | parser = argparse.ArgumentParser(description="Downloads SHIFT Dataset public release.")
157 | parser.add_argument("out_dir", help="output directory in which to store the data.")
158 | parser.add_argument("--split", type=str, default="", help="specific splits to download.")
159 | parser.add_argument("--view", type=str, default="", help="specific views to download.")
160 | parser.add_argument("--group", type=str, default="", help="specific data groups to download.")
161 | parser.add_argument("--framerate", type=str, default="", help="specific frame rate to download.")
162 | parser.add_argument(
163 | "--shift",
164 | type=str,
165 | default="discrete",
166 | choices=["discrete", "continuous/1x", "continuous/10x", "continuous/100x"],
167 | help="specific shift type to download.",
168 | )
169 | args = parser.parse_args()
170 |
171 | print(
172 | "Welcome to use SHIFT Dataset download script! \n"
173 | "By continuing you confirm that you have agreed to the SHIFT's user license.\n"
174 | )
175 |
176 | frame_rates = parse_options(args.framerate, FRAME_RATES, "frame rate")
177 | splits = parse_options(args.split, SPLITS, "split")
178 | views = parse_options(args.view, VIEWS, "view")
179 | data_groups = parse_options(args.group, DATA_GROUPS, "data group")
180 | total_files = len(frame_rates) * len(splits) * len(views) * len(data_groups)
181 | logger.info("Number of files to download: " + str(total_files))
182 |
183 | if "lidar" in data_groups and views != ["center"]:
184 | logger.error("LiDAR data only available for Center view!")
185 | sys.exit(1)
186 |
187 | for rate, rate_name in frame_rates:
188 | for split, split_name in splits:
189 | for view, view_name in views:
190 | for group, ext, group_name in data_groups:
191 | if rate == "videos" and group in ["img"]:
192 | ext = "tar"
193 | if args.shift == "discrete":
194 | url = get_url_discrete(rate, split, view, group, ext)
195 | out_file = os.path.join(args.out_dir, "discrete", rate, split, view, group + "." + ext)
196 | else:
197 | shift_length = args.shift.split("/")[-1]
198 | url = get_url_continuous(rate, shift_length, split, view, group, ext)
199 | out_file = os.path.join(
200 | args.out_dir, "continuous", rate, shift_length, split, view, group + "." + ext
201 | )
202 | logger.info(
203 | "Downloading - Shift: {shift}, Framerate: {rate}, Split: {split}, View: {view}, Data group: {group}.".format(
204 | shift=args.shift,
205 | rate=rate_name,
206 | split=split_name,
207 | view=view_name,
208 | group=group_name,
209 | url=url,
210 | )
211 | )
212 | try:
213 | download_file(url, out_file)
214 | except Exception as e:
215 | logger.error("Error in downloading " + str(e))
216 |
217 | logger.info("Done!")
218 |
219 |
220 | if __name__ == "__main__":
221 | logger = setup_logger()
222 | main()
--------------------------------------------------------------------------------
/tools/test.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import argparse
3 | import os
4 | import os.path as osp
5 |
6 | from mmengine.config import Config, DictAction
7 | from mmengine.model import is_model_wrapper
8 | from mmengine.registry import RUNNERS
9 | from mmengine.runner import Runner
10 |
11 | from shift_tta.utils import register_all_modules
12 |
13 |
14 | # TODO: support fuse_conv_bn, visualization, and format_only
15 | def parse_args():
16 | parser = argparse.ArgumentParser(
17 | description='shift-tta test (and eval) a model')
18 | parser.add_argument('config', help='test config file path')
19 | parser.add_argument('--checkpoint', help='checkpoint file')
20 | parser.add_argument(
21 | '--work-dir',
22 | help='the directory to save the file containing evaluation metrics')
23 | parser.add_argument(
24 | '--cfg-options',
25 | nargs='+',
26 | action=DictAction,
27 | help='override some settings in the used config, the key-value pair '
28 | 'in xxx=yyy format will be merged into config file. If the value to '
29 | 'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
30 | 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
31 | 'Note that the quotation marks are necessary and that no white space '
32 | 'is allowed.')
33 | parser.add_argument(
34 | '--launcher',
35 | choices=['none', 'pytorch', 'slurm', 'mpi'],
36 | default='none',
37 | help='job launcher')
38 | parser.add_argument('--local_rank', type=int, default=0)
39 | args = parser.parse_args()
40 | if 'LOCAL_RANK' not in os.environ:
41 | os.environ['LOCAL_RANK'] = str(args.local_rank)
42 | return args
43 |
44 |
45 | def main():
46 | args = parse_args()
47 |
48 | # register all modules in shift-tta into the registries
49 | # do not init the default scope here because it will be init in the runner
50 | register_all_modules(init_default_scope=False)
51 |
52 | # load config
53 | cfg = Config.fromfile(args.config)
54 | cfg.launcher = args.launcher
55 | if args.cfg_options is not None:
56 | cfg.merge_from_dict(args.cfg_options)
57 |
58 | # work_dir is determined in this priority: CLI > segment in file > filename
59 | if args.work_dir is not None:
60 | # update configs according to CLI args if args.work_dir is not None
61 | cfg.work_dir = args.work_dir
62 | elif cfg.get('work_dir', None) is None:
63 | # use config filename as default work_dir if cfg.work_dir is None
64 | cfg.work_dir = osp.join('./work_dirs',
65 | osp.splitext(osp.basename(args.config))[0])
66 |
67 | cfg.load_from = args.checkpoint
68 |
69 | # build the runner from config
70 | if 'runner_type' not in cfg:
71 | # build the default runner
72 | runner = Runner.from_cfg(cfg)
73 | else:
74 | # build customized runner from the registry
75 | # if 'runner_type' is set in the cfg
76 | runner = RUNNERS.build(cfg)
77 |
78 | if is_model_wrapper(runner.model):
79 | runner.model.module.init_weights()
80 | else:
81 | runner.model.init_weights()
82 |
83 | # start testing
84 | runner.test()
85 |
86 |
87 | if __name__ == '__main__':
88 | main()
--------------------------------------------------------------------------------
/tools/test.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | CONFIG=$1
4 | GPUS=$2
5 | PY_ARGS=${@:3}
6 |
7 | # Make conda available
8 | eval "$(conda shell.bash hook)"
9 | # Activate a conda environment
10 | conda activate shift-tta
11 |
12 | python tools/test.py \
13 | ${CONFIG} \
14 | ${PY_ARGS}
--------------------------------------------------------------------------------
/tools/train.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import argparse
3 | import logging
4 | import os
5 | import os.path as osp
6 |
7 | from mmengine.config import Config, DictAction
8 | from mmengine.logging import print_log
9 | from mmengine.registry import RUNNERS
10 | from mmengine.runner import Runner
11 |
12 | from shift_tta.utils import register_all_modules
13 |
14 |
15 | def parse_args():
16 | parser = argparse.ArgumentParser(description='Train a model')
17 | parser.add_argument('config', help='train config file path')
18 | parser.add_argument('--work-dir', help='the dir to save logs and models')
19 | parser.add_argument(
20 | '--amp',
21 | action='store_true',
22 | default=False,
23 | help='enable automatic-mixed-precision training')
24 | parser.add_argument(
25 | '--auto-scale-lr',
26 | action='store_true',
27 | help='enable automatically scaling LR.')
28 | parser.add_argument(
29 | '--resume',
30 | action='store_true',
31 | help='resume from the latest checkpoint in the work_dir automatically')
32 | parser.add_argument(
33 | '--cfg-options',
34 | nargs='+',
35 | action=DictAction,
36 | help='override some settings in the used config, the key-value pair '
37 | 'in xxx=yyy format will be merged into config file. If the value to '
38 | 'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
39 | 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
40 | 'Note that the quotation marks are necessary and that no white space '
41 | 'is allowed.')
42 | parser.add_argument(
43 | '--launcher',
44 | choices=['none', 'pytorch', 'slurm', 'mpi'],
45 | default='none',
46 | help='job launcher')
47 | parser.add_argument('--local_rank', type=int, default=0)
48 | args = parser.parse_args()
49 | if 'LOCAL_RANK' not in os.environ:
50 | os.environ['LOCAL_RANK'] = str(args.local_rank)
51 |
52 | return args
53 |
54 |
55 | def main():
56 | args = parse_args()
57 |
58 | # register all modules in shift-tta into the registries
59 | # do not init the default scope here because it will be init in the runner
60 | register_all_modules(init_default_scope=False)
61 |
62 | # load config
63 | cfg = Config.fromfile(args.config)
64 | cfg.launcher = args.launcher
65 | if args.cfg_options is not None:
66 | cfg.merge_from_dict(args.cfg_options)
67 |
68 | # work_dir is determined in this priority: CLI > segment in file > filename
69 | if args.work_dir is not None:
70 | # update configs according to CLI args if args.work_dir is not None
71 | cfg.work_dir = args.work_dir
72 | elif cfg.get('work_dir', None) is None:
73 | # use config filename as default work_dir if cfg.work_dir is None
74 | cfg.work_dir = osp.join('./work_dirs',
75 | osp.splitext(osp.basename(args.config))[0])
76 |
77 | # enable automatic-mixed-precision training
78 | if args.amp is True:
79 | optim_wrapper = cfg.optim_wrapper.type
80 | if optim_wrapper == 'AmpOptimWrapper':
81 | print_log(
82 | 'AMP training is already enabled in your config.',
83 | logger='current',
84 | level=logging.WARNING)
85 | else:
86 | assert optim_wrapper == 'OptimWrapper', (
87 | '`--amp` is only supported when the optimizer wrapper type is '
88 | f'`OptimWrapper` but got {optim_wrapper}.')
89 | cfg.optim_wrapper.type = 'AmpOptimWrapper'
90 | cfg.optim_wrapper.loss_scale = 'dynamic'
91 |
92 | # enable automatically scaling LR
93 | if args.auto_scale_lr:
94 | if 'auto_scale_lr' in cfg and \
95 | 'enable' in cfg.auto_scale_lr and \
96 | 'base_batch_size' in cfg.auto_scale_lr:
97 | cfg.auto_scale_lr.enable = True
98 | else:
99 | raise RuntimeError('Can not find "auto_scale_lr" or '
100 | '"auto_scale_lr.enable" or '
101 | '"auto_scale_lr.base_batch_size" in your'
102 | ' configuration file.')
103 | cfg.resume = args.resume
104 |
105 | # build the runner from config
106 | if 'runner_type' not in cfg:
107 | # build the default runner
108 | runner = Runner.from_cfg(cfg)
109 | else:
110 | # build customized runner from the registry
111 | # if 'runner_type' is set in the cfg
112 | runner = RUNNERS.build(cfg)
113 |
114 | # start training
115 | runner.train()
116 |
117 |
118 | if __name__ == '__main__':
119 | main()
--------------------------------------------------------------------------------
/tools/train.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | CONFIG=$1
4 | GPUS=$2
5 | PY_ARGS=${@:3}
6 |
7 | # Make conda available
8 | eval "$(conda shell.bash hook)"
9 | # Activate a conda environment
10 | conda activate shift-tta
11 |
12 | python tools/train.py \
13 | ${CONFIG} \
14 | ${PY_ARGS}
--------------------------------------------------------------------------------