├── .gitignore ├── LICENSE ├── README.md ├── configs ├── _base_ │ ├── datasets │ │ └── shift.py │ ├── default_runtime.py │ └── models │ │ ├── faster-rcnn_r50_fpn.py │ │ └── yolox_x_8x8.py ├── continuous │ ├── mean_teacher_adapter_yolox │ │ └── yolox_x_8xb4-24e_shift_from_clear_daytime.py │ └── no_adap_yolox │ │ ├── yolox_x_8xb4-24e_shift_from_all.py │ │ ├── yolox_x_8xb4-24e_shift_from_clear_daytime.py │ │ └── yolox_x_8xb4-24e_shift_from_clear_night.py └── source │ └── yolox │ ├── README.md │ ├── amp_yolox_x_8xb4-24e_shift_all.py │ ├── amp_yolox_x_8xb4-24e_shift_clear_daytime.py │ ├── amp_yolox_x_8xb4-24e_shift_clear_night.py │ ├── amp_yolox_x_8xb4-24e_shift_daytime.py │ ├── amp_yolox_x_8xb4-24e_shift_night.py │ ├── yolox_x_8xb4-24e_shift_all.py │ ├── yolox_x_8xb4-24e_shift_clear_daytime.py │ ├── yolox_x_8xb4-24e_shift_clear_night.py │ ├── yolox_x_8xb4-24e_shift_daytime.py │ └── yolox_x_8xb4-24e_shift_night.py ├── docs ├── challenge.md ├── dataset_prepare.md ├── get_started.md ├── model_zoo.md └── train_test.md ├── requirements.txt ├── requirements ├── build.txt ├── mminstall.txt └── runtime.txt ├── resources └── shift-logo.png ├── scripts ├── continuous │ ├── mean_teacher_adapter_yolox │ │ ├── slurm_test_yolox_shift_from_clear_daytime.sh │ │ ├── test_yolox_shift_from_clear_daytime.sh │ │ └── val_yolox_shift_from_clear_daytime.sh │ └── no_adap_yolox │ │ ├── test_yolox_shift_from_clear_daytime.sh │ │ ├── test_yolox_shift_from_clear_night.sh │ │ ├── val_yolox_shift_from_clear_daytime.sh │ │ └── val_yolox_shift_from_clear_night.sh ├── source │ ├── slurm_train_yolox_24e_shift_clear_daytime.sh │ ├── slurm_train_yolox_24e_shift_daytime.sh │ ├── test_yolox_shift_clear_daytime.sh │ ├── test_yolox_shift_clear_night.sh │ └── train_yolox_shift_clear_daytime.sh └── tmp │ ├── edit_json.py │ └── edit_json_ids.py ├── setup.cfg ├── setup.py ├── shift_tta ├── .mim │ ├── configs │ └── tools ├── __init__.py ├── datasets │ ├── __init__.py │ ├── shift_dataset.py │ └── utils │ │ ├── __init__.py │ │ └── filters.py ├── evaluation │ ├── __init__.py │ └── metrics │ │ ├── __init__.py │ │ └── shift_video_metrics.py ├── fileio │ ├── __init__.py │ └── backends │ │ ├── __init__.py │ │ ├── tar_backend.py │ │ └── zip_backend.py ├── models │ ├── __init__.py │ ├── adapters │ │ ├── __init__.py │ │ ├── base_adapter.py │ │ └── mean_teacher_yolox_adapter.py │ ├── detectors │ │ ├── __init__.py │ │ ├── adaptive_detector.py │ │ └── base.py │ └── losses │ │ ├── __init__.py │ │ └── yolox_consistency_loss.py ├── registry.py ├── utils │ ├── __init__.py │ └── setup_env.py └── version.py └── tools ├── dist_test.sh ├── dist_train.sh ├── install └── setup_venv.sh ├── shift └── download.py ├── test.py ├── test.sh ├── train.py └── train.sh /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/en/_build/ 68 | docs/zh_cn/_build/ 69 | 70 | # PyBuilder 71 | target/ 72 | 73 | # Jupyter Notebook 74 | .ipynb_checkpoints 75 | 76 | # pyenv 77 | .python-version 78 | 79 | # celery beat schedule file 80 | celerybeat-schedule 81 | 82 | # SageMath parsed files 83 | *.sage.py 84 | 85 | # Environments 86 | .env 87 | .venv 88 | env/ 89 | venv/ 90 | ENV/ 91 | env.bak/ 92 | venv.bak/ 93 | 94 | # Spyder project settings 95 | .spyderproject 96 | .spyproject 97 | 98 | # Rope project settings 99 | .ropeproject 100 | 101 | # mkdocs documentation 102 | /site 103 | 104 | # mypy 105 | .mypy_cache/ 106 | 107 | data 108 | ckpts 109 | .vscode 110 | .idea 111 | .DS_Store 112 | 113 | # custom 114 | *.pkl 115 | *.pkl.json 116 | *.log.json 117 | work_dirs/ 118 | shift_tta/.mim 119 | shift_tta.egg_info 120 | errors/ 121 | outputs/ 122 | checkpoints/ 123 | scripts/tmp/ 124 | 125 | # Pytorch 126 | *.pth 127 | *.py~ 128 | *.sh~ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 ETH VIS Group 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 |
 
4 |
5 | 6 | [ SHIFT Project Page ] 7 | 8 |      9 | 10 | [ SHIFT Paper (CVPR 2022) ] 11 | 12 |      13 | 14 | [ VIS Group ] 15 | 16 |
17 |
 
18 |
19 | 20 | ## Introduction 21 | 22 | [SHIFT](https://www.vis.xyz/shift/) is a driving dataset for continuous multi-task domain adaptation. It is maintained by the [VIS Group](https://www.vis.xyz/) at ETH Zurich. 23 | 24 | The main branch works with **PyTorch1.6+**. 25 | 26 | 27 | https://github.com/SysCV/shift-detection-tta/assets/44324619/9ddc4b31-7ca9-46b1-a1c5-3b9107e04f9e 28 | 29 | 30 | 31 | 32 | ## Tutorial 33 | ### Get started 34 | 35 | Please refer to [get_started.md](docs/get_started.md) for install instructions. 36 | 37 | ### Prepare the SHIFT dataset 38 | 39 | Please refer to [dataset_prepare.md](docs/dataset_prepare.md) for instructions on how to download and prepare the SHIFT dataset. 40 | 41 | ### Usage 42 | 43 | Please refer to [train_test.md](docs/train_test.md) for instructions on how to train and test your own model. 44 | 45 | ### Participate in the Challenge on Continuous Test-time Adaptation for Object Detection 46 | 47 | Please refer to [challenge.md](docs/challenge.md) for instructions on how to participate in the challenge and for training, test, and adaptation instructions. 48 | 49 | The challenge is organized for the Workshop on [Visual Continual Learning @ ICCV2023](https://wvcl.vis.xyz). Checkout [wvcl.vis.xyz/challenges](https://wvcl.vis.xyz/challenges) for additional details on this and other challenges. 50 | 51 | We will award the top three teams of each challenge with a certificate and a prize of 1000, 500, and 300 USD, respectively. The winners of each challenge will be invited to give a presentation at the workshop. Teams will be selected based on the performance of their methods on the test set. 52 | 53 | We will also award one team from each challenge with an innovation award. The innovation award is given to the team that proposes the most innovative method and/or insightful analysis. The winner will receive a certificate and an additional prize of 300 USD. 54 | 55 | **Please notice** that this challenge is part of the track **Challenge B - Continual Test-time Adaptation**, together with the challenge on "Continuous Test-time Adaptation for Semantic Segmentation". Since the challenge on "Continuous Test-time Adaptation for Object Detection" constitutes half of the track B, the prize should be considered half of what mentioned above. 56 | 57 | # Continuous Test-time Adaptation for Object Detection 58 | 59 | ## Model zoo 60 | 61 | Results and models are available in the [model zoo](docs/model_zoo.md). 62 | 63 | ### Object Detection 64 | 65 | Supported Adaptation Methods 66 | - [x] [no_adap](configs/continuous/no_adap_yolox) 67 | - [x] [mean_teacher_adapter_yolox](configs/continuous/mean_teacher_adapter_yolox) 68 | 69 | Supported Datasets 70 | 71 | - [x] [SHIFT](https://www.vis.xyz/shift/) 72 | 73 | 74 | ## Citation 75 | 76 | If you find this project useful in your research, please consider citing: 77 | 78 | - SHIFT, the dataset powering this challenge and the continuous adaptation tasks: 79 | 80 | ```latex 81 | @inproceedings{sun2022shift, 82 | title={SHIFT: a synthetic driving dataset for continuous multi-task domain adaptation}, 83 | author={Sun, Tao and Segu, Mattia and Postels, Janis and Wang, Yuxuan and Van Gool, Luc and Schiele, Bernt and Tombari, Federico and Yu, Fisher}, 84 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 85 | pages={21371--21382}, 86 | year={2022} 87 | } 88 | ``` 89 | 90 | - DARTH, the test-time adaptation method introducing the detection consistency loss for detection adaptation based on mean-teacher: 91 | ```latex 92 | @inproceedings{segu2023darth, 93 | title={Darth: holistic test-time adaptation for multiple object tracking}, 94 | author={Segu, Mattia and Schiele, Bernt and Yu, Fisher}, 95 | booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision}, 96 | pages={9717--9727}, 97 | year={2023} 98 | } 99 | ``` 100 | 101 | ## License 102 | 103 | This project is released under the [MIT License](LICENSE). 104 | -------------------------------------------------------------------------------- /configs/_base_/datasets/shift.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SysCV/shift-detection-tta/a5c4fb0e906caa57c0b572d651f24f01ff6bdcdf/configs/_base_/datasets/shift.py -------------------------------------------------------------------------------- /configs/_base_/default_runtime.py: -------------------------------------------------------------------------------- 1 | default_scope = 'shift_tta' 2 | 3 | default_hooks = dict( 4 | timer=dict(type='IterTimerHook'), 5 | logger=dict(type='LoggerHook', interval=50), 6 | param_scheduler=dict(type='ParamSchedulerHook'), 7 | checkpoint=dict(type='CheckpointHook', interval=1), 8 | sampler_seed=dict(type='DistSamplerSeedHook'), 9 | visualization=dict( 10 | type='mmtrack.TrackVisualizationHook', 11 | draw=False), 12 | ) 13 | 14 | env_cfg = dict( 15 | cudnn_benchmark=False, 16 | mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), 17 | dist_cfg=dict(backend='nccl'), 18 | ) 19 | 20 | vis_backends = [dict(type='LocalVisBackend')] 21 | visualizer = dict( 22 | type='mmtrack.TrackLocalVisualizer', vis_backends=vis_backends, name='visualizer') 23 | 24 | log_level = 'INFO' 25 | load_from = None 26 | resume = False -------------------------------------------------------------------------------- /configs/_base_/models/faster-rcnn_r50_fpn.py: -------------------------------------------------------------------------------- 1 | model = dict( 2 | data_preprocessor=dict( 3 | type='mmtrack.TrackDataPreprocessor', 4 | mean=[123.675, 116.28, 103.53], 5 | std=[58.395, 57.12, 57.375], 6 | bgr_to_rgb=True, 7 | rgb_to_bgr=False, 8 | pad_size_divisor=32), 9 | detector=dict( 10 | type='FasterRCNN', 11 | _scope_='mmdet', 12 | backbone=dict( 13 | type='ResNet', 14 | depth=50, 15 | num_stages=4, 16 | out_indices=(0, 1, 2, 3), 17 | frozen_stages=1, 18 | norm_cfg=dict(type='BN', requires_grad=True), 19 | norm_eval=True, 20 | style='pytorch', 21 | init_cfg=dict( 22 | type='Pretrained', checkpoint='torchvision://resnet50')), 23 | neck=dict( 24 | type='FPN', 25 | in_channels=[256, 512, 1024, 2048], 26 | out_channels=256, 27 | num_outs=5), 28 | rpn_head=dict( 29 | type='RPNHead', 30 | in_channels=256, 31 | feat_channels=256, 32 | anchor_generator=dict( 33 | type='AnchorGenerator', 34 | scales=[8], 35 | ratios=[0.5, 1.0, 2.0], 36 | strides=[4, 8, 16, 32, 64]), 37 | bbox_coder=dict( 38 | type='DeltaXYWHBBoxCoder', 39 | target_means=[.0, .0, .0, .0], 40 | target_stds=[1.0, 1.0, 1.0, 1.0]), 41 | loss_cls=dict( 42 | type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), 43 | loss_bbox=dict( 44 | type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)), 45 | roi_head=dict( 46 | type='StandardRoIHead', 47 | bbox_roi_extractor=dict( 48 | type='SingleRoIExtractor', 49 | roi_layer=dict( 50 | type='RoIAlign', output_size=7, sampling_ratio=0), 51 | out_channels=256, 52 | featmap_strides=[4, 8, 16, 32]), 53 | bbox_head=dict( 54 | type='Shared2FCBBoxHead', 55 | in_channels=256, 56 | fc_out_channels=1024, 57 | roi_feat_size=7, 58 | num_classes=80, 59 | bbox_coder=dict( 60 | type='DeltaXYWHBBoxCoder', 61 | target_means=[0., 0., 0., 0.], 62 | target_stds=[0.1, 0.1, 0.2, 0.2]), 63 | reg_class_agnostic=False, 64 | loss_cls=dict( 65 | type='CrossEntropyLoss', 66 | use_sigmoid=False, 67 | loss_weight=1.0), 68 | loss_bbox=dict(type='SmoothL1Loss', loss_weight=1.0))), 69 | train_cfg=dict( 70 | rpn=dict( 71 | assigner=dict( 72 | type='MaxIoUAssigner', 73 | pos_iou_thr=0.7, 74 | neg_iou_thr=0.3, 75 | min_pos_iou=0.3, 76 | match_low_quality=True, 77 | ignore_iof_thr=-1), 78 | sampler=dict( 79 | type='RandomSampler', 80 | num=256, 81 | pos_fraction=0.5, 82 | neg_pos_ub=-1, 83 | add_gt_as_proposals=False), 84 | allowed_border=-1, 85 | pos_weight=-1, 86 | debug=False), 87 | rpn_proposal=dict( 88 | nms_pre=2000, 89 | max_per_img=1000, 90 | nms=dict(type='nms', iou_threshold=0.7), 91 | min_bbox_size=0), 92 | rcnn=dict( 93 | assigner=dict( 94 | type='MaxIoUAssigner', 95 | pos_iou_thr=0.5, 96 | neg_iou_thr=0.5, 97 | min_pos_iou=0.5, 98 | match_low_quality=False, 99 | ignore_iof_thr=-1), 100 | sampler=dict( 101 | type='RandomSampler', 102 | num=512, 103 | pos_fraction=0.25, 104 | neg_pos_ub=-1, 105 | add_gt_as_proposals=True), 106 | pos_weight=-1, 107 | debug=False)), 108 | test_cfg=dict( 109 | rpn=dict( 110 | nms_pre=1000, 111 | max_per_img=1000, 112 | nms=dict(type='nms', iou_threshold=0.7), 113 | min_bbox_size=0), 114 | rcnn=dict( 115 | score_thr=0.05, 116 | nms=dict(type='nms', iou_threshold=0.5), 117 | max_per_img=100)) 118 | # soft-nms is also supported for rcnn testing 119 | # e.g., nms=dict(type='soft_nms', iou_threshold=0.5, min_score=0.05) 120 | )) -------------------------------------------------------------------------------- /configs/_base_/models/yolox_x_8x8.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | img_scale = (640, 640) 3 | 4 | model = dict( 5 | data_preprocessor=dict( 6 | type='mmtrack.TrackDataPreprocessor', 7 | pad_size_divisor=32, 8 | batch_augments=[ 9 | dict( 10 | type='mmdet.BatchSyncRandomResize', 11 | random_size_range=(480, 800), 12 | size_divisor=32, 13 | interval=10) 14 | ]), 15 | detector=dict( 16 | _scope_='mmdet', 17 | type='YOLOX', 18 | backbone=dict( 19 | type='CSPDarknet', deepen_factor=1.33, widen_factor=1.25), 20 | neck=dict( 21 | type='YOLOXPAFPN', 22 | in_channels=[320, 640, 1280], 23 | out_channels=320, 24 | num_csp_blocks=4), 25 | bbox_head=dict( 26 | type='YOLOXHead', 27 | num_classes=80, 28 | in_channels=320, 29 | feat_channels=320), 30 | train_cfg=dict( 31 | assigner=dict(type='SimOTAAssigner', center_radius=2.5)), 32 | test_cfg=dict( 33 | score_thr=0.01, nms=dict(type='nms', iou_threshold=0.65)))) -------------------------------------------------------------------------------- /configs/continuous/mean_teacher_adapter_yolox/yolox_x_8xb4-24e_shift_from_clear_daytime.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../../_base_/models/yolox_x_8x8.py', 3 | '../../_base_/default_runtime.py' 4 | ] 5 | 6 | dataset_type = 'SHIFTDataset' 7 | data_root = 'data/shift/' 8 | attributes = dict(weather_coarse='clear', timeofday_coarse='daytime') 9 | 10 | img_scale = (800, 1440) 11 | batch_size = 2 12 | 13 | model = dict( 14 | type='AdaptiveDetector', 15 | data_preprocessor=dict( 16 | type='mmtrack.TrackDataPreprocessor', 17 | pad_size_divisor=32, 18 | batch_augments=[ 19 | dict( 20 | type='mmdet.BatchSyncRandomResize', 21 | random_size_range=(576, 1024), 22 | size_divisor=32, 23 | interval=10) 24 | ]), 25 | detector=dict( 26 | _scope_='mmdet', 27 | bbox_head=dict(num_classes=6), 28 | test_cfg=dict(score_thr=0.01, nms=dict(type='nms', iou_threshold=0.7)), 29 | init_cfg=dict( 30 | type='Pretrained', 31 | checkpoint= # noqa: E251 32 | 'https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_x_8x8_300e_coco/yolox_x_8x8_300e_coco_20211126_140254-1ef88d67.pth' # noqa: E501 33 | )), 34 | adapter=dict( 35 | type='MeanTeacherYOLOXAdapter', 36 | episodic=True, # do NOT change this. episodic must be set to True for the WVCL ICCV 2023 SHIFT Challenges 37 | optim_wrapper=dict( 38 | type='OptimWrapper', 39 | optimizer=dict( 40 | type='SGD', lr=0.00025, momentum=0.9, weight_decay=5e-4, nesterov=True), 41 | paramwise_cfg=dict(norm_decay_mult=0., bias_decay_mult=0.)), 42 | optim_steps=5, 43 | teacher=dict( 44 | type='ExponentialMovingAverage', 45 | momentum=0.0002, 46 | update_buffers=True), 47 | loss=dict( 48 | type='YOLOXConsistencyLoss', 49 | weight=0.01, 50 | ), 51 | pipeline = [ 52 | dict(type='LoadImageFromFile', 53 | backend_args=dict( 54 | backend='tar', 55 | tar_path=data_root + 'continuous/videos/1x/val/front/img_decompressed.tar', 56 | ) 57 | ), 58 | dict(type='mmtrack.LoadTrackAnnotations'), 59 | ], 60 | teacher_pipeline = [ 61 | dict(type='mmdet.Resize', scale=img_scale, keep_ratio=True), 62 | dict( 63 | type='mmdet.Pad', 64 | size_divisor=32, 65 | pad_val=dict(img=(114.0, 114.0, 114.0))), 66 | dict(type='mmtrack.PackTrackInputs', pack_single_img=True), 67 | ], 68 | student_pipeline = [ 69 | dict(type='mmdet.YOLOXHSVRandomAug'), 70 | dict(type='mmdet.Resize', scale=img_scale, keep_ratio=True), 71 | dict( 72 | type='mmdet.Pad', 73 | size_divisor=32, 74 | pad_val=dict(img=(114.0, 114.0, 114.0))), 75 | dict(type='mmtrack.PackTrackInputs', pack_single_img=True), 76 | ], 77 | views=2, 78 | )) 79 | 80 | train_pipeline = [ 81 | dict( 82 | type='mmdet.Mosaic', 83 | img_scale=img_scale, 84 | pad_val=114.0, 85 | bbox_clip_border=False), 86 | dict( 87 | type='mmdet.RandomAffine', 88 | scaling_ratio_range=(0.1, 2), 89 | border=(-img_scale[0] // 2, -img_scale[1] // 2), 90 | bbox_clip_border=False), 91 | dict( 92 | type='mmdet.MixUp', 93 | img_scale=img_scale, 94 | ratio_range=(0.8, 1.6), 95 | pad_val=114.0, 96 | bbox_clip_border=False), 97 | dict(type='mmdet.YOLOXHSVRandomAug'), 98 | dict(type='mmdet.RandomFlip', prob=0.5), 99 | dict( 100 | type='mmdet.Resize', 101 | scale=img_scale, 102 | keep_ratio=True, 103 | clip_object_border=False), 104 | dict( 105 | type='mmdet.Pad', 106 | size_divisor=32, 107 | pad_val=dict(img=(114.0, 114.0, 114.0))), 108 | dict( 109 | type='mmdet.FilterAnnotations', 110 | min_gt_bbox_wh=(1, 1), 111 | keep_empty=False), 112 | dict(type='mmtrack.PackTrackInputs', pack_single_img=True) 113 | ] 114 | test_pipeline = [ 115 | dict(type='LoadImageFromFile', 116 | backend_args=dict( 117 | backend='tar', 118 | tar_path=data_root + 'continuous/videos/1x/val/front/img_decompressed.tar', 119 | ) 120 | ), 121 | dict(type='mmtrack.LoadTrackAnnotations'), 122 | dict(type='mmdet.Resize', scale=img_scale, keep_ratio=True), 123 | dict( 124 | type='mmdet.Pad', 125 | size_divisor=32, 126 | pad_val=dict(img=(114.0, 114.0, 114.0))), 127 | dict(type='mmtrack.PackTrackInputs', pack_single_img=True) 128 | ] 129 | 130 | train_dataset = dict( 131 | # use MultiImageMixDataset wrapper to support mosaic and mixup 132 | type='mmdet.MultiImageMixDataset', 133 | dataset=dict( 134 | type='SHIFTDataset', 135 | load_as_video=False, 136 | ann_file=data_root + 'discrete/images/train/front/det_2d_cocoformat.json', 137 | data_prefix=dict(img=''), 138 | ref_img_sampler=None, 139 | metainfo=dict(classes=('pedestrian', 'car', 'truck', 'bus', 'motorcycle', 'bicycle')), 140 | pipeline=[ 141 | dict(type='LoadImageFromFile', 142 | backend_args=dict( 143 | backend='zip', 144 | zip_path=data_root + 'discrete/images/train/front/img.zip', 145 | ) 146 | ), 147 | dict(type='mmtrack.LoadTrackAnnotations'), 148 | ], 149 | filter_cfg=dict( 150 | attributes=attributes, 151 | filter_empty_gt=False, 152 | min_size=32 153 | )), 154 | pipeline=train_pipeline) 155 | train_dataloader = dict( 156 | batch_size=batch_size, 157 | num_workers=4, 158 | persistent_workers=True, 159 | sampler=dict(type='DefaultSampler', shuffle=True), 160 | dataset=train_dataset) 161 | 162 | val_dataset=dict( 163 | type='SHIFTDataset', 164 | load_as_video=True, 165 | ann_file=data_root + 'continuous/videos/1x/val/front/det_2d_cocoformat.json', 166 | data_prefix=dict(img=''), 167 | ref_img_sampler=None, 168 | test_mode=True, 169 | filter_cfg=dict(attributes=attributes), 170 | pipeline=test_pipeline, 171 | metainfo=dict(classes=('pedestrian', 'car', 'truck', 'bus', 'motorcycle', 'bicycle'))) 172 | val_dataloader = dict( 173 | batch_size=1, 174 | num_workers=4, 175 | persistent_workers=True, 176 | drop_last=False, 177 | sampler=dict(type='mmtrack.VideoSampler'), 178 | dataset=val_dataset) 179 | test_dataloader = val_dataloader 180 | # optimizer 181 | # default 8 gpu 182 | lr = 0.0005 / 8 * batch_size 183 | optim_wrapper = dict( 184 | type='OptimWrapper', 185 | optimizer=dict( 186 | type='SGD', lr=lr, momentum=0.9, weight_decay=5e-4, nesterov=True), 187 | paramwise_cfg=dict(norm_decay_mult=0., bias_decay_mult=0.)) 188 | 189 | # some hyper parameters 190 | # training settings 191 | total_epochs = 12 192 | num_last_epochs = 2 193 | resume_from = None 194 | interval = 5 195 | 196 | train_cfg = dict( 197 | type='EpochBasedTrainLoop', max_epochs=total_epochs, val_interval=1) 198 | val_cfg = dict(type='ValLoop') 199 | test_cfg = dict(type='TestLoop') 200 | # learning policy 201 | param_scheduler = [ 202 | dict( 203 | # use quadratic formula to warm up 1 epochs 204 | # and lr is updated by iteration 205 | type='mmdet.QuadraticWarmupLR', 206 | by_epoch=True, 207 | begin=0, 208 | end=1, 209 | convert_to_iter_based=True), 210 | dict( 211 | # use cosine lr from 1 to epoch #(total_epochs - num_last_epochs) 212 | type='mmdet.CosineAnnealingLR', 213 | eta_min=lr * 0.05, 214 | begin=1, 215 | T_max=total_epochs - num_last_epochs, 216 | end=total_epochs - num_last_epochs, 217 | by_epoch=True, 218 | convert_to_iter_based=True), 219 | dict( 220 | # use fixed lr during last 10 epochs 221 | type='mmdet.ConstantLR', 222 | by_epoch=True, 223 | factor=1, 224 | begin=total_epochs - num_last_epochs, 225 | end=total_epochs, 226 | ) 227 | ] 228 | 229 | custom_hooks = [ 230 | dict( 231 | type='mmtrack.YOLOXModeSwitchHook', 232 | num_last_epochs=num_last_epochs, 233 | priority=48), 234 | dict(type='mmdet.SyncNormHook', priority=48), 235 | dict( 236 | type='mmdet.EMAHook', 237 | ema_type='mmdet.ExpMomentumEMA', 238 | momentum=0.0001, 239 | update_buffers=True, 240 | priority=49) 241 | ] 242 | default_hooks = dict(checkpoint=dict(interval=1)) 243 | 244 | # evaluator 245 | val_evaluator = [ 246 | dict(type='SHIFTVideoMetric', metric=['bbox'], classwise=True), 247 | ] 248 | test_evaluator = val_evaluator -------------------------------------------------------------------------------- /configs/continuous/no_adap_yolox/yolox_x_8xb4-24e_shift_from_all.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../../_base_/models/yolox_x_8x8.py', 3 | '../../_base_/default_runtime.py' 4 | ] 5 | 6 | dataset_type = 'SHIFTDataset' 7 | data_root = 'data/shift/' 8 | attributes = None 9 | 10 | img_scale = (800, 1440) 11 | batch_size = 2 12 | 13 | model = dict( 14 | type='AdaptiveDetector', 15 | data_preprocessor=dict( 16 | type='mmtrack.TrackDataPreprocessor', 17 | pad_size_divisor=32, 18 | batch_augments=[ 19 | dict( 20 | type='mmdet.BatchSyncRandomResize', 21 | random_size_range=(576, 1024), 22 | size_divisor=32, 23 | interval=10) 24 | ]), 25 | detector=dict( 26 | _scope_='mmdet', 27 | bbox_head=dict(num_classes=6), 28 | test_cfg=dict(score_thr=0.01, nms=dict(type='nms', iou_threshold=0.7)), 29 | init_cfg=dict( 30 | type='Pretrained', 31 | checkpoint= # noqa: E251 32 | 'https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_x_8x8_300e_coco/yolox_x_8x8_300e_coco_20211126_140254-1ef88d67.pth' # noqa: E501 33 | )), 34 | adapter=None) 35 | 36 | train_pipeline = [ 37 | dict( 38 | type='mmdet.Mosaic', 39 | img_scale=img_scale, 40 | pad_val=114.0, 41 | bbox_clip_border=False), 42 | dict( 43 | type='mmdet.RandomAffine', 44 | scaling_ratio_range=(0.1, 2), 45 | border=(-img_scale[0] // 2, -img_scale[1] // 2), 46 | bbox_clip_border=False), 47 | dict( 48 | type='mmdet.MixUp', 49 | img_scale=img_scale, 50 | ratio_range=(0.8, 1.6), 51 | pad_val=114.0, 52 | bbox_clip_border=False), 53 | dict(type='mmdet.YOLOXHSVRandomAug'), 54 | dict(type='mmdet.RandomFlip', prob=0.5), 55 | dict( 56 | type='mmdet.Resize', 57 | scale=img_scale, 58 | keep_ratio=True, 59 | clip_object_border=False), 60 | dict( 61 | type='mmdet.Pad', 62 | size_divisor=32, 63 | pad_val=dict(img=(114.0, 114.0, 114.0))), 64 | dict( 65 | type='mmdet.FilterAnnotations', 66 | min_gt_bbox_wh=(1, 1), 67 | keep_empty=False), 68 | dict(type='mmtrack.PackTrackInputs', pack_single_img=True) 69 | ] 70 | test_pipeline = [ 71 | dict(type='LoadImageFromFile', 72 | backend_args=dict( 73 | backend='tar', 74 | tar_path=data_root + 'continuous/videos/1x/val/front/img_decompressed.tar', 75 | ) 76 | ), 77 | dict(type='mmtrack.LoadTrackAnnotations'), 78 | dict(type='mmdet.Resize', scale=img_scale, keep_ratio=True), 79 | dict( 80 | type='mmdet.Pad', 81 | size_divisor=32, 82 | pad_val=dict(img=(114.0, 114.0, 114.0))), 83 | dict(type='mmtrack.PackTrackInputs', pack_single_img=True) 84 | ] 85 | 86 | train_dataset = dict( 87 | # use MultiImageMixDataset wrapper to support mosaic and mixup 88 | type='mmdet.MultiImageMixDataset', 89 | dataset=dict( 90 | type='SHIFTDataset', 91 | load_as_video=False, 92 | ann_file=data_root + 'discrete/images/train/front/det_2d_cocoformat.json', 93 | data_prefix=dict(img=''), 94 | ref_img_sampler=None, 95 | metainfo=dict(classes=('pedestrian', 'car', 'truck', 'bus', 'motorcycle', 'bicycle')), 96 | pipeline=[ 97 | dict(type='LoadImageFromFile', 98 | backend_args=dict( 99 | backend='zip', 100 | zip_path=data_root + 'discrete/images/train/front/img.zip', 101 | ) 102 | ), 103 | dict(type='mmtrack.LoadTrackAnnotations'), 104 | ], 105 | filter_cfg=dict( 106 | attributes=attributes, 107 | filter_empty_gt=False, 108 | min_size=32 109 | )), 110 | pipeline=train_pipeline) 111 | train_dataloader = dict( 112 | batch_size=batch_size, 113 | num_workers=4, 114 | persistent_workers=True, 115 | sampler=dict(type='DefaultSampler', shuffle=True), 116 | dataset=train_dataset) 117 | 118 | val_dataset=dict( 119 | type='SHIFTDataset', 120 | load_as_video=True, 121 | ann_file=data_root + 'continuous/videos/1x/val/front/det_2d_cocoformat.json', 122 | data_prefix=dict(img=''), 123 | ref_img_sampler=None, 124 | test_mode=True, 125 | filter_cfg=dict(attributes=attributes), 126 | pipeline=test_pipeline, 127 | metainfo=dict(classes=('pedestrian', 'car', 'truck', 'bus', 'motorcycle', 'bicycle'))) 128 | val_dataloader = dict( 129 | batch_size=1, 130 | num_workers=4, 131 | persistent_workers=True, 132 | drop_last=False, 133 | sampler=dict(type='mmtrack.VideoSampler'), 134 | dataset=val_dataset) 135 | test_dataloader = val_dataloader 136 | # optimizer 137 | # default 8 gpu 138 | lr = 0.0005 / 8 * batch_size 139 | optim_wrapper = dict( 140 | type='OptimWrapper', 141 | optimizer=dict( 142 | type='SGD', lr=lr, momentum=0.9, weight_decay=5e-4, nesterov=True), 143 | paramwise_cfg=dict(norm_decay_mult=0., bias_decay_mult=0.)) 144 | 145 | # some hyper parameters 146 | # training settings 147 | total_epochs = 12 148 | num_last_epochs = 2 149 | resume_from = None 150 | interval = 5 151 | 152 | train_cfg = dict( 153 | type='EpochBasedTrainLoop', max_epochs=total_epochs, val_interval=1) 154 | val_cfg = dict(type='ValLoop') 155 | test_cfg = dict(type='TestLoop') 156 | # learning policy 157 | param_scheduler = [ 158 | dict( 159 | # use quadratic formula to warm up 1 epochs 160 | # and lr is updated by iteration 161 | type='mmdet.QuadraticWarmupLR', 162 | by_epoch=True, 163 | begin=0, 164 | end=1, 165 | convert_to_iter_based=True), 166 | dict( 167 | # use cosine lr from 1 to epoch #(total_epochs - num_last_epochs) 168 | type='mmdet.CosineAnnealingLR', 169 | eta_min=lr * 0.05, 170 | begin=1, 171 | T_max=total_epochs - num_last_epochs, 172 | end=total_epochs - num_last_epochs, 173 | by_epoch=True, 174 | convert_to_iter_based=True), 175 | dict( 176 | # use fixed lr during last 10 epochs 177 | type='mmdet.ConstantLR', 178 | by_epoch=True, 179 | factor=1, 180 | begin=total_epochs - num_last_epochs, 181 | end=total_epochs, 182 | ) 183 | ] 184 | 185 | custom_hooks = [ 186 | dict( 187 | type='mmtrack.YOLOXModeSwitchHook', 188 | num_last_epochs=num_last_epochs, 189 | priority=48), 190 | dict(type='mmdet.SyncNormHook', priority=48), 191 | dict( 192 | type='mmdet.EMAHook', 193 | ema_type='mmdet.ExpMomentumEMA', 194 | momentum=0.0001, 195 | update_buffers=True, 196 | priority=49) 197 | ] 198 | default_hooks = dict(checkpoint=dict(interval=1)) 199 | 200 | # evaluator 201 | val_evaluator = [ 202 | dict(type='SHIFTVideoMetric', metric=['bbox'], classwise=True), 203 | ] 204 | test_evaluator = val_evaluator -------------------------------------------------------------------------------- /configs/continuous/no_adap_yolox/yolox_x_8xb4-24e_shift_from_clear_daytime.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../../_base_/models/yolox_x_8x8.py', 3 | '../../_base_/default_runtime.py' 4 | ] 5 | 6 | dataset_type = 'SHIFTDataset' 7 | data_root = 'data/shift/' 8 | attributes = dict(weather_coarse='clear', timeofday_coarse='daytime') 9 | 10 | img_scale = (800, 1440) 11 | batch_size = 2 12 | 13 | model = dict( 14 | type='AdaptiveDetector', 15 | data_preprocessor=dict( 16 | type='mmtrack.TrackDataPreprocessor', 17 | pad_size_divisor=32, 18 | batch_augments=[ 19 | dict( 20 | type='mmdet.BatchSyncRandomResize', 21 | random_size_range=(576, 1024), 22 | size_divisor=32, 23 | interval=10) 24 | ]), 25 | detector=dict( 26 | _scope_='mmdet', 27 | bbox_head=dict(num_classes=6), 28 | test_cfg=dict(score_thr=0.01, nms=dict(type='nms', iou_threshold=0.7)), 29 | init_cfg=dict( 30 | type='Pretrained', 31 | checkpoint= # noqa: E251 32 | 'https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_x_8x8_300e_coco/yolox_x_8x8_300e_coco_20211126_140254-1ef88d67.pth' # noqa: E501 33 | )), 34 | adapter=None) 35 | 36 | train_pipeline = [ 37 | dict( 38 | type='mmdet.Mosaic', 39 | img_scale=img_scale, 40 | pad_val=114.0, 41 | bbox_clip_border=False), 42 | dict( 43 | type='mmdet.RandomAffine', 44 | scaling_ratio_range=(0.1, 2), 45 | border=(-img_scale[0] // 2, -img_scale[1] // 2), 46 | bbox_clip_border=False), 47 | dict( 48 | type='mmdet.MixUp', 49 | img_scale=img_scale, 50 | ratio_range=(0.8, 1.6), 51 | pad_val=114.0, 52 | bbox_clip_border=False), 53 | dict(type='mmdet.YOLOXHSVRandomAug'), 54 | dict(type='mmdet.RandomFlip', prob=0.5), 55 | dict( 56 | type='mmdet.Resize', 57 | scale=img_scale, 58 | keep_ratio=True, 59 | clip_object_border=False), 60 | dict( 61 | type='mmdet.Pad', 62 | size_divisor=32, 63 | pad_val=dict(img=(114.0, 114.0, 114.0))), 64 | dict( 65 | type='mmdet.FilterAnnotations', 66 | min_gt_bbox_wh=(1, 1), 67 | keep_empty=False), 68 | dict(type='mmtrack.PackTrackInputs', pack_single_img=True) 69 | ] 70 | test_pipeline = [ 71 | dict(type='LoadImageFromFile', 72 | backend_args=dict( 73 | backend='tar', 74 | tar_path=data_root + 'continuous/videos/1x/val/front/img_decompressed.tar', 75 | ) 76 | ), 77 | dict(type='mmtrack.LoadTrackAnnotations'), 78 | dict(type='mmdet.Resize', scale=img_scale, keep_ratio=True), 79 | dict( 80 | type='mmdet.Pad', 81 | size_divisor=32, 82 | pad_val=dict(img=(114.0, 114.0, 114.0))), 83 | dict(type='mmtrack.PackTrackInputs', pack_single_img=True) 84 | ] 85 | 86 | train_dataset = dict( 87 | # use MultiImageMixDataset wrapper to support mosaic and mixup 88 | type='mmdet.MultiImageMixDataset', 89 | dataset=dict( 90 | type='SHIFTDataset', 91 | load_as_video=False, 92 | ann_file=data_root + 'discrete/images/train/front/det_2d_cocoformat.json', 93 | data_prefix=dict(img=''), 94 | ref_img_sampler=None, 95 | metainfo=dict(classes=('pedestrian', 'car', 'truck', 'bus', 'motorcycle', 'bicycle')), 96 | pipeline=[ 97 | dict(type='LoadImageFromFile', 98 | backend_args=dict( 99 | backend='zip', 100 | zip_path=data_root + 'discrete/images/train/front/img.zip', 101 | ) 102 | ), 103 | dict(type='mmtrack.LoadTrackAnnotations'), 104 | ], 105 | filter_cfg=dict( 106 | attributes=attributes, 107 | filter_empty_gt=False, 108 | min_size=32 109 | )), 110 | pipeline=train_pipeline) 111 | train_dataloader = dict( 112 | batch_size=batch_size, 113 | num_workers=4, 114 | persistent_workers=True, 115 | sampler=dict(type='DefaultSampler', shuffle=True), 116 | dataset=train_dataset) 117 | 118 | val_dataset=dict( 119 | type='SHIFTDataset', 120 | load_as_video=True, 121 | ann_file=data_root + 'continuous/videos/1x/val/front/det_2d_cocoformat.json', 122 | data_prefix=dict(img=''), 123 | ref_img_sampler=None, 124 | test_mode=True, 125 | filter_cfg=dict(attributes=attributes), 126 | pipeline=test_pipeline, 127 | metainfo=dict(classes=('pedestrian', 'car', 'truck', 'bus', 'motorcycle', 'bicycle'))) 128 | val_dataloader = dict( 129 | batch_size=1, 130 | num_workers=4, 131 | persistent_workers=True, 132 | drop_last=False, 133 | sampler=dict(type='mmtrack.VideoSampler'), 134 | dataset=val_dataset) 135 | test_dataloader = val_dataloader 136 | # optimizer 137 | # default 8 gpu 138 | lr = 0.0005 / 8 * batch_size 139 | optim_wrapper = dict( 140 | type='OptimWrapper', 141 | optimizer=dict( 142 | type='SGD', lr=lr, momentum=0.9, weight_decay=5e-4, nesterov=True), 143 | paramwise_cfg=dict(norm_decay_mult=0., bias_decay_mult=0.)) 144 | 145 | # some hyper parameters 146 | # training settings 147 | total_epochs = 12 148 | num_last_epochs = 2 149 | resume_from = None 150 | interval = 5 151 | 152 | train_cfg = dict( 153 | type='EpochBasedTrainLoop', max_epochs=total_epochs, val_interval=1) 154 | val_cfg = dict(type='ValLoop') 155 | test_cfg = dict(type='TestLoop') 156 | # learning policy 157 | param_scheduler = [ 158 | dict( 159 | # use quadratic formula to warm up 1 epochs 160 | # and lr is updated by iteration 161 | type='mmdet.QuadraticWarmupLR', 162 | by_epoch=True, 163 | begin=0, 164 | end=1, 165 | convert_to_iter_based=True), 166 | dict( 167 | # use cosine lr from 1 to epoch #(total_epochs - num_last_epochs) 168 | type='mmdet.CosineAnnealingLR', 169 | eta_min=lr * 0.05, 170 | begin=1, 171 | T_max=total_epochs - num_last_epochs, 172 | end=total_epochs - num_last_epochs, 173 | by_epoch=True, 174 | convert_to_iter_based=True), 175 | dict( 176 | # use fixed lr during last 10 epochs 177 | type='mmdet.ConstantLR', 178 | by_epoch=True, 179 | factor=1, 180 | begin=total_epochs - num_last_epochs, 181 | end=total_epochs, 182 | ) 183 | ] 184 | 185 | custom_hooks = [ 186 | dict( 187 | type='mmtrack.YOLOXModeSwitchHook', 188 | num_last_epochs=num_last_epochs, 189 | priority=48), 190 | dict(type='mmdet.SyncNormHook', priority=48), 191 | dict( 192 | type='mmdet.EMAHook', 193 | ema_type='mmdet.ExpMomentumEMA', 194 | momentum=0.0001, 195 | update_buffers=True, 196 | priority=49) 197 | ] 198 | default_hooks = dict(checkpoint=dict(interval=1)) 199 | 200 | # evaluator 201 | val_evaluator = [ 202 | dict(type='SHIFTVideoMetric', metric=['bbox'], classwise=True), 203 | ] 204 | test_evaluator = val_evaluator -------------------------------------------------------------------------------- /configs/continuous/no_adap_yolox/yolox_x_8xb4-24e_shift_from_clear_night.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../../_base_/models/yolox_x_8x8.py', 3 | '../../_base_/default_runtime.py' 4 | ] 5 | 6 | dataset_type = 'SHIFTDataset' 7 | data_root = 'data/shift/' 8 | attributes = dict(weather_coarse='clear', timeofday_coarse='night') 9 | 10 | img_scale = (800, 1440) 11 | batch_size = 2 12 | 13 | model = dict( 14 | type='AdaptiveDetector', 15 | data_preprocessor=dict( 16 | type='mmtrack.TrackDataPreprocessor', 17 | pad_size_divisor=32, 18 | batch_augments=[ 19 | dict( 20 | type='mmdet.BatchSyncRandomResize', 21 | random_size_range=(576, 1024), 22 | size_divisor=32, 23 | interval=10) 24 | ]), 25 | detector=dict( 26 | _scope_='mmdet', 27 | bbox_head=dict(num_classes=6), 28 | test_cfg=dict(score_thr=0.01, nms=dict(type='nms', iou_threshold=0.7)), 29 | init_cfg=dict( 30 | type='Pretrained', 31 | checkpoint= # noqa: E251 32 | 'https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_x_8x8_300e_coco/yolox_x_8x8_300e_coco_20211126_140254-1ef88d67.pth' # noqa: E501 33 | )), 34 | adapter=None) 35 | 36 | train_pipeline = [ 37 | dict( 38 | type='mmdet.Mosaic', 39 | img_scale=img_scale, 40 | pad_val=114.0, 41 | bbox_clip_border=False), 42 | dict( 43 | type='mmdet.RandomAffine', 44 | scaling_ratio_range=(0.1, 2), 45 | border=(-img_scale[0] // 2, -img_scale[1] // 2), 46 | bbox_clip_border=False), 47 | dict( 48 | type='mmdet.MixUp', 49 | img_scale=img_scale, 50 | ratio_range=(0.8, 1.6), 51 | pad_val=114.0, 52 | bbox_clip_border=False), 53 | dict(type='mmdet.YOLOXHSVRandomAug'), 54 | dict(type='mmdet.RandomFlip', prob=0.5), 55 | dict( 56 | type='mmdet.Resize', 57 | scale=img_scale, 58 | keep_ratio=True, 59 | clip_object_border=False), 60 | dict( 61 | type='mmdet.Pad', 62 | size_divisor=32, 63 | pad_val=dict(img=(114.0, 114.0, 114.0))), 64 | dict( 65 | type='mmdet.FilterAnnotations', 66 | min_gt_bbox_wh=(1, 1), 67 | keep_empty=False), 68 | dict(type='mmtrack.PackTrackInputs', pack_single_img=True) 69 | ] 70 | test_pipeline = [ 71 | dict(type='LoadImageFromFile', 72 | backend_args=dict( 73 | backend='tar', 74 | tar_path=data_root + 'continuous/videos/1x/val/front/img_decompressed.tar', 75 | ) 76 | ), 77 | dict(type='mmtrack.LoadTrackAnnotations'), 78 | dict(type='mmdet.Resize', scale=img_scale, keep_ratio=True), 79 | dict( 80 | type='mmdet.Pad', 81 | size_divisor=32, 82 | pad_val=dict(img=(114.0, 114.0, 114.0))), 83 | dict(type='mmtrack.PackTrackInputs', pack_single_img=True) 84 | ] 85 | 86 | train_dataset = dict( 87 | # use MultiImageMixDataset wrapper to support mosaic and mixup 88 | type='mmdet.MultiImageMixDataset', 89 | dataset=dict( 90 | type='SHIFTDataset', 91 | load_as_video=False, 92 | ann_file=data_root + 'discrete/images/train/front/det_2d_cocoformat.json', 93 | data_prefix=dict(img=''), 94 | ref_img_sampler=None, 95 | metainfo=dict(classes=('pedestrian', 'car', 'truck', 'bus', 'motorcycle', 'bicycle')), 96 | pipeline=[ 97 | dict(type='LoadImageFromFile', 98 | backend_args=dict( 99 | backend='zip', 100 | zip_path=data_root + 'discrete/images/train/front/img.zip', 101 | ) 102 | ), 103 | dict(type='mmtrack.LoadTrackAnnotations'), 104 | ], 105 | filter_cfg=dict( 106 | attributes=attributes, 107 | filter_empty_gt=False, 108 | min_size=32 109 | )), 110 | pipeline=train_pipeline) 111 | train_dataloader = dict( 112 | batch_size=batch_size, 113 | num_workers=4, 114 | persistent_workers=True, 115 | sampler=dict(type='DefaultSampler', shuffle=True), 116 | dataset=train_dataset) 117 | 118 | val_dataset=dict( 119 | type='SHIFTDataset', 120 | load_as_video=True, 121 | ann_file=data_root + 'continuous/videos/1x/val/front/det_2d_cocoformat.json', 122 | data_prefix=dict(img=''), 123 | ref_img_sampler=None, 124 | test_mode=True, 125 | filter_cfg=dict(attributes=attributes), 126 | pipeline=test_pipeline, 127 | metainfo=dict(classes=('pedestrian', 'car', 'truck', 'bus', 'motorcycle', 'bicycle'))) 128 | val_dataloader = dict( 129 | batch_size=1, 130 | num_workers=4, 131 | persistent_workers=True, 132 | drop_last=False, 133 | sampler=dict(type='mmtrack.VideoSampler'), 134 | dataset=val_dataset) 135 | test_dataloader = val_dataloader 136 | # optimizer 137 | # default 8 gpu 138 | lr = 0.0005 / 8 * batch_size 139 | optim_wrapper = dict( 140 | type='OptimWrapper', 141 | optimizer=dict( 142 | type='SGD', lr=lr, momentum=0.9, weight_decay=5e-4, nesterov=True), 143 | paramwise_cfg=dict(norm_decay_mult=0., bias_decay_mult=0.)) 144 | 145 | # some hyper parameters 146 | # training settings 147 | total_epochs = 12 148 | num_last_epochs = 2 149 | resume_from = None 150 | interval = 5 151 | 152 | train_cfg = dict( 153 | type='EpochBasedTrainLoop', max_epochs=total_epochs, val_interval=1) 154 | val_cfg = dict(type='ValLoop') 155 | test_cfg = dict(type='TestLoop') 156 | # learning policy 157 | param_scheduler = [ 158 | dict( 159 | # use quadratic formula to warm up 1 epochs 160 | # and lr is updated by iteration 161 | type='mmdet.QuadraticWarmupLR', 162 | by_epoch=True, 163 | begin=0, 164 | end=1, 165 | convert_to_iter_based=True), 166 | dict( 167 | # use cosine lr from 1 to epoch #(total_epochs - num_last_epochs) 168 | type='mmdet.CosineAnnealingLR', 169 | eta_min=lr * 0.05, 170 | begin=1, 171 | T_max=total_epochs - num_last_epochs, 172 | end=total_epochs - num_last_epochs, 173 | by_epoch=True, 174 | convert_to_iter_based=True), 175 | dict( 176 | # use fixed lr during last 10 epochs 177 | type='mmdet.ConstantLR', 178 | by_epoch=True, 179 | factor=1, 180 | begin=total_epochs - num_last_epochs, 181 | end=total_epochs, 182 | ) 183 | ] 184 | 185 | custom_hooks = [ 186 | dict( 187 | type='mmtrack.YOLOXModeSwitchHook', 188 | num_last_epochs=num_last_epochs, 189 | priority=48), 190 | dict(type='mmdet.SyncNormHook', priority=48), 191 | dict( 192 | type='mmdet.EMAHook', 193 | ema_type='mmdet.ExpMomentumEMA', 194 | momentum=0.0001, 195 | update_buffers=True, 196 | priority=49) 197 | ] 198 | default_hooks = dict(checkpoint=dict(interval=1)) 199 | 200 | # evaluator 201 | val_evaluator = [ 202 | dict(type='SHIFTVideoMetric', metric=['bbox'], classwise=True), 203 | ] 204 | test_evaluator = val_evaluator -------------------------------------------------------------------------------- /configs/source/yolox/README.md: -------------------------------------------------------------------------------- 1 | # TODO -------------------------------------------------------------------------------- /configs/source/yolox/amp_yolox_x_8xb4-24e_shift_all.py: -------------------------------------------------------------------------------- 1 | _base_ = ['./yolox_x_8xb4-24e_shift_all.py'] 2 | 3 | # fp16 settings 4 | optim_wrapper = dict(type='AmpOptimWrapper', loss_scale='dynamic') 5 | test_cfg = dict(type='TestLoop', fp16=True) -------------------------------------------------------------------------------- /configs/source/yolox/amp_yolox_x_8xb4-24e_shift_clear_daytime.py: -------------------------------------------------------------------------------- 1 | _base_ = ['./yolox_x_8xb4-24e_shift_clear_daytime.py'] 2 | 3 | # fp16 settings 4 | optim_wrapper = dict(type='AmpOptimWrapper', loss_scale='dynamic') 5 | test_cfg = dict(type='TestLoop', fp16=True) -------------------------------------------------------------------------------- /configs/source/yolox/amp_yolox_x_8xb4-24e_shift_clear_night.py: -------------------------------------------------------------------------------- 1 | _base_ = ['./yolox_x_8xb4-24e_shift_clear_night.py'] 2 | 3 | # fp16 settings 4 | optim_wrapper = dict(type='AmpOptimWrapper', loss_scale='dynamic') 5 | test_cfg = dict(type='TestLoop', fp16=True) -------------------------------------------------------------------------------- /configs/source/yolox/amp_yolox_x_8xb4-24e_shift_daytime.py: -------------------------------------------------------------------------------- 1 | _base_ = ['./yolox_x_8xb4-24e_shift_daytime.py'] 2 | 3 | # fp16 settings 4 | optim_wrapper = dict(type='AmpOptimWrapper', loss_scale='dynamic') 5 | test_cfg = dict(type='TestLoop', fp16=True) -------------------------------------------------------------------------------- /configs/source/yolox/amp_yolox_x_8xb4-24e_shift_night.py: -------------------------------------------------------------------------------- 1 | _base_ = ['./yolox_x_8xb4-24e_shift_night.py'] 2 | 3 | # fp16 settings 4 | optim_wrapper = dict(type='AmpOptimWrapper', loss_scale='dynamic') 5 | test_cfg = dict(type='TestLoop', fp16=True) -------------------------------------------------------------------------------- /configs/source/yolox/yolox_x_8xb4-24e_shift_all.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../../_base_/models/yolox_x_8x8.py', 3 | '../../_base_/default_runtime.py' 4 | ] 5 | 6 | dataset_type = 'SHIFTDataset' 7 | data_root = 'data/shift/' 8 | attributes = None 9 | 10 | img_scale = (800, 1440) 11 | batch_size = 2 12 | 13 | model = dict( 14 | type='AdaptiveDetector', 15 | data_preprocessor=dict( 16 | type='mmtrack.TrackDataPreprocessor', 17 | pad_size_divisor=32, 18 | batch_augments=[ 19 | dict( 20 | type='mmdet.BatchSyncRandomResize', 21 | random_size_range=(576, 1024), 22 | size_divisor=32, 23 | interval=10) 24 | ]), 25 | detector=dict( 26 | _scope_='mmdet', 27 | bbox_head=dict(num_classes=6), 28 | test_cfg=dict(score_thr=0.01, nms=dict(type='nms', iou_threshold=0.7)), 29 | init_cfg=dict( 30 | type='Pretrained', 31 | checkpoint= # noqa: E251 32 | 'https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_x_8x8_300e_coco/yolox_x_8x8_300e_coco_20211126_140254-1ef88d67.pth' # noqa: E501 33 | )), 34 | adapter=None) 35 | 36 | train_pipeline = [ 37 | dict( 38 | type='mmdet.Mosaic', 39 | img_scale=img_scale, 40 | pad_val=114.0, 41 | bbox_clip_border=False), 42 | dict( 43 | type='mmdet.RandomAffine', 44 | scaling_ratio_range=(0.1, 2), 45 | border=(-img_scale[0] // 2, -img_scale[1] // 2), 46 | bbox_clip_border=False), 47 | dict( 48 | type='mmdet.MixUp', 49 | img_scale=img_scale, 50 | ratio_range=(0.8, 1.6), 51 | pad_val=114.0, 52 | bbox_clip_border=False), 53 | dict(type='mmdet.YOLOXHSVRandomAug'), 54 | dict(type='mmdet.RandomFlip', prob=0.5), 55 | dict( 56 | type='mmdet.Resize', 57 | scale=img_scale, 58 | keep_ratio=True, 59 | clip_object_border=False), 60 | dict( 61 | type='mmdet.Pad', 62 | size_divisor=32, 63 | pad_val=dict(img=(114.0, 114.0, 114.0))), 64 | dict( 65 | type='mmdet.FilterAnnotations', 66 | min_gt_bbox_wh=(1, 1), 67 | keep_empty=False), 68 | dict(type='mmtrack.PackTrackInputs', pack_single_img=True) 69 | ] 70 | test_pipeline = [ 71 | dict(type='LoadImageFromFile', 72 | backend_args=dict( 73 | backend='zip', 74 | zip_path=data_root + 'discrete/images/val/front/img.zip', 75 | ) 76 | ), 77 | dict(type='mmtrack.LoadTrackAnnotations'), 78 | dict(type='mmdet.Resize', scale=img_scale, keep_ratio=True), 79 | dict( 80 | type='mmdet.Pad', 81 | size_divisor=32, 82 | pad_val=dict(img=(114.0, 114.0, 114.0))), 83 | dict(type='mmtrack.PackTrackInputs', pack_single_img=True) 84 | ] 85 | 86 | train_dataset = dict( 87 | # use MultiImageMixDataset wrapper to support mosaic and mixup 88 | type='mmdet.MultiImageMixDataset', 89 | dataset=dict( 90 | type='SHIFTDataset', 91 | load_as_video=False, 92 | ann_file=data_root + 'discrete/images/train/front/det_2d_cocoformat.json', 93 | data_prefix=dict(img=''), 94 | ref_img_sampler=None, 95 | metainfo=dict(classes=('pedestrian', 'car', 'truck', 'bus', 'motorcycle', 'bicycle')), 96 | pipeline=[ 97 | dict(type='LoadImageFromFile', 98 | backend_args=dict( 99 | backend='zip', 100 | zip_path=data_root + 'discrete/images/train/front/img.zip', 101 | ) 102 | ), 103 | dict(type='mmtrack.LoadTrackAnnotations'), 104 | ], 105 | filter_cfg=dict( 106 | attributes=attributes, 107 | filter_empty_gt=False, 108 | min_size=32 109 | )), 110 | pipeline=train_pipeline) 111 | train_dataloader = dict( 112 | batch_size=batch_size, 113 | num_workers=4, 114 | persistent_workers=True, 115 | sampler=dict(type='DefaultSampler', shuffle=True), 116 | dataset=train_dataset) 117 | 118 | val_dataset=dict( 119 | type='SHIFTDataset', 120 | load_as_video=False, 121 | ann_file=data_root + 'discrete/images/val/front/det_2d_cocoformat.json', 122 | data_prefix=dict(img=''), 123 | ref_img_sampler=None, 124 | test_mode=True, 125 | filter_cfg=dict(attributes=attributes), 126 | pipeline=test_pipeline, 127 | metainfo=dict(classes=('pedestrian', 'car', 'truck', 'bus', 'motorcycle', 'bicycle'))) 128 | val_dataloader = dict( 129 | batch_size=1, 130 | num_workers=4, 131 | persistent_workers=True, 132 | drop_last=False, 133 | sampler=dict(type='DefaultSampler', shuffle=False), 134 | dataset=val_dataset) 135 | test_dataloader = val_dataloader 136 | # optimizer 137 | # default 8 gpu 138 | lr = 0.0005 / 8 * batch_size 139 | optim_wrapper = dict( 140 | type='OptimWrapper', 141 | optimizer=dict( 142 | type='SGD', lr=lr, momentum=0.9, weight_decay=5e-4, nesterov=True), 143 | paramwise_cfg=dict(norm_decay_mult=0., bias_decay_mult=0.)) 144 | 145 | # some hyper parameters 146 | # training settings 147 | total_epochs = 24 148 | num_last_epochs = 2 149 | resume_from = None 150 | interval = 2 151 | 152 | train_cfg = dict( 153 | type='EpochBasedTrainLoop', max_epochs=total_epochs, val_interval=1) 154 | val_cfg = dict(type='ValLoop') 155 | test_cfg = dict(type='TestLoop') 156 | # learning policy 157 | param_scheduler = [ 158 | dict( 159 | # use quadratic formula to warm up 1 epochs 160 | # and lr is updated by iteration 161 | type='mmdet.QuadraticWarmupLR', 162 | by_epoch=True, 163 | begin=0, 164 | end=1, 165 | convert_to_iter_based=True), 166 | dict( 167 | # use cosine lr from 1 to epoch #(total_epochs - num_last_epochs) 168 | type='mmdet.CosineAnnealingLR', 169 | eta_min=lr * 0.05, 170 | begin=1, 171 | T_max=total_epochs - num_last_epochs, 172 | end=total_epochs - num_last_epochs, 173 | by_epoch=True, 174 | convert_to_iter_based=True), 175 | dict( 176 | # use fixed lr during last 10 epochs 177 | type='mmdet.ConstantLR', 178 | by_epoch=True, 179 | factor=1, 180 | begin=total_epochs - num_last_epochs, 181 | end=total_epochs, 182 | ) 183 | ] 184 | 185 | custom_hooks = [ 186 | dict( 187 | type='mmtrack.YOLOXModeSwitchHook', 188 | num_last_epochs=num_last_epochs, 189 | priority=48), 190 | dict(type='mmdet.SyncNormHook', priority=48), 191 | dict( 192 | type='mmdet.EMAHook', 193 | ema_type='mmdet.ExpMomentumEMA', 194 | momentum=0.0001, 195 | update_buffers=True, 196 | priority=49) 197 | ] 198 | default_hooks = dict(checkpoint=dict(interval=1)) 199 | 200 | # evaluator 201 | val_evaluator = [ 202 | dict(type='SHIFTVideoMetric', metric=['bbox'], classwise=True), 203 | ] 204 | test_evaluator = val_evaluator -------------------------------------------------------------------------------- /configs/source/yolox/yolox_x_8xb4-24e_shift_clear_daytime.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../../_base_/models/yolox_x_8x8.py', 3 | '../../_base_/default_runtime.py' 4 | ] 5 | 6 | dataset_type = 'SHIFTDataset' 7 | data_root = 'data/shift/' 8 | attributes = dict(weather_coarse='clear', timeofday_coarse='daytime') 9 | 10 | img_scale = (800, 1440) 11 | batch_size = 2 12 | 13 | model = dict( 14 | type='AdaptiveDetector', 15 | data_preprocessor=dict( 16 | type='mmtrack.TrackDataPreprocessor', 17 | pad_size_divisor=32, 18 | batch_augments=[ 19 | dict( 20 | type='mmdet.BatchSyncRandomResize', 21 | random_size_range=(576, 1024), 22 | size_divisor=32, 23 | interval=10) 24 | ]), 25 | detector=dict( 26 | _scope_='mmdet', 27 | bbox_head=dict(num_classes=6), 28 | test_cfg=dict(score_thr=0.01, nms=dict(type='nms', iou_threshold=0.7)), 29 | init_cfg=dict( 30 | type='Pretrained', 31 | checkpoint= # noqa: E251 32 | 'https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_x_8x8_300e_coco/yolox_x_8x8_300e_coco_20211126_140254-1ef88d67.pth' # noqa: E501 33 | )), 34 | adapter=None) 35 | 36 | train_pipeline = [ 37 | dict( 38 | type='mmdet.Mosaic', 39 | img_scale=img_scale, 40 | pad_val=114.0, 41 | bbox_clip_border=False), 42 | dict( 43 | type='mmdet.RandomAffine', 44 | scaling_ratio_range=(0.1, 2), 45 | border=(-img_scale[0] // 2, -img_scale[1] // 2), 46 | bbox_clip_border=False), 47 | dict( 48 | type='mmdet.MixUp', 49 | img_scale=img_scale, 50 | ratio_range=(0.8, 1.6), 51 | pad_val=114.0, 52 | bbox_clip_border=False), 53 | dict(type='mmdet.YOLOXHSVRandomAug'), 54 | dict(type='mmdet.RandomFlip', prob=0.5), 55 | dict( 56 | type='mmdet.Resize', 57 | scale=img_scale, 58 | keep_ratio=True, 59 | clip_object_border=False), 60 | dict( 61 | type='mmdet.Pad', 62 | size_divisor=32, 63 | pad_val=dict(img=(114.0, 114.0, 114.0))), 64 | dict( 65 | type='mmdet.FilterAnnotations', 66 | min_gt_bbox_wh=(1, 1), 67 | keep_empty=False), 68 | dict(type='mmtrack.PackTrackInputs', pack_single_img=True) 69 | ] 70 | test_pipeline = [ 71 | dict(type='LoadImageFromFile', 72 | backend_args=dict( 73 | backend='zip', 74 | zip_path=data_root + 'discrete/images/val/front/img.zip', 75 | ) 76 | ), 77 | dict(type='mmtrack.LoadTrackAnnotations'), 78 | dict(type='mmdet.Resize', scale=img_scale, keep_ratio=True), 79 | dict( 80 | type='mmdet.Pad', 81 | size_divisor=32, 82 | pad_val=dict(img=(114.0, 114.0, 114.0))), 83 | dict(type='mmtrack.PackTrackInputs', pack_single_img=True) 84 | ] 85 | 86 | train_dataset = dict( 87 | # use MultiImageMixDataset wrapper to support mosaic and mixup 88 | type='mmdet.MultiImageMixDataset', 89 | dataset=dict( 90 | type='SHIFTDataset', 91 | load_as_video=False, 92 | ann_file=data_root + 'discrete/images/train/front/det_2d_cocoformat.json', 93 | data_prefix=dict(img=''), 94 | ref_img_sampler=None, 95 | metainfo=dict(classes=('pedestrian', 'car', 'truck', 'bus', 'motorcycle', 'bicycle')), 96 | pipeline=[ 97 | dict(type='LoadImageFromFile', 98 | backend_args=dict( 99 | backend='zip', 100 | zip_path=data_root + 'discrete/images/train/front/img.zip', 101 | ) 102 | ), 103 | dict(type='mmtrack.LoadTrackAnnotations'), 104 | ], 105 | filter_cfg=dict( 106 | attributes=attributes, 107 | filter_empty_gt=False, 108 | min_size=32 109 | )), 110 | pipeline=train_pipeline) 111 | train_dataloader = dict( 112 | batch_size=batch_size, 113 | num_workers=4, 114 | persistent_workers=True, 115 | sampler=dict(type='DefaultSampler', shuffle=True), 116 | dataset=train_dataset) 117 | 118 | val_dataset=dict( 119 | type='SHIFTDataset', 120 | load_as_video=False, 121 | ann_file=data_root + 'discrete/images/val/front/det_2d_cocoformat.json', 122 | data_prefix=dict(img=''), 123 | ref_img_sampler=None, 124 | test_mode=True, 125 | filter_cfg=dict(attributes=attributes), 126 | pipeline=test_pipeline, 127 | metainfo=dict(classes=('pedestrian', 'car', 'truck', 'bus', 'motorcycle', 'bicycle'))) 128 | val_dataloader = dict( 129 | batch_size=1, 130 | num_workers=4, 131 | persistent_workers=True, 132 | drop_last=False, 133 | sampler=dict(type='DefaultSampler', shuffle=False), 134 | dataset=val_dataset) 135 | test_dataloader = val_dataloader 136 | # optimizer 137 | # default 8 gpu 138 | lr = 0.0005 / 8 * batch_size 139 | optim_wrapper = dict( 140 | type='OptimWrapper', 141 | optimizer=dict( 142 | type='SGD', lr=lr, momentum=0.9, weight_decay=5e-4, nesterov=True), 143 | paramwise_cfg=dict(norm_decay_mult=0., bias_decay_mult=0.)) 144 | 145 | # some hyper parameters 146 | # training settings 147 | total_epochs = 24 148 | num_last_epochs = 2 149 | resume_from = None 150 | interval = 2 151 | 152 | train_cfg = dict( 153 | type='EpochBasedTrainLoop', max_epochs=total_epochs, val_interval=1) 154 | val_cfg = dict(type='ValLoop') 155 | test_cfg = dict(type='TestLoop') 156 | # learning policy 157 | param_scheduler = [ 158 | dict( 159 | # use quadratic formula to warm up 1 epochs 160 | # and lr is updated by iteration 161 | type='mmdet.QuadraticWarmupLR', 162 | by_epoch=True, 163 | begin=0, 164 | end=1, 165 | convert_to_iter_based=True), 166 | dict( 167 | # use cosine lr from 1 to epoch #(total_epochs - num_last_epochs) 168 | type='mmdet.CosineAnnealingLR', 169 | eta_min=lr * 0.05, 170 | begin=1, 171 | T_max=total_epochs - num_last_epochs, 172 | end=total_epochs - num_last_epochs, 173 | by_epoch=True, 174 | convert_to_iter_based=True), 175 | dict( 176 | # use fixed lr during last 10 epochs 177 | type='mmdet.ConstantLR', 178 | by_epoch=True, 179 | factor=1, 180 | begin=total_epochs - num_last_epochs, 181 | end=total_epochs, 182 | ) 183 | ] 184 | 185 | custom_hooks = [ 186 | dict( 187 | type='mmtrack.YOLOXModeSwitchHook', 188 | num_last_epochs=num_last_epochs, 189 | priority=48), 190 | dict(type='mmdet.SyncNormHook', priority=48), 191 | dict( 192 | type='mmdet.EMAHook', 193 | ema_type='mmdet.ExpMomentumEMA', 194 | momentum=0.0001, 195 | update_buffers=True, 196 | priority=49) 197 | ] 198 | default_hooks = dict(checkpoint=dict(interval=1)) 199 | 200 | # evaluator 201 | val_evaluator = [ 202 | dict(type='SHIFTVideoMetric', metric=['bbox'], classwise=True), 203 | ] 204 | test_evaluator = val_evaluator -------------------------------------------------------------------------------- /configs/source/yolox/yolox_x_8xb4-24e_shift_clear_night.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../../_base_/models/yolox_x_8x8.py', 3 | '../../_base_/default_runtime.py' 4 | ] 5 | 6 | dataset_type = 'SHIFTDataset' 7 | data_root = 'data/shift/' 8 | attributes = dict(weather_coarse='clear', timeofday_coarse='night') 9 | 10 | img_scale = (800, 1440) 11 | batch_size = 2 12 | 13 | model = dict( 14 | type='AdaptiveDetector', 15 | data_preprocessor=dict( 16 | type='mmtrack.TrackDataPreprocessor', 17 | pad_size_divisor=32, 18 | batch_augments=[ 19 | dict( 20 | type='mmdet.BatchSyncRandomResize', 21 | random_size_range=(576, 1024), 22 | size_divisor=32, 23 | interval=10) 24 | ]), 25 | detector=dict( 26 | _scope_='mmdet', 27 | bbox_head=dict(num_classes=6), 28 | test_cfg=dict(score_thr=0.01, nms=dict(type='nms', iou_threshold=0.7)), 29 | init_cfg=dict( 30 | type='Pretrained', 31 | checkpoint= # noqa: E251 32 | 'https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_x_8x8_300e_coco/yolox_x_8x8_300e_coco_20211126_140254-1ef88d67.pth' # noqa: E501 33 | )), 34 | adapter=None) 35 | 36 | train_pipeline = [ 37 | dict( 38 | type='mmdet.Mosaic', 39 | img_scale=img_scale, 40 | pad_val=114.0, 41 | bbox_clip_border=False), 42 | dict( 43 | type='mmdet.RandomAffine', 44 | scaling_ratio_range=(0.1, 2), 45 | border=(-img_scale[0] // 2, -img_scale[1] // 2), 46 | bbox_clip_border=False), 47 | dict( 48 | type='mmdet.MixUp', 49 | img_scale=img_scale, 50 | ratio_range=(0.8, 1.6), 51 | pad_val=114.0, 52 | bbox_clip_border=False), 53 | dict(type='mmdet.YOLOXHSVRandomAug'), 54 | dict(type='mmdet.RandomFlip', prob=0.5), 55 | dict( 56 | type='mmdet.Resize', 57 | scale=img_scale, 58 | keep_ratio=True, 59 | clip_object_border=False), 60 | dict( 61 | type='mmdet.Pad', 62 | size_divisor=32, 63 | pad_val=dict(img=(114.0, 114.0, 114.0))), 64 | dict( 65 | type='mmdet.FilterAnnotations', 66 | min_gt_bbox_wh=(1, 1), 67 | keep_empty=False), 68 | dict(type='mmtrack.PackTrackInputs', pack_single_img=True) 69 | ] 70 | test_pipeline = [ 71 | dict(type='LoadImageFromFile', 72 | backend_args=dict( 73 | backend='zip', 74 | zip_path=data_root + 'discrete/images/val/front/img.zip', 75 | ) 76 | ), 77 | dict(type='mmtrack.LoadTrackAnnotations'), 78 | dict(type='mmdet.Resize', scale=img_scale, keep_ratio=True), 79 | dict( 80 | type='mmdet.Pad', 81 | size_divisor=32, 82 | pad_val=dict(img=(114.0, 114.0, 114.0))), 83 | dict(type='mmtrack.PackTrackInputs', pack_single_img=True) 84 | ] 85 | 86 | train_dataset = dict( 87 | # use MultiImageMixDataset wrapper to support mosaic and mixup 88 | type='mmdet.MultiImageMixDataset', 89 | dataset=dict( 90 | type='SHIFTDataset', 91 | load_as_video=False, 92 | ann_file=data_root + 'discrete/images/train/front/det_2d_cocoformat.json', 93 | data_prefix=dict(img=''), 94 | ref_img_sampler=None, 95 | metainfo=dict(classes=('pedestrian', 'car', 'truck', 'bus', 'motorcycle', 'bicycle')), 96 | pipeline=[ 97 | dict(type='LoadImageFromFile', 98 | backend_args=dict( 99 | backend='zip', 100 | zip_path=data_root + 'discrete/images/train/front/img.zip', 101 | ) 102 | ), 103 | dict(type='mmtrack.LoadTrackAnnotations'), 104 | ], 105 | filter_cfg=dict( 106 | attributes=attributes, 107 | filter_empty_gt=False, 108 | min_size=32 109 | )), 110 | pipeline=train_pipeline) 111 | train_dataloader = dict( 112 | batch_size=batch_size, 113 | num_workers=4, 114 | persistent_workers=True, 115 | sampler=dict(type='DefaultSampler', shuffle=True), 116 | dataset=train_dataset) 117 | 118 | val_dataset=dict( 119 | type='SHIFTDataset', 120 | load_as_video=False, 121 | ann_file=data_root + 'discrete/images/val/front/det_2d_cocoformat.json', 122 | data_prefix=dict(img=''), 123 | ref_img_sampler=None, 124 | test_mode=True, 125 | filter_cfg=dict(attributes=attributes), 126 | pipeline=test_pipeline, 127 | metainfo=dict(classes=('pedestrian', 'car', 'truck', 'bus', 'motorcycle', 'bicycle'))) 128 | val_dataloader = dict( 129 | batch_size=1, 130 | num_workers=4, 131 | persistent_workers=True, 132 | drop_last=False, 133 | sampler=dict(type='DefaultSampler', shuffle=False), 134 | dataset=val_dataset) 135 | test_dataloader = val_dataloader 136 | # optimizer 137 | # default 8 gpu 138 | lr = 0.0005 / 8 * batch_size 139 | optim_wrapper = dict( 140 | type='OptimWrapper', 141 | optimizer=dict( 142 | type='SGD', lr=lr, momentum=0.9, weight_decay=5e-4, nesterov=True), 143 | paramwise_cfg=dict(norm_decay_mult=0., bias_decay_mult=0.)) 144 | 145 | # some hyper parameters 146 | # training settings 147 | total_epochs = 24 148 | num_last_epochs = 2 149 | resume_from = None 150 | interval = 2 151 | 152 | train_cfg = dict( 153 | type='EpochBasedTrainLoop', max_epochs=total_epochs, val_interval=1) 154 | val_cfg = dict(type='ValLoop') 155 | test_cfg = dict(type='TestLoop') 156 | # learning policy 157 | param_scheduler = [ 158 | dict( 159 | # use quadratic formula to warm up 1 epochs 160 | # and lr is updated by iteration 161 | type='mmdet.QuadraticWarmupLR', 162 | by_epoch=True, 163 | begin=0, 164 | end=1, 165 | convert_to_iter_based=True), 166 | dict( 167 | # use cosine lr from 1 to epoch #(total_epochs - num_last_epochs) 168 | type='mmdet.CosineAnnealingLR', 169 | eta_min=lr * 0.05, 170 | begin=1, 171 | T_max=total_epochs - num_last_epochs, 172 | end=total_epochs - num_last_epochs, 173 | by_epoch=True, 174 | convert_to_iter_based=True), 175 | dict( 176 | # use fixed lr during last 10 epochs 177 | type='mmdet.ConstantLR', 178 | by_epoch=True, 179 | factor=1, 180 | begin=total_epochs - num_last_epochs, 181 | end=total_epochs, 182 | ) 183 | ] 184 | 185 | custom_hooks = [ 186 | dict( 187 | type='mmtrack.YOLOXModeSwitchHook', 188 | num_last_epochs=num_last_epochs, 189 | priority=48), 190 | dict(type='mmdet.SyncNormHook', priority=48), 191 | dict( 192 | type='mmdet.EMAHook', 193 | ema_type='mmdet.ExpMomentumEMA', 194 | momentum=0.0001, 195 | update_buffers=True, 196 | priority=49) 197 | ] 198 | default_hooks = dict(checkpoint=dict(interval=1)) 199 | 200 | # evaluator 201 | val_evaluator = [ 202 | dict(type='SHIFTVideoMetric', metric=['bbox'], classwise=True), 203 | ] 204 | test_evaluator = val_evaluator -------------------------------------------------------------------------------- /configs/source/yolox/yolox_x_8xb4-24e_shift_daytime.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../../_base_/models/yolox_x_8x8.py', 3 | '../../_base_/default_runtime.py' 4 | ] 5 | 6 | dataset_type = 'SHIFTDataset' 7 | data_root = 'data/shift/' 8 | attributes = dict(timeofday_coarse='daytime') 9 | 10 | img_scale = (800, 1440) 11 | batch_size = 2 12 | 13 | model = dict( 14 | type='AdaptiveDetector', 15 | data_preprocessor=dict( 16 | type='mmtrack.TrackDataPreprocessor', 17 | pad_size_divisor=32, 18 | batch_augments=[ 19 | dict( 20 | type='mmdet.BatchSyncRandomResize', 21 | random_size_range=(576, 1024), 22 | size_divisor=32, 23 | interval=10) 24 | ]), 25 | detector=dict( 26 | _scope_='mmdet', 27 | bbox_head=dict(num_classes=6), 28 | test_cfg=dict(score_thr=0.01, nms=dict(type='nms', iou_threshold=0.7)), 29 | init_cfg=dict( 30 | type='Pretrained', 31 | checkpoint= # noqa: E251 32 | 'https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_x_8x8_300e_coco/yolox_x_8x8_300e_coco_20211126_140254-1ef88d67.pth' # noqa: E501 33 | )), 34 | adapter=None) 35 | 36 | train_pipeline = [ 37 | dict( 38 | type='mmdet.Mosaic', 39 | img_scale=img_scale, 40 | pad_val=114.0, 41 | bbox_clip_border=False), 42 | dict( 43 | type='mmdet.RandomAffine', 44 | scaling_ratio_range=(0.1, 2), 45 | border=(-img_scale[0] // 2, -img_scale[1] // 2), 46 | bbox_clip_border=False), 47 | dict( 48 | type='mmdet.MixUp', 49 | img_scale=img_scale, 50 | ratio_range=(0.8, 1.6), 51 | pad_val=114.0, 52 | bbox_clip_border=False), 53 | dict(type='mmdet.YOLOXHSVRandomAug'), 54 | dict(type='mmdet.RandomFlip', prob=0.5), 55 | dict( 56 | type='mmdet.Resize', 57 | scale=img_scale, 58 | keep_ratio=True, 59 | clip_object_border=False), 60 | dict( 61 | type='mmdet.Pad', 62 | size_divisor=32, 63 | pad_val=dict(img=(114.0, 114.0, 114.0))), 64 | dict( 65 | type='mmdet.FilterAnnotations', 66 | min_gt_bbox_wh=(1, 1), 67 | keep_empty=False), 68 | dict(type='mmtrack.PackTrackInputs', pack_single_img=True) 69 | ] 70 | test_pipeline = [ 71 | dict(type='LoadImageFromFile', 72 | backend_args=dict( 73 | backend='zip', 74 | zip_path=data_root + 'discrete/images/val/front/img.zip', 75 | ) 76 | ), 77 | dict(type='mmtrack.LoadTrackAnnotations'), 78 | dict(type='mmdet.Resize', scale=img_scale, keep_ratio=True), 79 | dict( 80 | type='mmdet.Pad', 81 | size_divisor=32, 82 | pad_val=dict(img=(114.0, 114.0, 114.0))), 83 | dict(type='mmtrack.PackTrackInputs', pack_single_img=True) 84 | ] 85 | 86 | train_dataset = dict( 87 | # use MultiImageMixDataset wrapper to support mosaic and mixup 88 | type='mmdet.MultiImageMixDataset', 89 | dataset=dict( 90 | type='SHIFTDataset', 91 | load_as_video=False, 92 | ann_file=data_root + 'discrete/images/train/front/det_2d_cocoformat.json', 93 | data_prefix=dict(img=''), 94 | ref_img_sampler=None, 95 | metainfo=dict(classes=('pedestrian', 'car', 'truck', 'bus', 'motorcycle', 'bicycle')), 96 | pipeline=[ 97 | dict(type='LoadImageFromFile', 98 | backend_args=dict( 99 | backend='zip', 100 | zip_path=data_root + 'discrete/images/train/front/img.zip', 101 | ) 102 | ), 103 | dict(type='mmtrack.LoadTrackAnnotations'), 104 | ], 105 | filter_cfg=dict( 106 | attributes=attributes, 107 | filter_empty_gt=False, 108 | min_size=32 109 | )), 110 | pipeline=train_pipeline) 111 | train_dataloader = dict( 112 | batch_size=batch_size, 113 | num_workers=4, 114 | persistent_workers=True, 115 | sampler=dict(type='DefaultSampler', shuffle=True), 116 | dataset=train_dataset) 117 | 118 | val_dataset=dict( 119 | type='SHIFTDataset', 120 | load_as_video=False, 121 | ann_file=data_root + 'discrete/images/val/front/det_2d_cocoformat.json', 122 | data_prefix=dict(img=''), 123 | ref_img_sampler=None, 124 | test_mode=True, 125 | filter_cfg=dict(attributes=attributes), 126 | pipeline=test_pipeline, 127 | metainfo=dict(classes=('pedestrian', 'car', 'truck', 'bus', 'motorcycle', 'bicycle'))) 128 | val_dataloader = dict( 129 | batch_size=1, 130 | num_workers=4, 131 | persistent_workers=True, 132 | drop_last=False, 133 | sampler=dict(type='DefaultSampler', shuffle=False), 134 | dataset=val_dataset) 135 | test_dataloader = val_dataloader 136 | # optimizer 137 | # default 8 gpu 138 | lr = 0.0005 / 8 * batch_size 139 | optim_wrapper = dict( 140 | type='OptimWrapper', 141 | optimizer=dict( 142 | type='SGD', lr=lr, momentum=0.9, weight_decay=5e-4, nesterov=True), 143 | paramwise_cfg=dict(norm_decay_mult=0., bias_decay_mult=0.)) 144 | 145 | # some hyper parameters 146 | # training settings 147 | total_epochs = 24 148 | num_last_epochs = 2 149 | resume_from = None 150 | interval = 2 151 | 152 | train_cfg = dict( 153 | type='EpochBasedTrainLoop', max_epochs=total_epochs, val_interval=1) 154 | val_cfg = dict(type='ValLoop') 155 | test_cfg = dict(type='TestLoop') 156 | # learning policy 157 | param_scheduler = [ 158 | dict( 159 | # use quadratic formula to warm up 1 epochs 160 | # and lr is updated by iteration 161 | type='mmdet.QuadraticWarmupLR', 162 | by_epoch=True, 163 | begin=0, 164 | end=1, 165 | convert_to_iter_based=True), 166 | dict( 167 | # use cosine lr from 1 to epoch #(total_epochs - num_last_epochs) 168 | type='mmdet.CosineAnnealingLR', 169 | eta_min=lr * 0.05, 170 | begin=1, 171 | T_max=total_epochs - num_last_epochs, 172 | end=total_epochs - num_last_epochs, 173 | by_epoch=True, 174 | convert_to_iter_based=True), 175 | dict( 176 | # use fixed lr during last 10 epochs 177 | type='mmdet.ConstantLR', 178 | by_epoch=True, 179 | factor=1, 180 | begin=total_epochs - num_last_epochs, 181 | end=total_epochs, 182 | ) 183 | ] 184 | 185 | custom_hooks = [ 186 | dict( 187 | type='mmtrack.YOLOXModeSwitchHook', 188 | num_last_epochs=num_last_epochs, 189 | priority=48), 190 | dict(type='mmdet.SyncNormHook', priority=48), 191 | dict( 192 | type='mmdet.EMAHook', 193 | ema_type='mmdet.ExpMomentumEMA', 194 | momentum=0.0001, 195 | update_buffers=True, 196 | priority=49) 197 | ] 198 | default_hooks = dict(checkpoint=dict(interval=1)) 199 | 200 | # evaluator 201 | val_evaluator = [ 202 | dict(type='SHIFTVideoMetric', metric=['bbox'], classwise=True), 203 | ] 204 | test_evaluator = val_evaluator -------------------------------------------------------------------------------- /configs/source/yolox/yolox_x_8xb4-24e_shift_night.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../../_base_/models/yolox_x_8x8.py', 3 | '../../_base_/default_runtime.py' 4 | ] 5 | 6 | dataset_type = 'SHIFTDataset' 7 | data_root = 'data/shift/' 8 | attributes = dict(timeofday_coarse='night') 9 | 10 | img_scale = (800, 1440) 11 | batch_size = 2 12 | 13 | model = dict( 14 | type='AdaptiveDetector', 15 | data_preprocessor=dict( 16 | type='mmtrack.TrackDataPreprocessor', 17 | pad_size_divisor=32, 18 | batch_augments=[ 19 | dict( 20 | type='mmdet.BatchSyncRandomResize', 21 | random_size_range=(576, 1024), 22 | size_divisor=32, 23 | interval=10) 24 | ]), 25 | detector=dict( 26 | _scope_='mmdet', 27 | bbox_head=dict(num_classes=6), 28 | test_cfg=dict(score_thr=0.01, nms=dict(type='nms', iou_threshold=0.7)), 29 | init_cfg=dict( 30 | type='Pretrained', 31 | checkpoint= # noqa: E251 32 | 'https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_x_8x8_300e_coco/yolox_x_8x8_300e_coco_20211126_140254-1ef88d67.pth' # noqa: E501 33 | )), 34 | adapter=None) 35 | 36 | train_pipeline = [ 37 | dict( 38 | type='mmdet.Mosaic', 39 | img_scale=img_scale, 40 | pad_val=114.0, 41 | bbox_clip_border=False), 42 | dict( 43 | type='mmdet.RandomAffine', 44 | scaling_ratio_range=(0.1, 2), 45 | border=(-img_scale[0] // 2, -img_scale[1] // 2), 46 | bbox_clip_border=False), 47 | dict( 48 | type='mmdet.MixUp', 49 | img_scale=img_scale, 50 | ratio_range=(0.8, 1.6), 51 | pad_val=114.0, 52 | bbox_clip_border=False), 53 | dict(type='mmdet.YOLOXHSVRandomAug'), 54 | dict(type='mmdet.RandomFlip', prob=0.5), 55 | dict( 56 | type='mmdet.Resize', 57 | scale=img_scale, 58 | keep_ratio=True, 59 | clip_object_border=False), 60 | dict( 61 | type='mmdet.Pad', 62 | size_divisor=32, 63 | pad_val=dict(img=(114.0, 114.0, 114.0))), 64 | dict( 65 | type='mmdet.FilterAnnotations', 66 | min_gt_bbox_wh=(1, 1), 67 | keep_empty=False), 68 | dict(type='mmtrack.PackTrackInputs', pack_single_img=True) 69 | ] 70 | test_pipeline = [ 71 | dict(type='LoadImageFromFile', 72 | backend_args=dict( 73 | backend='zip', 74 | zip_path=data_root + 'discrete/images/val/front/img.zip', 75 | ) 76 | ), 77 | dict(type='mmtrack.LoadTrackAnnotations'), 78 | dict(type='mmdet.Resize', scale=img_scale, keep_ratio=True), 79 | dict( 80 | type='mmdet.Pad', 81 | size_divisor=32, 82 | pad_val=dict(img=(114.0, 114.0, 114.0))), 83 | dict(type='mmtrack.PackTrackInputs', pack_single_img=True) 84 | ] 85 | 86 | train_dataset = dict( 87 | # use MultiImageMixDataset wrapper to support mosaic and mixup 88 | type='mmdet.MultiImageMixDataset', 89 | dataset=dict( 90 | type='SHIFTDataset', 91 | load_as_video=False, 92 | ann_file=data_root + 'discrete/images/train/front/det_2d_cocoformat.json', 93 | data_prefix=dict(img=''), 94 | ref_img_sampler=None, 95 | metainfo=dict(classes=('pedestrian', 'car', 'truck', 'bus', 'motorcycle', 'bicycle')), 96 | pipeline=[ 97 | dict(type='LoadImageFromFile', 98 | backend_args=dict( 99 | backend='zip', 100 | zip_path=data_root + 'discrete/images/train/front/img.zip', 101 | ) 102 | ), 103 | dict(type='mmtrack.LoadTrackAnnotations'), 104 | ], 105 | filter_cfg=dict( 106 | attributes=attributes, 107 | filter_empty_gt=False, 108 | min_size=32 109 | )), 110 | pipeline=train_pipeline) 111 | train_dataloader = dict( 112 | batch_size=batch_size, 113 | num_workers=4, 114 | persistent_workers=True, 115 | sampler=dict(type='DefaultSampler', shuffle=True), 116 | dataset=train_dataset) 117 | 118 | val_dataset=dict( 119 | type='SHIFTDataset', 120 | load_as_video=False, 121 | ann_file=data_root + 'discrete/images/val/front/det_2d_cocoformat.json', 122 | data_prefix=dict(img=''), 123 | ref_img_sampler=None, 124 | test_mode=True, 125 | filter_cfg=dict(attributes=attributes), 126 | pipeline=test_pipeline, 127 | metainfo=dict(classes=('pedestrian', 'car', 'truck', 'bus', 'motorcycle', 'bicycle'))) 128 | val_dataloader = dict( 129 | batch_size=1, 130 | num_workers=4, 131 | persistent_workers=True, 132 | drop_last=False, 133 | sampler=dict(type='DefaultSampler', shuffle=False), 134 | dataset=val_dataset) 135 | test_dataloader = val_dataloader 136 | # optimizer 137 | # default 8 gpu 138 | lr = 0.0005 / 8 * batch_size 139 | optim_wrapper = dict( 140 | type='OptimWrapper', 141 | optimizer=dict( 142 | type='SGD', lr=lr, momentum=0.9, weight_decay=5e-4, nesterov=True), 143 | paramwise_cfg=dict(norm_decay_mult=0., bias_decay_mult=0.)) 144 | 145 | # some hyper parameters 146 | # training settings 147 | total_epochs = 24 148 | num_last_epochs = 2 149 | resume_from = None 150 | interval = 2 151 | 152 | train_cfg = dict( 153 | type='EpochBasedTrainLoop', max_epochs=total_epochs, val_interval=1) 154 | val_cfg = dict(type='ValLoop') 155 | test_cfg = dict(type='TestLoop') 156 | # learning policy 157 | param_scheduler = [ 158 | dict( 159 | # use quadratic formula to warm up 1 epochs 160 | # and lr is updated by iteration 161 | type='mmdet.QuadraticWarmupLR', 162 | by_epoch=True, 163 | begin=0, 164 | end=1, 165 | convert_to_iter_based=True), 166 | dict( 167 | # use cosine lr from 1 to epoch #(total_epochs - num_last_epochs) 168 | type='mmdet.CosineAnnealingLR', 169 | eta_min=lr * 0.05, 170 | begin=1, 171 | T_max=total_epochs - num_last_epochs, 172 | end=total_epochs - num_last_epochs, 173 | by_epoch=True, 174 | convert_to_iter_based=True), 175 | dict( 176 | # use fixed lr during last 10 epochs 177 | type='mmdet.ConstantLR', 178 | by_epoch=True, 179 | factor=1, 180 | begin=total_epochs - num_last_epochs, 181 | end=total_epochs, 182 | ) 183 | ] 184 | 185 | custom_hooks = [ 186 | dict( 187 | type='mmtrack.YOLOXModeSwitchHook', 188 | num_last_epochs=num_last_epochs, 189 | priority=48), 190 | dict(type='mmdet.SyncNormHook', priority=48), 191 | dict( 192 | type='mmdet.EMAHook', 193 | ema_type='mmdet.ExpMomentumEMA', 194 | momentum=0.0001, 195 | update_buffers=True, 196 | priority=49) 197 | ] 198 | default_hooks = dict(checkpoint=dict(interval=1)) 199 | 200 | # evaluator 201 | val_evaluator = [ 202 | dict(type='SHIFTVideoMetric', metric=['bbox'], classwise=True), 203 | ] 204 | test_evaluator = val_evaluator -------------------------------------------------------------------------------- /docs/challenge.md: -------------------------------------------------------------------------------- 1 | # Workshop on Visual Continual Learning @ ICCV2023 2 | 3 | # Challenge on Continual Test-time Adaptation for Object Detection 4 | ## Goal 5 | We introduce the [1st Challenge on Continual Test-time Adaptation for Object Detection](https://wvcl.vis.xyz/challenges). 6 | 7 | The goal of this challenge is training an object detector on the SHIFT clear-daytime subset (source domain) and adapting it to the set of SHIFT sequences with continuous domain shift starting from clear-daytime conditions. 8 | 9 | ## Rules 10 | - Using additional data is **not** allowed; 11 | - Any detector architecture can be used; 12 | - The model should be adapted on the fly to each target sequence, and reset to its original state at the end of every sequence. 13 | 14 | You can find a reference implementation for an [AdaptiveDetector](shift_tta/models/detectors/adaptive_detector.py) class wrapping any object detector and an adapter, a [BaseAdapter](shift_tta/models/adapters/base_adapter.py) class and a reference implementation of a [mean-teacher adapter](shift_tta/models/adapters/mean_teacher_adapter_yolox.py) based on YOLOX. 15 | 16 | ## Prize 17 | We will award the top three teams of each challenge with a certificate and a prize of 1000, 500, and 300 USD, respectively. The winners of each challenge will be invited to give a presentation at the workshop. Teams will be selected based on the performance of their methods on the test set. 18 | 19 | We will also award one team from each challenge with an Innovation Award. The Innovation Award is given to the team that proposes the most innovative method and/or insightful analysis. The winner will receive a certificate and an additional prize of 300 USD. 20 | 21 | **Please notice** that this challenge is part of the track **Challenge B - Continual Test-time Adaptation**, together with the challenge on "Continuous Test-time Adaptation for Semantic Segmentation". Since the challenge on "Continuous Test-time Adaptation for Object Detection" constitutes half of the track B, the prize should be considered half of what mentioned above. 22 | 23 | ## Instructions 24 | 25 | ### Train a model on the source domain 26 | First, train an object detection model on the source domain. You may choose any object detector architecture. 27 | 28 | You can find a reference training script at [scripts/source/train_yolox_shift_clear_daytime.sh](scripts/source/train_yolox_shift_clear_daytime.sh) to train a YOLOX model on the SHIFT clear-daytime discrete set. 29 | 30 | We use the discrete set of SHIFT to train the object detector. 31 | 32 | You can also download a YOLOX checkpoint pre-trained using the above-mentioned script at [link](https://dl.cv.ethz.ch/shift/challenge2023/test_time_adaptation/checkpoints/yolox_x_8xb4-24e_shift_clear_daytime.pth). 33 | 34 | ### Test the source model on the target domain 35 | Then, validate the source model on the validation set of the continuous target domain. In particular, we validate on the videos presenting continuous domain shift starting from the clear-daytime conditions. The validation set should be used for validating your method under continuous domain shift and for hyperparameter search. 36 | 37 | You can find a reference validation script at [scripts/continuous/no_adap_yolox/val_yolox_shift_from_clear_daytime.sh](scripts/continuous/no_adap_yolox/val_yolox_shift_from_clear_daytime.sh). 38 | 39 | 40 | ### Continuously adapt a model to the validation target domain 41 | You can now validate your test-time adaptation baseline on the validation videos presenting continuous domain shift starting from the clear-daytime conditions. The validation set should be used for validating your method under continuous domain shift and for hyperparameter search. 42 | 43 | We implemented a baseline adapter based on a detection consistency loss and a mean-teacher formulation. You can find an implementation of the adapter at [mean_teacher_yolox_adapter](shift_tta/models/adapters/mean_teacher_yolox_adapter.py), and the corresponding config file at [configs/continuous/mean_teacher_adapter_yolox/yolox_x_8xb4-24e_shift_from_clear_daytime.py](configs/continuous/mean_teacher_adapter_yolox/yolox_x_8xb4-24e_shift_from_clear_daytime.py). 44 | 45 | You can run the adaptation script on the validation set using [scripts/continuous/mean_teacher_adapter_yolox/val_yolox_shift_from_clear_daytime.sh](scripts/continuous/mean_teacher_adapter_yolox/val_yolox_shift_from_clear_daytime.sh) 46 | 47 | ### Continuously adapt a model to the test target domain 48 | Finally, collect your results on the test set and submit to our evaluation [benchmark](https://evalai.vis.xyz/web/challenges/challenge-page/6/overview). 49 | 50 | You can now test your test-time adaptation baseline on the test videos presenting continuous domain shift starting from the clear-daytime conditions. 51 | 52 | We implemented a baseline adapter based on a detection consistency loss and a mean-teacher formulation. You can find an implementation of the adapter at [mean_teacher_yolox_adapter](shift_tta/models/adapters/mean_teacher_yolox_adapter.py), and the corresponding config file at [configs/continuous/mean_teacher_adapter_yolox/yolox_x_8xb4-24e_shift_from_clear_daytime.py](configs/continuous/mean_teacher_adapter_yolox/yolox_x_8xb4-24e_shift_from_clear_daytime.py). 53 | 54 | You can run the adaptation script on the validation set using [scripts/continuous/mean_teacher_adapter_yolox/test_yolox_shift_from_clear_daytime.sh](scripts/continuous/mean_teacher_adapter_yolox/test_yolox_shift_from_clear_daytime.sh) 55 | 56 | 57 | ## Submission 58 | 59 | ### Submit your results 60 | Running the above-mentioned scripts with the following `CFG_OPTIONS` stores results in the [Scalabel](https://www.scalabel.ai/) format in `${WORK_DIR}/results`: 61 | 62 | ```bash 63 | declare -a CFG_OPTIONS=( 64 | "test_evaluator.0.outfile_prefix=${WORK_DIR}/results" 65 | ) 66 | ``` 67 | 68 | Identify the file ending with `.scalabel.json` and submit it to our [evaluation benchmark](https://evalai.vis.xyz/web/challenges/challenge-page/6/overview) to participate in the challenge. 69 | 70 | ### Submit a technical report 71 | 72 | We require participants to submit a short report providing details on their solution to [vcl.iccvworkshop.2023@gmail.com](mailto:vcl.iccvworkshop.2023@gmail.com). 73 | 74 | Remember that we will also award one team from each challenge with an Innovation Award. The Innovation Award is given to the team that proposes the most innovative method and/or insightful analysis. The winner will receive a certificate and an additional prize of 300 USD. 75 | 76 | Optionally, participant may submit their code or open a pull request after the challenge deadline if they want their adapter included in this repository. 77 | -------------------------------------------------------------------------------- /docs/dataset_prepare.md: -------------------------------------------------------------------------------- 1 | ## SHIFT Dataset Preparation 2 | 3 | This page provides the instructions for the [SHIFT](https://www.vis.xyz/shift/) dataset preparation. 4 | 5 | ### 1. Downloading the Dataset 6 | 7 | Please download the SHIFT dataset from the [official website](https://www.vis.xyz/shift/download/) to your $DATADIR. It is recommended to symlink the root of the datasets to `$SHIFT_DETECTION_TTA/data`. This will avoid storing large files in your project directory, a requirement of several high-performance computing systems. 8 | 9 | Examples of other directories that we recommend to symlink are `checkpoints/`, `data/`, `work_dir/`. 10 | 11 | Symlink your data directory to the `$SHIFT_DETECTION_TTA` base directory using: 12 | 13 | ```shell 14 | ln -s $DATADIR/ $SHIFT_DETECTION_TTA/ 15 | ``` 16 | 17 | Then, use the official [download.py](https://github.com/SysCV/shift-dev/blob/main/download.py) script provided with the SHIFT devkit to download the dataset. 18 | 19 | ```shell 20 | mkdir -p $DATADIR/shift 21 | 22 | # Download the discrete shift set for training source models 23 | python tools/shift/download.py \ 24 | --view "[front]" --group "[img, det_2d]" \ 25 | --split "[train, val]" --framerate "[images]" \ 26 | --shift "discrete" \ 27 | $DATADIR/shift 28 | 29 | # Download the continuous shift set for test-time adaptation 30 | python tools/shift/download.py \ 31 | --view "[front]" --group "[img, det_2d]" \ 32 | --split "[val, test]" --framerate "[videos]" \ 33 | --shift "continuous/1x" \ 34 | $DATADIR/shift 35 | ``` 36 | 37 | #### 1.1 Data Structure 38 | 39 | We here report the recommended data structure. If your folder structure is different from the following, you may need to change the corresponding paths in the config files. 40 | 41 | ``` 42 | shift-detection-tta 43 | ├── shift_tta 44 | ├── tools 45 | ├── configs 46 | ├── data 47 | │ ├── shift 48 | │ │ ├── discrete 49 | │ │ │ ├── images 50 | │ │ │ │ ├── train 51 | │ │ │ │ │ ├── front 52 | │ │ │ │ │ │ ├── img.zip 53 | │ │ │ │ │ │ ├── det_2d.json (the official annotation files) 54 | │ │ │ │ ├── val 55 | │ │ │ │ │ │ ├── img.zip 56 | │ │ │ │ │ │ ├── det_2d.json (the official annotation files) 57 | │ │ ├── continuous 58 | │ │ │ ├── videos 59 | │ │ │ │ ├── 1x 60 | │ │ │ │ │ ├── val 61 | │ │ │ │ │ │ ├── front 62 | │ │ │ │ │ │ │ ├── img.tar 63 | │ │ │ │ │ │ │ ├── det_2d.json (the official annotation files) 64 | │ │ │ │ │ ├── test 65 | │ │ │ │ │ │ │ ├── img.tar 66 | │ │ │ │ │ │ │ ├── det_2d.json (the official annotation files) 67 | ``` 68 | 69 | 70 | ### 2. Process the Dataset 71 | 72 | ### 2.1 Decompress the Dataset 73 | To ensure reproducible decompression of videos, we recommend using the [Docker image](https://github.com/SysCV/shift-dev/blob/main/Dockerfile) from the [official SHIFT devkit](https://github.com/SysCV/shift-dev). You could refer to the Docker engine's installation doc. 74 | 75 | ```shell 76 | # clone the SHIFT devkit 77 | git clone git@github.com:SysCV/shift-dev.git 78 | cd shift-dev 79 | 80 | # build and install our Docker image 81 | docker build -t shift_dataset_decompress . 82 | 83 | # run the container (the mode is set to "tar") 84 | docker run -v :/data -e MODE=tar shift_dataset_decompress 85 | # Here, denotes the root path under which all tar files will be processed recursively. The mode and number of jobs can be configured through environment variables MODE and JOBS. 86 | ``` 87 | 88 | The folder structure will be as following after your run these scripts: 89 | 90 | ``` 91 | shift-detection-tta 92 | ├── shift_tta 93 | ├── tools 94 | ├── configs 95 | ├── data 96 | │ ├── shift 97 | │ │ ├── discrete 98 | │ │ │ ├── images 99 | │ │ │ │ ├── train 100 | │ │ │ │ │ ├── front 101 | │ │ │ │ │ │ ├── img.zip 102 | │ │ │ │ │ │ ├── det_2d.json (the official annotation files) 103 | │ │ │ │ ├── val 104 | │ │ │ │ │ ├── front 105 | │ │ │ │ │ │ ├── img.zip 106 | │ │ │ │ │ │ ├── det_2d.json (the official annotation files) 107 | │ │ ├── continuous 108 | │ │ │ ├── videos 109 | │ │ │ │ ├── 1x 110 | │ │ │ │ │ ├── val 111 | │ │ │ │ │ │ ├── front 112 | │ │ │ │ │ │ │ ├── img.tar 113 | │ │ │ │ │ │ │ ├── img_decompressed.tar 114 | │ │ │ │ │ │ │ ├── det_2d.json (the official annotation files) 115 | │ │ │ │ │ ├── test 116 | │ │ │ │ │ │ ├── front 117 | │ │ │ │ │ │ │ ├── img.tar 118 | │ │ │ │ │ │ │ ├── img_decompressed.tar 119 | │ │ │ │ │ │ │ ├── det_2d.json (the official annotation files) 120 | ``` 121 | 122 | ### 2.2 Convert Annotations 123 | 124 | We use [CocoVID](https://github.com/open-mmlab/mmtracking/blob/master/mmtrack/datasets/parsers/coco_video_parser.py) to maintain all datasets in this codebase. 125 | 126 | In this case, you need to convert the official annotations to this style. We provide scripts and the usages are as following: 127 | 128 | ```shell 129 | # SHIFT discrete (images, detection-like) 130 | python -m scalabel.label.to_coco -m box_track -i $DATADIR/shift/discrete/images/$SET_NAME/front/det_2d.json -o $DATADIR/shift/discrete/images/$SET_NAME/front/det_2d_cocoformat.json 131 | 132 | # SHIFT continuous (videos, tracking-like) 133 | python -m scalabel.label.to_coco -m box_track -i $DATADIR/shift/continuous/videos/1x/$SET_NAME/front/det_2d.json -o $DATADIR/shift/continuous/videos/1x/$SET_NAME/front/det_2d_cocoformat.json 134 | ``` 135 | 136 | where `$SET_NAME` is one of `[train, val, test]`. 137 | 138 | 139 | The folder structure will be as following after your run these scripts: 140 | 141 | ``` 142 | shift-detection-tta 143 | ├── shift_tta 144 | ├── tools 145 | ├── configs 146 | ├── data 147 | │ ├── shift 148 | │ │ ├── discrete 149 | │ │ │ ├── images 150 | │ │ │ │ ├── train 151 | │ │ │ │ │ ├── front 152 | │ │ │ │ │ │ ├── img.zip 153 | │ │ │ │ │ │ ├── det_2d.json (the official annotation files) 154 | │ │ │ │ │ │ ├── det_2d_cocoformat.json (the converted annotation file) 155 | │ │ │ │ ├── val 156 | │ │ │ │ │ ├── front 157 | │ │ │ │ │ │ ├── img.zip 158 | │ │ │ │ │ │ ├── det_2d.json (the official annotation files) 159 | │ │ │ │ │ │ ├── det_2d_cocoformat.json (the converted annotation file) 160 | │ │ ├── continuous 161 | │ │ │ ├── videos 162 | │ │ │ │ ├── 1x 163 | │ │ │ │ │ ├── val 164 | │ │ │ │ │ │ ├── front 165 | │ │ │ │ │ │ │ ├── img.tar 166 | │ │ │ │ │ │ │ ├── img_decompressed.tar 167 | │ │ │ │ │ │ │ ├── det_2d.json (the official annotation files) 168 | │ │ │ │ │ │ │ ├── det_2d_cocoformat.json (the converted annotation file) 169 | │ │ │ │ │ ├── test 170 | │ │ │ │ │ │ ├── front 171 | │ │ │ │ │ │ │ ├── img.tar 172 | │ │ │ │ │ │ │ ├── img_decompressed.tar 173 | │ │ │ │ │ │ │ ├── det_2d.json (the official annotation files) 174 | │ │ │ │ │ │ │ ├── det_2d_cocoformat.json (the converted annotation file) 175 | ``` 176 | 177 | ### 2.3 Dataset Loading 178 | Some high-performance clusters do not support folders with a large number of files. For this reason, we implemented a [ZipBackend](shift_tta/fileio/backends/zip_backend.py) and a [TarBackend](shift_tta/fileio/backends/tar_backend.py) for loading data directly from `.zip` and `.tar` files. 179 | 180 | For usage, refer to the [`shift.py`](configs/_base_/datasets/shift.py) config file. 181 | 182 | # TODO: we might have to split the dataset file into two, depending on how we want to handle testing on the discrete val set and adapting to the continuous val set. 183 | -------------------------------------------------------------------------------- /docs/get_started.md: -------------------------------------------------------------------------------- 1 | ## Prerequisites 2 | 3 | - Linux | macOS | Windows 4 | - Python 3.6+ 5 | - PyTorch 1.6+ 6 | - CUDA 9.2+ (If you build PyTorch from source, CUDA 9.0 is also compatible) 7 | - GCC 5+ 8 | - [MMCV](https://mmcv.readthedocs.io/en/latest/get_started/installation.html) 9 | - [MMEngine](https://mmengine.readthedocs.io/en/latest/get_started/installation.html) 10 | - [MMDetection](https://mmdetection.readthedocs.io/en/latest/get_started.html#installation) 11 | - [MMTracking](https://mmtracking.readthedocs.io/en/latest/install.html#installation) 12 | 13 | The compatible MMTracking, MMEngine, MMCV, and MMDetection versions are as below. Please install the correct version to avoid installation issues. 14 | 15 | | MMTracking version | MMEngine version | MMCV version | MMDetection version | 16 | | :----------------: | :--------------: | :--------------------: | :---------------------: | 17 | | 1.x | mmengine>=0.1.0 | mmcv>=2.0.0rc1,\<2.0.0 | mmdet>=3.0.0rc0,\<3.0.0 | 18 | | 1.0.0rc1 | mmengine>=0.1.0 | mmcv>=2.0.0rc1,\<2.0.0 | mmdet>=3.0.0rc0,\<3.0.0 | 19 | 20 | ## Installation 21 | 22 | ### Detailed Instructions 23 | 24 | 1. Create a conda virtual environment and activate it. 25 | 26 | ```shell 27 | conda create -n shift-tta python=3.9 -y 28 | conda activate shift-tta 29 | ``` 30 | 31 | 2. Install PyTorch and torchvision following the [official instructions](https://pytorch.org/). Here we use PyTorch 1.10.0 and CUDA 11.1. 32 | You may also switch to other version by specifying the version number. 33 | 34 | **Install with conda** 35 | 36 | ```shell 37 | conda install pytorch=1.11.0 torchvision cudatoolkit=11.3 -c pytorch 38 | ``` 39 | 40 | **Install with pip** 41 | 42 | ```shell 43 | pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 --extra-index-url https://download.pytorch.org/whl/cu113 44 | ``` 45 | 46 | 3. Install MMEngine 47 | 48 | ```shell 49 | pip install mmengine 50 | ``` 51 | 52 | 4. Install mmcv, we recommend you to install the pre-build package as below. 53 | 54 | ```shell 55 | pip install 'mmcv>=2.0.0rc1' -f https://download.openmmlab.com/mmcv/dist/{cu_version}/{torch_version}/index.html 56 | ``` 57 | 58 | mmcv is only compiled on PyTorch 1.x.0 because the compatibility usually holds between 1.x.0 and 1.x.1. If your PyTorch version is 1.x.1, you can install mmcv compiled with PyTorch 1.x.0 and it usually works well. 59 | 60 | ```shell 61 | # We can ignore the micro version of PyTorch 62 | pip install 'mmcv>=2.0.0rc1' -f https://download.openmmlab.com/mmcv/dist/cu113/torch1.11.0/index.html 63 | ``` 64 | 65 | See [here](https://mmcv.readthedocs.io/en/latest/get_started/installation.html) for different versions of MMCV compatible to different PyTorch and CUDA versions. 66 | Optionally you can choose to compile mmcv from source by the following command 67 | 68 | ```shell 69 | git clone -b 2.x https://github.com/open-mmlab/mmcv.git 70 | cd mmcv 71 | MMCV_WITH_OPS=1 pip install -e . # package mmcv, which contains cuda ops, will be installed after this step 72 | # pip install -e . # package mmcv, which contains no cuda ops, will be installed after this step 73 | cd .. 74 | ``` 75 | 76 | **Important**: You need to run pip uninstall mmcv-lite first if you have mmcv installed. Because if mmcv-lite and mmcv are both installed, there will be ModuleNotFoundError. 77 | 78 | 5. Install MMDetection 79 | 80 | ```shell 81 | pip install 'mmdet>=3.0.0rc0' 82 | ``` 83 | 84 | Optionally, you can also build MMDetection from source in case you want to modify the code: 85 | 86 | ```shell 87 | git clone -b 3.x https://github.com/open-mmlab/mmdetection.git 88 | cd mmdetection 89 | pip install -r requirements/build.txt 90 | pip install -v -e . # or "python setup.py develop" 91 | ``` 92 | 93 | 6. Install MMTracking 94 | 95 | ```shell 96 | pip install 'mmtrack>=1.0.0rc1' 97 | ``` 98 | 99 | Optionally, you can also build MMTracking from source in case you want to modify the code: 100 | 101 | ```shell 102 | git clone -b 1.x https://github.com/open-mmlab/mmtracking.git 103 | cd mmtracking 104 | pip install -r requirements/build.txt 105 | pip install -v -e . # or "python setup.py develop" 106 | ``` 107 | 108 | 7. Clone the shift-detection-tta repository. 109 | 110 | ```shell 111 | git clone git@github.com:SysCV/shift-detection-tta.git 112 | cd shift-detection-tta 113 | ``` 114 | 115 | 8. Install build requirements and then install shift-detection-tta. 116 | 117 | ```shell 118 | pip install -r requirements/build.txt 119 | pip install -v -e . # or "python setup.py develop" 120 | ``` 121 | 122 | Note: 123 | 124 | a. Following the above instructions, shift-detection-tta is installed on `dev` mode 125 | , any local modifications made to the code will take effect without the need to reinstall it. 126 | 127 | b. If you would like to use `opencv-python-headless` instead of `opencv-python`, 128 | you can install it before installing MMCV. 129 | 130 | ### A from-scratch setup script 131 | 132 | Assuming that you already have CUDA 10.1 installed, here is a full script for setting up shift-detection-tta with conda. 133 | 134 | ```shell 135 | conda create -n shift-tta python=3.9 -y 136 | conda activate shift-tta 137 | 138 | conda install pytorch=1.11.0 torchvision cudatoolkit=11.3 -c pytorch -y 139 | 140 | pip install -U openmim 141 | # install mmengine from main branch 142 | python -m pip install git+https://github.com/open-mmlab/mmengine.git@62f9504d701251db763f56658436fd23a586fe25 143 | # install mmcv 144 | mim install 'mmcv == 2.0.0rc4' 145 | # install mmdetection 146 | mim install 'mmdet == 3.0.0rc5' 147 | # install mmclassification from dev-1.x branch at specific commit 148 | python -m pip install git+https://github.com/open-mmlab/mmclassification.git@3ff80f5047fe3f3780a05d387f913dd02999611d 149 | # install mmtracking from dev-1.x branch at specific commit 150 | python -m pip install git+https://github.com/open-mmlab/mmtracking.git@9e4cb98a3cdac749242cd8decb3a172058d4fd6e 151 | 152 | # install trackeval for compatibility with mmtrack 153 | python -m pip install git+https://github.com/JonathonLuiten/TrackEval.git 154 | # install scalabel 155 | python -m pip install git+https://github.com/scalabel/scalabel.git 156 | 157 | # install shift-detection-tta 158 | git clone git@github.com:SysCV/shift-detection-tta.git 159 | cd shift-detection-tta 160 | python -m pip install --no-input -r requirements.txt 161 | pip install --no-input -v -e . 162 | ``` 163 | 164 | 165 | Alternatively (and recommended), clone the repository and directly run the install script [setup_env.sh](tools/install/setup_env.sh): 166 | 167 | ```shell 168 | git clone git@github.com:SysCV/shift-detection-tta.git 169 | cd shift-detection-tta 170 | ./tools/install/setup_env.sh 171 | ``` 172 | 173 | -------------------------------------------------------------------------------- /docs/model_zoo.md: -------------------------------------------------------------------------------- 1 | # Model Zoo 2 | 3 | ## Continual test-time adaptation for object detection 4 | ### No Adaptation 5 | 6 | Please refer to [no_adap_yolox](configs/continuous/no_adap_yolox) for details. 7 | 8 | ### Mean-teacher 9 | 10 | Please refer to [mean_teacher_adapter_yolox](configs/continuous/mean_teacher_adapter_yolox) for details. 11 | -------------------------------------------------------------------------------- /docs/train_test.md: -------------------------------------------------------------------------------- 1 | # Learn to train and test 2 | 3 | ## Train 4 | 5 | This section will show how to train existing models on supported datasets. 6 | The following training environments are supported: 7 | 8 | - CPU 9 | - single GPU 10 | - single node multiple GPUs 11 | - multiple nodes 12 | 13 | You can also manage jobs with Slurm. 14 | 15 | Important: 16 | 17 | - You can change the evaluation interval during training by modifying the `train_cfg` as 18 | `train_cfg = dict(val_interval=10)`. That means evaluating the model every 10 epochs. 19 | - The default learning rate in all config files is for 8 GPUs. 20 | According to the [Linear Scaling Rule](https://arxiv.org/abs/1706.02677), 21 | you need to set the learning rate proportional to the batch size if you use different GPUs or images per GPU, 22 | e.g., `lr=0.01` for 8 GPUs * 1 img/gpu and lr=0.04 for 16 GPUs * 2 imgs/gpu. 23 | - During training, log files and checkpoints will be saved to the working directory, 24 | which is specified by CLI argument `--work-dir`. It uses `./work_dirs/CONFIG_NAME` as default. 25 | - If you want the mixed precision training, simply specify CLI argument `--amp`. 26 | 27 | #### 1. Train on CPU 28 | 29 | The model is default put on cuda device. 30 | Only if there are no cuda devices, the model will be put on cpu. 31 | So if you want to train the model on CPU, you need to `export CUDA_VISIBLE_DEVICES=-1` to disable GPU visibility first. 32 | More details in [MMEngine](https://github.com/open-mmlab/mmengine/blob/ca282aee9e402104b644494ca491f73d93a9544f/mmengine/runner/runner.py#L849-L850). 33 | 34 | ```shell script 35 | CUDA_VISIBLE_DEVICES=-1 python tools/train.py ${CONFIG_FILE} [optional arguments] 36 | ``` 37 | 38 | An example of training the VID model DFF on CPU: 39 | 40 | ```shell script 41 | CUDA_VISIBLE_DEVICES=-1 python tools/train.py configs/source/yolox/amp_yolox_x_8xb4-24e_shift_clear_daytime.py 42 | ``` 43 | 44 | #### 2. Train on single GPU 45 | 46 | If you want to train the model on single GPU, you can directly use the `tools/train.py` as follows. 47 | 48 | ```shell script 49 | python tools/train.py ${CONFIG_FILE} [optional arguments] 50 | ``` 51 | 52 | You can use `export CUDA_VISIBLE_DEVICES=$GPU_ID` to select the GPU. 53 | 54 | An example of training the MOT model ByteTrack on single GPU: 55 | 56 | ```shell script 57 | CUDA_VISIBLE_DEVICES=2 python tools/train.py configs/source/yolox/amp_yolox_x_8xb4-24e_shift_clear_daytime.py 58 | ``` 59 | 60 | #### 3. Train on single node multiple GPUs 61 | 62 | We provide `tools/dist_train.sh` to launch training on multiple GPUs. 63 | The basic usage is as follows. 64 | 65 | ```shell script 66 | bash ./tools/dist_train.sh ${CONFIG_FILE} ${GPU_NUM} [optional arguments] 67 | ``` 68 | 69 | If you would like to launch multiple jobs on a single machine, 70 | e.g., 2 jobs of 4-GPU training on a machine with 8 GPUs, 71 | you need to specify different ports (29500 by default) for each job to avoid communication conflict. 72 | 73 | For example, you can set the port in commands as follows. 74 | 75 | ```shell script 76 | CUDA_VISIBLE_DEVICES=0,1,2,3 PORT=29500 ./tools/dist_train.sh ${CONFIG_FILE} 4 77 | CUDA_VISIBLE_DEVICES=4,5,6,7 PORT=29501 ./tools/dist_train.sh ${CONFIG_FILE} 4 78 | ``` 79 | 80 | An example of training the SOT model SiameseRPN++ on single node multiple GPUs: 81 | 82 | ```shell script 83 | bash ./tools/dist_train.sh configs/source/yolox/amp_yolox_x_8xb4-24e_shift_clear_daytime.py 8 84 | ``` 85 | 86 | #### 4. Train on multiple nodes 87 | 88 | If you launch with multiple machines simply connected with ethernet, you can simply run following commands: 89 | 90 | On the first machine: 91 | 92 | ```shell script 93 | NNODES=2 NODE_RANK=0 PORT=$MASTER_PORT MASTER_ADDR=$MASTER_ADDR bash tools/dist_train.sh $CONFIG $GPUS 94 | ``` 95 | 96 | On the second machine: 97 | 98 | ```shell script 99 | NNODES=2 NODE_RANK=1 PORT=$MASTER_PORT MASTER_ADDR=$MASTER_ADDR bash tools/dist_train.sh $CONFIG $GPUS 100 | ``` 101 | 102 | Usually it is slow if you do not have high speed networking like InfiniBand. 103 | 104 | #### 5. Train with Slurm 105 | 106 | [Slurm](https://slurm.schedmd.com/) is a good job scheduling system for computing clusters. 107 | On a cluster managed by Slurm, you can use `slurm_train.sh` to spawn training jobs. 108 | It supports both single-node and multi-node training. 109 | 110 | The basic usage is as follows. 111 | 112 | ```shell script 113 | bash ./tools/slurm_train.sh ${PARTITION} ${JOB_NAME} ${CONFIG_FILE} ${WORK_DIR} ${GPUS} 114 | ``` 115 | 116 | An example of training the YOLOX detector on SHIFT clear-daytime with Slurm: 117 | 118 | ```shell script 119 | PORT=29501 \ 120 | GPUS_PER_NODE=8 \ 121 | SRUN_ARGS="--quotatype=reserved" \ 122 | bash ./tools/slurm_train.sh \ 123 | mypartition \ 124 | YOLOX_shift_clear_daytime \ 125 | configs/source/yolox/amp_yolox_x_8xb4-24e_shift_clear_daytime.py \ 126 | ./work_dirs/YOLOX_shift_clear_daytime \ 127 | 8 128 | ``` 129 | 130 | ## Test 131 | 132 | This section will show how to test existing models on supported datasets. 133 | The following testing environments are supported: 134 | 135 | - CPU 136 | - single GPU 137 | - single node multiple GPUs 138 | - multiple nodes 139 | 140 | You can also manage jobs with Slurm. 141 | 142 | Important: 143 | 144 | - You can set the results saving path by modifying the key `outfile_prefix` in evaluator. 145 | For example, `val_evaluator = dict(outfile_prefix='results/YOLOX_shift_from_clear_daytime')`. 146 | Otherwise, a temporal file will be created and will be removed after evaluation. 147 | - If you just want the formatted results without evaluation, you can set `format_only=True`. 148 | For example, `test_evaluator = dict(type='SHIFTVideoMetric', metric='bbox', outfile_prefix='results/YOLOX_shift_from_clear_daytime', format_only=True)` 149 | 150 | #### 1. Test on CPU 151 | 152 | The model is default put on cuda device. 153 | Only if there are no cuda devices, the model will be put on cpu. 154 | So if you want to test the model on CPU, you need to `export CUDA_VISIBLE_DEVICES=-1` to disable GPU visibility first. 155 | More details in [MMEngine](https://github.com/open-mmlab/mmengine/blob/ca282aee9e402104b644494ca491f73d93a9544f/mmengine/runner/runner.py#L849-L850). 156 | 157 | ```shell script 158 | CUDA_VISIBLE_DEVICES=-1 python tools/test.py ${CONFIG_FILE} [optional arguments] 159 | ``` 160 | 161 | An example of testing the YOLOX clear-daytime model on CPU: 162 | 163 | ```shell script 164 | CUDA_VISIBLE_DEVICES=-1 python tools/test.py configs/source/yolox/amp_yolox_x_8xb4-24e_shift_clear_daytime.py --checkpoint checkpoints/yolox_x_8xb4-24e_shift_clear_daytime.pth 165 | ``` 166 | 167 | #### 2. Test on single GPU 168 | 169 | If you want to test the model on single GPU, you can directly use the `tools/test.py` as follows. 170 | 171 | ```shell script 172 | python tools/test.py ${CONFIG_FILE} [optional arguments] 173 | ``` 174 | 175 | You can use `export CUDA_VISIBLE_DEVICES=$GPU_ID` to select the GPU. 176 | 177 | An example of testing the YOLOX clear-daytime model on single GPU: 178 | 179 | ```shell script 180 | CUDA_VISIBLE_DEVICES=2 python tools/test.py configs/source/yolox/amp_yolox_x_8xb4-24e_shift_clear_daytime.py --checkpoint checkpoints/yolox_x_8xb4-24e_shift_clear_daytime.pth 181 | ``` 182 | 183 | #### 3. Test on single node multiple GPUs 184 | 185 | We provide `tools/dist_test.sh` to launch testing on multiple GPUs. 186 | The basic usage is as follows. 187 | 188 | ```shell script 189 | bash ./tools/dist_test.sh ${CONFIG_FILE} ${GPU_NUM} [optional arguments] 190 | ``` 191 | 192 | An example of testing the YOLOX clear-daytime model on single node multiple GPUs: 193 | 194 | ```shell script 195 | bash ./tools/dist_test.sh configs/source/yolox/amp_yolox_x_8xb4-24e_shift_clear_daytime.py 8 --checkpoint checkpoints/yolox_x_8xb4-24e_shift_clear_daytime.pth 196 | ``` 197 | 198 | #### 4. Test on multiple nodes 199 | 200 | You can test on multiple nodes, which is similar with "Train on multiple nodes". 201 | 202 | #### 5. Test with Slurm 203 | 204 | On a cluster managed by Slurm, you can use `slurm_test.sh` to spawn testing jobs. 205 | It supports both single-node and multi-node testing. 206 | 207 | The basic usage is as follows. 208 | 209 | ```shell script 210 | bash ./tools/slurm_test.sh ${PARTITION} ${JOB_NAME} ${CONFIG_FILE} ${GPUS} 211 | ``` 212 | 213 | An example of testing the YOLOX clear-daytime model with Slurm: 214 | 215 | ```shell script 216 | PORT=29501 \ 217 | GPUS_PER_NODE=8 \ 218 | SRUN_ARGS="--quotatype=reserved" \ 219 | bash ./tools/slurm_test.sh \ 220 | mypartition \ 221 | YOLOX_clear_daytime \ 222 | configs/source/yolox/amp_yolox_x_8xb4-24e_shift_clear_daytime.py \ 223 | 8 \ 224 | --checkpoint checkpoints/yolox_x_8xb4-24e_shift_clear_daytime.pth 225 | ``` -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | -r requirements/build.txt 2 | -r requirements/runtime.txt -------------------------------------------------------------------------------- /requirements/build.txt: -------------------------------------------------------------------------------- 1 | cython 2 | numpy -------------------------------------------------------------------------------- /requirements/mminstall.txt: -------------------------------------------------------------------------------- 1 | mmcls>=1.0.0rc0 2 | mmcv>=2.0.0rc1 3 | mmdet>=3.0.0rc0 4 | mmengine>=0.1.0 5 | mmtrack>=1.0.0rc1 -------------------------------------------------------------------------------- /requirements/runtime.txt: -------------------------------------------------------------------------------- 1 | attributee 2 | lap 3 | matplotlib 4 | mmcls>=1.0.0rc0 5 | motmetrics 6 | packaging 7 | pandas<=1.3.5 8 | pycocotools 9 | pydantic==1.9.1 10 | pytest 11 | scikit-learn 12 | scikit-image<=0.19.3 13 | scipy<=1.7.3 14 | seaborn 15 | tabulate 16 | terminaltables 17 | tqdm 18 | -------------------------------------------------------------------------------- /resources/shift-logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SysCV/shift-detection-tta/a5c4fb0e906caa57c0b572d651f24f01ff6bdcdf/resources/shift-logo.png -------------------------------------------------------------------------------- /scripts/continuous/mean_teacher_adapter_yolox/slurm_test_yolox_shift_from_clear_daytime.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | TIME=12:00:00 # TIME=(24:00:00) 3 | PARTITION=gpu22 # PARTITION=(gpu16 | gpu20 | gpu22) 4 | GPUS_TYPE=a40 # GPUS_TYPE=(Quadro_RTX_8000 | a40 | a100) 5 | GPUS=4 6 | CPUS=16 7 | MEM_PER_CPU=22000 8 | GPUS_PER_NODE=${GPUS_PER_NODE:-8} 9 | CPUS_PER_TASK=${CPUS_PER_TASK:-16} 10 | SBATCH_ARGS=${SBATCH_ARGS:-""} 11 | GPUS_PER_NODE=${GPUS} 12 | CPUS_PER_TASK=${CPUS} 13 | 14 | ############### 15 | ##### Your args 16 | CONFIG=configs/continuous/mean_teacher_adapter_yolox/yolox_x_8xb4-24e_shift_from_clear_daytime.py 17 | CKPT=https://dl.cv.ethz.ch/shift/challenge2023/test_time_adaptation/checkpoints/yolox_x_8xb4-24e_shift_clear_daytime.pth 18 | WORK_DIR=work_dirs/continuous/mean_teacher_yolox/test/yolox_x_8xb4-24e_shift_from_clear_ndaytime 19 | 20 | declare -a CFG_OPTIONS=( 21 | "test_evaluator.0.outfile_prefix=${WORK_DIR}/results" 22 | "test_dataloader.dataset.ann_file=data/shift/continuous/videos/1x/test/front/det_2d_cocoformat.json" 23 | "test_dataloader.dataset.pipeline.0.backend_args.tar_path=data/shift/continuous/videos/1x/test/front/img_decompressed.tar" 24 | ) 25 | ##### 26 | ############### 27 | 28 | if [ $GPUS -gt 1 ] 29 | then 30 | CMD=tools/dist_test.sh 31 | else 32 | CMD=tools/test.sh 33 | fi 34 | 35 | echo "Launching ${CMD} on ${GPUS} gpus." 36 | echo "Starting job ${JOB_NAME} from ${CONFIG} using --cfg-options ${CFG_OPTIONS[*]}" 37 | 38 | mkdir -p errors/ 39 | mkdir -p outputs/ 40 | 41 | ID=$(sbatch \ 42 | --parsable \ 43 | -t ${TIME} \ 44 | --job-name=${JOB_NAME} \ 45 | -p ${PARTITION} \ 46 | --gres=gpu:${GPUS_TYPE}:${GPUS_PER_NODE} \ 47 | -e errors/%j.log \ 48 | -o outputs/%j.log \ 49 | --mail-type=BEGIN,END,FAIL \ 50 | ${SBATCH_ARGS} \ 51 | ${CMD} \ 52 | ${CONFIG} \ 53 | ${GPUS} \ 54 | --checkpoint ${CKPT} \ 55 | --cfg-options ${CFG_OPTIONS[@]}) 56 | 57 | -------------------------------------------------------------------------------- /scripts/continuous/mean_teacher_adapter_yolox/test_yolox_shift_from_clear_daytime.sh: -------------------------------------------------------------------------------- 1 | CONFIG=configs/continuous/mean_teacher_adapter_yolox/yolox_x_8xb4-24e_shift_from_clear_daytime.py 2 | CKPT=https://dl.cv.ethz.ch/shift/challenge2023/test_time_adaptation/checkpoints/yolox_x_8xb4-24e_shift_clear_daytime.pth 3 | WORK_DIR=work_dirs/continuous/mean_teacher_yolox/test/yolox_x_8xb4-24e_shift_from_clear_ndaytime 4 | 5 | declare -a CFG_OPTIONS=( 6 | "test_evaluator.0.outfile_prefix=${WORK_DIR}/results" 7 | "test_dataloader.dataset.ann_file=data/shift/continuous/videos/1x/test/front/det_2d_cocoformat.json" 8 | "test_dataloader.dataset.pipeline.0.backend_args.tar_path=data/shift/continuous/videos/1x/test/front/img_decompressed.tar" 9 | ) 10 | 11 | # python -m debugpy --listen $HOSTNAME:5678 --wait-for-client tools/test.py \ 12 | python tools/test.py \ 13 | ${CONFIG} \ 14 | --checkpoint ${CKPT} \ 15 | --cfg-options ${CFG_OPTIONS[@]} 16 | -------------------------------------------------------------------------------- /scripts/continuous/mean_teacher_adapter_yolox/val_yolox_shift_from_clear_daytime.sh: -------------------------------------------------------------------------------- 1 | CONFIG=configs/continuous/mean_teacher_adapter_yolox/yolox_x_8xb4-24e_shift_from_clear_daytime.py 2 | CKPT=https://dl.cv.ethz.ch/shift/challenge2023/test_time_adaptation/checkpoints/yolox_x_8xb4-24e_shift_clear_daytime.pth 3 | WORK_DIR=work_dirs/continuous/mean_teacher_yolox/val/yolox_x_8xb4-24e_shift_from_clear_ndaytime 4 | 5 | declare -a CFG_OPTIONS=( 6 | "test_evaluator.0.outfile_prefix=${WORK_DIR}/results" 7 | ) 8 | 9 | # python -m debugpy --listen $HOSTNAME:5678 --wait-for-client tools/test.py \ 10 | python tools/test.py \ 11 | ${CONFIG} \ 12 | --checkpoint ${CKPT} \ 13 | --cfg-options ${CFG_OPTIONS[@]} 14 | -------------------------------------------------------------------------------- /scripts/continuous/no_adap_yolox/test_yolox_shift_from_clear_daytime.sh: -------------------------------------------------------------------------------- 1 | CONFIG=configs/continuous/no_adap_yolox/yolox_x_8xb4-24e_shift_from_clear_daytime.py 2 | CKPT=https://dl.cv.ethz.ch/shift/challenge2023/test_time_adaptation/checkpoints/yolox_x_8xb4-24e_shift_clear_daytime.pth 3 | WORK_DIR=work_dirs/continuous/no_adap_yolox/test/yolox_x_8xb4-24e_shift_from_clear_daytime 4 | 5 | declare -a CFG_OPTIONS=( 6 | "test_evaluator.0.outfile_prefix=${WORK_DIR}/results" 7 | "test_dataloader.dataset.ann_file=data/shift/continuous/videos/1x/test/front/det_2d_cocoformat.json" 8 | "test_dataloader.dataset.pipeline.0.backend_args.tar_path=data/shift/continuous/videos/1x/test/front/img_decompressed.tar" 9 | ) 10 | 11 | # python -m debugpy --listen $HOSTNAME:5678 --wait-for-client tools/test.py \ 12 | python tools/test.py \ 13 | ${CONFIG} \ 14 | --checkpoint ${CKPT} \ 15 | --work-dir ${WORK_DIR} \ 16 | --cfg-options ${CFG_OPTIONS[@]} -------------------------------------------------------------------------------- /scripts/continuous/no_adap_yolox/test_yolox_shift_from_clear_night.sh: -------------------------------------------------------------------------------- 1 | CONFIG=configs/continuous/no_adap_yolox/yolox_x_8xb4-24e_shift_from_clear_night.py 2 | CKPT=https://dl.cv.ethz.ch/shift/challenge2023/test_time_adaptation/checkpoints/yolox_x_8xb4-24e_shift_clear_daytime.pth 3 | WORK_DIR=work_dirs/continuous/no_adap_yolox/test/yolox_x_8xb4-24e_shift_from_clear_night 4 | 5 | declare -a CFG_OPTIONS=( 6 | "test_evaluator.0.outfile_prefix=${WORK_DIR}/results" 7 | "test_dataloader.dataset.ann_file=data/shift/continuous/videos/1x/test/front/det_2d_cocoformat.json" 8 | "test_dataloader.dataset.pipeline.0.backend_args.tar_path=data/shift/continuous/videos/1x/test/front/img_decompressed.tar" 9 | 10 | ) 11 | 12 | # python tools/test.py \ 13 | python -m debugpy --listen $HOSTNAME:5678 --wait-for-client tools/test.py \ 14 | ${CONFIG} \ 15 | --checkpoint ${CKPT} \ 16 | --work-dir ${WORK_DIR} \ 17 | --cfg-options ${CFG_OPTIONS[@]} -------------------------------------------------------------------------------- /scripts/continuous/no_adap_yolox/val_yolox_shift_from_clear_daytime.sh: -------------------------------------------------------------------------------- 1 | CONFIG=configs/continuous/no_adap_yolox/yolox_x_8xb4-24e_shift_from_clear_daytime.py 2 | CKPT=https://dl.cv.ethz.ch/shift/challenge2023/test_time_adaptation/checkpoints/yolox_x_8xb4-24e_shift_clear_daytime.pth 3 | WORK_DIR=work_dirs/continuous/no_adap_yolox/val/yolox_x_8xb4-24e_shift_from_clear_daytime 4 | 5 | declare -a CFG_OPTIONS=( 6 | "test_evaluator.0.outfile_prefix=${WORK_DIR}/results" 7 | ) 8 | 9 | # python -m debugpy --listen $HOSTNAME:5678 --wait-for-client tools/test.py \ 10 | python tools/test.py \ 11 | ${CONFIG} \ 12 | --checkpoint ${CKPT} \ 13 | --work-dir ${WORK_DIR} \ 14 | --cfg-options ${CFG_OPTIONS[@]} -------------------------------------------------------------------------------- /scripts/continuous/no_adap_yolox/val_yolox_shift_from_clear_night.sh: -------------------------------------------------------------------------------- 1 | CONFIG=configs/continuous/no_adap_yolox/yolox_x_8xb4-24e_shift_from_clear_night.py 2 | CKPT=https://dl.cv.ethz.ch/shift/challenge2023/test_time_adaptation/checkpoints/yolox_x_8xb4-24e_shift_clear_daytime.pth 3 | WORK_DIR=work_dirs/continuous/no_adap_yolox/val/yolox_x_8xb4-24e_shift_from_clear_night 4 | 5 | declare -a CFG_OPTIONS=( 6 | "test_evaluator.0.outfile_prefix=${WORK_DIR}/results" 7 | ) 8 | 9 | # python tools/test.py \ 10 | python -m debugpy --listen $HOSTNAME:5678 --wait-for-client tools/test.py \ 11 | ${CONFIG} \ 12 | --checkpoint ${CKPT} \ 13 | --work-dir ${WORK_DIR} \ 14 | --cfg-options ${CFG_OPTIONS[@]} -------------------------------------------------------------------------------- /scripts/source/slurm_train_yolox_24e_shift_clear_daytime.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | TIME=12:00:00 # TIME=(24:00:00) 3 | PARTITION=gpu22 # PARTITION=(gpu16 | gpu20 | gpu22) 4 | GPUS_TYPE=a40 # GPUS_TYPE=(Quadro_RTX_8000 | a40 | a100) 5 | GPUS=4 6 | CPUS=16 7 | MEM_PER_CPU=22000 8 | GPUS_PER_NODE=${GPUS_PER_NODE:-8} 9 | CPUS_PER_TASK=${CPUS_PER_TASK:-16} 10 | SBATCH_ARGS=${SBATCH_ARGS:-""} 11 | GPUS_PER_NODE=${GPUS} 12 | CPUS_PER_TASK=${CPUS} 13 | 14 | ############### 15 | ##### Your args 16 | CONFIG=configs/source/yolox/yolox_x_8xb4-24e_shift_clear_daytime.py 17 | declare -a CFG_OPTIONS=( 18 | "data.workers_per_gpu=4" 19 | "data.samples_per_gpu=4" 20 | ) 21 | ##### 22 | ############### 23 | 24 | if [ $GPUS -gt 1 ] 25 | then 26 | CMD=tools/dist_train.sh 27 | else 28 | CMD=tools/train.sh 29 | fi 30 | 31 | echo "Launching ${CMD} on ${GPUS} gpus." 32 | echo "Starting job ${JOB_NAME} from ${CONFIG} using --cfg-options ${CFG_OPTIONS[*]}" 33 | 34 | mkdir -p errors/ 35 | mkdir -p outputs/ 36 | 37 | ID=$(sbatch \ 38 | --parsable \ 39 | -t ${TIME} \ 40 | --job-name=${JOB_NAME} \ 41 | -p ${PARTITION} \ 42 | --gres=gpu:${GPUS_TYPE}:${GPUS_PER_NODE} \ 43 | -e errors/%j.log \ 44 | -o outputs/%j.log \ 45 | --mail-type=BEGIN,END,FAIL \ 46 | ${SBATCH_ARGS} \ 47 | ${CMD} \ 48 | ${CONFIG} \ 49 | ${GPUS} \ 50 | --cfg-options ${CFG_OPTIONS[@]}) -------------------------------------------------------------------------------- /scripts/source/slurm_train_yolox_24e_shift_daytime.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | TIME=12:00:00 # TIME=(24:00:00) 3 | PARTITION=gpu22 # PARTITION=(gpu16 | gpu20 | gpu22) 4 | GPUS_TYPE=a40 # GPUS_TYPE=(Quadro_RTX_8000 | a40 | a100) 5 | GPUS=4 6 | CPUS=16 7 | MEM_PER_CPU=22000 8 | GPUS_PER_NODE=${GPUS_PER_NODE:-8} 9 | CPUS_PER_TASK=${CPUS_PER_TASK:-16} 10 | SBATCH_ARGS=${SBATCH_ARGS:-""} 11 | GPUS_PER_NODE=${GPUS} 12 | CPUS_PER_TASK=${CPUS} 13 | 14 | ############### 15 | ##### Your args 16 | CONFIG=configs/source/yolox/yolox_x_8xb4-24e_shift_daytime.py 17 | declare -a CFG_OPTIONS=( 18 | "data.workers_per_gpu=4" 19 | "data.samples_per_gpu=4" 20 | ) 21 | ##### 22 | ############### 23 | 24 | if [ $GPUS -gt 1 ] 25 | then 26 | CMD=tools/dist_train.sh 27 | else 28 | CMD=tools/train.sh 29 | fi 30 | 31 | echo "Launching ${CMD} on ${GPUS} gpus." 32 | echo "Starting job ${JOB_NAME} from ${CONFIG} using --cfg-options ${CFG_OPTIONS[*]}" 33 | 34 | mkdir -p errors/ 35 | mkdir -p outputs/ 36 | 37 | ID=$(sbatch \ 38 | --parsable \ 39 | -t ${TIME} \ 40 | --job-name=${JOB_NAME} \ 41 | -p ${PARTITION} \ 42 | --gres=gpu:${GPUS_TYPE}:${GPUS_PER_NODE} \ 43 | -e errors/%j.log \ 44 | -o outputs/%j.log \ 45 | --mail-type=BEGIN,END,FAIL \ 46 | ${SBATCH_ARGS} \ 47 | ${CMD} \ 48 | ${CONFIG} \ 49 | ${GPUS} \ 50 | --cfg-options ${CFG_OPTIONS[@]}) -------------------------------------------------------------------------------- /scripts/source/test_yolox_shift_clear_daytime.sh: -------------------------------------------------------------------------------- 1 | CONFIG=configs/source/yolox/yolox_x_8xb4-24e_shift_clear_daytime.py 2 | CKPT=https://dl.cv.ethz.ch/shift/challenge2023/test_time_adaptation/checkpoints/yolox_x_8xb4-24e_shift_clear_daytime.pth 3 | 4 | # python -m debugpy --listen $HOSTNAME:5678 --wait-for-client tools/test.py \ 5 | python tools/test.py \ 6 | ${CONFIG} \ 7 | --checkpoint ${CKPT} -------------------------------------------------------------------------------- /scripts/source/test_yolox_shift_clear_night.sh: -------------------------------------------------------------------------------- 1 | CONFIG=configs/source/yolox/yolox_x_8xb4-24e_shift_clear_night.py 2 | CKPT=https://dl.cv.ethz.ch/shift/challenge2023/test_time_adaptation/checkpoints/yolox_x_8xb4-24e_shift_clear_daytime.pth 3 | 4 | # python -m debugpy --listen $HOSTNAME:5678 --wait-for-client tools/test.py \ 5 | python tools/test.py \ 6 | ${CONFIG} \ 7 | --checkpoint ${CKPT} -------------------------------------------------------------------------------- /scripts/source/train_yolox_shift_clear_daytime.sh: -------------------------------------------------------------------------------- 1 | # CONFIG=configs/source/yolox/yolox_x_8xb4-24e_shift_clear_daytime.py 2 | CONFIG=configs/source/yolox/amp_yolox_x_8xb4-24e_shift_clear_daytime.py 3 | 4 | # python -m debugpy --listen $HOSTNAME:5678 --wait-for-client tools/train.py \ 5 | python tools/train.py \ 6 | ${CONFIG} -------------------------------------------------------------------------------- /scripts/tmp/edit_json.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | ann_file = 'data/shift/continuous/videos/1x/val/front/det_2d_cocoformat.json' 4 | dest_file = 'data/shift/continuous/videos/1x/val/front/det_2d_cocoformat_tmp.json' 5 | 6 | with open(ann_file, 'r') as fp: 7 | anns = json.load(fp) 8 | 9 | for im in anns['images']: 10 | im['file_name'] = im['file_name'].replace('_img_front.jpg', '.jpg') 11 | 12 | with open(dest_file, 'w') as fp: 13 | json.dump(anns, fp) 14 | 15 | -------------------------------------------------------------------------------- /scripts/tmp/edit_json_ids.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | ann_file = 'data/shift/continuous/videos/1x/test/front/det_2d_cocoformat.json' 4 | dest_file = 'data/shift/continuous/videos/1x/test/front/det_2d_cocoformat_tmp.json' 5 | 6 | with open(ann_file, 'r') as fp: 7 | anns = json.load(fp) 8 | 9 | print('Starting conversion') 10 | for im in anns['images']: 11 | im['file_name'] = im['file_name'].replace(str(im['frame_id']).zfill(8), str(im['frame_id']//10).zfill(8)) 12 | im['frame_id'] = im['frame_id'] // 10 13 | print('Conversion done') 14 | 15 | with open(dest_file, 'w') as fp: 16 | json.dump(anns, fp) 17 | 18 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [isort] 2 | line_length = 79 3 | multi_line_output = 0 4 | extra_standard_library = setuptools 5 | known_first_party = shift-tta 6 | known_third_party = PIL,addict,cv2,lap,matplotlib,mmcls,mmcv,mmdet,mmtrack,motmetrics,numpy,packaging,pandas,pycocotools,pytest,pytorch_sphinx_theme,requests,scipy,script_utils,seaborn,tao,terminaltables,torch,tqdm 7 | no_lines_before = STDLIB,LOCALFOLDER 8 | default_section = THIRDPARTY 9 | 10 | [yapf] 11 | BASED_ON_STYLE = pep8 12 | BLANK_LINE_BEFORE_NESTED_CLASS_OR_DEF = true 13 | SPLIT_BEFORE_EXPRESSION_AFTER_OPENING_PAREN = true 14 | 15 | [codespell] 16 | ignore-words-list = mot,gool 17 | skip = *.json -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import shutil 4 | import sys 5 | import warnings 6 | from setuptools import find_packages, setup 7 | 8 | import torch 9 | from torch.utils.cpp_extension import (BuildExtension, CppExtension, 10 | CUDAExtension) 11 | 12 | 13 | def readme(): 14 | with open('README.md', encoding='utf-8') as f: 15 | content = f.read() 16 | return content 17 | 18 | 19 | version_file = 'shift_tta/version.py' 20 | 21 | 22 | def get_version(): 23 | with open(version_file, 'r') as f: 24 | exec(compile(f.read(), version_file, 'exec')) 25 | return locals()['__version__'] 26 | 27 | 28 | def make_cuda_ext(name, module, sources, sources_cuda=[]): 29 | 30 | define_macros = [] 31 | extra_compile_args = {'cxx': []} 32 | 33 | if torch.cuda.is_available() or os.getenv('FORCE_CUDA', '0') == '1': 34 | define_macros += [('WITH_CUDA', None)] 35 | extension = CUDAExtension 36 | extra_compile_args['nvcc'] = [ 37 | '-D__CUDA_NO_HALF_OPERATORS__', 38 | '-D__CUDA_NO_HALF_CONVERSIONS__', 39 | '-D__CUDA_NO_HALF2_OPERATORS__', 40 | ] 41 | sources += sources_cuda 42 | else: 43 | print(f'Compiling {name} without CUDA') 44 | extension = CppExtension 45 | 46 | return extension( 47 | name=f'{module}.{name}', 48 | sources=[os.path.join(*module.split('.'), p) for p in sources], 49 | define_macros=define_macros, 50 | extra_compile_args=extra_compile_args) 51 | 52 | 53 | def parse_requirements(fname='requirements.txt', with_version=True): 54 | """Parse the package dependencies listed in a requirements file but strips 55 | specific versioning information. 56 | 57 | Args: 58 | fname (str): path to requirements file 59 | with_version (bool, default=False): if True include version specs 60 | 61 | Returns: 62 | List[str]: list of requirements items 63 | 64 | CommandLine: 65 | python -c "import setup; print(setup.parse_requirements())" 66 | """ 67 | import re 68 | import sys 69 | from os.path import exists 70 | require_fpath = fname 71 | 72 | def parse_line(line): 73 | """Parse information from a line in a requirements text file.""" 74 | if line.startswith('-r '): 75 | # Allow specifying requirements in other files 76 | target = line.split(' ')[1] 77 | for info in parse_require_file(target): 78 | yield info 79 | else: 80 | info = {'line': line} 81 | if line.startswith('-e '): 82 | info['package'] = line.split('#egg=')[1] 83 | elif '@git+' in line: 84 | info['package'] = line 85 | else: 86 | # Remove versioning from the package 87 | pat = '(' + '|'.join(['>=', '==', '>']) + ')' 88 | parts = re.split(pat, line, maxsplit=1) 89 | parts = [p.strip() for p in parts] 90 | 91 | info['package'] = parts[0] 92 | if len(parts) > 1: 93 | op, rest = parts[1:] 94 | if ';' in rest: 95 | # Handle platform specific dependencies 96 | # http://setuptools.readthedocs.io/en/latest/setuptools.html#declaring-platform-specific-dependencies 97 | version, platform_deps = map(str.strip, 98 | rest.split(';')) 99 | info['platform_deps'] = platform_deps 100 | else: 101 | version = rest # NOQA 102 | info['version'] = (op, version) 103 | yield info 104 | 105 | def parse_require_file(fpath): 106 | with open(fpath, 'r') as f: 107 | for line in f.readlines(): 108 | line = line.strip() 109 | if line and not line.startswith('#'): 110 | for info in parse_line(line): 111 | yield info 112 | 113 | def gen_packages_items(): 114 | if exists(require_fpath): 115 | for info in parse_require_file(require_fpath): 116 | parts = [info['package']] 117 | if with_version and 'version' in info: 118 | parts.extend(info['version']) 119 | if not sys.version.startswith('3.4'): 120 | # apparently package_deps are broken in 3.4 121 | platform_deps = info.get('platform_deps') 122 | if platform_deps is not None: 123 | parts.append(';' + platform_deps) 124 | item = ''.join(parts) 125 | yield item 126 | 127 | packages = list(gen_packages_items()) 128 | return packages 129 | 130 | 131 | def add_mim_extension(): 132 | """Add extra files that are required to support MIM into the package. 133 | 134 | These files will be added by creating a symlink to the originals if the 135 | package is installed in `editable` mode (e.g. pip install -e .), or by 136 | copying from the originals otherwise. 137 | """ 138 | 139 | # parse installment mode 140 | if 'develop' in sys.argv: 141 | # installed by `pip install -e .` 142 | mode = 'symlink' 143 | elif 'sdist' in sys.argv or 'bdist_wheel' in sys.argv: 144 | # installed by `pip install .` 145 | # or create source distribution by `python setup.py sdist` 146 | mode = 'copy' 147 | else: 148 | return 149 | 150 | filenames = ['tools', 'configs', 'model-index.yml'] 151 | repo_path = osp.dirname(__file__) 152 | mim_path = osp.join(repo_path, 'shift_tta', '.mim') 153 | os.makedirs(mim_path, exist_ok=True) 154 | 155 | for filename in filenames: 156 | if osp.exists(filename): 157 | src_path = osp.join(repo_path, filename) 158 | tar_path = osp.join(mim_path, filename) 159 | 160 | if osp.isfile(tar_path) or osp.islink(tar_path): 161 | os.remove(tar_path) 162 | elif osp.isdir(tar_path): 163 | shutil.rmtree(tar_path) 164 | 165 | if mode == 'symlink': 166 | src_relpath = osp.relpath(src_path, osp.dirname(tar_path)) 167 | try: 168 | os.symlink(src_relpath, tar_path) 169 | except OSError: 170 | # Creating a symbolic link on windows may raise an 171 | # `OSError: [WinError 1314]` due to privilege. If 172 | # the error happens, the src file will be copied 173 | mode = 'copy' 174 | warnings.warn( 175 | f'Failed to create a symbolic link for {src_relpath}, ' 176 | f'and it will be copied to {tar_path}') 177 | else: 178 | continue 179 | 180 | if mode == 'copy': 181 | if osp.isfile(src_path): 182 | shutil.copyfile(src_path, tar_path) 183 | elif osp.isdir(src_path): 184 | shutil.copytree(src_path, tar_path) 185 | else: 186 | warnings.warn(f'Cannot copy file {src_path}.') 187 | else: 188 | raise ValueError(f'Invalid mode {mode}') 189 | 190 | 191 | if __name__ == '__main__': 192 | add_mim_extension() 193 | setup( 194 | name='shift-tta', 195 | version=get_version(), 196 | description='Test-time adaptation platform for object detection from the VIS group at ETH Zurich', 197 | long_description=readme(), 198 | long_description_content_type='text/markdown', 199 | author='Mattia Segu', 200 | author_email='mattia.segu@vision.ee.ethz.ch', 201 | keywords='computer vision, object detection, test-time adaptation', 202 | url='https://github.com/SysCV/shift-detection-tta', 203 | packages=find_packages(exclude=('configs', 'tools', 'demo')), 204 | include_package_data=True, 205 | classifiers=[ 206 | 'Development Status :: 4 - Beta', 207 | 'License :: OSI Approved :: MIT License', 208 | 'Operating System :: OS Independent', 209 | 'Programming Language :: Python :: 3', 210 | 'Programming Language :: Python :: 3.6', 211 | 'Programming Language :: Python :: 3.7', 212 | 'Programming Language :: Python :: 3.8', 213 | 'Programming Language :: Python :: 3.9', 214 | ], 215 | license='MIT License', 216 | install_requires=parse_requirements('requirements/runtime.txt'), 217 | extras_require={ 218 | 'all': parse_requirements('requirements.txt'), 219 | 'build': parse_requirements('requirements/build.txt'), 220 | 'mim': parse_requirements('requirements/mminstall.txt') 221 | }, 222 | ext_modules=[], 223 | cmdclass={'build_ext': BuildExtension}, 224 | zip_safe=False) -------------------------------------------------------------------------------- /shift_tta/.mim/configs: -------------------------------------------------------------------------------- 1 | ../../configs -------------------------------------------------------------------------------- /shift_tta/.mim/tools: -------------------------------------------------------------------------------- 1 | ../../tools -------------------------------------------------------------------------------- /shift_tta/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import warnings 3 | 4 | import mmcv 5 | import mmdet 6 | import mmtrack 7 | from packaging.version import parse 8 | 9 | from .version import __version__, version_info 10 | 11 | MMCV_MIN = '2.0.0rc1' 12 | MMCV_MAX = '2.0.0' 13 | 14 | MMDET_MIN = '3.0.0rc0' 15 | MMDET_MAX = '3.0.0' 16 | 17 | MMTRACK_MIN = '1.0.0rc1' 18 | MMTRACK_MAX = '1.0.0' 19 | 20 | 21 | def digit_version(version_str: str, length: int = 4): 22 | """Convert a version string into a tuple of integers. 23 | 24 | This method is usually used for comparing two versions. For pre-release 25 | versions: alpha < beta < rc. 26 | 27 | Args: 28 | version_str (str): The version string. 29 | length (int): The maximum number of version levels. Default: 4. 30 | 31 | Returns: 32 | tuple[int]: The version info in digits (integers). 33 | """ 34 | version = parse(version_str) 35 | assert version.release, f'failed to parse version {version_str}' 36 | release = list(version.release) 37 | release = release[:length] 38 | if len(release) < length: 39 | release = release + [0] * (length - len(release)) 40 | if version.is_prerelease: 41 | mapping = {'a': -3, 'b': -2, 'rc': -1} 42 | val = -4 43 | # version.pre can be None 44 | if version.pre: 45 | if version.pre[0] not in mapping: 46 | warnings.warn(f'unknown prerelease version {version.pre[0]}, ' 47 | 'version checking may go wrong') 48 | else: 49 | val = mapping[version.pre[0]] 50 | release.extend([val, version.pre[-1]]) 51 | else: 52 | release.extend([val, 0]) 53 | 54 | elif version.is_postrelease: 55 | release.extend([1, version.post]) 56 | else: 57 | release.extend([0, 0]) 58 | return tuple(release) 59 | 60 | 61 | mmcv_min_version = digit_version(MMCV_MIN) 62 | mmcv_max_version = digit_version(MMCV_MAX) 63 | mmcv_version = digit_version(mmcv.__version__) 64 | 65 | 66 | assert (mmcv_min_version <= mmcv_version < mmcv_max_version), \ 67 | f'MMCV=={mmcv.__version__} is used but incompatible. ' \ 68 | f'Please install mmcv>={MMCV_MIN}, <{MMCV_MAX}.' 69 | 70 | mmdet_min_version = digit_version(MMDET_MIN) 71 | mmdet_max_version = digit_version(MMDET_MAX) 72 | mmdet_version = digit_version(mmdet.__version__) 73 | 74 | 75 | assert (mmdet_min_version <= mmdet_version < mmdet_max_version), \ 76 | f'MMDet=={mmdet.__version__} is used but incompatible. ' \ 77 | f'Please install mmdet>={MMDET_MIN}, <{MMDET_MAX}.' 78 | 79 | mmtrack_min_version = digit_version(MMTRACK_MIN) 80 | mmtrack_max_version = digit_version(MMTRACK_MAX) 81 | mmtrack_version = digit_version(mmtrack.__version__) 82 | 83 | 84 | assert (mmtrack_min_version <= mmtrack_version < mmtrack_max_version), \ 85 | f'MMTrack=={mmtrack.__version__} is used but incompatible. ' \ 86 | f'Please install mmtrack>={MMTRACK_MIN}, <{MMTRACK_MAX}.' 87 | 88 | __all__ = ['__version__', 'version_info', 'digit_version'] -------------------------------------------------------------------------------- /shift_tta/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .shift_dataset import SHIFTDataset 2 | from .utils import check_attributes 3 | 4 | __all__ = [ 5 | 'SHIFTDataset', 6 | 'check_attributes', 7 | ] -------------------------------------------------------------------------------- /shift_tta/datasets/shift_dataset.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from mmengine.fileio import FileClient 4 | from mmtrack.datasets import BaseVideoDataset 5 | from mmtrack.datasets.api_wrappers import CocoVID 6 | 7 | from shift_tta.registry import DATASETS 8 | 9 | from .utils import check_attributes 10 | 11 | 12 | @DATASETS.register_module() 13 | class SHIFTDataset(BaseVideoDataset): 14 | """Dataset class for SHIFT. 15 | Args: 16 | attributes (Optional[Dict[str, ...]]): a dictionary containing the 17 | allowed attributes. Dataset samples will be filtered based on 18 | the allowed attributes. If None, load all samples. Default: None. 19 | """ 20 | 21 | METAINFO = dict( 22 | classes = ('pedestrian', 'car', 'truck', 'bus', 'motorcycle', 'bicycle') 23 | ) 24 | 25 | def __init__(self, 26 | attributes=None, 27 | *args, 28 | **kwargs): 29 | self.attributes = attributes 30 | super().__init__(*args, **kwargs) 31 | 32 | def filter_by_attributes(self): 33 | """Filter annotations according to filter_cfg.attributes. 34 | 35 | Returns: 36 | list[int]: Filtered results. 37 | """ 38 | if self.load_as_video: 39 | valid_data_indices = self._filter_video_by_attributes() 40 | else: 41 | valid_data_indices = self._filter_image_by_attributes() 42 | 43 | return valid_data_indices 44 | 45 | def _filter_video_by_attributes(self): 46 | """Filter video annotations according to filter_cfg.attributes. 47 | 48 | Annotations are filtered based on the attributes of the first 49 | frame in the video. 50 | 51 | Returns: 52 | list[int]: Filtered results. 53 | """ 54 | file_client = FileClient.infer_client(uri=self.ann_file) 55 | with file_client.get_local_path(self.ann_file) as local_path: 56 | coco = CocoVID(local_path) 57 | 58 | valid_data_indices = [] 59 | data_id = 0 60 | vid_ids = coco.get_vid_ids() 61 | for vid_id in vid_ids: 62 | img_ids = coco.get_img_ids_from_vid(vid_id) 63 | if not len(img_ids) > 0: 64 | continue 65 | raw_img_info = coco.load_imgs([img_ids[0]])[0] 66 | if check_attributes( 67 | raw_img_info['attributes'], self.filter_cfg['attributes']): 68 | valid_data_indices.extend( 69 | list(range(data_id, data_id + len(img_ids)))) 70 | data_id += len(img_ids) 71 | 72 | set_valid_data_indices = set(self.valid_data_indices) 73 | valid_data_indices = [ 74 | id for id in valid_data_indices if id in set_valid_data_indices 75 | ] 76 | return valid_data_indices 77 | 78 | def _filter_image_by_attributes(self): 79 | """Filter image annotations according to filter_cfg.attributes. 80 | 81 | Returns: 82 | list[int]: Filtered results. 83 | """ 84 | valid_data_indices = [] 85 | for i, data_info in enumerate(self.data_list): 86 | img_id = data_info['img_id'] 87 | if check_attributes( 88 | data_info['attributes'], self.filter_cfg['attributes'] 89 | ): 90 | valid_data_indices.append(i) 91 | 92 | set_valid_data_indices = set(self.valid_data_indices) 93 | valid_data_indices = [ 94 | id for id in valid_data_indices if id in set_valid_data_indices 95 | ] 96 | return valid_data_indices 97 | 98 | def filter_data(self) -> List[int]: 99 | """Filter annotations according to filter_cfg. 100 | 101 | Returns: 102 | list[int]: Filtered results. 103 | """ 104 | # filter data by attributes (useful for domain filtering) 105 | if self.filter_cfg is not None: 106 | self.valid_data_indices = self.filter_by_attributes() 107 | 108 | return super().filter_data() -------------------------------------------------------------------------------- /shift_tta/datasets/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .filters import check_attributes 2 | 3 | __all__ = ['check_attributes'] -------------------------------------------------------------------------------- /shift_tta/datasets/utils/filters.py: -------------------------------------------------------------------------------- 1 | """Dataset filtering utils.""" 2 | from typing import Any, Dict, List, Optional, Union 3 | 4 | 5 | def _check_attributes( 6 | attributes: Union[bool, float, str], 7 | allowed_attributes: Union[bool, float, str, List[float], List[str]], 8 | ) -> bool: 9 | """Check if attributes are allowed. 10 | Args: 11 | attributes: Attributes of current frame. 12 | allowed_attributes: Attributes allowed. 13 | Return: 14 | boolean, whether frame attributes are allowed. 15 | """ 16 | if isinstance(allowed_attributes, list): 17 | # assert frame_attributes not in allowed_attributes 18 | return attributes in allowed_attributes 19 | return attributes == allowed_attributes 20 | 21 | 22 | def check_attributes(attributes, allowed_attributes=None): 23 | """Check if a dictionary of attributes is allowed. 24 | Args: 25 | attributes (Dict[str, str]): attributes to check 26 | allowed_attributes (Dict[str, List[str]]): allowed attributes 27 | Return: 28 | boolean, whether frame attributes are allowed. 29 | """ 30 | check = True 31 | if allowed_attributes: 32 | for key in allowed_attributes: 33 | allowed_attribute = allowed_attributes[key] 34 | check = check and _check_attributes( 35 | attributes[key], allowed_attribute 36 | ) 37 | if not check: 38 | return check 39 | return check -------------------------------------------------------------------------------- /shift_tta/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | from .metrics import * # noqa: F401,F403 -------------------------------------------------------------------------------- /shift_tta/evaluation/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from .shift_video_metrics import SHIFTVideoMetric 2 | 3 | 4 | __all__ = ['SHIFTVideoMetric'] -------------------------------------------------------------------------------- /shift_tta/evaluation/metrics/shift_video_metrics.py: -------------------------------------------------------------------------------- 1 | 2 | import warnings 3 | from typing import Optional, Sequence 4 | 5 | from mmdet.datasets.api_wrappers import COCO 6 | from mmdet.evaluation import CocoMetric 7 | from mmdet.structures.mask import encode_mask_results 8 | from mmengine.dist import broadcast_object_list, is_main_process 9 | from mmengine.fileio import FileClient, dump, load 10 | from mmtrack.evaluation.metrics import CocoVideoMetric 11 | 12 | from scalabel.label.typing import Dataset, Frame, Label, Box3D, RLE 13 | from scalabel.label.transforms import bbox_to_box2d, coco_rle_to_rle, polygon_to_poly2ds 14 | 15 | from shift_tta.registry import METRICS 16 | 17 | 18 | @METRICS.register_module() 19 | class SHIFTVideoMetric(CocoVideoMetric): 20 | """SHIFT evaluation metric. 21 | 22 | Wraps CocoVideoMetric to implement dumping to coco json format so that 23 | it is compatible for conversion to the scalabel format. 24 | 25 | Args: 26 | to_scalabel (bool): Whether to dump results to a Scalabel-style json 27 | file. Defaults to True. 28 | """ 29 | def __init__(self, to_scalabel: bool = True, **kwargs) -> None: 30 | super().__init__(**kwargs) 31 | self.to_scalabel = to_scalabel 32 | 33 | def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None: 34 | """Process one batch of data samples and predictions. The processed 35 | results should be stored in ``self.results``, which will be used to 36 | compute the metrics when all batches have been processed. 37 | 38 | Note that we only modify ``pred['pred_instances']`` in ``CocoMetric`` 39 | to ``pred['pred_det_instances']`` here. 40 | 41 | Args: 42 | data_batch (dict): A batch of data from the dataloader. 43 | data_samples (Sequence[dict]): A batch of data samples that 44 | contain annotations and predictions. 45 | """ 46 | for data_sample in data_samples: 47 | result = dict() 48 | pred = data_sample['pred_det_instances'] 49 | result['img_id'] = data_sample['img_id'] 50 | result['img_name'] = data_sample['img_path'].split('/')[-1] 51 | result['video_name'] = data_sample['img_path'].split('/')[-2] 52 | result['bboxes'] = pred['bboxes'].cpu().numpy() 53 | result['scores'] = pred['scores'].cpu().numpy() 54 | result['labels'] = pred['labels'].cpu().numpy() 55 | # encode mask to RLE 56 | if 'masks' in pred: 57 | result['masks'] = encode_mask_results( 58 | pred['masks'].detach().cpu().numpy()) 59 | # some detectors use different scores for bbox and mask 60 | if 'mask_scores' in pred: 61 | result['mask_scores'] = pred['mask_scores'].cpu().numpy() 62 | 63 | # parse gt 64 | gt = dict() 65 | gt['width'] = data_sample['ori_shape'][1] 66 | gt['height'] = data_sample['ori_shape'][0] 67 | gt['img_id'] = data_sample['img_id'] 68 | if self._coco_api is None: 69 | assert 'instances' in data_sample, \ 70 | 'ground truth is required for evaluation when ' \ 71 | '`ann_file` is not provided' 72 | gt['anns'] = data_sample['instances'] 73 | # add converted result to the results list 74 | self.results.append((gt, result)) 75 | 76 | def results2scalabel(self, results: Sequence[dict], 77 | outfile_prefix: str) -> dict: 78 | """Dump the detection results to a Scalabel style json file. 79 | 80 | There are 3 types of results: proposals, bbox predictions, mask 81 | predictions, and they have different data types. This method will 82 | automatically recognize the type, and dump them to json files. 83 | 84 | Args: 85 | results (Sequence[dict]): Testing results of the 86 | dataset. 87 | outfile_prefix (str): The filename prefix of the json files. If the 88 | prefix is "somepath/xxx", the json files will be named 89 | "somepath/xxx.bbox.json", "somepath/xxx.segm.json", 90 | "somepath/xxx.proposal.json". 91 | 92 | Returns: 93 | dict: Possible keys are "bbox", "segm", "proposal", and 94 | values are corresponding filenames. 95 | """ 96 | bbox_frames = [] 97 | segm_frames = [] if 'masks' in results[0] else None 98 | for idx, result in enumerate(results): 99 | image_id = result.get('img_id', idx) 100 | bboxes = result['bboxes'] 101 | scores = result['scores'] 102 | # bbox results 103 | labels = [] 104 | for i, label in enumerate(result['labels']): 105 | label = Label( 106 | id=i, 107 | # category=self.cat_ids[label], # check if this is text cat 108 | category=self.dataset_meta['classes'][label], 109 | box2d=bbox_to_box2d(self.xyxy2xywh(bboxes[i])), 110 | score=float(scores[i]), 111 | ) 112 | labels.append(label) 113 | 114 | bbox_frame = Frame( 115 | name=result["img_name"], 116 | videoName=result["video_name"], 117 | frameIndex=result["img_name"].split('.')[0].split('_')[0], 118 | labels=labels, 119 | ) 120 | bbox_frames.append(bbox_frame) 121 | 122 | if segm_frames is None: 123 | continue 124 | 125 | # segm results 126 | masks = result['masks'] 127 | mask_scores = result.get('mask_scores', scores) 128 | labels = [] 129 | for i, label in enumerate(result['labels']): 130 | if isinstance(masks[i]['counts'], bytes): 131 | masks[i]['counts'] = masks[i]['counts'].decode() 132 | 133 | label = Label( 134 | id=label["id"], 135 | category=self.dataset_meta['classes'][label], 136 | box2d=bbox_to_box2d(self.xyxy2xywh(bboxes[i])), 137 | score=float(mask_scores[i]) 138 | ) 139 | if isinstance(masks[i], list): 140 | label.poly2d = polygon_to_poly2ds(masks[i]) 141 | else: 142 | label.rle = coco_rle_to_rle(masks[i]) 143 | labels.append(label) 144 | segm_frame = Frame( 145 | name=result["img_name"], 146 | videoName=result["video_name"], 147 | frameIndex=result["img_name"].split('.')[0].split('_')[0], 148 | labels=labels, 149 | ) 150 | segm_frames.append(segm_frame) 151 | 152 | bbox_ds = Dataset(frames=bbox_frames, groups=None, config=None) 153 | 154 | result_files = dict() 155 | result_files['bbox'] = f'{outfile_prefix}.bbox.scalabel.json' 156 | result_files['proposal'] = f'{outfile_prefix}.bbox.scalabel.json' 157 | with open(result_files['bbox'], "w") as f: 158 | f.write(bbox_ds.json(exclude_unset=True)) 159 | 160 | if segm_frames is not None: 161 | segm_ds = Dataset(frames=segm_frames, groups=None, config=None) 162 | result_files['segm'] = f'{outfile_prefix}.segm.scalabel.json' 163 | with open(result_files['segm'], "w") as f: 164 | f.write(segm_ds.json(exclude_unset=True)) 165 | 166 | return result_files 167 | 168 | def results2json(self, results: Sequence[dict], 169 | outfile_prefix: str) -> dict: 170 | """Dump the detection results to a COCO / Scalabel style json file. 171 | 172 | There are 3 types of results: proposals, bbox predictions, mask 173 | predictions, and they have different data types. This method will 174 | automatically recognize the type, and dump them to json files. 175 | 176 | Args: 177 | results (Sequence[dict]): Testing results of the 178 | dataset. 179 | outfile_prefix (str): The filename prefix of the json files. If the 180 | prefix is "somepath/xxx", the json files will be named 181 | "somepath/xxx.bbox.json", "somepath/xxx.segm.json", 182 | "somepath/xxx.proposal.json". 183 | 184 | Returns: 185 | dict: Possible keys are "bbox", "segm", "proposal", and 186 | values are corresponding filenames. 187 | """ 188 | 189 | if self.to_scalabel: 190 | _ = self.results2scalabel(results, outfile_prefix) 191 | 192 | return super().results2json(results, outfile_prefix) 193 | -------------------------------------------------------------------------------- /shift_tta/fileio/__init__.py: -------------------------------------------------------------------------------- 1 | from .backends import TarBackend, ZipBackend 2 | 3 | 4 | __all__ = [ 5 | 'TarBackend', 6 | 'ZipBackend' 7 | ] -------------------------------------------------------------------------------- /shift_tta/fileio/backends/__init__.py: -------------------------------------------------------------------------------- 1 | from .tar_backend import TarBackend 2 | from .zip_backend import ZipBackend 3 | 4 | __all__ = [ 5 | 'TarBackend', 6 | 'ZipBackend' 7 | ] -------------------------------------------------------------------------------- /shift_tta/fileio/backends/tar_backend.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Union 3 | from tarfile import TarFile 4 | 5 | import os 6 | 7 | from mmengine.fileio.backends import BaseStorageBackend, register_backend 8 | 9 | 10 | @register_backend('tar') 11 | class TarBackend(BaseStorageBackend): 12 | """Backend for loading data from .tar files. 13 | 14 | This backend works with filepaths pointing to valid .tar files. We assume 15 | that the given .tar file contains the whole dataset associated to this 16 | backend. 17 | """ 18 | 19 | def __init__(self, tar_path='', **kwargs): 20 | self.tar_path = str(tar_path) 21 | self._client = None 22 | 23 | def get(self, filepath: Union[str, Path]) -> bytes: 24 | """Get values according to the filepath. 25 | 26 | Args: 27 | filepath (str or Path): Here, filepath is the tar key. 28 | 29 | Returns: 30 | bytes: Expected bytes object. 31 | 32 | Examples: 33 | >>> backend = TarBackend('path/to/tar') 34 | >>> backend.get('key') 35 | b'hello world' 36 | """ 37 | if self._client is None: 38 | self._client = self._get_client() 39 | 40 | filepath = str(filepath) 41 | try: 42 | with self._client.extractfile(filepath) as data: 43 | data = data.read() 44 | except KeyError as e: 45 | raise ValueError( 46 | f"Value '{filepath}' not found in {self._client}!") from e 47 | return data 48 | 49 | def get_text(self, filepath, encoding=None): 50 | raise NotImplementedError 51 | 52 | def _get_client(self) -> TarFile: 53 | """Get Tar client. 54 | 55 | Returns: 56 | TarFile: the tar file. 57 | """ 58 | 59 | if not os.path.exists(self.tar_path): 60 | raise FileNotFoundError( 61 | f"Corresponding tar file not found:" f" {self.tar_path}") 62 | 63 | return TarFile(self.tar_path) 64 | 65 | def __del__(self): 66 | if self._client is not None: 67 | self._client.close() -------------------------------------------------------------------------------- /shift_tta/fileio/backends/zip_backend.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | from abc import abstractmethod 4 | from zipfile import ZipFile 5 | 6 | from pathlib import Path 7 | from typing import Literal, Union 8 | 9 | from mmengine.fileio.backends import BaseStorageBackend, register_backend 10 | 11 | 12 | @register_backend('zip') 13 | class ZipBackend(BaseStorageBackend): 14 | """Backend for loading data from .zip files. 15 | 16 | This backend works with filepaths pointing to valid .zip files. We assume 17 | that the given .zip file contains the whole dataset associated to this 18 | backend. 19 | """ 20 | 21 | def __init__(self, zip_path: Union[str, Path] = '') -> None: 22 | self.zip_path = str(zip_path) 23 | self._client = None 24 | 25 | def get(self, filepath: str) -> bytes: 26 | """Get values according to the filepath. 27 | 28 | Args: 29 | filepath (str or Path): Here, filepath is the zip key. 30 | 31 | Returns: 32 | bytes: Expected bytes object. 33 | 34 | Examples: 35 | >>> backend = TarBackend('path/to/tar') 36 | >>> backend.get('key') 37 | b'hello world' 38 | """ 39 | 40 | if self._client is None: 41 | self._client = self._get_client('r') 42 | 43 | filepath = str(filepath) 44 | try: 45 | with self._client.open(filepath) as data: 46 | data = data.read() 47 | except KeyError as e: 48 | raise ValueError( 49 | f"Value '{filepath}' not found in {self._client}!") from e 50 | return bytes(data) 51 | 52 | def get_text(self, filepath, encoding=None): 53 | raise NotImplementedError 54 | 55 | def _get_client(self, mode: Literal["r", "w", "a", "x"]) -> ZipFile: 56 | """Get Zip client. 57 | 58 | Args: 59 | mode (str): Mode to open the file in. 60 | 61 | Returns: 62 | ZipFile: the zip file. 63 | """ 64 | assert len(mode) == 1, "Mode must be a single character for zip file." 65 | 66 | if not os.path.exists(self.zip_path): 67 | raise FileNotFoundError( 68 | f"Corresponding zip file not found:" f" {self.zip_path}") 69 | 70 | return ZipFile(self.zip_path, mode) 71 | 72 | def __del__(self): 73 | if self._client is not None: 74 | self._client.close() -------------------------------------------------------------------------------- /shift_tta/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .adapters import * # noqa: F401,F403 2 | from .detectors import * # noqa: F401,F403 3 | from .losses import * # noqa: F401,F403 4 | -------------------------------------------------------------------------------- /shift_tta/models/adapters/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_adapter import BaseAdapter 2 | from .mean_teacher_yolox_adapter import MeanTeacherYOLOXAdapter 3 | 4 | __all__ = [ 5 | 'BaseAdapter', 6 | 'MeanTeacherYOLOXAdapter' 7 | ] -------------------------------------------------------------------------------- /shift_tta/models/adapters/base_adapter.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from abc import ABCMeta, abstractmethod 3 | from typing import List 4 | 5 | import torch 6 | from copy import deepcopy 7 | 8 | from mmengine.structures import InstanceData 9 | from mmtrack.structures import TrackDataSample 10 | 11 | 12 | class BaseAdapter(metaclass=ABCMeta): 13 | """Base adapter model. 14 | 15 | Args: 16 | episodic (bool, optional). If episodic is True, the model will be reset 17 | to its initial state at the end of every evaluated sequence. 18 | Defaults to True. 19 | """ 20 | 21 | def __init__(self, 22 | episodic: bool = True) -> None: 23 | super().__init__() 24 | self.episodic = episodic 25 | self.fp16_enabled = False 26 | 27 | self.source_model_state = None 28 | 29 | def _init_source_model_state(self, model) -> None: 30 | """Init self.source_model_state. 31 | 32 | Args: 33 | model (nn.Module): detection model. 34 | """ 35 | self.source_model_state = deepcopy(model.state_dict()) 36 | 37 | def _restore_source_model_state(self, model) -> None: 38 | """Init self.source_model_state. 39 | 40 | Args: 41 | model (nn.Module): detection model. 42 | """ 43 | 44 | if self.source_model_state is None: 45 | raise Exception("cannot reset without saved model state") 46 | model.load_state_dict(self.source_model_state, strict=True) 47 | 48 | def reset(self, model) -> None: 49 | """Reset the model state to self.source_model_state. 50 | 51 | Args: 52 | model (nn.Module): detection model. 53 | """ 54 | if self.source_model_state is None: 55 | self._init_source_model_state(model) 56 | else: 57 | self._restore_source_model_state(model) 58 | 59 | @property 60 | def with_episodic(self) -> bool: 61 | """Whether the model has to be reset at the end of every sequence.""" 62 | return True if self.episodic else False 63 | 64 | @abstractmethod 65 | def _adapt(self, *args, **kwargs): 66 | """Adapt the model.""" 67 | pass 68 | 69 | def adapt(self, model: torch.nn.Module, img: torch.Tensor, 70 | feats: List[torch.Tensor], data_sample: TrackDataSample, 71 | **kwargs) -> InstanceData: 72 | """Adapt the model. 73 | 74 | 75 | Args: 76 | model (nn.Module): detection model. 77 | img (Tensor): of shape (T, C, H, W) encoding input image. 78 | Typically these should be mean centered and std scaled. 79 | The T denotes the number of key images and usually is 1 in 80 | ByteTrack method. 81 | feats (list[Tensor]): Multi level feature maps of `img`. 82 | data_sample (:obj:`TrackDataSample`): The data sample. 83 | It includes information such as `pred_det_instances`. 84 | 85 | Returns: 86 | :obj:`InstanceData`: Detection results of the input images. 87 | Each InstanceData usually contains ``bboxes``, ``labels``, 88 | ``scores`` and ``instances_id``. 89 | """ 90 | metainfo = data_sample.metainfo 91 | bboxes = data_sample.pred_det_instances.bboxes 92 | labels = data_sample.pred_det_instances.labels 93 | scores = data_sample.pred_det_instances.scores 94 | 95 | frame_id = metainfo.get('frame_id', -1) 96 | if self.with_episodic and frame_id == 0: 97 | self.reset(model) 98 | 99 | # adapt model 100 | self._adapt() # TODO: implement your own adapt method here 101 | 102 | # update pred_det_instances 103 | pred_det_instances = InstanceData() 104 | pred_det_instances.bboxes = bboxes 105 | pred_det_instances.labels = labels 106 | pred_det_instances.scores = scores 107 | 108 | return pred_det_instances -------------------------------------------------------------------------------- /shift_tta/models/adapters/mean_teacher_yolox_adapter.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from typing import List, Optional, Tuple 3 | 4 | import torch 5 | 6 | from copy import deepcopy 7 | 8 | from mmengine.dataset import Compose 9 | from mmengine.optim import build_optim_wrapper 10 | from mmengine.structures import InstanceData 11 | from mmtrack.structures import TrackDataSample 12 | 13 | from shift_tta.registry import MODELS 14 | from .base_adapter import BaseAdapter 15 | 16 | 17 | @MODELS.register_module() 18 | class MeanTeacherYOLOXAdapter(BaseAdapter): 19 | """Mean-teacher YOLOX adapter model. 20 | 21 | Args: 22 | teacher (dict): Configuration of teacher. Defaults to None. 23 | optim_wrapper (dict): Configuration of optimizer wrapper. 24 | Defaults to None. 25 | loss (dict): Configuration of loss. Defaults to None. 26 | pipeline (list(dict)): Configuration of image transforms. 27 | Defaults to None. 28 | """ 29 | 30 | def __init__(self, 31 | teacher: Optional[dict] = None, 32 | optim_wrapper: Optional[dict] = None, 33 | optim_steps: int = 0, 34 | loss: Optional[dict] = dict( 35 | type='ROIConsistencyLoss', 36 | weight=0.01, 37 | ), 38 | pipeline: Optional[list[dict]] = None, 39 | teacher_pipeline: Optional[list[dict]] = None, 40 | student_pipeline: Optional[list[dict]] = None, 41 | views: int = 1, 42 | **kwargs) -> None: 43 | super().__init__(**kwargs) 44 | 45 | self.teacher = None 46 | if teacher is not None: 47 | self.teacher_cfg = teacher 48 | 49 | # build optimizer 50 | self.optim_wrapper = None 51 | if optim_wrapper is not None: 52 | self.optim_wrapper_cfg = optim_wrapper 53 | self.optim_steps = optim_steps 54 | 55 | # build loss 56 | self.loss = MODELS.build(loss) 57 | 58 | # build image transforms 59 | self.pipeline = Compose(pipeline) 60 | self.teacher_pipeline = Compose(teacher_pipeline) 61 | self.student_pipeline = Compose(student_pipeline) 62 | self.views = views 63 | 64 | # TODO: implement param_scheduler for optimizer (e.g. lr decay) 65 | 66 | def _init_source_model_state(self, model) -> None: 67 | """Init self.source_model_state. 68 | 69 | Args: 70 | model (nn.Module): detection model. 71 | """ 72 | super()._init_source_model_state(model) 73 | 74 | self.optim_wrapper = build_optim_wrapper(model, self.optim_wrapper_cfg) 75 | self.optim_wrapper_state = deepcopy(self.optim_wrapper.state_dict()) 76 | 77 | if self.teacher_cfg is not None: 78 | self.teacher_cfg['model'] = model 79 | self.teacher = MODELS.build(self.teacher_cfg) 80 | self.teacher_model_state = deepcopy(self.teacher.state_dict()) 81 | 82 | def _reset_optimizer(self) -> None: 83 | """Reset optimizer state. 84 | 85 | Args: 86 | model (nn.Module): detection model.""" 87 | if self.optim_wrapper is not None: 88 | self.optim_wrapper.load_state_dict(self.optim_wrapper_state) 89 | 90 | def _restore_source_model_state(self, model) -> None: 91 | """Init self.source_model_state. 92 | 93 | Args: 94 | model (nn.Module): detection model. 95 | """ 96 | super()._restore_source_model_state(model) 97 | 98 | if self.teacher is not None: 99 | self.teacher.load_state_dict( 100 | self.teacher_model_state, strict=True) 101 | 102 | def _detect_forward(self, detector: torch.nn.Module, img: torch.Tensor, 103 | batch_data_samples: TrackDataSample, rescale: bool = True): 104 | """Detector forward pass.""" 105 | feats = detector.extract_feat(img) 106 | outs = detector.bbox_head.forward(feats) 107 | 108 | batch_img_metas = [ 109 | data_samples.metainfo for data_samples in batch_data_samples 110 | ] 111 | predictions = detector.bbox_head.predict_by_feat( 112 | *outs, batch_img_metas=batch_img_metas, rescale=rescale) 113 | det_results = detector.add_pred_to_datasample( 114 | batch_data_samples, predictions) 115 | 116 | return det_results, outs 117 | 118 | def _expand_view(self, outs: Tuple[torch.Tensor], views: int = 1): 119 | """Expand batch size of each element in outs to views.""" 120 | outs = tuple(list(o.repeat_interleave(views, dim=0) for o in out) 121 | for out in outs) 122 | return outs 123 | 124 | def _adapt(self, model: torch.nn.Module, 125 | teacher_img: torch.Tensor, 126 | student_imgs: torch.Tensor, 127 | teacher_data_samples: List[TrackDataSample], 128 | student_data_samples: List[TrackDataSample], 129 | *args, **kwargs) -> InstanceData: 130 | """Adapt the model.""" 131 | 132 | # teacher forward 133 | teacher_det_results, teacher_outs = self._detect_forward( 134 | self.teacher.module.detector, teacher_img, teacher_data_samples) 135 | teacher_outs = self._expand_view(teacher_outs, views=self.views) 136 | teacher_outs = dict( 137 | cls_score=teacher_outs[0], 138 | bbox_pred=teacher_outs[1], 139 | objectness=teacher_outs[2]) 140 | 141 | # student forward 142 | _, outs = self._detect_forward( 143 | model.detector, student_imgs, student_data_samples) 144 | outs = dict( 145 | cls_score=outs[0], 146 | bbox_pred=outs[1], 147 | objectness=outs[2]) 148 | 149 | # adapt 150 | loss = self.loss(outs, teacher_outs) 151 | loss.backward() 152 | self.optim_wrapper.step() 153 | self.optim_wrapper.zero_grad() 154 | 155 | return teacher_det_results 156 | 157 | def adapt(self, model: torch.nn.Module, img: torch.Tensor, 158 | feats: List[torch.Tensor], data_sample: TrackDataSample, 159 | **kwargs) -> InstanceData: 160 | """Adapt the model. 161 | 162 | Args: 163 | model (nn.Module): detection model. 164 | img (Tensor): of shape (T, C, H, W) encoding input image. 165 | Typically these should be mean centered and std scaled. 166 | The T denotes the number of key images and usually is 1 in 167 | ByteTrack method. 168 | feats (list[Tensor]): Multi level feature maps of `img`. 169 | data_sample (:obj:`TrackDataSample`): The data sample. 170 | It includes information such as `pred_det_instances`. 171 | 172 | Returns: 173 | :obj:`InstanceData`: Detection results of the input images. 174 | Each InstanceData usually contains ``bboxes``, ``labels``, 175 | ``scores`` and ``instances_id``. 176 | """ 177 | metainfo = data_sample.metainfo 178 | frame_id = metainfo.get('frame_id', -1) 179 | if self.with_episodic and frame_id == 0: 180 | self.reset(model) 181 | 182 | # adapt model 183 | # TODO: apply multiple image transforms 184 | # data_sample = self.transforms(deepcopy(data_sample)) 185 | # TODO: create a batch 186 | # TODO: compute teacher prediction on clean target image 187 | # TODO: compute distill loss into augmented batch 188 | # (concat targets to batch size) 189 | 190 | # make teacher and student views 191 | results = dict(img_path=data_sample.img_path, 192 | instances=data_sample.instances, 193 | ) 194 | results = self.pipeline(results) 195 | teacher_results = self.teacher_pipeline(results) 196 | teacher_img = teacher_results['inputs']['img'].to(img) 197 | teacher_data_samples = [teacher_results['data_samples']] 198 | 199 | student_imgs = [] 200 | student_data_samples = [] 201 | for _ in range(self.views): 202 | student_results = self.student_pipeline(results) 203 | student_imgs.append(student_results['inputs']['img']) 204 | student_data_samples.append(student_results['data_samples']) 205 | student_imgs = torch.cat(student_imgs).to(img) 206 | 207 | with torch.enable_grad(): 208 | model.requires_grad_(True) 209 | model.train(True) 210 | for _ in range(self.optim_steps): 211 | outs = self._adapt( 212 | model, teacher_img, student_imgs, 213 | teacher_data_samples, student_data_samples) 214 | self.teacher.update_parameters(model) 215 | 216 | self._reset_optimizer() 217 | 218 | # update pred_det_instances 219 | pred_det_instances = outs[0].pred_instances.clone() 220 | 221 | return pred_det_instances -------------------------------------------------------------------------------- /shift_tta/models/detectors/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import BaseAdaptiveDetector 2 | from .adaptive_detector import AdaptiveDetector 3 | 4 | __all__ = [ 5 | 'BaseAdaptiveDetector', 6 | 'AdaptiveDetector' 7 | ] -------------------------------------------------------------------------------- /shift_tta/models/detectors/adaptive_detector.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from typing import Dict, Optional 3 | 4 | import torch 5 | from torch import Tensor 6 | 7 | from mmtrack.utils import OptConfigType, OptMultiConfig, SampleList 8 | 9 | from shift_tta.registry import MODELS 10 | from .base import BaseAdaptiveDetector 11 | 12 | 13 | @MODELS.register_module() 14 | class AdaptiveDetector(BaseAdaptiveDetector): 15 | """AdaptiveDetector: baseline test-time adaptation method for object detection. 16 | 17 | Args: 18 | detector (dict): Configuration of detector. Defaults to None. 19 | adapter (dict): Configuration of adapter. Defaults to None. 20 | data_preprocessor (dict or ConfigDict, optional): The pre-process 21 | config of :class:`TrackDataPreprocessor`. it usually includes, 22 | ``pad_size_divisor``, ``pad_value``, ``mean`` and ``std``. 23 | init_cfg (dict or list[dict]): Configuration of initialization. 24 | Defaults to None. 25 | """ 26 | 27 | def __init__(self, 28 | detector: Optional[dict] = None, 29 | adapter: Optional[dict] = None, 30 | data_preprocessor: OptConfigType = None, 31 | init_cfg: OptMultiConfig = None): 32 | super().__init__(data_preprocessor, init_cfg) 33 | 34 | if detector is not None: 35 | self.detector = MODELS.build(detector) 36 | 37 | if adapter is not None: 38 | self.adapter = MODELS.build(adapter) 39 | 40 | def loss(self, inputs: Dict[str, Tensor], data_samples: SampleList, 41 | **kwargs) -> dict: 42 | """Calculate losses from a batch of inputs and data samples. 43 | 44 | Args: 45 | inputs (Dict[str, Tensor]): of shape (N, T, C, H, W) encoding 46 | input images. Typically these should be mean centered and std 47 | scaled. The N denotes batch size.The T denotes the number of 48 | key/reference frames. 49 | - img (Tensor) : The key images. 50 | - ref_img (Tensor): The reference images. 51 | data_samples (list[:obj:`TrackDataSample`]): The batch 52 | data samples. It usually includes information such 53 | as `gt_instance`. 54 | 55 | Returns: 56 | dict: A dictionary of loss components. 57 | """ 58 | # modify the inputs shape to fit mmdet 59 | img = inputs['img'] 60 | assert img.size(1) == 1 61 | # convert 'inputs' shape to (N, C, H, W) 62 | img = torch.squeeze(img, dim=1) 63 | return self.detector.loss(img, data_samples, **kwargs) 64 | 65 | def predict(self, inputs: Dict[str, Tensor], data_samples: SampleList, 66 | **kwargs) -> SampleList: 67 | """Predict results from a batch of inputs and data samples with post- 68 | processing. 69 | 70 | Args: 71 | inputs (Dict[str, Tensor]): of shape (N, T, C, H, W) encoding 72 | input images. Typically these should be mean centered and std 73 | scaled. The N denotes batch size.The T denotes the number of 74 | key/reference frames. 75 | - img (Tensor) : The key images. 76 | - ref_img (Tensor): The reference images. 77 | data_samples (list[:obj:`TrackDataSample`]): The batch 78 | data samples. It usually includes information such 79 | as `gt_instance`. 80 | 81 | Returns: 82 | SampleList: Tracking results of the input images. 83 | Each TrackDataSample usually contains ``pred_det_instances`` 84 | or ``pred_track_instances``. 85 | """ 86 | img = inputs['img'] 87 | assert img.dim() == 5, 'The img must be 5D Tensor (N, T, C, H, W).' 88 | assert img.size(0) == 1, \ 89 | 'AdaptiveDetector inference only support 1 batch size per gpu for now.' 90 | img = img[0] 91 | 92 | assert len(data_samples) == 1, \ 93 | 'AdaptiveDetector inference only support 1 batch size per gpu for now.' 94 | 95 | data_sample = data_samples[0] 96 | 97 | if self.with_adapter: 98 | adapted_det_instances = self.adapter.adapt( 99 | model=self, 100 | img=img, 101 | feats=None, 102 | data_sample=data_sample, 103 | **kwargs) 104 | data_sample.pred_det_instances = adapted_det_instances 105 | else: 106 | det_results = self.detector.predict(img, data_samples) 107 | assert len(det_results) == 1, 'Batch inference is not supported.' 108 | data_sample.pred_det_instances = \ 109 | det_results[0].pred_instances.clone() 110 | 111 | return [data_sample] -------------------------------------------------------------------------------- /shift_tta/models/detectors/base.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from abc import ABCMeta, abstractmethod 3 | from typing import Dict, List, Tuple, Union 4 | 5 | from mmengine.model import BaseModel 6 | from torch import Tensor 7 | 8 | from mmtrack.utils import (ForwardResults, OptConfigType, OptMultiConfig, 9 | OptSampleList, SampleList) 10 | 11 | 12 | class BaseAdaptiveDetector(BaseModel, metaclass=ABCMeta): 13 | """Base class for test-time adaptation of object detection. 14 | 15 | Args: 16 | data_preprocessor (dict or ConfigDict, optional): The pre-process 17 | config of :class:`TrackDataPreprocessor`. it usually includes, 18 | ``pad_size_divisor``, ``pad_value``, ``mean`` and ``std``. 19 | init_cfg (dict or list[dict]): Initialization config dict. 20 | """ 21 | 22 | def __init__(self, 23 | data_preprocessor: OptConfigType = None, 24 | init_cfg: OptMultiConfig = None) -> None: 25 | super().__init__( 26 | data_preprocessor=data_preprocessor, init_cfg=init_cfg) 27 | 28 | def freeze_module(self, module: Union[List[str], Tuple[str], str]) -> None: 29 | """Freeze module during training.""" 30 | if isinstance(module, str): 31 | modules = [module] 32 | else: 33 | if not (isinstance(module, list) or isinstance(module, tuple)): 34 | raise TypeError('module must be a str or a list.') 35 | else: 36 | modules = module 37 | for module in modules: 38 | m = getattr(self, module) 39 | m.eval() 40 | for param in m.parameters(): 41 | param.requires_grad = False 42 | 43 | @property 44 | def with_detector(self) -> bool: 45 | """bool: whether the framework has a detector.""" 46 | return hasattr(self, 'detector') and self.detector is not None 47 | 48 | @property 49 | def with_adapter(self) -> bool: 50 | """bool: whether the framework has an adapter model.""" 51 | return hasattr(self, 'adapter') and self.adapter is not None 52 | 53 | def forward(self, 54 | inputs: Dict[str, Tensor], 55 | data_samples: OptSampleList = None, 56 | mode: str = 'predict', 57 | **kwargs) -> ForwardResults: 58 | """The unified entry for a forward process in both training and test. 59 | 60 | The method should accept three modes: "tensor", "predict" and "loss": 61 | 62 | - "tensor": Forward the whole network and return tensor or tuple of 63 | tensor without any post-processing, same as a common nn.Module. 64 | - "predict": Forward and return the predictions, which are fully 65 | processed to a list of :obj:`TrackDataSample`. 66 | - "loss": Forward and return a dict of losses according to the given 67 | inputs and data samples. 68 | 69 | Note that this method doesn't handle neither back propagation nor 70 | optimizer updating, which are done in the :meth:`train_step`. 71 | 72 | Args: 73 | inputs (Dict[str, Tensor]): of shape (N, T, C, H, W) 74 | encoding input images. Typically these should be mean centered 75 | and std scaled. The N denotes batch size. The T denotes the 76 | number of key/reference frames. 77 | - img (Tensor) : The key images. 78 | - ref_img (Tensor): The reference images. 79 | data_samples (list[:obj:`TrackDataSample`], optional): The 80 | annotation data of every samples. Defaults to None. 81 | mode (str): Return what kind of value. Defaults to 'predict'. 82 | 83 | Returns: 84 | The return type depends on ``mode``. 85 | 86 | - If ``mode="tensor"``, return a tensor or a tuple of tensor. 87 | - If ``mode="predict"``, return a list of :obj:`TrackDataSample`. 88 | - If ``mode="loss"``, return a dict of tensor. 89 | """ 90 | if mode == 'loss': 91 | return self.loss(inputs, data_samples, **kwargs) 92 | elif mode == 'predict': 93 | return self.predict(inputs, data_samples, **kwargs) 94 | elif mode == 'tensor': 95 | return self._forward(inputs, data_samples, **kwargs) 96 | else: 97 | raise RuntimeError(f'Invalid mode "{mode}". ' 98 | 'Only supports loss, predict and tensor mode') 99 | 100 | @abstractmethod 101 | def loss(self, inputs: Dict[str, Tensor], data_samples: SampleList, 102 | **kwargs) -> Union[dict, tuple]: 103 | """Calculate losses from a batch of inputs and data samples.""" 104 | pass 105 | 106 | @abstractmethod 107 | def predict(self, inputs: Dict[str, Tensor], data_samples: SampleList, 108 | **kwargs) -> SampleList: 109 | """Predict results from a batch of inputs and data samples with post- 110 | processing.""" 111 | pass 112 | 113 | def _forward(self, 114 | inputs: Dict[str, Tensor], 115 | data_samples: OptSampleList = None, 116 | **kwargs): 117 | """Network forward process. Usually includes backbone, neck and head 118 | forward without any post-processing. 119 | 120 | Args: 121 | inputs (Dict[str, Tensor]): of shape (N, T, C, H, W). 122 | data_samples (List[:obj:`TrackDataSample`], optional): The 123 | Data Samples. It usually includes information such as 124 | `gt_instance`. 125 | 126 | Returns: 127 | tuple[list]: A tuple of features from ``head`` forward. 128 | """ 129 | raise NotImplementedError( 130 | "_forward function (namely 'tensor' mode) is not supported now") -------------------------------------------------------------------------------- /shift_tta/models/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from .yolox_consistency_loss import YOLOXConsistencyLoss 2 | 3 | __all__ = ['YOLOXConsistencyLoss'] -------------------------------------------------------------------------------- /shift_tta/models/losses/yolox_consistency_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.functional import mse_loss 4 | 5 | from shift_tta.registry import MODELS 6 | 7 | 8 | @MODELS.register_module() 9 | class YOLOXConsistencyLoss(nn.Module): 10 | """YOLOXConsistencyLoss 11 | Args: 12 | weight (float, optional): Weight of the loss. Default to 1.0. 13 | obj_weight (float, optional): Weight of the objectness consistency loss. 14 | Default to 1.0. 15 | reg_weight (float, optional): Weight of the regression consistency loss. 16 | Default to 1.0. 17 | cls_weight (float, optional): Weight of the classification consistency loss. 18 | Default to 1.0. 19 | """ 20 | 21 | def __init__(self, 22 | weight=1.0, 23 | obj_weight=1.0, 24 | reg_weight=1.0, 25 | cls_weight=1.0, 26 | ): 27 | super(YOLOXConsistencyLoss, self).__init__() 28 | self.weight = weight 29 | self.obj_weight = obj_weight 30 | self.reg_weight = reg_weight 31 | self.cls_weight = cls_weight 32 | 33 | def forward(self, inputs, targets, **kwargs): 34 | 35 | """Forward pass. 36 | Args: 37 | inputs: Dictionary of classification scores and bounding box 38 | refinements for the sampled proposals. For cls scores, the shape is 39 | (b*n) * (cats + 1), where n is sampled proposal in each image, cats 40 | is the total number of categories without the background. For bbox 41 | preds, the shape is (b*n) * (4*cats) 42 | targets: Same output by bbox_head from the teacher output. 43 | Returns: 44 | The YOLOX consistency loss. 45 | """ 46 | teacher_obj = targets["objectness"] 47 | teacher_reg = targets["bbox_pred"] 48 | teacher_cls = targets["cls_score"] 49 | 50 | student_obj = inputs["objectness"] 51 | student_reg = inputs["bbox_pred"] 52 | student_cls = inputs["cls_score"] 53 | 54 | obj_elements = 0 55 | reg_elements = 0 56 | cls_elements = 0 57 | obj_loss = 0. 58 | reg_loss = 0. 59 | cls_loss = 0. 60 | for (t_obj, t_reg, t_cls, s_obj, s_reg, s_cls) in zip( 61 | teacher_obj, teacher_reg, teacher_cls, 62 | student_obj, student_reg, student_cls, 63 | ): 64 | assert s_obj.shape == t_obj.shape 65 | assert s_reg.shape == t_reg.shape 66 | assert s_cls.shape == t_cls.shape 67 | 68 | _obj_loss = mse_loss(t_obj, s_obj, reduction="none") 69 | _cls_loss = mse_loss(t_cls, s_cls, reduction="none") 70 | _reg_loss = mse_loss(t_reg, s_reg, reduction="none") 71 | 72 | obj_loss += torch.sum(_obj_loss) 73 | reg_loss += torch.sum(_reg_loss) 74 | cls_loss += torch.sum(_cls_loss) 75 | 76 | obj_elements += _obj_loss.numel() 77 | reg_elements += _reg_loss.numel() 78 | cls_elements += _cls_loss.numel() 79 | 80 | obj_loss = obj_loss / obj_elements 81 | reg_loss = reg_loss / reg_elements 82 | cls_loss = cls_loss / cls_elements 83 | 84 | loss = self.obj_weight * obj_loss 85 | loss += self.reg_weight * reg_loss 86 | loss += self.cls_weight * cls_loss 87 | 88 | return self.weight * loss -------------------------------------------------------------------------------- /shift_tta/registry.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | """MMDetection provides 17 registry nodes to support using modules across 3 | projects. Each node is a child of the root registry in MMEngine. 4 | 5 | More details can be found at 6 | https://mmengine.readthedocs.io/en/latest/tutorials/registry.html. 7 | """ 8 | 9 | from mmengine.registry import DATA_SAMPLERS as MMENGINE_DATA_SAMPLERS 10 | from mmengine.registry import DATASETS as MMENGINE_DATASETS 11 | from mmengine.registry import EVALUATOR as MMENGINE_EVALUATOR 12 | from mmengine.registry import HOOKS as MMENGINE_HOOKS 13 | from mmengine.registry import LOG_PROCESSORS as MMENGINE_LOG_PROCESSORS 14 | from mmengine.registry import LOOPS as MMENGINE_LOOPS 15 | from mmengine.registry import METRICS as MMENGINE_METRICS 16 | from mmengine.registry import MODEL_WRAPPERS as MMENGINE_MODEL_WRAPPERS 17 | from mmengine.registry import MODELS as MMENGINE_MODELS 18 | from mmengine.registry import \ 19 | OPTIM_WRAPPER_CONSTRUCTORS as MMENGINE_OPTIM_WRAPPER_CONSTRUCTORS 20 | from mmengine.registry import OPTIM_WRAPPERS as MMENGINE_OPTIM_WRAPPERS 21 | from mmengine.registry import OPTIMIZERS as MMENGINE_OPTIMIZERS 22 | from mmengine.registry import PARAM_SCHEDULERS as MMENGINE_PARAM_SCHEDULERS 23 | from mmengine.registry import \ 24 | RUNNER_CONSTRUCTORS as MMENGINE_RUNNER_CONSTRUCTORS 25 | from mmengine.registry import RUNNERS as MMENGINE_RUNNERS 26 | from mmengine.registry import TASK_UTILS as MMENGINE_TASK_UTILS 27 | from mmengine.registry import TRANSFORMS as MMENGINE_TRANSFORMS 28 | from mmengine.registry import VISBACKENDS as MMENGINE_VISBACKENDS 29 | from mmengine.registry import VISUALIZERS as MMENGINE_VISUALIZERS 30 | from mmengine.registry import \ 31 | WEIGHT_INITIALIZERS as MMENGINE_WEIGHT_INITIALIZERS 32 | from mmengine.registry import Registry 33 | 34 | # manage all kinds of runners like `EpochBasedRunner` and `IterBasedRunner` 35 | RUNNERS = Registry( 36 | 'runner', parent=MMENGINE_RUNNERS) 37 | # manage runner constructors that define how to initialize runners 38 | RUNNER_CONSTRUCTORS = Registry( 39 | 'runner constructor', 40 | parent=MMENGINE_RUNNER_CONSTRUCTORS) 41 | # manage all kinds of loops like `EpochBasedTrainLoop` 42 | LOOPS = Registry( 43 | 'loop', parent=MMENGINE_LOOPS) 44 | # manage all kinds of hooks like `CheckpointHook` 45 | HOOKS = Registry( 46 | 'hook', parent=MMENGINE_HOOKS) 47 | 48 | # manage data-related modules 49 | DATASETS = Registry( 50 | 'dataset', parent=MMENGINE_DATASETS) 51 | DATA_SAMPLERS = Registry( 52 | 'data sampler', 53 | parent=MMENGINE_DATA_SAMPLERS) 54 | TRANSFORMS = Registry( 55 | 'transform', 56 | parent=MMENGINE_TRANSFORMS) 57 | 58 | # manage all kinds of modules inheriting `nn.Module` 59 | MODELS = Registry('model', parent=MMENGINE_MODELS) 60 | # manage all kinds of model wrappers like 'MMDistributedDataParallel' 61 | MODEL_WRAPPERS = Registry( 62 | 'model_wrapper', 63 | parent=MMENGINE_MODEL_WRAPPERS) 64 | # manage all kinds of weight initialization modules like `Uniform` 65 | WEIGHT_INITIALIZERS = Registry( 66 | 'weight initializer', 67 | parent=MMENGINE_WEIGHT_INITIALIZERS) 68 | 69 | # manage all kinds of optimizers like `SGD` and `Adam` 70 | OPTIMIZERS = Registry( 71 | 'optimizer', 72 | parent=MMENGINE_OPTIMIZERS) 73 | # manage optimizer wrapper 74 | OPTIM_WRAPPERS = Registry( 75 | 'optim_wrapper', 76 | parent=MMENGINE_OPTIM_WRAPPERS) 77 | # manage constructors that customize the optimization hyperparameters. 78 | OPTIM_WRAPPER_CONSTRUCTORS = Registry( 79 | 'optimizer constructor', 80 | parent=MMENGINE_OPTIM_WRAPPER_CONSTRUCTORS) 81 | # manage all kinds of parameter schedulers like `MultiStepLR` 82 | PARAM_SCHEDULERS = Registry( 83 | 'parameter scheduler', 84 | parent=MMENGINE_PARAM_SCHEDULERS) 85 | # manage all kinds of metrics 86 | METRICS = Registry( 87 | 'metric', parent=MMENGINE_METRICS) 88 | # manage evaluator 89 | EVALUATOR = Registry( 90 | 'evaluator', parent=MMENGINE_EVALUATOR) 91 | 92 | # manage task-specific modules like anchor generators and box coders 93 | TASK_UTILS = Registry( 94 | 'task util', parent=MMENGINE_TASK_UTILS) 95 | 96 | # manage visualizer 97 | VISUALIZERS = Registry( 98 | 'visualizer', 99 | parent=MMENGINE_VISUALIZERS) 100 | # manage visualizer backend 101 | VISBACKENDS = Registry( 102 | 'vis_backend', 103 | parent=MMENGINE_VISBACKENDS) 104 | 105 | # manage logprocessor 106 | LOG_PROCESSORS = Registry( 107 | 'log_processor', 108 | parent=MMENGINE_LOG_PROCESSORS) -------------------------------------------------------------------------------- /shift_tta/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .setup_env import register_all_modules 2 | 3 | __all__ = [ 4 | 'register_all_modules', 5 | ] -------------------------------------------------------------------------------- /shift_tta/utils/setup_env.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import warnings 3 | 4 | from mmtrack.utils import register_all_modules as register_all_mmtrack_modules 5 | from mmengine import DefaultScope 6 | 7 | 8 | def register_all_modules(init_default_scope: bool = True) -> None: 9 | """Register all modules in mmtrack into the registries. 10 | 11 | Args: 12 | init_default_scope (bool): Whether initialize the mmtrack default scope. 13 | When `init_default_scope=True`, the global default scope will be 14 | set to `mmtrack`, and all registries will build modules from mmtrack's 15 | registry node. To understand more about the registry, please refer 16 | to https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/registry.md 17 | Defaults to True. 18 | """ # noqa 19 | import shift_tta.fileio # noqa: F401,F403 20 | import shift_tta.datasets # noqa: F401,F403 21 | import shift_tta.evaluation # noqa: F401,F403 22 | import shift_tta.models # noqa: F401,F403 23 | 24 | # register parent modules 25 | register_all_mmtrack_modules(init_default_scope=False) 26 | 27 | if init_default_scope: 28 | never_created = DefaultScope.get_current_instance() is None \ 29 | or not DefaultScope.check_instance_created('shift_tta') 30 | if never_created: 31 | DefaultScope.get_instance('shift_tta', scope_name='shift_tta') 32 | return 33 | current_scope = DefaultScope.get_current_instance() 34 | if current_scope.scope_name != 'shift_tta': 35 | warnings.warn('The current default scope ' 36 | f'"{current_scope.scope_name}" is not "shift_tta", ' 37 | '`register_all_modules` will force the current' 38 | 'default scope to be "shift_tta". If this is not ' 39 | 'expected, please set `init_default_scope=False`.') 40 | # avoid name conflict 41 | new_instance_name = f'shift_tta-{datetime.datetime.now()}' 42 | DefaultScope.get_instance(new_instance_name, scope_name='shift_tta') -------------------------------------------------------------------------------- /shift_tta/version.py: -------------------------------------------------------------------------------- 1 | __version__ = '0.0.0' 2 | 3 | 4 | def parse_version_info(version_str): 5 | version_info = [] 6 | for x in version_str.split('.'): 7 | if x.isdigit(): 8 | version_info.append(int(x)) 9 | elif x.find('rc') != -1: 10 | patch_version = x.split('rc') 11 | version_info.append(int(patch_version[0])) 12 | version_info.append(f'rc{patch_version[1]}') 13 | return tuple(version_info) 14 | 15 | 16 | version_info = parse_version_info(__version__) -------------------------------------------------------------------------------- /tools/dist_test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | CONFIG=$1 4 | GPUS=$2 5 | NNODES=${NNODES:-1} 6 | NODE_RANK=${NODE_RANK:-0} 7 | PORT=${PORT:-29500} 8 | MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} 9 | 10 | # Make conda available 11 | eval "$(conda shell.bash hook)" 12 | # Activate a conda environment 13 | conda activate shift-tta 14 | 15 | export MPLBACKEND=Agg 16 | 17 | python -m torch.distributed.launch \ 18 | --nnodes=$NNODES \ 19 | --node_rank=$NODE_RANK \ 20 | --master_addr=$MASTER_ADDR \ 21 | --nproc_per_node=$GPUS \ 22 | --master_port=$PORT \ 23 | tools/test.py \ 24 | $CONFIG \ 25 | --launcher pytorch \ 26 | ${@:3} -------------------------------------------------------------------------------- /tools/dist_train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | CONFIG=$1 4 | GPUS=$2 5 | NNODES=${NNODES:-1} 6 | NODE_RANK=${NODE_RANK:-0} 7 | PORT=${PORT:-29500} 8 | MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} 9 | 10 | PORT=$(shuf -i 24000-29500 -n 1) 11 | 12 | # Make conda available 13 | eval "$(conda shell.bash hook)" 14 | # Activate a conda environment 15 | conda activate shift-tta 16 | 17 | export MPLBACKEND=Agg 18 | 19 | python -m torch.distributed.launch \ 20 | --nnodes=$NNODES \ 21 | --node_rank=$NODE_RANK \ 22 | --master_addr=$MASTER_ADDR \ 23 | --nproc_per_node=$GPUS \ 24 | --master_port=$PORT \ 25 | tools/train.py \ 26 | $CONFIG \ 27 | --launcher pytorch \ 28 | ${@:3} -------------------------------------------------------------------------------- /tools/install/setup_venv.sh: -------------------------------------------------------------------------------- 1 | 2 | conda create -n shift-tta python=3.9 -y 3 | conda activate shift-tta 4 | 5 | conda install pytorch=1.11.0 torchvision cudatoolkit=11.3 -c pytorch -y 6 | 7 | pip install -U openmim 8 | # install mmengine from main branch 9 | python -m pip install git+https://github.com/open-mmlab/mmengine.git@62f9504d701251db763f56658436fd23a586fe25 10 | mim install 'mmcv == 2.0.0rc4' 11 | mim install 'mmdet == 3.0.0rc5' 12 | # install mmclassification from dev-1.x branch at specific commit 13 | python -m pip install git+https://github.com/open-mmlab/mmclassification.git@3ff80f5047fe3f3780a05d387f913dd02999611d 14 | # install mmtracking from dev-1.x branch at specific commit 15 | python -m pip install git+https://github.com/open-mmlab/mmtracking.git@9e4cb98a3cdac749242cd8decb3a172058d4fd6e 16 | python -m pip install git+https://github.com/JonathonLuiten/TrackEval.git 17 | python -m pip install git+https://github.com/scalabel/scalabel.git 18 | python -m pip install --no-input -r requirements.txt 19 | 20 | # install shift-detection-tta 21 | pip install --no-input -v -e . 22 | -------------------------------------------------------------------------------- /tools/shift/download.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """ 4 | Download script for SHIFT Dataset. 5 | 6 | The data is released under a Creative Commons Attribution-NonCommercial-ShareAlike 4.0 License. 7 | Homepage: www.vis.xyz/shift/. 8 | (C)2022, VIS Group, ETH Zurich. 9 | 10 | 11 | Script usage example: 12 | python download.py --view "[front, left_stereo]" \ # list of view abbreviation to download 13 | --group "[img, semseg]" \ # list of data group abbreviation to download 14 | --split "[train, val, test]" \ # list of split to download 15 | --framerate "[images, videos]" \ # chooses the desired frame rate (images=1fps, videos=10fps) 16 | --shift "discrete" \ # type of domain shifts. Options: discrete, continuous/1x, continuous/10x, continuous/100x 17 | dataset_root # path where to store the downloaded data 18 | 19 | You can set the option to "all" to download the entire data from this option. For example, 20 | python download.py --view "all" --group "[img]" --split "all" --framerate "[images]" . 21 | downloads the entire RGB images from the dataset. 22 | """ 23 | 24 | import argparse 25 | import logging 26 | import os 27 | import sys 28 | import tempfile 29 | 30 | import tqdm 31 | 32 | if sys.version_info.major >= 3 and sys.version_info.minor >= 6: 33 | import urllib.request as urllib 34 | else: 35 | import urllib 36 | 37 | 38 | BASE_URL = "https://dl.cv.ethz.ch/shift/" 39 | 40 | FRAME_RATES = [("images", "images (1 fps)"), ("videos", "videos (10 fps)")] 41 | 42 | SPLITS = [ 43 | ("train", "training set"), 44 | ("val", "validation set"), 45 | ("test", "testing set"), 46 | ] 47 | 48 | VIEWS = [ 49 | ("front", "Front"), 50 | ("left_45", "Left 45°"), 51 | ("left_90", "Left 90°"), 52 | ("right_45", "Right 45°"), 53 | ("right_90", "Right 90°"), 54 | ("left_stereo", "Front (Stereo)"), 55 | ("center", "Center (for LiDAR)"), 56 | ] 57 | 58 | DATA_GROUPS = [ 59 | ("img", "zip", "RGB Image"), 60 | ("det_2d", "json", "2D Detection and Tracking"), 61 | ("det_3d", "json", "3D Detection and Tracking"), 62 | ("semseg", "zip", "Semantic Segmentation"), 63 | ("det_insseg_2d", "json", "Instance Segmentation"), 64 | ("flow", "zip", "Optical Flow"), 65 | ("depth", "zip", "Depth Maps"), 66 | ("seq", "csv", "Sequence Info"), 67 | ("lidar", "zip", "LiDAR Point Cloud"), 68 | ] 69 | 70 | 71 | class ProgressBar(tqdm.tqdm): 72 | def update_to(self, batch=1, batch_size=1, total=None): 73 | if total is not None: 74 | self.total = total 75 | self.update(batch * batch_size - self.n) 76 | 77 | 78 | def setup_logger(): 79 | log_formatter = logging.Formatter( 80 | "[%(asctime)s] SHIFT Downloader - %(levelname)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", 81 | ) 82 | logger = logging.getLogger("logger") 83 | logger.setLevel(logging.DEBUG) 84 | ch = logging.StreamHandler() 85 | ch.setLevel(logging.DEBUG) 86 | ch.setFormatter(log_formatter) 87 | logger.addHandler(ch) 88 | return logger 89 | 90 | 91 | def get_url_discrete(rate, split, view, group, ext): 92 | url = BASE_URL + "discrete/{rate}/{split}/{view}/{group}.{ext}".format( 93 | rate=rate, split=split, view=view, group=group, ext=ext 94 | ) 95 | return url 96 | 97 | 98 | def get_url_continuous(rate, shift_length, split, view, group, ext): 99 | url = BASE_URL + "continuous/{rate}/{shift_length}/{split}/{view}/{group}.{ext}".format( 100 | rate=rate, shift_length=shift_length, split=split, view=view, group=group, ext=ext 101 | ) 102 | return url 103 | 104 | 105 | def string_to_list(option_str): 106 | option_str = option_str.replace(" ", "").lstrip("[").rstrip("]") 107 | return option_str.split(",") 108 | 109 | 110 | def parse_options(option_str, bounds, name): 111 | if option_str == "all": 112 | return bounds 113 | candidates = {} 114 | for item in bounds: 115 | candidates[item[0]] = item 116 | used = [] 117 | try: 118 | option_list = string_to_list(option_str) 119 | except Exception as e: 120 | logger.error("Error in parsing options." + e) 121 | for option in option_list: 122 | if option not in candidates: 123 | logger.info( 124 | "Invalid option '{option}' for '{name}'. ".format(option=option, name=name) 125 | + "Please check the download document (https://www.vis.xyz/shift/download/)." 126 | ) 127 | else: 128 | used.append(candidates[option]) 129 | if len(used) == 0: 130 | logger.error( 131 | "No '{name}' is specified to download. ".format(name=name) 132 | + "If you want to download all {name}s, please use '--{name} all'.".format(name=name) 133 | ) 134 | sys.exit(1) 135 | return used 136 | 137 | 138 | def download_file(url, out_file): 139 | out_dir = os.path.dirname(out_file) 140 | if not os.path.isdir(out_dir): 141 | os.makedirs(out_dir) 142 | if not os.path.isfile(out_file): 143 | logging.info("downloading " + url) 144 | fh, out_file_tmp = tempfile.mkstemp(dir=out_dir) 145 | f = os.fdopen(fh, "w") 146 | f.close() 147 | filename = url.split("/")[-1] 148 | with ProgressBar(unit="B", unit_scale=True, miniters=1, desc=filename) as t: 149 | urllib.urlretrieve(url, out_file_tmp, reporthook=t.update_to) 150 | os.rename(out_file_tmp, out_file) 151 | else: 152 | logger.warning("Skipping download of existing file " + out_file) 153 | 154 | 155 | def main(): 156 | parser = argparse.ArgumentParser(description="Downloads SHIFT Dataset public release.") 157 | parser.add_argument("out_dir", help="output directory in which to store the data.") 158 | parser.add_argument("--split", type=str, default="", help="specific splits to download.") 159 | parser.add_argument("--view", type=str, default="", help="specific views to download.") 160 | parser.add_argument("--group", type=str, default="", help="specific data groups to download.") 161 | parser.add_argument("--framerate", type=str, default="", help="specific frame rate to download.") 162 | parser.add_argument( 163 | "--shift", 164 | type=str, 165 | default="discrete", 166 | choices=["discrete", "continuous/1x", "continuous/10x", "continuous/100x"], 167 | help="specific shift type to download.", 168 | ) 169 | args = parser.parse_args() 170 | 171 | print( 172 | "Welcome to use SHIFT Dataset download script! \n" 173 | "By continuing you confirm that you have agreed to the SHIFT's user license.\n" 174 | ) 175 | 176 | frame_rates = parse_options(args.framerate, FRAME_RATES, "frame rate") 177 | splits = parse_options(args.split, SPLITS, "split") 178 | views = parse_options(args.view, VIEWS, "view") 179 | data_groups = parse_options(args.group, DATA_GROUPS, "data group") 180 | total_files = len(frame_rates) * len(splits) * len(views) * len(data_groups) 181 | logger.info("Number of files to download: " + str(total_files)) 182 | 183 | if "lidar" in data_groups and views != ["center"]: 184 | logger.error("LiDAR data only available for Center view!") 185 | sys.exit(1) 186 | 187 | for rate, rate_name in frame_rates: 188 | for split, split_name in splits: 189 | for view, view_name in views: 190 | for group, ext, group_name in data_groups: 191 | if rate == "videos" and group in ["img"]: 192 | ext = "tar" 193 | if args.shift == "discrete": 194 | url = get_url_discrete(rate, split, view, group, ext) 195 | out_file = os.path.join(args.out_dir, "discrete", rate, split, view, group + "." + ext) 196 | else: 197 | shift_length = args.shift.split("/")[-1] 198 | url = get_url_continuous(rate, shift_length, split, view, group, ext) 199 | out_file = os.path.join( 200 | args.out_dir, "continuous", rate, shift_length, split, view, group + "." + ext 201 | ) 202 | logger.info( 203 | "Downloading - Shift: {shift}, Framerate: {rate}, Split: {split}, View: {view}, Data group: {group}.".format( 204 | shift=args.shift, 205 | rate=rate_name, 206 | split=split_name, 207 | view=view_name, 208 | group=group_name, 209 | url=url, 210 | ) 211 | ) 212 | try: 213 | download_file(url, out_file) 214 | except Exception as e: 215 | logger.error("Error in downloading " + str(e)) 216 | 217 | logger.info("Done!") 218 | 219 | 220 | if __name__ == "__main__": 221 | logger = setup_logger() 222 | main() -------------------------------------------------------------------------------- /tools/test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | import os 4 | import os.path as osp 5 | 6 | from mmengine.config import Config, DictAction 7 | from mmengine.model import is_model_wrapper 8 | from mmengine.registry import RUNNERS 9 | from mmengine.runner import Runner 10 | 11 | from shift_tta.utils import register_all_modules 12 | 13 | 14 | # TODO: support fuse_conv_bn, visualization, and format_only 15 | def parse_args(): 16 | parser = argparse.ArgumentParser( 17 | description='shift-tta test (and eval) a model') 18 | parser.add_argument('config', help='test config file path') 19 | parser.add_argument('--checkpoint', help='checkpoint file') 20 | parser.add_argument( 21 | '--work-dir', 22 | help='the directory to save the file containing evaluation metrics') 23 | parser.add_argument( 24 | '--cfg-options', 25 | nargs='+', 26 | action=DictAction, 27 | help='override some settings in the used config, the key-value pair ' 28 | 'in xxx=yyy format will be merged into config file. If the value to ' 29 | 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' 30 | 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' 31 | 'Note that the quotation marks are necessary and that no white space ' 32 | 'is allowed.') 33 | parser.add_argument( 34 | '--launcher', 35 | choices=['none', 'pytorch', 'slurm', 'mpi'], 36 | default='none', 37 | help='job launcher') 38 | parser.add_argument('--local_rank', type=int, default=0) 39 | args = parser.parse_args() 40 | if 'LOCAL_RANK' not in os.environ: 41 | os.environ['LOCAL_RANK'] = str(args.local_rank) 42 | return args 43 | 44 | 45 | def main(): 46 | args = parse_args() 47 | 48 | # register all modules in shift-tta into the registries 49 | # do not init the default scope here because it will be init in the runner 50 | register_all_modules(init_default_scope=False) 51 | 52 | # load config 53 | cfg = Config.fromfile(args.config) 54 | cfg.launcher = args.launcher 55 | if args.cfg_options is not None: 56 | cfg.merge_from_dict(args.cfg_options) 57 | 58 | # work_dir is determined in this priority: CLI > segment in file > filename 59 | if args.work_dir is not None: 60 | # update configs according to CLI args if args.work_dir is not None 61 | cfg.work_dir = args.work_dir 62 | elif cfg.get('work_dir', None) is None: 63 | # use config filename as default work_dir if cfg.work_dir is None 64 | cfg.work_dir = osp.join('./work_dirs', 65 | osp.splitext(osp.basename(args.config))[0]) 66 | 67 | cfg.load_from = args.checkpoint 68 | 69 | # build the runner from config 70 | if 'runner_type' not in cfg: 71 | # build the default runner 72 | runner = Runner.from_cfg(cfg) 73 | else: 74 | # build customized runner from the registry 75 | # if 'runner_type' is set in the cfg 76 | runner = RUNNERS.build(cfg) 77 | 78 | if is_model_wrapper(runner.model): 79 | runner.model.module.init_weights() 80 | else: 81 | runner.model.init_weights() 82 | 83 | # start testing 84 | runner.test() 85 | 86 | 87 | if __name__ == '__main__': 88 | main() -------------------------------------------------------------------------------- /tools/test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | CONFIG=$1 4 | GPUS=$2 5 | PY_ARGS=${@:3} 6 | 7 | # Make conda available 8 | eval "$(conda shell.bash hook)" 9 | # Activate a conda environment 10 | conda activate shift-tta 11 | 12 | python tools/test.py \ 13 | ${CONFIG} \ 14 | ${PY_ARGS} -------------------------------------------------------------------------------- /tools/train.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | import logging 4 | import os 5 | import os.path as osp 6 | 7 | from mmengine.config import Config, DictAction 8 | from mmengine.logging import print_log 9 | from mmengine.registry import RUNNERS 10 | from mmengine.runner import Runner 11 | 12 | from shift_tta.utils import register_all_modules 13 | 14 | 15 | def parse_args(): 16 | parser = argparse.ArgumentParser(description='Train a model') 17 | parser.add_argument('config', help='train config file path') 18 | parser.add_argument('--work-dir', help='the dir to save logs and models') 19 | parser.add_argument( 20 | '--amp', 21 | action='store_true', 22 | default=False, 23 | help='enable automatic-mixed-precision training') 24 | parser.add_argument( 25 | '--auto-scale-lr', 26 | action='store_true', 27 | help='enable automatically scaling LR.') 28 | parser.add_argument( 29 | '--resume', 30 | action='store_true', 31 | help='resume from the latest checkpoint in the work_dir automatically') 32 | parser.add_argument( 33 | '--cfg-options', 34 | nargs='+', 35 | action=DictAction, 36 | help='override some settings in the used config, the key-value pair ' 37 | 'in xxx=yyy format will be merged into config file. If the value to ' 38 | 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' 39 | 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' 40 | 'Note that the quotation marks are necessary and that no white space ' 41 | 'is allowed.') 42 | parser.add_argument( 43 | '--launcher', 44 | choices=['none', 'pytorch', 'slurm', 'mpi'], 45 | default='none', 46 | help='job launcher') 47 | parser.add_argument('--local_rank', type=int, default=0) 48 | args = parser.parse_args() 49 | if 'LOCAL_RANK' not in os.environ: 50 | os.environ['LOCAL_RANK'] = str(args.local_rank) 51 | 52 | return args 53 | 54 | 55 | def main(): 56 | args = parse_args() 57 | 58 | # register all modules in shift-tta into the registries 59 | # do not init the default scope here because it will be init in the runner 60 | register_all_modules(init_default_scope=False) 61 | 62 | # load config 63 | cfg = Config.fromfile(args.config) 64 | cfg.launcher = args.launcher 65 | if args.cfg_options is not None: 66 | cfg.merge_from_dict(args.cfg_options) 67 | 68 | # work_dir is determined in this priority: CLI > segment in file > filename 69 | if args.work_dir is not None: 70 | # update configs according to CLI args if args.work_dir is not None 71 | cfg.work_dir = args.work_dir 72 | elif cfg.get('work_dir', None) is None: 73 | # use config filename as default work_dir if cfg.work_dir is None 74 | cfg.work_dir = osp.join('./work_dirs', 75 | osp.splitext(osp.basename(args.config))[0]) 76 | 77 | # enable automatic-mixed-precision training 78 | if args.amp is True: 79 | optim_wrapper = cfg.optim_wrapper.type 80 | if optim_wrapper == 'AmpOptimWrapper': 81 | print_log( 82 | 'AMP training is already enabled in your config.', 83 | logger='current', 84 | level=logging.WARNING) 85 | else: 86 | assert optim_wrapper == 'OptimWrapper', ( 87 | '`--amp` is only supported when the optimizer wrapper type is ' 88 | f'`OptimWrapper` but got {optim_wrapper}.') 89 | cfg.optim_wrapper.type = 'AmpOptimWrapper' 90 | cfg.optim_wrapper.loss_scale = 'dynamic' 91 | 92 | # enable automatically scaling LR 93 | if args.auto_scale_lr: 94 | if 'auto_scale_lr' in cfg and \ 95 | 'enable' in cfg.auto_scale_lr and \ 96 | 'base_batch_size' in cfg.auto_scale_lr: 97 | cfg.auto_scale_lr.enable = True 98 | else: 99 | raise RuntimeError('Can not find "auto_scale_lr" or ' 100 | '"auto_scale_lr.enable" or ' 101 | '"auto_scale_lr.base_batch_size" in your' 102 | ' configuration file.') 103 | cfg.resume = args.resume 104 | 105 | # build the runner from config 106 | if 'runner_type' not in cfg: 107 | # build the default runner 108 | runner = Runner.from_cfg(cfg) 109 | else: 110 | # build customized runner from the registry 111 | # if 'runner_type' is set in the cfg 112 | runner = RUNNERS.build(cfg) 113 | 114 | # start training 115 | runner.train() 116 | 117 | 118 | if __name__ == '__main__': 119 | main() -------------------------------------------------------------------------------- /tools/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | CONFIG=$1 4 | GPUS=$2 5 | PY_ARGS=${@:3} 6 | 7 | # Make conda available 8 | eval "$(conda shell.bash hook)" 9 | # Activate a conda environment 10 | conda activate shift-tta 11 | 12 | python tools/train.py \ 13 | ${CONFIG} \ 14 | ${PY_ARGS} --------------------------------------------------------------------------------