├── .gitignore ├── .isort.cfg ├── .pre-commit-config.yaml ├── LICENSE ├── Makefile ├── README.md ├── SECURITY.md ├── configs ├── baseline │ ├── base.py │ ├── faster_rcnn_r101_caffe_fpn_coco_full_720k.py │ ├── faster_rcnn_r50_caffe_fpn_coco_full_720k.py │ └── faster_rcnn_r50_caffe_fpn_coco_partial_180k.py └── soft_teacher │ ├── base.py │ ├── soft_teacher_faster_rcnn_r101_caffe_fpn_coco_full_1080k.py │ ├── soft_teacher_faster_rcnn_r50_caffe_fpn_coco_180k.py │ ├── soft_teacher_faster_rcnn_r50_caffe_fpn_coco_full_1440k.py │ └── soft_teacher_faster_rcnn_r50_caffe_fpn_coco_full_720k.py ├── demo └── image_demo.py ├── requirements.txt ├── resources └── pipeline.png ├── setup.py ├── ssod ├── __init__.py ├── apis │ ├── __init__.py │ ├── inference.py │ └── train.py ├── core │ ├── __init__.py │ └── masks │ │ ├── __init__.py │ │ └── structures.py ├── datasets │ ├── __init__.py │ ├── builder.py │ ├── dataset_wrappers.py │ ├── pipelines │ │ ├── __init__.py │ │ ├── formatting.py │ │ ├── geo_utils.py │ │ └── rand_aug.py │ ├── pseudo_coco.py │ └── samplers │ │ ├── __init__.py │ │ └── semi_sampler.py ├── models │ ├── __init__.py │ ├── multi_stream_detector.py │ ├── soft_teacher.py │ └── utils │ │ ├── __init__.py │ │ └── bbox_utils.py ├── utils │ ├── __init__.py │ ├── exts │ │ ├── __init__.py │ │ └── optimizer_constructor.py │ ├── hooks │ │ ├── __init__.py │ │ ├── evaluation.py │ │ ├── mean_teacher.py │ │ ├── submodules_evaluation.py │ │ ├── weight_adjust.py │ │ └── weights_summary.py │ ├── logger.py │ ├── patch.py │ ├── signature.py │ ├── structure_utils.py │ └── vars.py └── version.py └── tools ├── dataset ├── prepare_coco_data.sh ├── semi_coco.py └── semi_coco.sh ├── dist_test.sh ├── dist_train.sh ├── dist_train_partially.sh ├── misc └── browse_dataset.py ├── test.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | #vs code 2 | .history/ 3 | .vscode 4 | .idea 5 | .history 6 | .DS_Store 7 | #python 8 | __pycache__/ 9 | */__pycache__ 10 | *.egg-info 11 | build 12 | #lib 13 | tests 14 | thirdparty 15 | thirdparty/ 16 | 17 | #develop 18 | wandb 19 | data 20 | data/ 21 | *.pkl 22 | *.pkl.json 23 | *.log.json 24 | work_dirs/ 25 | figures 26 | cp.py 27 | 28 | # Pytorch 29 | *.pth 30 | *.py~ 31 | *.sh~ 32 | launch.py 33 | 34 | #nvidia 35 | *.qdrep 36 | *.sqlite 37 | 38 | .pytest* 39 | -------------------------------------------------------------------------------- /.isort.cfg: -------------------------------------------------------------------------------- 1 | [settings] 2 | known_third_party = PIL,cv2,mmcv,mmdet,numpy,prettytable,setuptools,torch 3 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/ambv/black 3 | rev: 21.5b1 4 | hooks: 5 | - id: black 6 | - repo: https://github.com/asottile/seed-isort-config 7 | rev: v2.2.0 8 | hooks: 9 | - id: seed-isort-config 10 | - repo: https://github.com/pre-commit/pre-commit-hooks 11 | rev: v4.0.1 12 | hooks: 13 | - id: trailing-whitespace 14 | - id: check-yaml 15 | - id: end-of-file-fixer 16 | - id: requirements-txt-fixer 17 | - id: check-merge-conflict 18 | - id: fix-encoding-pragma 19 | args: ["--remove"] 20 | - id: mixed-line-ending 21 | args: ["--fix=lf"] 22 | - repo: https://github.com/jumanjihouse/pre-commit-hooks 23 | rev: 2.1.5 24 | hooks: 25 | - id: markdownlint 26 | args: ["-r", "~MD002,~MD013,~MD024,~MD029,~MD033,~MD034,~MD036", "-t", "allow_different_nesting"] 27 | - repo: https://github.com/myint/docformatter 28 | rev: v1.4 29 | hooks: 30 | - id: docformatter 31 | args: ["--in-place", "--wrap-descriptions", "79"] 32 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Microsoft 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | pre: 2 | python -m pip install -r requirements.txt -f https://download.pytorch.org/whl/torch_stable.html 3 | mkdir -p thirdparty 4 | git clone https://github.com/open-mmlab/mmdetection.git thirdparty/mmdetection 5 | cd thirdparty/mmdetection && python -m pip install -e . 6 | install: 7 | make pre 8 | python -m pip install -e . 9 | clean: 10 | rm -rf thirdparty 11 | rm -r ssod.egg-info 12 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # End-to-End Semi-Supervised Object Detection with Soft Teacher 2 | 3 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/end-to-end-semi-supervised-object-detection/semi-supervised-object-detection-on-coco-1)](https://paperswithcode.com/sota/semi-supervised-object-detection-on-coco-1?p=end-to-end-semi-supervised-object-detection) 4 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/end-to-end-semi-supervised-object-detection/semi-supervised-object-detection-on-coco-5)](https://paperswithcode.com/sota/semi-supervised-object-detection-on-coco-5?p=end-to-end-semi-supervised-object-detection) 5 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/end-to-end-semi-supervised-object-detection/semi-supervised-object-detection-on-coco-10)](https://paperswithcode.com/sota/semi-supervised-object-detection-on-coco-10?p=end-to-end-semi-supervised-object-detection) 6 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/end-to-end-semi-supervised-object-detection/semi-supervised-object-detection-on-coco-100)](https://paperswithcode.com/sota/semi-supervised-object-detection-on-coco-100?p=end-to-end-semi-supervised-object-detection) 7 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/end-to-end-semi-supervised-object-detection/instance-segmentation-on-coco-minival)](https://paperswithcode.com/sota/instance-segmentation-on-coco-minival?p=end-to-end-semi-supervised-object-detection) 8 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/end-to-end-semi-supervised-object-detection/object-detection-on-coco-minival)](https://paperswithcode.com/sota/object-detection-on-coco-minival?p=end-to-end-semi-supervised-object-detection) 9 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/end-to-end-semi-supervised-object-detection/instance-segmentation-on-coco)](https://paperswithcode.com/sota/instance-segmentation-on-coco?p=end-to-end-semi-supervised-object-detection) 10 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/end-to-end-semi-supervised-object-detection/object-detection-on-coco)](https://paperswithcode.com/sota/object-detection-on-coco?p=end-to-end-semi-supervised-object-detection) 11 | 12 | By [Mengde Xu*](https://scholar.google.com/citations?user=C04zJHEAAAAJ&hl=zh-CN), [Zheng Zhang*](https://github.com/stupidZZ), [Han Hu](https://github.com/ancientmooner), [Jianfeng Wang](https://github.com/amsword), [Lijuan Wang](https://www.microsoft.com/en-us/research/people/lijuanw/), [Fangyun Wei](https://scholar.google.com.tw/citations?user=-ncz2s8AAAAJ&hl=zh-TW), [Xiang Bai](http://cloud.eic.hust.edu.cn:8071/~xbai/), [Zicheng Liu](https://www.microsoft.com/en-us/research/people/zliu/). 13 | 14 | ![](./resources/pipeline.png) 15 | This repo is the official implementation of ICCV2021 paper ["End-to-End Semi-Supervised Object Detection with Soft Teacher"](https://arxiv.org/abs/2106.09018). 16 | 17 | ## Citation 18 | 19 | ```bib 20 | @article{xu2021end, 21 | title={End-to-End Semi-Supervised Object Detection with Soft Teacher}, 22 | author={Xu, Mengde and Zhang, Zheng and Hu, Han and Wang, Jianfeng and Wang, Lijuan and Wei, Fangyun and Bai, Xiang and Liu, Zicheng}, 23 | journal={Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)}, 24 | year={2021} 25 | } 26 | ``` 27 | 28 | ## Main Results 29 | 30 | ### Partial Labeled Data 31 | 32 | We followed STAC[1] to evaluate on 5 different data splits for each setting, and report the average performance of 5 splits. The results are shown in the following: 33 | 34 | #### 1% labeled data 35 | | Method | mAP| Model Weights |Config Files| 36 | | ---- | -------| ----- |----| 37 | | Baseline| 10.0 |-|[Config](configs/baseline/faster_rcnn_r50_caffe_fpn_coco_partial_180k.py)| 38 | | Ours (thr=5e-2) | 21.62 |[Drive](https://drive.google.com/drive/folders/1QA8sAw49DJiMHF-Cr7q0j7KgKjlJyklV?usp=sharing)|[Config](configs/soft_teacher/soft_teacher_faster_rcnn_r50_caffe_fpn_coco_180k.py)| 39 | | Ours (thr=1e-3)|22.64| [Drive](https://drive.google.com/drive/folders/1QA8sAw49DJiMHF-Cr7q0j7KgKjlJyklV?usp=sharing)|[Config](configs/soft_teacher/soft_teacher_faster_rcnn_r50_caffe_fpn_coco_180k.py)| 40 | 41 | #### 5% labeled data 42 | | Method | mAP| Model Weights |Config Files| 43 | | ---- | -------| ----- |----| 44 | | Baseline| 20.92 |-|[Config](configs/baseline/faster_rcnn_r50_caffe_fpn_coco_partial_180k.py)| 45 | | Ours (thr=5e-2) | 30.42 |[Drive](https://drive.google.com/drive/folders/1FBWj5SB888m0LU_XYUOK9QEgiubSbU-8?usp=sharing)|[Config](configs/soft_teacher/soft_teacher_faster_rcnn_r50_caffe_fpn_coco_180k.py)| 46 | | Ours (thr=1e-3)|31.7| [Drive](https://drive.google.com/drive/folders/1FBWj5SB888m0LU_XYUOK9QEgiubSbU-8?usp=sharing)|[Config](configs/soft_teacher/soft_teacher_faster_rcnn_r50_caffe_fpn_coco_180k.py)| 47 | 48 | #### 10% labeled data 49 | | Method | mAP| Model Weights |Config Files| 50 | | ---- | -------| ----- |----| 51 | | Baseline| 26.94 |-|[Config](configs/baseline/faster_rcnn_r50_caffe_fpn_coco_partial_180k.py)| 52 | | Ours (thr=5e-2) | 33.78 |[Drive](https://drive.google.com/drive/folders/1WyAVpfnWxEgvxCLUesxzNB81fM_de9DI?usp=sharing)|[Config](configs/soft_teacher/soft_teacher_faster_rcnn_r50_caffe_fpn_coco_180k.py)| 53 | | Ours (thr=1e-3)|34.7| [Drive](https://drive.google.com/drive/folders/1WyAVpfnWxEgvxCLUesxzNB81fM_de9DI?usp=sharing)|[Config](configs/soft_teacher/soft_teacher_faster_rcnn_r50_caffe_fpn_coco_180k.py)| 54 | 55 | ### Full Labeled Data 56 | 57 | #### Faster R-CNN (ResNet-50) 58 | | Model | mAP| Model Weights |Config Files| 59 | | ------ |--- | ----- |----| 60 | | Baseline | 40.9 | - | [Config](configs/baseline/faster_rcnn_r50_caffe_fpn_coco_full_720k.py) | 61 | | Ours (thr=5e-2) | 44.05 |[Drive](https://drive.google.com/file/d/1QSwAcU1dpmqVkJiXufW_QaQu-puOeblG/view?usp=sharing)|[Config](configs/soft_teacher/soft_teacher_faster_rcnn_r50_caffe_fpn_coco_full_720k.py)| 62 | | Ours (thr=1e-3) | 44.6 |[Drive](https://drive.google.com/file/d/1QSwAcU1dpmqVkJiXufW_QaQu-puOeblG/view?usp=sharing)|[Config](configs/soft_teacher/soft_teacher_faster_rcnn_r50_caffe_fpn_coco_full_720k.py)| 63 | | Ours* (thr=5e-2) | 44.5 | - | [Config](configs/soft_teacher/soft_teacher_faster_rcnn_r50_caffe_fpn_coco_full_1440k.py) | 64 | | Ours* (thr=1e-3) | 44.9 | - | [Config](configs/soft_teacher/soft_teacher_faster_rcnn_r50_caffe_fpn_coco_full_1440k.py) | 65 | 66 | #### Faster R-CNN (ResNet-101) 67 | | Model | mAP| Model Weights |Config Files| 68 | | ------ |--- | ----- |----| 69 | | Baseline | 43.8 | - | [Config](configs/baseline/faster_rcnn_r101_caffe_fpn_coco_full_720k.py) | 70 | | Ours* (thr=5e-2) | 46.9 | [Drive](https://drive.google.com/file/d/1LCZpIKBt0ihnPmvvZolV-L94uIn-U7Sp/view?usp=sharing) |[Config](configs/soft_teacher/soft_teacher_faster_rcnn_r101_caffe_fpn_coco_full_1080k.py) | 71 | | Ours* (thr=1e-3) | 47.6 | [Drive](https://drive.google.com/file/d/1LCZpIKBt0ihnPmvvZolV-L94uIn-U7Sp/view?usp=sharing) | [Config](configs/soft_teacher/soft_teacher_faster_rcnn_r101_caffe_fpn_coco_full_1080k.py) | 72 | 73 | 74 | ### Notes 75 | - Ours* means we use longer training schedule. 76 | - `thr` indicates `model.test_cfg.rcnn.score_thr` in config files. This inference trick was first introduced by Instant-Teaching[2]. 77 | - All models are trained on 8*V100 GPUs 78 | 79 | ## Usage 80 | 81 | ### Requirements 82 | - `Ubuntu 16.04` 83 | - `Anaconda3` with `python=3.6` 84 | - `Pytorch=1.9.0` 85 | - `mmdetection=2.16.0+fe46ffe` 86 | - `mmcv=1.3.9` 87 | - `wandb=0.10.31` 88 | 89 | #### Notes 90 | - We use [wandb](https://wandb.ai/) for visualization, if you don't want to use it, just comment line `273-284` in `configs/soft_teacher/base.py`. 91 | - The project should be compatible to the latest version of `mmdetection`. If you want to switch to the same version `mmdetection` as ours, run `cd thirdparty/mmdetection && git checkout v2.16.0` 92 | ### Installation 93 | ``` 94 | make install 95 | ``` 96 | 97 | ### Data Preparation 98 | - Download the COCO dataset 99 | - Execute the following command to generate data set splits: 100 | ```shell script 101 | # YOUR_DATA should be a directory contains coco dataset. 102 | # For eg.: 103 | # YOUR_DATA/ 104 | # coco/ 105 | # train2017/ 106 | # val2017/ 107 | # unlabeled2017/ 108 | # annotations/ 109 | ln -s ${YOUR_DATA} data 110 | bash tools/dataset/prepare_coco_data.sh conduct 111 | 112 | ``` 113 | For concrete instructions of what should be downloaded, please refer to `tools/dataset/prepare_coco_data.sh` line [`11-24`](https://github.com/microsoft/SoftTeacher/blob/863d90a3aa98615be3d156e7d305a22c2a5075f5/tools/dataset/prepare_coco_data.sh#L11) 114 | ### Training 115 | - To train model on the **partial labeled data** setting: 116 | ```shell script 117 | # JOB_TYPE: 'baseline' or 'semi', decide which kind of job to run 118 | # PERCENT_LABELED_DATA: 1, 5, 10. The ratio of labeled coco data in whole training dataset. 119 | # GPU_NUM: number of gpus to run the job 120 | for FOLD in 1 2 3 4 5; 121 | do 122 | bash tools/dist_train_partially.sh ${FOLD} 123 | done 124 | ``` 125 | For example, we could run the following scripts to train our model on 10% labeled data with 8 GPUs: 126 | 127 | ```shell script 128 | for FOLD in 1 2 3 4 5; 129 | do 130 | bash tools/dist_train_partially.sh semi ${FOLD} 10 8 131 | done 132 | ``` 133 | 134 | - To train model on the **full labeled data** setting: 135 | 136 | ```shell script 137 | bash tools/dist_train.sh 138 | ``` 139 | For example, to train ours `R50` model with 8 GPUs: 140 | ```shell script 141 | bash tools/dist_train.sh configs/soft_teacher/soft_teacher_faster_rcnn_r50_caffe_fpn_coco_full_720k.py 8 142 | ``` 143 | - To train model on **new dataset**: 144 | 145 | The core idea is to convert a new dataset to coco format. Details about it can be found in the [adding new dataset](https://github.com/open-mmlab/mmdetection/blob/master/docs/tutorials/customize_dataset.md). 146 | 147 | 148 | 149 | ### Evaluation 150 | ``` 151 | bash tools/dist_test.sh --eval bbox --cfg-options model.test_cfg.rcnn.score_thr= 152 | ``` 153 | ### Inference 154 | To inference with trained model and visualize the detection results: 155 | 156 | ```shell script 157 | # [IMAGE_FILE_PATH]: the path of your image file in local file system 158 | # [CONFIG_FILE]: the path of a confile file 159 | # [CHECKPOINT_PATH]: the path of a trained model related to provided confilg file. 160 | # [OUTPUT_PATH]: the directory to save detection result 161 | python demo/image_demo.py [IMAGE_FILE_PATH] [CONFIG_FILE] [CHECKPOINT_PATH] --output [OUTPUT_PATH] 162 | ``` 163 | For example: 164 | - Inference on single image with provided `R50` model: 165 | ```shell script 166 | python demo/image_demo.py /tmp/tmp.png configs/soft_teacher/soft_teacher_faster_rcnn_r50_caffe_fpn_coco_full_720k.py work_dirs/downloaded.model --output work_dirs/ 167 | ``` 168 | 169 | After the program completes, a image with the same name as input will be saved to `work_dirs` 170 | 171 | - Inference on many images with provided `R50` model: 172 | ```shell script 173 | python demo/image_demo.py '/tmp/*.jpg' configs/soft_teacher/soft_teacher_faster_rcnn_r50_caffe_fpn_coco_full_720k.py work_dirs/downloaded.model --output work_dirs/ 174 | ``` 175 | 176 | [1] [A Simple Semi-Supervised Learning Framework for Object Detection](https://arxiv.org/pdf/2005.04757.pdf) 177 | 178 | 179 | [2] [Instant-Teaching: An End-to-End Semi-Supervised Object Detection Framework](https://arxiv.org/pdf/2103.11402.pdf) 180 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/opensource/security/definition), please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/opensource/security/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/opensource/security/pgpkey). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://aka.ms/opensource/security/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/opensource/security/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/opensource/security/cvd). 40 | 41 | 42 | -------------------------------------------------------------------------------- /configs/baseline/base.py: -------------------------------------------------------------------------------- 1 | mmdet_base = "../../thirdparty/mmdetection/configs/_base_" 2 | _base_ = [ 3 | f"{mmdet_base}/models/faster_rcnn_r50_fpn.py", 4 | f"{mmdet_base}/datasets/coco_detection.py", 5 | f"{mmdet_base}/schedules/schedule_1x.py", 6 | f"{mmdet_base}/default_runtime.py", 7 | ] 8 | 9 | model = dict( 10 | backbone=dict( 11 | norm_cfg=dict(requires_grad=False), 12 | norm_eval=True, 13 | style="caffe", 14 | init_cfg=dict( 15 | type="Pretrained", checkpoint="open-mmlab://detectron2/resnet50_caffe" 16 | ), 17 | ) 18 | ) 19 | 20 | img_norm_cfg = dict(mean=[103.530, 116.280, 123.675], std=[1.0, 1.0, 1.0], to_rgb=False) 21 | 22 | train_pipeline = [ 23 | dict(type="LoadImageFromFile"), 24 | dict(type="LoadAnnotations", with_bbox=True), 25 | dict( 26 | type="Sequential", 27 | transforms=[ 28 | dict( 29 | type="RandResize", 30 | img_scale=[(1333, 400), (1333, 1200)], 31 | multiscale_mode="range", 32 | keep_ratio=True, 33 | ), 34 | dict(type="RandFlip", flip_ratio=0.5), 35 | dict( 36 | type="OneOf", 37 | transforms=[ 38 | dict(type=k) 39 | for k in [ 40 | "Identity", 41 | "AutoContrast", 42 | "RandEqualize", 43 | "RandSolarize", 44 | "RandColor", 45 | "RandContrast", 46 | "RandBrightness", 47 | "RandSharpness", 48 | "RandPosterize", 49 | ] 50 | ], 51 | ), 52 | ], 53 | ), 54 | dict(type="Pad", size_divisor=32), 55 | dict(type="Normalize", **img_norm_cfg), 56 | dict(type="ExtraAttrs", tag="sup"), 57 | dict(type="DefaultFormatBundle"), 58 | dict( 59 | type="Collect", 60 | keys=["img", "gt_bboxes", "gt_labels"], 61 | meta_keys=( 62 | "filename", 63 | "ori_shape", 64 | "img_shape", 65 | "img_norm_cfg", 66 | "pad_shape", 67 | "scale_factor", 68 | "tag", 69 | ), 70 | ), 71 | ] 72 | 73 | test_pipeline = [ 74 | dict(type="LoadImageFromFile"), 75 | dict( 76 | type="MultiScaleFlipAug", 77 | img_scale=(1333, 800), 78 | flip=False, 79 | transforms=[ 80 | dict(type="Resize", keep_ratio=True), 81 | dict(type="RandomFlip"), 82 | dict(type="Normalize", **img_norm_cfg), 83 | dict(type="Pad", size_divisor=32), 84 | dict(type="ImageToTensor", keys=["img"]), 85 | dict(type="Collect", keys=["img"]), 86 | ], 87 | ), 88 | ] 89 | 90 | data = dict( 91 | samples_per_gpu=1, 92 | workers_per_gpu=1, 93 | train=dict(pipeline=train_pipeline), 94 | val=dict(pipeline=test_pipeline), 95 | test=dict(pipeline=test_pipeline), 96 | ) 97 | 98 | optimizer = dict(type="SGD", lr=0.01, momentum=0.9, weight_decay=0.0001) 99 | lr_config = dict(step=[120000, 160000]) 100 | runner = dict(_delete_=True, type="IterBasedRunner", max_iters=180000) 101 | checkpoint_config = dict(by_epoch=False, interval=4000, max_keep_ckpts=10) 102 | evaluation = dict(interval=4000) 103 | 104 | fp16 = dict(loss_scale="dynamic") 105 | 106 | log_config = dict( 107 | interval=50, 108 | hooks=[ 109 | dict(type="TextLoggerHook", by_epoch=False), 110 | dict( 111 | type="WandbLoggerHook", 112 | init_kwargs=dict( 113 | project="pre_release", 114 | name="${cfg_name}", 115 | config=dict( 116 | work_dirs="${work_dir}", 117 | total_step="${runner.max_iters}", 118 | ), 119 | ), 120 | by_epoch=False, 121 | ), 122 | ], 123 | ) 124 | -------------------------------------------------------------------------------- /configs/baseline/faster_rcnn_r101_caffe_fpn_coco_full_720k.py: -------------------------------------------------------------------------------- 1 | _base_ = "base.py" 2 | model = dict( 3 | backbone=dict( 4 | depth=101, 5 | init_cfg=dict(checkpoint="open-mmlab://detectron2/resnet101_caffe"), 6 | ) 7 | ) 8 | 9 | data = dict( 10 | samples_per_gpu=2, 11 | workers_per_gpu=2, 12 | train=dict( 13 | ann_file="data/coco/annotations/instances_train2017.json", 14 | img_prefix="data/coco/train2017/", 15 | ), 16 | ) 17 | 18 | optimizer = dict(lr=0.02) 19 | lr_config = dict(step=[120000 * 4, 160000 * 4]) 20 | runner = dict(_delete_=True, type="IterBasedRunner", max_iters=180000 * 4) 21 | -------------------------------------------------------------------------------- /configs/baseline/faster_rcnn_r50_caffe_fpn_coco_full_720k.py: -------------------------------------------------------------------------------- 1 | _base_ = "base.py" 2 | 3 | data = dict( 4 | samples_per_gpu=2, 5 | workers_per_gpu=2, 6 | train=dict( 7 | ann_file="data/coco/annotations/instances_train2017.json", 8 | img_prefix="data/coco/train2017/", 9 | ), 10 | ) 11 | 12 | optimizer = dict(lr=0.02) 13 | lr_config = dict(step=[120000 * 4, 160000 * 4]) 14 | runner = dict(_delete_=True, type="IterBasedRunner", max_iters=180000 * 4) 15 | -------------------------------------------------------------------------------- /configs/baseline/faster_rcnn_r50_caffe_fpn_coco_partial_180k.py: -------------------------------------------------------------------------------- 1 | _base_ = "base.py" 2 | fold = 1 3 | percent = 1 4 | data = dict( 5 | samples_per_gpu=1, 6 | workers_per_gpu=1, 7 | train=dict( 8 | ann_file="data/coco/annotations/semi_supervised/instances_train2017.${fold}@${percent}.json", 9 | img_prefix="data/coco/train2017/", 10 | ), 11 | ) 12 | work_dir = "work_dirs/${cfg_name}/${percent}/${fold}" 13 | log_config = dict( 14 | interval=50, 15 | hooks=[ 16 | dict(type="TextLoggerHook"), 17 | dict( 18 | type="WandbLoggerHook", 19 | init_kwargs=dict( 20 | project="pre_release", 21 | name="${cfg_name}", 22 | config=dict( 23 | fold="${fold}", 24 | percent="${percent}", 25 | work_dirs="${work_dir}", 26 | total_step="${runner.max_iters}", 27 | ), 28 | ), 29 | by_epoch=False, 30 | ), 31 | ], 32 | ) 33 | -------------------------------------------------------------------------------- /configs/soft_teacher/base.py: -------------------------------------------------------------------------------- 1 | mmdet_base = "../../thirdparty/mmdetection/configs/_base_" 2 | _base_ = [ 3 | f"{mmdet_base}/models/faster_rcnn_r50_fpn.py", 4 | f"{mmdet_base}/datasets/coco_detection.py", 5 | f"{mmdet_base}/schedules/schedule_1x.py", 6 | f"{mmdet_base}/default_runtime.py", 7 | ] 8 | 9 | model = dict( 10 | backbone=dict( 11 | norm_cfg=dict(requires_grad=False), 12 | norm_eval=True, 13 | style="caffe", 14 | init_cfg=dict( 15 | type="Pretrained", checkpoint="open-mmlab://detectron2/resnet50_caffe" 16 | ), 17 | ) 18 | ) 19 | 20 | img_norm_cfg = dict(mean=[103.530, 116.280, 123.675], std=[1.0, 1.0, 1.0], to_rgb=False) 21 | 22 | train_pipeline = [ 23 | dict(type="LoadImageFromFile"), 24 | dict(type="LoadAnnotations", with_bbox=True), 25 | dict( 26 | type="Sequential", 27 | transforms=[ 28 | dict( 29 | type="RandResize", 30 | img_scale=[(1333, 400), (1333, 1200)], 31 | multiscale_mode="range", 32 | keep_ratio=True, 33 | ), 34 | dict(type="RandFlip", flip_ratio=0.5), 35 | dict( 36 | type="OneOf", 37 | transforms=[ 38 | dict(type=k) 39 | for k in [ 40 | "Identity", 41 | "AutoContrast", 42 | "RandEqualize", 43 | "RandSolarize", 44 | "RandColor", 45 | "RandContrast", 46 | "RandBrightness", 47 | "RandSharpness", 48 | "RandPosterize", 49 | ] 50 | ], 51 | ), 52 | ], 53 | record=True, 54 | ), 55 | dict(type="Pad", size_divisor=32), 56 | dict(type="Normalize", **img_norm_cfg), 57 | dict(type="ExtraAttrs", tag="sup"), 58 | dict(type="DefaultFormatBundle"), 59 | dict( 60 | type="Collect", 61 | keys=["img", "gt_bboxes", "gt_labels"], 62 | meta_keys=( 63 | "filename", 64 | "ori_shape", 65 | "img_shape", 66 | "img_norm_cfg", 67 | "pad_shape", 68 | "scale_factor", 69 | "tag", 70 | ), 71 | ), 72 | ] 73 | 74 | strong_pipeline = [ 75 | dict( 76 | type="Sequential", 77 | transforms=[ 78 | dict( 79 | type="RandResize", 80 | img_scale=[(1333, 400), (1333, 1200)], 81 | multiscale_mode="range", 82 | keep_ratio=True, 83 | ), 84 | dict(type="RandFlip", flip_ratio=0.5), 85 | dict( 86 | type="ShuffledSequential", 87 | transforms=[ 88 | dict( 89 | type="OneOf", 90 | transforms=[ 91 | dict(type=k) 92 | for k in [ 93 | "Identity", 94 | "AutoContrast", 95 | "RandEqualize", 96 | "RandSolarize", 97 | "RandColor", 98 | "RandContrast", 99 | "RandBrightness", 100 | "RandSharpness", 101 | "RandPosterize", 102 | ] 103 | ], 104 | ), 105 | dict( 106 | type="OneOf", 107 | transforms=[ 108 | dict(type="RandTranslate", x=(-0.1, 0.1)), 109 | dict(type="RandTranslate", y=(-0.1, 0.1)), 110 | dict(type="RandRotate", angle=(-30, 30)), 111 | [ 112 | dict(type="RandShear", x=(-30, 30)), 113 | dict(type="RandShear", y=(-30, 30)), 114 | ], 115 | ], 116 | ), 117 | ], 118 | ), 119 | dict( 120 | type="RandErase", 121 | n_iterations=(1, 5), 122 | size=[0, 0.2], 123 | squared=True, 124 | ), 125 | ], 126 | record=True, 127 | ), 128 | dict(type="Pad", size_divisor=32), 129 | dict(type="Normalize", **img_norm_cfg), 130 | dict(type="ExtraAttrs", tag="unsup_student"), 131 | dict(type="DefaultFormatBundle"), 132 | dict( 133 | type="Collect", 134 | keys=["img", "gt_bboxes", "gt_labels"], 135 | meta_keys=( 136 | "filename", 137 | "ori_shape", 138 | "img_shape", 139 | "img_norm_cfg", 140 | "pad_shape", 141 | "scale_factor", 142 | "tag", 143 | "transform_matrix", 144 | ), 145 | ), 146 | ] 147 | weak_pipeline = [ 148 | dict( 149 | type="Sequential", 150 | transforms=[ 151 | dict( 152 | type="RandResize", 153 | img_scale=[(1333, 400), (1333, 1200)], 154 | multiscale_mode="range", 155 | keep_ratio=True, 156 | ), 157 | dict(type="RandFlip", flip_ratio=0.5), 158 | ], 159 | record=True, 160 | ), 161 | dict(type="Pad", size_divisor=32), 162 | dict(type="Normalize", **img_norm_cfg), 163 | dict(type="ExtraAttrs", tag="unsup_teacher"), 164 | dict(type="DefaultFormatBundle"), 165 | dict( 166 | type="Collect", 167 | keys=["img", "gt_bboxes", "gt_labels"], 168 | meta_keys=( 169 | "filename", 170 | "ori_shape", 171 | "img_shape", 172 | "img_norm_cfg", 173 | "pad_shape", 174 | "scale_factor", 175 | "tag", 176 | "transform_matrix", 177 | ), 178 | ), 179 | ] 180 | unsup_pipeline = [ 181 | dict(type="LoadImageFromFile"), 182 | # dict(type="LoadAnnotations", with_bbox=True), 183 | # generate fake labels for data format compatibility 184 | dict(type="PseudoSamples", with_bbox=True), 185 | dict( 186 | type="MultiBranch", unsup_student=strong_pipeline, unsup_teacher=weak_pipeline 187 | ), 188 | ] 189 | 190 | test_pipeline = [ 191 | dict(type="LoadImageFromFile"), 192 | dict( 193 | type="MultiScaleFlipAug", 194 | img_scale=(1333, 800), 195 | flip=False, 196 | transforms=[ 197 | dict(type="Resize", keep_ratio=True), 198 | dict(type="RandomFlip"), 199 | dict(type="Normalize", **img_norm_cfg), 200 | dict(type="Pad", size_divisor=32), 201 | dict(type="ImageToTensor", keys=["img"]), 202 | dict(type="Collect", keys=["img"]), 203 | ], 204 | ), 205 | ] 206 | data = dict( 207 | samples_per_gpu=None, 208 | workers_per_gpu=None, 209 | train=dict( 210 | _delete_=True, 211 | type="SemiDataset", 212 | sup=dict( 213 | type="CocoDataset", 214 | ann_file=None, 215 | img_prefix=None, 216 | pipeline=train_pipeline, 217 | ), 218 | unsup=dict( 219 | type="CocoDataset", 220 | ann_file=None, 221 | img_prefix=None, 222 | pipeline=unsup_pipeline, 223 | filter_empty_gt=False, 224 | ), 225 | ), 226 | val=dict(pipeline=test_pipeline), 227 | test=dict(pipeline=test_pipeline), 228 | sampler=dict( 229 | train=dict( 230 | type="SemiBalanceSampler", 231 | sample_ratio=[1, 4], 232 | by_prob=True, 233 | # at_least_one=True, 234 | epoch_length=7330, 235 | ) 236 | ), 237 | ) 238 | 239 | semi_wrapper = dict( 240 | type="SoftTeacher", 241 | model="${model}", 242 | train_cfg=dict( 243 | use_teacher_proposal=False, 244 | pseudo_label_initial_score_thr=0.5, 245 | rpn_pseudo_threshold=0.9, 246 | cls_pseudo_threshold=0.9, 247 | reg_pseudo_threshold=0.02, 248 | jitter_times=10, 249 | jitter_scale=0.06, 250 | min_pseduo_box_size=0, 251 | unsup_weight=4.0, 252 | ), 253 | test_cfg=dict(inference_on="student"), 254 | ) 255 | 256 | custom_hooks = [ 257 | dict(type="NumClassCheckHook"), 258 | dict(type="WeightSummary"), 259 | dict(type="MeanTeacher", momentum=0.999, interval=1, warm_up=0), 260 | ] 261 | evaluation = dict(type="SubModulesDistEvalHook", interval=4000) 262 | optimizer = dict(type="SGD", lr=0.01, momentum=0.9, weight_decay=0.0001) 263 | lr_config = dict(step=[120000, 160000]) 264 | runner = dict(_delete_=True, type="IterBasedRunner", max_iters=180000) 265 | checkpoint_config = dict(by_epoch=False, interval=4000, max_keep_ckpts=20) 266 | 267 | fp16 = dict(loss_scale="dynamic") 268 | 269 | log_config = dict( 270 | interval=50, 271 | hooks=[ 272 | dict(type="TextLoggerHook", by_epoch=False), 273 | dict( 274 | type="WandbLoggerHook", 275 | init_kwargs=dict( 276 | project="pre_release", 277 | name="${cfg_name}", 278 | config=dict( 279 | work_dirs="${work_dir}", 280 | total_step="${runner.max_iters}", 281 | ), 282 | ), 283 | by_epoch=False, 284 | ), 285 | ], 286 | ) 287 | -------------------------------------------------------------------------------- /configs/soft_teacher/soft_teacher_faster_rcnn_r101_caffe_fpn_coco_full_1080k.py: -------------------------------------------------------------------------------- 1 | _base_ = "soft_teacher_faster_rcnn_r50_caffe_fpn_coco_full_720k.py" 2 | model = dict( 3 | backbone=dict( 4 | depth=101, 5 | init_cfg=dict(checkpoint="open-mmlab://detectron2/resnet101_caffe"), 6 | ) 7 | ) 8 | 9 | lr_config = dict(step=[120000 * 6, 160000 * 6]) 10 | runner = dict(_delete_=True, type="IterBasedRunner", max_iters=180000 * 6) 11 | -------------------------------------------------------------------------------- /configs/soft_teacher/soft_teacher_faster_rcnn_r50_caffe_fpn_coco_180k.py: -------------------------------------------------------------------------------- 1 | _base_ = "base.py" 2 | 3 | data = dict( 4 | samples_per_gpu=5, 5 | workers_per_gpu=5, 6 | train=dict( 7 | sup=dict( 8 | type="CocoDataset", 9 | ann_file="data/coco/annotations/semi_supervised/instances_train2017.${fold}@${percent}.json", 10 | img_prefix="data/coco/train2017/", 11 | ), 12 | unsup=dict( 13 | type="CocoDataset", 14 | ann_file="data/coco/annotations/semi_supervised/instances_train2017.${fold}@${percent}-unlabeled.json", 15 | img_prefix="data/coco/train2017/", 16 | ), 17 | ), 18 | sampler=dict( 19 | train=dict( 20 | sample_ratio=[1, 4], 21 | ) 22 | ), 23 | ) 24 | 25 | fold = 1 26 | percent = 1 27 | 28 | work_dir = "work_dirs/${cfg_name}/${percent}/${fold}" 29 | log_config = dict( 30 | interval=50, 31 | hooks=[ 32 | dict(type="TextLoggerHook"), 33 | dict( 34 | type="WandbLoggerHook", 35 | init_kwargs=dict( 36 | project="pre_release", 37 | name="${cfg_name}", 38 | config=dict( 39 | fold="${fold}", 40 | percent="${percent}", 41 | work_dirs="${work_dir}", 42 | total_step="${runner.max_iters}", 43 | ), 44 | ), 45 | by_epoch=False, 46 | ), 47 | ], 48 | ) 49 | -------------------------------------------------------------------------------- /configs/soft_teacher/soft_teacher_faster_rcnn_r50_caffe_fpn_coco_full_1440k.py: -------------------------------------------------------------------------------- 1 | _base_ = "soft_teacher_faster_rcnn_r50_caffe_fpn_coco_full_720k.py" 2 | 3 | 4 | lr_config = dict(step=[120000 * 8, 160000 * 8]) 5 | runner = dict(_delete_=True, type="IterBasedRunner", max_iters=180000 * 8) 6 | -------------------------------------------------------------------------------- /configs/soft_teacher/soft_teacher_faster_rcnn_r50_caffe_fpn_coco_full_720k.py: -------------------------------------------------------------------------------- 1 | _base_="base.py" 2 | 3 | data = dict( 4 | samples_per_gpu=8, 5 | workers_per_gpu=8, 6 | train=dict( 7 | 8 | sup=dict( 9 | 10 | ann_file="data/coco/annotations/instances_train2017.json", 11 | img_prefix="data/coco/train2017/", 12 | 13 | ), 14 | unsup=dict( 15 | 16 | ann_file="data/coco/annotations/instances_unlabeled2017.json", 17 | img_prefix="data/coco/unlabeled2017/", 18 | 19 | ), 20 | ), 21 | sampler=dict( 22 | train=dict( 23 | sample_ratio=[1, 1], 24 | ) 25 | ), 26 | ) 27 | 28 | semi_wrapper = dict( 29 | train_cfg=dict( 30 | unsup_weight=2.0, 31 | ) 32 | ) 33 | 34 | lr_config = dict(step=[120000 * 4, 160000 * 4]) 35 | runner = dict(_delete_=True, type="IterBasedRunner", max_iters=180000 * 4) 36 | 37 | -------------------------------------------------------------------------------- /demo/image_demo.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | # Modified from thirdparty/mmdetection/demo/image_demo.py 3 | import asyncio 4 | import glob 5 | import os 6 | from argparse import ArgumentParser 7 | 8 | from mmcv import Config 9 | from mmdet.apis import async_inference_detector, inference_detector, show_result_pyplot 10 | 11 | from ssod.apis.inference import init_detector, save_result 12 | from ssod.utils import patch_config 13 | 14 | 15 | def parse_args(): 16 | parser = ArgumentParser() 17 | parser.add_argument("img", help="Image file") 18 | parser.add_argument("config", help="Config file") 19 | parser.add_argument("checkpoint", help="Checkpoint file") 20 | parser.add_argument("--device", default="cuda:0", help="Device used for inference") 21 | parser.add_argument( 22 | "--score-thr", type=float, default=0.3, help="bbox score threshold" 23 | ) 24 | parser.add_argument( 25 | "--async-test", 26 | action="store_true", 27 | help="whether to set async options for async inference.", 28 | ) 29 | parser.add_argument( 30 | "--output", 31 | type=str, 32 | default=None, 33 | help="specify the directory to save visualization results.", 34 | ) 35 | args = parser.parse_args() 36 | return args 37 | 38 | 39 | def main(args): 40 | cfg = Config.fromfile(args.config) 41 | # Not affect anything, just avoid index error 42 | cfg.work_dir = "./work_dirs" 43 | cfg = patch_config(cfg) 44 | # build the model from a config file and a checkpoint file 45 | model = init_detector(cfg, args.checkpoint, device=args.device) 46 | imgs = glob.glob(args.img) 47 | for img in imgs: 48 | # test a single image 49 | result = inference_detector(model, img) 50 | # show the results 51 | if args.output is None: 52 | show_result_pyplot(model, img, result, score_thr=args.score_thr) 53 | else: 54 | out_file_path = os.path.join(args.output, os.path.basename(img)) 55 | print(f"Save results to {out_file_path}") 56 | save_result( 57 | model, img, result, score_thr=args.score_thr, out_file=out_file_path 58 | ) 59 | 60 | 61 | async def async_main(args): 62 | cfg = Config.fromfile(args.config) 63 | # Not affect anything, just avoid index error 64 | cfg.work_dir = "./work_dirs" 65 | cfg = patch_config(cfg) 66 | # build the model from a config file and a checkpoint file 67 | model = init_detector(cfg, args.checkpoint, device=args.device) 68 | # test a single image 69 | args.img = glob.glob(args.img) 70 | tasks = asyncio.create_task(async_inference_detector(model, args.img)) 71 | result = await asyncio.gather(tasks) 72 | # show the results 73 | for img, pred in zip(args.img, result): 74 | if args.output is None: 75 | show_result_pyplot(model, img, pred, score_thr=args.score_thr) 76 | else: 77 | out_file_path = os.path.join(args.output, os.path.basename(img)) 78 | print(f"Save results to {out_file_path}") 79 | save_result( 80 | model, img, pred, score_thr=args.score_thr, out_file=out_file_path 81 | ) 82 | 83 | 84 | if __name__ == "__main__": 85 | args = parse_args() 86 | if args.async_test: 87 | asyncio.run(async_main(args)) 88 | else: 89 | main(args) 90 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | mmcv-full 4 | wandb 5 | prettytable 6 | -------------------------------------------------------------------------------- /resources/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/SoftTeacher/0423bb2fe43f4fbe0e6704b1011d865cc3d2fdbe/resources/pipeline.png -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | from setuptools import find_packages, setup 4 | 5 | 6 | def get_version(): 7 | version_file = "ssod/version.py" 8 | with open(version_file, "r") as f: 9 | exec(compile(f.read(), version_file, "exec")) 10 | return locals()["__version__"] 11 | 12 | 13 | def parse_requirements(fname="requirements.txt", with_version=True): 14 | """Parse the package dependencies listed in a requirements file but strips 15 | specific versioning information. 16 | 17 | Args: 18 | fname (str): path to requirements file 19 | with_version (bool, default=False): if True include version specs 20 | Returns: 21 | List[str]: list of requirements items 22 | CommandLine: 23 | python -c "import setup; print(setup.parse_requirements())" 24 | """ 25 | import sys 26 | from os.path import exists 27 | 28 | require_fpath = fname 29 | 30 | def parse_line(line): 31 | """Parse information from a line in a requirements text file.""" 32 | if line.startswith("-r "): 33 | # Allow specifying requirements in other files 34 | target = line.split(" ")[1] 35 | for info in parse_require_file(target): 36 | yield info 37 | else: 38 | info = {"line": line} 39 | if line.startswith("-e "): 40 | info["package"] = line.split("#egg=")[1] 41 | else: 42 | # Remove versioning from the package 43 | pat = "(" + "|".join([">=", "==", ">"]) + ")" 44 | parts = re.split(pat, line, maxsplit=1) 45 | parts = [p.strip() for p in parts] 46 | 47 | info["package"] = parts[0] 48 | if len(parts) > 1: 49 | op, rest = parts[1:] 50 | if ";" in rest: 51 | # Handle platform specific dependencies 52 | # http://setuptools.readthedocs.io/en/latest/setuptools.html#declaring-platform-specific-dependencies 53 | version, platform_deps = map(str.strip, rest.split(";")) 54 | info["platform_deps"] = platform_deps 55 | else: 56 | version = rest # NOQA 57 | info["version"] = (op, version) 58 | yield info 59 | 60 | def parse_require_file(fpath): 61 | with open(fpath, "r") as f: 62 | for line in f.readlines(): 63 | line = line.strip() 64 | if line and not line.startswith("#"): 65 | for info in parse_line(line): 66 | yield info 67 | 68 | def gen_packages_items(): 69 | if exists(require_fpath): 70 | for info in parse_require_file(require_fpath): 71 | parts = [info["package"]] 72 | if with_version and "version" in info: 73 | parts.extend(info["version"]) 74 | if not sys.version.startswith("3.4"): 75 | # apparently package_deps are broken in 3.4 76 | platform_deps = info.get("platform_deps") 77 | if platform_deps is not None: 78 | parts.append(";" + platform_deps) 79 | item = "".join(parts) 80 | yield item 81 | 82 | packages = list(gen_packages_items()) 83 | return packages 84 | 85 | 86 | if __name__ == "__main__": 87 | install_requires = parse_requirements() 88 | setup( 89 | name="ssod", 90 | version=get_version(), 91 | description="Semi-Supervised Object Detection Benchmark", 92 | author="someone", 93 | author_email="someone", 94 | packages=find_packages(exclude=("configs", "tools", "demo")), 95 | install_requires=install_requires, 96 | include_package_data=True, 97 | ext_modules=[], 98 | zip_safe=False, 99 | ) 100 | -------------------------------------------------------------------------------- /ssod/__init__.py: -------------------------------------------------------------------------------- 1 | from .models import * 2 | -------------------------------------------------------------------------------- /ssod/apis/__init__.py: -------------------------------------------------------------------------------- 1 | from .train import get_root_logger, set_random_seed, train_detector 2 | 3 | __all__ = ["get_root_logger", "set_random_seed", "train_detector"] 4 | -------------------------------------------------------------------------------- /ssod/apis/inference.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import warnings 3 | 4 | import mmcv 5 | from mmcv.runner import load_checkpoint 6 | 7 | from mmdet.core import get_classes 8 | from mmdet.models import build_detector 9 | 10 | 11 | def init_detector(config, checkpoint=None, device="cuda:0", cfg_options=None): 12 | """Initialize a detector from config file. 13 | 14 | Args: 15 | config (str or :obj:`mmcv.Config`): Config file path or the config 16 | object. 17 | checkpoint (str, optional): Checkpoint path. If left as None, the model 18 | will not load any weights. 19 | cfg_options (dict): Options to override some settings in the used 20 | config. 21 | 22 | Returns: 23 | nn.Module: The constructed detector. 24 | """ 25 | if isinstance(config, str): 26 | config = mmcv.Config.fromfile(config) 27 | elif not isinstance(config, mmcv.Config): 28 | raise TypeError( 29 | "config must be a filename or Config object, " f"but got {type(config)}" 30 | ) 31 | if cfg_options is not None: 32 | config.merge_from_dict(cfg_options) 33 | config.model.train_cfg = None 34 | 35 | if hasattr(config.model, "model"): 36 | config.model.model.pretrained = None 37 | config.model.model.train_cfg = None 38 | else: 39 | config.model.pretrained = None 40 | 41 | model = build_detector(config.model, test_cfg=config.get("test_cfg")) 42 | if checkpoint is not None: 43 | map_loc = "cpu" if device == "cpu" else None 44 | checkpoint = load_checkpoint(model, checkpoint, map_location=map_loc) 45 | if "CLASSES" in checkpoint.get("meta", {}): 46 | model.CLASSES = checkpoint["meta"]["CLASSES"] 47 | else: 48 | warnings.simplefilter("once") 49 | warnings.warn( 50 | "Class names are not saved in the checkpoint's " 51 | "meta data, use COCO classes by default." 52 | ) 53 | model.CLASSES = get_classes("coco") 54 | model.cfg = config # save the config in the model for convenience 55 | model.to(device) 56 | model.eval() 57 | return model 58 | 59 | 60 | def save_result(model, img, result, score_thr=0.3, out_file="res.png"): 61 | """Save the detection results on the image. 62 | 63 | Args: 64 | model (nn.Module): The loaded detector. 65 | img (str or np.ndarray): Image filename or loaded image. 66 | result (tuple[list] or list): The detection result, can be either 67 | (bbox, segm) or just bbox. 68 | score_thr (float): The threshold to visualize the bboxes and masks. 69 | out_file (str): Specifies where to save the visualization result 70 | """ 71 | if hasattr(model, "module"): 72 | model = model.module 73 | model.show_result( 74 | img, 75 | result, 76 | score_thr=score_thr, 77 | show=False, 78 | out_file=out_file, 79 | bbox_color=(72, 101, 241), 80 | text_color=(72, 101, 241), 81 | ) 82 | -------------------------------------------------------------------------------- /ssod/apis/train.py: -------------------------------------------------------------------------------- 1 | import random 2 | import warnings 3 | 4 | import numpy as np 5 | import torch 6 | from mmcv.parallel import MMDataParallel, MMDistributedDataParallel 7 | from mmcv.runner import ( 8 | HOOKS, 9 | DistSamplerSeedHook, 10 | EpochBasedRunner, 11 | Fp16OptimizerHook, 12 | OptimizerHook, 13 | build_optimizer, 14 | build_runner, 15 | ) 16 | from mmcv.runner.hooks import HOOKS 17 | from mmcv.utils import build_from_cfg 18 | from mmdet.core import EvalHook 19 | from mmdet.datasets import build_dataset, replace_ImageToTensor 20 | 21 | from ssod.datasets import build_dataloader 22 | from ssod.utils import find_latest_checkpoint, get_root_logger, patch_runner 23 | from ssod.utils.hooks import DistEvalHook 24 | 25 | 26 | def set_random_seed(seed, deterministic=False): 27 | """Set random seed. 28 | 29 | Args: 30 | seed (int): Seed to be used. 31 | deterministic (bool): Whether to set the deterministic option for 32 | CUDNN backend, i.e., set `torch.backends.cudnn.deterministic` 33 | to True and `torch.backends.cudnn.benchmark` to False. 34 | Default: False. 35 | """ 36 | random.seed(seed) 37 | np.random.seed(seed) 38 | torch.manual_seed(seed) 39 | torch.cuda.manual_seed_all(seed) 40 | if deterministic: 41 | torch.backends.cudnn.deterministic = True 42 | torch.backends.cudnn.benchmark = False 43 | 44 | 45 | def train_detector( 46 | model, dataset, cfg, distributed=False, validate=False, timestamp=None, meta=None 47 | ): 48 | logger = get_root_logger(log_level=cfg.log_level) 49 | 50 | # prepare data loaders 51 | dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset] 52 | if "imgs_per_gpu" in cfg.data: 53 | logger.warning( 54 | '"imgs_per_gpu" is deprecated in MMDet V2.0. ' 55 | 'Please use "samples_per_gpu" instead' 56 | ) 57 | if "samples_per_gpu" in cfg.data: 58 | logger.warning( 59 | f'Got "imgs_per_gpu"={cfg.data.imgs_per_gpu} and ' 60 | f'"samples_per_gpu"={cfg.data.samples_per_gpu}, "imgs_per_gpu"' 61 | f"={cfg.data.imgs_per_gpu} is used in this experiments" 62 | ) 63 | else: 64 | logger.warning( 65 | 'Automatically set "samples_per_gpu"="imgs_per_gpu"=' 66 | f"{cfg.data.imgs_per_gpu} in this experiments" 67 | ) 68 | cfg.data.samples_per_gpu = cfg.data.imgs_per_gpu 69 | 70 | data_loaders = [ 71 | build_dataloader( 72 | ds, 73 | cfg.data.samples_per_gpu, 74 | cfg.data.workers_per_gpu, 75 | # cfg.gpus will be ignored if distributed 76 | len(cfg.gpu_ids), 77 | dist=distributed, 78 | seed=cfg.seed, 79 | sampler_cfg=cfg.data.get("sampler", {}).get("train", {}), 80 | ) 81 | for ds in dataset 82 | ] 83 | 84 | # put model on gpus 85 | if distributed: 86 | find_unused_parameters = cfg.get("find_unused_parameters", False) 87 | # Sets the `find_unused_parameters` parameter in 88 | # torch.nn.parallel.DistributedDataParallel 89 | model = MMDistributedDataParallel( 90 | model.cuda(), 91 | device_ids=[torch.cuda.current_device()], 92 | broadcast_buffers=False, 93 | find_unused_parameters=find_unused_parameters, 94 | ) 95 | else: 96 | model = MMDataParallel(model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids) 97 | 98 | # build runner 99 | optimizer = build_optimizer(model, cfg.optimizer) 100 | 101 | if "runner" not in cfg: 102 | cfg.runner = {"type": "EpochBasedRunner", "max_epochs": cfg.total_epochs} 103 | warnings.warn( 104 | "config is now expected to have a `runner` section, " 105 | "please set `runner` in your config.", 106 | UserWarning, 107 | ) 108 | else: 109 | if "total_epochs" in cfg: 110 | assert cfg.total_epochs == cfg.runner.max_epochs 111 | 112 | runner = build_runner( 113 | cfg.runner, 114 | default_args=dict( 115 | model=model, 116 | optimizer=optimizer, 117 | work_dir=cfg.work_dir, 118 | logger=logger, 119 | meta=meta, 120 | ), 121 | ) 122 | 123 | # an ugly workaround to make .log and .log.json filenames the same 124 | runner.timestamp = timestamp 125 | 126 | # fp16 setting 127 | fp16_cfg = cfg.get("fp16", None) 128 | if fp16_cfg is not None: 129 | optimizer_config = Fp16OptimizerHook( 130 | **cfg.optimizer_config, **fp16_cfg, distributed=distributed 131 | ) 132 | elif distributed and "type" not in cfg.optimizer_config: 133 | optimizer_config = OptimizerHook(**cfg.optimizer_config) 134 | else: 135 | optimizer_config = cfg.optimizer_config 136 | 137 | # register hooks 138 | runner.register_training_hooks( 139 | cfg.lr_config, 140 | optimizer_config, 141 | cfg.checkpoint_config, 142 | cfg.log_config, 143 | cfg.get("momentum_config", None), 144 | ) 145 | if distributed: 146 | if isinstance(runner, EpochBasedRunner): 147 | runner.register_hook(DistSamplerSeedHook()) 148 | 149 | # register eval hooks 150 | if validate: 151 | # Support batch_size > 1 in validation 152 | val_samples_per_gpu = cfg.data.val.pop("samples_per_gpu", 1) 153 | if val_samples_per_gpu > 1: 154 | # Replace 'ImageToTensor' to 'DefaultFormatBundle' 155 | cfg.data.val.pipeline = replace_ImageToTensor(cfg.data.val.pipeline) 156 | val_dataset = build_dataset(cfg.data.val, dict(test_mode=True)) 157 | val_dataloader = build_dataloader( 158 | val_dataset, 159 | samples_per_gpu=val_samples_per_gpu, 160 | workers_per_gpu=cfg.data.workers_per_gpu, 161 | dist=distributed, 162 | shuffle=False, 163 | ) 164 | eval_cfg = cfg.get("evaluation", {}) 165 | eval_cfg["by_epoch"] = eval_cfg.get( 166 | "by_epoch", cfg.runner["type"] != "IterBasedRunner" 167 | ) 168 | if "type" not in eval_cfg: 169 | eval_hook = DistEvalHook if distributed else EvalHook 170 | eval_hook = eval_hook(val_dataloader, **eval_cfg) 171 | 172 | else: 173 | eval_hook = build_from_cfg( 174 | eval_cfg, HOOKS, default_args=dict(dataloader=val_dataloader) 175 | ) 176 | 177 | runner.register_hook(eval_hook, priority=80) 178 | 179 | # user-defined hooks 180 | if cfg.get("custom_hooks", None): 181 | custom_hooks = cfg.custom_hooks 182 | assert isinstance( 183 | custom_hooks, list 184 | ), f"custom_hooks expect list type, but got {type(custom_hooks)}" 185 | for hook_cfg in cfg.custom_hooks: 186 | assert isinstance(hook_cfg, dict), ( 187 | "Each item in custom_hooks expects dict type, but got " 188 | f"{type(hook_cfg)}" 189 | ) 190 | hook_cfg = hook_cfg.copy() 191 | priority = hook_cfg.pop("priority", "NORMAL") 192 | hook = build_from_cfg(hook_cfg, HOOKS) 193 | runner.register_hook(hook, priority=priority) 194 | 195 | runner = patch_runner(runner) 196 | resume_from = None 197 | if cfg.get("auto_resume", True): 198 | resume_from = find_latest_checkpoint(cfg.work_dir) 199 | if resume_from is not None: 200 | cfg.resume_from = resume_from 201 | 202 | if cfg.resume_from: 203 | runner.resume(cfg.resume_from) 204 | elif cfg.load_from: 205 | runner.load_checkpoint(cfg.load_from) 206 | runner.run(data_loaders, cfg.workflow) 207 | -------------------------------------------------------------------------------- /ssod/core/__init__.py: -------------------------------------------------------------------------------- 1 | from .masks import TrimapMasks 2 | -------------------------------------------------------------------------------- /ssod/core/masks/__init__.py: -------------------------------------------------------------------------------- 1 | from .structures import TrimapMasks 2 | -------------------------------------------------------------------------------- /ssod/core/masks/structures.py: -------------------------------------------------------------------------------- 1 | """ 2 | Designed for pseudo masks. 3 | In a `TrimapMasks`, it allow some part of the mask is ignored when computing loss. 4 | """ 5 | import numpy as np 6 | import torch 7 | from mmcv.ops.roi_align import roi_align 8 | from mmdet.core import BitmapMasks 9 | 10 | 11 | class TrimapMasks(BitmapMasks): 12 | def __init__(self, masks, height, width, ignore_value=255): 13 | """ 14 | Args: 15 | ignore_value: flag to ignore in loss computation. 16 | See `mmdet.core.BitmapMasks` for more information 17 | """ 18 | super().__init__(masks, height, width) 19 | self.ignore_value = ignore_value 20 | 21 | def crop_and_resize( 22 | self, bboxes, out_shape, inds, device="cpu", interpolation="bilinear" 23 | ): 24 | """See :func:`BaseInstanceMasks.crop_and_resize`.""" 25 | if len(self.masks) == 0: 26 | empty_masks = np.empty((0, *out_shape), dtype=np.uint8) 27 | return BitmapMasks(empty_masks, *out_shape) 28 | 29 | # convert bboxes to tensor 30 | if isinstance(bboxes, np.ndarray): 31 | bboxes = torch.from_numpy(bboxes).to(device=device) 32 | if isinstance(inds, np.ndarray): 33 | inds = torch.from_numpy(inds).to(device=device) 34 | 35 | num_bbox = bboxes.shape[0] 36 | fake_inds = torch.arange(num_bbox, device=device).to(dtype=bboxes.dtype)[ 37 | :, None 38 | ] 39 | rois = torch.cat([fake_inds, bboxes], dim=1) # Nx5 40 | rois = rois.to(device=device) 41 | if num_bbox > 0: 42 | gt_masks_th = ( 43 | torch.from_numpy(self.masks) 44 | .to(device) 45 | .index_select(0, inds) 46 | .to(dtype=rois.dtype) 47 | ) 48 | targets = roi_align( 49 | gt_masks_th[:, None, :, :], rois, out_shape, 1.0, 0, "avg", True 50 | ).squeeze(1) 51 | # for a mask: 52 | # value<0.5 -> background, 53 | # 0.5<=value<=1 -> foreground 54 | # value>1 -> ignored area 55 | resized_masks = (targets >= 0.5).float() 56 | resized_masks[targets > 1] = self.ignore_value 57 | resized_masks = resized_masks.cpu().numpy() 58 | else: 59 | resized_masks = [] 60 | return BitmapMasks(resized_masks, *out_shape) 61 | -------------------------------------------------------------------------------- /ssod/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from mmdet.datasets import build_dataset 2 | 3 | from .builder import build_dataloader 4 | from .dataset_wrappers import SemiDataset 5 | from .pipelines import * 6 | from .pseudo_coco import PseudoCocoDataset 7 | from .samplers import DistributedGroupSemiBalanceSampler 8 | 9 | __all__ = [ 10 | "PseudoCocoDataset", 11 | "build_dataloader", 12 | "build_dataset", 13 | "SemiDataset", 14 | "DistributedGroupSemiBalanceSampler", 15 | ] 16 | -------------------------------------------------------------------------------- /ssod/datasets/builder.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Mapping, Sequence 2 | from functools import partial 3 | 4 | import torch 5 | from mmcv.parallel import DataContainer 6 | from mmcv.runner import get_dist_info 7 | from mmcv.utils import Registry, build_from_cfg 8 | from mmdet.datasets.builder import worker_init_fn 9 | from mmdet.datasets.samplers import ( 10 | DistributedGroupSampler, 11 | DistributedSampler, 12 | GroupSampler, 13 | ) 14 | from torch.nn import functional as F 15 | from torch.utils.data import DataLoader 16 | from torch.utils.data.dataloader import default_collate 17 | 18 | SAMPLERS = Registry("sampler") 19 | 20 | SAMPLERS.register_module(module=DistributedGroupSampler) 21 | SAMPLERS.register_module(module=DistributedSampler) 22 | SAMPLERS.register_module(module=GroupSampler) 23 | 24 | 25 | def build_sampler(cfg, dist=False, group=False, default_args=None): 26 | if cfg and ("type" in cfg): 27 | sampler_type = cfg.get("type") 28 | else: 29 | sampler_type = default_args.get("type") 30 | if group: 31 | sampler_type = "Group" + sampler_type 32 | if dist: 33 | sampler_type = "Distributed" + sampler_type 34 | 35 | if cfg: 36 | cfg.update(type=sampler_type) 37 | else: 38 | cfg = dict(type=sampler_type) 39 | 40 | return build_from_cfg(cfg, SAMPLERS, default_args) 41 | 42 | 43 | def build_dataloader( 44 | dataset, 45 | samples_per_gpu, 46 | workers_per_gpu, 47 | num_gpus=1, 48 | dist=True, 49 | shuffle=True, 50 | seed=None, 51 | sampler_cfg=None, 52 | **kwargs, 53 | ): 54 | rank, world_size = get_dist_info() 55 | default_sampler_cfg = dict(type="Sampler", dataset=dataset) 56 | if shuffle: 57 | default_sampler_cfg.update(samples_per_gpu=samples_per_gpu) 58 | else: 59 | default_sampler_cfg.update(shuffle=False) 60 | if dist: 61 | default_sampler_cfg.update(num_replicas=world_size, rank=rank, seed=seed) 62 | sampler = build_sampler(sampler_cfg, dist, shuffle, default_sampler_cfg) 63 | 64 | batch_size = samples_per_gpu 65 | num_workers = workers_per_gpu 66 | else: 67 | sampler = ( 68 | build_sampler(sampler_cfg, default_args=default_sampler_cfg) 69 | if shuffle 70 | else None 71 | ) 72 | batch_size = num_gpus * samples_per_gpu 73 | num_workers = num_gpus * workers_per_gpu 74 | 75 | init_fn = ( 76 | partial(worker_init_fn, num_workers=num_workers, rank=rank, seed=seed) 77 | if seed is not None 78 | else None 79 | ) 80 | 81 | data_loader = DataLoader( 82 | dataset, 83 | batch_size=batch_size, 84 | sampler=sampler, 85 | num_workers=num_workers, 86 | collate_fn=partial(collate, samples_per_gpu=samples_per_gpu, flatten=True), 87 | pin_memory=False, 88 | worker_init_fn=init_fn, 89 | **kwargs, 90 | ) 91 | return data_loader 92 | 93 | 94 | def collate(batch, samples_per_gpu=1, flatten=False): 95 | """Puts each data field into a tensor/DataContainer with outer dimension 96 | batch size. 97 | 98 | Extend default_collate to add support for 99 | :type:`~mmcv.parallel.DataContainer`. There are 3 cases. 100 | 101 | 1. cpu_only = True, e.g., meta data 102 | 2. cpu_only = False, stack = True, e.g., images tensors 103 | 3. cpu_only = False, stack = False, e.g., gt bboxes 104 | """ 105 | if not isinstance(batch, Sequence): 106 | raise TypeError(f"{batch.dtype} is not supported.") 107 | 108 | if isinstance(batch[0], DataContainer): 109 | stacked = [] 110 | if batch[0].cpu_only: 111 | for i in range(0, len(batch), samples_per_gpu): 112 | stacked.append( 113 | [sample.data for sample in batch[i : i + samples_per_gpu]] 114 | ) 115 | return DataContainer( 116 | stacked, batch[0].stack, batch[0].padding_value, cpu_only=True 117 | ) 118 | elif batch[0].stack: 119 | for i in range(0, len(batch), samples_per_gpu): 120 | assert isinstance(batch[i].data, torch.Tensor) 121 | 122 | if batch[i].pad_dims is not None: 123 | ndim = batch[i].dim() 124 | assert ndim > batch[i].pad_dims 125 | max_shape = [0 for _ in range(batch[i].pad_dims)] 126 | for dim in range(1, batch[i].pad_dims + 1): 127 | max_shape[dim - 1] = batch[i].size(-dim) 128 | for sample in batch[i : i + samples_per_gpu]: 129 | for dim in range(0, ndim - batch[i].pad_dims): 130 | assert batch[i].size(dim) == sample.size(dim) 131 | for dim in range(1, batch[i].pad_dims + 1): 132 | max_shape[dim - 1] = max( 133 | max_shape[dim - 1], sample.size(-dim) 134 | ) 135 | padded_samples = [] 136 | for sample in batch[i : i + samples_per_gpu]: 137 | pad = [0 for _ in range(batch[i].pad_dims * 2)] 138 | for dim in range(1, batch[i].pad_dims + 1): 139 | pad[2 * dim - 1] = max_shape[dim - 1] - sample.size(-dim) 140 | padded_samples.append( 141 | F.pad(sample.data, pad, value=sample.padding_value) 142 | ) 143 | stacked.append(default_collate(padded_samples)) 144 | elif batch[i].pad_dims is None: 145 | stacked.append( 146 | default_collate( 147 | [sample.data for sample in batch[i : i + samples_per_gpu]] 148 | ) 149 | ) 150 | else: 151 | raise ValueError("pad_dims should be either None or integers (1-3)") 152 | 153 | else: 154 | for i in range(0, len(batch), samples_per_gpu): 155 | stacked.append( 156 | [sample.data for sample in batch[i : i + samples_per_gpu]] 157 | ) 158 | return DataContainer(stacked, batch[0].stack, batch[0].padding_value) 159 | elif any([isinstance(b, Sequence) for b in batch]): 160 | if flatten: 161 | flattened = [] 162 | for b in batch: 163 | if isinstance(b, Sequence): 164 | flattened.extend(b) 165 | else: 166 | flattened.extend([b]) 167 | return collate(flattened, len(flattened)) 168 | else: 169 | transposed = zip(*batch) 170 | return [collate(samples, samples_per_gpu) for samples in transposed] 171 | elif isinstance(batch[0], Mapping): 172 | return { 173 | key: collate([d[key] for d in batch], samples_per_gpu) for key in batch[0] 174 | } 175 | else: 176 | return default_collate(batch) 177 | -------------------------------------------------------------------------------- /ssod/datasets/dataset_wrappers.py: -------------------------------------------------------------------------------- 1 | from mmdet.datasets import DATASETS, ConcatDataset, build_dataset 2 | 3 | 4 | @DATASETS.register_module() 5 | class SemiDataset(ConcatDataset): 6 | """Wrapper for semisupervised od.""" 7 | 8 | def __init__(self, sup: dict, unsup: dict, **kwargs): 9 | super().__init__([build_dataset(sup), build_dataset(unsup)], **kwargs) 10 | 11 | @property 12 | def sup(self): 13 | return self.datasets[0] 14 | 15 | @property 16 | def unsup(self): 17 | return self.datasets[1] 18 | -------------------------------------------------------------------------------- /ssod/datasets/pipelines/__init__.py: -------------------------------------------------------------------------------- 1 | from .formatting import * 2 | from .rand_aug import * 3 | -------------------------------------------------------------------------------- /ssod/datasets/pipelines/formatting.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from mmdet.datasets import PIPELINES 3 | from mmdet.datasets.pipelines.formating import Collect 4 | 5 | from ssod.core import TrimapMasks 6 | 7 | 8 | @PIPELINES.register_module() 9 | class ExtraAttrs(object): 10 | def __init__(self, **attrs): 11 | self.attrs = attrs 12 | 13 | def __call__(self, results): 14 | for k, v in self.attrs.items(): 15 | assert k not in results 16 | results[k] = v 17 | return results 18 | 19 | 20 | @PIPELINES.register_module() 21 | class ExtraCollect(Collect): 22 | def __init__(self, *args, extra_meta_keys=[], **kwargs): 23 | super().__init__(*args, **kwargs) 24 | self.meta_keys = self.meta_keys + tuple(extra_meta_keys) 25 | 26 | 27 | @PIPELINES.register_module() 28 | class PseudoSamples(object): 29 | def __init__( 30 | self, with_bbox=False, with_mask=False, with_seg=False, fill_value=255 31 | ): 32 | """ 33 | Replacing gt labels in original data with fake labels or adding extra fake labels for unlabeled data. 34 | This is to remove the effect of labeled data and keep its elements aligned with other sample. 35 | Args: 36 | with_bbox: 37 | with_mask: 38 | with_seg: 39 | fill_value: 40 | """ 41 | self.with_bbox = with_bbox 42 | self.with_mask = with_mask 43 | self.with_seg = with_seg 44 | self.fill_value = fill_value 45 | 46 | def __call__(self, results): 47 | if self.with_bbox: 48 | results["gt_bboxes"] = np.zeros((0, 4)) 49 | results["gt_labels"] = np.zeros((0,)) 50 | if "bbox_fields" not in results: 51 | results["bbox_fields"] = [] 52 | if "gt_bboxes" not in results["bbox_fields"]: 53 | results["bbox_fields"].append("gt_bboxes") 54 | if self.with_mask: 55 | num_inst = len(results["gt_bboxes"]) 56 | h, w = results["img"].shape[:2] 57 | results["gt_masks"] = TrimapMasks( 58 | [ 59 | self.fill_value * np.ones((h, w), dtype=np.uint8) 60 | for _ in range(num_inst) 61 | ], 62 | h, 63 | w, 64 | ) 65 | 66 | if "mask_fields" not in results: 67 | results["mask_fields"] = [] 68 | if "gt_masks" not in results["mask_fields"]: 69 | results["mask_fields"].append("gt_masks") 70 | if self.with_seg: 71 | results["gt_semantic_seg"] = self.fill_value * np.ones( 72 | results["img"].shape[:2], dtype=np.uint8 73 | ) 74 | if "seg_fields" not in results: 75 | results["seg_fields"] = [] 76 | if "gt_semantic_seg" not in results["seg_fields"]: 77 | results["seg_fields"].append("gt_semantic_seg") 78 | return results 79 | -------------------------------------------------------------------------------- /ssod/datasets/pipelines/geo_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Record the geometric transformation information used in the augmentation in a transformation matrix. 3 | """ 4 | import numpy as np 5 | 6 | 7 | class GeometricTransformationBase(object): 8 | @classmethod 9 | def inverse(cls, results): 10 | # compute the inverse 11 | return results["transform_matrix"].I # 3x3 12 | 13 | @classmethod 14 | def apply(self, results, operator, **kwargs): 15 | trans_matrix = getattr(self, f"_get_{operator}_matrix")(**kwargs) 16 | if "transform_matrix" not in results: 17 | results["transform_matrix"] = trans_matrix 18 | else: 19 | base_transformation = results["transform_matrix"] 20 | results["transform_matrix"] = np.dot(trans_matrix, base_transformation) 21 | 22 | @classmethod 23 | def apply_cv2_matrix(self, results, cv2_matrix): 24 | if cv2_matrix.shape[0] == 2: 25 | mat = np.concatenate( 26 | [cv2_matrix, np.array([0, 0, 1]).reshape((1, 3))], axis=0 27 | ) 28 | else: 29 | mat = cv2_matrix 30 | base_transformation = results["transform_matrix"] 31 | results["transform_matrix"] = np.dot(mat, base_transformation) 32 | return results 33 | 34 | @classmethod 35 | def _get_rotate_matrix(cls, degree=None, cv2_rotation_matrix=None, inverse=False): 36 | # TODO: this is rotated by zero point 37 | if degree is None and cv2_rotation_matrix is None: 38 | raise ValueError( 39 | "At least one of degree or rotation matrix should be provided" 40 | ) 41 | if degree: 42 | if inverse: 43 | degree = -degree 44 | rad = degree * np.pi / 180 45 | sin_a = np.sin(rad) 46 | cos_a = np.cos(rad) 47 | return np.array([[cos_a, sin_a, 0], [-sin_a, cos_a, 0], [0, 0, 1]]) # 2x3 48 | else: 49 | mat = np.concatenate( 50 | [cv2_rotation_matrix, np.array([0, 0, 1]).reshape((1, 3))], axis=0 51 | ) 52 | if inverse: 53 | mat = mat * np.array([[1, -1, -1], [-1, 1, -1], [1, 1, 1]]) 54 | return mat 55 | 56 | @classmethod 57 | def _get_shift_matrix(cls, dx=0, dy=0, inverse=False): 58 | if inverse: 59 | dx = -dx 60 | dy = -dy 61 | return np.array([[1, 0, dx], [0, 1, dy], [0, 0, 1]]) 62 | 63 | @classmethod 64 | def _get_shear_matrix( 65 | cls, degree=None, magnitude=None, direction="horizontal", inverse=False 66 | ): 67 | if magnitude is None: 68 | assert degree is not None 69 | rad = degree * np.pi / 180 70 | magnitude = np.tan(rad) 71 | 72 | if inverse: 73 | magnitude = -magnitude 74 | if direction == "horizontal": 75 | shear_matrix = np.float32([[1, magnitude, 0], [0, 1, 0], [0, 0, 1]]) 76 | else: 77 | shear_matrix = np.float32([[1, 0, 0], [magnitude, 1, 0], [0, 0, 1]]) 78 | return shear_matrix 79 | 80 | @classmethod 81 | def _get_flip_matrix(cls, shape, direction="horizontal", inverse=False): 82 | h, w = shape 83 | if direction == "horizontal": 84 | flip_matrix = np.float32([[-1, 0, w], [0, 1, 0], [0, 0, 1]]) 85 | else: 86 | flip_matrix = np.float32([[1, 0, 0], [0, h - 1, 0], [0, 0, 1]]) 87 | return flip_matrix 88 | 89 | @classmethod 90 | def _get_scale_matrix(cls, sx, sy, inverse=False): 91 | if inverse: 92 | sx = 1 / sx 93 | sy = 1 / sy 94 | return np.float32([[sx, 0, 0], [0, sy, 0], [0, 0, 1]]) 95 | -------------------------------------------------------------------------------- /ssod/datasets/pipelines/rand_aug.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified from https://github.com/google-research/ssl_detection/blob/master/detection/utils/augmentation.py. 3 | """ 4 | import copy 5 | 6 | import cv2 7 | import mmcv 8 | import numpy as np 9 | from PIL import Image, ImageEnhance, ImageOps 10 | from mmcv.image.colorspace import bgr2rgb, rgb2bgr 11 | from mmdet.core.mask import BitmapMasks, PolygonMasks 12 | from mmdet.datasets import PIPELINES 13 | from mmdet.datasets.pipelines import Compose as BaseCompose 14 | from mmdet.datasets.pipelines import transforms 15 | 16 | from .geo_utils import GeometricTransformationBase as GTrans 17 | 18 | PARAMETER_MAX = 10 19 | 20 | 21 | def int_parameter(level, maxval, max_level=None): 22 | if max_level is None: 23 | max_level = PARAMETER_MAX 24 | return int(level * maxval / max_level) 25 | 26 | 27 | def float_parameter(level, maxval, max_level=None): 28 | if max_level is None: 29 | max_level = PARAMETER_MAX 30 | return float(level) * maxval / max_level 31 | 32 | 33 | class RandAug(object): 34 | """refer to https://github.com/google-research/ssl_detection/blob/00d52272f 35 | 61b56eade8d5ace18213cba6c74f6d8/detection/utils/augmentation.py#L240.""" 36 | 37 | def __init__( 38 | self, 39 | prob: float = 1.0, 40 | magnitude: int = 10, 41 | random_magnitude: bool = True, 42 | record: bool = False, 43 | magnitude_limit: int = 10, 44 | ): 45 | assert 0 <= prob <= 1, f"probability should be in (0,1) but get {prob}" 46 | assert ( 47 | magnitude <= PARAMETER_MAX 48 | ), f"magnitude should be small than max value {PARAMETER_MAX} but get {magnitude}" 49 | 50 | self.prob = prob 51 | self.magnitude = magnitude 52 | self.magnitude_limit = magnitude_limit 53 | self.random_magnitude = random_magnitude 54 | self.record = record 55 | self.buffer = None 56 | 57 | def __call__(self, results): 58 | if np.random.random() < self.prob: 59 | magnitude = self.magnitude 60 | if self.random_magnitude: 61 | magnitude = np.random.randint(1, magnitude) 62 | if self.record: 63 | if "aug_info" not in results: 64 | results["aug_info"] = [] 65 | results["aug_info"].append(self.get_aug_info(magnitude=magnitude)) 66 | results = self.apply(results, magnitude) 67 | # clear buffer 68 | return results 69 | 70 | def apply(self, results, magnitude: int = None): 71 | raise NotImplementedError() 72 | 73 | def __repr__(self): 74 | return f"{self.__class__.__name__}(prob={self.prob},magnitude={self.magnitude},max_magnitude={self.magnitude_limit},random_magnitude={self.random_magnitude})" 75 | 76 | def get_aug_info(self, **kwargs): 77 | aug_info = dict(type=self.__class__.__name__) 78 | aug_info.update( 79 | dict( 80 | prob=1.0, 81 | random_magnitude=False, 82 | record=False, 83 | magnitude=self.magnitude, 84 | ) 85 | ) 86 | aug_info.update(kwargs) 87 | return aug_info 88 | 89 | def enable_record(self, mode: bool = True): 90 | self.record = mode 91 | 92 | 93 | @PIPELINES.register_module() 94 | class Identity(RandAug): 95 | def apply(self, results, magnitude: int = None): 96 | return results 97 | 98 | 99 | @PIPELINES.register_module() 100 | class AutoContrast(RandAug): 101 | def apply(self, results, magnitude=None): 102 | for key in results.get("img_fields", ["img"]): 103 | img = bgr2rgb(results[key]) 104 | results[key] = rgb2bgr( 105 | np.asarray(ImageOps.autocontrast(Image.fromarray(img)), dtype=img.dtype) 106 | ) 107 | return results 108 | 109 | 110 | @PIPELINES.register_module() 111 | class RandEqualize(RandAug): 112 | def apply(self, results, magnitude=None): 113 | for key in results.get("img_fields", ["img"]): 114 | img = bgr2rgb(results[key]) 115 | results[key] = rgb2bgr( 116 | np.asarray(ImageOps.equalize(Image.fromarray(img)), dtype=img.dtype) 117 | ) 118 | return results 119 | 120 | 121 | @PIPELINES.register_module() 122 | class RandSolarize(RandAug): 123 | def apply(self, results, magnitude=None): 124 | for key in results.get("img_fields", ["img"]): 125 | img = results[key] 126 | results[key] = mmcv.solarize( 127 | img, min(int_parameter(magnitude, 256, self.magnitude_limit), 255) 128 | ) 129 | return results 130 | 131 | 132 | def _enhancer_impl(enhancer): 133 | """Sets level to be between 0.1 and 1.8 for ImageEnhance transforms of 134 | PIL.""" 135 | 136 | def impl(pil_img, level, max_level=None): 137 | v = float_parameter(level, 1.8, max_level) + 0.1 # going to 0 just destroys it 138 | return enhancer(pil_img).enhance(v) 139 | 140 | return impl 141 | 142 | 143 | class RandEnhance(RandAug): 144 | op = None 145 | 146 | def apply(self, results, magnitude=None): 147 | for key in results.get("img_fields", ["img"]): 148 | img = bgr2rgb(results[key]) 149 | 150 | results[key] = rgb2bgr( 151 | np.asarray( 152 | _enhancer_impl(self.op)( 153 | Image.fromarray(img), magnitude, self.magnitude_limit 154 | ), 155 | dtype=img.dtype, 156 | ) 157 | ) 158 | return results 159 | 160 | 161 | @PIPELINES.register_module() 162 | class RandColor(RandEnhance): 163 | op = ImageEnhance.Color 164 | 165 | 166 | @PIPELINES.register_module() 167 | class RandContrast(RandEnhance): 168 | op = ImageEnhance.Contrast 169 | 170 | 171 | @PIPELINES.register_module() 172 | class RandBrightness(RandEnhance): 173 | op = ImageEnhance.Brightness 174 | 175 | 176 | @PIPELINES.register_module() 177 | class RandSharpness(RandEnhance): 178 | op = ImageEnhance.Sharpness 179 | 180 | 181 | @PIPELINES.register_module() 182 | class RandPosterize(RandAug): 183 | def apply(self, results, magnitude=None): 184 | for key in results.get("img_fields", ["img"]): 185 | img = bgr2rgb(results[key]) 186 | magnitude = int_parameter(magnitude, 4, self.magnitude_limit) 187 | results[key] = rgb2bgr( 188 | np.asarray( 189 | ImageOps.posterize(Image.fromarray(img), 4 - magnitude), 190 | dtype=img.dtype, 191 | ) 192 | ) 193 | return results 194 | 195 | 196 | @PIPELINES.register_module() 197 | class Sequential(BaseCompose): 198 | def __init__(self, transforms, record: bool = False): 199 | super().__init__(transforms) 200 | self.record = record 201 | self.enable_record(record) 202 | 203 | def enable_record(self, mode: bool = True): 204 | # enable children to record 205 | self.record = mode 206 | for transform in self.transforms: 207 | transform.enable_record(mode) 208 | 209 | 210 | @PIPELINES.register_module() 211 | class OneOf(Sequential): 212 | def __init__(self, transforms, record: bool = False): 213 | self.transforms = [] 214 | for trans in transforms: 215 | if isinstance(trans, list): 216 | self.transforms.append(Sequential(trans)) 217 | else: 218 | assert isinstance(trans, dict) 219 | self.transforms.append(Sequential([trans])) 220 | self.enable_record(record) 221 | 222 | def __call__(self, results): 223 | transform = np.random.choice(self.transforms) 224 | return transform(results) 225 | 226 | 227 | @PIPELINES.register_module() 228 | class ShuffledSequential(Sequential): 229 | def __call__(self, data): 230 | order = np.random.permutation(len(self.transforms)) 231 | for idx in order: 232 | t = self.transforms[idx] 233 | data = t(data) 234 | if data is None: 235 | return None 236 | return data 237 | 238 | 239 | """ 240 | Geometric Augmentation. Modified from thirdparty/mmdetection/mmdet/datasets/pipelines/auto_augment.py 241 | """ 242 | 243 | 244 | def bbox2fields(): 245 | """The key correspondence from bboxes to labels, masks and 246 | segmentations.""" 247 | bbox2label = {"gt_bboxes": "gt_labels", "gt_bboxes_ignore": "gt_labels_ignore"} 248 | bbox2mask = {"gt_bboxes": "gt_masks", "gt_bboxes_ignore": "gt_masks_ignore"} 249 | bbox2seg = { 250 | "gt_bboxes": "gt_semantic_seg", 251 | } 252 | return bbox2label, bbox2mask, bbox2seg 253 | 254 | 255 | class GeometricAugmentation(object): 256 | def __init__( 257 | self, 258 | img_fill_val=125, 259 | seg_ignore_label=255, 260 | min_size=0, 261 | prob: float = 1.0, 262 | random_magnitude: bool = True, 263 | record: bool = False, 264 | ): 265 | if isinstance(img_fill_val, (float, int)): 266 | img_fill_val = tuple([float(img_fill_val)] * 3) 267 | elif isinstance(img_fill_val, tuple): 268 | assert len(img_fill_val) == 3, "img_fill_val as tuple must have 3 elements." 269 | img_fill_val = tuple([float(val) for val in img_fill_val]) 270 | assert np.all( 271 | [0 <= val <= 255 for val in img_fill_val] 272 | ), "all elements of img_fill_val should between range [0,255]." 273 | self.img_fill_val = img_fill_val 274 | self.seg_ignore_label = seg_ignore_label 275 | self.min_size = min_size 276 | self.prob = prob 277 | self.random_magnitude = random_magnitude 278 | self.record = record 279 | 280 | def __call__(self, results): 281 | if np.random.random() < self.prob: 282 | magnitude: dict = self.get_magnitude(results) 283 | if self.record: 284 | if "aug_info" not in results: 285 | results["aug_info"] = [] 286 | results["aug_info"].append(self.get_aug_info(**magnitude)) 287 | results = self.apply(results, **magnitude) 288 | self._filter_invalid(results, min_size=self.min_size) 289 | return results 290 | 291 | def get_magnitude(self, results) -> dict: 292 | raise NotImplementedError() 293 | 294 | def apply(self, results, **kwargs): 295 | raise NotImplementedError() 296 | 297 | def enable_record(self, mode: bool = True): 298 | self.record = mode 299 | 300 | def get_aug_info(self, **kwargs): 301 | aug_info = dict(type=self.__class__.__name__) 302 | aug_info.update( 303 | dict( 304 | # make op deterministic 305 | prob=1.0, 306 | random_magnitude=False, 307 | record=False, 308 | img_fill_val=self.img_fill_val, 309 | seg_ignore_label=self.seg_ignore_label, 310 | min_size=self.min_size, 311 | ) 312 | ) 313 | aug_info.update(kwargs) 314 | return aug_info 315 | 316 | def _filter_invalid(self, results, min_size=0): 317 | """Filter bboxes and masks too small or translated out of image.""" 318 | if min_size is None: 319 | return results 320 | bbox2label, bbox2mask, _ = bbox2fields() 321 | for key in results.get("bbox_fields", []): 322 | bbox_w = results[key][:, 2] - results[key][:, 0] 323 | bbox_h = results[key][:, 3] - results[key][:, 1] 324 | valid_inds = (bbox_w > min_size) & (bbox_h > min_size) 325 | valid_inds = np.nonzero(valid_inds)[0] 326 | results[key] = results[key][valid_inds] 327 | # label fields. e.g. gt_labels and gt_labels_ignore 328 | label_key = bbox2label.get(key) 329 | if label_key in results: 330 | results[label_key] = results[label_key][valid_inds] 331 | # mask fields, e.g. gt_masks and gt_masks_ignore 332 | mask_key = bbox2mask.get(key) 333 | if mask_key in results: 334 | results[mask_key] = results[mask_key][valid_inds] 335 | return results 336 | 337 | def __repr__(self): 338 | return f"""{self.__class__.__name__}( 339 | img_fill_val={self.img_fill_val}, 340 | seg_ignore_label={self.seg_ignore_label}, 341 | min_size={self.magnitude}, 342 | prob: float = {self.prob}, 343 | random_magnitude: bool = {self.random_magnitude}, 344 | )""" 345 | 346 | 347 | @PIPELINES.register_module() 348 | class RandTranslate(GeometricAugmentation): 349 | def __init__(self, x=None, y=None, **kwargs): 350 | super().__init__(**kwargs) 351 | self.x = x 352 | self.y = y 353 | if self.x is None and self.y is None: 354 | self.prob = 0.0 355 | 356 | def get_magnitude(self, results): 357 | magnitude = {} 358 | if self.random_magnitude: 359 | if isinstance(self.x, (list, tuple)): 360 | assert len(self.x) == 2 361 | x = np.random.random() * (self.x[1] - self.x[0]) + self.x[0] 362 | magnitude["x"] = x 363 | if isinstance(self.y, (list, tuple)): 364 | assert len(self.y) == 2 365 | y = np.random.random() * (self.y[1] - self.y[0]) + self.y[0] 366 | magnitude["y"] = y 367 | else: 368 | if self.x is not None: 369 | assert isinstance(self.x, (int, float)) 370 | magnitude["x"] = self.x 371 | if self.y is not None: 372 | assert isinstance(self.y, (int, float)) 373 | magnitude["y"] = self.y 374 | return magnitude 375 | 376 | def apply(self, results, x=None, y=None): 377 | # ratio to pixel 378 | h, w, c = results["img_shape"] 379 | if x is not None: 380 | x = w * x 381 | if y is not None: 382 | y = h * y 383 | if x is not None: 384 | # translate horizontally 385 | self._translate(results, x) 386 | if y is not None: 387 | # translate veritically 388 | self._translate(results, y, direction="vertical") 389 | return results 390 | 391 | def _translate(self, results, offset, direction="horizontal"): 392 | if self.record: 393 | GTrans.apply( 394 | results, 395 | "shift", 396 | dx=offset if direction == "horizontal" else 0, 397 | dy=offset if direction == "vertical" else 0, 398 | ) 399 | self._translate_img(results, offset, direction=direction) 400 | self._translate_bboxes(results, offset, direction=direction) 401 | # fill_val defaultly 0 for BitmapMasks and None for PolygonMasks. 402 | self._translate_masks(results, offset, direction=direction) 403 | self._translate_seg( 404 | results, offset, fill_val=self.seg_ignore_label, direction=direction 405 | ) 406 | 407 | def _translate_img(self, results, offset, direction="horizontal"): 408 | for key in results.get("img_fields", ["img"]): 409 | img = results[key].copy() 410 | results[key] = mmcv.imtranslate( 411 | img, offset, direction, self.img_fill_val 412 | ).astype(img.dtype) 413 | 414 | def _translate_bboxes(self, results, offset, direction="horizontal"): 415 | """Shift bboxes horizontally or vertically, according to offset.""" 416 | h, w, c = results["img_shape"] 417 | for key in results.get("bbox_fields", []): 418 | min_x, min_y, max_x, max_y = np.split( 419 | results[key], results[key].shape[-1], axis=-1 420 | ) 421 | if direction == "horizontal": 422 | min_x = np.maximum(0, min_x + offset) 423 | max_x = np.minimum(w, max_x + offset) 424 | elif direction == "vertical": 425 | min_y = np.maximum(0, min_y + offset) 426 | max_y = np.minimum(h, max_y + offset) 427 | 428 | # the boxes translated outside of image will be filtered along with 429 | # the corresponding masks, by invoking ``_filter_invalid``. 430 | results[key] = np.concatenate([min_x, min_y, max_x, max_y], axis=-1) 431 | 432 | def _translate_masks(self, results, offset, direction="horizontal", fill_val=0): 433 | """Translate masks horizontally or vertically.""" 434 | h, w, c = results["img_shape"] 435 | for key in results.get("mask_fields", []): 436 | masks = results[key] 437 | results[key] = masks.translate((h, w), offset, direction, fill_val) 438 | 439 | def _translate_seg(self, results, offset, direction="horizontal", fill_val=255): 440 | """Translate segmentation maps horizontally or vertically.""" 441 | for key in results.get("seg_fields", []): 442 | seg = results[key].copy() 443 | results[key] = mmcv.imtranslate(seg, offset, direction, fill_val).astype( 444 | seg.dtype 445 | ) 446 | 447 | def __repr__(self): 448 | repr_str = super().__repr__() 449 | return ("\n").join( 450 | repr_str.split("\n")[:-1] 451 | + [f"x={self.x}", f"y={self.y}"] 452 | + repr_str.split("\n")[-1:] 453 | ) 454 | 455 | 456 | @PIPELINES.register_module() 457 | class RandRotate(GeometricAugmentation): 458 | def __init__(self, angle=None, center=None, scale=1, **kwargs): 459 | super().__init__(**kwargs) 460 | self.angle = angle 461 | self.center = center 462 | self.scale = scale 463 | if self.angle is None: 464 | self.prob = 0.0 465 | 466 | def get_magnitude(self, results): 467 | magnitude = {} 468 | if self.random_magnitude: 469 | if isinstance(self.angle, (list, tuple)): 470 | assert len(self.angle) == 2 471 | angle = ( 472 | np.random.random() * (self.angle[1] - self.angle[0]) + self.angle[0] 473 | ) 474 | magnitude["angle"] = angle 475 | else: 476 | if self.angle is not None: 477 | assert isinstance(self.angle, (int, float)) 478 | magnitude["angle"] = self.angle 479 | 480 | return magnitude 481 | 482 | def apply(self, results, angle: float = None): 483 | h, w = results["img"].shape[:2] 484 | center = self.center 485 | if center is None: 486 | center = ((w - 1) * 0.5, (h - 1) * 0.5) 487 | self._rotate_img(results, angle, center, self.scale) 488 | rotate_matrix = cv2.getRotationMatrix2D(center, -angle, self.scale) 489 | if self.record: 490 | GTrans.apply(results, "rotate", cv2_rotation_matrix=rotate_matrix) 491 | self._rotate_bboxes(results, rotate_matrix) 492 | self._rotate_masks(results, angle, center, self.scale, fill_val=0) 493 | self._rotate_seg( 494 | results, angle, center, self.scale, fill_val=self.seg_ignore_label 495 | ) 496 | return results 497 | 498 | def _rotate_img(self, results, angle, center=None, scale=1.0): 499 | """Rotate the image. 500 | 501 | Args: 502 | results (dict): Result dict from loading pipeline. 503 | angle (float): Rotation angle in degrees, positive values 504 | mean clockwise rotation. Same in ``mmcv.imrotate``. 505 | center (tuple[float], optional): Center point (w, h) of the 506 | rotation. Same in ``mmcv.imrotate``. 507 | scale (int | float): Isotropic scale factor. Same in 508 | ``mmcv.imrotate``. 509 | """ 510 | for key in results.get("img_fields", ["img"]): 511 | img = results[key].copy() 512 | img_rotated = mmcv.imrotate( 513 | img, angle, center, scale, border_value=self.img_fill_val 514 | ) 515 | results[key] = img_rotated.astype(img.dtype) 516 | 517 | def _rotate_bboxes(self, results, rotate_matrix): 518 | """Rotate the bboxes.""" 519 | h, w, c = results["img_shape"] 520 | for key in results.get("bbox_fields", []): 521 | min_x, min_y, max_x, max_y = np.split( 522 | results[key], results[key].shape[-1], axis=-1 523 | ) 524 | coordinates = np.stack( 525 | [[min_x, min_y], [max_x, min_y], [min_x, max_y], [max_x, max_y]] 526 | ) # [4, 2, nb_bbox, 1] 527 | # pad 1 to convert from format [x, y] to homogeneous 528 | # coordinates format [x, y, 1] 529 | coordinates = np.concatenate( 530 | ( 531 | coordinates, 532 | np.ones((4, 1, coordinates.shape[2], 1), coordinates.dtype), 533 | ), 534 | axis=1, 535 | ) # [4, 3, nb_bbox, 1] 536 | coordinates = coordinates.transpose((2, 0, 1, 3)) # [nb_bbox, 4, 3, 1] 537 | rotated_coords = np.matmul(rotate_matrix, coordinates) # [nb_bbox, 4, 2, 1] 538 | rotated_coords = rotated_coords[..., 0] # [nb_bbox, 4, 2] 539 | min_x, min_y = ( 540 | np.min(rotated_coords[:, :, 0], axis=1), 541 | np.min(rotated_coords[:, :, 1], axis=1), 542 | ) 543 | max_x, max_y = ( 544 | np.max(rotated_coords[:, :, 0], axis=1), 545 | np.max(rotated_coords[:, :, 1], axis=1), 546 | ) 547 | min_x, min_y = ( 548 | np.clip(min_x, a_min=0, a_max=w), 549 | np.clip(min_y, a_min=0, a_max=h), 550 | ) 551 | max_x, max_y = ( 552 | np.clip(max_x, a_min=min_x, a_max=w), 553 | np.clip(max_y, a_min=min_y, a_max=h), 554 | ) 555 | results[key] = np.stack([min_x, min_y, max_x, max_y], axis=-1).astype( 556 | results[key].dtype 557 | ) 558 | 559 | def _rotate_masks(self, results, angle, center=None, scale=1.0, fill_val=0): 560 | """Rotate the masks.""" 561 | h, w, c = results["img_shape"] 562 | for key in results.get("mask_fields", []): 563 | masks = results[key] 564 | results[key] = masks.rotate((h, w), angle, center, scale, fill_val) 565 | 566 | def _rotate_seg(self, results, angle, center=None, scale=1.0, fill_val=255): 567 | """Rotate the segmentation map.""" 568 | for key in results.get("seg_fields", []): 569 | seg = results[key].copy() 570 | results[key] = mmcv.imrotate( 571 | seg, angle, center, scale, border_value=fill_val 572 | ).astype(seg.dtype) 573 | 574 | def __repr__(self): 575 | repr_str = super().__repr__() 576 | return ("\n").join( 577 | repr_str.split("\n")[:-1] 578 | + [f"angle={self.angle}", f"center={self.center}", f"scale={self.scale}"] 579 | + repr_str.split("\n")[-1:] 580 | ) 581 | 582 | 583 | @PIPELINES.register_module() 584 | class RandShear(GeometricAugmentation): 585 | def __init__(self, x=None, y=None, interpolation="bilinear", **kwargs): 586 | super().__init__(**kwargs) 587 | self.x = x 588 | self.y = y 589 | self.interpolation = interpolation 590 | if self.x is None and self.y is None: 591 | self.prob = 0.0 592 | 593 | def get_magnitude(self, results): 594 | magnitude = {} 595 | if self.random_magnitude: 596 | if isinstance(self.x, (list, tuple)): 597 | assert len(self.x) == 2 598 | x = np.random.random() * (self.x[1] - self.x[0]) + self.x[0] 599 | magnitude["x"] = x 600 | if isinstance(self.y, (list, tuple)): 601 | assert len(self.y) == 2 602 | y = np.random.random() * (self.y[1] - self.y[0]) + self.y[0] 603 | magnitude["y"] = y 604 | else: 605 | if self.x is not None: 606 | assert isinstance(self.x, (int, float)) 607 | magnitude["x"] = self.x 608 | if self.y is not None: 609 | assert isinstance(self.y, (int, float)) 610 | magnitude["y"] = self.y 611 | return magnitude 612 | 613 | def apply(self, results, x=None, y=None): 614 | if x is not None: 615 | # translate horizontally 616 | self._shear(results, np.tanh(-x * np.pi / 180)) 617 | if y is not None: 618 | # translate veritically 619 | self._shear(results, np.tanh(y * np.pi / 180), direction="vertical") 620 | return results 621 | 622 | def _shear(self, results, magnitude, direction="horizontal"): 623 | if self.record: 624 | GTrans.apply(results, "shear", magnitude=magnitude, direction=direction) 625 | self._shear_img(results, magnitude, direction, interpolation=self.interpolation) 626 | self._shear_bboxes(results, magnitude, direction=direction) 627 | # fill_val defaultly 0 for BitmapMasks and None for PolygonMasks. 628 | self._shear_masks( 629 | results, magnitude, direction=direction, interpolation=self.interpolation 630 | ) 631 | self._shear_seg( 632 | results, 633 | magnitude, 634 | direction=direction, 635 | interpolation=self.interpolation, 636 | fill_val=self.seg_ignore_label, 637 | ) 638 | 639 | def _shear_img( 640 | self, results, magnitude, direction="horizontal", interpolation="bilinear" 641 | ): 642 | """Shear the image. 643 | 644 | Args: 645 | results (dict): Result dict from loading pipeline. 646 | magnitude (int | float): The magnitude used for shear. 647 | direction (str): The direction for shear, either "horizontal" 648 | or "vertical". 649 | interpolation (str): Same as in :func:`mmcv.imshear`. 650 | """ 651 | for key in results.get("img_fields", ["img"]): 652 | img = results[key] 653 | img_sheared = mmcv.imshear( 654 | img, 655 | magnitude, 656 | direction, 657 | border_value=self.img_fill_val, 658 | interpolation=interpolation, 659 | ) 660 | results[key] = img_sheared.astype(img.dtype) 661 | 662 | def _shear_bboxes(self, results, magnitude, direction="horizontal"): 663 | """Shear the bboxes.""" 664 | h, w, c = results["img_shape"] 665 | if direction == "horizontal": 666 | shear_matrix = np.stack([[1, magnitude], [0, 1]]).astype( 667 | np.float32 668 | ) # [2, 2] 669 | else: 670 | shear_matrix = np.stack([[1, 0], [magnitude, 1]]).astype(np.float32) 671 | for key in results.get("bbox_fields", []): 672 | min_x, min_y, max_x, max_y = np.split( 673 | results[key], results[key].shape[-1], axis=-1 674 | ) 675 | coordinates = np.stack( 676 | [[min_x, min_y], [max_x, min_y], [min_x, max_y], [max_x, max_y]] 677 | ) # [4, 2, nb_box, 1] 678 | coordinates = ( 679 | coordinates[..., 0].transpose((2, 1, 0)).astype(np.float32) 680 | ) # [nb_box, 2, 4] 681 | new_coords = np.matmul( 682 | shear_matrix[None, :, :], coordinates 683 | ) # [nb_box, 2, 4] 684 | min_x = np.min(new_coords[:, 0, :], axis=-1) 685 | min_y = np.min(new_coords[:, 1, :], axis=-1) 686 | max_x = np.max(new_coords[:, 0, :], axis=-1) 687 | max_y = np.max(new_coords[:, 1, :], axis=-1) 688 | min_x = np.clip(min_x, a_min=0, a_max=w) 689 | min_y = np.clip(min_y, a_min=0, a_max=h) 690 | max_x = np.clip(max_x, a_min=min_x, a_max=w) 691 | max_y = np.clip(max_y, a_min=min_y, a_max=h) 692 | results[key] = np.stack([min_x, min_y, max_x, max_y], axis=-1).astype( 693 | results[key].dtype 694 | ) 695 | 696 | def _shear_masks( 697 | self, 698 | results, 699 | magnitude, 700 | direction="horizontal", 701 | fill_val=0, 702 | interpolation="bilinear", 703 | ): 704 | """Shear the masks.""" 705 | h, w, c = results["img_shape"] 706 | for key in results.get("mask_fields", []): 707 | masks = results[key] 708 | results[key] = masks.shear( 709 | (h, w), 710 | magnitude, 711 | direction, 712 | border_value=fill_val, 713 | interpolation=interpolation, 714 | ) 715 | 716 | def _shear_seg( 717 | self, 718 | results, 719 | magnitude, 720 | direction="horizontal", 721 | fill_val=255, 722 | interpolation="bilinear", 723 | ): 724 | """Shear the segmentation maps.""" 725 | for key in results.get("seg_fields", []): 726 | seg = results[key] 727 | results[key] = mmcv.imshear( 728 | seg, 729 | magnitude, 730 | direction, 731 | border_value=fill_val, 732 | interpolation=interpolation, 733 | ).astype(seg.dtype) 734 | 735 | def __repr__(self): 736 | repr_str = super().__repr__() 737 | return ("\n").join( 738 | repr_str.split("\n")[:-1] 739 | + [f"x_magnitude={self.x}", f"y_magnitude={self.y}"] 740 | + repr_str.split("\n")[-1:] 741 | ) 742 | 743 | 744 | @PIPELINES.register_module() 745 | class RandErase(GeometricAugmentation): 746 | def __init__( 747 | self, 748 | n_iterations=None, 749 | size=None, 750 | squared: bool = True, 751 | patches=None, 752 | **kwargs, 753 | ): 754 | kwargs.update(min_size=None) 755 | super().__init__(**kwargs) 756 | self.n_iterations = n_iterations 757 | self.size = size 758 | self.squared = squared 759 | self.patches = patches 760 | 761 | def get_magnitude(self, results): 762 | magnitude = {} 763 | if self.random_magnitude: 764 | n_iterations = self._get_erase_cycle() 765 | patches = [] 766 | h, w, c = results["img_shape"] 767 | for i in range(n_iterations): 768 | # random sample patch size in the image 769 | ph, pw = self._get_patch_size(h, w) 770 | # random sample patch left top in the image 771 | px, py = np.random.randint(0, w - pw), np.random.randint(0, h - ph) 772 | patches.append([px, py, px + pw, py + ph]) 773 | magnitude["patches"] = patches 774 | else: 775 | assert self.patches is not None 776 | magnitude["patches"] = self.patches 777 | 778 | return magnitude 779 | 780 | def _get_erase_cycle(self): 781 | if isinstance(self.n_iterations, int): 782 | n_iterations = self.n_iterations 783 | else: 784 | assert ( 785 | isinstance(self.n_iterations, (tuple, list)) 786 | and len(self.n_iterations) == 2 787 | ) 788 | n_iterations = np.random.randint(*self.n_iterations) 789 | return n_iterations 790 | 791 | def _get_patch_size(self, h, w): 792 | if isinstance(self.size, float): 793 | assert 0 < self.size < 1 794 | return int(self.size * h), int(self.size * w) 795 | else: 796 | assert isinstance(self.size, (tuple, list)) 797 | assert len(self.size) == 2 798 | assert 0 <= self.size[0] < 1 and 0 <= self.size[1] < 1 799 | w_ratio = np.random.random() * (self.size[1] - self.size[0]) + self.size[0] 800 | h_ratio = w_ratio 801 | 802 | if not self.squared: 803 | h_ratio = ( 804 | np.random.random() * (self.size[1] - self.size[0]) + self.size[0] 805 | ) 806 | return int(h_ratio * h), int(w_ratio * w) 807 | 808 | def apply(self, results, patches: list): 809 | for patch in patches: 810 | self._erase_image(results, patch, fill_val=self.img_fill_val) 811 | self._erase_mask(results, patch) 812 | self._erase_seg(results, patch, fill_val=self.seg_ignore_label) 813 | return results 814 | 815 | def _erase_image(self, results, patch, fill_val=128): 816 | for key in results.get("img_fields", ["img"]): 817 | tmp = results[key].copy() 818 | x1, y1, x2, y2 = patch 819 | tmp[y1:y2, x1:x2, :] = fill_val 820 | results[key] = tmp 821 | 822 | def _erase_mask(self, results, patch, fill_val=0): 823 | for key in results.get("mask_fields", []): 824 | masks = results[key] 825 | if isinstance(masks, PolygonMasks): 826 | # convert mask to bitmask 827 | masks = masks.to_bitmap() 828 | x1, y1, x2, y2 = patch 829 | tmp = masks.masks.copy() 830 | tmp[:, y1:y2, x1:x2] = fill_val 831 | masks = BitmapMasks(tmp, masks.height, masks.width) 832 | results[key] = masks 833 | 834 | def _erase_seg(self, results, patch, fill_val=0): 835 | for key in results.get("seg_fields", []): 836 | seg = results[key].copy() 837 | x1, y1, x2, y2 = patch 838 | seg[y1:y2, x1:x2] = fill_val 839 | results[key] = seg 840 | 841 | 842 | @PIPELINES.register_module() 843 | class RecomputeBox(object): 844 | def __init__(self, record=False): 845 | self.record = record 846 | 847 | def __call__(self, results): 848 | if self.record: 849 | if "aug_info" not in results: 850 | results["aug_info"] = [] 851 | results["aug_info"].append(dict(type="RecomputeBox")) 852 | _, bbox2mask, _ = bbox2fields() 853 | for key in results.get("bbox_fields", []): 854 | mask_key = bbox2mask.get(key) 855 | if mask_key in results: 856 | masks = results[mask_key] 857 | results[key] = self._recompute_bbox(masks) 858 | return results 859 | 860 | def enable_record(self, mode: bool = True): 861 | self.record = mode 862 | 863 | def _recompute_bbox(self, masks): 864 | boxes = np.zeros(masks.masks.shape[0], 4, dtype=np.float32) 865 | x_any = np.any(masks.masks, axis=1) 866 | y_any = np.any(masks.masks, axis=2) 867 | for idx in range(masks.masks.shape[0]): 868 | x = np.where(x_any[idx, :])[0] 869 | y = np.where(y_any[idx, :])[0] 870 | if len(x) > 0 and len(y) > 0: 871 | boxes[idx, :] = np.array( 872 | [x[0], y[0], x[-1] + 1, y[-1] + 1], dtype=np.float32 873 | ) 874 | return boxes 875 | 876 | 877 | # TODO: Implement Augmentation Inside Box 878 | 879 | 880 | @PIPELINES.register_module() 881 | class RandResize(transforms.Resize): 882 | def __init__(self, record=False, **kwargs): 883 | super().__init__(**kwargs) 884 | self.record = record 885 | 886 | def __call__(self, results): 887 | results = super().__call__(results) 888 | if self.record: 889 | scale_factor = results["scale_factor"] 890 | GTrans.apply(results, "scale", sx=scale_factor[0], sy=scale_factor[1]) 891 | 892 | if "aug_info" not in results: 893 | results["aug_info"] = [] 894 | new_h, new_w = results["img"].shape[:2] 895 | results["aug_info"].append( 896 | dict( 897 | type=self.__class__.__name__, 898 | record=False, 899 | img_scale=(new_w, new_h), 900 | keep_ratio=False, 901 | bbox_clip_border=self.bbox_clip_border, 902 | backend=self.backend, 903 | ) 904 | ) 905 | return results 906 | 907 | def enable_record(self, mode: bool = True): 908 | self.record = mode 909 | 910 | 911 | @PIPELINES.register_module() 912 | class RandFlip(transforms.RandomFlip): 913 | def __init__(self, record=False, **kwargs): 914 | super().__init__(**kwargs) 915 | self.record = record 916 | 917 | def __call__(self, results): 918 | results = super().__call__(results) 919 | if self.record: 920 | if "aug_info" not in results: 921 | results["aug_info"] = [] 922 | if results["flip"]: 923 | GTrans.apply( 924 | results, 925 | "flip", 926 | direction=results["flip_direction"], 927 | shape=results["img_shape"][:2], 928 | ) 929 | results["aug_info"].append( 930 | dict( 931 | type=self.__class__.__name__, 932 | record=False, 933 | flip_ratio=1.0, 934 | direction=results["flip_direction"], 935 | ) 936 | ) 937 | else: 938 | results["aug_info"].append( 939 | dict( 940 | type=self.__class__.__name__, 941 | record=False, 942 | flip_ratio=0.0, 943 | direction="vertical", 944 | ) 945 | ) 946 | return results 947 | 948 | def enable_record(self, mode: bool = True): 949 | self.record = mode 950 | 951 | 952 | @PIPELINES.register_module() 953 | class MultiBranch(object): 954 | def __init__(self, **transform_group): 955 | self.transform_group = {k: BaseCompose(v) for k, v in transform_group.items()} 956 | 957 | def __call__(self, results): 958 | multi_results = [] 959 | for k, v in self.transform_group.items(): 960 | res = v(copy.deepcopy(results)) 961 | if res is None: 962 | return None 963 | # res["img_metas"]["tag"] = k 964 | multi_results.append(res) 965 | return multi_results 966 | -------------------------------------------------------------------------------- /ssod/datasets/pseudo_coco.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import json 3 | 4 | from mmdet.datasets import DATASETS, CocoDataset 5 | from mmdet.datasets.api_wrappers import COCO 6 | 7 | 8 | @DATASETS.register_module() 9 | class PseudoCocoDataset(CocoDataset): 10 | def __init__( 11 | self, 12 | ann_file, 13 | pseudo_ann_file, 14 | pipeline, 15 | confidence_threshold=0.9, 16 | classes=None, 17 | data_root=None, 18 | img_prefix="", 19 | seg_prefix=None, 20 | proposal_file=None, 21 | test_mode=False, 22 | filter_empty_gt=True, 23 | ): 24 | self.confidence_threshold = confidence_threshold 25 | self.pseudo_ann_file = pseudo_ann_file 26 | 27 | super().__init__( 28 | ann_file, 29 | pipeline, 30 | classes, 31 | data_root, 32 | img_prefix, 33 | seg_prefix, 34 | proposal_file, 35 | test_mode=test_mode, 36 | filter_empty_gt=filter_empty_gt, 37 | ) 38 | 39 | def load_pesudo_targets(self, pseudo_ann_file): 40 | with open(pseudo_ann_file) as f: 41 | pesudo_anns = json.load(f) 42 | print(f"loading {len(pesudo_anns)} results") 43 | 44 | def _add_attr(dict_terms, **kwargs): 45 | new_dict = copy.copy(dict_terms) 46 | new_dict.update(**kwargs) 47 | return new_dict 48 | 49 | def _compute_area(bbox): 50 | _, _, w, h = bbox 51 | return w * h 52 | 53 | pesudo_anns = [ 54 | _add_attr(ann, id=i, area=_compute_area(ann["bbox"])) 55 | for i, ann in enumerate(pesudo_anns) 56 | if ann["score"] > self.confidence_threshold 57 | ] 58 | print( 59 | f"With {len(pesudo_anns)} results over threshold {self.confidence_threshold}" 60 | ) 61 | 62 | return pesudo_anns 63 | 64 | def load_annotations(self, ann_file): 65 | """Load annotation from COCO style annotation file. 66 | 67 | Args: 68 | ann_file (str): Path of annotation file. 69 | Returns: 70 | list[dict]: Annotation info from COCO api. 71 | """ 72 | pesudo_anns = self.load_pesudo_targets(self.pseudo_ann_file) 73 | self.coco = COCO(ann_file) 74 | self.coco.dataset["annotations"] = pesudo_anns 75 | self.coco.createIndex() 76 | 77 | self.cat_ids = self.coco.get_cat_ids(cat_names=self.CLASSES) 78 | self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)} 79 | self.img_ids = self.coco.get_img_ids() 80 | data_infos = [] 81 | for i in self.img_ids: 82 | info = self.coco.load_imgs([i])[0] 83 | info["filename"] = info["file_name"] 84 | data_infos.append(info) 85 | 86 | return data_infos 87 | -------------------------------------------------------------------------------- /ssod/datasets/samplers/__init__.py: -------------------------------------------------------------------------------- 1 | from .semi_sampler import DistributedGroupSemiBalanceSampler 2 | __all__ = [ 3 | "DistributedGroupSemiBalanceSampler", 4 | ] 5 | -------------------------------------------------------------------------------- /ssod/datasets/samplers/semi_sampler.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | import numpy as np 4 | import torch 5 | from mmcv.runner import get_dist_info 6 | from torch.utils.data import Sampler, WeightedRandomSampler 7 | 8 | from ..builder import SAMPLERS 9 | 10 | 11 | @SAMPLERS.register_module() 12 | class DistributedGroupSemiBalanceSampler(Sampler): 13 | def __init__( 14 | self, 15 | dataset, 16 | by_prob=False, 17 | epoch_length=7330, 18 | sample_ratio=None, 19 | samples_per_gpu=1, 20 | num_replicas=None, 21 | rank=None, 22 | **kwargs 23 | ): 24 | # check to avoid some problem 25 | assert samples_per_gpu > 1, "samples_per_gpu should be greater than 1." 26 | _rank, _num_replicas = get_dist_info() 27 | if num_replicas is None: 28 | num_replicas = _num_replicas 29 | if rank is None: 30 | rank = _rank 31 | 32 | self.dataset = dataset 33 | self.samples_per_gpu = samples_per_gpu 34 | self.num_replicas = num_replicas 35 | self.rank = rank 36 | self.epoch = 0 37 | self.by_prob = by_prob 38 | 39 | assert hasattr(self.dataset, "flag") 40 | self.flag = self.dataset.flag 41 | self.group_sizes = np.bincount(self.flag) 42 | self.num_samples = 0 43 | self.cumulative_sizes = dataset.cumulative_sizes 44 | # decide the frequency to sample each kind of datasets 45 | if not isinstance(sample_ratio, list): 46 | sample_ratio = [sample_ratio] * len(self.cumulative_sizes) 47 | self.sample_ratio = sample_ratio 48 | self.sample_ratio = [ 49 | int(sr / min(self.sample_ratio)) for sr in self.sample_ratio 50 | ] 51 | self.size_of_dataset = [] 52 | cumulative_sizes = [0] + self.cumulative_sizes 53 | 54 | for i, _ in enumerate(self.group_sizes): 55 | size_of_dataset = 0 56 | cur_group_inds = np.where(self.flag == i)[0] 57 | for j in range(len(self.cumulative_sizes)): 58 | cur_group_cur_dataset = np.where( 59 | np.logical_and( 60 | cur_group_inds > cumulative_sizes[j], 61 | cur_group_inds < cumulative_sizes[j + 1], 62 | ) 63 | )[0] 64 | size_per_dataset = len(cur_group_cur_dataset) 65 | size_of_dataset = max( 66 | size_of_dataset, np.ceil(size_per_dataset / self.sample_ratio[j]) 67 | ) 68 | 69 | self.size_of_dataset.append( 70 | int(np.ceil(size_of_dataset / self.samples_per_gpu / self.num_replicas)) 71 | * self.samples_per_gpu 72 | ) 73 | for j in range(len(self.cumulative_sizes)): 74 | self.num_samples += self.size_of_dataset[-1] * self.sample_ratio[j] 75 | 76 | self.total_size = self.num_samples * self.num_replicas 77 | group_factor = [g / sum(self.group_sizes) for g in self.group_sizes] 78 | self.epoch_length = [int(np.round(gf * epoch_length)) for gf in group_factor] 79 | self.epoch_length[-1] = epoch_length - sum(self.epoch_length[:-1]) 80 | 81 | def __iter__(self): 82 | # deterministically shuffle based on epoch 83 | g = torch.Generator() 84 | g.manual_seed(self.epoch) 85 | indices = [] 86 | cumulative_sizes = [0] + self.cumulative_sizes 87 | for i, size in enumerate(self.group_sizes): 88 | if size > 0: 89 | indice = np.where(self.flag == i)[0] 90 | assert len(indice) == size 91 | indice_per_dataset = [] 92 | 93 | for j in range(len(self.cumulative_sizes)): 94 | indice_per_dataset.append( 95 | indice[ 96 | np.where( 97 | np.logical_and( 98 | indice >= cumulative_sizes[j], 99 | indice < cumulative_sizes[j + 1], 100 | ) 101 | )[0] 102 | ] 103 | ) 104 | 105 | shuffled_indice_per_dataset = [ 106 | s[list(torch.randperm(int(s.shape[0]), generator=g).numpy())] 107 | for s in indice_per_dataset 108 | ] 109 | # split into 110 | total_indice = [] 111 | batch_idx = 0 112 | # pdb.set_trace() 113 | while batch_idx < self.epoch_length[i] * self.num_replicas: 114 | ratio = [x / sum(self.sample_ratio) for x in self.sample_ratio] 115 | if self.by_prob: 116 | indicator = list( 117 | WeightedRandomSampler( 118 | ratio, 119 | self.samples_per_gpu, 120 | replacement=True, 121 | generator=g, 122 | ) 123 | ) 124 | unique, counts = np.unique(indicator, return_counts=True) 125 | ratio = [0] * len(shuffled_indice_per_dataset) 126 | for u, c in zip(unique, counts): 127 | ratio[u] = c 128 | assert len(ratio) == 2, "Only two set is supported" 129 | if ratio[0] == 0: 130 | ratio[0] = 1 131 | ratio[1] -= 1 132 | elif ratio[1] == 0: 133 | ratio[1] = 1 134 | ratio[0] -= 1 135 | 136 | ratio = [r / sum(ratio) for r in ratio] 137 | 138 | # num of each dataset 139 | ratio = [int(r * self.samples_per_gpu) for r in ratio] 140 | 141 | ratio[-1] = self.samples_per_gpu - sum(ratio[:-1]) 142 | selected = [] 143 | # print(ratio) 144 | for j in range(len(shuffled_indice_per_dataset)): 145 | if len(shuffled_indice_per_dataset[j]) < ratio[j]: 146 | shuffled_indice_per_dataset[j] = np.concatenate( 147 | ( 148 | shuffled_indice_per_dataset[j], 149 | indice_per_dataset[j][ 150 | list( 151 | torch.randperm( 152 | int(indice_per_dataset[j].shape[0]), 153 | generator=g, 154 | ).numpy() 155 | ) 156 | ], 157 | ) 158 | ) 159 | 160 | selected.append(shuffled_indice_per_dataset[j][: ratio[j]]) 161 | shuffled_indice_per_dataset[j] = shuffled_indice_per_dataset[j][ 162 | ratio[j] : 163 | ] 164 | selected = np.concatenate(selected) 165 | total_indice.append(selected) 166 | batch_idx += 1 167 | # print(self.size_of_dataset) 168 | indice = np.concatenate(total_indice) 169 | indices.append(indice) 170 | indices = np.concatenate(indices) # k 171 | indices = [ 172 | indices[j] 173 | for i in list( 174 | torch.randperm( 175 | len(indices) // self.samples_per_gpu, 176 | generator=g, 177 | ) 178 | ) 179 | for j in range( 180 | i * self.samples_per_gpu, 181 | (i + 1) * self.samples_per_gpu, 182 | ) 183 | ] 184 | 185 | offset = len(self) * self.rank 186 | indices = indices[offset : offset + len(self)] 187 | assert len(indices) == len(self) 188 | return iter(indices) 189 | 190 | def __len__(self): 191 | return sum(self.epoch_length) * self.samples_per_gpu 192 | 193 | def set_epoch(self, epoch): 194 | self.epoch = epoch 195 | 196 | # duplicated, implement it by weight instead of sampling 197 | # def update_sample_ratio(self): 198 | # if self.dynamic_step is not None: 199 | # self.sample_ratio = [d(self.epoch) for d in self.dynamic] 200 | -------------------------------------------------------------------------------- /ssod/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .soft_teacher import SoftTeacher -------------------------------------------------------------------------------- /ssod/models/multi_stream_detector.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | from mmdet.models import BaseDetector, TwoStageDetector 3 | 4 | 5 | class MultiSteamDetector(BaseDetector): 6 | def __init__( 7 | self, model: Dict[str, TwoStageDetector], train_cfg=None, test_cfg=None 8 | ): 9 | super(MultiSteamDetector, self).__init__() 10 | self.submodules = list(model.keys()) 11 | for k, v in model.items(): 12 | setattr(self, k, v) 13 | 14 | self.train_cfg = train_cfg 15 | self.test_cfg = test_cfg 16 | self.inference_on = self.test_cfg.get("inference_on", self.submodules[0]) 17 | 18 | def model(self, **kwargs) -> TwoStageDetector: 19 | if "submodule" in kwargs: 20 | assert ( 21 | kwargs["submodule"] in self.submodules 22 | ), "Detector does not contain submodule {}".format(kwargs["submodule"]) 23 | model: TwoStageDetector = getattr(self, kwargs["submodule"]) 24 | else: 25 | model: TwoStageDetector = getattr(self, self.inference_on) 26 | return model 27 | 28 | def freeze(self, model_ref: str): 29 | assert model_ref in self.submodules 30 | model = getattr(self, model_ref) 31 | model.eval() 32 | for param in model.parameters(): 33 | param.requires_grad = False 34 | 35 | def forward_test(self, imgs, img_metas, **kwargs): 36 | 37 | return self.model(**kwargs).forward_test(imgs, img_metas, **kwargs) 38 | 39 | async def aforward_test(self, *, img, img_metas, **kwargs): 40 | return self.model(**kwargs).aforward_test(img, img_metas, **kwargs) 41 | 42 | def extract_feat(self, imgs): 43 | return self.model().extract_feat(imgs) 44 | 45 | async def aforward_test(self, *, img, img_metas, **kwargs): 46 | return self.model(**kwargs).aforward_test(img, img_metas, **kwargs) 47 | 48 | def aug_test(self, imgs, img_metas, **kwargs): 49 | return self.model(**kwargs).aug_test(imgs, img_metas, **kwargs) 50 | 51 | def simple_test(self, img, img_metas, **kwargs): 52 | return self.model(**kwargs).simple_test(img, img_metas, **kwargs) 53 | 54 | async def async_simple_test(self, img, img_metas, **kwargs): 55 | return self.model(**kwargs).async_simple_test(img, img_metas, **kwargs) 56 | 57 | def show_result(self, *args, **kwargs): 58 | self.model().CLASSES = self.CLASSES 59 | return self.model().show_result(*args, **kwargs) 60 | -------------------------------------------------------------------------------- /ssod/models/soft_teacher.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from mmcv.runner.fp16_utils import force_fp32 3 | from mmdet.core import bbox2roi, multi_apply 4 | from mmdet.models import DETECTORS, build_detector 5 | 6 | from ssod.utils.structure_utils import dict_split, weighted_loss 7 | from ssod.utils import log_image_with_boxes, log_every_n 8 | 9 | from .multi_stream_detector import MultiSteamDetector 10 | from .utils import Transform2D, filter_invalid 11 | 12 | 13 | @DETECTORS.register_module() 14 | class SoftTeacher(MultiSteamDetector): 15 | def __init__(self, model: dict, train_cfg=None, test_cfg=None): 16 | super(SoftTeacher, self).__init__( 17 | dict(teacher=build_detector(model), student=build_detector(model)), 18 | train_cfg=train_cfg, 19 | test_cfg=test_cfg, 20 | ) 21 | if train_cfg is not None: 22 | self.freeze("teacher") 23 | self.unsup_weight = self.train_cfg.unsup_weight 24 | 25 | def forward_train(self, img, img_metas, **kwargs): 26 | super().forward_train(img, img_metas, **kwargs) 27 | kwargs.update({"img": img}) 28 | kwargs.update({"img_metas": img_metas}) 29 | kwargs.update({"tag": [meta["tag"] for meta in img_metas]}) 30 | data_groups = dict_split(kwargs, "tag") 31 | for _, v in data_groups.items(): 32 | v.pop("tag") 33 | 34 | loss = {} 35 | #! Warnings: By splitting losses for supervised data and unsupervised data with different names, 36 | #! it means that at least one sample for each group should be provided on each gpu. 37 | #! In some situation, we can only put one image per gpu, we have to return the sum of loss 38 | #! and log the loss with logger instead. Or it will try to sync tensors don't exist. 39 | if "sup" in data_groups: 40 | gt_bboxes = data_groups["sup"]["gt_bboxes"] 41 | log_every_n( 42 | {"sup_gt_num": sum([len(bbox) for bbox in gt_bboxes]) / len(gt_bboxes)} 43 | ) 44 | sup_loss = self.student.forward_train(**data_groups["sup"]) 45 | sup_loss = {"sup_" + k: v for k, v in sup_loss.items()} 46 | loss.update(**sup_loss) 47 | if "unsup_student" in data_groups: 48 | unsup_loss = weighted_loss( 49 | self.foward_unsup_train( 50 | data_groups["unsup_teacher"], data_groups["unsup_student"] 51 | ), 52 | weight=self.unsup_weight, 53 | ) 54 | unsup_loss = {"unsup_" + k: v for k, v in unsup_loss.items()} 55 | loss.update(**unsup_loss) 56 | 57 | return loss 58 | 59 | def foward_unsup_train(self, teacher_data, student_data): 60 | # sort the teacher and student input to avoid some bugs 61 | tnames = [meta["filename"] for meta in teacher_data["img_metas"]] 62 | snames = [meta["filename"] for meta in student_data["img_metas"]] 63 | tidx = [tnames.index(name) for name in snames] 64 | with torch.no_grad(): 65 | teacher_info = self.extract_teacher_info( 66 | teacher_data["img"][ 67 | torch.Tensor(tidx).to(teacher_data["img"].device).long() 68 | ], 69 | [teacher_data["img_metas"][idx] for idx in tidx], 70 | [teacher_data["proposals"][idx] for idx in tidx] 71 | if ("proposals" in teacher_data) 72 | and (teacher_data["proposals"] is not None) 73 | else None, 74 | ) 75 | student_info = self.extract_student_info(**student_data) 76 | 77 | return self.compute_pseudo_label_loss(student_info, teacher_info) 78 | 79 | def compute_pseudo_label_loss(self, student_info, teacher_info): 80 | M = self._get_trans_mat( 81 | teacher_info["transform_matrix"], student_info["transform_matrix"] 82 | ) 83 | 84 | pseudo_bboxes = self._transform_bbox( 85 | teacher_info["det_bboxes"], 86 | M, 87 | [meta["img_shape"] for meta in student_info["img_metas"]], 88 | ) 89 | pseudo_labels = teacher_info["det_labels"] 90 | loss = {} 91 | rpn_loss, proposal_list = self.rpn_loss( 92 | student_info["rpn_out"], 93 | pseudo_bboxes, 94 | student_info["img_metas"], 95 | student_info=student_info, 96 | ) 97 | loss.update(rpn_loss) 98 | if proposal_list is not None: 99 | student_info["proposals"] = proposal_list 100 | if self.train_cfg.use_teacher_proposal: 101 | proposals = self._transform_bbox( 102 | teacher_info["proposals"], 103 | M, 104 | [meta["img_shape"] for meta in student_info["img_metas"]], 105 | ) 106 | else: 107 | proposals = student_info["proposals"] 108 | 109 | loss.update( 110 | self.unsup_rcnn_cls_loss( 111 | student_info["backbone_feature"], 112 | student_info["img_metas"], 113 | proposals, 114 | pseudo_bboxes, 115 | pseudo_labels, 116 | teacher_info["transform_matrix"], 117 | student_info["transform_matrix"], 118 | teacher_info["img_metas"], 119 | teacher_info["backbone_feature"], 120 | student_info=student_info, 121 | ) 122 | ) 123 | loss.update( 124 | self.unsup_rcnn_reg_loss( 125 | student_info["backbone_feature"], 126 | student_info["img_metas"], 127 | proposals, 128 | pseudo_bboxes, 129 | pseudo_labels, 130 | student_info=student_info, 131 | ) 132 | ) 133 | return loss 134 | 135 | def rpn_loss( 136 | self, 137 | rpn_out, 138 | pseudo_bboxes, 139 | img_metas, 140 | gt_bboxes_ignore=None, 141 | student_info=None, 142 | **kwargs, 143 | ): 144 | if self.student.with_rpn: 145 | gt_bboxes = [] 146 | for bbox in pseudo_bboxes: 147 | bbox, _, _ = filter_invalid( 148 | bbox[:, :4], 149 | score=bbox[ 150 | :, 4 151 | ], # TODO: replace with foreground score, here is classification score, 152 | thr=self.train_cfg.rpn_pseudo_threshold, 153 | min_size=self.train_cfg.min_pseduo_box_size, 154 | ) 155 | gt_bboxes.append(bbox) 156 | log_every_n( 157 | {"rpn_gt_num": sum([len(bbox) for bbox in gt_bboxes]) / len(gt_bboxes)} 158 | ) 159 | loss_inputs = rpn_out + [[bbox.float() for bbox in gt_bboxes], img_metas] 160 | losses = self.student.rpn_head.loss( 161 | *loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore 162 | ) 163 | proposal_cfg = self.student.train_cfg.get( 164 | "rpn_proposal", self.student.test_cfg.rpn 165 | ) 166 | proposal_list = self.student.rpn_head.get_bboxes( 167 | *rpn_out, img_metas=img_metas, cfg=proposal_cfg 168 | ) 169 | log_image_with_boxes( 170 | "rpn", 171 | student_info["img"][0], 172 | pseudo_bboxes[0][:, :4], 173 | bbox_tag="rpn_pseudo_label", 174 | scores=pseudo_bboxes[0][:, 4], 175 | interval=500, 176 | img_norm_cfg=student_info["img_metas"][0]["img_norm_cfg"], 177 | ) 178 | return losses, proposal_list 179 | else: 180 | return {}, None 181 | 182 | def unsup_rcnn_cls_loss( 183 | self, 184 | feat, 185 | img_metas, 186 | proposal_list, 187 | pseudo_bboxes, 188 | pseudo_labels, 189 | teacher_transMat, 190 | student_transMat, 191 | teacher_img_metas, 192 | teacher_feat, 193 | student_info=None, 194 | **kwargs, 195 | ): 196 | gt_bboxes, gt_labels, _ = multi_apply( 197 | filter_invalid, 198 | [bbox[:, :4] for bbox in pseudo_bboxes], 199 | pseudo_labels, 200 | [bbox[:, 4] for bbox in pseudo_bboxes], 201 | thr=self.train_cfg.cls_pseudo_threshold, 202 | ) 203 | log_every_n( 204 | {"rcnn_cls_gt_num": sum([len(bbox) for bbox in gt_bboxes]) / len(gt_bboxes)} 205 | ) 206 | sampling_results = self.get_sampling_result( 207 | img_metas, 208 | proposal_list, 209 | gt_bboxes, 210 | gt_labels, 211 | ) 212 | selected_bboxes = [res.bboxes[:, :4] for res in sampling_results] 213 | rois = bbox2roi(selected_bboxes) 214 | bbox_results = self.student.roi_head._bbox_forward(feat, rois) 215 | bbox_targets = self.student.roi_head.bbox_head.get_targets( 216 | sampling_results, gt_bboxes, gt_labels, self.student.train_cfg.rcnn 217 | ) 218 | M = self._get_trans_mat(student_transMat, teacher_transMat) 219 | aligned_proposals = self._transform_bbox( 220 | selected_bboxes, 221 | M, 222 | [meta["img_shape"] for meta in teacher_img_metas], 223 | ) 224 | with torch.no_grad(): 225 | _, _scores = self.teacher.roi_head.simple_test_bboxes( 226 | teacher_feat, 227 | teacher_img_metas, 228 | aligned_proposals, 229 | None, 230 | rescale=False, 231 | ) 232 | bg_score = torch.cat([_score[:, -1] for _score in _scores]) 233 | assigned_label, _, _, _ = bbox_targets 234 | neg_inds = assigned_label == self.student.roi_head.bbox_head.num_classes 235 | bbox_targets[1][neg_inds] = bg_score[neg_inds].detach() 236 | loss = self.student.roi_head.bbox_head.loss( 237 | bbox_results["cls_score"], 238 | bbox_results["bbox_pred"], 239 | rois, 240 | *bbox_targets, 241 | reduction_override="none", 242 | ) 243 | loss["loss_cls"] = loss["loss_cls"].sum() / max(bbox_targets[1].sum(), 1.0) 244 | loss["loss_bbox"] = loss["loss_bbox"].sum() / max( 245 | bbox_targets[1].size()[0], 1.0 246 | ) 247 | if len(gt_bboxes[0]) > 0: 248 | log_image_with_boxes( 249 | "rcnn_cls", 250 | student_info["img"][0], 251 | gt_bboxes[0], 252 | bbox_tag="pseudo_label", 253 | labels=gt_labels[0], 254 | class_names=self.CLASSES, 255 | interval=500, 256 | img_norm_cfg=student_info["img_metas"][0]["img_norm_cfg"], 257 | ) 258 | return loss 259 | 260 | def unsup_rcnn_reg_loss( 261 | self, 262 | feat, 263 | img_metas, 264 | proposal_list, 265 | pseudo_bboxes, 266 | pseudo_labels, 267 | student_info=None, 268 | **kwargs, 269 | ): 270 | gt_bboxes, gt_labels, _ = multi_apply( 271 | filter_invalid, 272 | [bbox[:, :4] for bbox in pseudo_bboxes], 273 | pseudo_labels, 274 | [-bbox[:, 5:].mean(dim=-1) for bbox in pseudo_bboxes], 275 | thr=-self.train_cfg.reg_pseudo_threshold, 276 | ) 277 | log_every_n( 278 | {"rcnn_reg_gt_num": sum([len(bbox) for bbox in gt_bboxes]) / len(gt_bboxes)} 279 | ) 280 | loss_bbox = self.student.roi_head.forward_train( 281 | feat, img_metas, proposal_list, gt_bboxes, gt_labels, **kwargs 282 | )["loss_bbox"] 283 | if len(gt_bboxes[0]) > 0: 284 | log_image_with_boxes( 285 | "rcnn_reg", 286 | student_info["img"][0], 287 | gt_bboxes[0], 288 | bbox_tag="pseudo_label", 289 | labels=gt_labels[0], 290 | class_names=self.CLASSES, 291 | interval=500, 292 | img_norm_cfg=student_info["img_metas"][0]["img_norm_cfg"], 293 | ) 294 | return {"loss_bbox": loss_bbox} 295 | 296 | def get_sampling_result( 297 | self, 298 | img_metas, 299 | proposal_list, 300 | gt_bboxes, 301 | gt_labels, 302 | gt_bboxes_ignore=None, 303 | **kwargs, 304 | ): 305 | num_imgs = len(img_metas) 306 | if gt_bboxes_ignore is None: 307 | gt_bboxes_ignore = [None for _ in range(num_imgs)] 308 | sampling_results = [] 309 | for i in range(num_imgs): 310 | assign_result = self.student.roi_head.bbox_assigner.assign( 311 | proposal_list[i], gt_bboxes[i], gt_bboxes_ignore[i], gt_labels[i] 312 | ) 313 | sampling_result = self.student.roi_head.bbox_sampler.sample( 314 | assign_result, 315 | proposal_list[i], 316 | gt_bboxes[i], 317 | gt_labels[i], 318 | ) 319 | sampling_results.append(sampling_result) 320 | return sampling_results 321 | 322 | @force_fp32(apply_to=["bboxes", "trans_mat"]) 323 | def _transform_bbox(self, bboxes, trans_mat, max_shape): 324 | bboxes = Transform2D.transform_bboxes(bboxes, trans_mat, max_shape) 325 | return bboxes 326 | 327 | @force_fp32(apply_to=["a", "b"]) 328 | def _get_trans_mat(self, a, b): 329 | return [bt @ at.inverse() for bt, at in zip(b, a)] 330 | 331 | def extract_student_info(self, img, img_metas, proposals=None, **kwargs): 332 | student_info = {} 333 | student_info["img"] = img 334 | feat = self.student.extract_feat(img) 335 | student_info["backbone_feature"] = feat 336 | if self.student.with_rpn: 337 | rpn_out = self.student.rpn_head(feat) 338 | student_info["rpn_out"] = list(rpn_out) 339 | student_info["img_metas"] = img_metas 340 | student_info["proposals"] = proposals 341 | student_info["transform_matrix"] = [ 342 | torch.from_numpy(meta["transform_matrix"]).float().to(feat[0][0].device) 343 | for meta in img_metas 344 | ] 345 | return student_info 346 | 347 | def extract_teacher_info(self, img, img_metas, proposals=None, **kwargs): 348 | teacher_info = {} 349 | feat = self.teacher.extract_feat(img) 350 | teacher_info["backbone_feature"] = feat 351 | if proposals is None: 352 | proposal_cfg = self.teacher.train_cfg.get( 353 | "rpn_proposal", self.teacher.test_cfg.rpn 354 | ) 355 | rpn_out = list(self.teacher.rpn_head(feat)) 356 | proposal_list = self.teacher.rpn_head.get_bboxes( 357 | *rpn_out, img_metas=img_metas, cfg=proposal_cfg 358 | ) 359 | else: 360 | proposal_list = proposals 361 | teacher_info["proposals"] = proposal_list 362 | 363 | proposal_list, proposal_label_list = self.teacher.roi_head.simple_test_bboxes( 364 | feat, img_metas, proposal_list, self.teacher.test_cfg.rcnn, rescale=False 365 | ) 366 | 367 | proposal_list = [p.to(feat[0].device) for p in proposal_list] 368 | proposal_list = [ 369 | p if p.shape[0] > 0 else p.new_zeros(0, 5) for p in proposal_list 370 | ] 371 | proposal_label_list = [p.to(feat[0].device) for p in proposal_label_list] 372 | # filter invalid box roughly 373 | if isinstance(self.train_cfg.pseudo_label_initial_score_thr, float): 374 | thr = self.train_cfg.pseudo_label_initial_score_thr 375 | else: 376 | # TODO: use dynamic threshold 377 | raise NotImplementedError("Dynamic Threshold is not implemented yet.") 378 | proposal_list, proposal_label_list, _ = list( 379 | zip( 380 | *[ 381 | filter_invalid( 382 | proposal, 383 | proposal_label, 384 | proposal[:, -1], 385 | thr=thr, 386 | min_size=self.train_cfg.min_pseduo_box_size, 387 | ) 388 | for proposal, proposal_label in zip( 389 | proposal_list, proposal_label_list 390 | ) 391 | ] 392 | ) 393 | ) 394 | det_bboxes = proposal_list 395 | reg_unc = self.compute_uncertainty_with_aug( 396 | feat, img_metas, proposal_list, proposal_label_list 397 | ) 398 | det_bboxes = [ 399 | torch.cat([bbox, unc], dim=-1) for bbox, unc in zip(det_bboxes, reg_unc) 400 | ] 401 | det_labels = proposal_label_list 402 | teacher_info["det_bboxes"] = det_bboxes 403 | teacher_info["det_labels"] = det_labels 404 | teacher_info["transform_matrix"] = [ 405 | torch.from_numpy(meta["transform_matrix"]).float().to(feat[0][0].device) 406 | for meta in img_metas 407 | ] 408 | teacher_info["img_metas"] = img_metas 409 | return teacher_info 410 | 411 | def compute_uncertainty_with_aug( 412 | self, feat, img_metas, proposal_list, proposal_label_list 413 | ): 414 | auged_proposal_list = self.aug_box( 415 | proposal_list, self.train_cfg.jitter_times, self.train_cfg.jitter_scale 416 | ) 417 | # flatten 418 | auged_proposal_list = [ 419 | auged.reshape(-1, auged.shape[-1]) for auged in auged_proposal_list 420 | ] 421 | 422 | bboxes, _ = self.teacher.roi_head.simple_test_bboxes( 423 | feat, 424 | img_metas, 425 | auged_proposal_list, 426 | None, 427 | rescale=False, 428 | ) 429 | reg_channel = max([bbox.shape[-1] for bbox in bboxes]) // 4 430 | bboxes = [ 431 | bbox.reshape(self.train_cfg.jitter_times, -1, bbox.shape[-1]) 432 | if bbox.numel() > 0 433 | else bbox.new_zeros(self.train_cfg.jitter_times, 0, 4 * reg_channel).float() 434 | for bbox in bboxes 435 | ] 436 | 437 | box_unc = [bbox.std(dim=0) for bbox in bboxes] 438 | bboxes = [bbox.mean(dim=0) for bbox in bboxes] 439 | # scores = [score.mean(dim=0) for score in scores] 440 | if reg_channel != 1: 441 | bboxes = [ 442 | bbox.reshape(bbox.shape[0], reg_channel, 4)[ 443 | torch.arange(bbox.shape[0]), label 444 | ] 445 | for bbox, label in zip(bboxes, proposal_label_list) 446 | ] 447 | box_unc = [ 448 | unc.reshape(unc.shape[0], reg_channel, 4)[ 449 | torch.arange(unc.shape[0]), label 450 | ] 451 | for unc, label in zip(box_unc, proposal_label_list) 452 | ] 453 | 454 | box_shape = [(bbox[:, 2:4] - bbox[:, :2]).clamp(min=1.0) for bbox in bboxes] 455 | # relative unc 456 | box_unc = [ 457 | unc / wh[:, None, :].expand(-1, 2, 2).reshape(-1, 4) 458 | if wh.numel() > 0 459 | else unc 460 | for unc, wh in zip(box_unc, box_shape) 461 | ] 462 | return box_unc 463 | 464 | @staticmethod 465 | def aug_box(boxes, times=1, frac=0.06): 466 | def _aug_single(box): 467 | # random translate 468 | # TODO: random flip or something 469 | box_scale = box[:, 2:4] - box[:, :2] 470 | box_scale = ( 471 | box_scale.clamp(min=1)[:, None, :].expand(-1, 2, 2).reshape(-1, 4) 472 | ) 473 | aug_scale = box_scale * frac # [n,4] 474 | 475 | offset = ( 476 | torch.randn(times, box.shape[0], 4, device=box.device) 477 | * aug_scale[None, ...] 478 | ) 479 | new_box = box.clone()[None, ...].expand(times, box.shape[0], -1) 480 | return torch.cat( 481 | [new_box[:, :, :4].clone() + offset, new_box[:, :, 4:]], dim=-1 482 | ) 483 | 484 | return [_aug_single(box) for box in boxes] 485 | 486 | def _load_from_state_dict( 487 | self, 488 | state_dict, 489 | prefix, 490 | local_metadata, 491 | strict, 492 | missing_keys, 493 | unexpected_keys, 494 | error_msgs, 495 | ): 496 | if not any(["student" in key or "teacher" in key for key in state_dict.keys()]): 497 | keys = list(state_dict.keys()) 498 | state_dict.update({"teacher." + k: state_dict[k] for k in keys}) 499 | state_dict.update({"student." + k: state_dict[k] for k in keys}) 500 | for k in keys: 501 | state_dict.pop(k) 502 | 503 | return super()._load_from_state_dict( 504 | state_dict, 505 | prefix, 506 | local_metadata, 507 | strict, 508 | missing_keys, 509 | unexpected_keys, 510 | error_msgs, 511 | ) 512 | -------------------------------------------------------------------------------- /ssod/models/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .bbox_utils import Transform2D, filter_invalid 2 | -------------------------------------------------------------------------------- /ssod/models/utils/bbox_utils.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from collections.abc import Sequence 3 | from typing import List, Optional, Tuple, Union 4 | 5 | import numpy as np 6 | import torch 7 | from mmdet.core.mask.structures import BitmapMasks 8 | from torch.nn import functional as F 9 | 10 | 11 | def bbox2points(box): 12 | min_x, min_y, max_x, max_y = torch.split(box[:, :4], [1, 1, 1, 1], dim=1) 13 | 14 | return torch.cat( 15 | [min_x, min_y, max_x, min_y, max_x, max_y, min_x, max_y], dim=1 16 | ).reshape( 17 | -1, 2 18 | ) # n*4,2 19 | 20 | 21 | def points2bbox(point, max_w, max_h): 22 | point = point.reshape(-1, 4, 2) 23 | if point.size()[0] > 0: 24 | min_xy = point.min(dim=1)[0] 25 | max_xy = point.max(dim=1)[0] 26 | xmin = min_xy[:, 0].clamp(min=0, max=max_w) 27 | ymin = min_xy[:, 1].clamp(min=0, max=max_h) 28 | xmax = max_xy[:, 0].clamp(min=0, max=max_w) 29 | ymax = max_xy[:, 1].clamp(min=0, max=max_h) 30 | min_xy = torch.stack([xmin, ymin], dim=1) 31 | max_xy = torch.stack([xmax, ymax], dim=1) 32 | return torch.cat([min_xy, max_xy], dim=1) # n,4 33 | else: 34 | return point.new_zeros(0, 4) 35 | 36 | 37 | def check_is_tensor(obj): 38 | """Checks whether the supplied object is a tensor.""" 39 | if not isinstance(obj, torch.Tensor): 40 | raise TypeError("Input type is not a torch.Tensor. Got {}".format(type(obj))) 41 | 42 | 43 | def normal_transform_pixel( 44 | height: int, 45 | width: int, 46 | eps: float = 1e-14, 47 | device: Optional[torch.device] = None, 48 | dtype: Optional[torch.dtype] = None, 49 | ) -> torch.Tensor: 50 | tr_mat = torch.tensor( 51 | [[1.0, 0.0, -1.0], [0.0, 1.0, -1.0], [0.0, 0.0, 1.0]], 52 | device=device, 53 | dtype=dtype, 54 | ) # 3x3 55 | 56 | # prevent divide by zero bugs 57 | width_denom: float = eps if width == 1 else width - 1.0 58 | height_denom: float = eps if height == 1 else height - 1.0 59 | 60 | tr_mat[0, 0] = tr_mat[0, 0] * 2.0 / width_denom 61 | tr_mat[1, 1] = tr_mat[1, 1] * 2.0 / height_denom 62 | 63 | return tr_mat.unsqueeze(0) # 1x3x3 64 | 65 | 66 | def normalize_homography( 67 | dst_pix_trans_src_pix: torch.Tensor, 68 | dsize_src: Tuple[int, int], 69 | dsize_dst: Tuple[int, int], 70 | ) -> torch.Tensor: 71 | check_is_tensor(dst_pix_trans_src_pix) 72 | 73 | if not ( 74 | len(dst_pix_trans_src_pix.shape) == 3 75 | or dst_pix_trans_src_pix.shape[-2:] == (3, 3) 76 | ): 77 | raise ValueError( 78 | "Input dst_pix_trans_src_pix must be a Bx3x3 tensor. Got {}".format( 79 | dst_pix_trans_src_pix.shape 80 | ) 81 | ) 82 | 83 | # source and destination sizes 84 | src_h, src_w = dsize_src 85 | dst_h, dst_w = dsize_dst 86 | 87 | # compute the transformation pixel/norm for src/dst 88 | src_norm_trans_src_pix: torch.Tensor = normal_transform_pixel(src_h, src_w).to( 89 | dst_pix_trans_src_pix 90 | ) 91 | src_pix_trans_src_norm = torch.inverse(src_norm_trans_src_pix.float()).to( 92 | src_norm_trans_src_pix.dtype 93 | ) 94 | dst_norm_trans_dst_pix: torch.Tensor = normal_transform_pixel(dst_h, dst_w).to( 95 | dst_pix_trans_src_pix 96 | ) 97 | 98 | # compute chain transformations 99 | dst_norm_trans_src_norm: torch.Tensor = dst_norm_trans_dst_pix @ ( 100 | dst_pix_trans_src_pix @ src_pix_trans_src_norm 101 | ) 102 | return dst_norm_trans_src_norm 103 | 104 | 105 | def warp_affine( 106 | src: torch.Tensor, 107 | M: torch.Tensor, 108 | dsize: Tuple[int, int], 109 | mode: str = "bilinear", 110 | padding_mode: str = "zeros", 111 | align_corners: Optional[bool] = None, 112 | ) -> torch.Tensor: 113 | if not isinstance(src, torch.Tensor): 114 | raise TypeError( 115 | "Input src type is not a torch.Tensor. Got {}".format(type(src)) 116 | ) 117 | 118 | if not isinstance(M, torch.Tensor): 119 | raise TypeError("Input M type is not a torch.Tensor. Got {}".format(type(M))) 120 | 121 | if not len(src.shape) == 4: 122 | raise ValueError("Input src must be a BxCxHxW tensor. Got {}".format(src.shape)) 123 | 124 | if not (len(M.shape) == 3 or M.shape[-2:] == (2, 3)): 125 | raise ValueError("Input M must be a Bx2x3 tensor. Got {}".format(M.shape)) 126 | 127 | # TODO: remove the statement below in kornia v0.6 128 | if align_corners is None: 129 | message: str = ( 130 | "The align_corners default value has been changed. By default now is set True " 131 | "in order to match cv2.warpAffine." 132 | ) 133 | warnings.warn(message) 134 | # set default value for align corners 135 | align_corners = True 136 | 137 | B, C, H, W = src.size() 138 | 139 | # we generate a 3x3 transformation matrix from 2x3 affine 140 | 141 | dst_norm_trans_src_norm: torch.Tensor = normalize_homography(M, (H, W), dsize) 142 | 143 | src_norm_trans_dst_norm = torch.inverse(dst_norm_trans_src_norm.float()) 144 | 145 | grid = F.affine_grid( 146 | src_norm_trans_dst_norm[:, :2, :], 147 | [B, C, dsize[0], dsize[1]], 148 | align_corners=align_corners, 149 | ) 150 | 151 | return F.grid_sample( 152 | src.float(), 153 | grid, 154 | align_corners=align_corners, 155 | mode=mode, 156 | padding_mode=padding_mode, 157 | ).to(src.dtype) 158 | 159 | 160 | class Transform2D: 161 | @staticmethod 162 | def transform_bboxes(bbox, M, out_shape): 163 | if isinstance(bbox, Sequence): 164 | assert len(bbox) == len(M) 165 | return [ 166 | Transform2D.transform_bboxes(b, m, o) 167 | for b, m, o in zip(bbox, M, out_shape) 168 | ] 169 | else: 170 | if bbox.shape[0] == 0: 171 | return bbox 172 | score = None 173 | if bbox.shape[1] > 4: 174 | score = bbox[:, 4:] 175 | points = bbox2points(bbox[:, :4]) 176 | points = torch.cat( 177 | [points, points.new_ones(points.shape[0], 1)], dim=1 178 | ) # n,3 179 | points = torch.matmul(M, points.t()).t() 180 | points = points[:, :2] / points[:, 2:3] 181 | bbox = points2bbox(points, out_shape[1], out_shape[0]) 182 | if score is not None: 183 | return torch.cat([bbox, score], dim=1) 184 | return bbox 185 | 186 | @staticmethod 187 | def transform_masks( 188 | mask: Union[BitmapMasks, List[BitmapMasks]], 189 | M: Union[torch.Tensor, List[torch.Tensor]], 190 | out_shape: Union[list, List[list]], 191 | ): 192 | if isinstance(mask, Sequence): 193 | assert len(mask) == len(M) 194 | return [ 195 | Transform2D.transform_masks(b, m, o) 196 | for b, m, o in zip(mask, M, out_shape) 197 | ] 198 | else: 199 | if mask.masks.shape[0] == 0: 200 | return BitmapMasks(np.zeros((0, *out_shape)), *out_shape) 201 | mask_tensor = ( 202 | torch.from_numpy(mask.masks[:, None, ...]).to(M.device).to(M.dtype) 203 | ) 204 | return BitmapMasks( 205 | warp_affine( 206 | mask_tensor, 207 | M[None, ...].expand(mask.masks.shape[0], -1, -1), 208 | out_shape, 209 | ) 210 | .squeeze(1) 211 | .cpu() 212 | .numpy(), 213 | out_shape[0], 214 | out_shape[1], 215 | ) 216 | 217 | @staticmethod 218 | def transform_image(img, M, out_shape): 219 | if isinstance(img, Sequence): 220 | assert len(img) == len(M) 221 | return [ 222 | Transform2D.transform_image(b, m, shape) 223 | for b, m, shape in zip(img, M, out_shape) 224 | ] 225 | else: 226 | if img.dim() == 2: 227 | img = img[None, None, ...] 228 | elif img.dim() == 3: 229 | img = img[None, ...] 230 | 231 | return ( 232 | warp_affine(img.float(), M[None, ...], out_shape, mode="nearest") 233 | .squeeze() 234 | .to(img.dtype) 235 | ) 236 | 237 | 238 | def filter_invalid(bbox, label=None, score=None, mask=None, thr=0.0, min_size=0): 239 | if score is not None: 240 | valid = score > thr 241 | bbox = bbox[valid] 242 | if label is not None: 243 | label = label[valid] 244 | if mask is not None: 245 | mask = BitmapMasks(mask.masks[valid.cpu().numpy()], mask.height, mask.width) 246 | if min_size is not None: 247 | bw = bbox[:, 2] - bbox[:, 0] 248 | bh = bbox[:, 3] - bbox[:, 1] 249 | valid = (bw > min_size) & (bh > min_size) 250 | bbox = bbox[valid] 251 | if label is not None: 252 | label = label[valid] 253 | if mask is not None: 254 | mask = BitmapMasks(mask.masks[valid.cpu().numpy()], mask.height, mask.width) 255 | return bbox, label, mask 256 | -------------------------------------------------------------------------------- /ssod/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .exts import NamedOptimizerConstructor 2 | from .hooks import Weighter, MeanTeacher, WeightSummary, SubModulesDistEvalHook 3 | from .logger import get_root_logger, log_every_n, log_image_with_boxes 4 | from .patch import patch_config, patch_runner, find_latest_checkpoint 5 | 6 | 7 | __all__ = [ 8 | "get_root_logger", 9 | "log_every_n", 10 | "log_image_with_boxes", 11 | "patch_config", 12 | "patch_runner", 13 | "find_latest_checkpoint", 14 | "Weighter", 15 | "MeanTeacher", 16 | "WeightSummary", 17 | "SubModulesDistEvalHook", 18 | "NamedOptimizerConstructor", 19 | ] 20 | -------------------------------------------------------------------------------- /ssod/utils/exts/__init__.py: -------------------------------------------------------------------------------- 1 | from .optimizer_constructor import NamedOptimizerConstructor 2 | -------------------------------------------------------------------------------- /ssod/utils/exts/optimizer_constructor.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import torch 4 | from torch.nn import GroupNorm, LayerNorm 5 | 6 | from mmcv.utils import _BatchNorm, _InstanceNorm, build_from_cfg 7 | from mmcv.utils.ext_loader import check_ops_exist 8 | from mmcv.runner.optimizer.builder import OPTIMIZER_BUILDERS, OPTIMIZERS 9 | from mmcv.runner.optimizer import DefaultOptimizerConstructor 10 | 11 | 12 | @OPTIMIZER_BUILDERS.register_module() 13 | class NamedOptimizerConstructor(DefaultOptimizerConstructor): 14 | """Main difference to default constructor: 15 | 16 | 1) Add name to parame groups 17 | """ 18 | 19 | def add_params(self, params, module, prefix="", is_dcn_module=None): 20 | """Add all parameters of module to the params list. 21 | 22 | The parameters of the given module will be added to the list of param 23 | groups, with specific rules defined by paramwise_cfg. 24 | 25 | Args: 26 | params (list[dict]): A list of param groups, it will be modified 27 | in place. 28 | module (nn.Module): The module to be added. 29 | prefix (str): The prefix of the module 30 | is_dcn_module (int|float|None): If the current module is a 31 | submodule of DCN, `is_dcn_module` will be passed to 32 | control conv_offset layer's learning rate. Defaults to None. 33 | """ 34 | # get param-wise options 35 | custom_keys = self.paramwise_cfg.get("custom_keys", {}) 36 | # first sort with alphabet order and then sort with reversed len of str 37 | sorted_keys = sorted(sorted(custom_keys.keys()), key=len, reverse=True) 38 | 39 | bias_lr_mult = self.paramwise_cfg.get("bias_lr_mult", 1.0) 40 | bias_decay_mult = self.paramwise_cfg.get("bias_decay_mult", 1.0) 41 | norm_decay_mult = self.paramwise_cfg.get("norm_decay_mult", 1.0) 42 | dwconv_decay_mult = self.paramwise_cfg.get("dwconv_decay_mult", 1.0) 43 | bypass_duplicate = self.paramwise_cfg.get("bypass_duplicate", False) 44 | dcn_offset_lr_mult = self.paramwise_cfg.get("dcn_offset_lr_mult", 1.0) 45 | 46 | # special rules for norm layers and depth-wise conv layers 47 | is_norm = isinstance(module, (_BatchNorm, _InstanceNorm, GroupNorm, LayerNorm)) 48 | is_dwconv = ( 49 | isinstance(module, torch.nn.Conv2d) and module.in_channels == module.groups 50 | ) 51 | 52 | for name, param in module.named_parameters(recurse=False): 53 | param_group = {"params": [param], "name": f"{prefix}.{name}"} 54 | if not param.requires_grad: 55 | params.append(param_group) 56 | continue 57 | if bypass_duplicate and self._is_in(param_group, params): 58 | warnings.warn( 59 | f"{prefix} is duplicate. It is skipped since " 60 | f"bypass_duplicate={bypass_duplicate}" 61 | ) 62 | continue 63 | # if the parameter match one of the custom keys, ignore other rules 64 | is_custom = False 65 | for key in sorted_keys: 66 | if key in f"{prefix}.{name}": 67 | is_custom = True 68 | lr_mult = custom_keys[key].get("lr_mult", 1.0) 69 | param_group["lr"] = self.base_lr * lr_mult 70 | if self.base_wd is not None: 71 | decay_mult = custom_keys[key].get("decay_mult", 1.0) 72 | param_group["weight_decay"] = self.base_wd * decay_mult 73 | break 74 | 75 | if not is_custom: 76 | # bias_lr_mult affects all bias parameters 77 | # except for norm.bias dcn.conv_offset.bias 78 | if name == "bias" and not (is_norm or is_dcn_module): 79 | param_group["lr"] = self.base_lr * bias_lr_mult 80 | 81 | if ( 82 | prefix.find("conv_offset") != -1 83 | and is_dcn_module 84 | and isinstance(module, torch.nn.Conv2d) 85 | ): 86 | # deal with both dcn_offset's bias & weight 87 | param_group["lr"] = self.base_lr * dcn_offset_lr_mult 88 | 89 | # apply weight decay policies 90 | if self.base_wd is not None: 91 | # norm decay 92 | if is_norm: 93 | param_group["weight_decay"] = self.base_wd * norm_decay_mult 94 | # depth-wise conv 95 | elif is_dwconv: 96 | param_group["weight_decay"] = self.base_wd * dwconv_decay_mult 97 | # bias lr and decay 98 | elif name == "bias" and not is_dcn_module: 99 | # TODO: current bias_decay_mult will have affect on DCN 100 | param_group["weight_decay"] = self.base_wd * bias_decay_mult 101 | params.append(param_group) 102 | 103 | if check_ops_exist(): 104 | from mmcv.ops import DeformConv2d, ModulatedDeformConv2d 105 | 106 | is_dcn_module = isinstance(module, (DeformConv2d, ModulatedDeformConv2d)) 107 | else: 108 | is_dcn_module = False 109 | for child_name, child_mod in module.named_children(): 110 | child_prefix = f"{prefix}.{child_name}" if prefix else child_name 111 | self.add_params( 112 | params, child_mod, prefix=child_prefix, is_dcn_module=is_dcn_module 113 | ) 114 | -------------------------------------------------------------------------------- /ssod/utils/hooks/__init__.py: -------------------------------------------------------------------------------- 1 | from .weight_adjust import Weighter 2 | from .mean_teacher import MeanTeacher 3 | from .weights_summary import WeightSummary 4 | from .evaluation import DistEvalHook 5 | from .submodules_evaluation import SubModulesDistEvalHook # ,SubModulesEvalHook 6 | 7 | 8 | __all__ = [ 9 | "Weighter", 10 | "MeanTeacher", 11 | "DistEvalHook", 12 | "SubModulesDistEvalHook", 13 | "WeightSummary", 14 | ] 15 | -------------------------------------------------------------------------------- /ssod/utils/hooks/evaluation.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | import torch.distributed as dist 4 | from mmcv.runner.hooks import LoggerHook, WandbLoggerHook 5 | from mmdet.core import DistEvalHook as BaseDistEvalHook 6 | from torch.nn.modules.batchnorm import _BatchNorm 7 | 8 | 9 | class DistEvalHook(BaseDistEvalHook): 10 | def after_train_iter(self, runner): 11 | """Called after every training iter to evaluate the results.""" 12 | if not self.by_epoch and self._should_evaluate(runner): 13 | for hook in runner._hooks: 14 | if isinstance(hook, WandbLoggerHook): 15 | _commit_state = hook.commit 16 | hook.commit = False 17 | if isinstance(hook, LoggerHook): 18 | hook.after_train_iter(runner) 19 | if isinstance(hook, WandbLoggerHook): 20 | hook.commit = _commit_state 21 | runner.log_buffer.clear() 22 | 23 | self._do_evaluate(runner) 24 | 25 | def _do_evaluate(self, runner): 26 | """perform evaluation and save ckpt.""" 27 | # Synchronization of BatchNorm's buffer (running_mean 28 | # and running_var) is not supported in the DDP of pytorch, 29 | # which may cause the inconsistent performance of models in 30 | # different ranks, so we broadcast BatchNorm's buffers 31 | # of rank 0 to other ranks to avoid this. 32 | if self.broadcast_bn_buffer: 33 | model = runner.model 34 | for name, module in model.named_modules(): 35 | if isinstance(module, _BatchNorm) and module.track_running_stats: 36 | dist.broadcast(module.running_var, 0) 37 | dist.broadcast(module.running_mean, 0) 38 | 39 | if not self._should_evaluate(runner): 40 | return 41 | 42 | tmpdir = self.tmpdir 43 | if tmpdir is None: 44 | tmpdir = osp.join(runner.work_dir, ".eval_hook") 45 | 46 | from mmdet.apis import multi_gpu_test 47 | 48 | results = multi_gpu_test( 49 | runner.model, self.dataloader, tmpdir=tmpdir, gpu_collect=self.gpu_collect 50 | ) 51 | if runner.rank == 0: 52 | print("\n") 53 | # runner.log_buffer.output['eval_iter_num'] = len(self.dataloader) 54 | key_score = self.evaluate(runner, results) 55 | 56 | if self.save_best: 57 | self._save_ckpt(runner, key_score) 58 | -------------------------------------------------------------------------------- /ssod/utils/hooks/mean_teacher.py: -------------------------------------------------------------------------------- 1 | from mmcv.parallel import is_module_wrapper 2 | from mmcv.runner.hooks import HOOKS, Hook 3 | from bisect import bisect_right 4 | from ..logger import log_every_n 5 | 6 | 7 | @HOOKS.register_module() 8 | class MeanTeacher(Hook): 9 | def __init__( 10 | self, 11 | momentum=0.999, 12 | interval=1, 13 | warm_up=100, 14 | decay_intervals=None, 15 | decay_factor=0.1, 16 | ): 17 | assert momentum >= 0 and momentum <= 1 18 | self.momentum = momentum 19 | assert isinstance(interval, int) and interval > 0 20 | self.warm_up = warm_up 21 | self.interval = interval 22 | assert isinstance(decay_intervals, list) or decay_intervals is None 23 | self.decay_intervals = decay_intervals 24 | self.decay_factor = decay_factor 25 | 26 | def before_run(self, runner): 27 | model = runner.model 28 | if is_module_wrapper(model): 29 | model = model.module 30 | assert hasattr(model, "teacher") 31 | assert hasattr(model, "student") 32 | # only do it at initial stage 33 | if runner.iter == 0: 34 | log_every_n("Clone all parameters of student to teacher...") 35 | self.momentum_update(model, 0) 36 | 37 | def before_train_iter(self, runner): 38 | """Update ema parameter every self.interval iterations.""" 39 | curr_step = runner.iter 40 | if curr_step % self.interval != 0: 41 | return 42 | model = runner.model 43 | if is_module_wrapper(model): 44 | model = model.module 45 | # We warm up the momentum considering the instability at beginning 46 | momentum = min( 47 | self.momentum, 1 - (1 + self.warm_up) / (curr_step + 1 + self.warm_up) 48 | ) 49 | runner.log_buffer.output["ema_momentum"] = momentum 50 | self.momentum_update(model, momentum) 51 | 52 | def after_train_iter(self, runner): 53 | curr_step = runner.iter 54 | if self.decay_intervals is None: 55 | return 56 | self.momentum = 1 - (1 - self.momentum) * self.decay_factor ** bisect_right( 57 | self.decay_intervals, curr_step 58 | ) 59 | 60 | def momentum_update(self, model, momentum): 61 | for (src_name, src_parm), (tgt_name, tgt_parm) in zip( 62 | model.student.named_parameters(), model.teacher.named_parameters() 63 | ): 64 | tgt_parm.data.mul_(momentum).add_(src_parm.data, alpha=1 - momentum) 65 | -------------------------------------------------------------------------------- /ssod/utils/hooks/submodules_evaluation.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | import torch.distributed as dist 4 | from mmcv.parallel import is_module_wrapper 5 | from mmcv.runner.hooks import HOOKS, LoggerHook, WandbLoggerHook 6 | from mmdet.core import DistEvalHook 7 | from torch.nn.modules.batchnorm import _BatchNorm 8 | 9 | 10 | @HOOKS.register_module() 11 | class SubModulesDistEvalHook(DistEvalHook): 12 | def __init__(self, *args, evaluated_modules=None, **kwargs): 13 | super().__init__(*args, **kwargs) 14 | self.evaluated_modules = evaluated_modules 15 | 16 | def before_run(self, runner): 17 | if is_module_wrapper(runner.model): 18 | model = runner.model.module 19 | else: 20 | model = runner.model 21 | assert hasattr(model, "submodules") 22 | assert hasattr(model, "inference_on") 23 | 24 | def after_train_iter(self, runner): 25 | """Called after every training iter to evaluate the results.""" 26 | if not self.by_epoch and self._should_evaluate(runner): 27 | for hook in runner._hooks: 28 | if isinstance(hook, WandbLoggerHook): 29 | _commit_state = hook.commit 30 | hook.commit = False 31 | if isinstance(hook, LoggerHook): 32 | hook.after_train_iter(runner) 33 | if isinstance(hook, WandbLoggerHook): 34 | hook.commit = _commit_state 35 | runner.log_buffer.clear() 36 | 37 | self._do_evaluate(runner) 38 | 39 | def _do_evaluate(self, runner): 40 | """perform evaluation and save ckpt.""" 41 | # Synchronization of BatchNorm's buffer (running_mean 42 | # and running_var) is not supported in the DDP of pytorch, 43 | # which may cause the inconsistent performance of models in 44 | # different ranks, so we broadcast BatchNorm's buffers 45 | # of rank 0 to other ranks to avoid this. 46 | 47 | if self.broadcast_bn_buffer: 48 | model = runner.model 49 | for name, module in model.named_modules(): 50 | if isinstance(module, _BatchNorm) and module.track_running_stats: 51 | dist.broadcast(module.running_var, 0) 52 | dist.broadcast(module.running_mean, 0) 53 | 54 | if not self._should_evaluate(runner): 55 | return 56 | 57 | tmpdir = self.tmpdir 58 | if tmpdir is None: 59 | tmpdir = osp.join(runner.work_dir, ".eval_hook") 60 | 61 | if is_module_wrapper(runner.model): 62 | model_ref = runner.model.module 63 | else: 64 | model_ref = runner.model 65 | if not self.evaluated_modules: 66 | submodules = model_ref.submodules 67 | else: 68 | submodules = self.evaluated_modules 69 | key_scores = [] 70 | from mmdet.apis import multi_gpu_test 71 | 72 | for submodule in submodules: 73 | # change inference on 74 | model_ref.inference_on = submodule 75 | results = multi_gpu_test( 76 | runner.model, 77 | self.dataloader, 78 | tmpdir=tmpdir, 79 | gpu_collect=self.gpu_collect, 80 | ) 81 | if runner.rank == 0: 82 | key_score = self.evaluate(runner, results, prefix=submodule) 83 | if key_score is not None: 84 | key_scores.append(key_score) 85 | 86 | if runner.rank == 0: 87 | runner.log_buffer.ready = True 88 | if len(key_scores) == 0: 89 | key_scores = [None] 90 | best_score = key_scores[0] 91 | for key_score in key_scores: 92 | if hasattr(self, "compare_func") and self.compare_func( 93 | key_score, best_score 94 | ): 95 | best_score = key_score 96 | 97 | print("\n") 98 | # runner.log_buffer.output["eval_iter_num"] = len(self.dataloader) 99 | if self.save_best: 100 | self._save_ckpt(runner, best_score) 101 | 102 | def evaluate(self, runner, results, prefix=""): 103 | """Evaluate the results. 104 | 105 | Args: 106 | runner (:obj:`mmcv.Runner`): The underlined training runner. 107 | results (list): Output results. 108 | """ 109 | eval_res = self.dataloader.dataset.evaluate( 110 | results, logger=runner.logger, **self.eval_kwargs 111 | ) 112 | for name, val in eval_res.items(): 113 | runner.log_buffer.output[(".").join([prefix, name])] = val 114 | 115 | if self.save_best is not None: 116 | if self.key_indicator == "auto": 117 | # infer from eval_results 118 | self._init_rule(self.rule, list(eval_res.keys())[0]) 119 | return eval_res[self.key_indicator] 120 | 121 | return None 122 | -------------------------------------------------------------------------------- /ssod/utils/hooks/weight_adjust.py: -------------------------------------------------------------------------------- 1 | from mmcv.parallel import is_module_wrapper 2 | from mmcv.runner.hooks import HOOKS, Hook 3 | from bisect import bisect_right 4 | 5 | 6 | @HOOKS.register_module() 7 | class Weighter(Hook): 8 | def __init__( 9 | self, 10 | steps=None, 11 | vals=None, 12 | name=None, 13 | ): 14 | self.steps = steps 15 | self.vals = vals 16 | self.name = name 17 | if self.name is not None: 18 | assert self.steps is not None 19 | assert self.vals is not None 20 | assert len(self.vals) == len(self.steps) + 1 21 | 22 | def before_train_iter(self, runner): 23 | curr_step = runner.iter 24 | if self.name is None: 25 | return 26 | model = runner.model 27 | if is_module_wrapper(model): 28 | model = model.module 29 | assert hasattr(model, self.name) 30 | self.steps = [s if s > 0 else runner.max_iters - s for s in self.steps] 31 | runner.log_buffer.output[self.name] = self.vals[ 32 | bisect_right(self.steps, curr_step) 33 | ] 34 | 35 | setattr(model, self.name, runner.log_buffer.output[self.name]) 36 | -------------------------------------------------------------------------------- /ssod/utils/hooks/weights_summary.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | import torch.distributed as dist 4 | from mmcv.parallel import is_module_wrapper 5 | from mmcv.runner.hooks import HOOKS, Hook 6 | from ..logger import get_root_logger 7 | from prettytable import PrettyTable 8 | 9 | 10 | def bool2str(input): 11 | if input: 12 | return "Y" 13 | else: 14 | return "N" 15 | 16 | 17 | def unknown(): 18 | return "-" 19 | 20 | 21 | def shape_str(size): 22 | size = [str(s) for s in size] 23 | return "X".join(size) 24 | 25 | 26 | def min_max_str(input): 27 | return "Min:{:.3f} Max:{:.3f}".format(input.min(), input.max()) 28 | 29 | 30 | def construct_params_dict(input): 31 | assert isinstance(input, list) 32 | param_dict = {} 33 | for group in input: 34 | if "name" in group: 35 | param_dict[group["name"]] = group 36 | return param_dict 37 | 38 | 39 | def max_match_sub_str(strs, sub_str): 40 | # find most related str for sub_str 41 | matched = None 42 | for child in strs: 43 | if len(child) <= len(sub_str): 44 | if child == sub_str: 45 | return child 46 | elif sub_str[: len(child)] == child: 47 | if matched is None or len(matched) < len(child): 48 | matched = child 49 | return matched 50 | 51 | 52 | def get_optim(optimizer, params_dict, name, key): 53 | rel_name = max_match_sub_str(list(params_dict.keys()), name) 54 | if rel_name is not None: 55 | return params_dict[rel_name][key] 56 | else: 57 | if key in optimizer.defaults: 58 | return optimizer.defaults[key] 59 | 60 | 61 | @HOOKS.register_module() 62 | class WeightSummary(Hook): 63 | def before_run(self, runner): 64 | if runner.rank != 0: 65 | return 66 | if is_module_wrapper(runner.model): 67 | model = runner.model.module 68 | else: 69 | model = runner.model 70 | weight_summaries = self.collect_model_info(model, optimizer=runner.optimizer) 71 | logger = get_root_logger() 72 | logger.info(weight_summaries) 73 | 74 | @staticmethod 75 | def collect_model_info(model, optimizer=None, rich_text=False): 76 | param_groups = None 77 | if optimizer is not None: 78 | param_groups = construct_params_dict(optimizer.param_groups) 79 | 80 | if not rich_text: 81 | table = PrettyTable( 82 | ["Name", "Optimized", "Shape", "Value Scale [Min,Max]", "Lr", "Wd"] 83 | ) 84 | for name, param in model.named_parameters(): 85 | table.add_row( 86 | [ 87 | name, 88 | bool2str(param.requires_grad), 89 | shape_str(param.size()), 90 | min_max_str(param), 91 | unknown() 92 | if param_groups is None 93 | else get_optim(optimizer, param_groups, name, "lr"), 94 | unknown() 95 | if param_groups is None 96 | else get_optim(optimizer, param_groups, name, "weight_decay"), 97 | ] 98 | ) 99 | return "\n" + table.get_string(title="Model Information") 100 | else: 101 | pass 102 | -------------------------------------------------------------------------------- /ssod/utils/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import sys 4 | from collections import Counter 5 | from typing import Tuple 6 | 7 | import mmcv 8 | import numpy as np 9 | import torch 10 | from mmcv.runner.dist_utils import get_dist_info 11 | from mmcv.utils import get_logger 12 | from mmdet.core.visualization import imshow_det_bboxes 13 | 14 | try: 15 | import wandb 16 | except: 17 | wandb = None 18 | 19 | _log_counter = Counter() 20 | 21 | 22 | def get_root_logger(log_file=None, log_level=logging.INFO): 23 | """Get root logger. 24 | 25 | Args: 26 | log_file (str, optional): File path of log. Defaults to None. 27 | log_level (int, optional): The level of logger. 28 | Defaults to logging.INFO. 29 | 30 | Returns: 31 | :obj:`logging.Logger`: The obtained logger 32 | """ 33 | logger = get_logger(name="mmdet.ssod", log_file=log_file, log_level=log_level) 34 | logger.propagate = False 35 | return logger 36 | 37 | 38 | def _find_caller(): 39 | frame = sys._getframe(2) 40 | while frame: 41 | code = frame.f_code 42 | if os.path.join("utils", "logger.") not in code.co_filename: 43 | mod_name = frame.f_globals["__name__"] 44 | if mod_name == "__main__": 45 | mod_name = r"ssod" 46 | return mod_name, (code.co_filename, frame.f_lineno, code.co_name) 47 | frame = frame.f_back 48 | 49 | 50 | def convert_box(tag, boxes, box_labels, class_labels, std, scores=None): 51 | if isinstance(std, int): 52 | std = [std, std] 53 | if len(std) != 4: 54 | std = std[::-1] * 2 55 | std = boxes.new_tensor(std).reshape(1, 4) 56 | wandb_box = {} 57 | boxes = boxes / std 58 | boxes = boxes.detach().cpu().numpy().tolist() 59 | box_labels = box_labels.detach().cpu().numpy().tolist() 60 | class_labels = {k: class_labels[k] for k in range(len(class_labels))} 61 | wandb_box["class_labels"] = class_labels 62 | assert len(boxes) == len(box_labels) 63 | if scores is not None: 64 | scores = scores.detach().cpu().numpy().tolist() 65 | box_data = [ 66 | dict( 67 | position=dict(minX=box[0], minY=box[1], maxX=box[2], maxY=box[3]), 68 | class_id=label, 69 | scores=dict(cls=scores[i]), 70 | ) 71 | for i, (box, label) in enumerate(zip(boxes, box_labels)) 72 | ] 73 | else: 74 | box_data = [ 75 | dict( 76 | position=dict(minX=box[0], minY=box[1], maxX=box[2], maxY=box[3]), 77 | class_id=label, 78 | ) 79 | for i, (box, label) in enumerate(zip(boxes, box_labels)) 80 | ] 81 | 82 | wandb_box["box_data"] = box_data 83 | return {tag: wandb.data_types.BoundingBoxes2D(wandb_box, tag)} 84 | 85 | 86 | def color_transform(img_tensor, mean, std, to_rgb=False): 87 | img_np = img_tensor.detach().cpu().numpy().transpose((1, 2, 0)).astype(np.float32) 88 | return mmcv.imdenormalize(img_np, mean, std, to_bgr=not to_rgb) 89 | 90 | 91 | def log_image_with_boxes( 92 | tag: str, 93 | image: torch.Tensor, 94 | bboxes: torch.Tensor, 95 | bbox_tag: str = None, 96 | labels: torch.Tensor = None, 97 | scores: torch.Tensor = None, 98 | class_names: Tuple[str] = None, 99 | filename: str = None, 100 | img_norm_cfg: dict = None, 101 | backend: str = "auto", 102 | interval: int = 50, 103 | ): 104 | rank, _ = get_dist_info() 105 | if rank != 0: 106 | return 107 | _, key = _find_caller() 108 | _log_counter[key] += 1 109 | if not (interval == 1 or _log_counter[key] % interval == 1): 110 | return 111 | if backend == "auto": 112 | if (wandb is None) or (wandb.run is None): 113 | backend = "file" 114 | else: 115 | backend = "wandb" 116 | 117 | if backend == "wandb": 118 | if wandb is None: 119 | raise ImportError("wandb is not installed") 120 | assert ( 121 | wandb.run is not None 122 | ), "wandb has not been initialized, call `wandb.init` first`" 123 | 124 | elif backend != "file": 125 | raise TypeError("backend must be file or wandb") 126 | 127 | if filename is None: 128 | filename = f"{_log_counter[key]}.jpg" 129 | if bbox_tag is not None: 130 | bbox_tag = "vis" 131 | if img_norm_cfg is not None: 132 | image = color_transform(image, **img_norm_cfg) 133 | if labels is None: 134 | labels = bboxes.new_zeros(bboxes.shape[0]).long() 135 | class_names = ["foreground"] 136 | if backend == "wandb": 137 | im = {} 138 | im["data_or_path"] = image 139 | im["boxes"] = convert_box( 140 | bbox_tag, bboxes, labels, class_names, scores=scores, std=image.shape[:2] 141 | ) 142 | wandb.log({tag: wandb.Image(**im)}, commit=False) 143 | elif backend == "file": 144 | root_dir = os.environ.get("WORK_DIR", ".") 145 | 146 | imshow_det_bboxes( 147 | image, 148 | bboxes.cpu().detach().numpy(), 149 | labels.cpu().detach().numpy(), 150 | class_names=class_names, 151 | show=False, 152 | out_file=os.path.join(root_dir, tag, bbox_tag, filename), 153 | ) 154 | else: 155 | raise TypeError("backend must be file or wandb") 156 | 157 | 158 | def log_every_n(msg: str, n: int = 50, level: int = logging.DEBUG, backend="auto"): 159 | """ 160 | Args: 161 | msg (Any): 162 | n (int): 163 | level (int): 164 | name (str): 165 | """ 166 | caller_module, key = _find_caller() 167 | _log_counter[key] += 1 168 | if n == 1 or _log_counter[key] % n == 1: 169 | if isinstance(msg, dict) and (wandb is not None) and (wandb.run is not None): 170 | wandb.log(msg, commit=False) 171 | else: 172 | get_root_logger().log(level, msg) 173 | -------------------------------------------------------------------------------- /ssod/utils/patch.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import os.path as osp 4 | import shutil 5 | import types 6 | 7 | from mmcv.runner import BaseRunner, EpochBasedRunner, IterBasedRunner 8 | from mmcv.utils import Config 9 | 10 | from .signature import parse_method_info 11 | from .vars import resolve 12 | 13 | 14 | def find_latest_checkpoint(path, ext="pth"): 15 | if not osp.exists(path): 16 | return None 17 | if osp.exists(osp.join(path, f"latest.{ext}")): 18 | return osp.join(path, f"latest.{ext}") 19 | 20 | checkpoints = glob.glob(osp.join(path, f"*.{ext}")) 21 | if len(checkpoints) == 0: 22 | return None 23 | latest = -1 24 | latest_path = None 25 | for checkpoint in checkpoints: 26 | count = int(osp.basename(checkpoint).split("_")[-1].split(".")[0]) 27 | if count > latest: 28 | latest = count 29 | latest_path = checkpoint 30 | return latest_path 31 | 32 | 33 | def patch_checkpoint(runner: BaseRunner): 34 | # patch save_checkpoint 35 | old_save_checkpoint = runner.save_checkpoint 36 | params = parse_method_info(old_save_checkpoint) 37 | default_tmpl = params["filename_tmpl"].default 38 | 39 | def save_checkpoint(self, out_dir, **kwargs): 40 | create_symlink = kwargs.get("create_symlink", True) 41 | filename_tmpl = kwargs.get("filename_tmpl", default_tmpl) 42 | # create_symlink 43 | kwargs.update(create_symlink=False) 44 | old_save_checkpoint(out_dir, **kwargs) 45 | if create_symlink: 46 | dst_file = osp.join(out_dir, "latest.pth") 47 | if isinstance(self, EpochBasedRunner): 48 | filename = filename_tmpl.format(self.epoch + 1) 49 | elif isinstance(self, IterBasedRunner): 50 | filename = filename_tmpl.format(self.iter + 1) 51 | else: 52 | raise NotImplementedError() 53 | filepath = osp.join(out_dir, filename) 54 | shutil.copy(filepath, dst_file) 55 | 56 | runner.save_checkpoint = types.MethodType(save_checkpoint, runner) 57 | return runner 58 | 59 | 60 | def patch_runner(runner): 61 | runner = patch_checkpoint(runner) 62 | return runner 63 | 64 | 65 | def setup_env(cfg): 66 | os.environ["WORK_DIR"] = cfg.work_dir 67 | 68 | 69 | def patch_config(cfg): 70 | 71 | cfg_dict = super(Config, cfg).__getattribute__("_cfg_dict").to_dict() 72 | cfg_dict["cfg_name"] = osp.splitext(osp.basename(cfg.filename))[0] 73 | cfg_dict = resolve(cfg_dict) 74 | cfg = Config(cfg_dict, filename=cfg.filename) 75 | # wrap for semi 76 | if cfg.get("semi_wrapper", None) is not None: 77 | cfg.model = cfg.semi_wrapper 78 | cfg.pop("semi_wrapper") 79 | # enable environment variables 80 | setup_env(cfg) 81 | return cfg 82 | -------------------------------------------------------------------------------- /ssod/utils/signature.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | 3 | 4 | def parse_method_info(method): 5 | sig = inspect.signature(method) 6 | params = sig.parameters 7 | return params 8 | -------------------------------------------------------------------------------- /ssod/utils/structure_utils.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from collections import Counter, Mapping, Sequence 3 | from numbers import Number 4 | from typing import Dict, List 5 | 6 | import numpy as np 7 | import torch 8 | from mmdet.core.mask.structures import BitmapMasks 9 | from torch.nn import functional as F 10 | 11 | _step_counter = Counter() 12 | 13 | 14 | def list_concat(data_list: List[list]): 15 | if isinstance(data_list[0], torch.Tensor): 16 | return torch.cat(data_list) 17 | else: 18 | endpoint = [d for d in data_list[0]] 19 | 20 | for i in range(1, len(data_list)): 21 | endpoint.extend(data_list[i]) 22 | return endpoint 23 | 24 | 25 | def sequence_concat(a, b): 26 | if isinstance(a, Sequence) and isinstance(b, Sequence): 27 | return a + b 28 | else: 29 | return None 30 | 31 | 32 | def dict_concat(dicts: List[Dict[str, list]]): 33 | return {k: list_concat([d[k] for d in dicts]) for k in dicts[0].keys()} 34 | 35 | 36 | def dict_fuse(obj_list, reference_obj): 37 | if isinstance(reference_obj, torch.Tensor): 38 | return torch.stack(obj_list) 39 | return obj_list 40 | 41 | 42 | def dict_select(dict1: Dict[str, list], key: str, value: str): 43 | flag = [v == value for v in dict1[key]] 44 | return { 45 | k: dict_fuse([vv for vv, ff in zip(v, flag) if ff], v) for k, v in dict1.items() 46 | } 47 | 48 | 49 | def dict_split(dict1, key): 50 | group_names = list(set(dict1[key])) 51 | dict_groups = {k: dict_select(dict1, key, k) for k in group_names} 52 | 53 | return dict_groups 54 | 55 | 56 | def dict_sum(a, b): 57 | if isinstance(a, dict): 58 | assert isinstance(b, dict) 59 | return {k: dict_sum(v, b[k]) for k, v in a.items()} 60 | elif isinstance(a, list): 61 | assert len(a) == len(b) 62 | return [dict_sum(aa, bb) for aa, bb in zip(a, b)] 63 | else: 64 | return a + b 65 | 66 | 67 | def zero_like(tensor_pack, prefix=""): 68 | if isinstance(tensor_pack, Sequence): 69 | return [zero_like(t) for t in tensor_pack] 70 | elif isinstance(tensor_pack, Mapping): 71 | return {prefix + k: zero_like(v) for k, v in tensor_pack.items()} 72 | elif isinstance(tensor_pack, torch.Tensor): 73 | return tensor_pack.new_zeros(tensor_pack.shape) 74 | elif isinstance(tensor_pack, np.ndarray): 75 | return np.zeros_like(tensor_pack) 76 | else: 77 | warnings.warn("Unexpected data type {}".format(type(tensor_pack))) 78 | return 0 79 | 80 | 81 | def pad_stack(tensors, shape, pad_value=255): 82 | tensors = torch.stack( 83 | [ 84 | F.pad( 85 | tensor, 86 | pad=[0, shape[1] - tensor.shape[1], 0, shape[0] - tensor.shape[0]], 87 | value=pad_value, 88 | ) 89 | for tensor in tensors 90 | ] 91 | ) 92 | return tensors 93 | 94 | 95 | def result2bbox(result): 96 | num_class = len(result) 97 | 98 | bbox = np.concatenate(result) 99 | if bbox.shape[0] == 0: 100 | label = np.zeros(0, dtype=np.uint8) 101 | else: 102 | label = np.concatenate( 103 | [[i] * len(result[i]) for i in range(num_class) if len(result[i]) > 0] 104 | ).reshape((-1,)) 105 | return bbox, label 106 | 107 | 108 | def result2mask(result): 109 | num_class = len(result) 110 | mask = [np.stack(result[i]) for i in range(num_class) if len(result[i]) > 0] 111 | if len(mask) > 0: 112 | mask = np.concatenate(mask) 113 | else: 114 | mask = np.zeros((0, 1, 1)) 115 | return BitmapMasks(mask, mask.shape[1], mask.shape[2]), None 116 | 117 | 118 | def sequence_mul(obj, multiplier): 119 | if isinstance(obj, Sequence): 120 | return [o * multiplier for o in obj] 121 | else: 122 | return obj * multiplier 123 | 124 | 125 | def is_match(word, word_list): 126 | for keyword in word_list: 127 | if keyword in word: 128 | return True 129 | return False 130 | 131 | 132 | def weighted_loss(loss: dict, weight, ignore_keys=[], warmup=0): 133 | _step_counter["weight"] += 1 134 | lambda_weight = ( 135 | lambda x: x * (_step_counter["weight"] - 1) / warmup 136 | if _step_counter["weight"] <= warmup 137 | else x 138 | ) 139 | if isinstance(weight, Mapping): 140 | for k, v in weight.items(): 141 | for name, loss_item in loss.items(): 142 | if (k in name) and ("loss" in name): 143 | loss[name] = sequence_mul(loss[name], lambda_weight(v)) 144 | elif isinstance(weight, Number): 145 | for name, loss_item in loss.items(): 146 | if "loss" in name: 147 | if not is_match(name, ignore_keys): 148 | loss[name] = sequence_mul(loss[name], lambda_weight(weight)) 149 | else: 150 | loss[name] = sequence_mul(loss[name], 0.0) 151 | else: 152 | raise NotImplementedError() 153 | return loss 154 | -------------------------------------------------------------------------------- /ssod/utils/vars.py: -------------------------------------------------------------------------------- 1 | import re 2 | from typing import Union 3 | 4 | pattern = re.compile("\$\{[a-zA-Z\d_.]*\}") 5 | 6 | 7 | def get_value(cfg: dict, chained_key: str): 8 | keys = chained_key.split(".") 9 | if len(keys) == 1: 10 | return cfg[keys[0]] 11 | else: 12 | return get_value(cfg[keys[0]], ".".join(keys[1:])) 13 | 14 | 15 | def resolve(cfg: Union[dict, list], base=None): 16 | if base is None: 17 | base = cfg 18 | if isinstance(cfg, dict): 19 | return {k: resolve(v, base) for k, v in cfg.items()} 20 | elif isinstance(cfg, list): 21 | return [resolve(v, base) for v in cfg] 22 | elif isinstance(cfg, tuple): 23 | return tuple([resolve(v, base) for v in cfg]) 24 | elif isinstance(cfg, str): 25 | # process 26 | var_names = pattern.findall(cfg) 27 | if len(var_names) == 1 and len(cfg) == len(var_names[0]): 28 | return get_value(base, var_names[0][2:-1]) 29 | else: 30 | vars = [get_value(base, name[2:-1]) for name in var_names] 31 | for name, var in zip(var_names, vars): 32 | cfg = cfg.replace(name, str(var)) 33 | return cfg 34 | else: 35 | return cfg 36 | -------------------------------------------------------------------------------- /ssod/version.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.0.1" 2 | 3 | __all__ = ["__version__"] 4 | -------------------------------------------------------------------------------- /tools/dataset/prepare_coco_data.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | 4 | help() { 5 | echo "Usage: $0 [option...] download|conduct|fulll" 6 | echo "download download coco dataset" 7 | echo "conduct conduct data split for semi supervised training and evaluation" 8 | echo "option:" 9 | echo " -r, --root [PATH] select the root path of dataset. The default dataset root is ssod/data" 10 | } 11 | download() { 12 | mkdir -p coco 13 | cd coco 14 | for split in train2017 val2017 unlabeled2017; 15 | do 16 | wget http://images.cocodataset.org/zips/${split}.zip; 17 | unzip ${split}.zip 18 | done 19 | wget http://images.cocodataset.org/annotations/annotations_trainval2017.zip 20 | unzip annotations_trainval2017.zip 21 | wget http://images.cocodataset.org/annotations/image_info_unlabeled2017.zip 22 | unzip image_info_unlabeled2017.zip 23 | cd .. 24 | } 25 | conduct() { 26 | OFFSET=$RANDOM 27 | for percent in 1 5 10; do 28 | for fold in 1 2 3 4 5; do 29 | python tools/dataset/semi_coco.py --percent ${percent} --seed ${fold} --data-dir "${data_root}"/coco --seed-offset ${OFFSET} 30 | done 31 | done 32 | } 33 | 34 | data_root=data 35 | ROOT=$(dirname "$0")/../.. 36 | 37 | cd "${ROOT}" 38 | 39 | case $1 in 40 | -r | --root) 41 | data_root=$2 42 | shift 2 43 | ;; 44 | esac 45 | mkdir -p ${data_root} 46 | case $1 in 47 | download) 48 | cd ${data_root} 49 | download 50 | ;; 51 | conduct) 52 | conduct 53 | ;; 54 | full) 55 | cd ${data_root} 56 | download 57 | cd .. 58 | conduct 59 | ;; 60 | *) 61 | help 62 | exit 0 63 | ;; 64 | esac 65 | -------------------------------------------------------------------------------- /tools/dataset/semi_coco.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | #!/bin/bash 16 | """Generate labeled and unlabeled dataset for coco train. 17 | 18 | Example: 19 | python tools/coco_semi.py 20 | """ 21 | 22 | import argparse 23 | import numpy as np 24 | import json 25 | import os 26 | 27 | 28 | def prepare_coco_data(seed=1, percent=10.0, version=2017, seed_offset=0): 29 | """Prepare COCO dataset for Semi-supervised learning 30 | Args: 31 | seed: random seed for dataset split 32 | percent: percentage of labeled dataset 33 | version: COCO dataset version 34 | """ 35 | 36 | def _save_anno(name, images, annotations): 37 | """Save annotation.""" 38 | print( 39 | ">> Processing dataset {}.json saved ({} images {} annotations)".format( 40 | name, len(images), len(annotations) 41 | ) 42 | ) 43 | new_anno = {} 44 | new_anno["images"] = images 45 | new_anno["annotations"] = annotations 46 | new_anno["licenses"] = anno["licenses"] 47 | new_anno["categories"] = anno["categories"] 48 | new_anno["info"] = anno["info"] 49 | path = "{}/{}".format(COCOANNODIR, "semi_supervised") 50 | if not os.path.exists(path): 51 | os.mkdir(path) 52 | 53 | with open( 54 | "{root}/{folder}/{save_name}.json".format( 55 | save_name=name, root=COCOANNODIR, folder="semi_supervised" 56 | ), 57 | "w", 58 | ) as f: 59 | json.dump(new_anno, f) 60 | print( 61 | ">> Data {}.json saved ({} images {} annotations)".format( 62 | name, len(images), len(annotations) 63 | ) 64 | ) 65 | 66 | np.random.seed(seed + seed_offset) 67 | COCOANNODIR = os.path.join(DATA_DIR, "annotations") 68 | 69 | anno = json.load( 70 | open(os.path.join(COCOANNODIR, "instances_train{}.json".format(version))) 71 | ) 72 | 73 | image_list = anno["images"] 74 | labeled_tot = int(percent / 100.0 * len(image_list)) 75 | labeled_ind = np.random.choice( 76 | range(len(image_list)), size=labeled_tot, replace=False 77 | ) 78 | labeled_id = [] 79 | labeled_images = [] 80 | unlabeled_images = [] 81 | labeled_ind = set(labeled_ind) 82 | for i in range(len(image_list)): 83 | if i in labeled_ind: 84 | labeled_images.append(image_list[i]) 85 | labeled_id.append(image_list[i]["id"]) 86 | else: 87 | unlabeled_images.append(image_list[i]) 88 | 89 | # get all annotations of labeled images 90 | labeled_id = set(labeled_id) 91 | labeled_annotations = [] 92 | unlabeled_annotations = [] 93 | for an in anno["annotations"]: 94 | if an["image_id"] in labeled_id: 95 | labeled_annotations.append(an) 96 | else: 97 | unlabeled_annotations.append(an) 98 | 99 | # save labeled and unlabeled 100 | save_name = "instances_train{version}.{seed}@{tot}".format( 101 | version=version, seed=seed, tot=int(percent) 102 | ) 103 | _save_anno(save_name, labeled_images, labeled_annotations) 104 | save_name = "instances_train{version}.{seed}@{tot}-unlabeled".format( 105 | version=version, seed=seed, tot=int(percent) 106 | ) 107 | _save_anno(save_name, unlabeled_images, unlabeled_annotations) 108 | #construct 120k unlabeled data 109 | unlabeled_ann_file = os.path.join(COCOANNODIR, "instances_unlabeled{}.json".format(version)) 110 | if not os.path.exists(unlabeled_ann_file): 111 | unlabeled_info = json.load( 112 | open(os.path.join(COCOANNODIR, "image_info_unlabeled{}.json".format(version))) 113 | ) 114 | unlabeled_info["annotations"] = [] 115 | unlabeled_info["categories"] = anno["categories"] 116 | print(">> Data {}.json saved({} images {} annotations)".format(unlabeled_ann_file,len(unlabeled_info["images"]),len(unlabeled_info["annotations"]))) 117 | json.dump(unlabeled_info,open(os.path.join(COCOANNODIR, "instances_unlabeled{}.json".format(version)),'w')) 118 | 119 | if __name__ == "__main__": 120 | 121 | parser = argparse.ArgumentParser() 122 | parser.add_argument("--data-dir", type=str) 123 | parser.add_argument("--percent", type=float, default=10) 124 | parser.add_argument("--version", type=int, default=2017) 125 | parser.add_argument("--seed", type=int, help="seed", default=1) 126 | parser.add_argument("--seed-offset", type=int, default=0) 127 | args = parser.parse_args() 128 | print(args) 129 | DATA_DIR = args.data_dir 130 | prepare_coco_data(args.seed, args.percent, args.version, args.seed_offset) 131 | -------------------------------------------------------------------------------- /tools/dataset/semi_coco.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -x 3 | OFFSET=$RANDOM 4 | for percent in 1 5 10; do 5 | for fold in 1 2 3 4 5; do 6 | $(dirname "$0")/coco_semi.py --percent ${percent} --seed ${fold} --data-dir $1 --seed-offset ${OFFSET} 7 | done 8 | done 9 | -------------------------------------------------------------------------------- /tools/dist_test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | CONFIG=$1 4 | CHECKPOINT=$2 5 | GPUS=$3 6 | PORT=${PORT:-29500} 7 | 8 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 9 | python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \ 10 | $(dirname "$0")/test.py $CONFIG $CHECKPOINT --launcher pytorch ${@:4} 11 | -------------------------------------------------------------------------------- /tools/dist_train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | CONFIG=$1 4 | GPUS=$2 5 | PORT=${PORT:-29500} 6 | 7 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 8 | python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \ 9 | $(dirname "$0")/train.py $CONFIG --launcher pytorch ${@:3} 10 | -------------------------------------------------------------------------------- /tools/dist_train_partially.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -x 3 | 4 | TYPE=$1 5 | FOLD=$2 6 | PERCENT=$3 7 | GPUS=$4 8 | PORT=${PORT:-29500} 9 | 10 | 11 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH 12 | 13 | if [[ ${TYPE} == 'baseline' ]]; then 14 | python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \ 15 | $(dirname "$0")/train.py configs/baseline/faster_rcnn_r50_caffe_fpn_coco_partial_180k.py --launcher pytorch \ 16 | --cfg-options fold=${FOLD} percent=${PERCENT} ${@:5} 17 | else 18 | python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \ 19 | $(dirname "$0")/train.py configs/soft_teacher/soft_teacher_faster_rcnn_r50_caffe_fpn_coco_180k.py --launcher pytorch \ 20 | --cfg-options fold=${FOLD} percent=${PERCENT} ${@:5} 21 | fi 22 | -------------------------------------------------------------------------------- /tools/misc/browse_dataset.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from pathlib import Path 4 | 5 | import mmcv 6 | import torch 7 | from mmcv import Config, DictAction 8 | from mmdet.core.utils import mask2ndarray 9 | from mmdet.core.visualization import imshow_det_bboxes 10 | 11 | from ssod.datasets import build_dataset 12 | from ssod.models.utils import Transform2D 13 | 14 | 15 | def parse_args(): 16 | parser = argparse.ArgumentParser(description="Browse a dataset") 17 | parser.add_argument("config", help="train config file path") 18 | parser.add_argument( 19 | "--skip-type", 20 | type=str, 21 | nargs="+", 22 | default=["DefaultFormatBundle", "Normalize", "Collect"], 23 | help="skip some useless pipeline", 24 | ) 25 | parser.add_argument( 26 | "--output-dir", 27 | default=None, 28 | type=str, 29 | help="If there is no display interface, you can save it", 30 | ) 31 | parser.add_argument("--not-show", default=False, action="store_true") 32 | parser.add_argument( 33 | "--show-interval", type=float, default=2, help="the interval of show (s)" 34 | ) 35 | parser.add_argument( 36 | "--cfg-options", 37 | nargs="+", 38 | action=DictAction, 39 | help="override some settings in the used config, the key-value pair " 40 | "in xxx=yyy format will be merged into config file. If the value to " 41 | 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' 42 | 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' 43 | "Note that the quotation marks are necessary and that no white space " 44 | "is allowed.", 45 | ) 46 | args = parser.parse_args() 47 | return args 48 | 49 | 50 | def remove_pipe(pipelines, skip_type): 51 | if isinstance(pipelines, list): 52 | new_pipelines = [] 53 | for pipe in pipelines: 54 | pipe = remove_pipe(pipe, skip_type) 55 | if pipe is not None: 56 | new_pipelines.append(pipe) 57 | return new_pipelines 58 | elif isinstance(pipelines, dict): 59 | if pipelines["type"] in skip_type: 60 | return None 61 | elif pipelines["type"] == "MultiBranch": 62 | new_pipelines = {} 63 | for k, v in pipelines.items(): 64 | if k != "type": 65 | new_pipelines[k] = remove_pipe(v, skip_type) 66 | else: 67 | new_pipelines[k] = v 68 | return new_pipelines 69 | else: 70 | return pipelines 71 | else: 72 | raise NotImplementedError() 73 | 74 | 75 | def retrieve_data_cfg(config_path, skip_type, cfg_options): 76 | cfg = Config.fromfile(config_path) 77 | if cfg_options is not None: 78 | cfg.merge_from_dict(cfg_options) 79 | # import modules from string list. 80 | if cfg.get("custom_imports", None): 81 | from mmcv.utils import import_modules_from_strings 82 | 83 | import_modules_from_strings(**cfg["custom_imports"]) 84 | train_data_cfg = cfg.data.train 85 | while "dataset" in train_data_cfg: 86 | train_data_cfg = train_data_cfg["dataset"] 87 | train_data_cfg["pipeline"] = remove_pipe(train_data_cfg["pipeline"], skip_type) 88 | return cfg 89 | 90 | 91 | def main(): 92 | args = parse_args() 93 | cfg = retrieve_data_cfg(args.config, args.skip_type, args.cfg_options) 94 | 95 | dataset = build_dataset(cfg.data.train) 96 | 97 | progress_bar = mmcv.ProgressBar(len(dataset)) 98 | 99 | for item in dataset: 100 | if not isinstance(item, list): 101 | item = [item] 102 | bboxes = [] 103 | labels = [] 104 | tran_mats = [] 105 | out_shapes = [] 106 | for it in item: 107 | trans_matrix = it["transform_matrix"] 108 | bbox = it["gt_bboxes"] 109 | tran_mats.append(trans_matrix) 110 | bboxes.append(bbox) 111 | labels.append(it["gt_labels"]) 112 | out_shapes.append(it["img_shape"]) 113 | 114 | filename = ( 115 | os.path.join(args.output_dir, Path(it["filename"]).name) 116 | if args.output_dir is not None 117 | else None 118 | ) 119 | 120 | gt_masks = it.get("gt_masks", None) 121 | if gt_masks is not None: 122 | gt_masks = mask2ndarray(gt_masks) 123 | 124 | imshow_det_bboxes( 125 | it["img"], 126 | it["gt_bboxes"], 127 | it["gt_labels"], 128 | gt_masks, 129 | class_names=dataset.CLASSES, 130 | show=not args.not_show, 131 | wait_time=args.show_interval, 132 | out_file=filename, 133 | bbox_color=(255, 102, 61), 134 | text_color=(255, 102, 61), 135 | ) 136 | 137 | if len(tran_mats) == 2: 138 | # check equality between different augmentation 139 | transed_bboxes = Transform2D.transform_bboxes( 140 | torch.from_numpy(bboxes[1]).float(), 141 | torch.from_numpy(tran_mats[0]).float() 142 | @ torch.from_numpy(tran_mats[1]).float().inverse(), 143 | out_shapes[0], 144 | ) 145 | img = imshow_det_bboxes( 146 | item[0]["img"], 147 | item[0]["gt_bboxes"], 148 | item[0]["gt_labels"], 149 | class_names=dataset.CLASSES, 150 | show=False, 151 | wait_time=args.show_interval, 152 | out_file=None, 153 | bbox_color=(255, 102, 61), 154 | text_color=(255, 102, 61), 155 | ) 156 | imshow_det_bboxes( 157 | img, 158 | transed_bboxes.numpy(), 159 | labels[1], 160 | class_names=dataset.CLASSES, 161 | show=True, 162 | wait_time=args.show_interval, 163 | out_file=None, 164 | bbox_color=(0, 0, 255), 165 | text_color=(0, 0, 255), 166 | thickness=5, 167 | ) 168 | 169 | progress_bar.update() 170 | 171 | 172 | if __name__ == "__main__": 173 | main() 174 | -------------------------------------------------------------------------------- /tools/test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import os.path as osp 4 | import time 5 | import warnings 6 | 7 | import mmcv 8 | import torch 9 | from mmcv import Config, DictAction 10 | from mmcv.cnn import fuse_conv_bn 11 | from mmcv.parallel import MMDataParallel, MMDistributedDataParallel 12 | from mmcv.runner import get_dist_info, init_dist, load_checkpoint, wrap_fp16_model 13 | from mmdet.apis import multi_gpu_test, single_gpu_test 14 | from mmdet.datasets import build_dataloader, build_dataset, replace_ImageToTensor 15 | from mmdet.models import build_detector 16 | 17 | from ssod.utils import patch_config 18 | 19 | 20 | def parse_args(): 21 | parser = argparse.ArgumentParser(description="MMDet test (and eval) a model") 22 | parser.add_argument("config", help="test config file path") 23 | parser.add_argument("checkpoint", help="checkpoint file") 24 | parser.add_argument( 25 | "--work-dir", 26 | help="the directory to save the file containing evaluation metrics", 27 | ) 28 | parser.add_argument("--out", help="output result file in pickle format") 29 | parser.add_argument( 30 | "--fuse-conv-bn", 31 | action="store_true", 32 | help="Whether to fuse conv and bn, this will slightly increase" 33 | "the inference speed", 34 | ) 35 | parser.add_argument( 36 | "--format-only", 37 | action="store_true", 38 | help="Format the output results without perform evaluation. It is" 39 | "useful when you want to format the result to a specific format and " 40 | "submit it to the test server", 41 | ) 42 | parser.add_argument( 43 | "--eval", 44 | type=str, 45 | nargs="+", 46 | help='evaluation metrics, which depends on the dataset, e.g., "bbox",' 47 | ' "segm", "proposal" for COCO, and "mAP", "recall" for PASCAL VOC', 48 | ) 49 | parser.add_argument("--show", action="store_true", help="show results") 50 | parser.add_argument( 51 | "--show-dir", help="directory where painted images will be saved" 52 | ) 53 | parser.add_argument( 54 | "--show-score-thr", 55 | type=float, 56 | default=0.3, 57 | help="score threshold (default: 0.3)", 58 | ) 59 | parser.add_argument( 60 | "--gpu-collect", 61 | action="store_true", 62 | help="whether to use gpu to collect results.", 63 | ) 64 | parser.add_argument( 65 | "--tmpdir", 66 | help="tmp directory used for collecting results from multiple " 67 | "workers, available when gpu-collect is not specified", 68 | ) 69 | parser.add_argument( 70 | "--cfg-options", 71 | nargs="+", 72 | action=DictAction, 73 | help="override some settings in the used config, the key-value pair " 74 | "in xxx=yyy format will be merged into config file. If the value to " 75 | 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' 76 | 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' 77 | "Note that the quotation marks are necessary and that no white space " 78 | "is allowed.", 79 | ) 80 | parser.add_argument( 81 | "--options", 82 | nargs="+", 83 | action=DictAction, 84 | help="custom options for evaluation, the key-value pair in xxx=yyy " 85 | "format will be kwargs for dataset.evaluate() function (deprecate), " 86 | "change to --eval-options instead.", 87 | ) 88 | parser.add_argument( 89 | "--eval-options", 90 | nargs="+", 91 | action=DictAction, 92 | help="custom options for evaluation, the key-value pair in xxx=yyy " 93 | "format will be kwargs for dataset.evaluate() function", 94 | ) 95 | parser.add_argument( 96 | "--launcher", 97 | choices=["none", "pytorch", "slurm", "mpi"], 98 | default="none", 99 | help="job launcher", 100 | ) 101 | parser.add_argument("--local_rank", type=int, default=0) 102 | args = parser.parse_args() 103 | if "LOCAL_RANK" not in os.environ: 104 | os.environ["LOCAL_RANK"] = str(args.local_rank) 105 | 106 | if args.options and args.eval_options: 107 | raise ValueError( 108 | "--options and --eval-options cannot be both " 109 | "specified, --options is deprecated in favor of --eval-options" 110 | ) 111 | if args.options: 112 | warnings.warn("--options is deprecated in favor of --eval-options") 113 | args.eval_options = args.options 114 | return args 115 | 116 | 117 | def main(): 118 | args = parse_args() 119 | 120 | assert args.out or args.eval or args.format_only or args.show or args.show_dir, ( 121 | "Please specify at least one operation (save/eval/format/show the " 122 | 'results / save the results) with the argument "--out", "--eval"' 123 | ', "--format-only", "--show" or "--show-dir"' 124 | ) 125 | 126 | if args.eval and args.format_only: 127 | raise ValueError("--eval and --format_only cannot be both specified") 128 | 129 | if args.out is not None and not args.out.endswith((".pkl", ".pickle")): 130 | raise ValueError("The output file must be a pkl file.") 131 | 132 | cfg = Config.fromfile(args.config) 133 | if args.cfg_options is not None: 134 | cfg.merge_from_dict(args.cfg_options) 135 | # import modules from string list. 136 | if cfg.get("custom_imports", None): 137 | from mmcv.utils import import_modules_from_strings 138 | 139 | import_modules_from_strings(**cfg["custom_imports"]) 140 | # set cudnn_benchmark 141 | if cfg.get("cudnn_benchmark", False): 142 | torch.backends.cudnn.benchmark = True 143 | # fix issue mentioned in https://github.com/microsoft/SoftTeacher/issues/111 144 | if "pretrained" in cfg.model: 145 | cfg.model.pretrained = None 146 | if cfg.model.get("neck"): 147 | if isinstance(cfg.model.neck, list): 148 | for neck_cfg in cfg.model.neck: 149 | if neck_cfg.get("rfp_backbone"): 150 | if neck_cfg.rfp_backbone.get("pretrained"): 151 | neck_cfg.rfp_backbone.pretrained = None 152 | elif cfg.model.neck.get("rfp_backbone"): 153 | if cfg.model.neck.rfp_backbone.get("pretrained"): 154 | cfg.model.neck.rfp_backbone.pretrained = None 155 | 156 | # in case the test dataset is concatenated 157 | samples_per_gpu = 1 158 | if isinstance(cfg.data.test, dict): 159 | cfg.data.test.test_mode = True 160 | samples_per_gpu = cfg.data.test.pop("samples_per_gpu", 1) 161 | if samples_per_gpu > 1: 162 | # Replace 'ImageToTensor' to 'DefaultFormatBundle' 163 | cfg.data.test.pipeline = replace_ImageToTensor(cfg.data.test.pipeline) 164 | elif isinstance(cfg.data.test, list): 165 | for ds_cfg in cfg.data.test: 166 | ds_cfg.test_mode = True 167 | samples_per_gpu = max( 168 | [ds_cfg.pop("samples_per_gpu", 1) for ds_cfg in cfg.data.test] 169 | ) 170 | if samples_per_gpu > 1: 171 | for ds_cfg in cfg.data.test: 172 | ds_cfg.pipeline = replace_ImageToTensor(ds_cfg.pipeline) 173 | 174 | # init distributed env first, since logger depends on the dist info. 175 | if args.launcher == "none": 176 | distributed = False 177 | else: 178 | distributed = True 179 | init_dist(args.launcher, **cfg.dist_params) 180 | 181 | rank, _ = get_dist_info() 182 | # allows not to create 183 | if args.work_dir is not None and rank == 0: 184 | cfg.work_dir = args.work_dir 185 | mmcv.mkdir_or_exist(osp.abspath(args.work_dir)) 186 | timestamp = time.strftime("%Y%m%d_%H%M%S", time.localtime()) 187 | json_file = osp.join(args.work_dir, f"eval_{timestamp}.json") 188 | elif cfg.get("work_dir", None) is None: 189 | cfg.work_dir = osp.join( 190 | "./work_dirs", osp.splitext(osp.basename(args.config))[0] 191 | ) 192 | cfg = patch_config(cfg) 193 | # build the dataloader 194 | dataset = build_dataset(cfg.data.test) 195 | data_loader = build_dataloader( 196 | dataset, 197 | samples_per_gpu=samples_per_gpu, 198 | workers_per_gpu=cfg.data.workers_per_gpu, 199 | dist=distributed, 200 | shuffle=False, 201 | ) 202 | 203 | # build the model and load checkpoint 204 | cfg.model.train_cfg = None 205 | model = build_detector(cfg.model, test_cfg=cfg.get("test_cfg")) 206 | fp16_cfg = cfg.get("fp16", None) 207 | if fp16_cfg is not None: 208 | wrap_fp16_model(model) 209 | checkpoint = load_checkpoint(model, args.checkpoint, map_location="cpu") 210 | if args.fuse_conv_bn: 211 | model = fuse_conv_bn(model) 212 | # old versions did not save class info in checkpoints, this walkaround is 213 | # for backward compatibility 214 | if "CLASSES" in checkpoint.get("meta", {}): 215 | model.CLASSES = checkpoint["meta"]["CLASSES"] 216 | else: 217 | model.CLASSES = dataset.CLASSES 218 | 219 | if not distributed: 220 | model = MMDataParallel(model, device_ids=[0]) 221 | outputs = single_gpu_test( 222 | model, data_loader, args.show, args.show_dir, args.show_score_thr 223 | ) 224 | else: 225 | model = MMDistributedDataParallel( 226 | model.cuda(), 227 | device_ids=[torch.cuda.current_device()], 228 | broadcast_buffers=False, 229 | ) 230 | outputs = multi_gpu_test(model, data_loader, args.tmpdir, args.gpu_collect) 231 | 232 | rank, _ = get_dist_info() 233 | if rank == 0: 234 | if args.out: 235 | print(f"\nwriting results to {args.out}") 236 | mmcv.dump(outputs, args.out) 237 | kwargs = {} if args.eval_options is None else args.eval_options 238 | if args.format_only: 239 | dataset.format_results(outputs, **kwargs) 240 | if args.eval: 241 | eval_kwargs = cfg.get("evaluation", {}).copy() 242 | # hard-code way to remove EvalHook args 243 | for key in [ 244 | "type", 245 | "interval", 246 | "tmpdir", 247 | "start", 248 | "gpu_collect", 249 | "save_best", 250 | "rule", 251 | ]: 252 | eval_kwargs.pop(key, None) 253 | eval_kwargs.update(dict(metric=args.eval, **kwargs)) 254 | metric = dataset.evaluate(outputs, **eval_kwargs) 255 | print(metric) 256 | metric_dict = dict(config=args.config, metric=metric) 257 | if args.work_dir is not None and rank == 0: 258 | mmcv.dump(metric_dict, json_file) 259 | 260 | 261 | if __name__ == "__main__": 262 | main() 263 | -------------------------------------------------------------------------------- /tools/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import copy 3 | import os 4 | import os.path as osp 5 | import time 6 | import warnings 7 | from logging import log 8 | 9 | import mmcv 10 | import torch 11 | from mmcv import Config, DictAction 12 | from mmcv.runner import get_dist_info, init_dist 13 | from mmcv.utils import get_git_hash 14 | from mmdet import __version__ 15 | from mmdet.models import build_detector 16 | from mmdet.utils import collect_env 17 | 18 | from ssod.apis import get_root_logger, set_random_seed, train_detector 19 | from ssod.datasets import build_dataset 20 | from ssod.utils import patch_config 21 | 22 | 23 | def parse_args(): 24 | parser = argparse.ArgumentParser(description="Train a detector") 25 | parser.add_argument("config", help="train config file path") 26 | parser.add_argument("--work-dir", help="the dir to save logs and models") 27 | parser.add_argument("--resume-from", help="the checkpoint file to resume from") 28 | parser.add_argument( 29 | "--no-validate", 30 | action="store_true", 31 | help="whether not to evaluate the checkpoint during training", 32 | ) 33 | group_gpus = parser.add_mutually_exclusive_group() 34 | group_gpus.add_argument( 35 | "--gpus", 36 | type=int, 37 | help="number of gpus to use " "(only applicable to non-distributed training)", 38 | ) 39 | group_gpus.add_argument( 40 | "--gpu-ids", 41 | type=int, 42 | nargs="+", 43 | help="ids of gpus to use " "(only applicable to non-distributed training)", 44 | ) 45 | parser.add_argument("--seed", type=int, default=None, help="random seed") 46 | parser.add_argument( 47 | "--deterministic", 48 | action="store_true", 49 | help="whether to set deterministic options for CUDNN backend.", 50 | ) 51 | parser.add_argument( 52 | "--options", 53 | nargs="+", 54 | action=DictAction, 55 | help="override some settings in the used config, the key-value pair " 56 | "in xxx=yyy format will be merged into config file (deprecate), " 57 | "change to --cfg-options instead.", 58 | ) 59 | parser.add_argument( 60 | "--cfg-options", 61 | nargs="+", 62 | action=DictAction, 63 | help="override some settings in the used config, the key-value pair " 64 | "in xxx=yyy format will be merged into config file. If the value to " 65 | 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' 66 | 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' 67 | "Note that the quotation marks are necessary and that no white space " 68 | "is allowed.", 69 | ) 70 | parser.add_argument( 71 | "--launcher", 72 | choices=["none", "pytorch", "slurm", "mpi"], 73 | default="none", 74 | help="job launcher", 75 | ) 76 | parser.add_argument("--local_rank", type=int, default=0) 77 | args = parser.parse_args() 78 | if "LOCAL_RANK" not in os.environ: 79 | os.environ["LOCAL_RANK"] = str(args.local_rank) 80 | 81 | if args.options and args.cfg_options: 82 | raise ValueError( 83 | "--options and --cfg-options cannot be both " 84 | "specified, --options is deprecated in favor of --cfg-options" 85 | ) 86 | if args.options: 87 | warnings.warn("--options is deprecated in favor of --cfg-options") 88 | args.cfg_options = args.options 89 | 90 | return args 91 | 92 | 93 | def main(): 94 | args = parse_args() 95 | 96 | cfg = Config.fromfile(args.config) 97 | if args.cfg_options is not None: 98 | cfg.merge_from_dict(args.cfg_options) 99 | # import modules from string list. 100 | if cfg.get("custom_imports", None): 101 | from mmcv.utils import import_modules_from_strings 102 | 103 | import_modules_from_strings(**cfg["custom_imports"]) 104 | # set cudnn_benchmark 105 | if cfg.get("cudnn_benchmark", False): 106 | torch.backends.cudnn.benchmark = True 107 | 108 | # work_dir is determined in this priority: CLI > segment in file > filename 109 | if args.work_dir is not None: 110 | # update configs according to CLI args if args.work_dir is not None 111 | cfg.work_dir = args.work_dir 112 | elif cfg.get("work_dir", None) is None: 113 | # use config filename as default work_dir if cfg.work_dir is None 114 | cfg.work_dir = osp.join( 115 | "./work_dirs", osp.splitext(osp.basename(args.config))[0] 116 | ) 117 | cfg = patch_config(cfg) 118 | if args.resume_from is not None: 119 | cfg.resume_from = args.resume_from 120 | if args.gpu_ids is not None: 121 | cfg.gpu_ids = args.gpu_ids 122 | else: 123 | cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus) 124 | 125 | # init distributed env first, since logger depends on the dist info. 126 | if args.launcher == "none": 127 | distributed = False 128 | else: 129 | distributed = True 130 | init_dist(args.launcher, **cfg.dist_params) 131 | # re-set gpu_ids with distributed training mode 132 | _, world_size = get_dist_info() 133 | cfg.gpu_ids = range(world_size) 134 | # create work_dir 135 | mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir)) 136 | # dump config 137 | cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config))) 138 | # init the logger before other steps 139 | timestamp = time.strftime("%Y%m%d_%H%M%S", time.localtime()) 140 | log_file = osp.join(cfg.work_dir, f"{timestamp}.log") 141 | logger = get_root_logger(log_file=log_file, log_level=cfg.log_level) 142 | 143 | # init the meta dict to record some important information such as 144 | # environment info and seed, which will be logged 145 | meta = dict() 146 | # log env info 147 | env_info_dict = collect_env() 148 | env_info = "\n".join([(f"{k}: {v}") for k, v in env_info_dict.items()]) 149 | dash_line = "-" * 60 + "\n" 150 | logger.info(logger.handlers) 151 | logger.info("Environment info:\n" + dash_line + env_info + "\n" + dash_line) 152 | meta["env_info"] = env_info 153 | meta["config"] = cfg.pretty_text 154 | # log some basic info 155 | logger.info(f"Distributed training: {distributed}") 156 | logger.info(f"Config:\n{cfg.pretty_text}") 157 | 158 | # set random seeds 159 | if args.seed is not None: 160 | logger.info( 161 | f"Set random seed to {args.seed}, " f"deterministic: {args.deterministic}" 162 | ) 163 | set_random_seed(args.seed, deterministic=args.deterministic) 164 | cfg.seed = args.seed 165 | meta["seed"] = args.seed 166 | meta["exp_name"] = osp.basename(args.config) 167 | 168 | model = build_detector( 169 | cfg.model, train_cfg=cfg.get("train_cfg"), test_cfg=cfg.get("test_cfg") 170 | ) 171 | model.init_weights() 172 | 173 | datasets = [build_dataset(cfg.data.train)] 174 | if len(cfg.workflow) == 2: 175 | val_dataset = copy.deepcopy(cfg.data.val) 176 | val_dataset.pipeline = cfg.data.train.pipeline 177 | datasets.append(build_dataset(val_dataset)) 178 | if cfg.checkpoint_config is not None: 179 | # save mmdet version, config file content and class names in 180 | # checkpoints as meta data 181 | cfg.checkpoint_config.meta = dict( 182 | mmdet_version=__version__ + get_git_hash()[:7], CLASSES=datasets[0].CLASSES 183 | ) 184 | # add an attribute for visualization convenience 185 | model.CLASSES = datasets[0].CLASSES 186 | train_detector( 187 | model, 188 | datasets, 189 | cfg, 190 | distributed=distributed, 191 | validate=(not args.no_validate), 192 | timestamp=timestamp, 193 | meta=meta, 194 | ) 195 | 196 | 197 | if __name__ == "__main__": 198 | main() 199 | --------------------------------------------------------------------------------