├── .github └── ISSUE_TEMPLATE │ ├── bug-report-chinese.yml │ ├── bug-report-english.yml │ ├── feature-request-chinese.yml │ ├── feature-request-english.yml │ ├── question-chinese.yml │ └── question-english.yml ├── .gitignore ├── LICENSE ├── README.md ├── configs ├── dab_def_detr_pp │ └── dab_def_detr_pp_resnet50_800_1333.py ├── deformable_detr_pp │ └── def_detr_pp_resnet_800_1333.py ├── dino_pp │ └── dino_pp_resnet50_800_1333.py ├── dn_def_detr_pp │ └── dn_def_detr_pp_resnet50_800_1333.py ├── relation_detr │ ├── relation_detr_convnext_l_800_1333.py │ ├── relation_detr_focalnet_large_lrf_fl4_1200_2000.py │ ├── relation_detr_focalnet_large_lrf_fl4_800_1333.py │ ├── relation_detr_resnet50_800_1333.py │ └── relation_detr_swin_l_800_1333.py └── train_config.py ├── datasets └── coco.py ├── images └── convergence_curve.png ├── inference.ipynb ├── inference.py ├── main.py ├── models ├── backbones │ ├── base_backbone.py │ ├── convnext.py │ ├── focalnet.py │ ├── resnet.py │ ├── swin.py │ └── vit.py ├── bricks │ ├── base_transformer.py │ ├── basic.py │ ├── dab_transformer.py │ ├── deform_conv2d_pack.py │ ├── deformable_transformer.py │ ├── denoising.py │ ├── dino_transformer.py │ ├── dn_transformer.py │ ├── losses.py │ ├── misc.py │ ├── ms_deform_attn.py │ ├── ops │ │ ├── __init__.py │ │ └── cuda │ │ │ ├── ms_deform_attn_cuda.cu │ │ │ └── ms_deform_im2col_cuda.cuh │ ├── position_encoding.py │ ├── post_process.py │ ├── relation_transformer.py │ └── set_criterion.py ├── detectors │ ├── base_detector.py │ ├── dab_deformable_detr.py │ ├── deformable_detr.py │ ├── dino.py │ ├── dn_deformable_detr.py │ └── relation_detr.py ├── matcher │ └── hungarian_matcher.py └── necks │ └── channel_mapper.py ├── optimizer └── param_dict.py ├── requirements.txt ├── test.py ├── tools ├── benchmark_model.py ├── pytorch2onnx.py └── visualize_datasets.py ├── transforms ├── __init__.py ├── _functional_pil.py ├── _functional_tensor.py ├── _functional_video.py ├── _presets.py ├── _transforms_video.py ├── _utils.py ├── album_transform.py ├── albumentations_warpper.py ├── autoaugment.py ├── convert_coco_polys_to_mask.py ├── crop.py ├── functional.py ├── functional_pil.py ├── functional_tensor.py ├── mix_transform.py ├── presets.py ├── simple_copy_paste.py ├── transforms.py ├── utils.py └── v2 │ ├── __init__.py │ ├── _augment.py │ ├── _auto_augment.py │ ├── _color.py │ ├── _container.py │ ├── _deprecated.py │ ├── _geometry.py │ ├── _meta.py │ ├── _misc.py │ ├── _temporal.py │ ├── _transform.py │ ├── _type_conversion.py │ ├── _utils.py │ ├── functional │ ├── __init__.py │ ├── _augment.py │ ├── _color.py │ ├── _deprecated.py │ ├── _geometry.py │ ├── _meta.py │ ├── _misc.py │ ├── _temporal.py │ ├── _type_conversion.py │ └── _utils.py │ └── utils.py ├── util ├── coco_eval.py ├── coco_utils.py ├── collate_fn.py ├── collect_env.py ├── datapoints.py ├── engine.py ├── group_by_aspect_ratio.py ├── lazy_load.py ├── logger.py ├── misc.py ├── tune_mode_convbn.py ├── utils.py └── visualize.py └── visualization └── mc_distribution.ipynb /.github/ISSUE_TEMPLATE/bug-report-chinese.yml: -------------------------------------------------------------------------------- 1 | name: 提交Bug报告 (中文) 2 | description: 提交Bug来帮助我们改进代码. 3 | title: "[Bug]: " 4 | labels: ["bug", "triage"] 5 | body: 6 | 7 | - type: textarea 8 | attributes: 9 | label: Bug 10 | description: 提供和bug相关的带有错误消息的输出或屏幕截图。 11 | placeholder: | 12 | 请点击本仓库主页右上角的star以支持本项目。 13 | 14 | 请提供尽可能多的信息(屏幕截图、日志、命令等),以帮助我们解决问题。 15 | validations: 16 | required: true 17 | 18 | - type: textarea 19 | attributes: 20 | label: 环境信息 21 | description: 请提供出现该错误时的您的软件和硬件环境信息。 22 | placeholder: | 23 | 将命令行输出/日志最开始的环境信息粘贴到此处,类似这样: 24 | ``` 25 | sys.platform linux 26 | Python 3.8.18 | packaged by conda-forge | (default, Dec 23 2023, 17:21:28) [GCC 12.3.0] 27 | numpy 1.24.4 28 | PyTorch 1.12.1+cu113 @/home/ubuntu22/anaconda3/envs/sl/lib/python3.8/site-packages/torch 29 | PyTorch debug build False 30 | torch._C._GLIBCXX_USE_CXX11_ABI False 31 | GPU available Yes 32 | GPU 0 NVIDIA GeForce RTX 3090 (arch=8.6) 33 | ``` 34 | validations: 35 | required: false 36 | 37 | - type: textarea 38 | attributes: 39 | label: 补充信息 40 | description: 其他您想补充的信息请在此处填写。 41 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug-report-english.yml: -------------------------------------------------------------------------------- 1 | name: Bug Report (English) 2 | description: Report a bug to help us improve the code. 3 | title: "[Bug]: " 4 | labels: ["bug", "triage"] 5 | body: 6 | 7 | - type: textarea 8 | attributes: 9 | label: Bug 10 | description: Provide console output with error messages and/or screenshots of the bug. 11 | placeholder: | 12 | Star the repo to help more people discover this project. 13 | 14 | Include as much information as possible (screenshots, logs, tracebacks etc.) to receive the most helpful response. 15 | validations: 16 | required: true 17 | 18 | - type: textarea 19 | attributes: 20 | label: Environment 21 | description: Please specify the software and hardware you used to produce the bug. 22 | placeholder: | 23 | Paste environment information from the beginning of console output, i.e.: 24 | ``` 25 | sys.platform linux 26 | Python 3.8.18 | packaged by conda-forge | (default, Dec 23 2023, 17:21:28) [GCC 12.3.0] 27 | numpy 1.24.4 28 | PyTorch 1.12.1+cu113 @/home/ubuntu22/anaconda3/envs/sl/lib/python3.8/site-packages/torch 29 | PyTorch debug build False 30 | torch._C._GLIBCXX_USE_CXX11_ABI False 31 | GPU available Yes 32 | GPU 0 NVIDIA GeForce RTX 3090 (arch=8.6) 33 | ``` 34 | validations: 35 | required: false 36 | 37 | - type: textarea 38 | attributes: 39 | label: Additional 40 | description: Anything else you would like to share? 41 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature-request-chinese.yml: -------------------------------------------------------------------------------- 1 | name: 功能请求 (中文) 2 | description: 增加新功能的建议 3 | # title: "[Bug]: " 4 | labels: ["enhancement"] 5 | body: 6 | 7 | - type: textarea 8 | attributes: 9 | label: 功能描述 10 | description: 简短地描述您想要的功能。 11 | placeholder: | 12 | 请点击本仓库主页右上角的star以支持本项目。 13 | 14 | 想让本项目增加什么新功能? 15 | validations: 16 | required: true 17 | 18 | - type: textarea 19 | attributes: 20 | label: 用例 21 | description: 为您想要的功能提供一个例子,这会帮助我们理解和优先实现功能。 22 | placeholder: | 23 | 如何使用这个功能,谁会使用它? 24 | validations: 25 | required: false 26 | 27 | - type: textarea 28 | attributes: 29 | label: 补充信息 30 | description: 其他您想补充的信息请在此处填写。 31 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature-request-english.yml: -------------------------------------------------------------------------------- 1 | name: Feature Request (English) 2 | description: Do you want a new feature? 3 | # title: "[Bug]: " 4 | labels: ["enhancement"] 5 | body: 6 | 7 | - type: textarea 8 | attributes: 9 | label: Description 10 | description: A short description of your feature. 11 | placeholder: | 12 | Star the repo to help more people discover this project. 13 | 14 | What new feature would you like to see in YOLOv8? 15 | validations: 16 | required: true 17 | 18 | - type: textarea 19 | attributes: 20 | label: Use case 21 | description: | 22 | Describe the use case of your feature request. It will help us understand and prioritize the feature request. 23 | placeholder: | 24 | How would this feature be used, and who would use it? 25 | 26 | - type: textarea 27 | attributes: 28 | label: Additional 29 | description: Anything else you would like to share? 30 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/question-chinese.yml: -------------------------------------------------------------------------------- 1 | name: 提问问题 (中文) 2 | description: 提问和本项目相关的问题. 3 | # title: "[Bug]: " 4 | labels: ["question"] 5 | body: 6 | 7 | - type: textarea 8 | attributes: 9 | label: Question 10 | description: 请描述您的问题 11 | placeholder: | 12 | 请点击本仓库主页右上角的star以支持本项目。 13 | 14 | 请提供尽可能多的信息(屏幕截图、日志、命令等),以帮助我们进行回复。 15 | validations: 16 | required: true 17 | 18 | - type: textarea 19 | attributes: 20 | label: 补充信息 21 | description: 其他您想补充的信息请在此处填写。 22 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/question-english.yml: -------------------------------------------------------------------------------- 1 | name: Question (English) 2 | description: Ask a question about the project. 3 | # title: "[Bug]: " 4 | labels: ["question"] 5 | body: 6 | 7 | - type: textarea 8 | attributes: 9 | label: Question 10 | description: What is your question? 11 | placeholder: | 12 | Star the repo to help more people discover this project. 13 | 14 | Include as much information as possible (screenshots, logs, tracebacks etc.) to receive the most helpful response. 15 | validations: 16 | required: true 17 | 18 | - type: textarea 19 | attributes: 20 | label: Additional 21 | description: Anything else you would like to share? 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | 164 | 165 | # project-related 166 | checkpoints/ -------------------------------------------------------------------------------- /configs/dab_def_detr_pp/dab_def_detr_pp_resnet50_800_1333.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torchvision.ops import FrozenBatchNorm2d 3 | 4 | from models.backbones.resnet import ResNetBackbone 5 | from models.bricks.dab_transformer import ( 6 | DabTransformer, 7 | DabTransformerDecoder, 8 | DabTransformerDecoderLayer, 9 | DabTransformerEncoder, 10 | DabTransformerEncoderLayer, 11 | ) 12 | from models.bricks.position_encoding import PositionEmbeddingSine 13 | from models.bricks.post_process import PostProcess 14 | from models.bricks.set_criterion import SetCriterion 15 | from models.detectors.dab_deformable_detr import DabDeformableDETR 16 | from models.matcher.hungarian_matcher import HungarianMatcher 17 | from models.necks.channel_mapper import ChannelMapper 18 | 19 | # most changed parameters 20 | embed_dim = 256 21 | num_classes = 91 22 | num_queries = 300 23 | num_feature_levels = 4 24 | transformer_enc_layers = 6 25 | transformer_dec_layers = 6 26 | num_heads = 8 27 | dim_feedforward = 2048 28 | 29 | # instantiate model components 30 | position_embedding = PositionEmbeddingSine(embed_dim // 2, temperature=10000, normalize=True, offset=-0.5) 31 | 32 | backbone = ResNetBackbone( 33 | "resnet50", norm_layer=FrozenBatchNorm2d, return_indices=(1, 2, 3), freeze_indices=(0,) 34 | ) 35 | 36 | neck = ChannelMapper( 37 | in_channels=backbone.num_channels, 38 | out_channels=embed_dim, 39 | num_outs=num_feature_levels, 40 | ) 41 | 42 | transformer = DabTransformer( 43 | encoder=DabTransformerEncoder( 44 | encoder_layer=DabTransformerEncoderLayer( 45 | embed_dim=embed_dim, 46 | n_heads=num_heads, 47 | dropout=0.0, 48 | activation=nn.ReLU(inplace=True), 49 | n_levels=num_feature_levels, 50 | n_points=4, 51 | d_ffn=dim_feedforward, 52 | ), 53 | num_layers=transformer_enc_layers, 54 | ), 55 | decoder=DabTransformerDecoder( 56 | decoder_layer=DabTransformerDecoderLayer( 57 | embed_dim=embed_dim, 58 | n_heads=num_heads, 59 | dropout=0.0, 60 | activation=nn.ReLU(inplace=True), 61 | n_levels=num_feature_levels, 62 | n_points=4, 63 | d_ffn=dim_feedforward, 64 | ), 65 | num_layers=transformer_dec_layers, 66 | num_classes=num_classes, 67 | ), 68 | num_classes=num_classes, 69 | num_feature_levels=num_feature_levels, 70 | two_stage_num_proposals=num_queries, 71 | ) 72 | 73 | matcher = HungarianMatcher(cost_class=2, cost_bbox=5, cost_giou=2, focal_alpha=0.25, focal_gamma=2.0) 74 | 75 | weight_dict = {"loss_class": 1, "loss_bbox": 5, "loss_giou": 2} 76 | aux_weight_dict = {} 77 | for i in range(transformer.decoder.num_layers - 1): 78 | aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()}) 79 | weight_dict.update(aux_weight_dict) 80 | weight_dict.update({"loss_class_enc": 1, "loss_bbox_enc": 5, "loss_giou_enc": 2}) 81 | 82 | criterion = SetCriterion( 83 | num_classes=num_classes, matcher=matcher, weight_dict=weight_dict, alpha=0.25, gamma=2.0 84 | ) 85 | postprocessor = PostProcess(select_box_nums_for_evaluation=300) 86 | 87 | # combine above components to instantiate the model 88 | model = DabDeformableDETR( 89 | backbone=backbone, 90 | neck=neck, 91 | position_embedding=position_embedding, 92 | transformer=transformer, 93 | criterion=criterion, 94 | postprocessor=postprocessor, 95 | num_classes=num_classes, 96 | min_size=800, 97 | max_size=1333, 98 | ) 99 | -------------------------------------------------------------------------------- /configs/deformable_detr_pp/def_detr_pp_resnet_800_1333.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torchvision.ops import FrozenBatchNorm2d 3 | 4 | from models.backbones.resnet import ResNetBackbone 5 | from models.bricks.deformable_transformer import ( 6 | DeformableTransformer, 7 | DeformableTransformerDecoder, 8 | DeformableTransformerDecoderLayer, 9 | DeformableTransformerEncoder, 10 | DeformableTransformerEncoderLayer, 11 | ) 12 | from models.bricks.position_encoding import PositionEmbeddingSine 13 | from models.bricks.post_process import PostProcess 14 | from models.bricks.set_criterion import SetCriterion 15 | from models.detectors.deformable_detr import DeformableDETR 16 | from models.matcher.hungarian_matcher import HungarianMatcher 17 | from models.necks.channel_mapper import ChannelMapper 18 | 19 | # mostly changed parameters 20 | embed_dim = 256 21 | num_classes = 91 22 | num_queries = 300 23 | num_feature_levels = 4 24 | transformer_enc_layers = 6 25 | transformer_dec_layers = 6 26 | num_heads = 8 27 | dim_feedforward = 2048 28 | 29 | # instantiate model components 30 | position_embedding = PositionEmbeddingSine( 31 | embed_dim // 2, temperature=10000, normalize=True, offset=-0.5 32 | ) 33 | 34 | backbone = ResNetBackbone( 35 | "resnet50", norm_layer=FrozenBatchNorm2d, return_indices=(1, 2, 3), freeze_indices=(0,) 36 | ) 37 | 38 | neck = ChannelMapper( 39 | in_channels=backbone.num_channels, 40 | out_channels=embed_dim, 41 | num_outs=num_feature_levels, 42 | ) 43 | 44 | transformer = DeformableTransformer( 45 | encoder=DeformableTransformerEncoder( 46 | encoder_layer=DeformableTransformerEncoderLayer( 47 | embed_dim=embed_dim, 48 | n_heads=num_heads, 49 | dropout=0.0, 50 | activation=nn.ReLU(inplace=True), 51 | n_levels=num_feature_levels, 52 | n_points=4, 53 | d_ffn=dim_feedforward, 54 | ), 55 | num_layers=transformer_enc_layers, 56 | ), 57 | decoder=DeformableTransformerDecoder( 58 | decoder_layer=DeformableTransformerDecoderLayer( 59 | embed_dim=embed_dim, 60 | n_heads=num_heads, 61 | dropout=0.0, 62 | activation=nn.ReLU(inplace=True), 63 | n_levels=num_feature_levels, 64 | n_points=4, 65 | d_ffn=dim_feedforward, 66 | ), 67 | num_layers=transformer_dec_layers, 68 | num_classes=num_classes, 69 | ), 70 | num_classes=num_classes, 71 | num_feature_levels=num_feature_levels, 72 | two_stage_num_proposals=num_queries, 73 | ) 74 | 75 | matcher = HungarianMatcher( 76 | cost_class=2, cost_bbox=5, cost_giou=2, focal_alpha=0.25, focal_gamma=2.0 77 | ) 78 | 79 | # construct weight_dic 80 | weight_dict = {"loss_class": 1, "loss_bbox": 5, "loss_giou": 2} 81 | aux_weight_dict = {} 82 | for i in range(transformer.decoder.num_layers - 1): 83 | aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()}) 84 | weight_dict.update(aux_weight_dict) 85 | weight_dict.update({"loss_class_enc": 1, "loss_bbox_enc": 5, "loss_giou_enc": 2}) 86 | 87 | criterion = SetCriterion( 88 | num_classes=num_classes, 89 | matcher=matcher, 90 | weight_dict=weight_dict, 91 | alpha=0.25, 92 | gamma=2.0, 93 | two_stage_binary_cls=True, 94 | ) 95 | postprocessor = PostProcess(select_box_nums_for_evaluation=300) 96 | 97 | # combine above components to instantiate the model 98 | model = DeformableDETR( 99 | backbone=backbone, 100 | neck=neck, 101 | position_embedding=position_embedding, 102 | transformer=transformer, 103 | criterion=criterion, 104 | postprocessor=postprocessor, 105 | num_classes=num_classes, 106 | min_size=800, 107 | max_size=1333, 108 | ) 109 | -------------------------------------------------------------------------------- /configs/dino_pp/dino_pp_resnet50_800_1333.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torchvision.ops import FrozenBatchNorm2d 3 | 4 | from models.backbones.resnet import ResNetBackbone 5 | from models.bricks.dino_transformer import ( 6 | DINOTransformer, 7 | DINOTransformerDecoder, 8 | DINOTransformerDecoderLayer, 9 | DINOTransformerEncoder, 10 | DINOTransformerEncoderLayer, 11 | ) 12 | from models.bricks.position_encoding import PositionEmbeddingSine 13 | from models.bricks.post_process import PostProcess 14 | from models.bricks.set_criterion import SetCriterion 15 | from models.detectors.dino import DINO 16 | from models.matcher.hungarian_matcher import HungarianMatcher 17 | from models.necks.channel_mapper import ChannelMapper 18 | 19 | # mostly changed parameters 20 | embed_dim = 256 21 | num_classes = 91 22 | num_queries = 900 23 | num_feature_levels = 4 24 | transformer_enc_layers = 6 25 | transformer_dec_layers = 6 26 | num_heads = 8 27 | dim_feedforward = 2048 28 | 29 | # instantiate model components 30 | position_embedding = PositionEmbeddingSine(embed_dim // 2, temperature=10000, normalize=True, offset=-0.5) 31 | 32 | backbone = ResNetBackbone( 33 | arch="resnet50", norm_layer=FrozenBatchNorm2d, return_indices=(1, 2, 3), freeze_indices=(0,) 34 | ) 35 | 36 | neck = ChannelMapper( 37 | in_channels=backbone.num_channels, 38 | out_channels=embed_dim, 39 | num_outs=num_feature_levels, 40 | ) 41 | 42 | transformer = DINOTransformer( 43 | encoder=DINOTransformerEncoder( 44 | encoder_layer=DINOTransformerEncoderLayer( 45 | embed_dim=embed_dim, 46 | n_heads=num_heads, 47 | dropout=0.0, 48 | activation=nn.ReLU(inplace=True), 49 | n_levels=num_feature_levels, 50 | n_points=4, 51 | d_ffn=dim_feedforward, 52 | ), 53 | num_layers=transformer_enc_layers, 54 | ), 55 | decoder=DINOTransformerDecoder( 56 | decoder_layer=DINOTransformerDecoderLayer( 57 | embed_dim=embed_dim, 58 | n_heads=num_heads, 59 | dropout=0.0, 60 | activation=nn.ReLU(inplace=True), 61 | n_levels=num_feature_levels, 62 | n_points=4, 63 | d_ffn=dim_feedforward, 64 | ), 65 | num_layers=transformer_dec_layers, 66 | num_classes=num_classes, 67 | ), 68 | num_classes=num_classes, 69 | num_feature_levels=num_feature_levels, 70 | two_stage_num_proposals=num_queries, 71 | ) 72 | 73 | matcher = HungarianMatcher(cost_class=2, cost_bbox=5, cost_giou=2, focal_alpha=0.25, focal_gamma=2.0) 74 | 75 | weight_dict = {"loss_class": 1, "loss_bbox": 5, "loss_giou": 2} 76 | weight_dict.update({"loss_class_dn": 1, "loss_bbox_dn": 5, "loss_giou_dn": 2}) 77 | weight_dict.update({ 78 | k + f"_{i}": v 79 | for i in range(transformer_dec_layers - 1) 80 | for k, v in weight_dict.items() 81 | }) 82 | weight_dict.update({"loss_class_enc": 1, "loss_bbox_enc": 5, "loss_giou_enc": 2}) 83 | 84 | criterion = SetCriterion(num_classes, matcher=matcher, weight_dict=weight_dict, alpha=0.25, gamma=2.0) 85 | postprocessor = PostProcess(select_box_nums_for_evaluation=300) 86 | 87 | # combine above components to instantiate the model 88 | model = DINO( 89 | backbone=backbone, 90 | neck=neck, 91 | position_embedding=position_embedding, 92 | transformer=transformer, 93 | criterion=criterion, 94 | postprocessor=postprocessor, 95 | num_classes=num_classes, 96 | num_queries=num_queries, 97 | min_size=800, 98 | max_size=1333, 99 | ) 100 | -------------------------------------------------------------------------------- /configs/dn_def_detr_pp/dn_def_detr_pp_resnet50_800_1333.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torchvision.ops import FrozenBatchNorm2d 3 | 4 | from models.backbones.resnet import ResNetBackbone 5 | from models.bricks.dn_transformer import ( 6 | DNTransformer, 7 | DNTransformerDecoder, 8 | DNTransformerDecoderLayer, 9 | DNTransformerEncoder, 10 | DNTransformerEncoderLayer, 11 | ) 12 | from models.bricks.position_encoding import PositionEmbeddingSine 13 | from models.bricks.post_process import PostProcess 14 | from models.bricks.set_criterion import SetCriterion 15 | from models.detectors.dn_deformable_detr import DNDeformableDETR 16 | from models.matcher.hungarian_matcher import HungarianMatcher 17 | from models.necks.channel_mapper import ChannelMapper 18 | 19 | # most changed parameters 20 | embed_dim = 256 21 | num_classes = 91 22 | num_queries = 300 23 | num_feature_levels = 4 24 | transformer_enc_layers = 6 25 | transformer_dec_layers = 6 26 | num_heads = 8 27 | dim_feedforward = 2048 28 | 29 | # instantiate model components 30 | position_embedding = PositionEmbeddingSine( 31 | embed_dim // 2, temperature=10000, normalize=True, offset=-0.5 32 | ) 33 | 34 | backbone = ResNetBackbone( 35 | "resnet50", norm_layer=FrozenBatchNorm2d, return_indices=(1, 2, 3), freeze_indices=(0,) 36 | ) 37 | 38 | neck = ChannelMapper( 39 | in_channels=backbone.num_channels, 40 | out_channels=embed_dim, 41 | num_outs=num_feature_levels, 42 | ) 43 | 44 | transformer = DNTransformer( 45 | encoder=DNTransformerEncoder( 46 | encoder_layer=DNTransformerEncoderLayer( 47 | embed_dim=embed_dim, 48 | n_heads=num_heads, 49 | dropout=0.0, 50 | activation=nn.ReLU(inplace=True), 51 | n_levels=num_feature_levels, 52 | n_points=4, 53 | d_ffn=dim_feedforward, 54 | ), 55 | num_layers=transformer_enc_layers, 56 | ), 57 | decoder=DNTransformerDecoder( 58 | decoder_layer=DNTransformerDecoderLayer( 59 | embed_dim=embed_dim, 60 | n_heads=num_heads, 61 | dropout=0.0, 62 | activation=nn.ReLU(inplace=True), 63 | n_levels=num_feature_levels, 64 | n_points=4, 65 | d_ffn=dim_feedforward, 66 | ), 67 | num_layers=transformer_dec_layers, 68 | num_classes=num_classes, 69 | ), 70 | num_classes=num_classes, 71 | num_feature_levels=num_feature_levels, 72 | two_stage_num_proposals=num_queries, 73 | ) 74 | 75 | matcher = HungarianMatcher( 76 | cost_class=2, cost_bbox=5, cost_giou=2, focal_alpha=0.25, focal_gamma=2.0 77 | ) 78 | 79 | # construct weight_dict 80 | weight_dict = {"loss_class": 1, "loss_bbox": 5, "loss_giou": 2} 81 | weight_dict.update({"loss_class_dn": 1, "loss_bbox_dn": 5, "loss_giou_dn": 2}) 82 | for i in range(transformer.decoder.num_layers - 1): 83 | weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()}) 84 | criterion = SetCriterion( 85 | num_classes=num_classes, matcher=matcher, weight_dict=weight_dict, alpha=0.25, gamma=2.0 86 | ) 87 | postprocessor = PostProcess(select_box_nums_for_evaluation=300) 88 | 89 | # combine above components to instantiate the model 90 | model = DNDeformableDETR( 91 | backbone=backbone, 92 | neck=neck, 93 | position_embedding=position_embedding, 94 | transformer=transformer, 95 | criterion=criterion, 96 | postprocessor=postprocessor, 97 | num_classes=num_classes, 98 | num_queries=num_queries, 99 | min_size=800, 100 | max_size=1333, 101 | ) 102 | -------------------------------------------------------------------------------- /configs/relation_detr/relation_detr_convnext_l_800_1333.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | from models.backbones.convnext import ConvNeXtBackbone 4 | from models.bricks.position_encoding import PositionEmbeddingSine 5 | from models.bricks.post_process import PostProcess 6 | from models.bricks.relation_transformer import ( 7 | RelationTransformer, 8 | RelationTransformerDecoder, 9 | RelationTransformerDecoderLayer, 10 | RelationTransformerEncoder, 11 | RelationTransformerEncoderLayer, 12 | ) 13 | from models.bricks.set_criterion import HybridSetCriterion 14 | from models.detectors.relation_detr import RelationDETR 15 | from models.matcher.hungarian_matcher import HungarianMatcher 16 | from models.necks.channel_mapper import ChannelMapper 17 | 18 | # mostly changed parameters 19 | embed_dim = 256 20 | num_classes = 91 21 | num_queries = 900 22 | hybrid_num_proposals = 1500 23 | hybrid_assign = 6 24 | num_feature_levels = 4 25 | transformer_enc_layers = 6 26 | transformer_dec_layers = 6 27 | num_heads = 8 28 | dim_feedforward = 2048 29 | 30 | # instantiate model components 31 | position_embedding = PositionEmbeddingSine( 32 | embed_dim // 2, temperature=10000, normalize=True, offset=-0.5 33 | ) 34 | 35 | backbone = ConvNeXtBackbone("conv_l", return_indices=(1, 2, 3), freeze_indices=(0,)) 36 | 37 | neck = ChannelMapper( 38 | in_channels=backbone.num_channels, 39 | out_channels=embed_dim, 40 | num_outs=num_feature_levels, 41 | ) 42 | 43 | transformer = RelationTransformer( 44 | encoder=RelationTransformerEncoder( 45 | encoder_layer=RelationTransformerEncoderLayer( 46 | embed_dim=embed_dim, 47 | n_heads=num_heads, 48 | dropout=0.0, 49 | activation=nn.ReLU(inplace=True), 50 | n_levels=num_feature_levels, 51 | n_points=4, 52 | d_ffn=dim_feedforward, 53 | ), 54 | num_layers=transformer_enc_layers, 55 | ), 56 | decoder=RelationTransformerDecoder( 57 | decoder_layer=RelationTransformerDecoderLayer( 58 | embed_dim=embed_dim, 59 | n_heads=num_heads, 60 | dropout=0.0, 61 | activation=nn.ReLU(inplace=True), 62 | n_levels=num_feature_levels, 63 | n_points=4, 64 | d_ffn=dim_feedforward, 65 | ), 66 | num_layers=transformer_dec_layers, 67 | num_classes=num_classes, 68 | ), 69 | num_classes=num_classes, 70 | num_feature_levels=num_feature_levels, 71 | two_stage_num_proposals=num_queries, 72 | hybrid_num_proposals=hybrid_num_proposals, 73 | ) 74 | 75 | matcher = HungarianMatcher( 76 | cost_class=2, cost_bbox=5, cost_giou=2, focal_alpha=0.25, focal_gamma=2.0 77 | ) 78 | 79 | # construct weight_dict for loss 80 | weight_dict = {"loss_class": 1, "loss_bbox": 5, "loss_giou": 2} 81 | weight_dict.update({"loss_class_dn": 1, "loss_bbox_dn": 5, "loss_giou_dn": 2}) 82 | aux_weight_dict = {} 83 | for i in range(transformer.decoder.num_layers - 1): 84 | aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()}) 85 | weight_dict.update(aux_weight_dict) 86 | weight_dict.update({"loss_class_enc": 1, "loss_bbox_enc": 5, "loss_giou_enc": 2}) 87 | weight_dict.update({k + "_hybrid": v for k, v in weight_dict.items()}) 88 | 89 | criterion = HybridSetCriterion( 90 | num_classes=num_classes, matcher=matcher, weight_dict=weight_dict, alpha=0.25, gamma=2.0 91 | ) 92 | postprocessor = PostProcess(select_box_nums_for_evaluation=300) 93 | 94 | # combine above components to instantiate the model 95 | model = RelationDETR( 96 | backbone=backbone, 97 | neck=neck, 98 | position_embedding=position_embedding, 99 | transformer=transformer, 100 | criterion=criterion, 101 | postprocessor=postprocessor, 102 | num_classes=num_classes, 103 | num_queries=num_queries, 104 | hybrid_assign=hybrid_assign, 105 | denoising_nums=100, 106 | min_size=800, 107 | max_size=1333, 108 | ) 109 | -------------------------------------------------------------------------------- /configs/relation_detr/relation_detr_focalnet_large_lrf_fl4_1200_2000.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | from models.backbones.focalnet import FocalNetBackbone 4 | from models.bricks.position_encoding import PositionEmbeddingSine 5 | from models.bricks.post_process import PostProcess 6 | from models.bricks.relation_transformer import ( 7 | RelationTransformer, 8 | RelationTransformerDecoder, 9 | RelationTransformerEncoder, 10 | RelationTransformerEncoderLayer, 11 | RelationTransformerDecoderLayer, 12 | ) 13 | from models.bricks.set_criterion import HybridSetCriterion 14 | from models.detectors.relation_detr import RelationDETR 15 | from models.matcher.hungarian_matcher import HungarianMatcher 16 | from models.necks.channel_mapper import ChannelMapper 17 | 18 | # mostly changed parameters 19 | embed_dim = 256 20 | num_classes = 91 21 | num_queries = 900 22 | hybrid_num_proposals = 1500 23 | hybrid_assign = 6 24 | num_feature_levels = 5 25 | transformer_enc_layers = 6 26 | transformer_dec_layers = 6 27 | num_heads = 8 28 | dim_feedforward = 2048 29 | 30 | # instantiate model components 31 | position_embedding = PositionEmbeddingSine( 32 | embed_dim // 2, temperature=10000, normalize=True, offset=-0.5 33 | ) 34 | 35 | backbone = FocalNetBackbone("focalnet_large_lrf_fl4", weights=False, return_indices=(0, 1, 2, 3)) 36 | 37 | neck = ChannelMapper(backbone.num_channels, out_channels=embed_dim, num_outs=num_feature_levels) 38 | 39 | transformer = RelationTransformer( 40 | encoder=RelationTransformerEncoder( 41 | encoder_layer=RelationTransformerEncoderLayer( 42 | embed_dim=embed_dim, 43 | n_heads=num_heads, 44 | dropout=0.0, 45 | activation=nn.ReLU(inplace=True), 46 | n_levels=num_feature_levels, 47 | n_points=4, 48 | d_ffn=dim_feedforward, 49 | ), 50 | num_layers=transformer_enc_layers, 51 | ), 52 | decoder=RelationTransformerDecoder( 53 | decoder_layer=RelationTransformerDecoderLayer( 54 | embed_dim=embed_dim, 55 | n_heads=num_heads, 56 | dropout=0.0, 57 | activation=nn.ReLU(inplace=True), 58 | n_levels=num_feature_levels, 59 | n_points=4, 60 | d_ffn=dim_feedforward, 61 | ), 62 | num_layers=transformer_dec_layers, 63 | num_classes=num_classes, 64 | ), 65 | num_classes=num_classes, 66 | num_feature_levels=num_feature_levels, 67 | two_stage_num_proposals=num_queries, 68 | hybrid_num_proposals=hybrid_num_proposals, 69 | ) 70 | 71 | matcher = HungarianMatcher( 72 | cost_class=2, cost_bbox=5, cost_giou=2, focal_alpha=0.25, focal_gamma=2.0 73 | ) 74 | 75 | # construct weight_dict for loss 76 | weight_dict = {"loss_class": 1, "loss_bbox": 5, "loss_giou": 2} 77 | weight_dict.update({"loss_class_dn": 1, "loss_bbox_dn": 5, "loss_giou_dn": 2}) 78 | aux_weight_dict = {} 79 | for i in range(transformer.decoder.num_layers - 1): 80 | aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()}) 81 | weight_dict.update(aux_weight_dict) 82 | weight_dict.update({"loss_class_enc": 1, "loss_bbox_enc": 5, "loss_giou_enc": 2}) 83 | weight_dict.update({k + "_hybrid": v for k, v in weight_dict.items()}) 84 | 85 | criterion = HybridSetCriterion( 86 | num_classes=num_classes, matcher=matcher, weight_dict=weight_dict, alpha=0.25, gamma=2.0 87 | ) 88 | postprocessor = PostProcess(select_box_nums_for_evaluation=300) 89 | 90 | # combine above components to instantiate the model 91 | model = RelationDETR( 92 | backbone=backbone, 93 | neck=neck, 94 | position_embedding=position_embedding, 95 | transformer=transformer, 96 | criterion=criterion, 97 | postprocessor=postprocessor, 98 | num_classes=num_classes, 99 | num_queries=num_queries, 100 | hybrid_assign=hybrid_assign, 101 | denoising_nums=1000, 102 | min_size=1200, 103 | max_size=2000, 104 | ) 105 | -------------------------------------------------------------------------------- /configs/relation_detr/relation_detr_focalnet_large_lrf_fl4_800_1333.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | from models.backbones.focalnet import FocalNetBackbone 4 | from models.bricks.position_encoding import PositionEmbeddingSine 5 | from models.bricks.post_process import PostProcess 6 | from models.bricks.relation_transformer import ( 7 | RelationTransformer, 8 | RelationTransformerDecoder, 9 | RelationTransformerEncoder, 10 | RelationTransformerEncoderLayer, 11 | RelationTransformerDecoderLayer, 12 | ) 13 | from models.bricks.set_criterion import HybridSetCriterion 14 | from models.detectors.relation_detr import RelationDETR 15 | from models.matcher.hungarian_matcher import HungarianMatcher 16 | from models.necks.channel_mapper import ChannelMapper 17 | 18 | # mostly changed parameters 19 | embed_dim = 256 20 | num_classes = 91 21 | num_queries = 900 22 | hybrid_num_proposals = 1500 23 | hybrid_assign = 6 24 | num_feature_levels = 5 25 | transformer_enc_layers = 6 26 | transformer_dec_layers = 6 27 | num_heads = 8 28 | dim_feedforward = 2048 29 | 30 | # instantiate model components 31 | position_embedding = PositionEmbeddingSine( 32 | embed_dim // 2, temperature=10000, normalize=True, offset=-0.5 33 | ) 34 | 35 | backbone = FocalNetBackbone("focalnet_large_lrf_fl4", weights=False, return_indices=(0, 1, 2, 3)) 36 | 37 | neck = ChannelMapper(backbone.num_channels, out_channels=embed_dim, num_outs=num_feature_levels) 38 | 39 | transformer = RelationTransformer( 40 | encoder=RelationTransformerEncoder( 41 | encoder_layer=RelationTransformerEncoderLayer( 42 | embed_dim=embed_dim, 43 | n_heads=num_heads, 44 | dropout=0.0, 45 | activation=nn.ReLU(inplace=True), 46 | n_levels=num_feature_levels, 47 | n_points=4, 48 | d_ffn=dim_feedforward, 49 | ), 50 | num_layers=transformer_enc_layers, 51 | ), 52 | decoder=RelationTransformerDecoder( 53 | decoder_layer=RelationTransformerDecoderLayer( 54 | embed_dim=embed_dim, 55 | n_heads=num_heads, 56 | dropout=0.0, 57 | activation=nn.ReLU(inplace=True), 58 | n_levels=num_feature_levels, 59 | n_points=4, 60 | d_ffn=dim_feedforward, 61 | ), 62 | num_layers=transformer_dec_layers, 63 | num_classes=num_classes, 64 | ), 65 | num_classes=num_classes, 66 | num_feature_levels=num_feature_levels, 67 | two_stage_num_proposals=num_queries, 68 | hybrid_num_proposals=hybrid_num_proposals, 69 | ) 70 | 71 | matcher = HungarianMatcher( 72 | cost_class=2, cost_bbox=5, cost_giou=2, focal_alpha=0.25, focal_gamma=2.0 73 | ) 74 | 75 | # construct weight_dict for loss 76 | weight_dict = {"loss_class": 1, "loss_bbox": 5, "loss_giou": 2} 77 | weight_dict.update({"loss_class_dn": 1, "loss_bbox_dn": 5, "loss_giou_dn": 2}) 78 | aux_weight_dict = {} 79 | for i in range(transformer.decoder.num_layers - 1): 80 | aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()}) 81 | weight_dict.update(aux_weight_dict) 82 | weight_dict.update({"loss_class_enc": 1, "loss_bbox_enc": 5, "loss_giou_enc": 2}) 83 | weight_dict.update({k + "_hybrid": v for k, v in weight_dict.items()}) 84 | 85 | criterion = HybridSetCriterion( 86 | num_classes=num_classes, matcher=matcher, weight_dict=weight_dict, alpha=0.25, gamma=2.0 87 | ) 88 | postprocessor = PostProcess(select_box_nums_for_evaluation=300) 89 | 90 | # combine above components to instantiate the model 91 | model = RelationDETR( 92 | backbone=backbone, 93 | neck=neck, 94 | position_embedding=position_embedding, 95 | transformer=transformer, 96 | criterion=criterion, 97 | postprocessor=postprocessor, 98 | num_classes=num_classes, 99 | num_queries=num_queries, 100 | hybrid_assign=hybrid_assign, 101 | denoising_nums=100, 102 | min_size=800, 103 | max_size=1333, 104 | ) 105 | -------------------------------------------------------------------------------- /configs/relation_detr/relation_detr_resnet50_800_1333.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | from models.backbones.resnet import ResNetBackbone 4 | from models.bricks.misc import FrozenBatchNorm2d 5 | from models.bricks.position_encoding import PositionEmbeddingSine 6 | from models.bricks.post_process import PostProcess 7 | from models.bricks.relation_transformer import ( 8 | RelationTransformer, 9 | RelationTransformerDecoder, 10 | RelationTransformerDecoderLayer, 11 | RelationTransformerEncoder, 12 | RelationTransformerEncoderLayer, 13 | ) 14 | from models.bricks.set_criterion import HybridSetCriterion 15 | from models.detectors.relation_detr import RelationDETR 16 | from models.matcher.hungarian_matcher import HungarianMatcher 17 | from models.necks.channel_mapper import ChannelMapper 18 | 19 | # mostly changed parameters 20 | embed_dim = 256 21 | num_classes = 91 22 | num_queries = 900 23 | hybrid_num_proposals = 1500 24 | hybrid_assign = 6 25 | num_feature_levels = 4 26 | transformer_enc_layers = 6 27 | transformer_dec_layers = 6 28 | num_heads = 8 29 | dim_feedforward = 2048 30 | 31 | # instantiate model components 32 | position_embedding = PositionEmbeddingSine( 33 | embed_dim // 2, temperature=10000, normalize=True, offset=-0.5 34 | ) 35 | 36 | backbone = ResNetBackbone( 37 | "resnet50", norm_layer=FrozenBatchNorm2d, return_indices=(1, 2, 3), freeze_indices=(0,) 38 | ) 39 | 40 | neck = ChannelMapper( 41 | in_channels=backbone.num_channels, 42 | out_channels=embed_dim, 43 | num_outs=num_feature_levels, 44 | ) 45 | 46 | transformer = RelationTransformer( 47 | encoder=RelationTransformerEncoder( 48 | encoder_layer=RelationTransformerEncoderLayer( 49 | embed_dim=embed_dim, 50 | n_heads=num_heads, 51 | dropout=0.0, 52 | activation=nn.ReLU(inplace=True), 53 | n_levels=num_feature_levels, 54 | n_points=4, 55 | d_ffn=dim_feedforward, 56 | ), 57 | num_layers=transformer_enc_layers, 58 | ), 59 | decoder=RelationTransformerDecoder( 60 | decoder_layer=RelationTransformerDecoderLayer( 61 | embed_dim=embed_dim, 62 | n_heads=num_heads, 63 | dropout=0.0, 64 | activation=nn.ReLU(inplace=True), 65 | n_levels=num_feature_levels, 66 | n_points=4, 67 | d_ffn=dim_feedforward, 68 | ), 69 | num_layers=transformer_dec_layers, 70 | num_classes=num_classes, 71 | ), 72 | num_classes=num_classes, 73 | num_feature_levels=num_feature_levels, 74 | two_stage_num_proposals=num_queries, 75 | hybrid_num_proposals=hybrid_num_proposals, 76 | ) 77 | 78 | matcher = HungarianMatcher( 79 | cost_class=2, cost_bbox=5, cost_giou=2, focal_alpha=0.25, focal_gamma=2.0 80 | ) 81 | 82 | # construct weight_dict for loss 83 | weight_dict = {"loss_class": 1, "loss_bbox": 5, "loss_giou": 2} 84 | weight_dict.update({"loss_class_dn": 1, "loss_bbox_dn": 5, "loss_giou_dn": 2}) 85 | aux_weight_dict = {} 86 | for i in range(transformer.decoder.num_layers - 1): 87 | aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()}) 88 | weight_dict.update(aux_weight_dict) 89 | weight_dict.update({"loss_class_enc": 1, "loss_bbox_enc": 5, "loss_giou_enc": 2}) 90 | weight_dict.update({k + "_hybrid": v for k, v in weight_dict.items()}) 91 | 92 | criterion = HybridSetCriterion( 93 | num_classes=num_classes, matcher=matcher, weight_dict=weight_dict, alpha=0.25, gamma=2.0 94 | ) 95 | postprocessor = PostProcess(select_box_nums_for_evaluation=300) 96 | 97 | # combine above components to instantiate the model 98 | model = RelationDETR( 99 | backbone=backbone, 100 | neck=neck, 101 | position_embedding=position_embedding, 102 | transformer=transformer, 103 | criterion=criterion, 104 | postprocessor=postprocessor, 105 | num_classes=num_classes, 106 | num_queries=num_queries, 107 | hybrid_assign=hybrid_assign, 108 | denoising_nums=100, 109 | min_size=800, 110 | max_size=1333, 111 | ) 112 | -------------------------------------------------------------------------------- /configs/relation_detr/relation_detr_swin_l_800_1333.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | from models.backbones.swin import SwinTransformerBackbone 4 | from models.bricks.position_encoding import PositionEmbeddingSine 5 | from models.bricks.post_process import PostProcess 6 | from models.bricks.relation_transformer import ( 7 | RelationTransformer, 8 | RelationTransformerDecoder, 9 | RelationTransformerDecoderLayer, 10 | RelationTransformerEncoder, 11 | RelationTransformerEncoderLayer, 12 | ) 13 | from models.bricks.set_criterion import HybridSetCriterion 14 | from models.detectors.relation_detr import RelationDETR 15 | from models.matcher.hungarian_matcher import HungarianMatcher 16 | from models.necks.channel_mapper import ChannelMapper 17 | 18 | # mostly changed parameters 19 | embed_dim = 256 20 | num_classes = 91 21 | num_queries = 900 22 | hybrid_num_proposals = 1500 23 | hybrid_assign = 6 24 | num_feature_levels = 4 25 | transformer_enc_layers = 6 26 | transformer_dec_layers = 6 27 | num_heads = 8 28 | dim_feedforward = 2048 29 | 30 | # instantiate model components 31 | position_embedding = PositionEmbeddingSine( 32 | embed_dim // 2, temperature=10000, normalize=True, offset=-0.5 33 | ) 34 | 35 | backbone = SwinTransformerBackbone(arch="swin_l", return_indices=(1, 2, 3), freeze_indices=(0,)) 36 | 37 | neck = ChannelMapper( 38 | in_channels=backbone.num_channels, 39 | out_channels=embed_dim, 40 | num_outs=num_feature_levels, 41 | ) 42 | 43 | transformer = RelationTransformer( 44 | encoder=RelationTransformerEncoder( 45 | encoder_layer=RelationTransformerEncoderLayer( 46 | embed_dim=embed_dim, 47 | n_heads=num_heads, 48 | dropout=0.0, 49 | activation=nn.ReLU(inplace=True), 50 | n_levels=num_feature_levels, 51 | n_points=4, 52 | d_ffn=dim_feedforward, 53 | ), 54 | num_layers=transformer_enc_layers, 55 | ), 56 | decoder=RelationTransformerDecoder( 57 | decoder_layer=RelationTransformerDecoderLayer( 58 | embed_dim=embed_dim, 59 | n_heads=num_heads, 60 | dropout=0.0, 61 | activation=nn.ReLU(inplace=True), 62 | n_levels=num_feature_levels, 63 | n_points=4, 64 | d_ffn=dim_feedforward, 65 | ), 66 | num_layers=transformer_dec_layers, 67 | num_classes=num_classes, 68 | ), 69 | num_classes=num_classes, 70 | num_feature_levels=num_feature_levels, 71 | two_stage_num_proposals=num_queries, 72 | hybrid_num_proposals=hybrid_num_proposals, 73 | ) 74 | 75 | matcher = HungarianMatcher( 76 | cost_class=2, cost_bbox=5, cost_giou=2, focal_alpha=0.25, focal_gamma=2.0 77 | ) 78 | 79 | # construct weight_dict for loss 80 | weight_dict = {"loss_class": 1, "loss_bbox": 5, "loss_giou": 2} 81 | weight_dict.update({"loss_class_dn": 1, "loss_bbox_dn": 5, "loss_giou_dn": 2}) 82 | aux_weight_dict = {} 83 | for i in range(transformer.decoder.num_layers - 1): 84 | aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()}) 85 | weight_dict.update(aux_weight_dict) 86 | weight_dict.update({"loss_class_enc": 1, "loss_bbox_enc": 5, "loss_giou_enc": 2}) 87 | weight_dict.update({k + "_hybrid": v for k, v in weight_dict.items()}) 88 | 89 | criterion = HybridSetCriterion( 90 | num_classes=num_classes, matcher=matcher, weight_dict=weight_dict, alpha=0.25, gamma=2.0 91 | ) 92 | postprocessor = PostProcess(select_box_nums_for_evaluation=300) 93 | 94 | # combine above components to instantiate the model 95 | model = RelationDETR( 96 | backbone=backbone, 97 | neck=neck, 98 | position_embedding=position_embedding, 99 | transformer=transformer, 100 | criterion=criterion, 101 | postprocessor=postprocessor, 102 | num_classes=num_classes, 103 | num_queries=num_queries, 104 | hybrid_assign=hybrid_assign, 105 | denoising_nums=100, 106 | min_size=800, 107 | max_size=1333, 108 | ) 109 | -------------------------------------------------------------------------------- /configs/train_config.py: -------------------------------------------------------------------------------- 1 | from torch import optim 2 | 3 | from datasets.coco import CocoDetection 4 | from transforms import presets 5 | from optimizer import param_dict 6 | 7 | # Commonly changed training configurations 8 | num_epochs = 12 # train epochs 9 | batch_size = 2 # total_batch_size = #GPU x batch_size 10 | num_workers = 4 # workers for pytorch DataLoader 11 | pin_memory = True # whether pin_memory for pytorch DataLoader 12 | print_freq = 50 # frequency to print logs 13 | starting_epoch = 0 14 | max_norm = 0.1 # clip gradient norm 15 | 16 | output_dir = None # path to save checkpoints, default for None: checkpoints/{model_name} 17 | find_unused_parameters = False # useful for debugging distributed training 18 | 19 | # define dataset for train 20 | coco_path = "data/coco" # /PATH/TO/YOUR/COCODIR 21 | train_dataset = CocoDetection( 22 | img_folder=f"{coco_path}/train2017", 23 | ann_file=f"{coco_path}/annotations/instances_train2017.json", 24 | transforms=presets.detr, # see transforms/presets to choose a transform 25 | train=True, 26 | ) 27 | test_dataset = CocoDetection( 28 | img_folder=f"{coco_path}/val2017", 29 | ann_file=f"{coco_path}/annotations/instances_val2017.json", 30 | transforms=None, # the eval_transform is integrated in the model 31 | ) 32 | 33 | # model config to train 34 | model_path = "configs/relation_detr/relation_detr_resnet50_800_1333.py" 35 | 36 | # specify a checkpoint folder to resume, or a pretrained ".pth" to finetune, for example: 37 | # checkpoints/relation_detr_resnet50_800_1333/train/2024-03-22-09_38_50 38 | # checkpoints/relation_detr_resnet50_800_1333/train/2024-03-22-09_38_50/best_ap.pth 39 | resume_from_checkpoint = None 40 | 41 | learning_rate = 1e-4 # initial learning rate 42 | optimizer = optim.AdamW(lr=learning_rate, weight_decay=1e-4, betas=(0.9, 0.999)) 43 | lr_scheduler = optim.lr_scheduler.MultiStepLR(milestones=[10], gamma=0.1) 44 | 45 | # This define parameter groups with different learning rate 46 | param_dicts = param_dict.finetune_backbone_and_linear_projection(lr=learning_rate) 47 | -------------------------------------------------------------------------------- /images/convergence_curve.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiuqhou/Relation-DETR/b485955c72452788240600da6d0f0b8cc49f33c7/images/convergence_curve.png -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from functools import partial 4 | from test import create_test_data_loader 5 | from typing import Dict, List, Tuple 6 | 7 | import accelerate 8 | import cv2 9 | import numpy as np 10 | import torch 11 | import torch.utils.data as data 12 | from accelerate import Accelerator 13 | from PIL import Image 14 | from tqdm import tqdm 15 | 16 | from util.lazy_load import Config 17 | from util.logger import setup_logger 18 | from util.utils import load_checkpoint, load_state_dict 19 | from util.visualize import plot_bounding_boxes_on_image 20 | 21 | 22 | def is_image(file_path): 23 | try: 24 | img = Image.open(file_path) 25 | img.close() 26 | return True 27 | except: 28 | return False 29 | 30 | 31 | def parse_args(): 32 | parser = argparse.ArgumentParser(description="Inference a detector") 33 | 34 | # dataset parameters 35 | parser.add_argument("--image-dir", type=str, required=True) 36 | parser.add_argument("--workers", type=int, default=2) 37 | 38 | # model parameters 39 | parser.add_argument("--model-config", type=str, required=True) 40 | parser.add_argument("--checkpoint", type=str, required=True) 41 | 42 | # visualization parameters 43 | parser.add_argument("--show-dir", type=str, default=None) 44 | parser.add_argument("--show-conf", type=float, default=0.5) 45 | 46 | # plot parameters 47 | parser.add_argument("--font-scale", type=float, default=1.0) 48 | parser.add_argument("--box-thick", type=int, default=1) 49 | parser.add_argument("--fill-alpha", type=float, default=0.2) 50 | parser.add_argument("--text-box-color", type=int, nargs="+", default=(255, 255, 255)) 51 | parser.add_argument("--text-font-color", type=int, nargs="+", default=None) 52 | parser.add_argument("--text-alpha", type=float, default=1.0) 53 | 54 | # engine parameters 55 | parser.add_argument("--seed", type=int, default=42) 56 | 57 | args = parser.parse_args() 58 | return args 59 | 60 | 61 | class InferenceDataset(data.Dataset): 62 | def __init__(self, root): 63 | self.images = [os.path.join(root, img) for img in os.listdir(root)] 64 | self.images = [img for img in self.images if is_image(img)] 65 | assert len(self.images) > 0, "No images found" 66 | 67 | def __len__(self): 68 | return len(self.images) 69 | 70 | def __getitem__(self, index): 71 | cv2.setNumThreads(0) 72 | cv2.ocl.setUseOpenCL(False) 73 | image = cv2.imdecode(np.fromfile(self.images[index], dtype=np.uint8), -1) 74 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).transpose(2, 0, 1) 75 | return torch.tensor(image) 76 | 77 | 78 | def inference(): 79 | args = parse_args() 80 | 81 | # set fixed seed and deterministic_algorithms 82 | accelerator = Accelerator() 83 | accelerate.utils.set_seed(args.seed, device_specific=False) 84 | torch.backends.cudnn.benchmark = False 85 | torch.backends.cudnn.deterministic = True 86 | # deterministic in low version pytorch leads to RuntimeError 87 | # torch.use_deterministic_algorithms(True, warn_only=True) 88 | 89 | # setup logger 90 | for logger_name in ["py.warnings", "accelerate", os.path.basename(os.getcwd())]: 91 | setup_logger(distributed_rank=accelerator.local_process_index, name=logger_name) 92 | 93 | dataset = InferenceDataset(args.image_dir) 94 | data_loader = create_test_data_loader( 95 | dataset, accelerator=accelerator, batch_size=1, num_workers=args.workers 96 | ) 97 | 98 | # get inference results from model output 99 | model = Config(args.model_config).model.eval() 100 | checkpoint = load_checkpoint(args.checkpoint) 101 | if isinstance(checkpoint, Dict) and "model" in checkpoint: 102 | checkpoint = checkpoint["model"] 103 | load_state_dict(model, checkpoint) 104 | model = accelerator.prepare_model(model) 105 | 106 | with torch.inference_mode(): 107 | predictions = [] 108 | for index, images in enumerate(tqdm(data_loader)): 109 | prediction = model(images)[0] 110 | 111 | # change torch.Tensor to CPU 112 | for key in prediction: 113 | prediction[key] = prediction[key].to("cpu", non_blocking=True) 114 | image_name = data_loader.dataset.images[index] 115 | image = images[0].to("cpu", non_blocking=True) 116 | prediction = {"image_name": image_name, "image": image, "output": prediction} 117 | predictions.append(prediction) 118 | 119 | # save visualization results 120 | if args.show_dir: 121 | os.makedirs(args.show_dir, exist_ok=True) 122 | 123 | # create a dummy dataset for visualization with multi-workers 124 | data_loader = create_test_data_loader( 125 | predictions, accelerator=accelerator, batch_size=1, num_workers=args.workers 126 | ) 127 | data_loader.collate_fn = partial(_visualize_batch_for_infer, classes=model.CLASSES, **vars(args)) 128 | [None for _ in tqdm(data_loader)] 129 | 130 | 131 | def _visualize_batch_for_infer( 132 | batch: Tuple[Dict], 133 | classes: List[str], 134 | show_conf: float = 0.0, 135 | show_dir: str = None, 136 | font_scale: float = 1.0, 137 | box_thick: int = 3, 138 | fill_alpha: float = 0.2, 139 | text_box_color: Tuple[int] = (255, 255, 255), 140 | text_font_color: Tuple[int] = None, 141 | text_alpha: float = 0.5, 142 | **kwargs, # Not useful 143 | ): 144 | image_name, image, output = batch[0].values() 145 | # plot bounding boxes on image 146 | image = image.numpy().transpose(1, 2, 0) 147 | image = plot_bounding_boxes_on_image( 148 | image=image, 149 | boxes=output["boxes"], 150 | labels=output["labels"], 151 | scores=output.get("scores", None), 152 | classes=classes, 153 | show_conf=show_conf, 154 | font_scale=font_scale, 155 | box_thick=box_thick, 156 | fill_alpha=fill_alpha, 157 | text_box_color=text_box_color, 158 | text_font_color=text_font_color, 159 | text_alpha=text_alpha, 160 | ) 161 | cv2.imwrite(os.path.join(show_dir, os.path.basename(image_name)), image[:, :, ::-1]) 162 | 163 | 164 | if __name__ == "__main__": 165 | inference() 166 | -------------------------------------------------------------------------------- /models/backbones/base_backbone.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import inspect 3 | import logging 4 | import os 5 | from typing import Dict 6 | 7 | from omegaconf import DictConfig 8 | from torch import nn 9 | 10 | from util.utils import load_state_dict as _load_state_dict 11 | 12 | 13 | class BaseBackbone: 14 | @staticmethod 15 | def load_state_dict(model: nn.Module, state_dict: Dict): 16 | if state_dict is None: 17 | return 18 | assert isinstance(state_dict, Dict), "state_dict must be OrderedDict." 19 | _load_state_dict(model, state_dict) 20 | 21 | @staticmethod 22 | def freeze_module(module: nn.Module): 23 | module.eval() 24 | for param in module.parameters(): 25 | param.requires_grad = False 26 | 27 | def get_instantiate_config(self, func_name, arch, extra_params): 28 | # log some necessary information about backbone 29 | logger = logging.getLogger(os.path.basename(os.getcwd()) + "." + __name__) 30 | assert arch is None or arch in self.model_arch, \ 31 | f"Expected architecture in {self.model_arch.keys()} but got {arch}" 32 | logger.info(f"Backbone architecture: {arch}") 33 | 34 | # merge parameters from self.arch with extra params 35 | model_config = copy.deepcopy(self.model_arch[arch]) if arch is not None else {} 36 | for name, param in inspect.signature(func_name).parameters.items(): 37 | # get default, current and modified params 38 | default = param.default if param.default is not inspect.Parameter.empty else None 39 | modified_param = extra_params.get(name, None) 40 | if isinstance(model_config, Dict): 41 | cur_param = model_config.get(name, None) 42 | elif isinstance(model_config, DictConfig): 43 | cur_param = getattr(model_config, name, None) 44 | else: 45 | cur_param = None 46 | 47 | # choose the high-prior parameter 48 | if cur_param is not None: 49 | default = cur_param 50 | if modified_param is not None: 51 | default = modified_param 52 | 53 | # replace parameters in model_config 54 | if isinstance(model_config, Dict): 55 | model_config[name] = default 56 | elif isinstance(model_config, DictConfig): 57 | setattr(model_config, name, default) 58 | else: 59 | raise TypeError("Only Dict and DictConfig supported.") 60 | 61 | return model_config 62 | -------------------------------------------------------------------------------- /models/bricks/base_transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | from torch import nn 4 | 5 | 6 | class DETRBaseTransformer(nn.Module): 7 | """A base class that contains some methods commonly used in DETR transformer, 8 | such as DeformableTransformer, DabTransformer, DINOTransformer, AlignTransformer. 9 | 10 | """ 11 | def __init__(self, num_feature_levels, embed_dim): 12 | super().__init__() 13 | self.embed_dim = embed_dim 14 | self.num_feature_levels = num_feature_levels 15 | 16 | @staticmethod 17 | def flatten_multi_level(multi_level_elements): 18 | multi_level_elements = torch.cat( 19 | tensors=[e.flatten(-2) for e in multi_level_elements], dim=-1 20 | ) # (b, [c], s) 21 | if multi_level_elements.ndim == 3: 22 | multi_level_elements.transpose_(1, 2) 23 | return multi_level_elements 24 | 25 | def multi_level_misc(self, multi_level_masks): 26 | if torchvision._is_tracing(): 27 | # torch.Tensor.shape exports not well for ONNX 28 | # use operators.shape_as_tensor istead 29 | from torch.onnx import operators 30 | spatial_shapes = [operators.shape_as_tensor(m)[-2:] for m in multi_level_masks] 31 | spatial_shapes = torch.stack(spatial_shapes).to(multi_level_masks[0].device) 32 | else: 33 | spatial_shapes = [m.shape[-2:] for m in multi_level_masks] 34 | spatial_shapes = multi_level_masks[0].new_tensor(spatial_shapes, dtype=torch.int64) 35 | level_start_index = torch.cat( 36 | (spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]) 37 | ) 38 | valid_ratios = self.multi_level_valid_ratios(multi_level_masks) 39 | return spatial_shapes, level_start_index, valid_ratios 40 | 41 | @staticmethod 42 | def get_valid_ratios(mask): 43 | b, h, w = mask.shape 44 | if h == 0 or w == 0: # for empty Tensor 45 | return mask.new_ones((b, 2)).float() 46 | valid_h = torch.sum(~mask[:, :, 0], 1) 47 | valid_w = torch.sum(~mask[:, 0, :], 1) 48 | valid_ratio_h = valid_h.float() / h 49 | valid_ratio_w = valid_w.float() / w 50 | valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1) # [n, 2] 51 | return valid_ratio 52 | 53 | def multi_level_valid_ratios(self, multi_level_masks): 54 | return torch.stack([self.get_valid_ratios(m) for m in multi_level_masks], 1) 55 | 56 | @staticmethod 57 | def get_full_reference_points(spatial_shapes, valid_ratios): 58 | reference_points_list = [] 59 | for lvl, (h, w) in enumerate(spatial_shapes): 60 | ref_y, ref_x = torch.meshgrid( 61 | torch.arange(0.5, h + 0.5, device=spatial_shapes.device), 62 | torch.arange(0.5, w + 0.5, device=spatial_shapes.device), 63 | indexing="ij", 64 | ) 65 | ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * h) 66 | ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * w) 67 | ref = torch.stack((ref_x, ref_y), -1) # [n, h*w, 2] 68 | reference_points_list.append(ref) 69 | reference_points = torch.cat(reference_points_list, 1) # [n, s, 2] 70 | return reference_points 71 | 72 | def get_reference(self, spatial_shapes, valid_ratios): 73 | # get full_reference_points, should be transferred using valid_ratios 74 | full_reference_points = self.get_full_reference_points(spatial_shapes, valid_ratios) 75 | reference_points = full_reference_points[:, :, None] * valid_ratios[:, None] 76 | # get proposals, reuse full_reference_points to speed up 77 | level_wh = full_reference_points.new_tensor([[i] for i in range(spatial_shapes.shape[0])]) 78 | level_wh = 0.05 * 2.0**level_wh.repeat_interleave(spatial_shapes.prod(-1), 0) 79 | level_wh = level_wh.expand_as(full_reference_points) 80 | proposals = torch.cat([full_reference_points, level_wh], -1) 81 | return reference_points, proposals 82 | 83 | 84 | class MultiLevelTransformer(DETRBaseTransformer): 85 | """A base class that contains methods based on level_embeds.""" 86 | def __init__(self, num_feature_levels, embed_dim): 87 | super().__init__(num_feature_levels, embed_dim) 88 | self.level_embeds = nn.Parameter(torch.Tensor(num_feature_levels, embed_dim)) 89 | self._init_weights_detr_transformer() 90 | 91 | def _init_weights_detr_transformer(self): 92 | nn.init.normal_(self.level_embeds) 93 | 94 | def get_lvl_pos_embed(self, multi_level_pos_embeds): 95 | multi_level_pos_embeds = [ 96 | p + l.view(1, -1, 1, 1) for p, l in zip(multi_level_pos_embeds, self.level_embeds) 97 | ] 98 | return self.flatten_multi_level(multi_level_pos_embeds) 99 | 100 | 101 | class TwostageTransformer(MultiLevelTransformer): 102 | """A base class that contains some methods commonly used in two-stage transformer, 103 | such as DeformableTransformer, DabTransformer, DINOTransformer, AlignTransformer. 104 | 105 | """ 106 | def __init__(self, num_feature_levels, embed_dim): 107 | super().__init__(num_feature_levels, embed_dim) 108 | self.enc_output = nn.Linear(embed_dim, embed_dim) 109 | self.enc_output_norm = nn.LayerNorm(embed_dim) 110 | self._init_weights_two_stage_transformer() 111 | 112 | def _init_weights_two_stage_transformer(self): 113 | nn.init.xavier_uniform_(self.enc_output.weight) 114 | nn.init.constant_(self.enc_output.bias, 0.0) 115 | 116 | def get_encoder_output(self, memory, proposals, memory_padding_mask): 117 | output_proposals_valid = ((proposals > 0.01) & (proposals < 0.99)).all(-1, keepdim=True) 118 | proposals = torch.log(proposals / (1 - proposals)) # inverse_sigmoid 119 | invalid = memory_padding_mask.unsqueeze(-1) | ~output_proposals_valid 120 | proposals.masked_fill_(invalid, float("inf")) 121 | 122 | output_memory = memory * (~memory_padding_mask.unsqueeze(-1)) * (output_proposals_valid) 123 | output_memory = self.enc_output_norm(self.enc_output(output_memory)) 124 | return output_memory, proposals 125 | -------------------------------------------------------------------------------- /models/bricks/dab_transformer.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch import nn 5 | 6 | from models.bricks.base_transformer import TwostageTransformer 7 | from models.bricks.basic import MLP 8 | from models.bricks.dino_transformer import DINOTransformerEncoder 9 | from models.bricks.dn_transformer import DNTransformerDecoder 10 | from models.bricks.relation_transformer import ( 11 | RelationTransformerDecoderLayer, 12 | RelationTransformerEncoderLayer, 13 | ) 14 | 15 | 16 | class DabTransformer(TwostageTransformer): 17 | def __init__( 18 | self, 19 | encoder: nn.Module, 20 | decoder: nn.Module, 21 | num_classes: int, 22 | num_feature_levels: int = 4, 23 | two_stage_num_proposals: int = 300, 24 | ): 25 | super().__init__(num_feature_levels, encoder.embed_dim) 26 | # model parameters 27 | self.two_stage_num_proposals = two_stage_num_proposals 28 | self.num_classes = num_classes 29 | 30 | # model structure 31 | self.encoder = encoder 32 | self.decoder = decoder 33 | self.encoder_class_head = nn.Linear(self.embed_dim, num_classes) 34 | self.encoder_bbox_head = MLP(self.embed_dim, self.embed_dim, 4, 3) 35 | 36 | self.init_weights() 37 | 38 | def init_weights(self): 39 | # initilize encoder and hybrid classification layers 40 | prior_prob = 0.01 41 | bias_value = -math.log((1 - prior_prob) / prior_prob) 42 | nn.init.constant_(self.encoder_class_head.bias, bias_value) 43 | # initiailize encoder and hybrid regression layers 44 | nn.init.constant_(self.encoder_bbox_head.layers[-1].weight, 0.0) 45 | nn.init.constant_(self.encoder_bbox_head.layers[-1].bias, 0.0) 46 | 47 | def forward(self, multi_level_feats, multi_level_masks, multi_level_pos_embeds): 48 | # get input for encoder 49 | feat_flatten = self.flatten_multi_level(multi_level_feats) 50 | mask_flatten = self.flatten_multi_level(multi_level_masks) 51 | lvl_pos_embed_flatten = self.get_lvl_pos_embed(multi_level_pos_embeds) 52 | spatial_shapes, level_start_index, valid_ratios = self.multi_level_misc(multi_level_masks) 53 | reference_points, proposals = self.get_reference(spatial_shapes, valid_ratios) 54 | 55 | # encoder 56 | memory = self.encoder( 57 | query=feat_flatten, 58 | query_pos=lvl_pos_embed_flatten, 59 | spatial_shapes=spatial_shapes, 60 | level_start_index=level_start_index, 61 | query_key_padding_mask=mask_flatten, 62 | reference_points=reference_points, 63 | ) 64 | 65 | # get encoder output, classes and coordinates 66 | output_memory, output_proposals = self.get_encoder_output(memory, proposals, mask_flatten) 67 | enc_outputs_class = self.encoder_class_head(output_memory) 68 | enc_outputs_coord = self.encoder_bbox_head(output_memory) + output_proposals 69 | enc_outputs_coord = enc_outputs_coord.sigmoid() 70 | 71 | # get topk output classes and coordinates 72 | topk, num_classes = self.two_stage_num_proposals, self.num_classes 73 | topk_index = torch.topk(enc_outputs_class.max(-1)[0], topk, dim=1)[1].unsqueeze(-1) 74 | enc_outputs_class = enc_outputs_class.gather(1, topk_index.expand(-1, -1, num_classes)) 75 | enc_outputs_coord = enc_outputs_coord.gather(1, topk_index.expand(-1, -1, 4)) 76 | 77 | # get query(target) and reference points 78 | target = torch.gather(output_memory, 1, topk_index.expand(-1, -1, self.embed_dim)).detach() 79 | reference_points = enc_outputs_coord.detach() 80 | 81 | # decoder 82 | outputs_classes, outputs_coords = self.decoder( 83 | query=target, 84 | value=memory, 85 | key_padding_mask=mask_flatten, 86 | reference_points=reference_points, 87 | spatial_shapes=spatial_shapes, 88 | level_start_index=level_start_index, 89 | valid_ratios=valid_ratios, 90 | ) 91 | 92 | return outputs_classes, outputs_coords, enc_outputs_class, enc_outputs_coord 93 | 94 | 95 | DabTransformerEncoderLayer = RelationTransformerEncoderLayer 96 | DabTransformerEncoder = DINOTransformerEncoder 97 | DabTransformerDecoderLayer = RelationTransformerDecoderLayer 98 | DabTransformerDecoder = DNTransformerDecoder # NOTE: equivalent under two-stage settings 99 | -------------------------------------------------------------------------------- /models/bricks/deform_conv2d_pack.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Tuple, Union 3 | 4 | import torch 5 | from torch import nn 6 | from torchvision.ops import DeformConv2d 7 | 8 | 9 | class DeformConv2dPack(nn.Module): 10 | """This is a pack of deformable convolution that can be used as normal convolution""" 11 | 12 | def __init__( 13 | self, 14 | in_channels: int, 15 | out_channels: int, 16 | kernel_size: Union[int, Tuple[int]], 17 | stride: int = 1, 18 | padding: int = 0, 19 | dilation: int = 1, 20 | groups: int = 1, 21 | bias: Union[bool, str] = True, 22 | ): 23 | super().__init__() 24 | if isinstance(kernel_size, int): 25 | kernel_size = ( 26 | kernel_size, 27 | kernel_size, 28 | ) 29 | self.in_channels = in_channels 30 | self.kernel_size = kernel_size 31 | self.groups = groups 32 | self.conv_offset = nn.Conv2d( 33 | in_channels, 34 | groups * 2 * kernel_size[0] * kernel_size[1], 35 | kernel_size=kernel_size, 36 | stride=stride, 37 | padding=padding, 38 | dilation=dilation, 39 | groups=groups, # Don't know whether to add groups here 40 | bias=True, 41 | ) 42 | self.conv_mask = nn.Conv2d( 43 | in_channels, 44 | groups * kernel_size[0] * kernel_size[1], 45 | kernel_size=kernel_size, 46 | stride=stride, 47 | padding=padding, 48 | dilation=dilation, 49 | groups=groups, 50 | bias=True, 51 | ) 52 | self.deform_conv2d = DeformConv2d( 53 | in_channels=in_channels, 54 | out_channels=out_channels, 55 | kernel_size=kernel_size, 56 | stride=stride, 57 | padding=padding, 58 | dilation=dilation, 59 | groups=groups, 60 | bias=bias, 61 | ) 62 | self.init_weights() 63 | 64 | def init_weights(self) -> None: 65 | self.conv_offset.weight.data.zero_() 66 | self.conv_mask.weight.data.zero_() 67 | self.conv_offset.bias.data.zero_() 68 | self.conv_mask.bias.data.zero_() 69 | n = self.in_channels 70 | for k in self.kernel_size: 71 | n *= k 72 | stdv = 1.0 / math.sqrt(n) 73 | self.deform_conv2d.weight.data.uniform_(-stdv, stdv) 74 | if self.deform_conv2d.bias is not None: 75 | self.deform_conv2d.bias.data.zero_() 76 | 77 | def forward(self, x: torch.Tensor) -> torch.Tensor: 78 | offset = self.conv_offset(x) 79 | mask = torch.sigmoid(self.conv_mask(x)) 80 | out = self.deform_conv2d(x, offset, mask) 81 | return out 82 | -------------------------------------------------------------------------------- /models/bricks/losses.py: -------------------------------------------------------------------------------- 1 | from torch.nn import functional as F 2 | 3 | 4 | def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2): 5 | prob = inputs.sigmoid() 6 | target_score = targets.to(inputs.dtype) 7 | weight = (1 - alpha) * prob**gamma * (1 - targets) + targets * alpha * (1 - prob)**gamma 8 | # according to original implementation, sigmoid_focal_loss keep gradient on weight 9 | loss = F.binary_cross_entropy_with_logits(inputs, target_score, reduction="none") 10 | loss = loss * weight 11 | # we use sum/num to replace mean to avoid NaN 12 | return (loss.sum(1) / max(loss.shape[1], 1)).sum() / num_boxes 13 | 14 | 15 | def vari_sigmoid_focal_loss(inputs, targets, gt_score, num_boxes, alpha: float = 0.25, gamma: float = 2): 16 | prob = inputs.sigmoid().detach() # pytorch version of RT-DETR has detach while paddle version not 17 | target_score = targets * gt_score.unsqueeze(-1) 18 | weight = (1 - alpha) * prob.pow(gamma) * (1 - targets) + target_score 19 | loss = F.binary_cross_entropy_with_logits(inputs, target_score, weight=weight, reduction="none") 20 | # we use sum/num to replace mean to avoid NaN 21 | return (loss.sum(1) / max(loss.shape[1], 1)).sum() / num_boxes 22 | 23 | 24 | def ia_bce_loss(inputs, targets, gt_score, num_boxes, k: float = 0.25, alpha: float = 0, gamma: float = 2): 25 | prob = inputs.sigmoid().detach() 26 | # calculate iou_aware_score and constrain the value following original implementation 27 | iou_aware_score = prob**k * gt_score.unsqueeze(-1)**(1 - k) 28 | iou_aware_score = iou_aware_score.clamp(min=0.01) 29 | target_score = targets * iou_aware_score 30 | weight = (1 - alpha) * prob.pow(gamma) * (1 - targets) + targets 31 | loss = F.binary_cross_entropy_with_logits(inputs, target_score, weight=weight, reduction="none") 32 | # we use sum/num to replace mean to avoid NaN 33 | return (loss.sum(1) / max(loss.shape[1], 1)).sum() / num_boxes 34 | -------------------------------------------------------------------------------- /models/bricks/ops/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiuqhou/Relation-DETR/b485955c72452788240600da6d0f0b8cc49f33c7/models/bricks/ops/__init__.py -------------------------------------------------------------------------------- /models/bricks/position_encoding.py: -------------------------------------------------------------------------------- 1 | import math 2 | import functools 3 | from typing import Tuple, Union 4 | 5 | import torch 6 | from torch import Tensor, nn 7 | 8 | 9 | class PositionEmbeddingSine(nn.Module): 10 | """Sinusoidal position embedding used in DETR model. See `End-to-End Object Detection 11 | with Transformers `_ for more details. 12 | 13 | :param num_pos_feats: The feature dimension for each position along x-axis or y-axis. 14 | The final returned dimension for each position is 2 times of the input value, 15 | defaults to 64 16 | :param temperature: The temperature used for scaling the position embedding, defaults to 10000 17 | :param normalize: Whether to normalize the position embedding, defaults to False 18 | :param scale: A scale factor that scales the position embedding, which is used only when 19 | `normalize` is True, defaults to 2*math.pi 20 | :param eps: A value added to the denominator for numerical stability, defaults to 1e-6 21 | :param offset: An offset added to embed, defaults to 0.0 22 | """ 23 | def __init__( 24 | self, 25 | num_pos_feats=64, 26 | temperature: Union[int, Tuple[int, int]] = 10000, 27 | normalize=False, 28 | scale=2 * math.pi, 29 | eps=1e-6, 30 | offset=0.0, 31 | ): 32 | super().__init__() 33 | assert isinstance(temperature, int) or len(temperature) == 2, \ 34 | "Only support (t_x, t_y) or an integer t for temperature" 35 | 36 | self.num_pos_feats = num_pos_feats 37 | self.temperature = temperature 38 | self.normalize = normalize 39 | self.scale = scale 40 | self.eps = eps 41 | self.offset = offset 42 | 43 | def get_dim_t(self, device: torch.device): 44 | if isinstance(self.temperature, int): 45 | dim_t = get_dim_t(self.num_pos_feats, self.temperature, device) 46 | return dim_t, dim_t 47 | return (get_dim_t(self.num_pos_feats, t, device) for t in self.temperature) 48 | 49 | def forward(self, mask: Tensor): 50 | not_mask = (~mask).int() # onnx export does not support cumsum on bool tensor 51 | y_embed = not_mask.cumsum(1) 52 | x_embed = not_mask.cumsum(2) 53 | if self.normalize: 54 | y_embed = (y_embed + self.offset) / (y_embed[:, -1:, :] + self.eps) * self.scale 55 | x_embed = (x_embed + self.offset) / (x_embed[:, :, -1:] + self.eps) * self.scale 56 | else: 57 | # RT-DETR uses unnormalized encoding with index from 0 58 | y_embed = y_embed + self.offset 59 | x_embed = x_embed + self.offset 60 | 61 | dim_tx, dim_ty = self.get_dim_t(mask.device) 62 | 63 | pos_x = x_embed.unsqueeze(-1) / dim_tx 64 | pos_y = y_embed.unsqueeze(-1) / dim_ty 65 | pos_x = torch.stack((pos_x.sin(), pos_x.cos()), dim=-1).flatten(-2) 66 | pos_y = torch.stack((pos_y.sin(), pos_y.cos()), dim=-1).flatten(-2) 67 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) 68 | return pos 69 | 70 | 71 | class PositionEmbeddingLearned(nn.Module): 72 | """Absolute pos embedding, learned.""" 73 | def __init__(self, num_embeddings: int = 50, num_pos_feats: int = 256): 74 | super().__init__() 75 | self.row_embed = nn.Embedding(num_embeddings, num_pos_feats) 76 | self.col_embed = nn.Embedding(num_embeddings, num_pos_feats) 77 | self.reset_parameters() 78 | 79 | def reset_parameters(self): 80 | nn.init.uniform_(self.row_embed.weight) 81 | nn.init.uniform_(self.col_embed.weight) 82 | 83 | def forward(self, mask: Tensor): 84 | h, w = mask.shape[-2:] 85 | i = torch.arange(w, device=mask.device) 86 | j = torch.arange(h, device=mask.device) 87 | x_emb = self.col_embed(i) 88 | y_emb = self.row_embed(j) 89 | pos = ( 90 | torch.cat( 91 | [ 92 | x_emb.unsqueeze(0).repeat(h, 1, 1), 93 | y_emb.unsqueeze(1).repeat(1, w, 1), 94 | ], 95 | dim=-1, 96 | ).permute(2, 0, 1).unsqueeze(0).repeat(mask.shape[0], 1, 1, 1) 97 | ) 98 | return pos 99 | 100 | 101 | @functools.lru_cache # use lru_cache to avoid redundant calculation for dim_t 102 | def get_dim_t(num_pos_feats: int, temperature: int, device: torch.device): 103 | dim_t = torch.arange(num_pos_feats // 2, dtype=torch.float32, device=device) 104 | dim_t = temperature**(dim_t * 2 / num_pos_feats) 105 | return dim_t # (0, 2, 4, ..., ⌊n/2⌋*2) 106 | 107 | def exchange_xy_fn(pos_res): 108 | index = torch.cat([ 109 | torch.arange(1, -1, -1, device=pos_res.device), 110 | torch.arange(2, pos_res.shape[-2], device=pos_res.device), 111 | ]) 112 | pos_res = torch.index_select(pos_res, -2, index) 113 | return pos_res 114 | 115 | def get_sine_pos_embed( 116 | pos_tensor: Tensor, 117 | num_pos_feats: int = 128, 118 | temperature: int = 10000, 119 | scale: float = 2 * math.pi, 120 | exchange_xy: bool = True, 121 | ) -> Tensor: 122 | """Generate sine position embedding for a position tensor 123 | 124 | :param pos_tensor: shape as (..., 2*n). 125 | :param num_pos_feats: projected shape for each float in the tensor, defaults to 128 126 | :param temperature: the temperature used for scaling the position embedding, defaults to 10000 127 | :param exchange_xy: exchange pos x and pos. For example, 128 | input tensor is [x, y], the results will be [pos(y), pos(x)], defaults to True 129 | :return: position embedding with shape (None, n * num_pos_feats) 130 | """ 131 | dim_t = get_dim_t(num_pos_feats, temperature, pos_tensor.device) 132 | 133 | pos_res = pos_tensor.unsqueeze(-1) * scale / dim_t 134 | pos_res = torch.stack((pos_res.sin(), pos_res.cos()), dim=-1).flatten(-2) 135 | if exchange_xy: 136 | pos_res = exchange_xy_fn(pos_res) 137 | pos_res = pos_res.flatten(-2) 138 | return pos_res 139 | -------------------------------------------------------------------------------- /models/bricks/post_process.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from torchvision.ops import boxes as box_ops 6 | 7 | 8 | class PostProcess(nn.Module): 9 | """This module converts the model's output into the format expected by the coco api""" 10 | def __init__( 11 | self, 12 | select_box_nums_for_evaluation=100, 13 | nms_iou_threshold=-1, 14 | confidence_score=-1, 15 | ): 16 | super().__init__() 17 | self.select_box_nums_for_evaluation = select_box_nums_for_evaluation 18 | self.nms_iou_threshold = nms_iou_threshold 19 | self.confidence_score = confidence_score 20 | 21 | @torch.no_grad() 22 | def forward(self, outputs, target_sizes): 23 | out_logits, out_bbox = outputs["pred_logits"], outputs["pred_boxes"] 24 | 25 | assert len(out_logits) == len(target_sizes) 26 | assert target_sizes.shape[1] == 2 27 | 28 | prob = out_logits.sigmoid() 29 | topk_values, topk_indexes = torch.topk( 30 | prob.view(out_logits.shape[0], -1), 31 | self.select_box_nums_for_evaluation, 32 | dim=1, 33 | ) 34 | scores = topk_values 35 | topk_boxes = torch.div(topk_indexes, out_logits.shape[2], rounding_mode="trunc") 36 | labels = topk_indexes % out_logits.shape[2] 37 | boxes = box_ops._box_cxcywh_to_xyxy(out_bbox) 38 | boxes = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).repeat(1, 1, 4)) 39 | 40 | # and from relative [0, 1] to absolute [0, height] coordinates 41 | img_h, img_w = target_sizes.unbind(1) 42 | scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1) 43 | boxes = boxes * scale_fct[:, None, :] 44 | 45 | item_indice = None 46 | # filter low-confidence predictions 47 | if self.confidence_score > 0: 48 | item_indice = [score > self.confidence_score for score in scores] 49 | 50 | # filter overlap predictions 51 | if self.nms_iou_threshold > 0: 52 | nms_indice = [ 53 | box_ops.nms(box, score, iou_threshold=self.nms_iou_threshold) 54 | for box, score in zip(boxes, scores) 55 | ] 56 | nms_binary_indice = [torch.zeros_like(item_index, dtype=torch.bool) for item_index in item_indice] 57 | for nms_binary_index, nms_index in zip(nms_binary_indice, nms_indice): 58 | nms_binary_index[nms_index] = True 59 | item_indice = [ 60 | item_index & nms_binary_index 61 | for item_index, nms_binary_index in zip(item_indice, nms_binary_indice) 62 | ] 63 | 64 | if item_indice is not None: 65 | scores = [score[item_index] for score, item_index in zip(scores, item_indice)] 66 | boxes = [box[item_index] for box, item_index in zip(boxes, item_indice)] 67 | labels = [label[item_index] for label, item_index in zip(labels, item_indice)] 68 | 69 | if torchvision._is_tracing(): 70 | # avoid interation warning during ONNX export 71 | scores, labels, boxes = map(lambda x: x.unbind(0), (scores, labels, boxes)) 72 | results = [{"scores": s, "labels": l, "boxes": b} for s, l, b in zip(scores, labels, boxes)] 73 | 74 | return results 75 | 76 | 77 | class SegmentationPostProcess(nn.Module): 78 | @torch.no_grad() 79 | def forward(self, outputs, target_sizes, input_sizes, batched_input_size): 80 | out_logits, out_bbox, out_mask = ( 81 | outputs["pred_logits"], 82 | outputs["pred_boxes"], 83 | outputs["pred_masks"], 84 | ) 85 | 86 | assert len(out_logits) == len(target_sizes) 87 | assert len(batched_input_size) == 2 88 | 89 | # we average queries of the same class to get onehot segmentation image 90 | out_class = out_logits.argmax(-1) 91 | num_class = out_logits.shape[-1] 92 | result_masks = [] 93 | for image_id in range(len(out_logits)): 94 | result_masks_per_image = [] 95 | for cur_class in range(num_class): 96 | class_index = out_class[image_id] == cur_class 97 | mask_per_class = out_mask[image_id][class_index].sigmoid() 98 | if mask_per_class.numel() == 0: 99 | mask_per_class = mask_per_class.new_zeros((1, *mask_per_class.shape[-2:])) 100 | mask_per_class = mask_per_class.mean(0) 101 | result_masks_per_image.append(mask_per_class) 102 | result_masks_per_image = torch.stack(result_masks_per_image, 0) 103 | result_masks.append(result_masks_per_image) 104 | result_masks = torch.stack(result_masks, 0) 105 | 106 | # upsample masks with 1/4 resolution to input image shapes 107 | result_masks = F.interpolate( 108 | result_masks, 109 | size=batched_input_size, 110 | mode="bilinear", 111 | align_corners=False, 112 | ) 113 | 114 | # resize masks to original shapes and transform onehot into class 115 | mask_results = [] 116 | for mask, (height, width), (out_height, out_width) in zip( 117 | result_masks, 118 | input_sizes, 119 | target_sizes, 120 | ): 121 | mask = F.interpolate( 122 | mask[None, :, :height, :width], 123 | size=(out_height, out_width), 124 | mode="bilinear", 125 | align_corners=False, 126 | )[0] 127 | mask_results.append({"masks": mask.argmax(0)}) 128 | 129 | return mask_results 130 | -------------------------------------------------------------------------------- /models/detectors/dab_deformable_detr.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List 2 | 3 | from torch import Tensor, nn 4 | 5 | from models.detectors.base_detector import DETRDetector 6 | 7 | 8 | class DabDeformableDETR(DETRDetector): 9 | def __init__( 10 | self, 11 | backbone: nn.Module, 12 | neck: nn.Module, 13 | position_embedding: nn.Module, 14 | transformer: nn.Module, 15 | criterion: nn.Module, 16 | postprocessor: nn.Module, 17 | num_classes: int, 18 | min_size: int = None, 19 | max_size: int = None, 20 | ): 21 | super().__init__(min_size, max_size) 22 | # define model parameters 23 | self.num_classes = num_classes 24 | 25 | # define model strctures 26 | self.backbone = backbone 27 | self.neck = neck 28 | self.position_embedding = position_embedding 29 | self.transformer = transformer 30 | self.criterion = criterion 31 | self.postprocessor = postprocessor 32 | 33 | def forward(self, images: List[Tensor], targets: List[Dict] = None): 34 | # get original image sizes, used for postprocess 35 | original_image_sizes = self.query_original_sizes(images) 36 | images, targets, mask = self.preprocess(images, targets) 37 | 38 | # get multi-level features, masks, and pos_embeds 39 | multi_levels = self.get_multi_levels(images, mask) 40 | multi_level_feats, multi_level_masks, multi_level_pos_embeds = multi_levels 41 | 42 | # feed into transformer 43 | outputs_class, outputs_coord, enc_class, enc_coord = self.transformer( 44 | multi_level_feats, multi_level_masks, multi_level_pos_embeds 45 | ) 46 | 47 | # prepare for loss computation 48 | output = {"pred_logits": outputs_class[-1], "pred_boxes": outputs_coord[-1]} 49 | output["aux_outputs"] = self._set_aux_loss(outputs_class, outputs_coord) 50 | output["enc_outputs"] = {"pred_logits": enc_class, "pred_boxes": enc_coord} 51 | 52 | if self.training: 53 | # compute loss 54 | loss_dict = self.criterion(output, targets) 55 | 56 | # loss reweighting 57 | weight_dict = self.criterion.weight_dict 58 | loss_dict = dict((k, loss_dict[k] * weight_dict[k]) 59 | for k in loss_dict.keys() 60 | if k in weight_dict) 61 | return loss_dict 62 | 63 | detections = self.postprocessor(output, original_image_sizes) 64 | return detections 65 | -------------------------------------------------------------------------------- /models/detectors/deformable_detr.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List 2 | 3 | from torch import Tensor, nn 4 | 5 | from models.detectors.base_detector import DETRDetector 6 | 7 | 8 | class DeformableDETR(DETRDetector): 9 | def __init__( 10 | self, 11 | backbone: nn.Module, 12 | neck: nn.Module, 13 | position_embedding: nn.Module, 14 | transformer: nn.Module, 15 | criterion: nn.Module, 16 | postprocessor: nn.Module, 17 | num_classes: int, 18 | min_size: int = None, 19 | max_size: int = None, 20 | ): 21 | super().__init__(min_size, max_size) 22 | # NOTE: we only suppoert DeformableDETR with two-stage and box refinement 23 | 24 | # define model parameters 25 | self.num_classes = num_classes 26 | 27 | # define model structures 28 | self.backbone = backbone 29 | self.neck = neck 30 | self.position_embedding = position_embedding 31 | self.transformer = transformer 32 | self.criterion = criterion 33 | self.postprocessor = postprocessor 34 | 35 | def forward(self, images: List[Tensor], targets: List[Dict] = None): 36 | # get original image sizes, used for postprocess 37 | original_image_sizes = self.query_original_sizes(images) 38 | images, targets, mask = self.preprocess(images, targets) 39 | 40 | # get multi-level features, masks, and pos_embeds 41 | multi_levels = self.get_multi_levels(images, mask) 42 | multi_level_feats, multi_level_masks, multi_level_pos_embeds = multi_levels 43 | 44 | outputs_class, outputs_coord, enc_class, enc_coord = self.transformer( 45 | multi_level_feats, multi_level_masks, multi_level_pos_embeds 46 | ) 47 | 48 | # prepare for loss computation 49 | output = {"pred_logits": outputs_class[-1], "pred_boxes": outputs_coord[-1]} 50 | output["aux_outputs"] = self._set_aux_loss(outputs_class, outputs_coord) 51 | output["enc_outputs"] = {"pred_logits": enc_class, "pred_boxes": enc_coord} 52 | 53 | if self.training: 54 | # compute loss 55 | loss_dict = self.criterion(output, targets) 56 | 57 | # loss reweighting 58 | weight_dict = self.criterion.weight_dict 59 | loss_dict = dict((k, loss_dict[k] * weight_dict[k]) for k in loss_dict.keys() if k in weight_dict) 60 | return loss_dict 61 | 62 | detections = self.postprocessor(output, original_image_sizes) 63 | return detections 64 | -------------------------------------------------------------------------------- /models/detectors/dino.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List 2 | 3 | from torch import Tensor, nn 4 | 5 | from models.bricks.denoising import GenerateCDNQueries 6 | from models.detectors.base_detector import DNDETRDetector 7 | 8 | 9 | class DINO(DNDETRDetector): 10 | def __init__( 11 | # model structure 12 | self, 13 | backbone: nn.Module, 14 | neck: nn.Module, 15 | position_embedding: nn.Module, 16 | transformer: nn.Module, 17 | criterion: nn.Module, 18 | postprocessor: nn.Module, 19 | # model parameters 20 | num_classes: int, 21 | num_queries: int = 900, 22 | denoising_nums: int = 100, 23 | # model variants 24 | min_size: int = None, 25 | max_size: int = None, 26 | ): 27 | super().__init__(min_size, max_size) 28 | # define model parameters 29 | self.num_classes = num_classes 30 | embed_dim = transformer.embed_dim 31 | 32 | # define model structures 33 | self.backbone = backbone 34 | self.neck = neck 35 | self.position_embedding = position_embedding 36 | self.transformer = transformer 37 | self.criterion = criterion 38 | self.postprocessor = postprocessor 39 | self.denoising_generator = GenerateCDNQueries( 40 | num_queries=num_queries, 41 | num_classes=num_classes, 42 | label_embed_dim=embed_dim, 43 | denoising_nums=denoising_nums, 44 | label_noise_prob=0.5, 45 | box_noise_scale=1.0, 46 | ) 47 | 48 | def forward(self, images: List[Tensor], targets: List[Dict] = None): 49 | # get original image sizes, used for postprocess 50 | original_image_sizes = self.query_original_sizes(images) 51 | images, targets, mask = self.preprocess(images, targets) 52 | 53 | # get multi-level features, masks, and pos_embeds 54 | multi_levels = self.get_multi_levels(images, mask) 55 | multi_level_feats, multi_level_masks, multi_level_pos_embeds = multi_levels 56 | 57 | if self.training: 58 | # collect ground truth for denoising generation 59 | gt_labels_list = [t["labels"] for t in targets] 60 | gt_boxes_list = [t["boxes"] for t in targets] 61 | noised_results = self.denoising_generator(gt_labels_list, gt_boxes_list) 62 | noised_label_query = noised_results[0] 63 | noised_box_query = noised_results[1] 64 | attn_mask = noised_results[2] 65 | denoising_groups = noised_results[3] 66 | max_gt_num_per_image = noised_results[4] 67 | else: 68 | noised_label_query = None 69 | noised_box_query = None 70 | attn_mask = None 71 | denoising_groups = None 72 | max_gt_num_per_image = None 73 | 74 | # feed into transformer 75 | outputs_class, outputs_coord, enc_class, enc_coord = self.transformer( 76 | multi_level_feats, 77 | multi_level_masks, 78 | multi_level_pos_embeds, 79 | noised_label_query, 80 | noised_box_query, 81 | attn_mask=attn_mask, 82 | ) 83 | # hack implementation for distributed training 84 | outputs_class[0] += self.denoising_generator.label_encoder.weight[0, 0] * 0.0 85 | 86 | # denoising postprocessing 87 | if denoising_groups is not None and max_gt_num_per_image is not None: 88 | dn_metas = { 89 | "denoising_groups": denoising_groups, 90 | "max_gt_num_per_image": max_gt_num_per_image, 91 | } 92 | outputs_class, outputs_coord = self.dn_post_process( 93 | outputs_class, outputs_coord, dn_metas 94 | ) 95 | 96 | # prepare for loss computation 97 | output = {"pred_logits": outputs_class[-1], "pred_boxes": outputs_coord[-1]} 98 | output["aux_outputs"] = self._set_aux_loss(outputs_class, outputs_coord) 99 | output["enc_outputs"] = {"pred_logits": enc_class, "pred_boxes": enc_coord} 100 | 101 | if self.training: 102 | # compute loss 103 | loss_dict = self.criterion(output, targets) 104 | dn_losses = self.compute_dn_loss(dn_metas, targets) 105 | loss_dict.update(dn_losses) 106 | 107 | # loss reweighting 108 | weight_dict = self.criterion.weight_dict 109 | loss_dict = dict((k, loss_dict[k] * weight_dict[k]) 110 | for k in loss_dict.keys() 111 | if k in weight_dict) 112 | return loss_dict 113 | 114 | detections = self.postprocessor(output, original_image_sizes) 115 | return detections 116 | -------------------------------------------------------------------------------- /models/detectors/dn_deformable_detr.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List 2 | 3 | from torch import Tensor, nn 4 | 5 | from models.bricks.denoising import GenerateDNQueries 6 | from models.detectors.base_detector import DNDETRDetector 7 | 8 | 9 | class DNDeformableDETR(DNDETRDetector): 10 | def __init__( 11 | self, 12 | backbone: nn.Module, 13 | neck: nn.Module, 14 | position_embedding: nn.Module, 15 | transformer: nn.Module, 16 | criterion: nn.Module, 17 | postprocessor: nn.Module, 18 | num_classes: int, 19 | num_queries: int = 300, 20 | denoising_groups: int = 5, 21 | min_size: int = None, 22 | max_size: int = None, 23 | ): 24 | super().__init__(min_size, max_size) 25 | # NOTE: Acording to authentic and detrex implementation, DN-Def-DETR has no two-stage setting. 26 | # DN-Def-DETR with two-stage settings is equaivalent to DINO without look-forward-twice and CDN. 27 | 28 | # define model parameters 29 | self.num_classes = num_classes 30 | self.num_queries = num_queries 31 | embed_dim = transformer.embed_dim 32 | 33 | # define model structures 34 | self.backbone = backbone 35 | self.neck = neck 36 | self.position_embedding = position_embedding 37 | self.transformer = transformer 38 | self.criterion = criterion 39 | self.postprocessor = postprocessor 40 | self.denoising_generator = GenerateDNQueries( 41 | num_queries=num_queries, 42 | num_classes=num_classes, 43 | label_embed_dim=embed_dim, 44 | denoising_groups=denoising_groups, 45 | label_noise_prob=0.2, 46 | box_noise_scale=0.4, 47 | with_indicator=True, 48 | ) 49 | 50 | def forward(self, images: List[Tensor], targets: List[Dict] = None): 51 | # get original image sizes, used for postprocess 52 | original_image_sizes = self.query_original_sizes(images) 53 | images, targets, mask = self.preprocess(images, targets) 54 | 55 | # get multi-level features, masks, and pos_embeds 56 | multi_levels = self.get_multi_levels(images, mask) 57 | multi_level_feats, multi_level_masks, multi_level_pos_embeds = multi_levels 58 | 59 | if self.training: 60 | # collect ground truth for denoising generation 61 | gt_labels_list = [t["labels"] for t in targets] 62 | gt_boxes_list = [t["boxes"] for t in targets] 63 | noised_results = self.denoising_generator(gt_labels_list, gt_boxes_list) 64 | noised_label_queries = noised_results[0] 65 | noised_box_queries = noised_results[1] 66 | attn_mask = noised_results[2] 67 | denoising_groups = noised_results[3] 68 | max_gt_num_per_image = noised_results[4] 69 | else: 70 | noised_label_queries = None 71 | noised_box_queries = None 72 | attn_mask = None 73 | denoising_groups = None 74 | max_gt_num_per_image = None 75 | 76 | # feed into transformer 77 | outputs_class, outputs_coord = self.transformer( 78 | multi_level_feats, 79 | multi_level_masks, 80 | multi_level_pos_embeds, 81 | noised_label_queries, 82 | noised_box_queries, 83 | attn_mask=attn_mask, 84 | ) 85 | 86 | # hack implementation for distributed training 87 | outputs_class[0] += self.denoising_generator.label_encoder.weight[0, 0] * 0.0 88 | 89 | # denoising postprocessing 90 | if denoising_groups is not None and max_gt_num_per_image is not None: 91 | dn_metas = { 92 | "denoising_groups": denoising_groups, 93 | "max_gt_num_per_image": max_gt_num_per_image 94 | } 95 | outputs_class, outputs_coord = self.dn_post_process( 96 | outputs_class, outputs_coord, dn_metas 97 | ) 98 | 99 | # prepare for loss computation 100 | output = {"pred_logits": outputs_class[-1], "pred_boxes": outputs_coord[-1]} 101 | output["aux_outputs"] = self._set_aux_loss(outputs_class, outputs_coord) 102 | 103 | if self.training: 104 | # matching loss 105 | loss_dict = self.criterion(output, targets) 106 | 107 | # denoising training loss 108 | dn_losses = self.compute_dn_loss(dn_metas, targets) 109 | loss_dict.update(dn_losses) 110 | 111 | # loss reweighting 112 | weight_dict = self.criterion.weight_dict 113 | loss_dict = dict((k, loss_dict[k] * weight_dict[k]) 114 | for k in loss_dict.keys() 115 | if k in weight_dict) 116 | return loss_dict 117 | 118 | detections = self.postprocessor(output, original_image_sizes) 119 | return detections 120 | -------------------------------------------------------------------------------- /models/detectors/relation_detr.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from typing import Dict, List 3 | 4 | from torch import Tensor, nn 5 | 6 | from models.bricks.denoising import GenerateCDNQueries 7 | from models.detectors.base_detector import DNDETRDetector 8 | 9 | 10 | class RelationDETR(DNDETRDetector): 11 | def __init__( 12 | # model structure 13 | self, 14 | backbone: nn.Module, 15 | neck: nn.Module, 16 | position_embedding: nn.Module, 17 | transformer: nn.Module, 18 | criterion: nn.Module, 19 | postprocessor: nn.Module, 20 | # model parameters 21 | num_classes: int, 22 | num_queries: int = 900, 23 | hybrid_assign: int = 6, 24 | denoising_nums: int = 100, 25 | # model variants 26 | min_size: int = None, 27 | max_size: int = None, 28 | ): 29 | super().__init__(min_size, max_size) 30 | # define model parameters 31 | self.num_classes = num_classes 32 | embed_dim = transformer.embed_dim 33 | self.hybrid_assign = hybrid_assign 34 | 35 | # define model structures 36 | self.backbone = backbone 37 | self.neck = neck 38 | self.position_embedding = position_embedding 39 | self.transformer = transformer 40 | self.criterion = criterion 41 | self.postprocessor = postprocessor 42 | self.denoising_generator = GenerateCDNQueries( 43 | num_queries=num_queries, 44 | num_classes=num_classes, 45 | label_embed_dim=embed_dim, 46 | denoising_nums=denoising_nums, 47 | label_noise_prob=0.5, 48 | box_noise_scale=1.0, 49 | ) 50 | 51 | def forward(self, images: List[Tensor], targets: List[Dict] = None): 52 | # get original image sizes, used for postprocess 53 | original_image_sizes = self.query_original_sizes(images) 54 | images, targets, mask = self.preprocess(images, targets) 55 | 56 | # get multi-level features, masks, and pos_embeds 57 | multi_levels = self.get_multi_levels(images, mask) 58 | multi_level_feats, multi_level_masks, multi_level_pos_embeds = multi_levels 59 | 60 | if self.training: 61 | # collect ground truth for denoising generation 62 | gt_labels_list = [t["labels"] for t in targets] 63 | gt_boxes_list = [t["boxes"] for t in targets] 64 | noised_results = self.denoising_generator(gt_labels_list, gt_boxes_list) 65 | noised_label_queries = noised_results[0] 66 | noised_box_queries = noised_results[1] 67 | attn_mask = noised_results[2] 68 | denoising_groups = noised_results[3] 69 | max_gt_num_per_image = noised_results[4] 70 | else: 71 | noised_label_queries = None 72 | noised_box_queries = None 73 | attn_mask = None 74 | denoising_groups = None 75 | max_gt_num_per_image = None 76 | 77 | # feed into transformer 78 | ( 79 | outputs_class, 80 | outputs_coord, 81 | enc_class, 82 | enc_coord, 83 | hybrid_class, 84 | hybrid_coord, 85 | hybrid_enc_class, 86 | hybrid_enc_coord, 87 | ) = self.transformer( 88 | multi_level_feats, 89 | multi_level_masks, 90 | multi_level_pos_embeds, 91 | noised_label_queries, 92 | noised_box_queries, 93 | attn_mask=attn_mask, 94 | ) 95 | 96 | # hack implemantation for distributed training 97 | outputs_class[0] += self.denoising_generator.label_encoder.weight[0, 0] * 0.0 98 | 99 | # denoising postprocessing 100 | if denoising_groups is not None and max_gt_num_per_image is not None: 101 | dn_metas = { 102 | "denoising_groups": denoising_groups, 103 | "max_gt_num_per_image": max_gt_num_per_image 104 | } 105 | outputs_class, outputs_coord = self.dn_post_process( 106 | outputs_class, outputs_coord, dn_metas 107 | ) 108 | 109 | # prepare for loss computation 110 | output = {"pred_logits": outputs_class[-1], "pred_boxes": outputs_coord[-1]} 111 | output["aux_outputs"] = self._set_aux_loss(outputs_class, outputs_coord) 112 | output["enc_outputs"] = {"pred_logits": enc_class, "pred_boxes": enc_coord} 113 | 114 | if self.training: 115 | # prepare for hybrid loss computation 116 | hybrid_metas = {"pred_logits": hybrid_class[-1], "pred_boxes": hybrid_coord[-1]} 117 | hybrid_metas["aux_outputs"] = self._set_aux_loss(hybrid_class, hybrid_coord) 118 | hybrid_metas["enc_outputs"] = { 119 | "pred_logits": hybrid_enc_class, 120 | "pred_boxes": hybrid_enc_coord 121 | } 122 | 123 | # compute loss 124 | loss_dict = self.criterion(output, targets) 125 | dn_losses = self.compute_dn_loss(dn_metas, targets) 126 | loss_dict.update(dn_losses) 127 | 128 | # compute hybrid loss 129 | multi_targets = copy.deepcopy(targets) 130 | for t in multi_targets: 131 | t["boxes"] = t["boxes"].repeat(self.hybrid_assign, 1) 132 | t["labels"] = t["labels"].repeat(self.hybrid_assign) 133 | hybrid_losses = self.criterion(hybrid_metas, multi_targets) 134 | loss_dict.update({k + "_hybrid": v for k, v in hybrid_losses.items()}) 135 | 136 | # loss reweighting 137 | weight_dict = self.criterion.weight_dict 138 | loss_dict = dict((k, loss_dict[k] * weight_dict[k]) 139 | for k in loss_dict.keys() 140 | if k in weight_dict) 141 | return loss_dict 142 | 143 | detections = self.postprocessor(output, original_image_sizes) 144 | return detections 145 | -------------------------------------------------------------------------------- /models/matcher/hungarian_matcher.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from scipy.optimize import linear_sum_assignment 3 | from torch import Tensor, nn 4 | from torchvision.ops.boxes import _box_cxcywh_to_xyxy, generalized_box_iou 5 | 6 | 7 | class HungarianMatcher(nn.Module): 8 | """This class implements the Hungarian matching algorithm for bipartite graphs. It matches predicted bounding 9 | boxes to ground truth boxes based on the minimum cost assignment. The cost is computed as a weighted sum of 10 | classification, bounding box, and generalized intersection over union (IoU) costs. The focal loss is used to 11 | weigh the classification cost. The HungarianMatcher class can be used in single or mixed assignment modes. 12 | The mixed assignment modes is introduced in `Align-DETR `_. 13 | 14 | :param cost_class: The weight of the classification cost, defaults to 1 15 | :param cost_bbox: The weight of the bounding box cost, defaults to 1 16 | :param cost_giou: The weight of the generalized IoU cost, defaults to 1 17 | :param focal_alpha: The alpha parameter of the focal loss, defaults to 0.25 18 | :param focal_gamma: The gamma parameter of the focal loss, defaults to 2.0 19 | :param mixed_match: If True, mixed assignment is used, defaults to False 20 | """ 21 | def __init__( 22 | self, 23 | cost_class: float = 1, 24 | cost_bbox: float = 1, 25 | cost_giou: float = 1, 26 | focal_alpha: float = 0.25, 27 | focal_gamma: float = 2.0, 28 | mixed_match: bool = False, 29 | ): 30 | super().__init__() 31 | 32 | self.cost_class = cost_class 33 | self.cost_bbox = cost_bbox 34 | self.cost_giou = cost_giou 35 | assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, "all costs cant be 0" 36 | 37 | self.focal_alpha = focal_alpha 38 | self.focal_gamma = focal_gamma 39 | self.mixed_match = mixed_match 40 | 41 | def calculate_class_cost(self, pred_logits, gt_labels, **kwargs): 42 | out_prob = pred_logits.sigmoid() 43 | 44 | # Compute the classification cost. 45 | neg_cost_class = -(1 - self.focal_alpha) * out_prob**self.focal_gamma * (1 - out_prob + 1e-6).log() 46 | pos_cost_class = -self.focal_alpha * (1 - out_prob)**self.focal_gamma * (out_prob + 1e-6).log() 47 | cost_class = pos_cost_class[:, gt_labels] - neg_cost_class[:, gt_labels] 48 | 49 | return cost_class 50 | 51 | def calculate_bbox_cost(self, pred_boxes, gt_boxes, **kwargs): 52 | # Compute the L1 cost between boxes 53 | cost_bbox = torch.cdist(pred_boxes, gt_boxes, p=1) 54 | return cost_bbox 55 | 56 | def calculate_giou_cost(self, pred_boxes, gt_boxes, **kwargs): 57 | # Compute the giou cost betwen boxes 58 | cost_giou = -generalized_box_iou(_box_cxcywh_to_xyxy(pred_boxes), _box_cxcywh_to_xyxy(gt_boxes)) 59 | return cost_giou 60 | 61 | @torch.no_grad() 62 | def calculate_cost(self, pred_boxes: Tensor, pred_logits: Tensor, gt_boxes: Tensor, gt_labels: Tensor): 63 | # Calculate class, bbox and giou cost 64 | cost_class = self.calculate_class_cost(pred_logits, gt_labels) 65 | cost_bbox = self.calculate_bbox_cost(pred_boxes, gt_boxes) 66 | cost_giou = self.calculate_giou_cost(pred_boxes, gt_boxes) 67 | 68 | # Final cost matrix 69 | c = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou 70 | return c 71 | 72 | @torch.no_grad() 73 | def forward( 74 | self, pred_boxes: Tensor, pred_logits: Tensor, gt_boxes: Tensor, gt_labels: Tensor, gt_copy: int = 1 75 | ): 76 | c = self.calculate_cost(pred_boxes, pred_logits, gt_boxes, gt_labels) 77 | 78 | # single assignment 79 | if not self.mixed_match: 80 | indices = linear_sum_assignment(c.cpu()) 81 | return torch.as_tensor(indices[0]), torch.as_tensor(indices[1]) 82 | 83 | # mixed assignment, used in AlignDETR 84 | gt_size = c.size(-1) 85 | num_queries = len(c) 86 | gt_copy = min(int(num_queries * 0.5 / gt_size), gt_copy) if gt_size > 0 else gt_copy 87 | src_ind, tgt_ind = linear_sum_assignment(c.cpu().repeat(1, gt_copy)) 88 | tgt_ind = tgt_ind % gt_size 89 | tgt_ind, ind = torch.as_tensor(tgt_ind, dtype=torch.int64).sort() 90 | src_ind = torch.as_tensor(src_ind, dtype=torch.int64)[ind].view(-1) 91 | return src_ind, tgt_ind 92 | -------------------------------------------------------------------------------- /models/necks/channel_mapper.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from typing import List 3 | 4 | from torch import nn 5 | from models.bricks.misc import Conv2dNormActivation 6 | 7 | 8 | class ChannelMapper(nn.Module): 9 | def __init__( 10 | self, 11 | in_channels: List[int], 12 | out_channels: int, 13 | num_outs: int, 14 | kernel_size: int = 1, 15 | stride: int = 1, 16 | groups: int = 1, 17 | norm_layer=partial(nn.GroupNorm, 32), 18 | activation_layer: nn.Module = None, 19 | dilation: int = 1, 20 | inplace: bool = True, 21 | bias: bool = None, 22 | ): 23 | self.in_channels = in_channels 24 | super().__init__() 25 | self.convs = nn.ModuleList() 26 | self.num_channels = [out_channels] * num_outs 27 | for in_channel in in_channels: 28 | self.convs.append( 29 | Conv2dNormActivation( 30 | in_channels=in_channel, 31 | out_channels=out_channels, 32 | kernel_size=kernel_size, 33 | stride=stride, 34 | padding=(kernel_size - 1) // 2, 35 | bias=bias, 36 | groups=groups, 37 | dilation=dilation, 38 | norm_layer=norm_layer, 39 | activation_layer=activation_layer, 40 | inplace=inplace, 41 | ) 42 | ) 43 | for _ in range(num_outs - len(in_channels)): 44 | self.convs.append( 45 | Conv2dNormActivation( 46 | in_channels=in_channel, 47 | out_channels=out_channels, 48 | kernel_size=3, 49 | stride=2, 50 | padding=1, 51 | bias=bias, 52 | groups=groups, 53 | dilation=dilation, 54 | norm_layer=norm_layer, 55 | activation_layer=activation_layer, 56 | inplace=inplace, 57 | ) 58 | ) 59 | in_channel = out_channels 60 | 61 | self.init_weights() 62 | 63 | def init_weights(self): 64 | # initialize modules 65 | for layer in self.modules(): 66 | if isinstance(layer, nn.Conv2d): 67 | nn.init.xavier_uniform_(layer.weight, gain=1) 68 | if layer.bias is not None: 69 | nn.init.constant_(layer.bias, 0) 70 | 71 | def forward(self, inputs): 72 | """Map inputs into specific embed_dim. If the length of inputs is smaller 73 | than convolution modules, only the last `len(inputs) + extra` will be used 74 | for mapping. 75 | 76 | :param inputs: dict containing torch.Tensor. 77 | :return: a list of mapped torch.Tensor 78 | """ 79 | inputs = list(inputs.values()) 80 | assert len(inputs) <= len(self.in_channels) 81 | start = len(self.in_channels) - len(inputs) 82 | convs = self.convs[start:] 83 | outs = [convs[i](inputs[i]) for i in range(len(inputs))] 84 | for i in range(len(inputs), len(convs)): 85 | if i == len(inputs): 86 | outs.append(convs[i](inputs[-1])) 87 | else: 88 | outs.append(convs[i](outs[-1])) 89 | return outs 90 | -------------------------------------------------------------------------------- /optimizer/param_dict.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple, Union 2 | 3 | from torch import nn 4 | 5 | 6 | def match_name_keywords(name: str, name_keywords: Union[Tuple, List, str]): 7 | if isinstance(name_keywords, str): 8 | name_keywords = [name_keywords] 9 | for b in name_keywords: 10 | if b in name: 11 | return True 12 | return False 13 | 14 | def basic_param(model, lr): 15 | return [{"params": [p for p in model.parameters() if p.requires_grad], "lr": lr}] 16 | 17 | def finetune_backbone_param(model, lr): 18 | return [ 19 | { 20 | "params": [ 21 | p for n, p in model.named_parameters() if "backbone" not in n and p.requires_grad 22 | ] 23 | }, 24 | { 25 | "params": [ 26 | p for n, p in model.named_parameters() if "backbone" in n and p.requires_grad 27 | ], 28 | "lr": lr * 0.1, 29 | }, 30 | ] 31 | 32 | 33 | def finetune_backbone_with_no_norm_weight_decay(model, lr): 34 | norm_classes = ( 35 | nn.modules.batchnorm._BatchNorm, 36 | nn.LayerNorm, 37 | nn.GroupNorm, 38 | nn.modules.instancenorm._InstanceNorm, 39 | nn.LocalResponseNorm, 40 | ) 41 | backbone_norm = [] 42 | other_norm = [] 43 | backbone = [] 44 | other = [] 45 | for name, module in model.named_modules(): 46 | if next(module.children(), None): 47 | if "backbone" in name: 48 | backbone.extend(p for p in module.parameters(recurse=False) if p.requires_grad) 49 | else: 50 | other.extend(p for p in module.parameters(recurse=False) if p.requires_grad) 51 | elif isinstance(module, norm_classes): 52 | if "backbone" in name: 53 | backbone_norm.extend(p for p in module.parameters() if p.requires_grad) 54 | else: 55 | other_norm.extend(p for p in module.parameters() if p.requires_grad) 56 | else: 57 | if "backbone" in name: 58 | backbone.extend(p for p in module.parameters() if p.requires_grad) 59 | else: 60 | other.extend(p for p in module.parameters() if p.requires_grad) 61 | return [ 62 | { 63 | "params": other, 64 | }, 65 | { 66 | "params": backbone_norm, 67 | "lr": lr * 0.1, 68 | "weight_decay": 0, 69 | }, 70 | { 71 | "params": other_norm, 72 | "weight_decay": 0, 73 | }, 74 | { 75 | "params": backbone, 76 | "lr": lr * 0.1, 77 | }, 78 | ] 79 | 80 | 81 | def finetune_backbone_and_linear_projection(model, lr): 82 | linear_keywords = ("reference_points", "sampling_offsets") 83 | norm_bias_keywords = ("norm", "bias") 84 | backbone = [] 85 | backbone_norm = [] 86 | linear_projection = [] 87 | linear_projection_norm = [] 88 | other = [] 89 | other_norm = [] 90 | for name, parameters in model.named_parameters(): 91 | if not parameters.requires_grad: 92 | continue 93 | if ( 94 | match_name_keywords(name, "backbone") 95 | and not match_name_keywords(name, linear_keywords) 96 | and match_name_keywords(name, norm_bias_keywords) 97 | ): 98 | backbone_norm.append(parameters) 99 | elif ( 100 | match_name_keywords(name, "backbone") 101 | and not match_name_keywords(name, linear_keywords) 102 | and not match_name_keywords(name, norm_bias_keywords) 103 | ): 104 | backbone.append(parameters) 105 | elif ( 106 | not match_name_keywords(name, "backbone") 107 | and match_name_keywords(name, linear_keywords) 108 | and match_name_keywords(name, norm_bias_keywords) 109 | ): 110 | linear_projection_norm.append(parameters) 111 | elif ( 112 | not match_name_keywords(name, "backbone") 113 | and match_name_keywords(name, linear_keywords) 114 | and not match_name_keywords(name, norm_bias_keywords) 115 | ): 116 | linear_projection.append(parameters) 117 | elif match_name_keywords(name, norm_bias_keywords): 118 | other_norm.append(parameters) 119 | else: 120 | other.append(parameters) 121 | 122 | return [ 123 | { 124 | "params": other, 125 | }, 126 | { 127 | "params": backbone, 128 | "lr": lr * 0.1, 129 | }, 130 | { 131 | "params": backbone_norm, 132 | "lr": lr * 0.1, 133 | "weight_decay": 0, 134 | }, 135 | { 136 | "params": linear_projection, 137 | "lr": lr * 0.1, 138 | }, 139 | { 140 | "params": linear_projection_norm, 141 | "lr": lr * 0.1, 142 | "weight_decay": 0, 143 | }, 144 | { 145 | "params": other_norm, 146 | "weight_decay": 0, 147 | }, 148 | ] 149 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate 2 | fvcore 3 | astunparse 4 | albumentations 5 | omegaconf 6 | pycocotools 7 | terminaltables 8 | tensorboard 9 | iopath 10 | ninja 11 | yapf -------------------------------------------------------------------------------- /tools/benchmark_model.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | 5 | import numpy as np 6 | import torch 7 | from fvcore.nn import FlopCountAnalysis, flop_count_table 8 | from tqdm import tqdm 9 | 10 | sys.path.insert(0, os.path.dirname(os.path.dirname(__file__))) 11 | from util.lazy_load import Config 12 | 13 | 14 | def parse_args(): 15 | parser = argparse.ArgumentParser(description="Benchmarking a model") 16 | 17 | parser.add_argument("--model-config", type=str, required=True) 18 | parser.add_argument("--shape", type=int, nargs="+", default=(1333, 800), help="input image size") 19 | parser.add_argument("--repeat", type=int, default=50) 20 | parser.add_argument("--device", default="cuda") 21 | 22 | args = parser.parse_args() 23 | return args 24 | 25 | 26 | def get_flops(): 27 | args = parse_args() 28 | # initialize model 29 | model = Config(args.model_config).model 30 | model.eval_transform = None 31 | model.eval().to(args.device) 32 | # test FLOPs 33 | image = torch.randn(3, args.shape[0], args.shape[1]).to(args.device) 34 | flops = FlopCountAnalysis(model, ((image,),)) 35 | print(flop_count_table(flops)) 36 | # test memory allocation 37 | print(f"Memory allocation {torch.cuda.memory_allocated() / 1024**3} GB") 38 | print(f"Max memory allocation {torch.cuda.max_memory_allocated() / 1024**3} GB") 39 | # test model parameters 40 | print(f"Model parameters {sum(p.numel() for p in model.parameters()) / 1024**3} GB") 41 | 42 | # test inference time 43 | print("warm up...") 44 | with torch.inference_mode(): 45 | for _ in range(10): 46 | _ = model((image,)) 47 | torch.cuda.synchronize() 48 | starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) 49 | timings = np.zeros((args.repeat, 1)) 50 | print("testing inference time...") 51 | with torch.inference_mode(): 52 | for rep in tqdm(range(args.repeat)): 53 | starter.record() 54 | _ = model((image,)) 55 | ender.record() 56 | torch.cuda.synchronize() 57 | curr_time = starter.elapsed_time(ender) 58 | timings[rep] = curr_time 59 | 60 | avg = timings.sum() / rep 61 | print(f"avg inference time per image = {avg / 1000}") 62 | 63 | 64 | if __name__ == "__main__": 65 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 66 | get_flops() 67 | -------------------------------------------------------------------------------- /tools/pytorch2onnx.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | import warnings 5 | from typing import Dict, List, Tuple 6 | 7 | import numpy as np 8 | import onnx 9 | import torch 10 | from torch import Tensor 11 | 12 | sys.path.insert(0, os.path.dirname(os.path.dirname(__file__))) 13 | from util import utils 14 | from util.lazy_load import Config 15 | 16 | 17 | class ONNXDetector: 18 | def __init__(self, onnx_file): 19 | import onnxruntime 20 | self.session = onnxruntime.InferenceSession( 21 | onnx_file, providers=["CUDAExecutionProvider", "CPUExecutionProvider"] 22 | ) 23 | self.io_binding = self.session.io_binding() 24 | self.is_cuda_available = onnxruntime.get_device() == "GPU" 25 | 26 | def __call__(self, images: List[Tensor], targets: List[Dict] = None): 27 | if targets is not None: 28 | warnings.warn("Currently ONNXDetector only support inference, targets will be ignored") 29 | assert len(images) == 1, "Currently ONNXDetector only support batch_size=1 for inference" 30 | assert images[0].ndim == 3, "Each image must be with three dimensions of C, H, W" 31 | if isinstance(images, (List, Tuple)): 32 | images = torch.stack(images) 33 | 34 | # set io binding for inputs/outputs 35 | device_type = images.device.type if self.is_cuda_available else "cpu" 36 | if not self.is_cuda_available: 37 | images = images.cpu() 38 | self.io_binding.bind_input( 39 | name="images", 40 | device_type=device_type, 41 | device_id=0, 42 | element_type=np.float32, 43 | shape=images.shape, 44 | buffer_ptr=images.data_ptr(), 45 | ) 46 | for output in self.session.get_outputs(): 47 | self.io_binding.bind_output(output.name) 48 | 49 | # run session to get outputs 50 | self.session.run_with_iobinding(self.io_binding) 51 | detections = self.io_binding.copy_outputs_to_cpu() 52 | return detections 53 | 54 | 55 | def parse_args(): 56 | parser = argparse.ArgumentParser(description="Convert a pytorch model to ONNX model") 57 | 58 | # model parameters 59 | parser.add_argument("--model-config", type=str, default=None) 60 | parser.add_argument("--checkpoint", type=str, default=None) 61 | parser.add_argument("--shape", type=int, nargs="+", default=(1333, 800)) 62 | 63 | # save parameters 64 | parser.add_argument("--save-file", type=str, required=True) 65 | 66 | # onnx parameters 67 | parser.add_argument("--opset-version", type=int, default=17) 68 | parser.add_argument("--dynamic-export", type=bool, default=True) 69 | parser.add_argument("--simplify", action="store_true") 70 | parser.add_argument("--verify", action="store_true") 71 | 72 | args = parser.parse_args() 73 | return args 74 | 75 | 76 | def pytorch2onnx(): 77 | # get args from parser 78 | args = parse_args() 79 | model = Config(args.model_config).model 80 | model.eval() 81 | if args.checkpoint: 82 | checkpoint = torch.load(args.checkpoint, map_location="cpu") 83 | utils.load_state_dict(model, checkpoint["model"] if "model" in checkpoint else checkpoint) 84 | image = torch.randn(1, 3, args.shape[0], args.shape[1]) 85 | 86 | if args.dynamic_export: 87 | dynamic_axes = { 88 | "images": { 89 | 0: "batch", 90 | 2: "height", 91 | 3: "width", 92 | }, 93 | } 94 | else: 95 | dynamic_axes = None 96 | torch.onnx.export( 97 | model=model, 98 | args=image, 99 | f=args.save_file, 100 | input_names=["images"], 101 | output_names=["scores", "labels", "boxes"], 102 | dynamic_axes=dynamic_axes, 103 | opset_version=args.opset_version, 104 | ) 105 | 106 | if args.simplify: 107 | import onnxsim 108 | model_ops, check_ok = onnxsim.simplify(args.save_file) 109 | if check_ok: 110 | onnx.save(model_ops, args.save_file) 111 | print(f"Successfully simplified ONNX model: {args.save_file}") 112 | else: 113 | warnings.warn("Failed to simplify ONNX model.") 114 | print(f"Successfully exported ONNX model: {args.save_file}") 115 | 116 | if args.verify: 117 | # check by onnx 118 | onnx_model = onnx.load(args.save_file) 119 | onnx.checker.check_model(onnx_model) 120 | 121 | # check onnx results and pytorch results 122 | onnx_model = ONNXDetector(args.save_file) 123 | onnx_results = onnx_model(image) 124 | pytorch_results = list(model(image)[0].values()) 125 | err_msg = "The numerical values are different between Pytorch and ONNX" 126 | err_msg += "But it does not necessarily mean the exported ONNX is problematic." 127 | for onnx_res, pytorch_res in zip(onnx_results, pytorch_results): 128 | np.testing.assert_allclose(onnx_res, pytorch_res, rtol=1e-3, atol=1e-5, err_msg=err_msg) 129 | print("The numerical values are the same between Pytorch and ONNX") 130 | 131 | 132 | if __name__ == "__main__": 133 | pytorch2onnx() 134 | -------------------------------------------------------------------------------- /tools/visualize_datasets.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | 5 | from torch.utils import data 6 | 7 | sys.path.insert(0, os.path.dirname(os.path.dirname(__file__))) 8 | from datasets.coco import CocoDetection 9 | from transforms import presets 10 | from transforms import v2 as T 11 | from util.collate_fn import collate_fn 12 | from util.logger import setup_logger 13 | from util.misc import fixed_generator, seed_worker 14 | from util.visualize import visualize_coco_bounding_boxes 15 | 16 | 17 | def parse_args(): 18 | parser = argparse.ArgumentParser(description="Visualize a datasets") 19 | 20 | # dataset parameters 21 | parser.add_argument("--coco-img", type=str, required=True) 22 | parser.add_argument("--coco-ann", type=str, required=True) 23 | parser.add_argument("--transform", type=str, default=None) 24 | parser.add_argument("--workers", type=int, default=2) 25 | 26 | # visualize parameters 27 | parser.add_argument("--show-dir", type=str, default=None, required=True) 28 | parser.add_argument("--show-conf", type=float, default=0.5) 29 | 30 | # plot parameters 31 | parser.add_argument("--font-scale", type=float, default=1.0) 32 | parser.add_argument("--box-thick", type=int, default=1) 33 | parser.add_argument("--fill-alpha", type=float, default=0.2) 34 | parser.add_argument("--text-box-color", type=int, nargs="+", default=(255, 255, 255)) 35 | parser.add_argument("--text-font-color", type=int, nargs="+", default=None) 36 | parser.add_argument("--text-alpha", type=float, default=1.0) 37 | 38 | # engine parameters 39 | parser.add_argument("--seed", type=int, default=42) 40 | 41 | args = parser.parse_args() 42 | 43 | return args 44 | 45 | 46 | def visualize_datasets(): 47 | args = parse_args() 48 | 49 | # setup logger 50 | for logger_name in ["py.warnings", "accelerate", os.path.basename(os.getcwd())]: 51 | setup_logger(name=logger_name) 52 | 53 | # remove the ConvertDtype and Normalize for visualization 54 | if args.transform: 55 | transform = getattr(presets, args.transform) 56 | transform = remove_cvtdtype_normalize(transform) 57 | else: 58 | transform = None 59 | 60 | # plot annotations for each image 61 | if args.show_dir: 62 | dataset = CocoDetection(img_folder=args.coco_img, ann_file=args.coco_ann, transforms=transform) 63 | data_loader = data.DataLoader( 64 | dataset, 65 | 1, 66 | shuffle=False, 67 | num_workers=args.workers, 68 | worker_init_fn=seed_worker, 69 | generator=fixed_generator(), 70 | collate_fn=collate_fn, 71 | ) 72 | visualize_coco_bounding_boxes( 73 | data_loader=data_loader, 74 | show_conf=args.show_conf, 75 | show_dir=args.show_dir, 76 | font_scale=args.font_scale, 77 | box_thick=args.box_thick, 78 | fill_alpha=args.fill_alpha, 79 | text_box_color=args.text_box_color, 80 | text_font_color=args.text_font_color, 81 | text_alpha=args.text_alpha, 82 | ) 83 | 84 | 85 | def remove_cvtdtype_normalize(transform): 86 | if isinstance(transform, T.Compose): 87 | transform = [remove_cvtdtype_normalize(trans) for trans in transform.transforms] 88 | transform = [trans for trans in transform if trans is not None] 89 | return T.Compose(transform) 90 | if isinstance(transform, (T.ConvertDtype, T.Normalize)): 91 | return None 92 | return transform 93 | 94 | 95 | if __name__ == "__main__": 96 | visualize_datasets() 97 | -------------------------------------------------------------------------------- /transforms/__init__.py: -------------------------------------------------------------------------------- 1 | from .transforms import * 2 | from .autoaugment import * 3 | from .convert_coco_polys_to_mask import ConvertCocoPolysToMask 4 | from .simple_copy_paste import SimpleCopyPaste 5 | -------------------------------------------------------------------------------- /transforms/_functional_video.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import torch 4 | 5 | 6 | warnings.warn( 7 | "The 'torchvision.transforms._functional_video' module is deprecated since 0.12 and will be removed in the future. " 8 | "Please use the 'torchvision.transforms.functional' module instead." 9 | ) 10 | 11 | 12 | def _is_tensor_video_clip(clip): 13 | if not torch.is_tensor(clip): 14 | raise TypeError("clip should be Tensor. Got %s" % type(clip)) 15 | 16 | if not clip.ndimension() == 4: 17 | raise ValueError("clip should be 4D. Got %dD" % clip.dim()) 18 | 19 | return True 20 | 21 | 22 | def crop(clip, i, j, h, w): 23 | """ 24 | Args: 25 | clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W) 26 | """ 27 | if len(clip.size()) != 4: 28 | raise ValueError("clip should be a 4D tensor") 29 | return clip[..., i : i + h, j : j + w] 30 | 31 | 32 | def resize(clip, target_size, interpolation_mode): 33 | if len(target_size) != 2: 34 | raise ValueError(f"target size should be tuple (height, width), instead got {target_size}") 35 | return torch.nn.functional.interpolate(clip, size=target_size, mode=interpolation_mode, align_corners=False) 36 | 37 | 38 | def resized_crop(clip, i, j, h, w, size, interpolation_mode="bilinear"): 39 | """ 40 | Do spatial cropping and resizing to the video clip 41 | Args: 42 | clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W) 43 | i (int): i in (i,j) i.e coordinates of the upper left corner. 44 | j (int): j in (i,j) i.e coordinates of the upper left corner. 45 | h (int): Height of the cropped region. 46 | w (int): Width of the cropped region. 47 | size (tuple(int, int)): height and width of resized clip 48 | Returns: 49 | clip (torch.tensor): Resized and cropped clip. Size is (C, T, H, W) 50 | """ 51 | if not _is_tensor_video_clip(clip): 52 | raise ValueError("clip should be a 4D torch.tensor") 53 | clip = crop(clip, i, j, h, w) 54 | clip = resize(clip, size, interpolation_mode) 55 | return clip 56 | 57 | 58 | def center_crop(clip, crop_size): 59 | if not _is_tensor_video_clip(clip): 60 | raise ValueError("clip should be a 4D torch.tensor") 61 | h, w = clip.size(-2), clip.size(-1) 62 | th, tw = crop_size 63 | if h < th or w < tw: 64 | raise ValueError("height and width must be no smaller than crop_size") 65 | 66 | i = int(round((h - th) / 2.0)) 67 | j = int(round((w - tw) / 2.0)) 68 | return crop(clip, i, j, th, tw) 69 | 70 | 71 | def to_tensor(clip): 72 | """ 73 | Convert tensor data type from uint8 to float, divide value by 255.0 and 74 | permute the dimensions of clip tensor 75 | Args: 76 | clip (torch.tensor, dtype=torch.uint8): Size is (T, H, W, C) 77 | Return: 78 | clip (torch.tensor, dtype=torch.float): Size is (C, T, H, W) 79 | """ 80 | _is_tensor_video_clip(clip) 81 | if not clip.dtype == torch.uint8: 82 | raise TypeError("clip tensor should have data type uint8. Got %s" % str(clip.dtype)) 83 | return clip.float().permute(3, 0, 1, 2) / 255.0 84 | 85 | 86 | def normalize(clip, mean, std, inplace=False): 87 | """ 88 | Args: 89 | clip (torch.tensor): Video clip to be normalized. Size is (C, T, H, W) 90 | mean (tuple): pixel RGB mean. Size is (3) 91 | std (tuple): pixel standard deviation. Size is (3) 92 | Returns: 93 | normalized clip (torch.tensor): Size is (C, T, H, W) 94 | """ 95 | if not _is_tensor_video_clip(clip): 96 | raise ValueError("clip should be a 4D torch.tensor") 97 | if not inplace: 98 | clip = clip.clone() 99 | mean = torch.as_tensor(mean, dtype=clip.dtype, device=clip.device) 100 | std = torch.as_tensor(std, dtype=clip.dtype, device=clip.device) 101 | clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None]) 102 | return clip 103 | 104 | 105 | def hflip(clip): 106 | """ 107 | Args: 108 | clip (torch.tensor): Video clip to be normalized. Size is (C, T, H, W) 109 | Returns: 110 | flipped clip (torch.tensor): Size is (C, T, H, W) 111 | """ 112 | if not _is_tensor_video_clip(clip): 113 | raise ValueError("clip should be a 4D torch.tensor") 114 | return clip.flip(-1) 115 | -------------------------------------------------------------------------------- /transforms/_transforms_video.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import numbers 4 | import random 5 | import warnings 6 | 7 | from transforms import RandomCrop, RandomResizedCrop 8 | 9 | from . import _functional_video as F 10 | 11 | 12 | __all__ = [ 13 | "RandomCropVideo", 14 | "RandomResizedCropVideo", 15 | "CenterCropVideo", 16 | "NormalizeVideo", 17 | "ToTensorVideo", 18 | "RandomHorizontalFlipVideo", 19 | ] 20 | 21 | 22 | warnings.warn( 23 | "The 'torchvision.transforms._transforms_video' module is deprecated since 0.12 and will be removed in the future. " 24 | "Please use the 'torchvision.transforms' module instead." 25 | ) 26 | 27 | 28 | class RandomCropVideo(RandomCrop): 29 | def __init__(self, size): 30 | if isinstance(size, numbers.Number): 31 | self.size = (int(size), int(size)) 32 | else: 33 | self.size = size 34 | 35 | def __call__(self, clip): 36 | """ 37 | Args: 38 | clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W) 39 | Returns: 40 | torch.tensor: randomly cropped/resized video clip. 41 | size is (C, T, OH, OW) 42 | """ 43 | i, j, h, w = self.get_params(clip, self.size) 44 | return F.crop(clip, i, j, h, w) 45 | 46 | def __repr__(self) -> str: 47 | return f"{self.__class__.__name__}(size={self.size})" 48 | 49 | 50 | class RandomResizedCropVideo(RandomResizedCrop): 51 | def __init__( 52 | self, 53 | size, 54 | scale=(0.08, 1.0), 55 | ratio=(3.0 / 4.0, 4.0 / 3.0), 56 | interpolation_mode="bilinear", 57 | ): 58 | if isinstance(size, tuple): 59 | if len(size) != 2: 60 | raise ValueError(f"size should be tuple (height, width), instead got {size}") 61 | self.size = size 62 | else: 63 | self.size = (size, size) 64 | 65 | self.interpolation_mode = interpolation_mode 66 | self.scale = scale 67 | self.ratio = ratio 68 | 69 | def __call__(self, clip): 70 | """ 71 | Args: 72 | clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W) 73 | Returns: 74 | torch.tensor: randomly cropped/resized video clip. 75 | size is (C, T, H, W) 76 | """ 77 | i, j, h, w = self.get_params(clip, self.scale, self.ratio) 78 | return F.resized_crop(clip, i, j, h, w, self.size, self.interpolation_mode) 79 | 80 | def __repr__(self) -> str: 81 | return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}, scale={self.scale}, ratio={self.ratio})" 82 | 83 | 84 | class CenterCropVideo: 85 | def __init__(self, crop_size): 86 | if isinstance(crop_size, numbers.Number): 87 | self.crop_size = (int(crop_size), int(crop_size)) 88 | else: 89 | self.crop_size = crop_size 90 | 91 | def __call__(self, clip): 92 | """ 93 | Args: 94 | clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W) 95 | Returns: 96 | torch.tensor: central cropping of video clip. Size is 97 | (C, T, crop_size, crop_size) 98 | """ 99 | return F.center_crop(clip, self.crop_size) 100 | 101 | def __repr__(self) -> str: 102 | return f"{self.__class__.__name__}(crop_size={self.crop_size})" 103 | 104 | 105 | class NormalizeVideo: 106 | """ 107 | Normalize the video clip by mean subtraction and division by standard deviation 108 | Args: 109 | mean (3-tuple): pixel RGB mean 110 | std (3-tuple): pixel RGB standard deviation 111 | inplace (boolean): whether do in-place normalization 112 | """ 113 | 114 | def __init__(self, mean, std, inplace=False): 115 | self.mean = mean 116 | self.std = std 117 | self.inplace = inplace 118 | 119 | def __call__(self, clip): 120 | """ 121 | Args: 122 | clip (torch.tensor): video clip to be normalized. Size is (C, T, H, W) 123 | """ 124 | return F.normalize(clip, self.mean, self.std, self.inplace) 125 | 126 | def __repr__(self) -> str: 127 | return f"{self.__class__.__name__}(mean={self.mean}, std={self.std}, inplace={self.inplace})" 128 | 129 | 130 | class ToTensorVideo: 131 | """ 132 | Convert tensor data type from uint8 to float, divide value by 255.0 and 133 | permute the dimensions of clip tensor 134 | """ 135 | 136 | def __init__(self): 137 | pass 138 | 139 | def __call__(self, clip): 140 | """ 141 | Args: 142 | clip (torch.tensor, dtype=torch.uint8): Size is (T, H, W, C) 143 | Return: 144 | clip (torch.tensor, dtype=torch.float): Size is (C, T, H, W) 145 | """ 146 | return F.to_tensor(clip) 147 | 148 | def __repr__(self) -> str: 149 | return self.__class__.__name__ 150 | 151 | 152 | class RandomHorizontalFlipVideo: 153 | """ 154 | Flip the video clip along the horizontal direction with a given probability 155 | Args: 156 | p (float): probability of the clip being flipped. Default value is 0.5 157 | """ 158 | 159 | def __init__(self, p=0.5): 160 | self.p = p 161 | 162 | def __call__(self, clip): 163 | """ 164 | Args: 165 | clip (torch.tensor): Size is (C, T, H, W) 166 | Return: 167 | clip (torch.tensor): Size is (C, T, H, W) 168 | """ 169 | if random.random() < self.p: 170 | clip = F.hflip(clip) 171 | return clip 172 | 173 | def __repr__(self) -> str: 174 | return f"{self.__class__.__name__}(p={self.p})" 175 | -------------------------------------------------------------------------------- /transforms/_utils.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import numbers 3 | from collections import defaultdict 4 | from typing import Any, Dict, Literal, Sequence, Type, TypeVar, Union 5 | 6 | from util import datapoints 7 | from util.datapoints import _FillType, _FillTypeJIT 8 | 9 | from transforms.transforms import _check_sequence_input, _setup_angle, _setup_size # noqa: F401 10 | 11 | 12 | def _setup_float_or_seq(arg: Union[float, Sequence[float]], name: str, req_size: int = 2) -> Sequence[float]: 13 | if not isinstance(arg, (float, Sequence)): 14 | raise TypeError(f"{name} should be float or a sequence of floats. Got {type(arg)}") 15 | if isinstance(arg, Sequence) and len(arg) != req_size: 16 | raise ValueError(f"If {name} is a sequence its length should be one of {req_size}. Got {len(arg)}") 17 | if isinstance(arg, Sequence): 18 | for element in arg: 19 | if not isinstance(element, float): 20 | raise ValueError(f"{name} should be a sequence of floats. Got {type(element)}") 21 | 22 | if isinstance(arg, float): 23 | arg = [float(arg), float(arg)] 24 | if isinstance(arg, (list, tuple)) and len(arg) == 1: 25 | arg = [arg[0], arg[0]] 26 | return arg 27 | 28 | 29 | def _check_fill_arg(fill: Union[_FillType, Dict[Type, _FillType]]) -> None: 30 | if isinstance(fill, dict): 31 | for key, value in fill.items(): 32 | # Check key for type 33 | _check_fill_arg(value) 34 | if isinstance(fill, defaultdict) and callable(fill.default_factory): 35 | default_value = fill.default_factory() 36 | _check_fill_arg(default_value) 37 | else: 38 | if fill is not None and not isinstance(fill, (numbers.Number, tuple, list)): 39 | raise TypeError("Got inappropriate fill arg, only Numbers, tuples, lists and dicts are allowed.") 40 | 41 | 42 | T = TypeVar("T") 43 | 44 | 45 | def _default_arg(value: T) -> T: 46 | return value 47 | 48 | 49 | def _get_defaultdict(default: T) -> Dict[Any, T]: 50 | # This weird looking construct only exists, since `lambda`'s cannot be serialized by pickle. 51 | # If it were possible, we could replace this with `defaultdict(lambda: default)` 52 | return defaultdict(functools.partial(_default_arg, default)) 53 | 54 | 55 | def _convert_fill_arg(fill: datapoints._FillType) -> datapoints._FillTypeJIT: 56 | # Fill = 0 is not equivalent to None, https://github.com/pytorch/vision/issues/6517 57 | # So, we can't reassign fill to 0 58 | # if fill is None: 59 | # fill = 0 60 | if fill is None: 61 | return fill 62 | 63 | if not isinstance(fill, (int, float)): 64 | fill = [float(v) for v in list(fill)] 65 | return fill # type: ignore[return-value] 66 | 67 | 68 | def _setup_fill_arg(fill: Union[_FillType, Dict[Type, _FillType]]) -> Dict[Type, _FillTypeJIT]: 69 | _check_fill_arg(fill) 70 | 71 | if isinstance(fill, dict): 72 | for k, v in fill.items(): 73 | fill[k] = _convert_fill_arg(v) 74 | if isinstance(fill, defaultdict) and callable(fill.default_factory): 75 | default_value = fill.default_factory() 76 | sanitized_default = _convert_fill_arg(default_value) 77 | fill.default_factory = functools.partial(_default_arg, sanitized_default) 78 | return fill # type: ignore[return-value] 79 | 80 | return _get_defaultdict(_convert_fill_arg(fill)) 81 | 82 | 83 | def _check_padding_arg(padding: Union[int, Sequence[int]]) -> None: 84 | if not isinstance(padding, (numbers.Number, tuple, list)): 85 | raise TypeError("Got inappropriate padding arg") 86 | 87 | if isinstance(padding, (tuple, list)) and len(padding) not in [1, 2, 4]: 88 | raise ValueError(f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple") 89 | 90 | 91 | # TODO: let's use torchvision._utils.StrEnum to have the best of both worlds (strings and enums) 92 | # https://github.com/pytorch/vision/issues/6250 93 | def _check_padding_mode_arg(padding_mode: Literal["constant", "edge", "reflect", "symmetric"]) -> None: 94 | if padding_mode not in ["constant", "edge", "reflect", "symmetric"]: 95 | raise ValueError("Padding mode should be either constant, edge, reflect or symmetric") 96 | -------------------------------------------------------------------------------- /transforms/albumentations_warpper.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from typing import Any 3 | 4 | import torch 5 | from torch import nn 6 | 7 | from util import datapoints 8 | 9 | 10 | class AlbumentationsWrapper(nn.Module): 11 | def __init__(self, albumentation_transforms): 12 | """ 13 | 14 | :param albumentation_transforms: albumentations transformation for data augmentation. For example: 15 | """ 16 | super().__init__() 17 | self.albumentation_transforms = albumentation_transforms 18 | 19 | def forward(self, input: Any) -> Any: 20 | # get image, box, mask, label from input 21 | labels = input[-1] 22 | not_allowed_data = list( 23 | filter( 24 | lambda x: not isinstance(x, (datapoints.Image, datapoints.BoundingBox, datapoints.Mask)), 25 | input, 26 | ) 27 | ) 28 | not_allowed_data_type = set(list(map(lambda x: type(x), not_allowed_data))) 29 | if len(not_allowed_data) != 1: 30 | warnings.warn( 31 | f"current we only support images, bounding boxes and masks" 32 | f"transformation for albumentations, but got {not_allowed_data_type}" 33 | ) 34 | images = list(filter(lambda x: isinstance(x, datapoints.Image), input)) 35 | boxes = list(filter(lambda x: isinstance(x, datapoints.BoundingBox), input)) 36 | masks = list(filter(lambda x: isinstance(x, datapoints.Mask), input)) 37 | if len(images) != 1 or len(boxes) != 1: 38 | raise ValueError 39 | 40 | # prepare albumentations input format 41 | images = images[0].data.numpy().transpose(1, 2, 0) 42 | boxes = boxes[0].data.numpy() 43 | keep = (boxes[:, 2] > boxes[:, 0]) & (boxes[:, 3] > boxes[:, 1]) # TODO: change into a function 44 | input_dict = { 45 | "image": images, 46 | "bboxes": boxes[keep], 47 | "labels": labels.numpy()[keep], 48 | } 49 | if len(masks) != 0: 50 | masks = masks[0].data.numpy() 51 | if masks.ndim == 3: 52 | masks = masks.transpose(1, 2, 0)[keep] 53 | input_dict.update({"mask": masks}) 54 | 55 | # perform albumentations transforms 56 | transformed = self.albumentation_transforms(**input_dict) 57 | images, boxes, labels = ( 58 | transformed["image"], 59 | transformed["bboxes"], 60 | transformed["labels"], 61 | ) 62 | if "mask" in transformed: 63 | masks = transformed["mask"] 64 | if masks.ndim == 3: 65 | masks = masks.transpose(2, 0, 1) 66 | masks = datapoints.Mask(masks) 67 | else: 68 | masks = None 69 | 70 | # prepare output data format 71 | images = datapoints.Image(images.transpose(2, 0, 1)) 72 | boxes = datapoints.BoundingBox( 73 | torch.as_tensor(boxes).reshape(-1, 4), # in case of empty boxes after transforms 74 | dtype=torch.float, 75 | format=datapoints.BoundingBoxFormat.XYXY, 76 | spatial_size=images.shape[-2:], 77 | ) 78 | output = [images, boxes] 79 | if masks is not None: 80 | output.append(masks) 81 | labels = torch.as_tensor(labels, dtype=torch.long) 82 | output.append(labels) 83 | return tuple(output) 84 | 85 | def __str__(self): 86 | return str(self.albumentation_transforms) 87 | -------------------------------------------------------------------------------- /transforms/convert_coco_polys_to_mask.py: -------------------------------------------------------------------------------- 1 | import PIL.Image 2 | import numpy as np 3 | import torch 4 | from pycocotools import mask as coco_mask 5 | 6 | 7 | class ConvertCocoPolysToMask(object): 8 | def __init__(self, return_masks=False): 9 | self.return_masks = return_masks 10 | 11 | def __call__(self, image_target_tuple): 12 | image, target = image_target_tuple 13 | if isinstance(image, (torch.Tensor, np.ndarray)): 14 | assert len(image.shape) == 3, "only one image is accepted" 15 | assert image.shape[-3] in [1, 3], "channels of images must be 1 or 3" 16 | _, h, w = image.shape 17 | elif isinstance(image, PIL.Image.Image): 18 | w, h = image.size 19 | else: 20 | raise TypeError( 21 | f"Now only torch.Tensor, PIL.Image.Image and np.ndarray " 22 | f"of an image is accepted but got type {type(image)}" 23 | ) 24 | 25 | anno = target["annotations"] 26 | 27 | anno = [obj for obj in anno if "iscrowd" not in obj or obj["iscrowd"] == 0] 28 | 29 | boxes = [obj["bbox"] for obj in anno] 30 | # guard against no boxes via resizing 31 | boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4) 32 | boxes[:, 2:] += boxes[:, :2] 33 | boxes[:, 0::2].clamp_(min=0, max=w) 34 | boxes[:, 1::2].clamp_(min=0, max=h) 35 | 36 | classes = [obj["category_id"] for obj in anno] 37 | classes = torch.as_tensor(classes, dtype=torch.int64) 38 | 39 | masks = None 40 | if self.return_masks: 41 | segmentations = [obj["segmentation"] for obj in anno] 42 | masks = convert_coco_poly_to_mask(segmentations, h, w) 43 | 44 | keypoints = None 45 | if anno and "keypoints" in anno[0]: 46 | keypoints = [obj["keypoints"] for obj in anno] 47 | keypoints = torch.as_tensor(keypoints, dtype=torch.float32) 48 | num_keypoints = keypoints.shape[0] 49 | if num_keypoints: 50 | keypoints = keypoints.view(num_keypoints, -1, 3) 51 | 52 | # adapt to result_file 53 | scores = None 54 | if anno and "score" in anno[0]: 55 | scores = [obj["score"] for obj in anno] 56 | scores = torch.as_tensor(scores, dtype=torch.float32) 57 | 58 | keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0]) 59 | boxes = boxes[keep] 60 | classes = classes[keep] 61 | if masks is not None: 62 | masks = masks[keep] 63 | if keypoints is not None: 64 | keypoints = keypoints[keep] 65 | if scores is not None: 66 | scores = scores[keep] 67 | 68 | target = {"boxes": boxes, "labels": classes, "image_id": target["image_id"]} 69 | if masks is not None: 70 | target["masks"] = masks 71 | if keypoints is not None: 72 | target["keypoints"] = keypoints 73 | if scores is not None: 74 | target["scores"] = scores 75 | 76 | # for conversion to coco api 77 | area = torch.tensor([obj["area"] for obj in anno], dtype=torch.float32) 78 | iscrowd = torch.tensor( 79 | [obj["iscrowd"] if "iscrowd" in obj else 0 for obj in anno], dtype=torch.long 80 | ) 81 | target["area"] = area[keep] 82 | target["iscrowd"] = iscrowd[keep] 83 | 84 | return image, target 85 | 86 | 87 | def convert_coco_poly_to_mask(segmentations, height, width): 88 | masks = [] 89 | for polygons in segmentations: 90 | rles = coco_mask.frPyObjects(polygons, height, width) 91 | mask = coco_mask.decode(rles) 92 | if len(mask.shape) < 3: 93 | mask = mask[..., None] 94 | mask = torch.as_tensor(mask, dtype=torch.uint8) 95 | mask = mask.any(dim=2) 96 | masks.append(mask) 97 | if masks: 98 | masks = torch.stack(masks, dim=0) 99 | else: 100 | masks = torch.zeros((0, height, width), dtype=torch.uint8) 101 | return masks 102 | -------------------------------------------------------------------------------- /transforms/crop.py: -------------------------------------------------------------------------------- 1 | import random 2 | from typing import Any, Dict, List 3 | 4 | import torch 5 | from torchvision.ops import box_iou 6 | 7 | from transforms.v2 import Transform 8 | from transforms.v2 import functional as F 9 | from transforms.v2.utils import query_bounding_box, query_spatial_size 10 | from util import datapoints 11 | 12 | 13 | class RandomSizeCrop(Transform): 14 | def __init__(self, min_size: int, max_size: int): 15 | super().__init__() 16 | self.min_size = min_size 17 | self.max_size = max_size 18 | 19 | def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: 20 | orig_h, orig_w = query_spatial_size(flat_inputs) 21 | crop_h = random.randint(self.min_size, min(orig_h, self.max_size)) 22 | crop_w = random.randint(self.min_size, min(orig_w, self.max_size)) 23 | 24 | # get crop region 25 | top = torch.randint(0, orig_h - crop_h + 1, size=(1,)).item() 26 | left = torch.randint(0, orig_w - crop_w + 1, size=(1,)).item() 27 | 28 | return {"left": left, "top": top, "height": crop_h, "width": crop_w} 29 | 30 | def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: 31 | return F.crop(inpt, **params) 32 | 33 | 34 | class BoxCenteredRandomSizeCrop(Transform): 35 | def __init__(self, min_size: int, max_size: int, sampler_options=None, trials: int = 40): 36 | super().__init__() 37 | self.min_size = min_size 38 | self.max_size = max_size 39 | self.trials = trials 40 | if sampler_options is None: 41 | sampler_options = [0.1, 0.3, 0.5, 0.7, 0.9, 1.0] 42 | self.options = sampler_options 43 | 44 | def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: 45 | orig_h, orig_w = query_spatial_size(flat_inputs) 46 | bboxes = query_bounding_box(flat_inputs) 47 | best_iou = 0 48 | for _ in range(self.trials): 49 | idx = int(torch.randint(low=0, high=len(self.options), size=(1,))) 50 | min_jaccard_overlap = self.options[idx] 51 | crop_h = random.randint(self.min_size, min(orig_h, self.max_size)) 52 | crop_w = random.randint(self.min_size, min(orig_w, self.max_size)) 53 | 54 | # get crop region 55 | top = torch.randint(0, orig_h - crop_h + 1, size=(1,)).item() 56 | left = torch.randint(0, orig_w - crop_w + 1, size=(1,)).item() 57 | right = left + crop_w 58 | bottom = top + crop_h 59 | 60 | # check for any valid boxes with centers within the crop area 61 | xyxy_bboxes = F.convert_format_bounding_box( 62 | bboxes.as_subclass(torch.Tensor), bboxes.format, datapoints.BoundingBoxFormat.XYXY 63 | ) 64 | cx = 0.5 * (xyxy_bboxes[..., 0] + xyxy_bboxes[..., 2]) 65 | cy = 0.5 * (xyxy_bboxes[..., 1] + xyxy_bboxes[..., 3]) 66 | is_within_crop_area = (left < cx) & (cx < right) & (top < cy) & (cy < bottom) 67 | if not is_within_crop_area.any(): 68 | continue 69 | 70 | xyxy_bboxes = xyxy_bboxes[is_within_crop_area] 71 | ious = box_iou( 72 | xyxy_bboxes, 73 | xyxy_bboxes.new_tensor([[left, top, right, bottom]]), 74 | ) 75 | cur_region = dict( 76 | top=top, 77 | left=left, 78 | height=crop_h, 79 | width=crop_w, 80 | is_within_crop_area=is_within_crop_area, 81 | ) 82 | 83 | if ious.max() > best_iou: 84 | best_region = cur_region 85 | 86 | if ious.max() < min_jaccard_overlap: 87 | continue 88 | 89 | return cur_region 90 | 91 | return best_region 92 | 93 | def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: 94 | if len(params) < 1: 95 | return inpt 96 | 97 | output = F.crop( 98 | inpt, 99 | top=params["top"], 100 | left=params["left"], 101 | height=params["height"], 102 | width=params["width"], 103 | ) 104 | 105 | if isinstance(output, datapoints.BoundingBox): 106 | # We "mark" the invalid boxes as degenreate, and they can be 107 | # removed by a later call to SanitizeBoundingBox() 108 | output[~params["is_within_crop_area"]] = 0 109 | return output 110 | -------------------------------------------------------------------------------- /transforms/functional_pil.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | from transforms._functional_pil import * # noqa 4 | 5 | warnings.warn( 6 | "The torchvision.transforms.functional_pil module is deprecated " 7 | "in 0.15 and will be **removed in 0.17**. Please don't rely on it. " 8 | "You probably just need to use APIs in " 9 | "torchvision.transforms.functional or in " 10 | "torchvision.transforms.v2.functional." 11 | ) 12 | -------------------------------------------------------------------------------- /transforms/functional_tensor.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | from transforms._functional_tensor import * # noqa 4 | 5 | warnings.warn( 6 | "The torchvision.transforms.functional_tensor module is deprecated " 7 | "in 0.15 and will be **removed in 0.17**. Please don't rely on it. " 8 | "You probably just need to use APIs in " 9 | "torchvision.transforms.functional or in " 10 | "torchvision.transforms.v2.functional." 11 | ) 12 | -------------------------------------------------------------------------------- /transforms/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Any, Callable, List, Tuple, Type, Union, Sequence 4 | 5 | import PIL.Image 6 | from util import datapoints 7 | 8 | from transforms.v2.functional import get_dimensions, get_spatial_size, is_simple_tensor 9 | 10 | 11 | def sequence_to_str(seq: Sequence, separate_last: str = "") -> str: 12 | if not seq: 13 | return "" 14 | if len(seq) == 1: 15 | return f"'{seq[0]}'" 16 | 17 | head = "'" + "', '".join([str(item) for item in seq[:-1]]) + "'" 18 | tail = f"{'' if separate_last and len(seq) == 2 else ','} {separate_last}'{seq[-1]}'" 19 | 20 | return head + tail 21 | 22 | 23 | def query_bounding_box(flat_inputs: List[Any]) -> datapoints.BoundingBox: 24 | bounding_boxes = [inpt for inpt in flat_inputs if isinstance(inpt, datapoints.BoundingBox)] 25 | if not bounding_boxes: 26 | raise TypeError("No bounding box was found in the sample") 27 | elif len(bounding_boxes) > 1: 28 | raise ValueError("Found multiple bounding boxes in the sample") 29 | return bounding_boxes.pop() 30 | 31 | 32 | def query_chw(flat_inputs: List[Any]) -> Tuple[int, int, int]: 33 | chws = { 34 | tuple(get_dimensions(inpt)) 35 | for inpt in flat_inputs 36 | if isinstance(inpt, (datapoints.Image, PIL.Image.Image, datapoints.Video)) or is_simple_tensor(inpt) 37 | } 38 | if not chws: 39 | raise TypeError("No image or video was found in the sample") 40 | elif len(chws) > 1: 41 | raise ValueError(f"Found multiple CxHxW dimensions in the sample: {sequence_to_str(sorted(chws))}") 42 | c, h, w = chws.pop() 43 | return c, h, w 44 | 45 | 46 | def query_spatial_size(flat_inputs: List[Any]) -> Tuple[int, int]: 47 | sizes = { 48 | tuple(get_spatial_size(inpt)) 49 | for inpt in flat_inputs 50 | if isinstance( 51 | inpt, (datapoints.Image, PIL.Image.Image, datapoints.Video, datapoints.Mask, datapoints.BoundingBox) 52 | ) 53 | or is_simple_tensor(inpt) 54 | } 55 | if not sizes: 56 | raise TypeError("No image, video, mask or bounding box was found in the sample") 57 | elif len(sizes) > 1: 58 | raise ValueError(f"Found multiple HxW dimensions in the sample: {sequence_to_str(sorted(sizes))}") 59 | h, w = sizes.pop() 60 | return h, w 61 | 62 | 63 | def check_type(obj: Any, types_or_checks: Tuple[Union[Type, Callable[[Any], bool]], ...]) -> bool: 64 | for type_or_check in types_or_checks: 65 | if isinstance(obj, type_or_check) if isinstance(type_or_check, type) else type_or_check(obj): 66 | return True 67 | return False 68 | 69 | 70 | def has_any(flat_inputs: List[Any], *types_or_checks: Union[Type, Callable[[Any], bool]]) -> bool: 71 | for inpt in flat_inputs: 72 | if check_type(inpt, types_or_checks): 73 | return True 74 | return False 75 | 76 | 77 | def has_all(flat_inputs: List[Any], *types_or_checks: Union[Type, Callable[[Any], bool]]) -> bool: 78 | for type_or_check in types_or_checks: 79 | for inpt in flat_inputs: 80 | if isinstance(inpt, type_or_check) if isinstance(type_or_check, type) else type_or_check(inpt): 81 | break 82 | else: 83 | return False 84 | return True 85 | -------------------------------------------------------------------------------- /transforms/v2/__init__.py: -------------------------------------------------------------------------------- 1 | from . import functional, utils # usort: skip 2 | 3 | from ._transform import Transform # usort: skip 4 | 5 | from ._augment import RandomErasing 6 | from ._auto_augment import AugMix, AutoAugment, RandAugment, TrivialAugmentWide 7 | from ._color import ( 8 | ColorJitter, 9 | Grayscale, 10 | RandomAdjustSharpness, 11 | RandomAutocontrast, 12 | RandomEqualize, 13 | RandomGrayscale, 14 | RandomInvert, 15 | RandomPhotometricDistort, 16 | RandomPosterize, 17 | RandomSolarize, 18 | ) 19 | from ._container import Compose, RandomApply, RandomChoice, RandomOrder 20 | from ._geometry import ( 21 | CenterCrop, 22 | ElasticTransform, 23 | FiveCrop, 24 | Pad, 25 | RandomAffine, 26 | RandomCrop, 27 | RandomHorizontalFlip, 28 | RandomIoUCrop, 29 | RandomPerspective, 30 | RandomResize, 31 | RandomResizedCrop, 32 | RandomRotation, 33 | RandomShortestSize, 34 | RandomVerticalFlip, 35 | RandomZoomOut, 36 | Resize, 37 | ScaleJitter, 38 | TenCrop, 39 | ) 40 | from ._meta import ClampBoundingBox, ConvertBoundingBoxFormat, ConvertDtype, ConvertImageDtype 41 | from ._misc import GaussianBlur, Identity, Lambda, LinearTransformation, Normalize, SanitizeBoundingBox, ToDtype 42 | from ._temporal import UniformTemporalSubsample 43 | from ._type_conversion import PILToTensor, ToImagePIL, ToImageTensor, ToPILImage 44 | 45 | from ._deprecated import ToTensor # usort: skip 46 | -------------------------------------------------------------------------------- /transforms/v2/_augment.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numbers 3 | import warnings 4 | from typing import Any, Dict, List, Tuple, Union 5 | 6 | import PIL.Image 7 | import torch 8 | from util import datapoints 9 | import transforms as _transforms 10 | from transforms.v2 import functional as F 11 | 12 | from ._transform import _RandomApplyTransform 13 | from .utils import is_simple_tensor, query_chw 14 | 15 | 16 | class RandomErasing(_RandomApplyTransform): 17 | """[BETA] Randomly select a rectangle region in the input image or video and erase its pixels. 18 | 19 | .. v2betastatus:: RandomErasing transform 20 | 21 | This transform does not support PIL Image. 22 | 'Random Erasing Data Augmentation' by Zhong et al. See https://arxiv.org/abs/1708.04896 23 | 24 | Args: 25 | p (float, optional): probability that the random erasing operation will be performed. 26 | scale (tuple of float, optional): range of proportion of erased area against input image. 27 | ratio (tuple of float, optional): range of aspect ratio of erased area. 28 | value (number or tuple of numbers): erasing value. Default is 0. If a single int, it is used to 29 | erase all pixels. If a tuple of length 3, it is used to erase 30 | R, G, B channels respectively. 31 | If a str of 'random', erasing each pixel with random values. 32 | inplace (bool, optional): boolean to make this transform inplace. Default set to False. 33 | 34 | Returns: 35 | Erased input. 36 | 37 | Example: 38 | >>> from torchvision.transforms import v2 as transforms 39 | >>> 40 | >>> transform = transforms.Compose([ 41 | >>> transforms.RandomHorizontalFlip(), 42 | >>> transforms.PILToTensor(), 43 | >>> transforms.ConvertImageDtype(torch.float), 44 | >>> transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 45 | >>> transforms.RandomErasing(), 46 | >>> ]) 47 | """ 48 | 49 | _v1_transform_cls = _transforms.RandomErasing 50 | 51 | def _extract_params_for_v1_transform(self) -> Dict[str, Any]: 52 | return dict( 53 | super()._extract_params_for_v1_transform(), 54 | value="random" if self.value is None else self.value, 55 | ) 56 | 57 | _transformed_types = (is_simple_tensor, datapoints.Image, PIL.Image.Image, datapoints.Video) 58 | 59 | def __init__( 60 | self, 61 | p: float = 0.5, 62 | scale: Tuple[float, float] = (0.02, 0.33), 63 | ratio: Tuple[float, float] = (0.3, 3.3), 64 | value: float = 0.0, 65 | inplace: bool = False, 66 | ): 67 | super().__init__(p=p) 68 | if not isinstance(value, (numbers.Number, str, tuple, list)): 69 | raise TypeError("Argument value should be either a number or str or a sequence") 70 | if isinstance(value, str) and value != "random": 71 | raise ValueError("If value is str, it should be 'random'") 72 | if not isinstance(scale, (tuple, list)): 73 | raise TypeError("Scale should be a sequence") 74 | if not isinstance(ratio, (tuple, list)): 75 | raise TypeError("Ratio should be a sequence") 76 | if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): 77 | warnings.warn("Scale and ratio should be of kind (min, max)") 78 | if scale[0] < 0 or scale[1] > 1: 79 | raise ValueError("Scale should be between 0 and 1") 80 | self.scale = scale 81 | self.ratio = ratio 82 | if isinstance(value, (int, float)): 83 | self.value = [float(value)] 84 | elif isinstance(value, str): 85 | self.value = None 86 | elif isinstance(value, (list, tuple)): 87 | self.value = [float(v) for v in value] 88 | else: 89 | self.value = value 90 | self.inplace = inplace 91 | 92 | self._log_ratio = torch.log(torch.tensor(self.ratio)) 93 | 94 | def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: 95 | img_c, img_h, img_w = query_chw(flat_inputs) 96 | 97 | if self.value is not None and not (len(self.value) in (1, img_c)): 98 | raise ValueError( 99 | f"If value is a sequence, it should have either a single value or {img_c} (number of inpt channels)" 100 | ) 101 | 102 | area = img_h * img_w 103 | 104 | log_ratio = self._log_ratio 105 | for _ in range(10): 106 | erase_area = area * torch.empty(1).uniform_(self.scale[0], self.scale[1]).item() 107 | aspect_ratio = torch.exp( 108 | torch.empty(1).uniform_( 109 | log_ratio[0], # type: ignore[arg-type] 110 | log_ratio[1], # type: ignore[arg-type] 111 | ) 112 | ).item() 113 | 114 | h = int(round(math.sqrt(erase_area * aspect_ratio))) 115 | w = int(round(math.sqrt(erase_area / aspect_ratio))) 116 | if not (h < img_h and w < img_w): 117 | continue 118 | 119 | if self.value is None: 120 | v = torch.empty([img_c, h, w], dtype=torch.float32).normal_() 121 | else: 122 | v = torch.tensor(self.value)[:, None, None] 123 | 124 | i = torch.randint(0, img_h - h + 1, size=(1,)).item() 125 | j = torch.randint(0, img_w - w + 1, size=(1,)).item() 126 | break 127 | else: 128 | i, j, h, w, v = 0, 0, img_h, img_w, None 129 | 130 | return dict(i=i, j=j, h=h, w=w, v=v) 131 | 132 | def _transform( 133 | self, inpt: Union[datapoints._ImageType, datapoints._VideoType], params: Dict[str, Any] 134 | ) -> Union[datapoints._ImageType, datapoints._VideoType]: 135 | if params["v"] is not None: 136 | inpt = F.erase(inpt, **params, inplace=self.inplace) 137 | 138 | return inpt 139 | -------------------------------------------------------------------------------- /transforms/v2/_container.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Dict, List, Optional, Sequence, Union 2 | 3 | import torch 4 | 5 | from torch import nn 6 | import transforms as _transforms 7 | from transforms.v2 import Transform 8 | 9 | 10 | class Compose(Transform): 11 | """[BETA] Composes several transforms together. 12 | 13 | .. v2betastatus:: Compose transform 14 | 15 | This transform does not support torchscript. 16 | Please, see the note below. 17 | 18 | Args: 19 | transforms (list of ``Transform`` objects): list of transforms to compose. 20 | 21 | Example: 22 | >>> transforms.Compose([ 23 | >>> transforms.CenterCrop(10), 24 | >>> transforms.PILToTensor(), 25 | >>> transforms.ConvertImageDtype(torch.float), 26 | >>> ]) 27 | 28 | .. note:: 29 | In order to script the transformations, please use ``torch.nn.Sequential`` as below. 30 | 31 | >>> transforms = torch.nn.Sequential( 32 | >>> transforms.CenterCrop(10), 33 | >>> transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 34 | >>> ) 35 | >>> scripted_transforms = torch.jit.script(transforms) 36 | 37 | Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require 38 | `lambda` functions or ``PIL.Image``. 39 | 40 | """ 41 | 42 | def __init__(self, transforms: Sequence[Callable]) -> None: 43 | super().__init__() 44 | if not isinstance(transforms, Sequence): 45 | raise TypeError("Argument transforms should be a sequence of callables") 46 | self.transforms = transforms 47 | 48 | def forward(self, *inputs: Any) -> Any: 49 | sample = inputs if len(inputs) > 1 else inputs[0] 50 | for transform in self.transforms: 51 | sample = transform(sample) 52 | return sample 53 | 54 | def extra_repr(self) -> str: 55 | format_string = [] 56 | for t in self.transforms: 57 | format_string.append(f" {t}") 58 | return "\n".join(format_string) 59 | 60 | 61 | class RandomApply(Transform): 62 | """[BETA] Apply randomly a list of transformations with a given probability. 63 | 64 | .. v2betastatus:: RandomApply transform 65 | 66 | .. note:: 67 | In order to script the transformation, please use ``torch.nn.ModuleList`` as input instead of list/tuple of 68 | transforms as shown below: 69 | 70 | >>> transforms = transforms.RandomApply(torch.nn.ModuleList([ 71 | >>> transforms.ColorJitter(), 72 | >>> ]), p=0.3) 73 | >>> scripted_transforms = torch.jit.script(transforms) 74 | 75 | Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require 76 | `lambda` functions or ``PIL.Image``. 77 | 78 | Args: 79 | transforms (sequence or torch.nn.Module): list of transformations 80 | p (float): probability of applying the list of transforms 81 | """ 82 | 83 | _v1_transform_cls = _transforms.RandomApply 84 | 85 | def __init__(self, transforms: Union[Sequence[Callable], nn.ModuleList], p: float = 0.5) -> None: 86 | super().__init__() 87 | 88 | if not isinstance(transforms, (Sequence, nn.ModuleList)): 89 | raise TypeError("Argument transforms should be a sequence of callables or a `nn.ModuleList`") 90 | self.transforms = transforms 91 | 92 | if not (0.0 <= p <= 1.0): 93 | raise ValueError("`p` should be a floating point value in the interval [0.0, 1.0].") 94 | self.p = p 95 | 96 | def _extract_params_for_v1_transform(self) -> Dict[str, Any]: 97 | return {"transforms": self.transforms, "p": self.p} 98 | 99 | def forward(self, *inputs: Any) -> Any: 100 | sample = inputs if len(inputs) > 1 else inputs[0] 101 | 102 | if torch.rand(1) >= self.p: 103 | return sample 104 | 105 | for transform in self.transforms: 106 | sample = transform(sample) 107 | return sample 108 | 109 | def extra_repr(self) -> str: 110 | format_string = [] 111 | for t in self.transforms: 112 | format_string.append(f" {t}") 113 | return "\n".join(format_string) 114 | 115 | 116 | class RandomChoice(Transform): 117 | """[BETA] Apply single transformation randomly picked from a list. 118 | 119 | .. v2betastatus:: RandomChoice transform 120 | 121 | This transform does not support torchscript. 122 | 123 | Args: 124 | transforms (sequence or torch.nn.Module): list of transformations 125 | p (list of floats or None, optional): probability of each transform being picked. 126 | If ``p`` doesn't sum to 1, it is automatically normalized. If ``None`` 127 | (default), all transforms have the same probability. 128 | """ 129 | 130 | def __init__( 131 | self, 132 | transforms: Sequence[Callable], 133 | p: Optional[List[float]] = None, 134 | ) -> None: 135 | if not isinstance(transforms, Sequence): 136 | raise TypeError("Argument transforms should be a sequence of callables") 137 | 138 | if p is None: 139 | p = [1] * len(transforms) 140 | elif len(p) != len(transforms): 141 | raise ValueError(f"Length of p doesn't match the number of transforms: {len(p)} != {len(transforms)}") 142 | 143 | super().__init__() 144 | 145 | self.transforms = transforms 146 | total = sum(p) 147 | self.p = [prob / total for prob in p] 148 | 149 | def forward(self, *inputs: Any) -> Any: 150 | idx = int(torch.multinomial(torch.tensor(self.p), 1)) 151 | transform = self.transforms[idx] 152 | return transform(*inputs) 153 | 154 | 155 | class RandomOrder(Transform): 156 | """[BETA] Apply a list of transformations in a random order. 157 | 158 | .. v2betastatus:: RandomOrder transform 159 | 160 | This transform does not support torchscript. 161 | 162 | Args: 163 | transforms (sequence or torch.nn.Module): list of transformations 164 | """ 165 | 166 | def __init__(self, transforms: Sequence[Callable]) -> None: 167 | if not isinstance(transforms, Sequence): 168 | raise TypeError("Argument transforms should be a sequence of callables") 169 | super().__init__() 170 | self.transforms = transforms 171 | 172 | def forward(self, *inputs: Any) -> Any: 173 | sample = inputs if len(inputs) > 1 else inputs[0] 174 | for idx in torch.randperm(len(self.transforms)): 175 | transform = self.transforms[idx] 176 | sample = transform(sample) 177 | return sample 178 | -------------------------------------------------------------------------------- /transforms/v2/_deprecated.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from typing import Any, Dict, Union 3 | 4 | import numpy as np 5 | import PIL.Image 6 | import torch 7 | from transforms import functional as _F 8 | 9 | from transforms.v2 import Transform 10 | 11 | 12 | class ToTensor(Transform): 13 | """[BETA] Convert a PIL Image or ndarray to tensor and scale the values accordingly. 14 | 15 | .. v2betastatus:: ToTensor transform 16 | 17 | .. warning:: 18 | :class:`v2.ToTensor` is deprecated and will be removed in a future release. 19 | Please use instead ``transforms.Compose([transforms.ToImageTensor(), transforms.ConvertImageDtype()])``. 20 | 21 | This transform does not support torchscript. 22 | 23 | 24 | Converts a PIL Image or numpy.ndarray (H x W x C) in the range 25 | [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] 26 | if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1) 27 | or if the numpy.ndarray has dtype = np.uint8 28 | 29 | In the other cases, tensors are returned without scaling. 30 | 31 | .. note:: 32 | Because the input image is scaled to [0.0, 1.0], this transformation should not be used when 33 | transforming target image masks. See the `references`_ for implementing the transforms for image masks. 34 | 35 | .. _references: https://github.com/pytorch/vision/tree/main/references/segmentation 36 | """ 37 | 38 | _transformed_types = (PIL.Image.Image, np.ndarray) 39 | 40 | def __init__(self) -> None: 41 | warnings.warn( 42 | "The transform `ToTensor()` is deprecated and will be removed in a future release. " 43 | "Instead, please use `transforms.Compose([transforms.ToImageTensor(), transforms.ConvertImageDtype()])`." 44 | ) 45 | super().__init__() 46 | 47 | def _transform(self, inpt: Union[PIL.Image.Image, np.ndarray], params: Dict[str, Any]) -> torch.Tensor: 48 | return _F.to_tensor(inpt) 49 | -------------------------------------------------------------------------------- /transforms/v2/_meta.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Union 2 | 3 | import torch 4 | 5 | from util import datapoints 6 | import transforms as _transforms 7 | from transforms.v2 import functional as F, Transform 8 | 9 | from .utils import is_simple_tensor 10 | 11 | 12 | class ConvertBoundingBoxFormat(Transform): 13 | """[BETA] Convert bounding box coordinates to the given ``format``, eg from "CXCYWH" to "XYXY". 14 | 15 | .. v2betastatus:: ConvertBoundingBoxFormat transform 16 | 17 | Args: 18 | format (str or datapoints.BoundingBoxFormat): output bounding box format. 19 | Possible values are defined by :class:`~torchvision.datapoints.BoundingBoxFormat` and 20 | string values match the enums, e.g. "XYXY" or "XYWH" etc. 21 | """ 22 | 23 | _transformed_types = (datapoints.BoundingBox,) 24 | 25 | def __init__(self, format: Union[str, datapoints.BoundingBoxFormat]) -> None: 26 | super().__init__() 27 | if isinstance(format, str): 28 | format = datapoints.BoundingBoxFormat[format] 29 | self.format = format 30 | 31 | def _transform(self, inpt: datapoints.BoundingBox, params: Dict[str, Any]) -> datapoints.BoundingBox: 32 | return F.convert_format_bounding_box(inpt, new_format=self.format) # type: ignore[return-value] 33 | 34 | 35 | class ConvertDtype(Transform): 36 | """[BETA] Convert input image or video to the given ``dtype`` and scale the values accordingly. 37 | 38 | .. v2betastatus:: ConvertDtype transform 39 | 40 | This function does not support PIL Image. 41 | 42 | Args: 43 | dtype (torch.dtype): Desired data type of the output 44 | 45 | .. note:: 46 | 47 | When converting from a smaller to a larger integer ``dtype`` the maximum values are **not** mapped exactly. 48 | If converted back and forth, this mismatch has no effect. 49 | 50 | Raises: 51 | RuntimeError: When trying to cast :class:`torch.float32` to :class:`torch.int32` or :class:`torch.int64` as 52 | well as for trying to cast :class:`torch.float64` to :class:`torch.int64`. These conversions might lead to 53 | overflow errors since the floating point ``dtype`` cannot store consecutive integers over the whole range 54 | of the integer ``dtype``. 55 | """ 56 | 57 | _v1_transform_cls = _transforms.ConvertImageDtype 58 | 59 | _transformed_types = (is_simple_tensor, datapoints.Image, datapoints.Video) 60 | 61 | def __init__(self, dtype: torch.dtype = torch.float32) -> None: 62 | super().__init__() 63 | self.dtype = dtype 64 | 65 | def _transform( 66 | self, inpt: Union[datapoints._TensorImageType, datapoints._TensorVideoType], params: Dict[str, Any] 67 | ) -> Union[datapoints._TensorImageType, datapoints._TensorVideoType]: 68 | return F.convert_dtype(inpt, self.dtype) 69 | 70 | 71 | # We changed the name to align it with the new naming scheme. Still, `ConvertImageDtype` is 72 | # prevalent and well understood. Thus, we just alias it without deprecating the old name. 73 | ConvertImageDtype = ConvertDtype 74 | 75 | 76 | class ClampBoundingBox(Transform): 77 | """[BETA] Clamp bounding boxes to their corresponding image dimensions. 78 | 79 | The clamping is done according to the bounding boxes' ``spatial_size`` meta-data. 80 | 81 | .. v2betastatus:: ClampBoundingBox transform 82 | 83 | """ 84 | 85 | _transformed_types = (datapoints.BoundingBox,) 86 | 87 | def _transform(self, inpt: datapoints.BoundingBox, params: Dict[str, Any]) -> datapoints.BoundingBox: 88 | return F.clamp_bounding_box(inpt) # type: ignore[return-value] 89 | -------------------------------------------------------------------------------- /transforms/v2/_temporal.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict 2 | 3 | from util import datapoints 4 | from transforms.v2 import functional as F, Transform 5 | 6 | from transforms.v2.utils import is_simple_tensor 7 | 8 | 9 | class UniformTemporalSubsample(Transform): 10 | """[BETA] Uniformly subsample ``num_samples`` indices from the temporal dimension of the video. 11 | 12 | .. v2betastatus:: UniformTemporalSubsample transform 13 | 14 | Videos are expected to be of shape ``[..., T, C, H, W]`` where ``T`` denotes the temporal dimension. 15 | 16 | When ``num_samples`` is larger than the size of temporal dimension of the video, it 17 | will sample frames based on nearest neighbor interpolation. 18 | 19 | Args: 20 | num_samples (int): The number of equispaced samples to be selected 21 | """ 22 | 23 | _transformed_types = (is_simple_tensor, datapoints.Video) 24 | 25 | def __init__(self, num_samples: int): 26 | super().__init__() 27 | self.num_samples = num_samples 28 | 29 | def _transform(self, inpt: datapoints._VideoType, params: Dict[str, Any]) -> datapoints._VideoType: 30 | return F.uniform_temporal_subsample(inpt, self.num_samples) 31 | -------------------------------------------------------------------------------- /transforms/v2/_type_conversion.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Optional, Union 2 | 3 | import numpy as np 4 | import PIL.Image 5 | import torch 6 | 7 | from util import datapoints 8 | from transforms.v2 import functional as F, Transform 9 | 10 | from transforms.v2.utils import is_simple_tensor 11 | 12 | 13 | class PILToTensor(Transform): 14 | """[BETA] Convert a PIL Image to a tensor of the same type - this does not scale values. 15 | 16 | .. v2betastatus:: PILToTensor transform 17 | 18 | This transform does not support torchscript. 19 | 20 | Converts a PIL Image (H x W x C) to a Tensor of shape (C x H x W). 21 | """ 22 | 23 | _transformed_types = (PIL.Image.Image,) 24 | 25 | def _transform(self, inpt: PIL.Image.Image, params: Dict[str, Any]) -> torch.Tensor: 26 | return F.pil_to_tensor(inpt) 27 | 28 | 29 | class ToImageTensor(Transform): 30 | """[BETA] Convert a tensor, ndarray, or PIL Image to :class:`~torchvision.datapoints.Image` 31 | ; this does not scale values. 32 | 33 | .. v2betastatus:: ToImageTensor transform 34 | 35 | This transform does not support torchscript. 36 | """ 37 | 38 | _transformed_types = (is_simple_tensor, PIL.Image.Image, np.ndarray) 39 | 40 | def _transform( 41 | self, inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray], params: Dict[str, Any] 42 | ) -> datapoints.Image: 43 | return F.to_image_tensor(inpt) 44 | 45 | 46 | class ToImagePIL(Transform): 47 | """[BETA] Convert a tensor or an ndarray to PIL Image - this does not scale values. 48 | 49 | .. v2betastatus:: ToImagePIL transform 50 | 51 | This transform does not support torchscript. 52 | 53 | Converts a torch.*Tensor of shape C x H x W or a numpy ndarray of shape 54 | H x W x C to a PIL Image while preserving the value range. 55 | 56 | Args: 57 | mode (`PIL.Image mode`_): color space and pixel depth of input data (optional). 58 | If ``mode`` is ``None`` (default) there are some assumptions made about the input data: 59 | - If the input has 4 channels, the ``mode`` is assumed to be ``RGBA``. 60 | - If the input has 3 channels, the ``mode`` is assumed to be ``RGB``. 61 | - If the input has 2 channels, the ``mode`` is assumed to be ``LA``. 62 | - If the input has 1 channel, the ``mode`` is determined by the data type (i.e ``int``, ``float``, 63 | ``short``). 64 | 65 | .. _PIL.Image mode: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#concept-modes 66 | """ 67 | 68 | _transformed_types = (is_simple_tensor, datapoints.Image, np.ndarray) 69 | 70 | def __init__(self, mode: Optional[str] = None) -> None: 71 | super().__init__() 72 | self.mode = mode 73 | 74 | def _transform( 75 | self, inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray], params: Dict[str, Any] 76 | ) -> PIL.Image.Image: 77 | return F.to_image_pil(inpt, mode=self.mode) 78 | 79 | 80 | # We changed the name to align them with the new naming scheme. Still, `ToPILImage` is 81 | # prevalent and well understood. Thus, we just alias it without deprecating the old name. 82 | ToPILImage = ToImagePIL 83 | -------------------------------------------------------------------------------- /transforms/v2/_utils.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import numbers 3 | from collections import defaultdict 4 | from typing import Any, Dict, Literal, Sequence, Type, TypeVar, Union 5 | 6 | from util import datapoints 7 | from util.datapoints import _FillType, _FillTypeJIT 8 | 9 | from transforms.transforms import _check_sequence_input, _setup_angle, _setup_size # noqa: F401 10 | 11 | 12 | def _setup_float_or_seq(arg: Union[float, Sequence[float]], name: str, req_size: int = 2) -> Sequence[float]: 13 | if not isinstance(arg, (float, Sequence)): 14 | raise TypeError(f"{name} should be float or a sequence of floats. Got {type(arg)}") 15 | if isinstance(arg, Sequence) and len(arg) != req_size: 16 | raise ValueError(f"If {name} is a sequence its length should be one of {req_size}. Got {len(arg)}") 17 | if isinstance(arg, Sequence): 18 | for element in arg: 19 | if not isinstance(element, float): 20 | raise ValueError(f"{name} should be a sequence of floats. Got {type(element)}") 21 | 22 | if isinstance(arg, float): 23 | arg = [float(arg), float(arg)] 24 | if isinstance(arg, (list, tuple)) and len(arg) == 1: 25 | arg = [arg[0], arg[0]] 26 | return arg 27 | 28 | 29 | def _check_fill_arg(fill: Union[_FillType, Dict[Type, _FillType]]) -> None: 30 | if isinstance(fill, dict): 31 | for key, value in fill.items(): 32 | # Check key for type 33 | _check_fill_arg(value) 34 | if isinstance(fill, defaultdict) and callable(fill.default_factory): 35 | default_value = fill.default_factory() 36 | _check_fill_arg(default_value) 37 | else: 38 | if fill is not None and not isinstance(fill, (numbers.Number, tuple, list)): 39 | raise TypeError("Got inappropriate fill arg, only Numbers, tuples, lists and dicts are allowed.") 40 | 41 | 42 | T = TypeVar("T") 43 | 44 | 45 | def _default_arg(value: T) -> T: 46 | return value 47 | 48 | 49 | def _get_defaultdict(default: T) -> Dict[Any, T]: 50 | # This weird looking construct only exists, since `lambda`'s cannot be serialized by pickle. 51 | # If it were possible, we could replace this with `defaultdict(lambda: default)` 52 | return defaultdict(functools.partial(_default_arg, default)) 53 | 54 | 55 | def _convert_fill_arg(fill: datapoints._FillType) -> datapoints._FillTypeJIT: 56 | # Fill = 0 is not equivalent to None, https://github.com/pytorch/vision/issues/6517 57 | # So, we can't reassign fill to 0 58 | # if fill is None: 59 | # fill = 0 60 | if fill is None: 61 | return fill 62 | 63 | if not isinstance(fill, (int, float)): 64 | fill = [float(v) for v in list(fill)] 65 | return fill # type: ignore[return-value] 66 | 67 | 68 | def _setup_fill_arg(fill: Union[_FillType, Dict[Type, _FillType]]) -> Dict[Type, _FillTypeJIT]: 69 | _check_fill_arg(fill) 70 | 71 | if isinstance(fill, dict): 72 | for k, v in fill.items(): 73 | fill[k] = _convert_fill_arg(v) 74 | if isinstance(fill, defaultdict) and callable(fill.default_factory): 75 | default_value = fill.default_factory() 76 | sanitized_default = _convert_fill_arg(default_value) 77 | fill.default_factory = functools.partial(_default_arg, sanitized_default) 78 | return fill # type: ignore[return-value] 79 | 80 | return _get_defaultdict(_convert_fill_arg(fill)) 81 | 82 | 83 | def _check_padding_arg(padding: Union[int, Sequence[int]]) -> None: 84 | if not isinstance(padding, (numbers.Number, tuple, list)): 85 | raise TypeError("Got inappropriate padding arg") 86 | 87 | if isinstance(padding, (tuple, list)) and len(padding) not in [1, 2, 4]: 88 | raise ValueError(f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple") 89 | 90 | 91 | # TODO: let's use torchvision._utils.StrEnum to have the best of both worlds (strings and enums) 92 | # https://github.com/pytorch/vision/issues/6250 93 | def _check_padding_mode_arg(padding_mode: Literal["constant", "edge", "reflect", "symmetric"]) -> None: 94 | if padding_mode not in ["constant", "edge", "reflect", "symmetric"]: 95 | raise ValueError("Padding mode should be either constant, edge, reflect or symmetric") 96 | -------------------------------------------------------------------------------- /transforms/v2/functional/__init__.py: -------------------------------------------------------------------------------- 1 | from transforms import InterpolationMode # usort: skip 2 | 3 | from ._utils import is_simple_tensor # usort: skip 4 | 5 | from ._meta import ( 6 | clamp_bounding_box, 7 | convert_format_bounding_box, 8 | convert_dtype_image_tensor, 9 | convert_dtype, 10 | convert_dtype_video, 11 | convert_image_dtype, 12 | get_dimensions_image_tensor, 13 | get_dimensions_image_pil, 14 | get_dimensions, 15 | get_num_frames_video, 16 | get_num_frames, 17 | get_image_num_channels, 18 | get_num_channels_image_tensor, 19 | get_num_channels_image_pil, 20 | get_num_channels_video, 21 | get_num_channels, 22 | get_spatial_size_bounding_box, 23 | get_spatial_size_image_tensor, 24 | get_spatial_size_image_pil, 25 | get_spatial_size_mask, 26 | get_spatial_size_video, 27 | get_spatial_size, 28 | ) # usort: skip 29 | 30 | from ._augment import erase, erase_image_pil, erase_image_tensor, erase_video 31 | from ._color import ( 32 | adjust_brightness, 33 | adjust_brightness_image_pil, 34 | adjust_brightness_image_tensor, 35 | adjust_brightness_video, 36 | adjust_contrast, 37 | adjust_contrast_image_pil, 38 | adjust_contrast_image_tensor, 39 | adjust_contrast_video, 40 | adjust_gamma, 41 | adjust_gamma_image_pil, 42 | adjust_gamma_image_tensor, 43 | adjust_gamma_video, 44 | adjust_hue, 45 | adjust_hue_image_pil, 46 | adjust_hue_image_tensor, 47 | adjust_hue_video, 48 | adjust_saturation, 49 | adjust_saturation_image_pil, 50 | adjust_saturation_image_tensor, 51 | adjust_saturation_video, 52 | adjust_sharpness, 53 | adjust_sharpness_image_pil, 54 | adjust_sharpness_image_tensor, 55 | adjust_sharpness_video, 56 | autocontrast, 57 | autocontrast_image_pil, 58 | autocontrast_image_tensor, 59 | autocontrast_video, 60 | equalize, 61 | equalize_image_pil, 62 | equalize_image_tensor, 63 | equalize_video, 64 | invert, 65 | invert_image_pil, 66 | invert_image_tensor, 67 | invert_video, 68 | posterize, 69 | posterize_image_pil, 70 | posterize_image_tensor, 71 | posterize_video, 72 | rgb_to_grayscale, 73 | rgb_to_grayscale_image_pil, 74 | rgb_to_grayscale_image_tensor, 75 | solarize, 76 | solarize_image_pil, 77 | solarize_image_tensor, 78 | solarize_video, 79 | ) 80 | from ._geometry import ( 81 | affine, 82 | affine_bounding_box, 83 | affine_image_pil, 84 | affine_image_tensor, 85 | affine_mask, 86 | affine_video, 87 | center_crop, 88 | center_crop_bounding_box, 89 | center_crop_image_pil, 90 | center_crop_image_tensor, 91 | center_crop_mask, 92 | center_crop_video, 93 | crop, 94 | crop_bounding_box, 95 | crop_image_pil, 96 | crop_image_tensor, 97 | crop_mask, 98 | crop_video, 99 | elastic, 100 | elastic_bounding_box, 101 | elastic_image_pil, 102 | elastic_image_tensor, 103 | elastic_mask, 104 | elastic_transform, 105 | elastic_video, 106 | five_crop, 107 | five_crop_image_pil, 108 | five_crop_image_tensor, 109 | five_crop_video, 110 | hflip, # TODO: Consider moving all pure alias definitions at the bottom of the file 111 | horizontal_flip, 112 | horizontal_flip_bounding_box, 113 | horizontal_flip_image_pil, 114 | horizontal_flip_image_tensor, 115 | horizontal_flip_mask, 116 | horizontal_flip_video, 117 | pad, 118 | pad_bounding_box, 119 | pad_image_pil, 120 | pad_image_tensor, 121 | pad_mask, 122 | pad_video, 123 | perspective, 124 | perspective_bounding_box, 125 | perspective_image_pil, 126 | perspective_image_tensor, 127 | perspective_mask, 128 | perspective_video, 129 | resize, 130 | resize_bounding_box, 131 | resize_image_pil, 132 | resize_image_tensor, 133 | resize_mask, 134 | resize_video, 135 | resized_crop, 136 | resized_crop_bounding_box, 137 | resized_crop_image_pil, 138 | resized_crop_image_tensor, 139 | resized_crop_mask, 140 | resized_crop_video, 141 | rotate, 142 | rotate_bounding_box, 143 | rotate_image_pil, 144 | rotate_image_tensor, 145 | rotate_mask, 146 | rotate_video, 147 | ten_crop, 148 | ten_crop_image_pil, 149 | ten_crop_image_tensor, 150 | ten_crop_video, 151 | vertical_flip, 152 | vertical_flip_bounding_box, 153 | vertical_flip_image_pil, 154 | vertical_flip_image_tensor, 155 | vertical_flip_mask, 156 | vertical_flip_video, 157 | vflip, 158 | ) 159 | from ._misc import ( 160 | gaussian_blur, 161 | gaussian_blur_image_pil, 162 | gaussian_blur_image_tensor, 163 | gaussian_blur_video, 164 | normalize, 165 | normalize_image_tensor, 166 | normalize_video, 167 | ) 168 | from ._temporal import uniform_temporal_subsample, uniform_temporal_subsample_video 169 | from ._type_conversion import pil_to_tensor, to_image_pil, to_image_tensor, to_pil_image 170 | 171 | from ._deprecated import get_image_size, to_grayscale, to_tensor # usort: skip 172 | -------------------------------------------------------------------------------- /transforms/v2/functional/_augment.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | import PIL.Image 4 | 5 | import torch 6 | from util import datapoints 7 | from transforms.functional import pil_to_tensor, to_pil_image 8 | 9 | from ._utils import is_simple_tensor 10 | 11 | 12 | def erase_image_tensor( 13 | image: torch.Tensor, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False 14 | ) -> torch.Tensor: 15 | if not inplace: 16 | image = image.clone() 17 | 18 | image[..., i : i + h, j : j + w] = v 19 | return image 20 | 21 | 22 | @torch.jit.unused 23 | def erase_image_pil( 24 | image: PIL.Image.Image, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False 25 | ) -> PIL.Image.Image: 26 | t_img = pil_to_tensor(image) 27 | output = erase_image_tensor(t_img, i=i, j=j, h=h, w=w, v=v, inplace=inplace) 28 | return to_pil_image(output, mode=image.mode) 29 | 30 | 31 | def erase_video( 32 | video: torch.Tensor, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False 33 | ) -> torch.Tensor: 34 | return erase_image_tensor(video, i=i, j=j, h=h, w=w, v=v, inplace=inplace) 35 | 36 | 37 | def erase( 38 | inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT], 39 | i: int, 40 | j: int, 41 | h: int, 42 | w: int, 43 | v: torch.Tensor, 44 | inplace: bool = False, 45 | ) -> Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT]: 46 | if torch.jit.is_scripting() or is_simple_tensor(inpt): 47 | return erase_image_tensor(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace) 48 | elif isinstance(inpt, datapoints.Image): 49 | output = erase_image_tensor(inpt.as_subclass(torch.Tensor), i=i, j=j, h=h, w=w, v=v, inplace=inplace) 50 | return datapoints.Image.wrap_like(inpt, output) 51 | elif isinstance(inpt, datapoints.Video): 52 | output = erase_video(inpt.as_subclass(torch.Tensor), i=i, j=j, h=h, w=w, v=v, inplace=inplace) 53 | return datapoints.Video.wrap_like(inpt, output) 54 | elif isinstance(inpt, PIL.Image.Image): 55 | return erase_image_pil(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace) 56 | else: 57 | raise TypeError( 58 | f"Input can either be a plain tensor, an `Image` or `Video` datapoint, or a PIL image, " 59 | f"but got {type(inpt)} instead." 60 | ) 61 | -------------------------------------------------------------------------------- /transforms/v2/functional/_deprecated.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from typing import Any, List, Union 3 | 4 | import PIL.Image 5 | import torch 6 | 7 | from util import datapoints 8 | from transforms import functional as _F 9 | 10 | 11 | @torch.jit.unused 12 | def to_grayscale(inpt: PIL.Image.Image, num_output_channels: int = 1) -> PIL.Image.Image: 13 | call = ", num_output_channels=3" if num_output_channels == 3 else "" 14 | replacement = "convert_color_space(..., color_space=datapoints.ColorSpace.GRAY)" 15 | if num_output_channels == 3: 16 | replacement = f"convert_color_space({replacement}, color_space=datapoints.ColorSpace.RGB)" 17 | warnings.warn( 18 | f"The function `to_grayscale(...{call})` is deprecated in will be removed in a future release. " 19 | f"Instead, please use `{replacement}`.", 20 | ) 21 | 22 | return _F.to_grayscale(inpt, num_output_channels=num_output_channels) 23 | 24 | 25 | @torch.jit.unused 26 | def to_tensor(inpt: Any) -> torch.Tensor: 27 | warnings.warn( 28 | "The function `to_tensor(...)` is deprecated and will be removed in a future release. " 29 | "Instead, please use `to_image_tensor(...)` followed by `convert_image_dtype(...)`." 30 | ) 31 | return _F.to_tensor(inpt) 32 | 33 | 34 | def get_image_size(inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT]) -> List[int]: 35 | warnings.warn( 36 | "The function `get_image_size(...)` is deprecated and will be removed in a future release. " 37 | "Instead, please use `get_spatial_size(...)` which returns `[h, w]` instead of `[w, h]`." 38 | ) 39 | return _F.get_image_size(inpt) 40 | -------------------------------------------------------------------------------- /transforms/v2/functional/_temporal.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from util import datapoints 4 | 5 | from ._utils import is_simple_tensor 6 | 7 | 8 | def uniform_temporal_subsample_video(video: torch.Tensor, num_samples: int) -> torch.Tensor: 9 | # Reference: https://github.com/facebookresearch/pytorchvideo/blob/a0a131e/pytorchvideo/transforms/functional.py#L19 10 | t_max = video.shape[-4] - 1 11 | indices = torch.linspace(0, t_max, num_samples, device=video.device).long() 12 | return torch.index_select(video, -4, indices) 13 | 14 | 15 | def uniform_temporal_subsample(inpt: datapoints._VideoTypeJIT, num_samples: int) -> datapoints._VideoTypeJIT: 16 | if torch.jit.is_scripting() or is_simple_tensor(inpt): 17 | return uniform_temporal_subsample_video(inpt, num_samples) 18 | elif isinstance(inpt, datapoints.Video): 19 | output = uniform_temporal_subsample_video(inpt.as_subclass(torch.Tensor), num_samples) 20 | return datapoints.Video.wrap_like(inpt, output) 21 | else: 22 | raise TypeError(f"Input can either be a plain tensor or a `Video` datapoint, but got {type(inpt)} instead.") 23 | -------------------------------------------------------------------------------- /transforms/v2/functional/_type_conversion.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | import numpy as np 4 | import PIL.Image 5 | import torch 6 | from util import datapoints 7 | from transforms import functional as _F 8 | 9 | 10 | @torch.jit.unused 11 | def to_image_tensor(inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray]) -> datapoints.Image: 12 | if isinstance(inpt, np.ndarray): 13 | output = torch.from_numpy(inpt).permute((2, 0, 1)).contiguous() 14 | elif isinstance(inpt, PIL.Image.Image): 15 | output = pil_to_tensor(inpt) 16 | elif isinstance(inpt, torch.Tensor): 17 | output = inpt 18 | else: 19 | raise TypeError(f"Input can either be a numpy array or a PIL image, but got {type(inpt)} instead.") 20 | return datapoints.Image(output) 21 | 22 | 23 | to_image_pil = _F.to_pil_image 24 | pil_to_tensor = _F.pil_to_tensor 25 | 26 | # We changed the names to align them with the new naming scheme. Still, `to_pil_image` is 27 | # prevalent and well understood. Thus, we just alias it without deprecating the old name. 28 | to_pil_image = to_image_pil 29 | -------------------------------------------------------------------------------- /transforms/v2/functional/_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import torch 4 | from util.datapoints import Datapoint 5 | 6 | 7 | def is_simple_tensor(inpt: Any) -> bool: 8 | return isinstance(inpt, torch.Tensor) and not isinstance(inpt, Datapoint) 9 | -------------------------------------------------------------------------------- /transforms/v2/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Any, Callable, List, Tuple, Type, Union, Sequence 4 | 5 | import PIL.Image 6 | from util import datapoints 7 | 8 | from transforms.v2.functional import get_dimensions, get_spatial_size, is_simple_tensor 9 | 10 | 11 | def sequence_to_str(seq: Sequence, separate_last: str = "") -> str: 12 | if not seq: 13 | return "" 14 | if len(seq) == 1: 15 | return f"'{seq[0]}'" 16 | 17 | head = "'" + "', '".join([str(item) for item in seq[:-1]]) + "'" 18 | tail = f"{'' if separate_last and len(seq) == 2 else ','} {separate_last}'{seq[-1]}'" 19 | 20 | return head + tail 21 | 22 | 23 | def query_bounding_box(flat_inputs: List[Any]) -> datapoints.BoundingBox: 24 | bounding_boxes = [inpt for inpt in flat_inputs if isinstance(inpt, datapoints.BoundingBox)] 25 | if not bounding_boxes: 26 | raise TypeError("No bounding box was found in the sample") 27 | elif len(bounding_boxes) > 1: 28 | raise ValueError("Found multiple bounding boxes in the sample") 29 | return bounding_boxes.pop() 30 | 31 | 32 | def query_chw(flat_inputs: List[Any]) -> Tuple[int, int, int]: 33 | chws = { 34 | tuple(get_dimensions(inpt)) 35 | for inpt in flat_inputs 36 | if isinstance(inpt, (datapoints.Image, PIL.Image.Image, datapoints.Video)) or is_simple_tensor(inpt) 37 | } 38 | if not chws: 39 | raise TypeError("No image or video was found in the sample") 40 | elif len(chws) > 1: 41 | raise ValueError(f"Found multiple CxHxW dimensions in the sample: {sequence_to_str(sorted(chws))}") 42 | c, h, w = chws.pop() 43 | return c, h, w 44 | 45 | 46 | def query_spatial_size(flat_inputs: List[Any]) -> Tuple[int, int]: 47 | sizes = { 48 | tuple(get_spatial_size(inpt)) 49 | for inpt in flat_inputs 50 | if isinstance( 51 | inpt, (datapoints.Image, PIL.Image.Image, datapoints.Video, datapoints.Mask, datapoints.BoundingBox) 52 | ) 53 | or is_simple_tensor(inpt) 54 | } 55 | if not sizes: 56 | raise TypeError("No image, video, mask or bounding box was found in the sample") 57 | elif len(sizes) > 1: 58 | raise ValueError(f"Found multiple HxW dimensions in the sample: {sequence_to_str(sorted(sizes))}") 59 | h, w = sizes.pop() 60 | return h, w 61 | 62 | 63 | def check_type(obj: Any, types_or_checks: Tuple[Union[Type, Callable[[Any], bool]], ...]) -> bool: 64 | for type_or_check in types_or_checks: 65 | if isinstance(obj, type_or_check) if isinstance(type_or_check, type) else type_or_check(obj): 66 | return True 67 | return False 68 | 69 | 70 | def has_any(flat_inputs: List[Any], *types_or_checks: Union[Type, Callable[[Any], bool]]) -> bool: 71 | for inpt in flat_inputs: 72 | if check_type(inpt, types_or_checks): 73 | return True 74 | return False 75 | 76 | 77 | def has_all(flat_inputs: List[Any], *types_or_checks: Union[Type, Callable[[Any], bool]]) -> bool: 78 | for type_or_check in types_or_checks: 79 | for inpt in flat_inputs: 80 | if isinstance(inpt, type_or_check) if isinstance(type_or_check, type) else type_or_check(inpt): 81 | break 82 | else: 83 | return False 84 | return True 85 | -------------------------------------------------------------------------------- /util/coco_utils.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import torch 4 | import torch.utils.data 5 | import torchvision 6 | from pycocotools import mask as coco_mask 7 | from pycocotools.coco import COCO 8 | from tqdm import tqdm 9 | 10 | 11 | class FilterAndRemapCocoCategories(object): 12 | def __init__(self, categories, remap=True): 13 | self.categories = categories 14 | self.remap = remap 15 | 16 | def __call__(self, image, target): 17 | anno = target["annotations"] 18 | anno = [obj for obj in anno if obj["category_id"] in self.categories] 19 | if not self.remap: 20 | target["annotations"] = anno 21 | return image, target 22 | anno = copy.deepcopy(anno) 23 | for obj in anno: 24 | obj["category_id"] = self.categories.index(obj["category_id"]) 25 | target["annotations"] = anno 26 | return image, target 27 | 28 | 29 | def convert_to_coco_api(ds): 30 | coco_ds = COCO() 31 | ann_id = 0 32 | dataset = {"images": [], "categories": [], "annotations": []} 33 | categories = set() 34 | for img_idx in tqdm(range(len(ds))): 35 | # find better way to get target 36 | # targets = ds.get_annotations(img_idx) 37 | img, targets = ds[img_idx] 38 | image_id = targets["image_id"].item() 39 | img_dict = {} 40 | img_dict["id"] = image_id 41 | img_dict["height"] = img.shape[-2] 42 | img_dict["width"] = img.shape[-1] 43 | dataset["images"].append(img_dict) 44 | bboxes = targets["boxes"] 45 | bboxes[:, 2:] -= bboxes[:, :2] 46 | bboxes = bboxes.tolist() 47 | labels = targets["labels"].tolist() 48 | areas = targets["area"].tolist() 49 | iscrowd = targets["iscrowd"].tolist() 50 | if "masks" in targets: 51 | masks = targets["masks"] 52 | # make masks Fortran contiguous for coco_mask 53 | masks = masks.permute(0, 2, 1).contiguous().permute(0, 2, 1) 54 | if "keypoints" in targets: 55 | keypoints = targets["keypoints"] 56 | keypoints = keypoints.reshape(keypoints.shape[0], -1).tolist() 57 | num_objs = len(bboxes) 58 | for i in range(num_objs): 59 | ann = {} 60 | ann["image_id"] = image_id 61 | ann["bbox"] = bboxes[i] 62 | ann["category_id"] = labels[i] 63 | categories.add(labels[i]) 64 | ann["area"] = areas[i] 65 | ann["iscrowd"] = iscrowd[i] 66 | ann["id"] = ann_id 67 | if "masks" in targets: 68 | ann["segmentation"] = coco_mask.encode(masks[i].numpy()) 69 | if "keypoints" in targets: 70 | ann["keypoints"] = keypoints[i] 71 | ann["num_keypoints"] = sum(k != 0 for k in keypoints[i][2::3]) 72 | dataset["annotations"].append(ann) 73 | ann_id += 1 74 | dataset["categories"] = [{"id": i} for i in sorted(categories)] 75 | coco_ds.dataset = dataset 76 | coco_ds.createIndex() 77 | return coco_ds 78 | 79 | 80 | def get_coco_api_from_dataset(dataset): 81 | for _ in range(10): 82 | if isinstance(dataset, torchvision.datasets.CocoDetection): 83 | break 84 | if isinstance(dataset, torch.utils.data.Subset): 85 | dataset = dataset.dataset 86 | if isinstance(dataset, torchvision.datasets.CocoDetection): 87 | return dataset.coco 88 | return convert_to_coco_api(dataset) 89 | -------------------------------------------------------------------------------- /util/collate_fn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from transforms import InterpolationMode 4 | from transforms.simple_copy_paste import SimpleCopyPaste 5 | from util.misc import to_device 6 | 7 | 8 | def collate_fn(batch): 9 | return tuple(zip(*batch)) 10 | 11 | 12 | def copypaste_collate_fn(batch): 13 | copypaste = SimpleCopyPaste(blending=True, resize_interpolation=InterpolationMode.BILINEAR) 14 | return copypaste(*collate_fn(batch)) 15 | 16 | 17 | class DataPrefetcher: 18 | def __init__(self, loader, device): 19 | self.loader = iter(loader) 20 | self.device = device 21 | if torch.cuda.is_available(): 22 | self.stream = torch.cuda.Stream() 23 | self.preload() 24 | 25 | def preload(self): 26 | try: 27 | self.next_batch = next(self.loader) 28 | except StopIteration: 29 | self.next_batch = None 30 | return 31 | 32 | if torch.cuda.is_available(): 33 | with torch.cuda.stream(self.stream): 34 | self.next_batch = to_device(self.next_batch, self.device) 35 | else: 36 | self.next_batch = to_device(self.next_batch, self.device) 37 | 38 | # With Amp, it isn't necessary to manually convert data to half. 39 | # if args.fp16: 40 | # self.next_input = self.next_input.half() 41 | # else: 42 | # self.next_input = self.next_input.float() 43 | 44 | def next(self): 45 | if torch.cuda.is_available(): 46 | torch.cuda.current_stream().wait_stream(self.stream) 47 | batch = self.next_batch 48 | self.preload() 49 | return batch 50 | -------------------------------------------------------------------------------- /util/collect_env.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os 3 | import re 4 | import subprocess 5 | import sys 6 | from collections import defaultdict 7 | 8 | import numpy as np 9 | import PIL 10 | import torch 11 | import torchvision 12 | from tabulate import tabulate 13 | 14 | 15 | def collect_torch_env(): 16 | try: 17 | import torch.__config__ 18 | 19 | return torch.__config__.show() 20 | except ImportError: 21 | # compatible with older versions of pytorch 22 | from torch.utils.collect_env import get_pretty_env_info 23 | 24 | return get_pretty_env_info() 25 | 26 | 27 | def detect_compute_compatibility(CUDA_HOME, so_file): 28 | try: 29 | cuobjdump = os.path.join(CUDA_HOME, "bin", "cuobjdump") 30 | if os.path.isfile(cuobjdump): 31 | output = subprocess.check_output( 32 | "'{}' --list-elf '{}'".format(cuobjdump, so_file), shell=True 33 | ) 34 | output = output.decode("utf-8").strip().split("\n") 35 | arch = [] 36 | for line in output: 37 | line = re.findall(r"\.sm_([0-9]*)\.", line)[0] 38 | arch.append(".".join(line)) 39 | arch = sorted(set(arch)) 40 | return ", ".join(arch) 41 | else: 42 | return so_file + "; cannot find cuobjdump" 43 | except Exception: 44 | # unhandled failure 45 | return so_file 46 | 47 | 48 | def collect_env_info(): 49 | has_gpu = torch.cuda.is_available() # true for both CUDA & ROCM 50 | torch_version = torch.__version__ 51 | 52 | # NOTE that CUDA_HOME/ROCM_HOME could be None even when CUDA runtime libs are functional 53 | from torch.utils.cpp_extension import CUDA_HOME, ROCM_HOME 54 | 55 | has_rocm = False 56 | if (getattr(torch.version, "hip", None) is not None) and (ROCM_HOME is not None): 57 | has_rocm = True 58 | has_cuda = has_gpu and (not has_rocm) 59 | 60 | data = [] 61 | data.append(("sys.platform", sys.platform)) # check-template.yml depends on it 62 | data.append(("Python", sys.version.replace("\n", ""))) 63 | data.append(("numpy", np.__version__)) 64 | 65 | data.append(("PyTorch", torch_version + " @" + os.path.dirname(torch.__file__))) 66 | data.append(("PyTorch debug build", torch.version.debug)) 67 | try: 68 | data.append(("torch._C._GLIBCXX_USE_CXX11_ABI", torch._C._GLIBCXX_USE_CXX11_ABI)) 69 | except Exception: 70 | pass 71 | 72 | if not has_gpu: 73 | has_gpu_text = "No: torch.cuda.is_available() == False" 74 | else: 75 | has_gpu_text = "Yes" 76 | data.append(("GPU available", has_gpu_text)) 77 | if has_gpu: 78 | devices = defaultdict(list) 79 | for k in range(torch.cuda.device_count()): 80 | cap = ".".join((str(x) for x in torch.cuda.get_device_capability(k))) 81 | name = torch.cuda.get_device_name(k) + f" (arch={cap})" 82 | devices[name].append(str(k)) 83 | for name, devids in devices.items(): 84 | data.append(("GPU " + ",".join(devids), name)) 85 | 86 | if has_rocm: 87 | msg = " - invalid!" if not (ROCM_HOME and os.path.isdir(ROCM_HOME)) else "" 88 | data.append(("ROCM_HOME", str(ROCM_HOME) + msg)) 89 | else: 90 | try: 91 | from torch.utils.collect_env import get_nvidia_driver_version 92 | from torch.utils.collect_env import run as _run 93 | 94 | data.append(("Driver version", get_nvidia_driver_version(_run))) 95 | except Exception: 96 | pass 97 | msg = " - invalid!" if not (CUDA_HOME and os.path.isdir(CUDA_HOME)) else "" 98 | data.append(("CUDA_HOME", str(CUDA_HOME) + msg)) 99 | 100 | cuda_arch_list = os.environ.get("TORCH_CUDA_ARCH_LIST", None) 101 | if cuda_arch_list: 102 | data.append(("TORCH_CUDA_ARCH_LIST", cuda_arch_list)) 103 | data.append(("Pillow", PIL.__version__)) 104 | 105 | try: 106 | data.append(( 107 | "torchvision", 108 | str(torchvision.__version__) + " @" + os.path.dirname(torchvision.__file__), 109 | )) 110 | if has_cuda: 111 | try: 112 | torchvision_C = importlib.util.find_spec("torchvision._C").origin 113 | msg = detect_compute_compatibility(CUDA_HOME, torchvision_C) 114 | data.append(("torchvision arch flags", msg)) 115 | except (ImportError, AttributeError): 116 | data.append(("torchvision._C", "Not found")) 117 | except AttributeError: 118 | data.append(("torchvision", "unknown")) 119 | 120 | try: 121 | import fvcore 122 | 123 | data.append(("fvcore", fvcore.__version__)) 124 | except (ImportError, AttributeError): 125 | pass 126 | 127 | try: 128 | import iopath 129 | 130 | data.append(("iopath", iopath.__version__)) 131 | except (ImportError, AttributeError): 132 | pass 133 | 134 | try: 135 | import cv2 136 | 137 | data.append(("cv2", cv2.__version__)) 138 | except (ImportError, AttributeError): 139 | data.append(("cv2", "Not found")) 140 | env_str = tabulate(data) + "\n" 141 | env_str += collect_torch_env() 142 | return env_str 143 | -------------------------------------------------------------------------------- /util/logger.py: -------------------------------------------------------------------------------- 1 | import atexit 2 | import functools 3 | import logging 4 | import os 5 | import re 6 | import sys 7 | 8 | from accelerate.logging import get_logger 9 | from iopath.common.file_io import HTTPURLHandler, OneDrivePathHandler 10 | from iopath.common.file_io import PathManager as PathManagerBase 11 | from termcolor import colored 12 | 13 | PathManager = PathManagerBase() 14 | PathManager.register_handler(HTTPURLHandler()) 15 | PathManager.register_handler(OneDrivePathHandler()) 16 | 17 | 18 | def create_logger(output_dir=None, dist_rank=0): 19 | logger = logging.getLogger() 20 | logger.setLevel(logging.INFO) 21 | logger.propagate = False 22 | 23 | fmt = "[%(asctime)s %(name)s] (%(filename)s %(lineno)d): %(levelname)s %(message)s" 24 | color_fmt = colored("[%(asctime)s %(name)s]", "green") 25 | color_fmt += colored("(%(filename)s %(lineno)d)", "yellow") 26 | color_fmt += ": %(levelname)s %(message)s" 27 | 28 | # create console handlers for master process 29 | if dist_rank == 0: 30 | console_handler = logging.StreamHandler(sys.stdout) 31 | console_handler.setLevel(logging.DEBUG) 32 | console_handler.setFormatter(logging.Formatter(fmt=color_fmt, datefmt="%Y-%m-%d %H:%M:%S")) 33 | logger.addHandler(console_handler) 34 | 35 | # create file handlers 36 | if output_dir: 37 | file_handler = logging.FileHandler(os.path.join(output_dir, "training.log"), mode="a") 38 | file_handler.setFormatter(logging.Formatter(fmt=fmt, datefmt="%Y-%m-%d %H:%M:%S")) 39 | logger.addHandler(file_handler) 40 | 41 | return logger 42 | 43 | 44 | class _ColorfulFormatter(logging.Formatter): 45 | def __init__(self, *args, **kwargs): 46 | self._root_name = kwargs.pop("root_name") + "." 47 | self._abbrev_name = kwargs.pop("abbrev_name", "") 48 | if len(self._abbrev_name): 49 | self._abbrev_name = self._abbrev_name + "." 50 | super(_ColorfulFormatter, self).__init__(*args, **kwargs) 51 | 52 | def formatMessage(self, record): 53 | record.name = record.name.replace(self._root_name, self._abbrev_name) 54 | log = super(_ColorfulFormatter, self).formatMessage(record) 55 | if record.levelno == logging.WARNING: 56 | prefix = colored("WARNING", "red", attrs=["blink"]) 57 | elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL: 58 | prefix = colored("ERROR", "red", attrs=["blink", "underline"]) 59 | else: 60 | return log 61 | return prefix + " " + log 62 | 63 | 64 | class ColorFilter(logging.Filter): 65 | def filter(self, record): 66 | message = record.getMessage() 67 | # matching colored patterns 68 | pattern = re.compile(r'\x1b\[[0-9;]*m') 69 | if pattern.search(message): 70 | record.msg = pattern.sub('', message) 71 | return True 72 | 73 | 74 | @functools.lru_cache() # so that calling setup_logger multiple times won't add many handlers 75 | def setup_logger( 76 | output=None, 77 | distributed_rank=0, 78 | *, 79 | color=True, 80 | name="detection", 81 | abbrev_name=None, 82 | enable_propagation: bool = False, 83 | configure_stdout: bool = True, 84 | ): 85 | """Initialize the detection logger and set its verbosity level to "DEBUG" 86 | 87 | :param output: a file name or a directory to save log. If None, will not save log file. 88 | If ends with ".txt" or ".log", assumed to be a file name, defaults to None 89 | :param distributed_rank: rank number id in distributed training, defaults to 0 90 | :param color: whether to show colored logging information, defaults to True 91 | :param name: the root module name of this logger, defaults to "detection" 92 | :param abbrev_name: an abbreviation of the module, to avoid long names in logs. 93 | Set to "" to not log the root module in logs. By default, will abbreviate "detection" 94 | to "det" and leave other modules unchanged, defaults to None 95 | :param enable_propagation: whether to propogate logs to the parent logger, defaults to False 96 | :param configure_stdout: whether to configure logging to stdout, defaults to True 97 | """ 98 | logger_adapter = get_logger(name, "DEBUG") 99 | logger = logger_adapter.logger 100 | logger.propagate = enable_propagation 101 | 102 | if abbrev_name is None: 103 | abbrev_name = name.replace(os.path.basename(os.getcwd()), "det") 104 | 105 | plain_formatter = logging.Formatter( 106 | "[%(asctime)s %(name)s] %(levelname)s: %(message)s", datefmt="%Y-%m-%d %H:%M:%S" 107 | ) 108 | # stdout logging: master only 109 | if configure_stdout and distributed_rank == 0: 110 | ch = logging.StreamHandler(stream=sys.stdout) 111 | ch.setLevel(logging.DEBUG) 112 | if color: 113 | formatter = _ColorfulFormatter( 114 | colored("[%(asctime)s %(name)s]: ", "green") + "%(message)s", 115 | datefmt="%Y-%m-%d %H:%M:%S", 116 | root_name=name, 117 | abbrev_name=str(abbrev_name), 118 | ) 119 | else: 120 | formatter = plain_formatter 121 | ch.setFormatter(formatter) 122 | logger.addHandler(ch) 123 | 124 | # file logging: all workers 125 | if output is not None: 126 | if output.endswith(".txt") or output.endswith(".log"): 127 | filename = output 128 | else: 129 | filename = os.path.join(output, "log.log") 130 | if distributed_rank > 0: 131 | filename = filename.replace(".", "_rank{}".format(distributed_rank) + ".") 132 | os.makedirs(os.path.dirname(filename), exist_ok=True) 133 | 134 | fh = logging.StreamHandler(_cached_log_stream(filename)) 135 | fh.addFilter(ColorFilter()) 136 | fh.setLevel(logging.DEBUG) 137 | fh.setFormatter(plain_formatter) 138 | logger.addHandler(fh) 139 | 140 | return logger_adapter 141 | 142 | 143 | # cache the opened file object, so that different calls to `setup_logger` 144 | # with the same file name can safely write to the same file. 145 | @functools.lru_cache(maxsize=None) 146 | def _cached_log_stream(filename): 147 | # use 1K buffer if writing to cloud storage 148 | io = PathManager.open(filename, "a", buffering=_get_log_stream_buffer_size(filename)) 149 | atexit.register(io.close) 150 | return io 151 | 152 | 153 | def _get_log_stream_buffer_size(filename: str) -> int: 154 | if "://" not in filename: 155 | # Local file, no extra caching is necessary 156 | return -1 157 | # Remote file requires a larger cache to avoid many small writes. 158 | return 1024 * 1024 159 | -------------------------------------------------------------------------------- /util/tune_mode_convbn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def efficient_conv_bn_eval_forward( 6 | bn: nn.modules.batchnorm._BatchNorm, conv: nn.modules.conv._ConvNd, x: torch.Tensor 7 | ): 8 | """ 9 | Implementation based on https://arxiv.org/abs/2305.11624 10 | "Tune-Mode ConvBN Blocks For Efficient Transfer Learning" 11 | It leverages the associative law between convolution and affine transform, 12 | i.e., normalize (weight conv feature) = (normalize weight) conv feature. 13 | It works for Eval mode of ConvBN blocks during validation, and can be used 14 | for training as well. It reduces memory and computation cost. 15 | Args: 16 | bn (_BatchNorm): a BatchNorm module. 17 | conv (nn._ConvNd): a conv module 18 | x (torch.Tensor): Input feature map. 19 | """ 20 | # These lines of code are designed to deal with various cases 21 | # like bn without affine transform, and conv without bias 22 | weight_on_the_fly = conv.weight 23 | if conv.bias is not None: 24 | bias_on_the_fly = conv.bias 25 | else: 26 | bias_on_the_fly = torch.zeros_like(bn.running_var) 27 | 28 | if bn.weight is not None: 29 | bn_weight = bn.weight 30 | else: 31 | bn_weight = torch.ones_like(bn.running_var) 32 | 33 | if bn.bias is not None: 34 | bn_bias = bn.bias 35 | else: 36 | bn_bias = torch.zeros_like(bn.running_var) 37 | 38 | # shape of [C_out, 1, 1, 1] in Conv2d 39 | weight_coeff = torch.rsqrt(bn.running_var + bn.eps).reshape([-1] + [1] * 40 | (len(conv.weight.shape) - 1)) 41 | # shape of [C_out, 1, 1, 1] in Conv2d 42 | coefff_on_the_fly = bn_weight.view_as(weight_coeff) * weight_coeff 43 | 44 | # shape of [C_out, C_in, k, k] in Conv2d 45 | weight_on_the_fly = weight_on_the_fly * coefff_on_the_fly 46 | # shape of [C_out] in Conv2d 47 | bias_on_the_fly = bn_bias + coefff_on_the_fly.flatten() *\ 48 | (bias_on_the_fly - bn.running_mean) 49 | 50 | return conv._conv_forward(x, weight_on_the_fly, bias_on_the_fly) 51 | 52 | 53 | def efficient_conv_bn_eval_control( 54 | bn: nn.modules.batchnorm._BatchNorm, conv: nn.modules.conv._ConvNd, x: torch.Tensor 55 | ): 56 | """This function controls whether to use `efficient_conv_bn_eval_forward`. 57 | 58 | If the following `bn` is in `eval` mode, then we turn on the special 59 | `efficient_conv_bn_eval_forward`. 60 | """ 61 | if not bn.training: 62 | # bn in eval mode 63 | output = efficient_conv_bn_eval_forward(bn, conv, x) 64 | return output 65 | else: 66 | conv_out = conv._conv_forward(x, conv.weight, conv.bias) 67 | return bn(conv_out) 68 | 69 | 70 | def efficient_conv_bn_eval_graph_transform(fx_model): 71 | """Find consecutive conv+bn calls in the graph, inplace modify the graph 72 | with the fused operation.""" 73 | modules = dict(fx_model.named_modules()) 74 | 75 | patterns = [(torch.nn.modules.conv._ConvNd, torch.nn.modules.batchnorm._BatchNorm)] 76 | 77 | pairs = [] 78 | # Iterate through nodes in the graph to find ConvBN blocks 79 | for node in fx_model.graph.nodes: 80 | # If our current node isn't calling a Module then we can ignore it. 81 | if node.op != 'call_module': 82 | continue 83 | target_module = modules[node.target] 84 | found_pair = False 85 | for conv_class, bn_class in patterns: 86 | if isinstance(target_module, bn_class): 87 | source_module = modules[node.args[0].target] 88 | if isinstance(source_module, conv_class): 89 | found_pair = True 90 | # Not a conv-BN pattern or output of conv is used by other nodes 91 | if not found_pair or len(node.args[0].users) > 1: 92 | continue 93 | 94 | # Find a pair of conv and bn computation nodes to optimize 95 | conv_node = node.args[0] 96 | bn_node = node 97 | pairs.append([conv_node, bn_node]) 98 | 99 | for conv_node, bn_node in pairs: 100 | # set insertion point 101 | fx_model.graph.inserting_before(conv_node) 102 | # create `get_attr` node to access modules 103 | # note that we directly call `create_node` to fill the `name` 104 | # argument. `fx_model.graph.get_attr` and 105 | # `fx_model.graph.call_function` does not allow the `name` argument. 106 | conv_get_node = fx_model.graph.create_node( 107 | op='get_attr', target=conv_node.target, name='get_conv' 108 | ) 109 | bn_get_node = fx_model.graph.create_node( 110 | op='get_attr', target=bn_node.target, name='get_bn' 111 | ) 112 | # prepare args for the fused function 113 | args = (bn_get_node, conv_get_node, conv_node.args[0]) 114 | # create a new node 115 | new_node = fx_model.graph.create_node( 116 | op='call_function', 117 | target=efficient_conv_bn_eval_control, 118 | args=args, 119 | name='efficient_conv_bn_eval' 120 | ) 121 | # this node replaces the original conv + bn, and therefore 122 | # should replace the uses of bn_node 123 | bn_node.replace_all_uses_with(new_node) 124 | # take care of the deletion order: 125 | # delete bn_node first, and then conv_node 126 | fx_model.graph.erase_node(bn_node) 127 | fx_model.graph.erase_node(conv_node) 128 | 129 | # regenerate the code 130 | fx_model.graph.lint() 131 | fx_model.recompile() 132 | 133 | 134 | def turn_on_efficient_conv_bn_eval_for_single_model(model: torch.nn.Module): 135 | import torch.fx as fx 136 | 137 | # currently we use `fx.symbolic_trace` to trace models. 138 | # in the future, we might turn to pytorch 2.0 compile infrastructure to 139 | # get the `fx.GraphModule` IR. Nonetheless, the graph transform function 140 | # can remain unchanged. We just need to change the way 141 | # we get `fx.GraphModule`. 142 | fx_model: fx.GraphModule = fx.symbolic_trace(model) 143 | efficient_conv_bn_eval_graph_transform(fx_model) 144 | model.forward = fx_model.forward 145 | --------------------------------------------------------------------------------