├── .gitignore ├── LICENSE ├── README.md ├── config.py ├── configs ├── mtlora │ └── tiny_448 │ │ ├── mtlora_plus_tiny_448_r16_scale4.yaml │ │ ├── mtlora_plus_tiny_448_r16_scale4_pertask.yaml │ │ ├── mtlora_plus_tiny_448_r32_scale4_pertask.yaml │ │ ├── mtlora_plus_tiny_448_r4_scale4.yaml │ │ ├── mtlora_plus_tiny_448_r64_scale4_pertask.yaml │ │ ├── mtlora_plus_tiny_448_r8_scale4.yaml │ │ ├── mtlora_tiny_448_r16_scale4_pertask.yaml │ │ ├── mtlora_tiny_448_r32_scale4_pertask.yaml │ │ └── mtlora_tiny_448_r64_scale4_pertask.yaml └── swin │ └── swin_tiny_patch4_window7_448.yaml ├── data ├── __init__.py ├── base.py ├── build.py ├── cached_image_folder.py ├── custom_transforms.py ├── data_simmim_ft.py ├── data_simmim_pt.py ├── db_info │ ├── context_classes.json │ ├── nyu_classes.json │ ├── pascal_map.npy │ └── pascal_part.json ├── helpers.py ├── imagenet22k_dataset.py ├── map22kto1k.txt ├── mtl_ds.py ├── samplers.py └── zipreader.py ├── evaluation ├── eval_depth.py ├── eval_edge.py ├── eval_human_parts.py ├── eval_normals.py ├── eval_normals_v1.py ├── eval_normals_v2.py ├── eval_sal.py ├── eval_sal_beta.py ├── eval_sal_no_beta.py ├── eval_semseg.py ├── evaluate_utils.py └── jaccard.py ├── kernels └── window_process │ ├── setup.py │ ├── swin_window_process.cpp │ ├── swin_window_process_kernel.cu │ ├── unit_test.py │ └── window_process.py ├── logger.py ├── lr_scheduler.py ├── main.py ├── models ├── __init__.py ├── aspp.py ├── aspp_single.py ├── base_decode_head.py ├── build.py ├── lora.py ├── seg_hrnet.py ├── segformer.py ├── swin_mtl.py ├── swin_transformer.py ├── swin_transformer_mtlora.py ├── transformer_head.py └── updecoder.py ├── mtl_loss_schemes.py ├── optimizer.py ├── requirements.txt └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # launch bash 7 | *.sh 8 | # nsight system report files 9 | *.nsys-rep 10 | *.sqlite 11 | 12 | # C extensions 13 | *.so 14 | 15 | # Distribution / packaging 16 | .Python 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | wheels/ 29 | pip-wheel-metadata/ 30 | share/python-wheels/ 31 | *.egg-info/ 32 | .installed.cfg 33 | *.egg 34 | MANIFEST 35 | 36 | # PyInstaller 37 | # Usually these files are written by a python script from a template 38 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 39 | *.manifest 40 | *.spec 41 | 42 | # Installer logs 43 | pip-log.txt 44 | pip-delete-this-directory.txt 45 | 46 | # Unit test / coverage reports 47 | htmlcov/ 48 | .tox/ 49 | .nox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | *.cover 56 | *.py,cover 57 | .hypothesis/ 58 | .pytest_cache/ 59 | 60 | # Translations 61 | *.mo 62 | *.pot 63 | 64 | # Django stuff: 65 | *.log 66 | local_settings.py 67 | db.sqlite3 68 | db.sqlite3-journal 69 | 70 | # Flask stuff: 71 | instance/ 72 | .webassets-cache 73 | 74 | # Scrapy stuff: 75 | .scrapy 76 | 77 | # Sphinx documentation 78 | docs/_build/ 79 | 80 | # PyBuilder 81 | target/ 82 | 83 | # Jupyter Notebook 84 | .ipynb_checkpoints 85 | 86 | # IPython 87 | profile_default/ 88 | ipython_config.py 89 | 90 | # pyenv 91 | .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 101 | __pypackages__/ 102 | 103 | # Celery stuff 104 | celerybeat-schedule 105 | celerybeat.pid 106 | 107 | # SageMath parsed files 108 | *.sage.py 109 | 110 | # Environments 111 | .env 112 | .venv 113 | env/ 114 | venv/ 115 | ENV/ 116 | env.bak/ 117 | venv.bak/ 118 | 119 | # Spyder project settings 120 | .spyderproject 121 | .spyproject 122 | 123 | # Rope project settings 124 | .ropeproject 125 | 126 | # mkdocs documentation 127 | /site 128 | 129 | # mypy 130 | .mypy_cache/ 131 | .dmypy.json 132 | dmypy.json 133 | 134 | # Pyre type checker 135 | .pyre/ 136 | output/ 137 | 138 | *.out 139 | *.log 140 | *.csv 141 | *.png 142 | *.pth 143 | *.pt 144 | *.ptl 145 | *.dec 146 | wandb/* -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) SCALE Lab, Brown University. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MTLoRA: A Low-Rank Adaptation Approach for Efficient Multi-Task Learning 2 | 3 | ## Introduction 4 | 5 | This is the official implementation of the paper: **MTLoRA: A Low-Rank Adaptation Approach for Efficient Multi-Task Learning** developed at [Brown University SCALE lab](https://scale-lab.github.io). 6 | 7 | This repository provides a Python-based implementation of MTLoRA including [`MTLoRALinear`](models/lora.py) (the main module) and MTL architectures. 8 | 9 | The repository is built on top of [Swin-Transformer](https://github.com/microsoft/Swin-Transformer) and uses some modules from [Multi-Task-Learning-PyTorch](https://github.com/SimonVandenhende/Multi-Task-Learning-PyTorch). 10 | 11 | 12 | ## How to Run 13 | 14 | Running MTLoRA code, is very simmilar to Swin's codebase: 15 | 16 | 1. **Clone the repository** 17 | ```bash 18 | git clone https://github.com/scale-lab/MTLoRA.git 19 | cd MTLoRA 20 | ``` 21 | 22 | 2. **Install the prerequisites** 23 | - Install `PyTorch>=1.12.0` and `torchvision>=0.13.0` with `CUDA>=11.6` 24 | - Install dependencies: `pip install -r requirements.txt` 25 | 26 | 3. **Run the code** 27 | ```python 28 | python -m torch.distributed.launch --nproc_per_node 1 --master_port 12345 main.py --cfg configs/mtlora/tiny_448/.yaml --pascal --tasks semseg,normals,sal,human_parts --batch-size 32 --ckpt-freq=20 --epoch=300 --resume-backbone 29 | ``` 30 | Swin variants and their weights can be found at the official [Swin Transformer repository](https://github.com/microsoft/Swin-Transformer). 31 | 32 | The outputs will be saved in `output/` folder unless overridden by the argument `--output`. 33 | 34 | 4. **Using the pre-trained model** 35 | 36 | You can download the model weights from the following [link](https://drive.google.com/file/d/1AzzOgX6X0VFKyXUBXhwlgmba5NbPUq3m/view?usp=drive_link). 37 | 38 | To run and evaluate the pre-trained model (assuming the model weight file is at `./mtlora.pth`), use `--eval` and `--resume ` as follows: 39 | ```python 40 | python -m torch.distributed.launch --nproc_per_node 1 --master_port 12345 main.py --cfg configs/mtlora/tiny_448/mtlora_tiny_448_r64_scale4_pertask.yaml --pascal --tasks semseg,normals,sal,human_parts --batch-size 32 --resume ./mtlora.pth --eval 41 | ``` 42 | 43 | ## Authorship 44 | Since the release commit is squashed, the GitHub contributors tab doesn't reflect the authors' contributions. The following authors contributed equally to this codebase: 45 | - [Ahmed Agiza](https://github.com/ahmed-agiza) 46 | - [Marina Neseem](https://github.com/marina-neseem) 47 | 48 | ## Citation 49 | If you find MTLoRA helpful in your research, please cite our paper: 50 | ``` 51 | @inproceedings{agiza2024mtlora, 52 | title={MTLoRA: Low-Rank Adaptation Approach for Efficient Multi-Task Learning}, 53 | author={Agiza, Ahmed and Neseem, Marina and Reda, Sherief}, 54 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 55 | pages={16196--16205}, 56 | year={2024} 57 | } 58 | ``` 59 | 60 | ## License 61 | MIT License. See [LICENSE](LICENSE) file 62 | -------------------------------------------------------------------------------- /configs/mtlora/tiny_448/mtlora_plus_tiny_448_r16_scale4.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | IMG_SIZE: 448 3 | MODEL: 4 | TYPE: swin 5 | NAME: mtlora_plus_tiny_448_r16_scale4 6 | DROP_PATH_RATE: 0.2 7 | SWIN: 8 | EMBED_DIM: 96 9 | DEPTHS: [2, 2, 6, 2] 10 | NUM_HEADS: [3, 6, 12, 24] 11 | WINDOW_SIZE: 7 12 | MTLORA: 13 | ENABLED: True 14 | R: [16, 16, 16, 16] 15 | SHARED_SCALE: [4.0] 16 | TASK_SCALE: [4.0] 17 | DROPOUT: [0.05, 0.05, 0.05, 0.05] 18 | TRAINABLE_SCALE_SHARED: False 19 | TRAINABLE_SCALE_PER_TASK: False 20 | INTERMEDIATE_SPECIALIZATION: False 21 | FREEZE_PRETRAINED: True 22 | SPLIT_QKV: False 23 | QKV_ENABLED: True 24 | PROJ_ENABLED: True 25 | FC1_ENABLED: True 26 | FC2_ENABLED: True 27 | DOWNSAMPLER_ENABLED: True 28 | DECODER_HEAD: 29 | semseg: hrnet 30 | normals: hrnet 31 | sal: hrnet 32 | human_parts: hrnet 33 | edge: hrnet 34 | depth: hrnet 35 | -------------------------------------------------------------------------------- /configs/mtlora/tiny_448/mtlora_plus_tiny_448_r16_scale4_pertask.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | IMG_SIZE: 448 3 | MODEL: 4 | TYPE: swin 5 | NAME: mtlora_plus_tiny_448_r16_scale4_pertask 6 | DROP_PATH_RATE: 0.2 7 | SWIN: 8 | EMBED_DIM: 96 9 | DEPTHS: [2, 2, 6, 2] 10 | NUM_HEADS: [3, 6, 12, 24] 11 | WINDOW_SIZE: 7 12 | MTLORA: 13 | ENABLED: True 14 | R: [16, 16, 16, 16] 15 | SHARED_SCALE: [4.0] 16 | TASK_SCALE: [4.0] 17 | DROPOUT: [0.05, 0.05, 0.05, 0.05] 18 | TRAINABLE_SCALE_SHARED: False 19 | TRAINABLE_SCALE_PER_TASK: False 20 | INTERMEDIATE_SPECIALIZATION: False 21 | FREEZE_PRETRAINED: True 22 | SPLIT_QKV: False 23 | QKV_ENABLED: True 24 | PROJ_ENABLED: True 25 | FC1_ENABLED: True 26 | FC2_ENABLED: True 27 | DOWNSAMPLER_ENABLED: True 28 | R_PER_TASK: 29 | semseg: [4] 30 | normals: [4] 31 | sal: [4] 32 | human_parts: [4] 33 | edge: [4] 34 | depth: [4] 35 | shared: [16] 36 | DECODER_HEAD: 37 | semseg: hrnet 38 | normals: hrnet 39 | sal: hrnet 40 | human_parts: hrnet 41 | edge: hrnet 42 | depth: hrnet 43 | -------------------------------------------------------------------------------- /configs/mtlora/tiny_448/mtlora_plus_tiny_448_r32_scale4_pertask.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | IMG_SIZE: 448 3 | MODEL: 4 | TYPE: swin 5 | NAME: mtlora_plus_tiny_448_r32_scale4_pertask 6 | DROP_PATH_RATE: 0.2 7 | SWIN: 8 | EMBED_DIM: 96 9 | DEPTHS: [2, 2, 6, 2] 10 | NUM_HEADS: [3, 6, 12, 24] 11 | WINDOW_SIZE: 7 12 | MTLORA: 13 | ENABLED: True 14 | R: [32, 32, 32, 32] 15 | SHARED_SCALE: [4.0] 16 | TASK_SCALE: [4.0] 17 | DROPOUT: [0.05, 0.05, 0.05, 0.05] 18 | TRAINABLE_SCALE_SHARED: False 19 | TRAINABLE_SCALE_PER_TASK: False 20 | INTERMEDIATE_SPECIALIZATION: False 21 | FREEZE_PRETRAINED: True 22 | SPLIT_QKV: False 23 | QKV_ENABLED: True 24 | PROJ_ENABLED: True 25 | FC1_ENABLED: True 26 | FC2_ENABLED: True 27 | DOWNSAMPLER_ENABLED: True 28 | R_PER_TASK: 29 | semseg: [4] 30 | normals: [4] 31 | sal: [4] 32 | human_parts: [4] 33 | edge: [4] 34 | depth: [4] 35 | shared: [32] 36 | DECODER_HEAD: 37 | semseg: hrnet 38 | normals: hrnet 39 | sal: hrnet 40 | human_parts: hrnet 41 | edge: hrnet 42 | depth: hrnet 43 | -------------------------------------------------------------------------------- /configs/mtlora/tiny_448/mtlora_plus_tiny_448_r4_scale4.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | IMG_SIZE: 448 3 | MODEL: 4 | TYPE: swin 5 | NAME: mtlora_plus_tiny_448_r4_scale4 6 | DROP_PATH_RATE: 0.2 7 | SWIN: 8 | EMBED_DIM: 96 9 | DEPTHS: [2, 2, 6, 2] 10 | NUM_HEADS: [3, 6, 12, 24] 11 | WINDOW_SIZE: 7 12 | MTLORA: 13 | ENABLED: True 14 | R: [4, 4, 4, 4] 15 | SHARED_SCALE: [4.0] 16 | TASK_SCALE: [4.0] 17 | DROPOUT: [0.05, 0.05, 0.05, 0.05] 18 | TRAINABLE_SCALE_SHARED: False 19 | TRAINABLE_SCALE_PER_TASK: False 20 | INTERMEDIATE_SPECIALIZATION: False 21 | FREEZE_PRETRAINED: True 22 | SPLIT_QKV: False 23 | QKV_ENABLED: True 24 | PROJ_ENABLED: True 25 | FC1_ENABLED: True 26 | FC2_ENABLED: True 27 | DOWNSAMPLER_ENABLED: True 28 | DECODER_HEAD: 29 | semseg: hrnet 30 | normals: hrnet 31 | sal: hrnet 32 | human_parts: hrnet 33 | edge: hrnet 34 | depth: hrnet 35 | -------------------------------------------------------------------------------- /configs/mtlora/tiny_448/mtlora_plus_tiny_448_r64_scale4_pertask.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | IMG_SIZE: 448 3 | MODEL: 4 | TYPE: swin 5 | NAME: mtlora_plus_tiny_448_r64_scale4_pertask 6 | DROP_PATH_RATE: 0.2 7 | SWIN: 8 | EMBED_DIM: 96 9 | DEPTHS: [2, 2, 6, 2] 10 | NUM_HEADS: [3, 6, 12, 24] 11 | WINDOW_SIZE: 7 12 | MTLORA: 13 | ENABLED: True 14 | R: [64, 64, 64, 64] 15 | SHARED_SCALE: [4.0] 16 | TASK_SCALE: [4.0] 17 | DROPOUT: [0.05, 0.05, 0.05, 0.05] 18 | TRAINABLE_SCALE_SHARED: False 19 | TRAINABLE_SCALE_PER_TASK: False 20 | INTERMEDIATE_SPECIALIZATION: False 21 | FREEZE_PRETRAINED: True 22 | SPLIT_QKV: False 23 | QKV_ENABLED: True 24 | PROJ_ENABLED: True 25 | FC1_ENABLED: True 26 | FC2_ENABLED: True 27 | DOWNSAMPLER_ENABLED: True 28 | R_PER_TASK: 29 | semseg: [4] 30 | normals: [4] 31 | sal: [4] 32 | human_parts: [4] 33 | edge: [4] 34 | depth: [4] 35 | shared: [64] 36 | DECODER_HEAD: 37 | semseg: hrnet 38 | normals: hrnet 39 | sal: hrnet 40 | human_parts: hrnet 41 | edge: hrnet 42 | depth: hrnet 43 | -------------------------------------------------------------------------------- /configs/mtlora/tiny_448/mtlora_plus_tiny_448_r8_scale4.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | IMG_SIZE: 448 3 | MODEL: 4 | TYPE: swin 5 | NAME: mtlora_plus_tiny_448_r8_scale4 6 | DROP_PATH_RATE: 0.2 7 | SWIN: 8 | EMBED_DIM: 96 9 | DEPTHS: [2, 2, 6, 2] 10 | NUM_HEADS: [3, 6, 12, 24] 11 | WINDOW_SIZE: 7 12 | MTLORA: 13 | ENABLED: True 14 | R: [8, 8, 8, 8] 15 | SHARED_SCALE: [4.0] 16 | TASK_SCALE: [4.0] 17 | DROPOUT: [0.05, 0.05, 0.05, 0.05] 18 | TRAINABLE_SCALE_SHARED: False 19 | TRAINABLE_SCALE_PER_TASK: False 20 | INTERMEDIATE_SPECIALIZATION: False 21 | FREEZE_PRETRAINED: True 22 | SPLIT_QKV: False 23 | QKV_ENABLED: True 24 | PROJ_ENABLED: True 25 | FC1_ENABLED: True 26 | FC2_ENABLED: True 27 | DOWNSAMPLER_ENABLED: True 28 | DECODER_HEAD: 29 | semseg: hrnet 30 | normals: hrnet 31 | sal: hrnet 32 | human_parts: hrnet 33 | edge: hrnet 34 | depth: hrnet 35 | -------------------------------------------------------------------------------- /configs/mtlora/tiny_448/mtlora_tiny_448_r16_scale4_pertask.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | IMG_SIZE: 448 3 | MODEL: 4 | TYPE: swin 5 | NAME: mtlora_tiny_448_r16_scale4_pertask 6 | DROP_PATH_RATE: 0.2 7 | SWIN: 8 | EMBED_DIM: 96 9 | DEPTHS: [2, 2, 6, 2] 10 | NUM_HEADS: [3, 6, 12, 24] 11 | WINDOW_SIZE: 7 12 | MTLORA: 13 | ENABLED: True 14 | R: [16, 16, 16, 16] 15 | SHARED_SCALE: [4.0] 16 | TASK_SCALE: [4.0] 17 | DROPOUT: [0.05, 0.05, 0.05, 0.05] 18 | TRAINABLE_SCALE_SHARED: False 19 | TRAINABLE_SCALE_PER_TASK: False 20 | INTERMEDIATE_SPECIALIZATION: False 21 | FREEZE_PRETRAINED: True 22 | SPLIT_QKV: False 23 | QKV_ENABLED: True 24 | PROJ_ENABLED: True 25 | FC1_ENABLED: True 26 | FC2_ENABLED: True 27 | DOWNSAMPLER_ENABLED: False 28 | R_PER_TASK: 29 | semseg: [4] 30 | normals: [4] 31 | sal: [4] 32 | human_parts: [4] 33 | edge: [4] 34 | depth: [4] 35 | shared: [16] 36 | DECODER_HEAD: 37 | semseg: hrnet 38 | normals: hrnet 39 | sal: hrnet 40 | human_parts: hrnet 41 | edge: hrnet 42 | depth: hrnet 43 | -------------------------------------------------------------------------------- /configs/mtlora/tiny_448/mtlora_tiny_448_r32_scale4_pertask.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | IMG_SIZE: 448 3 | MODEL: 4 | TYPE: swin 5 | NAME: mtlora_tiny_448_r32_scale4_pertask 6 | DROP_PATH_RATE: 0.2 7 | SWIN: 8 | EMBED_DIM: 96 9 | DEPTHS: [2, 2, 6, 2] 10 | NUM_HEADS: [3, 6, 12, 24] 11 | WINDOW_SIZE: 7 12 | MTLORA: 13 | ENABLED: True 14 | R: [32, 32, 32, 32] 15 | SHARED_SCALE: [4.0] 16 | TASK_SCALE: [4.0] 17 | DROPOUT: [0.05, 0.05, 0.05, 0.05] 18 | TRAINABLE_SCALE_SHARED: False 19 | TRAINABLE_SCALE_PER_TASK: False 20 | INTERMEDIATE_SPECIALIZATION: False 21 | FREEZE_PRETRAINED: True 22 | SPLIT_QKV: False 23 | QKV_ENABLED: True 24 | PROJ_ENABLED: True 25 | FC1_ENABLED: True 26 | FC2_ENABLED: True 27 | DOWNSAMPLER_ENABLED: False 28 | R_PER_TASK: 29 | semseg: [4] 30 | normals: [4] 31 | sal: [4] 32 | human_parts: [4] 33 | edge: [4] 34 | depth: [4] 35 | shared: [32] 36 | DECODER_HEAD: 37 | semseg: hrnet 38 | normals: hrnet 39 | sal: hrnet 40 | human_parts: hrnet 41 | edge: hrnet 42 | depth: hrnet 43 | -------------------------------------------------------------------------------- /configs/mtlora/tiny_448/mtlora_tiny_448_r64_scale4_pertask.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | IMG_SIZE: 448 3 | MODEL: 4 | TYPE: swin 5 | NAME: mtlora_tiny_448_r64_scale4_pertask 6 | DROP_PATH_RATE: 0.2 7 | SWIN: 8 | EMBED_DIM: 96 9 | DEPTHS: [2, 2, 6, 2] 10 | NUM_HEADS: [3, 6, 12, 24] 11 | WINDOW_SIZE: 7 12 | MTLORA: 13 | ENABLED: True 14 | R: [64, 64, 64, 64] 15 | SHARED_SCALE: [4.0] 16 | TASK_SCALE: [4.0] 17 | DROPOUT: [0.05, 0.05, 0.05, 0.05] 18 | TRAINABLE_SCALE_SHARED: False 19 | TRAINABLE_SCALE_PER_TASK: False 20 | INTERMEDIATE_SPECIALIZATION: False 21 | FREEZE_PRETRAINED: True 22 | SPLIT_QKV: False 23 | QKV_ENABLED: True 24 | PROJ_ENABLED: True 25 | FC1_ENABLED: True 26 | FC2_ENABLED: True 27 | DOWNSAMPLER_ENABLED: False 28 | R_PER_TASK: 29 | semseg: [4] 30 | normals: [4] 31 | sal: [4] 32 | human_parts: [4] 33 | edge: [4] 34 | depth: [4] 35 | shared: [64] 36 | DECODER_HEAD: 37 | semseg: hrnet 38 | normals: hrnet 39 | sal: hrnet 40 | human_parts: hrnet 41 | edge: hrnet 42 | depth: hrnet 43 | -------------------------------------------------------------------------------- /configs/swin/swin_tiny_patch4_window7_448.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | IMG_SIZE: 448 3 | MODEL: 4 | TYPE: swin 5 | NAME: swin_tiny_patch4_window7_448 6 | DROP_PATH_RATE: 0.2 7 | SWIN: 8 | EMBED_DIM: 96 9 | DEPTHS: [2, 2, 6, 2] 10 | NUM_HEADS: [3, 6, 12, 24] 11 | WINDOW_SIZE: 7 12 | DECODER_HEAD: 13 | semseg: hrnet 14 | normals: hrnet 15 | sal: hrnet 16 | human_parts: hrnet 17 | edge: hrnet 18 | depth: hrnet 19 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import build_loader as _build_loader, build_nyud, build_pascal 2 | from .data_simmim_pt import build_loader_simmim 3 | from .data_simmim_ft import build_loader_finetune 4 | 5 | 6 | def build_loader(config, simmim=False, is_pretrain=False, val_only=False): 7 | if config.get('DATA', {}).get('NYUD', False): 8 | return build_nyud(config) 9 | if config.get('DATA', {}).get('PASCAL', False): 10 | return build_pascal(config, val_only) 11 | if not simmim: 12 | return _build_loader(config) 13 | if is_pretrain: 14 | return build_loader_simmim(config) 15 | else: 16 | return build_loader_finetune(config) 17 | -------------------------------------------------------------------------------- /data/build.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # MTLoRA 3 | # GitHub: https://github.com/scale-lab/MTLoRA 4 | # Built upon Swin Transformer (https://github.com/microsoft/Swin-Transformer) 5 | # 6 | # Original file: 7 | # Copyright (c) 2021 Microsoft 8 | # Licensed under the MIT License 9 | # Written by Ze Liu 10 | # 11 | # Modifications: 12 | # Copyright (c) 2024 SCALE Lab, Brown University 13 | # Licensed under the MIT License (see LICENSE for details) 14 | # -------------------------------------------------------- 15 | 16 | 17 | import os 18 | import torch 19 | import numpy as np 20 | import torch.distributed as dist 21 | from torchvision import datasets, transforms 22 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 23 | from timm.data import Mixup 24 | from timm.data import create_transform 25 | 26 | from .cached_image_folder import CachedImageFolder 27 | from .imagenet22k_dataset import IN22KDATASET 28 | from .samplers import SubsetRandomSampler 29 | from data.mtl_ds import get_mtl_train_dataset, get_mtl_train_dataloader, get_mtl_val_dataset, get_mtl_val_dataloader, get_transformations 30 | import re 31 | 32 | 33 | try: 34 | from torchvision.transforms import InterpolationMode 35 | 36 | def _pil_interp(method): 37 | if method == 'bicubic': 38 | return InterpolationMode.BICUBIC 39 | elif method == 'lanczos': 40 | return InterpolationMode.LANCZOS 41 | elif method == 'hamming': 42 | return InterpolationMode.HAMMING 43 | else: 44 | # default bilinear, do we want to allow nearest? 45 | return InterpolationMode.BILINEAR 46 | 47 | import timm.data.transforms as timm_transforms 48 | 49 | timm_transforms._pil_interp = _pil_interp 50 | except: 51 | from timm.data.transforms import _pil_interp 52 | 53 | 54 | def build_loader(config): 55 | config.defrost() 56 | dataset_train, config.MODEL.NUM_CLASSES = build_dataset( 57 | is_train=True, config=config) 58 | config.freeze() 59 | print( 60 | f"local rank {config.LOCAL_RANK} / global rank {dist.get_rank()} successfully build train dataset") 61 | dataset_val, _ = build_dataset(is_train=False, config=config) 62 | print( 63 | f"local rank {config.LOCAL_RANK} / global rank {dist.get_rank()} successfully build val dataset") 64 | 65 | num_tasks = dist.get_world_size() 66 | global_rank = dist.get_rank() 67 | if config.DATA.ZIP_MODE and config.DATA.CACHE_MODE == 'part': 68 | indices = np.arange(dist.get_rank(), len( 69 | dataset_train), dist.get_world_size()) 70 | sampler_train = SubsetRandomSampler(indices) 71 | else: 72 | sampler_train = torch.utils.data.DistributedSampler( 73 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 74 | ) 75 | 76 | if config.TEST.SEQUENTIAL: 77 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 78 | else: 79 | sampler_val = torch.utils.data.distributed.DistributedSampler( 80 | dataset_val, shuffle=config.TEST.SHUFFLE 81 | ) 82 | 83 | data_loader_train = torch.utils.data.DataLoader( 84 | dataset_train, sampler=sampler_train, 85 | batch_size=config.DATA.BATCH_SIZE, 86 | num_workers=config.DATA.NUM_WORKERS, 87 | pin_memory=config.DATA.PIN_MEMORY, 88 | drop_last=True, 89 | ) 90 | 91 | data_loader_val = torch.utils.data.DataLoader( 92 | dataset_val, sampler=sampler_val, 93 | batch_size=config.DATA.BATCH_SIZE, 94 | shuffle=False, 95 | num_workers=config.DATA.NUM_WORKERS, 96 | pin_memory=config.DATA.PIN_MEMORY, 97 | drop_last=False 98 | ) 99 | 100 | # setup mixup / cutmix 101 | mixup_fn = None 102 | mixup_active = config.AUG.MIXUP > 0 or config.AUG.CUTMIX > 0. or config.AUG.CUTMIX_MINMAX is not None 103 | if mixup_active: 104 | mixup_fn = Mixup( 105 | mixup_alpha=config.AUG.MIXUP, cutmix_alpha=config.AUG.CUTMIX, cutmix_minmax=config.AUG.CUTMIX_MINMAX, 106 | prob=config.AUG.MIXUP_PROB, switch_prob=config.AUG.MIXUP_SWITCH_PROB, mode=config.AUG.MIXUP_MODE, 107 | label_smoothing=config.MODEL.LABEL_SMOOTHING, num_classes=config.MODEL.NUM_CLASSES) 108 | 109 | return dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn 110 | 111 | 112 | def build_dataset(is_train, config): 113 | transform = build_transform(is_train, config) 114 | if config.DATA.DATASET == 'imagenet': 115 | prefix = 'train' if is_train else 'val' 116 | if config.DATA.ZIP_MODE: 117 | ann_file = prefix + "_map.txt" 118 | prefix = prefix + ".zip@/" 119 | dataset = CachedImageFolder(config.DATA.DATA_PATH, ann_file, prefix, transform, 120 | cache_mode=config.DATA.CACHE_MODE if is_train else 'part') 121 | else: 122 | root = os.path.join(config.DATA.DATA_PATH, prefix) 123 | dataset = datasets.ImageFolder(root, transform=transform) 124 | nb_classes = 1000 125 | elif config.DATA.DATASET == 'imagenet22K': 126 | prefix = 'ILSVRC2011fall_whole' 127 | if is_train: 128 | ann_file = prefix + "_map_train.txt" 129 | else: 130 | ann_file = prefix + "_map_val.txt" 131 | dataset = IN22KDATASET(config.DATA.DATA_PATH, ann_file, transform) 132 | nb_classes = 21841 133 | else: 134 | raise NotImplementedError("We only support ImageNet Now.") 135 | 136 | return dataset, nb_classes 137 | 138 | 139 | def build_transform(is_train, config): 140 | resize_im = config.DATA.IMG_SIZE > 32 141 | if is_train: 142 | # this should always dispatch to transforms_imagenet_train 143 | transform = create_transform( 144 | input_size=config.DATA.IMG_SIZE, 145 | is_training=True, 146 | color_jitter=config.AUG.COLOR_JITTER if config.AUG.COLOR_JITTER > 0 else None, 147 | auto_augment=config.AUG.AUTO_AUGMENT if config.AUG.AUTO_AUGMENT != 'none' else None, 148 | re_prob=config.AUG.REPROB, 149 | re_mode=config.AUG.REMODE, 150 | re_count=config.AUG.RECOUNT, 151 | interpolation=config.DATA.INTERPOLATION, 152 | ) 153 | if not resize_im: 154 | # replace RandomResizedCropAndInterpolation with 155 | # RandomCrop 156 | transform.transforms[0] = transforms.RandomCrop( 157 | config.DATA.IMG_SIZE, padding=4) 158 | return transform 159 | 160 | t = [] 161 | if resize_im: 162 | if config.TEST.CROP: 163 | size = int((256 / 224) * config.DATA.IMG_SIZE) 164 | t.append( 165 | transforms.Resize(size, interpolation=_pil_interp( 166 | config.DATA.INTERPOLATION)), 167 | # to maintain same ratio w.r.t. 224 images 168 | ) 169 | t.append(transforms.CenterCrop(config.DATA.IMG_SIZE)) 170 | else: 171 | t.append( 172 | transforms.Resize((config.DATA.IMG_SIZE, config.DATA.IMG_SIZE), 173 | interpolation=_pil_interp(config.DATA.INTERPOLATION)) 174 | ) 175 | 176 | t.append(transforms.ToTensor()) 177 | t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)) 178 | return transforms.Compose(t) 179 | 180 | 181 | def build_mtl(config, db_name="NYUD", val_only=False): 182 | config.defrost() 183 | if val_only: 184 | _, val_transforms = get_transformations(db_name, config.TASKS_CONFIG) 185 | dataset_val = get_mtl_val_dataset(db_name, config, val_transforms) 186 | data_loader_val = get_mtl_val_dataloader(config, dataset_val) 187 | return dataset_val, data_loader_val 188 | 189 | print(f"Loading {db_name} dataset") 190 | print("===============") 191 | 192 | train_transforms, val_transforms = get_transformations( 193 | db_name, config.TASKS_CONFIG) 194 | dataset_train = get_mtl_train_dataset( 195 | db_name, config, train_transforms) 196 | dataset_val = get_mtl_val_dataset(db_name, config, val_transforms) 197 | data_loader_train = get_mtl_train_dataloader(config, dataset_train) 198 | data_loader_val = get_mtl_val_dataloader(config, dataset_val) 199 | mixup_fn = None 200 | 201 | return dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn 202 | 203 | 204 | def build_nyud(config): 205 | return build_mtl(config, 'NYUD') 206 | 207 | 208 | def build_pascal(config, val_only): 209 | return build_mtl(config, 'PASCALContext', val_only=val_only) 210 | -------------------------------------------------------------------------------- /data/cached_image_folder.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | 8 | import io 9 | import os 10 | import time 11 | import torch.distributed as dist 12 | import torch.utils.data as data 13 | from PIL import Image 14 | 15 | from .zipreader import is_zip_path, ZipReader 16 | 17 | 18 | def has_file_allowed_extension(filename, extensions): 19 | """Checks if a file is an allowed extension. 20 | Args: 21 | filename (string): path to a file 22 | Returns: 23 | bool: True if the filename ends with a known image extension 24 | """ 25 | filename_lower = filename.lower() 26 | return any(filename_lower.endswith(ext) for ext in extensions) 27 | 28 | 29 | def find_classes(dir): 30 | classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] 31 | classes.sort() 32 | class_to_idx = {classes[i]: i for i in range(len(classes))} 33 | return classes, class_to_idx 34 | 35 | 36 | def make_dataset(dir, class_to_idx, extensions): 37 | images = [] 38 | dir = os.path.expanduser(dir) 39 | for target in sorted(os.listdir(dir)): 40 | d = os.path.join(dir, target) 41 | if not os.path.isdir(d): 42 | continue 43 | 44 | for root, _, fnames in sorted(os.walk(d)): 45 | for fname in sorted(fnames): 46 | if has_file_allowed_extension(fname, extensions): 47 | path = os.path.join(root, fname) 48 | item = (path, class_to_idx[target]) 49 | images.append(item) 50 | 51 | return images 52 | 53 | 54 | def make_dataset_with_ann(ann_file, img_prefix, extensions): 55 | images = [] 56 | with open(ann_file, "r") as f: 57 | contents = f.readlines() 58 | for line_str in contents: 59 | path_contents = [c for c in line_str.split('\t')] 60 | im_file_name = path_contents[0] 61 | class_index = int(path_contents[1]) 62 | 63 | assert str.lower(os.path.splitext(im_file_name)[-1]) in extensions 64 | item = (os.path.join(img_prefix, im_file_name), class_index) 65 | 66 | images.append(item) 67 | 68 | return images 69 | 70 | 71 | class DatasetFolder(data.Dataset): 72 | """A generic data loader where the samples are arranged in this way: :: 73 | root/class_x/xxx.ext 74 | root/class_x/xxy.ext 75 | root/class_x/xxz.ext 76 | root/class_y/123.ext 77 | root/class_y/nsdf3.ext 78 | root/class_y/asd932_.ext 79 | Args: 80 | root (string): Root directory path. 81 | loader (callable): A function to load a sample given its path. 82 | extensions (list[string]): A list of allowed extensions. 83 | transform (callable, optional): A function/transform that takes in 84 | a sample and returns a transformed version. 85 | E.g, ``transforms.RandomCrop`` for images. 86 | target_transform (callable, optional): A function/transform that takes 87 | in the target and transforms it. 88 | Attributes: 89 | samples (list): List of (sample path, class_index) tuples 90 | """ 91 | 92 | def __init__(self, root, loader, extensions, ann_file='', img_prefix='', transform=None, target_transform=None, 93 | cache_mode="no"): 94 | # image folder mode 95 | if ann_file == '': 96 | _, class_to_idx = find_classes(root) 97 | samples = make_dataset(root, class_to_idx, extensions) 98 | # zip mode 99 | else: 100 | samples = make_dataset_with_ann(os.path.join(root, ann_file), 101 | os.path.join(root, img_prefix), 102 | extensions) 103 | 104 | if len(samples) == 0: 105 | raise (RuntimeError("Found 0 files in subfolders of: " + root + "\n" + 106 | "Supported extensions are: " + ",".join(extensions))) 107 | 108 | self.root = root 109 | self.loader = loader 110 | self.extensions = extensions 111 | 112 | self.samples = samples 113 | self.labels = [y_1k for _, y_1k in samples] 114 | self.classes = list(set(self.labels)) 115 | 116 | self.transform = transform 117 | self.target_transform = target_transform 118 | 119 | self.cache_mode = cache_mode 120 | if self.cache_mode != "no": 121 | self.init_cache() 122 | 123 | def init_cache(self): 124 | assert self.cache_mode in ["part", "full"] 125 | n_sample = len(self.samples) 126 | global_rank = dist.get_rank() 127 | world_size = dist.get_world_size() 128 | 129 | samples_bytes = [None for _ in range(n_sample)] 130 | start_time = time.time() 131 | for index in range(n_sample): 132 | if index % (n_sample // 10) == 0: 133 | t = time.time() - start_time 134 | print(f'global_rank {dist.get_rank()} cached {index}/{n_sample} takes {t:.2f}s per block') 135 | start_time = time.time() 136 | path, target = self.samples[index] 137 | if self.cache_mode == "full": 138 | samples_bytes[index] = (ZipReader.read(path), target) 139 | elif self.cache_mode == "part" and index % world_size == global_rank: 140 | samples_bytes[index] = (ZipReader.read(path), target) 141 | else: 142 | samples_bytes[index] = (path, target) 143 | self.samples = samples_bytes 144 | 145 | def __getitem__(self, index): 146 | """ 147 | Args: 148 | index (int): Index 149 | Returns: 150 | tuple: (sample, target) where target is class_index of the target class. 151 | """ 152 | path, target = self.samples[index] 153 | sample = self.loader(path) 154 | if self.transform is not None: 155 | sample = self.transform(sample) 156 | if self.target_transform is not None: 157 | target = self.target_transform(target) 158 | 159 | return sample, target 160 | 161 | def __len__(self): 162 | return len(self.samples) 163 | 164 | def __repr__(self): 165 | fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' 166 | fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) 167 | fmt_str += ' Root Location: {}\n'.format(self.root) 168 | tmp = ' Transforms (if any): ' 169 | fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 170 | tmp = ' Target Transforms (if any): ' 171 | fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 172 | return fmt_str 173 | 174 | 175 | IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif'] 176 | 177 | 178 | def pil_loader(path): 179 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 180 | if isinstance(path, bytes): 181 | img = Image.open(io.BytesIO(path)) 182 | elif is_zip_path(path): 183 | data = ZipReader.read(path) 184 | img = Image.open(io.BytesIO(data)) 185 | else: 186 | with open(path, 'rb') as f: 187 | img = Image.open(f) 188 | return img.convert('RGB') 189 | return img.convert('RGB') 190 | 191 | 192 | def accimage_loader(path): 193 | import accimage 194 | try: 195 | return accimage.Image(path) 196 | except IOError: 197 | # Potentially a decoding problem, fall back to PIL.Image 198 | return pil_loader(path) 199 | 200 | 201 | def default_img_loader(path): 202 | from torchvision import get_image_backend 203 | if get_image_backend() == 'accimage': 204 | return accimage_loader(path) 205 | else: 206 | return pil_loader(path) 207 | 208 | 209 | class CachedImageFolder(DatasetFolder): 210 | """A generic data loader where the images are arranged in this way: :: 211 | root/dog/xxx.png 212 | root/dog/xxy.png 213 | root/dog/xxz.png 214 | root/cat/123.png 215 | root/cat/nsdf3.png 216 | root/cat/asd932_.png 217 | Args: 218 | root (string): Root directory path. 219 | transform (callable, optional): A function/transform that takes in an PIL image 220 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 221 | target_transform (callable, optional): A function/transform that takes in the 222 | target and transforms it. 223 | loader (callable, optional): A function to load an image given its path. 224 | Attributes: 225 | imgs (list): List of (image path, class_index) tuples 226 | """ 227 | 228 | def __init__(self, root, ann_file='', img_prefix='', transform=None, target_transform=None, 229 | loader=default_img_loader, cache_mode="no"): 230 | super(CachedImageFolder, self).__init__(root, loader, IMG_EXTENSIONS, 231 | ann_file=ann_file, img_prefix=img_prefix, 232 | transform=transform, target_transform=target_transform, 233 | cache_mode=cache_mode) 234 | self.imgs = self.samples 235 | 236 | def __getitem__(self, index): 237 | """ 238 | Args: 239 | index (int): Index 240 | Returns: 241 | tuple: (image, target) where target is class_index of the target class. 242 | """ 243 | path, target = self.samples[index] 244 | image = self.loader(path) 245 | if self.transform is not None: 246 | img = self.transform(image) 247 | else: 248 | img = image 249 | if self.target_transform is not None: 250 | target = self.target_transform(target) 251 | 252 | return img, target 253 | -------------------------------------------------------------------------------- /data/data_simmim_ft.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # SimMIM 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Zhenda Xie 6 | # -------------------------------------------------------- 7 | 8 | import os 9 | import torch.distributed as dist 10 | from torch.utils.data import DataLoader, DistributedSampler 11 | from torchvision import datasets, transforms 12 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 13 | from timm.data import Mixup 14 | from timm.data import create_transform 15 | from timm.data.transforms import _pil_interp 16 | 17 | 18 | def build_loader_finetune(config): 19 | config.defrost() 20 | dataset_train, config.MODEL.NUM_CLASSES = build_dataset(is_train=True, config=config) 21 | config.freeze() 22 | dataset_val, _ = build_dataset(is_train=False, config=config) 23 | 24 | num_tasks = dist.get_world_size() 25 | global_rank = dist.get_rank() 26 | sampler_train = DistributedSampler( 27 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 28 | ) 29 | sampler_val = DistributedSampler( 30 | dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False 31 | ) 32 | 33 | data_loader_train = DataLoader( 34 | dataset_train, sampler=sampler_train, 35 | batch_size=config.DATA.BATCH_SIZE, 36 | num_workers=config.DATA.NUM_WORKERS, 37 | pin_memory=config.DATA.PIN_MEMORY, 38 | drop_last=True, 39 | ) 40 | 41 | data_loader_val = DataLoader( 42 | dataset_val, sampler=sampler_val, 43 | batch_size=config.DATA.BATCH_SIZE, 44 | num_workers=config.DATA.NUM_WORKERS, 45 | pin_memory=config.DATA.PIN_MEMORY, 46 | drop_last=False, 47 | ) 48 | 49 | # setup mixup / cutmix 50 | mixup_fn = None 51 | mixup_active = config.AUG.MIXUP > 0 or config.AUG.CUTMIX > 0. or config.AUG.CUTMIX_MINMAX is not None 52 | if mixup_active: 53 | mixup_fn = Mixup( 54 | mixup_alpha=config.AUG.MIXUP, cutmix_alpha=config.AUG.CUTMIX, cutmix_minmax=config.AUG.CUTMIX_MINMAX, 55 | prob=config.AUG.MIXUP_PROB, switch_prob=config.AUG.MIXUP_SWITCH_PROB, mode=config.AUG.MIXUP_MODE, 56 | label_smoothing=config.MODEL.LABEL_SMOOTHING, num_classes=config.MODEL.NUM_CLASSES) 57 | 58 | return dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn 59 | 60 | 61 | def build_dataset(is_train, config): 62 | transform = build_transform(is_train, config) 63 | 64 | if config.DATA.DATASET == 'imagenet': 65 | prefix = 'train' if is_train else 'val' 66 | root = os.path.join(config.DATA.DATA_PATH, prefix) 67 | dataset = datasets.ImageFolder(root, transform=transform) 68 | nb_classes = 1000 69 | else: 70 | raise NotImplementedError("We only support ImageNet Now.") 71 | 72 | return dataset, nb_classes 73 | 74 | 75 | def build_transform(is_train, config): 76 | resize_im = config.DATA.IMG_SIZE > 32 77 | if is_train: 78 | # this should always dispatch to transforms_imagenet_train 79 | transform = create_transform( 80 | input_size=config.DATA.IMG_SIZE, 81 | is_training=True, 82 | color_jitter=config.AUG.COLOR_JITTER if config.AUG.COLOR_JITTER > 0 else None, 83 | auto_augment=config.AUG.AUTO_AUGMENT if config.AUG.AUTO_AUGMENT != 'none' else None, 84 | re_prob=config.AUG.REPROB, 85 | re_mode=config.AUG.REMODE, 86 | re_count=config.AUG.RECOUNT, 87 | interpolation=config.DATA.INTERPOLATION, 88 | ) 89 | if not resize_im: 90 | # replace RandomResizedCropAndInterpolation with 91 | # RandomCrop 92 | transform.transforms[0] = transforms.RandomCrop(config.DATA.IMG_SIZE, padding=4) 93 | return transform 94 | 95 | t = [] 96 | if resize_im: 97 | if config.TEST.CROP: 98 | size = int((256 / 224) * config.DATA.IMG_SIZE) 99 | t.append( 100 | transforms.Resize(size, interpolation=_pil_interp(config.DATA.INTERPOLATION)), 101 | # to maintain same ratio w.r.t. 224 images 102 | ) 103 | t.append(transforms.CenterCrop(config.DATA.IMG_SIZE)) 104 | else: 105 | t.append( 106 | transforms.Resize((config.DATA.IMG_SIZE, config.DATA.IMG_SIZE), 107 | interpolation=_pil_interp(config.DATA.INTERPOLATION)) 108 | ) 109 | 110 | t.append(transforms.ToTensor()) 111 | t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)) 112 | return transforms.Compose(t) 113 | -------------------------------------------------------------------------------- /data/data_simmim_pt.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # SimMIM 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Zhenda Xie 6 | # -------------------------------------------------------- 7 | 8 | import math 9 | import random 10 | import numpy as np 11 | 12 | import torch 13 | import torch.distributed as dist 14 | import torchvision.transforms as T 15 | from torch.utils.data import DataLoader, DistributedSampler 16 | from torch.utils.data._utils.collate import default_collate 17 | from torchvision.datasets import ImageFolder 18 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 19 | 20 | 21 | class MaskGenerator: 22 | def __init__(self, input_size=192, mask_patch_size=32, model_patch_size=4, mask_ratio=0.6): 23 | self.input_size = input_size 24 | self.mask_patch_size = mask_patch_size 25 | self.model_patch_size = model_patch_size 26 | self.mask_ratio = mask_ratio 27 | 28 | assert self.input_size % self.mask_patch_size == 0 29 | assert self.mask_patch_size % self.model_patch_size == 0 30 | 31 | self.rand_size = self.input_size // self.mask_patch_size 32 | self.scale = self.mask_patch_size // self.model_patch_size 33 | 34 | self.token_count = self.rand_size ** 2 35 | self.mask_count = int(np.ceil(self.token_count * self.mask_ratio)) 36 | 37 | def __call__(self): 38 | mask_idx = np.random.permutation(self.token_count)[:self.mask_count] 39 | mask = np.zeros(self.token_count, dtype=int) 40 | mask[mask_idx] = 1 41 | 42 | mask = mask.reshape((self.rand_size, self.rand_size)) 43 | mask = mask.repeat(self.scale, axis=0).repeat(self.scale, axis=1) 44 | 45 | return mask 46 | 47 | 48 | class SimMIMTransform: 49 | def __init__(self, config): 50 | self.transform_img = T.Compose([ 51 | T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), 52 | T.RandomResizedCrop(config.DATA.IMG_SIZE, scale=(0.67, 1.), ratio=(3. / 4., 4. / 3.)), 53 | T.RandomHorizontalFlip(), 54 | T.ToTensor(), 55 | T.Normalize(mean=torch.tensor(IMAGENET_DEFAULT_MEAN),std=torch.tensor(IMAGENET_DEFAULT_STD)), 56 | ]) 57 | 58 | if config.MODEL.TYPE in ['swin', 'swinv2']: 59 | model_patch_size=config.MODEL.SWIN.PATCH_SIZE 60 | else: 61 | raise NotImplementedError 62 | 63 | self.mask_generator = MaskGenerator( 64 | input_size=config.DATA.IMG_SIZE, 65 | mask_patch_size=config.DATA.MASK_PATCH_SIZE, 66 | model_patch_size=model_patch_size, 67 | mask_ratio=config.DATA.MASK_RATIO, 68 | ) 69 | 70 | def __call__(self, img): 71 | img = self.transform_img(img) 72 | mask = self.mask_generator() 73 | 74 | return img, mask 75 | 76 | 77 | def collate_fn(batch): 78 | if not isinstance(batch[0][0], tuple): 79 | return default_collate(batch) 80 | else: 81 | batch_num = len(batch) 82 | ret = [] 83 | for item_idx in range(len(batch[0][0])): 84 | if batch[0][0][item_idx] is None: 85 | ret.append(None) 86 | else: 87 | ret.append(default_collate([batch[i][0][item_idx] for i in range(batch_num)])) 88 | ret.append(default_collate([batch[i][1] for i in range(batch_num)])) 89 | return ret 90 | 91 | 92 | def build_loader_simmim(config): 93 | transform = SimMIMTransform(config) 94 | dataset = ImageFolder(config.DATA.DATA_PATH, transform) 95 | 96 | sampler = DistributedSampler(dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=True) 97 | dataloader = DataLoader(dataset, config.DATA.BATCH_SIZE, sampler=sampler, num_workers=config.DATA.NUM_WORKERS, pin_memory=True, drop_last=True, collate_fn=collate_fn) 98 | 99 | return dataloader -------------------------------------------------------------------------------- /data/db_info/context_classes.json: -------------------------------------------------------------------------------- 1 | {"accordion": 1, "aeroplane": 2, "air conditioner": 3, "antenna": 4, "artillery": 5, "ashtray": 6, "atrium": 7, "baby carriage": 8, "bag": 9, "ball": 10, "balloon": 11, "bamboo weaving": 12, "barrel": 13, "baseball bat": 14, "basket": 15, "basketball backboard": 16, "bathtub": 17, "bed": 18, "bedclothes": 19, "beer": 20, "bell": 21, "bench": 22, "bicycle": 23, "binoculars": 24, "bird": 25, "bird cage": 26, "bird feeder": 27, "bird nest": 28, "blackboard": 29, "board": 30, "boat": 31, "bone": 32, "book": 33, "bottle": 34, "bottle opener": 35, "bowl": 36, "box": 37, "bracelet": 38, "brick": 39, "bridge": 40, "broom": 41, "brush": 42, "bucket": 43, "building": 44, "bus": 45, "cabinet": 46, "cabinet door": 47, "cage": 48, "cake": 49, "calculator": 50, "calendar": 51, "camel": 52, "camera": 53, "camera lens": 54, "can": 55, "candle": 56, "candle holder": 57, "cap": 58, "car": 59, "card": 60, "cart": 61, "case": 62, "casette recorder": 63, "cash register": 64, "cat": 65, "cd": 66, "cd player": 67, "ceiling": 68, "cell phone": 69, "cello": 70, "chain": 71, "chair": 72, "chessboard": 73, "chicken": 74, "chopstick": 75, "clip": 76, "clippers": 77, "clock": 78, "closet": 79, "cloth": 80, "clothes tree": 81, "coffee": 82, "coffee machine": 83, "comb": 84, "computer": 85, "concrete": 86, "cone": 87, "container": 88, "control booth": 89, "controller": 90, "cooker": 91, "copying machine": 92, "coral": 93, "cork": 94, "corkscrew": 95, "counter": 96, "court": 97, "cow": 98, "crabstick": 99, "crane": 100, "crate": 101, "cross": 102, "crutch": 103, "cup": 104, "curtain": 105, "cushion": 106, "cutting board": 107, "dais": 108, "disc": 109, "disc case": 110, "dishwasher": 111, "dock": 112, "dog": 113, "dolphin": 114, "door": 115, "drainer": 116, "dray": 117, "drink dispenser": 118, "drinking machine": 119, "drop": 120, "drug": 121, "drum": 122, "drum kit": 123, "duck": 124, "dumbbell": 125, "earphone": 126, "earrings": 127, "egg": 128, "electric fan": 129, "electric iron": 130, "electric pot": 131, "electric saw": 132, "electronic keyboard": 133, "engine": 134, "envelope": 135, "equipment": 136, "escalator": 137, "exhibition booth": 138, "extinguisher": 139, "eyeglass": 140, "fan": 141, "faucet": 142, "fax machine": 143, "fence": 144, "ferris wheel": 145, "fire extinguisher": 146, "fire hydrant": 147, "fire place": 148, "fish": 149, "fish tank": 150, "fishbowl": 151, "fishing net": 152, "fishing pole": 153, "flag": 154, "flagstaff": 155, "flame": 156, "flashlight": 157, "floor": 158, "flower": 159, "fly": 160, "foam": 161, "food": 162, "footbridge": 163, "forceps": 164, "fork": 165, "forklift": 166, "fountain": 167, "fox": 168, "frame": 169, "fridge": 170, "frog": 171, "fruit": 172, "funnel": 173, "furnace": 174, "game controller": 175, "game machine": 176, "gas cylinder": 177, "gas hood": 178, "gas stove": 179, "gift box": 180, "glass": 181, "glass marble": 182, "globe": 183, "glove": 184, "goal": 185, "grandstand": 186, "grass": 187, "gravestone": 188, "ground": 189, "guardrail": 190, "guitar": 191, "gun": 192, "hammer": 193, "hand cart": 194, "handle": 195, "handrail": 196, "hanger": 197, "hard disk drive": 198, "hat": 199, "hay": 200, "headphone": 201, "heater": 202, "helicopter": 203, "helmet": 204, "holder": 205, "hook": 206, "horse": 207, "horse-drawn carriage": 208, "hot-air balloon": 209, "hydrovalve": 210, "ice": 211, "inflator pump": 212, "ipod": 213, "iron": 214, "ironing board": 215, "jar": 216, "kart": 217, "kettle": 218, "key": 219, "keyboard": 220, "kitchen range": 221, "kite": 222, "knife": 223, "knife block": 224, "ladder": 225, "ladder truck": 226, "ladle": 227, "laptop": 228, "leaves": 229, "lid": 230, "life buoy": 231, "light": 232, "light bulb": 233, "lighter": 234, "line": 235, "lion": 236, "lobster": 237, "lock": 238, "machine": 239, "mailbox": 240, "mannequin": 241, "map": 242, "mask": 243, "mat": 244, "match book": 245, "mattress": 246, "menu": 247, "metal": 248, "meter box": 249, "microphone": 250, "microwave": 251, "mirror": 252, "missile": 253, "model": 254, "money": 255, "monkey": 256, "mop": 257, "motorbike": 258, "mountain": 259, "mouse": 260, "mouse pad": 261, "musical instrument": 262, "napkin": 263, "net": 264, "newspaper": 265, "oar": 266, "ornament": 267, "outlet": 268, "oven": 269, "oxygen bottle": 270, "pack": 271, "pan": 272, "paper": 273, "paper box": 274, "paper cutter": 275, "parachute": 276, "parasol": 277, "parterre": 278, "patio": 279, "pelage": 280, "pen": 281, "pen container": 282, "pencil": 283, "person": 284, "photo": 285, "piano": 286, "picture": 287, "pig": 288, "pillar": 289, "pillow": 290, "pipe": 291, "pitcher": 292, "plant": 293, "plastic": 294, "plate": 295, "platform": 296, "player": 297, "playground": 298, "pliers": 299, "plume": 300, "poker": 301, "poker chip": 302, "pole": 303, "pool table": 304, "postcard": 305, "poster": 306, "pot": 307, "pottedplant": 308, "printer": 309, "projector": 310, "pumpkin": 311, "rabbit": 312, "racket": 313, "radiator": 314, "radio": 315, "rail": 316, "rake": 317, "ramp": 318, "range hood": 319, "receiver": 320, "recorder": 321, "recreational machines": 322, "remote control": 323, "road": 324, "robot": 325, "rock": 326, "rocket": 327, "rocking horse": 328, "rope": 329, "rug": 330, "ruler": 331, "runway": 332, "saddle": 333, "sand": 334, "saw": 335, "scale": 336, "scanner": 337, "scissors": 338, "scoop": 339, "screen": 340, "screwdriver": 341, "sculpture": 342, "scythe": 343, "sewer": 344, "sewing machine": 345, "shed": 346, "sheep": 347, "shell": 348, "shelves": 349, "shoe": 350, "shopping cart": 351, "shovel": 352, "sidecar": 353, "sidewalk": 354, "sign": 355, "signal light": 356, "sink": 357, "skateboard": 358, "ski": 359, "sky": 360, "sled": 361, "slippers": 362, "smoke": 363, "snail": 364, "snake": 365, "snow": 366, "snowmobiles": 367, "sofa": 368, "spanner": 369, "spatula": 370, "speaker": 371, "speed bump": 372, "spice container": 373, "spoon": 374, "sprayer": 375, "squirrel": 376, "stage": 377, "stair": 378, "stapler": 379, "stick": 380, "sticky note": 381, "stone": 382, "stool": 383, "stove": 384, "straw": 385, "stretcher": 386, "sun": 387, "sunglass": 388, "sunshade": 389, "surveillance camera": 390, "swan": 391, "sweeper": 392, "swim ring": 393, "swimming pool": 394, "swing": 395, "switch": 396, "table": 397, "tableware": 398, "tank": 399, "tap": 400, "tape": 401, "tarp": 402, "telephone": 403, "telephone booth": 404, "tent": 405, "tire": 406, "toaster": 407, "toilet": 408, "tong": 409, "tool": 410, "toothbrush": 411, "towel": 412, "toy": 413, "toy car": 414, "track": 415, "train": 416, "trampoline": 417, "trash bin": 418, "tray": 419, "tree": 420, "tricycle": 421, "tripod": 422, "trophy": 423, "truck": 424, "tube": 425, "turtle": 426, "tvmonitor": 427, "tweezers": 428, "typewriter": 429, "umbrella": 430, "unknown": 431, "vacuum cleaner": 432, "vending machine": 433, "video camera": 434, "video game console": 435, "video player": 436, "video tape": 437, "violin": 438, "wakeboard": 439, "wall": 440, "wallet": 441, "wardrobe": 442, "washing machine": 443, "watch": 444, "water": 445, "water dispenser": 446, "water pipe": 447, "water skate board": 448, "watermelon": 449, "whale": 450, "wharf": 451, "wheel": 452, "wheelchair": 453, "window": 454, "window blinds": 455, "wineglass": 456, "wire": 457, "wood": 458, "wool": 459} -------------------------------------------------------------------------------- /data/db_info/pascal_map.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scale-lab/MTLoRA/333a3750568f2ac9b8313fa6ff5dd86f517be13f/data/db_info/pascal_map.npy -------------------------------------------------------------------------------- /data/db_info/pascal_part.json: -------------------------------------------------------------------------------- 1 | { 2 | "1": {"body": 1 , "engine_1": 11, "engine_10": 20, "engine_2": 12, "engine_3": 13, "engine_4": 14, "engine_5": 15, "engine_6": 16, "engine_7": 17, "engine_8": 18, "engine_9": 19, "lwing": 3, "rwing": 4, "stern": 2, "tail": 5, "wheel_1": 21, "wheel_10": 30, "wheel_2": 22, "wheel_3": 23, "wheel_4": 24, "wheel_5": 25, "wheel_6": 26, "wheel_7": 27, "wheel_8": 28, "wheel_9": 29}, 3 | "2": {"bwheel": 2 , "chainwheel": 5, "fwheel": 1, "handlebar": 4, "headlight_1": 11, "headlight_10": 20, "headlight_2": 12, "headlight_3": 13, "headlight_4": 14, "headlight_5": 15, "headlight_6": 16, "headlight_7": 17, "headlight_8": 18, "headlight_9": 19, "saddle": 3}, 4 | "3": {"beak": 4 , "head": 1, "leye": 2, "lfoot": 10, "lleg": 9, "lwing": 7, "neck": 6, "reye": 3, "rfoot": 12, "rleg": 11, "rwing": 8, "tail": 13, "torso": 5}, 5 | "4": {}, 6 | "5": {"body": 2 , "cap": 1}, 7 | "6": {"backside": 4 , "bliplate": 9, "door_1": 11, "door_10": 20, "door_2": 12, "door_3": 13, "door_4": 14, "door_5": 15, "door_6": 16, "door_7": 17, "door_8": 18, "door_9": 19, "fliplate": 8, "frontside": 1, "headlight_1": 31, "headlight_10": 40, "headlight_2": 32, "headlight_3": 33, "headlight_4": 34, "headlight_5": 35, "headlight_6": 36, "headlight_7": 37, "headlight_8": 38, "headlight_9": 39, "leftmirror": 6, "leftside": 2, "rightmirror": 7, "rightside": 3, "roofside": 5, "wheel_1": 21, "wheel_10": 30, "wheel_2": 22, "wheel_3": 23, "wheel_4": 24, "wheel_5": 25, "wheel_6": 26, "wheel_7": 27, "wheel_8": 28, "wheel_9": 29, "window_1": 41, "window_10": 50, "window_11": 51, "window_12": 52, "window_13": 53, "window_14": 54, "window_15": 55, "window_16": 56, "window_17": 57, "window_18": 58, "window_19": 59, "window_2": 42, "window_20": 60, "window_3": 43, "window_4": 44, "window_5": 45, "window_6": 46, "window_7": 47, "window_8": 48, "window_9": 49}, 8 | "7": {"backside": 4 , "bliplate": 9, "door_1": 11, "door_10": 20, "door_2": 12, "door_3": 13, "door_4": 14, "door_5": 15, "door_6": 16, "door_7": 17, "door_8": 18, "door_9": 19, "fliplate": 8, "frontside": 1, "headlight_1": 31, "headlight_10": 40, "headlight_2": 32, "headlight_3": 33, "headlight_4": 34, "headlight_5": 35, "headlight_6": 36, "headlight_7": 37, "headlight_8": 38, "headlight_9": 39, "leftmirror": 6, "leftside": 2, "rightmirror": 7, "rightside": 3, "roofside": 5, "wheel_1": 21, "wheel_10": 30, "wheel_2": 22, "wheel_3": 23, "wheel_4": 24, "wheel_5": 25, "wheel_6": 26, "wheel_7": 27, "wheel_8": 28, "wheel_9": 29, "window_1": 41, "window_10": 50, "window_11": 51, "window_12": 52, "window_13": 53, "window_14": 54, "window_15": 55, "window_16": 56, "window_17": 57, "window_18": 58, "window_19": 59, "window_2": 42, "window_20": 60, "window_3": 43, "window_4": 44, "window_5": 45, "window_6": 46, "window_7": 47, "window_8": 48, "window_9": 49}, 9 | "8": {"head": 1 , "lbleg": 13, "lbpa": 14, "lear": 4, "leye": 2, "lfleg": 9, "lfpa": 10, "neck": 8, "nose": 6, "rbleg": 15, "rbpa": 16, "rear": 5, "reye": 3, "rfleg": 11, "rfpa": 12, "tail": 17, "torso": 7}, 10 | "9": {}, 11 | "10": {"head": 1 , "lblleg": 16, "lbuleg": 15, "lear": 4, "leye": 2, "lflleg": 12, "lfuleg": 11, "lhorn": 7, "muzzle": 6, "neck": 10, "rblleg": 18, "rbuleg": 17, "rear": 5, "reye": 3, "rflleg": 14, "rfuleg": 13, "rhorn": 8, "tail": 19, "torso": 9}, 12 | "11": {}, 13 | "12": {"head": 1 , "lbleg": 13, "lbpa": 14, "lear": 4, "leye": 2, "lfleg": 9, "lfpa": 10, "muzzle": 20, "neck": 8, "nose": 6, "rbleg": 15, "rbpa": 16, "rear": 5, "reye": 3, "rfleg": 11, "rfpa": 12, "tail": 17, "torso": 7}, 14 | "13": {"head": 1 , "lbho": 32, "lblleg": 16, "lbuleg": 15, "lear": 4, "leye": 2, "lfho": 30, "lflleg": 12, "lfuleg": 11, "muzzle": 6, "neck": 10, "rbho": 33, "rblleg": 18, "rbuleg": 17, "rear": 5, "reye": 3, "rfho": 31, "rflleg": 14, "rfuleg": 13, "tail": 19, "torso": 9}, 15 | "14": {"bwheel": 2 , "fwheel": 1, "handlebar": 3, "headlight_1": 11, "headlight_10": 20, "headlight_2": 12, "headlight_3": 13, "headlight_4": 14, "headlight_5": 15, "headlight_6": 16, "headlight_7": 17, "headlight_8": 18, "headlight_9": 19, "saddle": 4}, 16 | "15": {"hair": 10 , "head": 1, "lear": 4, "lebrow": 6, "leye": 2, "lfoot": 21, "lhand": 15, "llarm": 13, "llleg": 19, "luarm": 14, "luleg": 20, "mouth": 9, "neck": 12, "nose": 8, "rear": 5, "rebrow": 7, "reye": 3, "rfoot": 24, "rhand": 18, "rlarm": 16, "rlleg": 22, "ruarm": 17, "ruleg": 23, "torso": 11}, 17 | "16": {"plant": 2 , "pot": 1}, 18 | "17": {"head": 1 , "lblleg": 16, "lbuleg": 15, "lear": 4, "leye": 2, "lflleg": 12, "lfuleg": 11, "lhorn": 7, "muzzle": 6, "neck": 10, "rblleg": 18, "rbuleg": 17, "rear": 5, "reye": 3, "rflleg": 14, "rfuleg": 13, "rhorn": 8, "tail": 19, "torso": 9}, 19 | "18": {}, 20 | "19": {"cbackside_1": 61 , "cbackside_10": 70, "cbackside_2": 62, "cbackside_3": 63, "cbackside_4": 64, "cbackside_5": 65, "cbackside_6": 66, "cbackside_7": 67, "cbackside_8": 68, "cbackside_9": 69, "cfrontside_1": 31, "cfrontside_10": 40, "cfrontside_2": 32, "cfrontside_3": 33, "cfrontside_4": 34, "cfrontside_5": 35, "cfrontside_6": 36, "cfrontside_7": 37, "cfrontside_8": 38, "cfrontside_9": 39, "cleftside_1": 41, "cleftside_10": 50, "cleftside_2": 42, "cleftside_3": 43, "cleftside_4": 44, "cleftside_5": 45, "cleftside_6": 46, "cleftside_7": 47, "cleftside_8": 48, "cleftside_9": 49, "coach_1": 21, "coach_10": 30, "coach_2": 22, "coach_3": 23, "coach_4": 24, "coach_5": 25, "coach_6": 26, "coach_7": 27, "coach_8": 28, "coach_9": 29, "crightside_1": 51, "crightside_10": 60, "crightside_2": 52, "crightside_3": 53, "crightside_4": 54, "crightside_5": 55, "crightside_6": 56, "crightside_7": 57, "crightside_8": 58, "crightside_9": 59, "croofside_1": 71, "croofside_10": 80, "croofside_2": 72, "croofside_3": 73, "croofside_4": 74, "croofside_5": 75, "croofside_6": 76, "croofside_7": 77, "croofside_8": 78, "croofside_9": 79, "hbackside": 5, "head": 1, "headlight_1": 11, "headlight_10": 20, "headlight_2": 12, "headlight_3": 13, "headlight_4": 14, "headlight_5": 15, "headlight_6": 16, "headlight_7": 17, "headlight_8": 18, "headlight_9": 19, "hfrontside": 2, "hleftside": 3, "hrightside": 4, "hroofside": 6}, 21 | "20": {"screen": 1 } 22 | } 23 | -------------------------------------------------------------------------------- /data/helpers.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # MTLoRA 3 | # GitHub: https://github.com/scale-lab/MTLoRA 4 | # 5 | # Original file: 6 | # License: Attribution-NonCommercial 4.0 International (https://github.com/facebookresearch/astmt/) 7 | # Copyright (c) Facebook, Inc. and its affiliates. 8 | # 9 | # Modifications: 10 | # Copyright (c) 2024 SCALE Lab, Brown University 11 | # Licensed under the MIT License (see LICENSE for details) 12 | # -------------------------------------------------------- 13 | 14 | 15 | import torch 16 | import cv2 17 | import numpy as np 18 | 19 | 20 | def tens2image(tens): 21 | """Converts tensor with 2 or 3 dimensions to numpy array""" 22 | im = tens.numpy() 23 | 24 | if im.shape[0] == 1: 25 | im = np.squeeze(im, axis=0) 26 | 27 | if im.ndim == 3: 28 | im = im.transpose((1, 2, 0)) 29 | 30 | return im 31 | 32 | 33 | def pascal_color_map(N=256, normalized=False): 34 | """ 35 | Python implementation of the color map function for the PASCAL VOC data set. 36 | Official Matlab version can be found in the PASCAL VOC devkit 37 | http://host.robots.ox.ac.uk/pascal/VOC/voc2012/index.html#devkit 38 | """ 39 | 40 | def bitget(byteval, idx): 41 | return (byteval & (1 << idx)) != 0 42 | 43 | dtype = 'float32' if normalized else 'uint8' 44 | cmap = np.zeros((N, 3), dtype=dtype) 45 | for i in range(N): 46 | r = g = b = 0 47 | c = i 48 | for j in range(8): 49 | r = r | (bitget(c, 0) << 7 - j) 50 | g = g | (bitget(c, 1) << 7 - j) 51 | b = b | (bitget(c, 2) << 7 - j) 52 | c = c >> 3 53 | 54 | cmap[i] = np.array([r, g, b]) 55 | 56 | cmap = cmap / 255 if normalized else cmap 57 | return cmap 58 | 59 | 60 | def fixed_resize(sample, resolution, flagval=None): 61 | """ 62 | Fixed resize to 63 | resolution (tuple): resize image to size specified by tuple eg. (512, 512). 64 | resolution (int): bring smaller side to resolution eg. image of shape 321 x 481 -> 512 x 767 65 | """ 66 | if flagval is None: 67 | if ((sample == 0) | (sample == 1)).all(): 68 | flagval = cv2.INTER_NEAREST 69 | else: 70 | flagval = cv2.INTER_CUBIC 71 | 72 | if isinstance(resolution, int): 73 | tmp = [resolution, resolution] 74 | tmp[int(np.argmax(sample.shape[:2]))] = int( 75 | round(float(resolution) / np.min(sample.shape[:2]) * np.max(sample.shape[:2]))) 76 | resolution = tuple(tmp) 77 | 78 | if sample.ndim == 2 or (sample.ndim == 3 and sample.shape[2] == 3): 79 | sample = cv2.resize(sample, resolution[::-1], interpolation=flagval) 80 | else: 81 | tmp = sample 82 | sample = np.zeros( 83 | np.append(resolution, tmp.shape[2]), dtype=float) 84 | for ii in range(sample.shape[2]): 85 | sample[:, :, ii] = cv2.resize( 86 | tmp[:, :, ii], resolution[::-1], interpolation=flagval) 87 | return sample 88 | 89 | 90 | def im_normalize(im, max_value=1): 91 | """ 92 | Normalize image to range 0 - max_value 93 | """ 94 | imn = max_value * (im - im.min()) / max((im.max() - im.min()), 1e-8) 95 | return imn 96 | 97 | 98 | def generate_param_report(logfile, param): 99 | log_file = open(logfile, 'w') 100 | for key, val in param.items(): 101 | log_file.write(key + ':' + str(val) + '\n') 102 | log_file.close() 103 | 104 | 105 | def ind2sub(array_shape, inds): 106 | rows, cols = [], [] 107 | for k in range(len(inds)): 108 | if inds[k] == 0: 109 | continue 110 | cols.append((inds[k].astype('int') // array_shape[1])) 111 | rows.append((inds[k].astype('int') % array_shape[1])) 112 | return rows, cols 113 | -------------------------------------------------------------------------------- /data/imagenet22k_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch.utils.data as data 4 | import numpy as np 5 | from PIL import Image 6 | 7 | import warnings 8 | 9 | warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning) 10 | 11 | 12 | class IN22KDATASET(data.Dataset): 13 | def __init__(self, root, ann_file='', transform=None, target_transform=None): 14 | super(IN22KDATASET, self).__init__() 15 | 16 | self.data_path = root 17 | self.ann_path = os.path.join(self.data_path, ann_file) 18 | self.transform = transform 19 | self.target_transform = target_transform 20 | # id & label: https://github.com/google-research/big_transfer/issues/7 21 | # total: 21843; only 21841 class have images: map 21841->9205; 21842->15027 22 | self.database = json.load(open(self.ann_path)) 23 | 24 | def _load_image(self, path): 25 | try: 26 | im = Image.open(path) 27 | except: 28 | print("ERROR IMG LOADED: ", path) 29 | random_img = np.random.rand(224, 224, 3) * 255 30 | im = Image.fromarray(np.uint8(random_img)) 31 | return im 32 | 33 | def __getitem__(self, index): 34 | """ 35 | Args: 36 | index (int): Index 37 | Returns: 38 | tuple: (image, target) where target is class_index of the target class. 39 | """ 40 | idb = self.database[index] 41 | 42 | # images 43 | images = self._load_image(self.data_path + '/' + idb[0]).convert('RGB') 44 | if self.transform is not None: 45 | images = self.transform(images) 46 | 47 | # target 48 | target = int(idb[1]) 49 | if self.target_transform is not None: 50 | target = self.target_transform(target) 51 | 52 | return images, target 53 | 54 | def __len__(self): 55 | return len(self.database) 56 | -------------------------------------------------------------------------------- /data/map22kto1k.txt: -------------------------------------------------------------------------------- 1 | 359 2 | 368 3 | 460 4 | 475 5 | 486 6 | 492 7 | 496 8 | 514 9 | 516 10 | 525 11 | 547 12 | 548 13 | 556 14 | 563 15 | 575 16 | 641 17 | 648 18 | 723 19 | 733 20 | 765 21 | 801 22 | 826 23 | 852 24 | 858 25 | 878 26 | 896 27 | 900 28 | 905 29 | 908 30 | 910 31 | 935 32 | 946 33 | 947 34 | 994 35 | 999 36 | 1003 37 | 1005 38 | 1010 39 | 1027 40 | 1029 41 | 1048 42 | 1055 43 | 1064 44 | 1065 45 | 1069 46 | 1075 47 | 1079 48 | 1081 49 | 1085 50 | 1088 51 | 1093 52 | 1106 53 | 1143 54 | 1144 55 | 1145 56 | 1147 57 | 1168 58 | 1171 59 | 1178 60 | 1187 61 | 1190 62 | 1197 63 | 1205 64 | 1216 65 | 1223 66 | 1230 67 | 1236 68 | 1241 69 | 1245 70 | 1257 71 | 1259 72 | 1260 73 | 1267 74 | 1268 75 | 1269 76 | 1271 77 | 1272 78 | 1273 79 | 1277 80 | 1303 81 | 1344 82 | 1349 83 | 1355 84 | 1357 85 | 1384 86 | 1388 87 | 1391 88 | 1427 89 | 1429 90 | 1432 91 | 1437 92 | 1450 93 | 1461 94 | 1462 95 | 1474 96 | 1502 97 | 1503 98 | 1512 99 | 1552 100 | 1555 101 | 1577 102 | 1584 103 | 1587 104 | 1589 105 | 1599 106 | 1615 107 | 1616 108 | 1681 109 | 1692 110 | 1701 111 | 1716 112 | 1729 113 | 1757 114 | 1759 115 | 1764 116 | 1777 117 | 1786 118 | 1822 119 | 1841 120 | 1842 121 | 1848 122 | 1850 123 | 1856 124 | 1860 125 | 1861 126 | 1864 127 | 1876 128 | 1897 129 | 1898 130 | 1910 131 | 1913 132 | 1918 133 | 1922 134 | 1928 135 | 1932 136 | 1935 137 | 1947 138 | 1951 139 | 1953 140 | 1970 141 | 1977 142 | 1979 143 | 2001 144 | 2017 145 | 2067 146 | 2081 147 | 2087 148 | 2112 149 | 2128 150 | 2135 151 | 2147 152 | 2174 153 | 2175 154 | 2176 155 | 2177 156 | 2178 157 | 2181 158 | 2183 159 | 2184 160 | 2187 161 | 2189 162 | 2190 163 | 2191 164 | 2192 165 | 2193 166 | 2197 167 | 2202 168 | 2203 169 | 2206 170 | 2208 171 | 2209 172 | 2211 173 | 2212 174 | 2213 175 | 2214 176 | 2215 177 | 2216 178 | 2217 179 | 2219 180 | 2222 181 | 2223 182 | 2224 183 | 2225 184 | 2226 185 | 2227 186 | 2228 187 | 2229 188 | 2230 189 | 2236 190 | 2238 191 | 2240 192 | 2241 193 | 2242 194 | 2243 195 | 2244 196 | 2245 197 | 2247 198 | 2248 199 | 2249 200 | 2250 201 | 2251 202 | 2252 203 | 2255 204 | 2256 205 | 2257 206 | 2262 207 | 2263 208 | 2264 209 | 2265 210 | 2266 211 | 2268 212 | 2270 213 | 2271 214 | 2272 215 | 2273 216 | 2275 217 | 2276 218 | 2279 219 | 2280 220 | 2281 221 | 2282 222 | 2285 223 | 2289 224 | 2292 225 | 2295 226 | 2296 227 | 2297 228 | 2298 229 | 2299 230 | 2300 231 | 2301 232 | 2302 233 | 2303 234 | 2304 235 | 2305 236 | 2306 237 | 2309 238 | 2310 239 | 2312 240 | 2313 241 | 2314 242 | 2315 243 | 2316 244 | 2318 245 | 2319 246 | 2321 247 | 2322 248 | 2326 249 | 2329 250 | 2330 251 | 2331 252 | 2332 253 | 2334 254 | 2335 255 | 2336 256 | 2337 257 | 2338 258 | 2339 259 | 2341 260 | 2342 261 | 2343 262 | 2344 263 | 2346 264 | 2348 265 | 2349 266 | 2351 267 | 2352 268 | 2353 269 | 2355 270 | 2357 271 | 2358 272 | 2359 273 | 2360 274 | 2364 275 | 2365 276 | 2368 277 | 2369 278 | 2377 279 | 2382 280 | 2383 281 | 2385 282 | 2397 283 | 2398 284 | 2400 285 | 2402 286 | 2405 287 | 2412 288 | 2421 289 | 2428 290 | 2431 291 | 2432 292 | 2433 293 | 2436 294 | 2441 295 | 2445 296 | 2450 297 | 2453 298 | 2454 299 | 2465 300 | 2469 301 | 2532 302 | 2533 303 | 2538 304 | 2544 305 | 2547 306 | 2557 307 | 2565 308 | 2578 309 | 2612 310 | 2658 311 | 2702 312 | 2722 313 | 2731 314 | 2738 315 | 2741 316 | 2747 317 | 2810 318 | 2818 319 | 2833 320 | 2844 321 | 2845 322 | 2867 323 | 2874 324 | 2882 325 | 2884 326 | 2888 327 | 2889 328 | 3008 329 | 3012 330 | 3019 331 | 3029 332 | 3033 333 | 3042 334 | 3091 335 | 3106 336 | 3138 337 | 3159 338 | 3164 339 | 3169 340 | 3280 341 | 3296 342 | 3311 343 | 3318 344 | 3320 345 | 3324 346 | 3330 347 | 3366 348 | 3375 349 | 3381 350 | 3406 351 | 3419 352 | 3432 353 | 3434 354 | 3435 355 | 3493 356 | 3495 357 | 3503 358 | 3509 359 | 3511 360 | 3513 361 | 3517 362 | 3521 363 | 3526 364 | 3546 365 | 3554 366 | 3600 367 | 3601 368 | 3606 369 | 3612 370 | 3613 371 | 3616 372 | 3622 373 | 3623 374 | 3627 375 | 3632 376 | 3634 377 | 3636 378 | 3638 379 | 3644 380 | 3646 381 | 3649 382 | 3650 383 | 3651 384 | 3656 385 | 3663 386 | 3673 387 | 3674 388 | 3689 389 | 3690 390 | 3702 391 | 3733 392 | 3769 393 | 3971 394 | 3974 395 | 4065 396 | 4068 397 | 4073 398 | 4102 399 | 4136 400 | 4140 401 | 4151 402 | 4159 403 | 4165 404 | 4207 405 | 4219 406 | 4226 407 | 4249 408 | 4256 409 | 4263 410 | 4270 411 | 4313 412 | 4321 413 | 4378 414 | 4386 415 | 4478 416 | 4508 417 | 4512 418 | 4536 419 | 4542 420 | 4550 421 | 4560 422 | 4562 423 | 4570 424 | 4571 425 | 4572 426 | 4583 427 | 4588 428 | 4594 429 | 4604 430 | 4608 431 | 4623 432 | 4634 433 | 4636 434 | 4646 435 | 4651 436 | 4652 437 | 4686 438 | 4688 439 | 4691 440 | 4699 441 | 4724 442 | 4727 443 | 4737 444 | 4770 445 | 4774 446 | 4789 447 | 4802 448 | 4807 449 | 4819 450 | 4880 451 | 4886 452 | 4908 453 | 4927 454 | 4931 455 | 4936 456 | 4964 457 | 4976 458 | 4993 459 | 5028 460 | 5033 461 | 5043 462 | 5046 463 | 5096 464 | 5111 465 | 5114 466 | 5131 467 | 5132 468 | 5183 469 | 5199 470 | 5235 471 | 5275 472 | 5291 473 | 5293 474 | 5294 475 | 5343 476 | 5360 477 | 5362 478 | 5364 479 | 5390 480 | 5402 481 | 5418 482 | 5428 483 | 5430 484 | 5437 485 | 5443 486 | 5473 487 | 5484 488 | 5486 489 | 5505 490 | 5507 491 | 5508 492 | 5510 493 | 5567 494 | 5578 495 | 5580 496 | 5584 497 | 5606 498 | 5613 499 | 5629 500 | 5672 501 | 5676 502 | 5692 503 | 5701 504 | 5760 505 | 5769 506 | 5770 507 | 5779 508 | 5814 509 | 5850 510 | 5871 511 | 5893 512 | 5911 513 | 5949 514 | 5954 515 | 6005 516 | 6006 517 | 6012 518 | 6017 519 | 6023 520 | 6024 521 | 6040 522 | 6050 523 | 6054 524 | 6087 525 | 6105 526 | 6157 527 | 6235 528 | 6237 529 | 6256 530 | 6259 531 | 6286 532 | 6291 533 | 6306 534 | 6339 535 | 6341 536 | 6343 537 | 6379 538 | 6383 539 | 6393 540 | 6405 541 | 6479 542 | 6511 543 | 6517 544 | 6541 545 | 6561 546 | 6608 547 | 6611 548 | 6615 549 | 6678 550 | 6682 551 | 6707 552 | 6752 553 | 6798 554 | 6850 555 | 6880 556 | 6885 557 | 6890 558 | 6920 559 | 6981 560 | 7000 561 | 7009 562 | 7038 563 | 7049 564 | 7050 565 | 7052 566 | 7073 567 | 7078 568 | 7098 569 | 7111 570 | 7165 571 | 7198 572 | 7204 573 | 7280 574 | 7283 575 | 7286 576 | 7287 577 | 7293 578 | 7294 579 | 7305 580 | 7318 581 | 7341 582 | 7346 583 | 7354 584 | 7382 585 | 7427 586 | 7428 587 | 7435 588 | 7445 589 | 7450 590 | 7455 591 | 7467 592 | 7469 593 | 7497 594 | 7502 595 | 7506 596 | 7514 597 | 7523 598 | 7651 599 | 7661 600 | 7664 601 | 7672 602 | 7679 603 | 7685 604 | 7696 605 | 7730 606 | 7871 607 | 7873 608 | 7895 609 | 7914 610 | 7915 611 | 7920 612 | 7934 613 | 7935 614 | 7949 615 | 8009 616 | 8036 617 | 8051 618 | 8065 619 | 8074 620 | 8090 621 | 8112 622 | 8140 623 | 8164 624 | 8168 625 | 8178 626 | 8182 627 | 8198 628 | 8212 629 | 8216 630 | 8230 631 | 8242 632 | 8288 633 | 8289 634 | 8295 635 | 8318 636 | 8352 637 | 8368 638 | 8371 639 | 8375 640 | 8376 641 | 8401 642 | 8416 643 | 8419 644 | 8436 645 | 8460 646 | 8477 647 | 8478 648 | 8482 649 | 8498 650 | 8500 651 | 8539 652 | 8543 653 | 8552 654 | 8555 655 | 8580 656 | 8584 657 | 8586 658 | 8594 659 | 8598 660 | 8601 661 | 8606 662 | 8610 663 | 8611 664 | 8622 665 | 8627 666 | 8639 667 | 8649 668 | 8650 669 | 8653 670 | 8654 671 | 8667 672 | 8672 673 | 8673 674 | 8674 675 | 8676 676 | 8684 677 | 8720 678 | 8723 679 | 8750 680 | 8753 681 | 8801 682 | 8815 683 | 8831 684 | 8835 685 | 8842 686 | 8845 687 | 8858 688 | 8897 689 | 8916 690 | 8951 691 | 8954 692 | 8959 693 | 8970 694 | 8976 695 | 8981 696 | 8983 697 | 8989 698 | 8991 699 | 8993 700 | 9019 701 | 9039 702 | 9042 703 | 9043 704 | 9056 705 | 9057 706 | 9070 707 | 9087 708 | 9098 709 | 9106 710 | 9130 711 | 9131 712 | 9155 713 | 9171 714 | 9183 715 | 9198 716 | 9199 717 | 9201 718 | 9204 719 | 9212 720 | 9221 721 | 9225 722 | 9229 723 | 9250 724 | 9260 725 | 9271 726 | 9279 727 | 9295 728 | 9300 729 | 9310 730 | 9322 731 | 9345 732 | 9352 733 | 9376 734 | 9377 735 | 9382 736 | 9392 737 | 9401 738 | 9405 739 | 9441 740 | 9449 741 | 9464 742 | 9475 743 | 9502 744 | 9505 745 | 9514 746 | 9515 747 | 9545 748 | 9567 749 | 9576 750 | 9608 751 | 9609 752 | 9624 753 | 9633 754 | 9639 755 | 9643 756 | 9656 757 | 9674 758 | 9740 759 | 9752 760 | 9760 761 | 9767 762 | 9778 763 | 9802 764 | 9820 765 | 9839 766 | 9879 767 | 9924 768 | 9956 769 | 9961 770 | 9963 771 | 9970 772 | 9997 773 | 10010 774 | 10031 775 | 10040 776 | 10052 777 | 10073 778 | 10075 779 | 10078 780 | 10094 781 | 10097 782 | 10109 783 | 10118 784 | 10121 785 | 10124 786 | 10158 787 | 10226 788 | 10276 789 | 10304 790 | 10307 791 | 10314 792 | 10315 793 | 10332 794 | 10337 795 | 10338 796 | 10413 797 | 10423 798 | 10451 799 | 10463 800 | 10465 801 | 10487 802 | 10519 803 | 10522 804 | 10523 805 | 10532 806 | 10534 807 | 10535 808 | 10551 809 | 10559 810 | 10574 811 | 10583 812 | 10586 813 | 10589 814 | 10612 815 | 10626 816 | 10635 817 | 10638 818 | 10677 819 | 10683 820 | 10726 821 | 10776 822 | 10782 823 | 10783 824 | 10807 825 | 10837 826 | 10840 827 | 10848 828 | 10859 829 | 10871 830 | 10881 831 | 10884 832 | 10908 833 | 10914 834 | 10921 835 | 10936 836 | 10947 837 | 10951 838 | 10952 839 | 10957 840 | 10999 841 | 11003 842 | 11018 843 | 11023 844 | 11025 845 | 11027 846 | 11045 847 | 11055 848 | 11095 849 | 11110 850 | 11137 851 | 5564 852 | 11168 853 | 11186 854 | 11221 855 | 11223 856 | 11242 857 | 11255 858 | 11259 859 | 11279 860 | 11306 861 | 11311 862 | 11331 863 | 11367 864 | 11377 865 | 11389 866 | 11392 867 | 11401 868 | 11407 869 | 11437 870 | 11449 871 | 11466 872 | 11469 873 | 11473 874 | 11478 875 | 11483 876 | 11484 877 | 11507 878 | 11536 879 | 11558 880 | 11566 881 | 11575 882 | 11584 883 | 11594 884 | 11611 885 | 11612 886 | 11619 887 | 11621 888 | 11640 889 | 11643 890 | 11664 891 | 11674 892 | 11689 893 | 11709 894 | 11710 895 | 11716 896 | 11721 897 | 11726 898 | 11729 899 | 11743 900 | 11760 901 | 11771 902 | 11837 903 | 11839 904 | 11856 905 | 11876 906 | 11878 907 | 11884 908 | 11889 909 | 11896 910 | 11917 911 | 11923 912 | 11930 913 | 11944 914 | 11952 915 | 11980 916 | 11984 917 | 12214 918 | 12229 919 | 12239 920 | 12241 921 | 12242 922 | 12247 923 | 12283 924 | 12349 925 | 12369 926 | 12373 927 | 12422 928 | 12560 929 | 12566 930 | 12575 931 | 12688 932 | 12755 933 | 12768 934 | 12778 935 | 12780 936 | 12812 937 | 12832 938 | 12835 939 | 12836 940 | 12843 941 | 12847 942 | 12849 943 | 12850 944 | 12856 945 | 12858 946 | 12873 947 | 12938 948 | 12971 949 | 13017 950 | 13038 951 | 13046 952 | 13059 953 | 13085 954 | 13086 955 | 13088 956 | 13094 957 | 13134 958 | 13182 959 | 13230 960 | 13406 961 | 13444 962 | 13614 963 | 13690 964 | 13698 965 | 13709 966 | 13749 967 | 13804 968 | 13982 969 | 14051 970 | 14059 971 | 14219 972 | 14246 973 | 14256 974 | 14264 975 | 14294 976 | 14324 977 | 14367 978 | 14389 979 | 14394 980 | 14438 981 | 14442 982 | 14965 983 | 15732 984 | 16744 985 | 18037 986 | 18205 987 | 18535 988 | 18792 989 | 19102 990 | 20019 991 | 20462 992 | 21026 993 | 21045 994 | 21163 995 | 21171 996 | 21181 997 | 21196 998 | 21200 999 | 21369 1000 | 21817 -------------------------------------------------------------------------------- /data/samplers.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | 8 | import torch 9 | 10 | 11 | class SubsetRandomSampler(torch.utils.data.Sampler): 12 | r"""Samples elements randomly from a given list of indices, without replacement. 13 | 14 | Arguments: 15 | indices (sequence): a sequence of indices 16 | """ 17 | 18 | def __init__(self, indices): 19 | self.epoch = 0 20 | self.indices = indices 21 | 22 | def __iter__(self): 23 | return (self.indices[i] for i in torch.randperm(len(self.indices))) 24 | 25 | def __len__(self): 26 | return len(self.indices) 27 | 28 | def set_epoch(self, epoch): 29 | self.epoch = epoch 30 | -------------------------------------------------------------------------------- /data/zipreader.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | 8 | import os 9 | import zipfile 10 | import io 11 | import numpy as np 12 | from PIL import Image 13 | from PIL import ImageFile 14 | 15 | ImageFile.LOAD_TRUNCATED_IMAGES = True 16 | 17 | 18 | def is_zip_path(img_or_path): 19 | """judge if this is a zip path""" 20 | return '.zip@' in img_or_path 21 | 22 | 23 | class ZipReader(object): 24 | """A class to read zipped files""" 25 | zip_bank = dict() 26 | 27 | def __init__(self): 28 | super(ZipReader, self).__init__() 29 | 30 | @staticmethod 31 | def get_zipfile(path): 32 | zip_bank = ZipReader.zip_bank 33 | if path not in zip_bank: 34 | zfile = zipfile.ZipFile(path, 'r') 35 | zip_bank[path] = zfile 36 | return zip_bank[path] 37 | 38 | @staticmethod 39 | def split_zip_style_path(path): 40 | pos_at = path.index('@') 41 | assert pos_at != -1, "character '@' is not found from the given path '%s'" % path 42 | 43 | zip_path = path[0: pos_at] 44 | folder_path = path[pos_at + 1:] 45 | folder_path = str.strip(folder_path, '/') 46 | return zip_path, folder_path 47 | 48 | @staticmethod 49 | def list_folder(path): 50 | zip_path, folder_path = ZipReader.split_zip_style_path(path) 51 | 52 | zfile = ZipReader.get_zipfile(zip_path) 53 | folder_list = [] 54 | for file_foler_name in zfile.namelist(): 55 | file_foler_name = str.strip(file_foler_name, '/') 56 | if file_foler_name.startswith(folder_path) and \ 57 | len(os.path.splitext(file_foler_name)[-1]) == 0 and \ 58 | file_foler_name != folder_path: 59 | if len(folder_path) == 0: 60 | folder_list.append(file_foler_name) 61 | else: 62 | folder_list.append(file_foler_name[len(folder_path) + 1:]) 63 | 64 | return folder_list 65 | 66 | @staticmethod 67 | def list_files(path, extension=None): 68 | if extension is None: 69 | extension = ['.*'] 70 | zip_path, folder_path = ZipReader.split_zip_style_path(path) 71 | 72 | zfile = ZipReader.get_zipfile(zip_path) 73 | file_lists = [] 74 | for file_foler_name in zfile.namelist(): 75 | file_foler_name = str.strip(file_foler_name, '/') 76 | if file_foler_name.startswith(folder_path) and \ 77 | str.lower(os.path.splitext(file_foler_name)[-1]) in extension: 78 | if len(folder_path) == 0: 79 | file_lists.append(file_foler_name) 80 | else: 81 | file_lists.append(file_foler_name[len(folder_path) + 1:]) 82 | 83 | return file_lists 84 | 85 | @staticmethod 86 | def read(path): 87 | zip_path, path_img = ZipReader.split_zip_style_path(path) 88 | zfile = ZipReader.get_zipfile(zip_path) 89 | data = zfile.read(path_img) 90 | return data 91 | 92 | @staticmethod 93 | def imread(path): 94 | zip_path, path_img = ZipReader.split_zip_style_path(path) 95 | zfile = ZipReader.get_zipfile(zip_path) 96 | data = zfile.read(path_img) 97 | try: 98 | im = Image.open(io.BytesIO(data)) 99 | except: 100 | print("ERROR IMG LOADED: ", path_img) 101 | random_img = np.random.rand(224, 224, 3) * 255 102 | im = Image.fromarray(np.uint8(random_img)) 103 | return im 104 | -------------------------------------------------------------------------------- /evaluation/eval_depth.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # MTLoRA 3 | # GitHub: https://github.com/scale-lab/MTLoRA 4 | # 5 | # Original file: 6 | # License: Attribution-NonCommercial 4.0 International (https://github.com/facebookresearch/astmt/) 7 | # Copyright (c) Facebook, Inc. and its affiliates. 8 | # 9 | # Modifications: 10 | # Copyright (c) 2024 SCALE Lab, Brown University 11 | # Licensed under the MIT License (see LICENSE for details) 12 | 13 | import warnings 14 | import cv2 15 | import os.path 16 | import numpy as np 17 | import glob 18 | import torch 19 | import json 20 | import scipy.io as sio 21 | 22 | 23 | def eval_depth(loader, folder): 24 | 25 | total_rmses = 0.0 26 | total_log_rmses = 0.0 27 | n_valid = 0.0 28 | 29 | for i, sample in enumerate(loader): 30 | 31 | if i % 500 == 0: 32 | print('Evaluating depth: {} of {} objects'.format(i, len(loader))) 33 | 34 | # Load result 35 | filename = os.path.join(folder, sample['meta']['image'] + '.mat') 36 | pred = sio.loadmat(filename)['depth'].astype(float) 37 | label = sample['depth'] 38 | 39 | if pred.shape != label.shape: 40 | warnings.warn( 41 | 'Prediction and ground truth have different size. Resizing Prediction..') 42 | pred = cv2.resize( 43 | pred, label.shape[::-1], interpolation=cv2.INTER_LINEAR) 44 | 45 | valid_mask = (label != 0) 46 | n_valid += np.sum(valid_mask) 47 | 48 | label[label == 0] = 1e-9 # Avoid overflow/underflow 49 | pred[pred <= 0] = 1e-9 50 | 51 | log_rmse_tmp = (np.log(label[valid_mask]) - 52 | np.log(pred[valid_mask])) ** 2 53 | total_log_rmses += np.sum(log_rmse_tmp) 54 | 55 | rmse_tmp = (label[valid_mask] - pred[valid_mask]) ** 2 56 | total_rmses += np.sum(rmse_tmp) 57 | 58 | eval_result = dict() 59 | eval_result['rmse'] = np.sqrt(total_rmses / n_valid) 60 | eval_result['log_rmse'] = np.sqrt(total_log_rmses / n_valid) 61 | 62 | return eval_result 63 | 64 | 65 | class DepthMeter(object): 66 | def __init__(self): 67 | self.total_rmses = 0.0 68 | self.total_log_rmses = 0.0 69 | self.n_valid = 0.0 70 | 71 | @torch.no_grad() 72 | def update(self, pred, gt): 73 | pred, gt = pred.squeeze(), gt.squeeze() 74 | 75 | # Determine valid mask 76 | mask = (gt != 255).bool() 77 | self.n_valid += mask.float().sum().item() # Valid pixels per image 78 | 79 | # Only positive depth values are possible 80 | pred = torch.clamp(pred, min=1e-9) 81 | 82 | # Per pixel rmse and log-rmse. 83 | log_rmse_tmp = torch.pow(torch.log(gt) - torch.log(pred), 2) 84 | log_rmse_tmp = torch.masked_select(log_rmse_tmp, mask) 85 | self.total_log_rmses += log_rmse_tmp.sum().item() 86 | 87 | rmse_tmp = torch.pow(gt - pred, 2) 88 | rmse_tmp = torch.masked_select(rmse_tmp, mask) 89 | self.total_rmses += rmse_tmp.sum().item() 90 | 91 | def reset(self): 92 | self.rmses = [] 93 | self.log_rmses = [] 94 | 95 | def get_score(self, verbose=True): 96 | eval_result = dict() 97 | eval_result['rmse'] = np.sqrt(self.total_rmses / self.n_valid) 98 | eval_result['log_rmse'] = np.sqrt(self.total_log_rmses / self.n_valid) 99 | 100 | if verbose: 101 | print('Results for depth prediction') 102 | for x in eval_result: 103 | spaces = '' 104 | for j in range(0, 15 - len(x)): 105 | spaces += ' ' 106 | print('{0:s}{1:s}{2:.4f}'.format(x, spaces, eval_result[x])) 107 | 108 | return eval_result 109 | 110 | 111 | def eval_depth_predictions(database, save_dir, overfit=False): 112 | 113 | # Dataloaders 114 | if database == 'NYUD': 115 | from data.nyud import NYUD_MT 116 | gt_set = 'val' 117 | db = NYUD_MT(split=gt_set, do_depth=True, overfit=overfit) 118 | 119 | else: 120 | raise NotImplementedError 121 | 122 | base_name = database + '_' + 'test' + '_depth' 123 | fname = os.path.join(save_dir, base_name + '.json') 124 | 125 | # Eval the model 126 | print('Evaluate the saved images (depth)') 127 | eval_results = eval_depth(db, os.path.join(save_dir, 'depth')) 128 | with open(fname, 'w') as f: 129 | json.dump(eval_results, f) 130 | 131 | # Print results 132 | print('Results for Depth Estimation') 133 | for x in eval_results: 134 | spaces = '' 135 | for j in range(0, 15 - len(x)): 136 | spaces += ' ' 137 | print('{0:s}{1:s}{2:.4f}'.format(x, spaces, eval_results[x])) 138 | 139 | return eval_results 140 | -------------------------------------------------------------------------------- /evaluation/eval_edge.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # MTLoRA 3 | # GitHub: https://github.com/scale-lab/MTLoRA 4 | # 5 | # Original file: 6 | # Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/) 7 | # Written by Simon Vandenhende 8 | # 9 | # Modifications: 10 | # Copyright (c) 2024 SCALE Lab, Brown University 11 | # Licensed under the MIT License (see LICENSE for details) 12 | # -------------------------------------------------------- 13 | 14 | import os 15 | import glob 16 | import json 17 | import torch 18 | import numpy as np 19 | from utils import mkdir_if_missing 20 | from mtl_loss_schemes import BalancedCrossEntropyLoss 21 | 22 | 23 | class EdgeMeter(object): 24 | def __init__(self, pos_weight): 25 | self.loss = 0 26 | self.n = 0 27 | self.loss_function = BalancedCrossEntropyLoss( 28 | size_average=True, pos_weight=pos_weight) 29 | 30 | @torch.no_grad() 31 | def update(self, pred, gt): 32 | gt = gt.squeeze() 33 | pred = pred.float().squeeze() / 255. 34 | loss = self.loss_function(pred, gt).item() 35 | numel = gt.numel() 36 | self.n += numel 37 | self.loss += numel * loss 38 | 39 | def reset(self): 40 | self.loss = 0 41 | self.n = 0 42 | 43 | def get_score(self, verbose=True): 44 | eval_dict = {'loss': self.loss / self.n} 45 | 46 | if verbose: 47 | print('\nEdge Detection Evaluation') 48 | print('Edge Detection Loss %.3f' % (eval_dict['loss'])) 49 | 50 | return eval_dict 51 | -------------------------------------------------------------------------------- /evaluation/eval_human_parts.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # MTLoRA 3 | # GitHub: https://github.com/scale-lab/MTLoRA 4 | # 5 | # Original file: 6 | # License: Attribution-NonCommercial 4.0 International (https://github.com/facebookresearch/astmt/) 7 | # Copyright (c) Facebook, Inc. and its affiliates. 8 | # 9 | # Modifications: 10 | # Copyright (c) 2024 SCALE Lab, Brown University 11 | # Licensed under the MIT License (see LICENSE for details) 12 | 13 | import warnings 14 | import cv2 15 | import glob 16 | import json 17 | import os.path 18 | import numpy as np 19 | import torch 20 | from PIL import Image 21 | 22 | PART_CATEGORY_NAMES = ['background', 'head', 23 | 'torso', 'uarm', 'larm', 'uleg', 'lleg'] 24 | 25 | 26 | def eval_human_parts(loader, folder, n_parts=6): 27 | 28 | tp = [0] * (n_parts + 1) 29 | fp = [0] * (n_parts + 1) 30 | fn = [0] * (n_parts + 1) 31 | 32 | counter = 0 33 | for i, sample in enumerate(loader): 34 | 35 | if i % 500 == 0: 36 | print('Evaluating: {} of {} objects'.format(i, len(loader))) 37 | 38 | if 'human_parts' not in sample: 39 | continue 40 | 41 | # Check for valid pixels 42 | gt = sample['human_parts'] 43 | uniq = np.unique(gt) 44 | if len(uniq) == 1 and (uniq[0] == 255 or uniq[0] == 0): 45 | continue 46 | 47 | # Load result 48 | filename = os.path.join(folder, sample['meta']['image'] + '.png') 49 | mask = np.array(Image.open(filename)).astype(float) 50 | 51 | # Case of a binary (probability) result 52 | if n_parts == 1: 53 | mask = (mask > 0.5 * 255).astype(float) 54 | 55 | counter += 1 56 | valid = (gt != 255) 57 | 58 | if mask.shape != gt.shape: 59 | warnings.warn( 60 | 'Prediction and ground truth have different size. Resizing Prediction..') 61 | mask = cv2.resize( 62 | mask, gt.shape[::-1], interpolation=cv2.INTER_NEAREST) 63 | 64 | # TP, FP, and FN evaluation 65 | for i_part in range(0, n_parts + 1): 66 | tmp_gt = (gt == i_part) 67 | tmp_pred = (mask == i_part) 68 | tp[i_part] += np.sum(tmp_gt & tmp_pred & (valid)) 69 | fp[i_part] += np.sum(~tmp_gt & tmp_pred & (valid)) 70 | fn[i_part] += np.sum(tmp_gt & ~tmp_pred & (valid)) 71 | 72 | print('Successful evaluation for {} images'.format(counter)) 73 | jac = [0] * (n_parts + 1) 74 | for i_part in range(0, n_parts + 1): 75 | jac[i_part] = float( 76 | tp[i_part]) / max(float(tp[i_part] + fp[i_part] + fn[i_part]), 1e-8) 77 | 78 | # Write results 79 | eval_result = dict() 80 | eval_result['jaccards_all_categs'] = jac 81 | eval_result['mIoU'] = np.mean(jac) 82 | 83 | return eval_result 84 | 85 | 86 | class HumanPartsMeter(object): 87 | def __init__(self, database): 88 | assert (database == 'PASCALContext') 89 | self.database = database 90 | self.cat_names = PART_CATEGORY_NAMES 91 | self.n_parts = 6 92 | self.tp = [0] * (self.n_parts + 1) 93 | self.fp = [0] * (self.n_parts + 1) 94 | self.fn = [0] * (self.n_parts + 1) 95 | 96 | @torch.no_grad() 97 | def update(self, pred, gt): 98 | pred, gt = pred.squeeze(), gt.squeeze() 99 | valid = (gt != 255) 100 | 101 | for i_part in range(self.n_parts + 1): 102 | tmp_gt = (gt == i_part) 103 | tmp_pred = (pred == i_part) 104 | self.tp[i_part] += torch.sum(tmp_gt & tmp_pred & (valid)).item() 105 | self.fp[i_part] += torch.sum(~tmp_gt & tmp_pred & (valid)).item() 106 | self.fn[i_part] += torch.sum(tmp_gt & ~tmp_pred & (valid)).item() 107 | 108 | def reset(self): 109 | self.tp = [0] * (self.n_parts + 1) 110 | self.fp = [0] * (self.n_parts + 1) 111 | self.fn = [0] * (self.n_parts + 1) 112 | 113 | def get_score(self, verbose=True): 114 | jac = [0] * (self.n_parts + 1) 115 | for i_part in range(0, self.n_parts + 1): 116 | jac[i_part] = float( 117 | self.tp[i_part]) / max(float(self.tp[i_part] + self.fp[i_part] + self.fn[i_part]), 1e-8) 118 | 119 | eval_result = dict() 120 | eval_result['jaccards_all_categs'] = jac 121 | eval_result['mIoU'] = np.mean(jac) 122 | 123 | print('\nHuman Parts mIoU: {0:.4f}\n'.format( 124 | 100 * eval_result['mIoU'])) 125 | class_IoU = jac 126 | for i in range(len(class_IoU)): 127 | spaces = '' 128 | for j in range(0, 15 - len(self.cat_names[i])): 129 | spaces += ' ' 130 | print('{0:s}{1:s}{2:.4f}'.format( 131 | self.cat_names[i], spaces, 100 * class_IoU[i])) 132 | 133 | return eval_result 134 | 135 | 136 | def eval_human_parts_predictions(database, save_dir, overfit=False): 137 | """ Evaluate the human parts predictions that are stored in the save dir """ 138 | 139 | # Dataloaders 140 | if database == 'PASCALContext': 141 | from data.pascal_context import PASCALContext 142 | gt_set = 'val' 143 | db = PASCALContext(split=gt_set, do_edge=False, do_human_parts=True, do_semseg=False, 144 | do_normals=False, do_sal=False, overfit=overfit) 145 | 146 | else: 147 | raise NotImplementedError 148 | 149 | base_name = database + '_' + 'test' + '_human_parts' 150 | fname = os.path.join(save_dir, base_name + '.json') 151 | 152 | # Eval the model 153 | print('Evaluate the saved images (human parts)') 154 | eval_results = eval_human_parts(db, os.path.join(save_dir, 'human_parts')) 155 | with open(fname, 'w') as f: 156 | json.dump(eval_results, f) 157 | 158 | # Print Results 159 | class_IoU = eval_results['jaccards_all_categs'] 160 | mIoU = eval_results['mIoU'] 161 | 162 | print('\nHuman Parts mIoU: {0:.4f}\n'.format(100 * mIoU)) 163 | for i in range(len(class_IoU)): 164 | spaces = '' 165 | for j in range(0, 15 - len(PART_CATEGORY_NAMES[i])): 166 | spaces += ' ' 167 | print('{0:s}{1:s}{2:.4f}'.format( 168 | PART_CATEGORY_NAMES[i], spaces, 100 * class_IoU[i])) 169 | 170 | return eval_results 171 | -------------------------------------------------------------------------------- /evaluation/eval_normals.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # MTLoRA 3 | # GitHub: https://github.com/scale-lab/MTLoRA 4 | # Copyright (c) 2024 SCALE Lab, Brown University 5 | # Licensed under the MIT License (see LICENSE for details). 6 | # -------------------------------------------------------- 7 | 8 | 9 | import torch 10 | 11 | from evaluation.eval_normals_v1 import NormalsMeterV1 12 | from evaluation.eval_normals_v2 import NormalsMeterV2 13 | 14 | 15 | class NormalsMeter(object): 16 | def __init__(self): 17 | self.v1 = NormalsMeterV1() 18 | self.v2 = NormalsMeterV2() 19 | 20 | @torch.no_grad() 21 | def update(self, pred, gt): 22 | self.v1.update(pred.clone(), gt.clone()) 23 | self.v2.update(pred, gt) 24 | 25 | def reset(self): 26 | self.v1.reset() 27 | self.v2.reset() 28 | 29 | def get_score(self, verbose=True): 30 | eval_v1 = self.v1.get_score(verbose=False) 31 | eval_v2 = self.v2.get_score(verbose=False) 32 | eval_result = { 33 | 'mean': eval_v1['mean'], 34 | 'rmse': eval_v1['rmse'], 35 | 'mean_v2': eval_v2['mean'], 36 | 'rmse_v2': eval_v2['rmse'], 37 | } 38 | 39 | if verbose: 40 | print('\nResults for Surface Normal Estimation') 41 | print('mean: {:.4f}'.format(eval_v1['mean'])) 42 | print('rmse: {:.4f}'.format(eval_v1['rmse'])) 43 | print('mean_v2: {:.4f}'.format(eval_v2['mean'])) 44 | print('rmse_v2: {:.4f}'.format(eval_v2['rmse'])) 45 | 46 | return eval_result 47 | -------------------------------------------------------------------------------- /evaluation/eval_normals_v1.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # MTLoRA 3 | # GitHub: https://github.com/scale-lab/MTLoRA 4 | # 5 | # Original file: 6 | # License: Attribution-NonCommercial 4.0 International (https://github.com/facebookresearch/astmt/) 7 | # Copyright (c) Facebook, Inc. and its affiliates. 8 | # 9 | # Modifications: 10 | # Copyright (c) 2024 SCALE Lab, Brown University 11 | # Licensed under the MIT License (see LICENSE for details) 12 | 13 | import numpy as np 14 | import math 15 | import torch 16 | 17 | 18 | def normal_ize(arr): 19 | arr_norm = np.linalg.norm(arr, ord=2, axis=2)[..., np.newaxis] + 1e-12 20 | return arr / arr_norm 21 | 22 | 23 | class NormalsMeterV1(object): 24 | def __init__(self): 25 | self.eval_dict = {'mean': 0., 'rmse': 0., 26 | '11.25': 0., '22.5': 0., '30': 0., 'n': 0} 27 | 28 | @torch.no_grad() 29 | def update(self, pred, gt): 30 | # Performance measurement happens in pixel wise fashion (Same as code from ASTMT (above)) 31 | pred = 2 * pred / 255 - 1 32 | pred = pred.permute(0, 3, 1, 2) # [B, C, H, W] 33 | valid_mask = (gt != 255) 34 | invalid_mask = (gt == 255) 35 | 36 | # Put zeros where mask is invalid 37 | pred[invalid_mask] = 0.0 38 | gt[invalid_mask] = 0.0 39 | 40 | # Calculate difference expressed in degrees 41 | deg_diff_tmp = ( 42 | 180 / math.pi) * (torch.acos(torch.clamp(torch.sum(pred * gt, 1), min=-1, max=1))) 43 | deg_diff_tmp = torch.masked_select(deg_diff_tmp, valid_mask[:, 0]) 44 | 45 | self.eval_dict['mean'] += torch.sum(deg_diff_tmp).item() 46 | self.eval_dict['rmse'] += torch.sum( 47 | torch.sqrt(torch.pow(deg_diff_tmp, 2))).item() 48 | self.eval_dict['11.25'] += torch.sum( 49 | (deg_diff_tmp < 11.25).float()).item() * 100 50 | self.eval_dict['22.5'] += torch.sum( 51 | (deg_diff_tmp < 22.5).float()).item() * 100 52 | self.eval_dict['30'] += torch.sum((deg_diff_tmp < 53 | 30).float()).item() * 100 54 | self.eval_dict['n'] += deg_diff_tmp.numel() 55 | 56 | def reset(self): 57 | self.eval_dict = {'mean': 0., 'rmse': 0., 58 | '11.25': 0., '22.5': 0., '30': 0., 'n': 0} 59 | 60 | def get_score(self, verbose=True): 61 | eval_result = dict() 62 | eval_result['mean'] = self.eval_dict['mean'] / self.eval_dict['n'] 63 | eval_result['rmse'] = self.eval_dict['mean'] / self.eval_dict['n'] 64 | eval_result['11.25'] = self.eval_dict['11.25'] / self.eval_dict['n'] 65 | eval_result['22.5'] = self.eval_dict['22.5'] / self.eval_dict['n'] 66 | eval_result['30'] = self.eval_dict['30'] / self.eval_dict['n'] 67 | 68 | if verbose: 69 | print('Results for Surface Normal Estimation') 70 | for x in eval_result: 71 | spaces = "" 72 | for j in range(0, 15 - len(x)): 73 | spaces += ' ' 74 | print('{0:s}{1:s}{2:.4f}'.format(x, spaces, eval_result[x])) 75 | 76 | return eval_result 77 | -------------------------------------------------------------------------------- /evaluation/eval_normals_v2.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # MTLoRA 3 | # GitHub: https://github.com/scale-lab/MTLoRA 4 | # 5 | # Original file: 6 | # License: Attribution-NonCommercial 4.0 International (https://github.com/facebookresearch/astmt/) 7 | # Copyright (c) Facebook, Inc. and its affiliates. 8 | # 9 | # Modifications: 10 | # Copyright (c) 2024 SCALE Lab, Brown University 11 | # Licensed under the MIT License (see LICENSE for details) 12 | 13 | import torch 14 | 15 | 16 | def normalize_tensor(input_tensor, dim): 17 | norm = torch.norm(input_tensor, p='fro', dim=dim, keepdim=True) 18 | zero_mask = (norm == 0) 19 | norm[zero_mask] = 1 20 | out = input_tensor.div(norm) 21 | out[zero_mask.expand_as(out)] = 0 22 | return out 23 | 24 | 25 | class NormalsMeterV2(object): 26 | def __init__(self, ignore_index=255): 27 | self.sum_deg_diff = 0 28 | self.total = 0 29 | self.ignore_index = ignore_index 30 | 31 | @torch.no_grad() 32 | def update(self, pred, gt): 33 | pred = pred.permute(0, 3, 1, 2) 34 | pred = 2 * pred / 255 - 1 35 | valid_mask = (gt != self.ignore_index).all(dim=1) 36 | 37 | pred = normalize_tensor(pred, dim=1) 38 | gt = normalize_tensor(gt, dim=1) 39 | deg_diff = torch.rad2deg( 40 | 2 * torch.atan2(torch.norm(pred - gt, dim=1), torch.norm(pred + gt, dim=1))) 41 | deg_diff = torch.masked_select(deg_diff, valid_mask) 42 | 43 | self.sum_deg_diff += torch.sum(deg_diff).cpu().item() 44 | self.total += deg_diff.numel() 45 | 46 | def get_score(self, verbose=False): 47 | eval_result = dict() 48 | eval_result['mean'] = self.sum_deg_diff / self.total 49 | eval_result['rmse'] = self.sum_deg_diff / self.total 50 | 51 | if verbose: 52 | print('Results for Surface Normal Estimation') 53 | print('mean: {:.3f}'.format(eval_result['mean'])) 54 | print('rmse: {:.3f}'.format(eval_result['rmse'])) 55 | 56 | return eval_result 57 | -------------------------------------------------------------------------------- /evaluation/eval_sal.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # MTLoRA 3 | # GitHub: https://github.com/scale-lab/MTLoRA 4 | # 5 | # Original file: 6 | # License: Attribution-NonCommercial 4.0 International (https://github.com/facebookresearch/astmt/) 7 | # Copyright (c) Facebook, Inc. and its affiliates. 8 | # 9 | # Modifications: 10 | # Copyright (c) 2024 SCALE Lab, Brown University 11 | # Licensed under the MIT License (see LICENSE for details) 12 | 13 | import numpy as np 14 | import torch 15 | from evaluation.eval_sal_no_beta import SaliencyMeterWithNoBeta 16 | from evaluation.eval_sal_beta import SaliencyMeterWithBeta 17 | 18 | import evaluation.jaccard as evaluation 19 | 20 | 21 | class SaliencyMeter(object): 22 | def __init__(self, ignore_index=255, threshold_step=0.05, beta_squared=0.3): 23 | self.no_beta = SaliencyMeterWithNoBeta() 24 | self.with_beta = SaliencyMeterWithBeta( 25 | ignore_index=ignore_index, threshold_step=threshold_step, beta_squared=beta_squared) 26 | 27 | @torch.no_grad() 28 | def update(self, pred, gt): 29 | self.no_beta.update(pred, gt) 30 | self.with_beta.update(pred, gt) 31 | 32 | def reset(self): 33 | self.no_beta.reset() 34 | self.with_beta.reset() 35 | 36 | def get_score(self, verbose=True): 37 | no_beta_result = self.no_beta.get_score(verbose=False) 38 | with_beta_result = self.with_beta.get_score(verbose=False) 39 | eval_result = { 40 | 'Beta maxF': with_beta_result['maxF'], 41 | 'maxF': no_beta_result['maxF'], 42 | 'mIoU': no_beta_result['mIoU'], 43 | } 44 | 45 | if verbose: 46 | print('\nResults for Saliency Estimation') 47 | print('Beta maxF: {:.3f}'.format(100.0 * with_beta_result['maxF'])) 48 | print('maxF: {:.3f}'.format(100.0 * no_beta_result['maxF'])) 49 | print('mIoU: {:.3f}'.format(100.0 * no_beta_result['mIoU'])) 50 | 51 | return eval_result 52 | -------------------------------------------------------------------------------- /evaluation/eval_sal_beta.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # MTLoRA 3 | # GitHub: https://github.com/scale-lab/MTLoRA 4 | # 5 | # Original file: 6 | # License: Attribution-NonCommercial 4.0 International (https://github.com/facebookresearch/astmt/) 7 | # Copyright (c) Facebook, Inc. and its affiliates. 8 | # 9 | # Modifications: 10 | # Copyright (c) 2024 SCALE Lab, Brown University 11 | # Licensed under the MIT License (see LICENSE for details) 12 | 13 | import torch 14 | from torch import nn 15 | 16 | import evaluation.jaccard as evaluation 17 | 18 | 19 | class SaliencyMeterWithBeta(object): 20 | def __init__(self, ignore_index=255, threshold_step=0.05, beta_squared=0.3): 21 | self.ignore_index = ignore_index 22 | self.beta_squared = beta_squared 23 | self.thresholds = torch.arange(threshold_step, 1, threshold_step) 24 | self.true_positives = torch.zeros(len(self.thresholds)) 25 | self.predicted_positives = torch.zeros(len(self.thresholds)) 26 | self.actual_positives = torch.zeros(len(self.thresholds)) 27 | 28 | @torch.no_grad() 29 | def update(self, preds, target): 30 | """ 31 | Update state with predictions and targets. 32 | 33 | Args: 34 | preds: Predictions from model [B, H, W] 35 | target: Ground truth values 36 | """ 37 | preds = preds.float() / 255. 38 | 39 | if target.shape[1] == 1 and len(target.shape) == 4: 40 | target = target.squeeze(1) 41 | if len(preds.shape) == 2: 42 | preds = preds.unsqueeze(0) 43 | 44 | # assert preds.shape == target.shape, f"preds shape {preds.shape} does not match target shape {target.shape}" 45 | 46 | if len(preds.shape) == len(target.shape) + 1: 47 | assert preds.shape[1] == 2 48 | # two class probabilites 49 | preds = nn.functional.softmax(preds, dim=1)[:, 1, :, :] 50 | else: 51 | # squash logits into probabilities 52 | preds = torch.sigmoid(preds) 53 | 54 | if not len(preds.shape) == len(target.shape): 55 | raise ValueError( 56 | f"preds and target must have same number of dimensions, or preds one more, but got {preds.shape} and {target.shape}") 57 | 58 | valid_mask = (target != self.ignore_index) 59 | 60 | for idx, thresh in enumerate(self.thresholds): 61 | # threshold probablities 62 | f_preds = (preds >= thresh).long() 63 | f_target = target.long() 64 | 65 | f_preds = torch.masked_select(f_preds, valid_mask) 66 | f_target = torch.masked_select(f_target, valid_mask) 67 | 68 | self.true_positives[idx] += torch.sum(f_preds * f_target).cpu() 69 | self.predicted_positives[idx] += torch.sum(f_preds).cpu() 70 | self.actual_positives[idx] += torch.sum(f_target).cpu() 71 | 72 | def get_score(self, verbose=True): 73 | """ 74 | Computes F-scores over state and returns the max. 75 | """ 76 | precision = self.true_positives.float() / self.predicted_positives 77 | recall = self.true_positives.float() / self.actual_positives 78 | 79 | num = (1 + self.beta_squared) * precision * recall 80 | denom = self.beta_squared * precision + recall 81 | 82 | # For the rest we need to take care of instances where the denom can be 0 83 | # for some classes which will produce nans for that class 84 | fscore = num / denom 85 | fscore[fscore != fscore] = 0 86 | 87 | eval_result = {'maxF': fscore.max().item()} 88 | if verbose: 89 | print('Results for Saliency Estimation') 90 | print('maxF: {:.3f}'.format(100.0 * eval_result['maxF'])) 91 | return eval_result 92 | -------------------------------------------------------------------------------- /evaluation/eval_sal_no_beta.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # MTLoRA 3 | # GitHub: https://github.com/scale-lab/MTLoRA 4 | # 5 | # Original file: 6 | # License: Attribution-NonCommercial 4.0 International (https://github.com/facebookresearch/astmt/) 7 | # Copyright (c) Facebook, Inc. and its affiliates. 8 | # 9 | # Modifications: 10 | # Copyright (c) 2024 SCALE Lab, Brown University 11 | # Licensed under the MIT License (see LICENSE for details) 12 | 13 | import warnings 14 | import cv2 15 | import os.path 16 | import numpy as np 17 | import glob 18 | import json 19 | import torch 20 | from PIL import Image 21 | 22 | import evaluation.jaccard as evaluation 23 | 24 | 25 | class SaliencyMeterWithNoBeta(object): 26 | def __init__(self): 27 | self.mask_thres = np.linspace(0.2, 0.9, 15) # As below 28 | self.all_jacards = [] 29 | self.prec = [] 30 | self.rec = [] 31 | 32 | @torch.no_grad() 33 | def update(self, pred, gt): 34 | # Predictions and ground-truth 35 | b = pred.size(0) 36 | pred = pred.float().squeeze() / 255. 37 | gt = gt.squeeze().cpu().numpy() 38 | 39 | # Allocate memory for batch results 40 | jaccards = np.zeros((b, len(self.mask_thres))) 41 | prec = np.zeros((b, len(self.mask_thres))) 42 | rec = np.zeros((b, len(self.mask_thres))) 43 | 44 | for j, thres in enumerate(self.mask_thres): 45 | # gt_eval = (gt > thres).cpu().numpy() # Removed this from ASTMT code. GT is already binarized. 46 | mask_eval = (pred > thres).cpu().numpy() 47 | for i in range(b): 48 | jaccards[i, j] = evaluation.jaccard(gt[i], mask_eval[i]) 49 | prec[i, j], rec[i, j] = evaluation.precision_recall( 50 | gt[i], mask_eval[i]) 51 | 52 | self.all_jacards.append(jaccards) 53 | self.prec.append(prec) 54 | self.rec.append(rec) 55 | 56 | def reset(self): 57 | self.all_jacards = [] 58 | self.prec = [] 59 | self.rec = [] 60 | 61 | def get_score(self, verbose=True): 62 | eval_result = dict() 63 | 64 | # Concatenate batched results 65 | eval_result['all_jaccards'] = np.concatenate(self.all_jacards) 66 | eval_result['prec'] = np.concatenate(self.prec) 67 | eval_result['rec'] = np.concatenate(self.rec) 68 | 69 | # Average for each threshold 70 | eval_result['mIoUs'] = np.mean(eval_result['all_jaccards'], 0) 71 | 72 | eval_result['mPrec'] = np.mean(eval_result['prec'], 0) 73 | eval_result['mRec'] = np.mean(eval_result['rec'], 0) 74 | eval_result['F'] = 2 * eval_result['mPrec'] * eval_result['mRec'] / \ 75 | (eval_result['mPrec'] + eval_result['mRec'] + 1e-12) 76 | 77 | # Maximum of averages (maxF, maxmIoU) 78 | eval_result['mIoU'] = np.max(eval_result['mIoUs']) 79 | eval_result['maxF'] = np.max(eval_result['F']) 80 | 81 | eval_result = {x: eval_result[x].tolist() for x in eval_result} 82 | 83 | if verbose: 84 | # Print the results 85 | print('Results for Saliency Estimation') 86 | print('mIoU: {0:.3f}'.format(100.0 * eval_result['mIoU'])) 87 | print('maxF: {0:.3f}'.format(100.0 * eval_result['maxF'])) 88 | 89 | return eval_result 90 | -------------------------------------------------------------------------------- /evaluation/eval_semseg.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # MTLoRA 3 | # GitHub: https://github.com/scale-lab/MTLoRA 4 | # 5 | # Original file: 6 | # License: Attribution-NonCommercial 4.0 International (https://github.com/facebookresearch/astmt/) 7 | # Copyright (c) Facebook, Inc. and its affiliates. 8 | # 9 | # Modifications: 10 | # Copyright (c) 2024 SCALE Lab, Brown University 11 | # Licensed under the MIT License (see LICENSE for details) 12 | 13 | import warnings 14 | import cv2 15 | import os.path 16 | import glob 17 | import json 18 | import numpy as np 19 | import torch 20 | from PIL import Image 21 | 22 | 23 | VOC_CATEGORY_NAMES = ['background', 24 | 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 25 | 'bus', 'car', 'cat', 'chair', 'cow', 26 | 'diningtable', 'dog', 'horse', 'motorbike', 'person', 27 | 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'] 28 | 29 | 30 | NYU_CATEGORY_NAMES = ['wall', 'floor', 'cabinet', 'bed', 'chair', 31 | 'sofa', 'table', 'door', 'window', 'bookshelf', 32 | 'picture', 'counter', 'blinds', 'desk', 'shelves', 33 | 'curtain', 'dresser', 'pillow', 'mirror', 'floor mat', 34 | 'clothes', 'ceiling', 'books', 'refridgerator', 'television', 35 | 'paper', 'towel', 'shower curtain', 'box', 'whiteboard', 36 | 'person', 'night stand', 'toilet', 'sink', 'lamp', 37 | 'bathtub', 'bag', 'otherstructure', 'otherfurniture', 'otherprop'] 38 | 39 | 40 | def eval_semseg(loader, folder, n_classes=20, has_bg=True, ignore_index=255): 41 | 42 | n_classes = n_classes + int(has_bg) 43 | 44 | # Iterate 45 | tp = [0] * n_classes 46 | fp = [0] * n_classes 47 | fn = [0] * n_classes 48 | 49 | for i, sample in enumerate(loader): 50 | 51 | if i % 500 == 0: 52 | print('Evaluating: {} of {} objects'.format(i, len(loader))) 53 | 54 | # Load result 55 | filename = os.path.join(folder, sample['meta']['image'] + '.png') 56 | mask = np.array(Image.open(filename)).astype(float) 57 | 58 | gt = sample['semseg'] 59 | valid = (gt != ignore_index) 60 | 61 | if mask.shape != gt.shape: 62 | warnings.warn( 63 | 'Prediction and ground truth have different size. Resizing Prediction..') 64 | mask = cv2.resize( 65 | mask, gt.shape[::-1], interpolation=cv2.INTER_NEAREST) 66 | 67 | # TP, FP, and FN evaluation 68 | for i_part in range(0, n_classes): 69 | tmp_gt = (gt == i_part) 70 | tmp_pred = (mask == i_part) 71 | tp[i_part] += np.sum(tmp_gt & tmp_pred & valid) 72 | fp[i_part] += np.sum(~tmp_gt & tmp_pred & valid) 73 | fn[i_part] += np.sum(tmp_gt & ~tmp_pred & valid) 74 | 75 | jac = [0] * n_classes 76 | for i_part in range(0, n_classes): 77 | jac[i_part] = float( 78 | tp[i_part]) / max(float(tp[i_part] + fp[i_part] + fn[i_part]), 1e-8) 79 | 80 | # Write results 81 | eval_result = dict() 82 | eval_result['jaccards_all_categs'] = jac 83 | eval_result['mIoU'] = np.mean(jac) 84 | 85 | return eval_result 86 | 87 | 88 | class SemsegMeter(object): 89 | def __init__(self, database, config): 90 | if database == 'PASCALContext': 91 | n_classes = 20 92 | cat_names = VOC_CATEGORY_NAMES 93 | has_bg = True 94 | ignore_index = 255 95 | elif database == 'NYUD': 96 | n_classes = 40 97 | cat_names = NYU_CATEGORY_NAMES 98 | has_bg = False 99 | ignore_index = 255 100 | else: 101 | raise NotImplementedError 102 | self.ignore_index = ignore_index 103 | self.n_classes = n_classes + int(has_bg) 104 | self.cat_names = cat_names 105 | self.tp = [0] * self.n_classes 106 | self.fp = [0] * self.n_classes 107 | self.fn = [0] * self.n_classes 108 | 109 | @torch.no_grad() 110 | def update(self, pred, gt): 111 | pred = pred.squeeze() 112 | gt = gt.squeeze() 113 | valid = (gt != self.ignore_index) 114 | 115 | for i_part in range(0, self.n_classes): 116 | tmp_gt = (gt == i_part) 117 | tmp_pred = (pred == i_part) 118 | self.tp[i_part] += torch.sum(tmp_gt & tmp_pred & valid).item() 119 | self.fp[i_part] += torch.sum(~tmp_gt & tmp_pred & valid).item() 120 | self.fn[i_part] += torch.sum(tmp_gt & ~tmp_pred & valid).item() 121 | 122 | def reset(self): 123 | self.tp = [0] * self.n_classes 124 | self.fp = [0] * self.n_classes 125 | self.fn = [0] * self.n_classes 126 | 127 | def get_score(self, verbose=True): 128 | jac = [0] * self.n_classes 129 | for i_part in range(self.n_classes): 130 | jac[i_part] = float( 131 | self.tp[i_part]) / max(float(self.tp[i_part] + self.fp[i_part] + self.fn[i_part]), 1e-8) 132 | 133 | eval_result = dict() 134 | eval_result['jaccards_all_categs'] = jac 135 | eval_result['mIoU'] = np.mean(jac) 136 | 137 | if verbose: 138 | print('\nSemantic Segmentation mIoU: {0:.4f}\n'.format( 139 | 100 * eval_result['mIoU'])) 140 | class_IoU = eval_result['jaccards_all_categs'] 141 | for i in range(len(class_IoU)): 142 | spaces = '' 143 | for j in range(0, 20 - len(self.cat_names[i])): 144 | spaces += ' ' 145 | print('{0:s}{1:s}{2:.4f}'.format( 146 | self.cat_names[i], spaces, 100 * class_IoU[i])) 147 | 148 | return eval_result 149 | 150 | 151 | def eval_semseg_predictions(database, save_dir, overfit=False): 152 | """ Evaluate the segmentation maps that are stored in the save dir """ 153 | 154 | # Dataloaders 155 | if database == 'PASCALContext': 156 | from data.pascal_context import PASCALContext 157 | n_classes = 20 158 | cat_names = VOC_CATEGORY_NAMES 159 | has_bg = True 160 | gt_set = 'val' 161 | db = PASCALContext(split=gt_set, do_edge=False, do_human_parts=False, do_semseg=True, 162 | do_normals=False, overfit=overfit) 163 | ignore_index = 255 164 | 165 | elif database == 'NYUD': 166 | from data.nyud import NYUD_MT 167 | n_classes = 40 168 | cat_names = NYU_CATEGORY_NAMES 169 | has_bg = False 170 | gt_set = 'val' 171 | db = NYUD_MT(split=gt_set, do_semseg=True, overfit=overfit) 172 | ignore_index = 255 173 | 174 | else: 175 | raise NotImplementedError 176 | 177 | base_name = database + '_' + 'test' + '_semseg' 178 | fname = os.path.join(save_dir, base_name + '.json') 179 | 180 | # Eval the model 181 | print('Evaluate the saved images (semseg)') 182 | eval_results = eval_semseg(db, os.path.join( 183 | save_dir, 'semseg'), n_classes=n_classes, has_bg=has_bg, ignore_index=ignore_index) 184 | with open(fname, 'w') as f: 185 | json.dump(eval_results, f) 186 | 187 | # Print results 188 | class_IoU = eval_results['jaccards_all_categs'] 189 | mIoU = eval_results['mIoU'] 190 | 191 | print('\nSemantic Segmentation mIoU: {0:.4f}\n'.format(100 * mIoU)) 192 | for i in range(len(class_IoU)): 193 | spaces = '' 194 | for j in range(0, 15 - len(cat_names[i])): 195 | spaces += ' ' 196 | print('{0:s}{1:s}{2:.4f}'.format( 197 | cat_names[i], spaces, 100 * class_IoU[i])) 198 | 199 | return eval_results 200 | -------------------------------------------------------------------------------- /evaluation/evaluate_utils.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # MTLoRA 3 | # GitHub: https://github.com/scale-lab/MTLoRA 4 | # 5 | # Original file: 6 | # Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/) 7 | # Written by Simon Vandenhende 8 | # 9 | # Modifications: 10 | # Copyright (c) 2024 SCALE Lab, Brown University 11 | # Licensed under the MIT License (see LICENSE for details) 12 | # -------------------------------------------------------- 13 | 14 | import os 15 | import numpy as np 16 | import torch 17 | import torch.nn.functional as F 18 | 19 | 20 | def get_output(output, task): 21 | output = output.permute(0, 2, 3, 1) 22 | 23 | if task == 'normals': 24 | output = (F.normalize(output, p=2, dim=3) + 1.0) * 255 / 2.0 25 | 26 | elif task in {'semseg', 'human_parts'}: 27 | _, output = torch.max(output, dim=3) 28 | 29 | elif task in {'edge', 'sal'}: 30 | output = torch.squeeze(255 * 1 / (1 + torch.exp(-output))) 31 | 32 | elif task in {'depth'}: 33 | pass 34 | 35 | else: 36 | raise ValueError('Select one of the valid tasks') 37 | 38 | return output 39 | 40 | 41 | class PerformanceMeter(object): 42 | """ A general performance meter which shows performance across one or more tasks """ 43 | 44 | def __init__(self, config, db_name="NYUD"): 45 | self.database = db_name 46 | self.tasks = config.TASKS 47 | self.meters = {t: get_single_task_meter(config, 48 | t, self.database) for t in self.tasks} 49 | 50 | def reset(self): 51 | for t in self.tasks: 52 | self.meters[t].reset() 53 | 54 | def update(self, pred, gt): 55 | for t in self.tasks: 56 | self.meters[t].update(pred[t], gt[t]) 57 | 58 | def get_score(self, verbose=True): 59 | eval_dict = {} 60 | for t in self.tasks: 61 | eval_dict[t] = self.meters[t].get_score(verbose) 62 | 63 | return eval_dict 64 | 65 | 66 | def calculate_multi_task_performance(eval_dict, single_task_dict): 67 | assert (set(eval_dict.keys()) == set(single_task_dict.keys())) 68 | tasks = eval_dict.keys() 69 | num_tasks = len(tasks) 70 | mtl_performance = 0.0 71 | 72 | for task in tasks: 73 | mtl = eval_dict[task] 74 | stl = single_task_dict[task] 75 | 76 | if task == 'depth': # rmse lower is better 77 | mtl_performance -= (mtl['rmse'] - stl['rmse'])/stl['rmse'] 78 | 79 | elif task in ['semseg', 'sal', 'human_parts']: # mIoU higher is better 80 | mtl_performance += (mtl['mIoU'] - stl['mIoU'])/stl['mIoU'] 81 | 82 | elif task == 'normals': # mean error lower is better 83 | mtl_performance -= (mtl['mean'] - stl['mean'])/stl['mean'] 84 | 85 | elif task == 'edge': # odsF higher is better 86 | mtl_performance += (mtl['odsF'] - stl['odsF'])/stl['odsF'] 87 | 88 | else: 89 | raise NotImplementedError 90 | 91 | return mtl_performance / num_tasks 92 | 93 | # TODO change database to handle more datasets 94 | 95 | 96 | def get_single_task_meter(config, task, database="NYUD"): 97 | """ Retrieve a meter to measure the single-task performance """ 98 | if task == 'semseg': 99 | from evaluation.eval_semseg import SemsegMeter 100 | return SemsegMeter(database, config) 101 | 102 | elif task == 'human_parts': 103 | from evaluation.eval_human_parts import HumanPartsMeter 104 | return HumanPartsMeter(database) 105 | 106 | elif task == 'normals': 107 | from evaluation.eval_normals import NormalsMeter 108 | return NormalsMeter() 109 | 110 | elif task == 'sal': 111 | from evaluation.eval_sal import SaliencyMeter 112 | return SaliencyMeter() 113 | 114 | elif task == 'depth': 115 | from evaluation.eval_depth import DepthMeter 116 | return DepthMeter() 117 | 118 | # Single task performance meter uses the loss (True evaluation is based on seism evaluation) 119 | elif task == 'edge': 120 | from evaluation.eval_edge import EdgeMeter 121 | # TODO: get edge_w from task config 122 | return EdgeMeter(pos_weight=0.95) 123 | # return EdgeMeter() 124 | 125 | else: 126 | raise NotImplementedError 127 | -------------------------------------------------------------------------------- /evaluation/jaccard.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # MTLoRA 3 | # GitHub: https://github.com/scale-lab/MTLoRA 4 | # 5 | # Original file: 6 | # License: Attribution-NonCommercial 4.0 International (https://github.com/facebookresearch/astmt/) 7 | # Copyright (c) Facebook, Inc. and its affiliates. 8 | # 9 | # Modifications: 10 | # Copyright (c) 2024 SCALE Lab, Brown University 11 | # Licensed under the MIT License (see LICENSE for details) 12 | 13 | import numpy as np 14 | 15 | 16 | def jaccard(gt, pred, void_pixels=None): 17 | 18 | assert (gt.shape == pred.shape) 19 | 20 | if void_pixels is None: 21 | void_pixels = np.zeros_like(gt) 22 | assert (void_pixels.shape == gt.shape) 23 | 24 | gt = gt.astype(bool) 25 | pred = pred.astype(bool) 26 | void_pixels = void_pixels.astype(bool) 27 | if np.isclose(np.sum(gt & np.logical_not(void_pixels)), 0) and np.isclose(np.sum(pred & np.logical_not(void_pixels)), 0): 28 | return 1 29 | 30 | else: 31 | return np.sum(((gt & pred) & np.logical_not(void_pixels))) / \ 32 | np.sum(((gt | pred) & np.logical_not(void_pixels)), dtype=float) 33 | 34 | 35 | def precision_recall(gt, pred, void_pixels=None): 36 | 37 | if void_pixels is None: 38 | void_pixels = np.zeros_like(gt) 39 | 40 | gt = gt.astype(bool) 41 | pred = pred.astype(bool) 42 | void_pixels = void_pixels.astype(bool) 43 | 44 | tp = ((pred & gt) & ~void_pixels).sum() 45 | fn = ((~pred & gt) & ~void_pixels).sum() 46 | 47 | fp = ((pred & ~gt) & ~void_pixels).sum() 48 | 49 | prec = tp / (tp + fp + 1e-12) 50 | rec = tp / (tp + fn + 1e-12) 51 | 52 | return prec, rec 53 | -------------------------------------------------------------------------------- /kernels/window_process/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 3 | 4 | 5 | setup(name='swin_window_process', 6 | ext_modules=[ 7 | CUDAExtension('swin_window_process', [ 8 | 'swin_window_process.cpp', 9 | 'swin_window_process_kernel.cu', 10 | ]) 11 | ], 12 | cmdclass={'build_ext': BuildExtension}) -------------------------------------------------------------------------------- /kernels/window_process/swin_window_process.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #include 18 | #include 19 | 20 | 21 | at::Tensor roll_and_window_partition_forward_cuda( 22 | at::Tensor & input, 23 | //at::Tensor & output, 24 | const int B, 25 | const int H, 26 | const int W, 27 | const int C, 28 | const int shift_size, 29 | const int window_size); 30 | 31 | 32 | at::Tensor roll_and_window_partition_backward_cuda( 33 | at::Tensor & grad_in, 34 | //at::Tensor & grad_out, 35 | const int B, 36 | const int H, 37 | const int W, 38 | const int C, 39 | const int shift_size, 40 | const int window_size); 41 | 42 | 43 | at::Tensor window_merge_and_roll_forward_cuda( 44 | at::Tensor & input, 45 | //at::Tensor & output, 46 | const int B, 47 | const int H, 48 | const int W, 49 | const int C, 50 | const int shift_size, 51 | const int window_size); 52 | 53 | at::Tensor window_merge_and_roll_backward_cuda( 54 | at::Tensor & grad_in, 55 | //at::Tensor & grad_out, 56 | const int B, 57 | const int H, 58 | const int W, 59 | const int C, 60 | const int shift_size, 61 | const int window_size); 62 | 63 | 64 | #define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") 65 | #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") 66 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 67 | 68 | 69 | 70 | at::Tensor roll_and_window_partition_forward( 71 | at::Tensor & input, 72 | //at::Tensor & output, 73 | const int B, 74 | const int H, 75 | const int W, 76 | const int C, 77 | const int shift_size, 78 | const int window_size){ 79 | CHECK_INPUT(input); 80 | return roll_and_window_partition_forward_cuda(input, B, H, W, C, shift_size, window_size); 81 | } 82 | 83 | 84 | at::Tensor roll_and_window_partition_backward( 85 | at::Tensor & grad_in, 86 | //at::Tensor & grad_out, 87 | const int B, 88 | const int H, 89 | const int W, 90 | const int C, 91 | const int shift_size, 92 | const int window_size){ 93 | CHECK_INPUT(grad_in); 94 | return roll_and_window_partition_backward_cuda(grad_in, B, H, W, C, shift_size, window_size); 95 | } 96 | 97 | 98 | at::Tensor window_merge_and_roll_forward( 99 | at::Tensor & input, 100 | //at::Tensor & output, 101 | const int B, 102 | const int H, 103 | const int W, 104 | const int C, 105 | const int shift_size, 106 | const int window_size){ 107 | CHECK_INPUT(input); 108 | return window_merge_and_roll_forward_cuda(input, B, H, W, C, shift_size, window_size); 109 | } 110 | 111 | 112 | at::Tensor window_merge_and_roll_backward( 113 | at::Tensor & grad_in, 114 | //at::Tensor & grad_out, 115 | const int B, 116 | const int H, 117 | const int W, 118 | const int C, 119 | const int shift_size, 120 | const int window_size){ 121 | CHECK_INPUT(grad_in); 122 | return window_merge_and_roll_backward_cuda(grad_in, B, H, W, C, shift_size, window_size); 123 | } 124 | 125 | 126 | 127 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 128 | m.def("roll_and_window_partition_forward", &roll_and_window_partition_forward, "torch.roll and window_partition."); 129 | m.def("roll_and_window_partition_backward", &roll_and_window_partition_backward, "torch.roll and window_partition."); 130 | m.def("window_merge_and_roll_forward", &window_merge_and_roll_forward, "window merge and torch.roll."); 131 | m.def("window_merge_and_roll_backward", &window_merge_and_roll_backward, "window merge and torch.roll."); 132 | } -------------------------------------------------------------------------------- /kernels/window_process/swin_window_process_kernel.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #include 18 | #include 19 | #include 20 | #include 21 | #include 22 | #include 23 | 24 | int best_block_dim(int feat_dim){ 25 | int best_dim; 26 | if (feat_dim < 384){ 27 | best_dim = 64; 28 | } 29 | else{ 30 | if (feat_dim < 1024){ 31 | best_dim = 128; 32 | } 33 | else{ 34 | best_dim = 256; 35 | } 36 | } 37 | return best_dim; 38 | } 39 | 40 | 41 | template 42 | __global__ void roll_and_window_partition_forward_cuda_kernel( 43 | T* input, 44 | T* output, 45 | const int B, 46 | const int H, 47 | const int W, 48 | const int C, 49 | const int shift_size, 50 | const int window_size, 51 | const int nH, 52 | const int nW){ 53 | // start 54 | //bool qual = threadIdx.x < C; 55 | int index = threadIdx.x; 56 | int offset; 57 | for (int i = index; i < C; i += blockDim.x) { 58 | offset = ((blockIdx.z * gridDim.y + blockIdx.y) * gridDim.x + blockIdx.x) * C + i; // C = blocksize 59 | int input_offset = blockIdx.z / (nH * nW) * H * W * C + 60 | (blockIdx.z % (nH * nW) / nW * window_size + blockIdx.y - shift_size + H) % H * W * C + 61 | (blockIdx.z % nW * window_size + blockIdx.x - shift_size + W) % W * C + 62 | i; 63 | output[offset] = (T)(__ldg(input + input_offset)); 64 | } 65 | } 66 | 67 | 68 | template 69 | __global__ void roll_and_window_partition_backward_cuda_kernel( 70 | T* grad_in, 71 | T* grad_out, 72 | const int B, 73 | const int H, 74 | const int W, 75 | const int C, 76 | const int shift_size, 77 | const int window_size, 78 | const int nH, 79 | const int nW){ 80 | // start 81 | int index = threadIdx.x; 82 | int offset; 83 | for (int i = index; i < C; i += blockDim.x) { 84 | offset = ((blockIdx.z * gridDim.y + blockIdx.y) * gridDim.x + blockIdx.x) * C + i; // C = blocksize 85 | int input_offset = 86 | (blockIdx.z * nH * nW + (blockIdx.y + shift_size + H) % H / window_size * nW + (blockIdx.x + shift_size + W) % W / window_size) * window_size * window_size * C + 87 | (blockIdx.y + shift_size + H ) % H % window_size * window_size * C + 88 | (blockIdx.x + shift_size + W ) % W % window_size * C + 89 | i; 90 | grad_out[offset] = (T)(__ldg(grad_in + input_offset)); 91 | } 92 | } 93 | 94 | 95 | template 96 | __global__ void window_merge_and_roll_forward_cuda_kernel( 97 | T* input, 98 | T* output, 99 | const int B, 100 | const int H, 101 | const int W, 102 | const int C, 103 | const int shift_size, 104 | const int window_size, 105 | const int nH, 106 | const int nW){ 107 | // start 108 | int index = threadIdx.x; 109 | int offset; 110 | for (int i = index; i < C; i += blockDim.x) { 111 | offset = ((blockIdx.z * gridDim.y + blockIdx.y) * gridDim.x + blockIdx.x) * C + i; // C = blocksize 112 | int input_offset = 113 | (blockIdx.z * nH * nW + (blockIdx.y - shift_size + H) % H / window_size * nH + (blockIdx.x - shift_size + W) % W / window_size) * window_size * window_size * C + 114 | (blockIdx.y - shift_size + H) % window_size * window_size * C + 115 | (blockIdx.x - shift_size + W) % window_size * C + 116 | i; 117 | output[offset] = (T)(__ldg(input + input_offset)); 118 | } 119 | } 120 | 121 | 122 | 123 | template 124 | __global__ void window_merge_and_roll_backward_cuda_kernel( 125 | T* grad_in, 126 | T* grad_out, 127 | const int B, 128 | const int H, 129 | const int W, 130 | const int C, 131 | const int shift_size, 132 | const int window_size, 133 | const int nH, 134 | const int nW){ 135 | // start 136 | int index = threadIdx.x; 137 | int offset; 138 | for (int i = index; i < C; i += blockDim.x) { 139 | offset = ((blockIdx.z * gridDim.y + blockIdx.y) * gridDim.x + blockIdx.x) * C + i; // C = blocksize 140 | int input_offset = 141 | (blockIdx.z / (nH * nW)) * H * W * C + 142 | (blockIdx.z % (nH * nW) / nW * window_size + blockIdx.y + shift_size + H) % H * W * C + 143 | (blockIdx.z % nW * window_size + blockIdx.x + shift_size + W) % W * C + 144 | i; 145 | grad_out[offset] = (T)(__ldg(grad_in + input_offset)); 146 | } 147 | } 148 | 149 | // input: [B, H, W, C] 150 | // output: [B*nH*nW, window_size, window_size, C] 151 | at::Tensor roll_and_window_partition_forward_cuda( 152 | at::Tensor & input, 153 | //at::Tensor & output, 154 | const int B, 155 | const int H, 156 | const int W, 157 | const int C, 158 | const int shift_size, 159 | const int window_size){ 160 | 161 | int nH = H / window_size; 162 | int nW = W / window_size; 163 | 164 | dim3 grid(window_size, window_size, B * nH * nW); 165 | //dim3 block((C + 31) / 32 * 32); 166 | int blocknum = best_block_dim(C); 167 | dim3 block(blocknum); 168 | 169 | at::Tensor output; 170 | if (input.scalar_type() == torch::kFloat16){ 171 | output = torch::empty({B*nH*nW, window_size, window_size, C}, torch::dtype(torch::kFloat16).device(torch::kCUDA).requires_grad(true)); 172 | } 173 | else{ 174 | output = torch::empty({B*nH*nW, window_size, window_size, C}, torch::dtype(torch::kFloat32).device(torch::kCUDA).requires_grad(true)); 175 | } 176 | 177 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "roll_and_window_partition_forward_cuda_kernel", ([&] { 178 | roll_and_window_partition_forward_cuda_kernel<<>>( 179 | input.data(), 180 | output.data(), 181 | B, 182 | H, 183 | W, 184 | C, 185 | shift_size, 186 | window_size, 187 | nH, 188 | nW); 189 | })); 190 | return output; 191 | } 192 | 193 | 194 | // grad_in: [B*nH*nW, window_size, window_size, C] 195 | // grad_out: [B, H, W, C] 196 | at::Tensor roll_and_window_partition_backward_cuda( 197 | at::Tensor & grad_in, 198 | const int B, 199 | const int H, 200 | const int W, 201 | const int C, 202 | const int shift_size, 203 | const int window_size){ 204 | 205 | int nH = H / window_size; 206 | int nW = W / window_size; 207 | 208 | dim3 grid(W, H, B); 209 | //dim3 block((C + 31) / 32 * 32); 210 | int blocknum = best_block_dim(C); 211 | dim3 block(blocknum); 212 | 213 | at::Tensor grad_out; 214 | if (grad_in.scalar_type() == torch::kFloat16){ 215 | grad_out = torch::empty({B, H, W, C}, torch::dtype(torch::kFloat16).device(torch::kCUDA).requires_grad(false)); 216 | } 217 | else{ 218 | grad_out = torch::empty({B, H, W, C}, torch::dtype(torch::kFloat32).device(torch::kCUDA).requires_grad(false)); 219 | } 220 | 221 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad_in.type(), "roll_and_window_partition_backward_cuda_kernel", ([&] { 222 | roll_and_window_partition_backward_cuda_kernel<<>>( 223 | grad_in.data(), 224 | grad_out.data(), 225 | B, 226 | H, 227 | W, 228 | C, 229 | shift_size, 230 | window_size, 231 | nH, 232 | nW); 233 | })); 234 | return grad_out; 235 | } 236 | 237 | 238 | // input: [B*nH*nW, window_size, window_size, C] 239 | // output: [B, H, W, C] 240 | at::Tensor window_merge_and_roll_forward_cuda( 241 | at::Tensor & input, 242 | //at::Tensor & output, 243 | const int B, 244 | const int H, 245 | const int W, 246 | const int C, 247 | const int shift_size, 248 | const int window_size){ 249 | 250 | int nH = H / window_size; 251 | int nW = W / window_size; 252 | 253 | dim3 grid(W, H, B); 254 | //dim3 block((C + 31) / 32 * 32); 255 | int blocknum = best_block_dim(C); 256 | dim3 block(blocknum); 257 | 258 | //generate output tensor inside 259 | at::Tensor output; 260 | if (input.scalar_type() == torch::kFloat16){ 261 | output = torch::empty({B, H, W, C}, torch::dtype(torch::kFloat16).device(torch::kCUDA).requires_grad(true)); 262 | } 263 | else{ 264 | output = torch::empty({B, H, W, C}, torch::dtype(torch::kFloat32).device(torch::kCUDA).requires_grad(true)); 265 | } 266 | 267 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "window_merge_and_roll_forward_cuda_kernel", ([&] { 268 | window_merge_and_roll_forward_cuda_kernel<<>>( 269 | input.data(), 270 | output.data(), 271 | B, 272 | H, 273 | W, 274 | C, 275 | shift_size, 276 | window_size, 277 | nH, 278 | nW); 279 | })); 280 | return output; 281 | } 282 | 283 | 284 | at::Tensor window_merge_and_roll_backward_cuda( 285 | at::Tensor & grad_in, 286 | const int B, 287 | const int H, 288 | const int W, 289 | const int C, 290 | const int shift_size, 291 | const int window_size){ 292 | 293 | int nH = H / window_size; 294 | int nW = W / window_size; 295 | 296 | dim3 grid(window_size, window_size, B * nH * nW); 297 | //dim3 block((C + 31) / 32 * 32); 298 | int blocknum = best_block_dim(C); 299 | dim3 block(blocknum); 300 | 301 | at::Tensor grad_out; 302 | if (grad_in.scalar_type() == torch::kFloat16){ 303 | grad_out = torch::empty({B*nH*nW, window_size, window_size, C}, torch::dtype(torch::kFloat16).device(torch::kCUDA).requires_grad(false)); 304 | } 305 | else{ 306 | grad_out = torch::empty({B*nH*nW, window_size, window_size, C}, torch::dtype(torch::kFloat32).device(torch::kCUDA).requires_grad(false)); 307 | } 308 | 309 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad_in.type(), "window_merge_and_roll_backward_cuda_kernel", ([&] { 310 | window_merge_and_roll_backward_cuda_kernel<<>>( 311 | grad_in.data(), 312 | grad_out.data(), 313 | B, 314 | H, 315 | W, 316 | C, 317 | shift_size, 318 | window_size, 319 | nH, 320 | nW); 321 | })); 322 | return grad_out; 323 | } -------------------------------------------------------------------------------- /kernels/window_process/unit_test.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Fused kernel for window process for SwinTransformer 3 | # Copyright (c) 2022 Nvidia 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # -------------------------------------------------------- 6 | 7 | import torch 8 | import swin_window_process 9 | import random 10 | import time 11 | import unittest 12 | 13 | 14 | class WindowProcess(torch.autograd.Function): 15 | @staticmethod 16 | def forward(ctx, input, B, H, W, C, shift_size, window_size): 17 | output = swin_window_process.roll_and_window_partition_forward(input, B, H, W, C, shift_size, window_size) 18 | 19 | ctx.B = B 20 | ctx.H = H 21 | ctx.W = W 22 | ctx.C = C 23 | ctx.shift_size = shift_size 24 | ctx.window_size = window_size 25 | return output 26 | 27 | @staticmethod 28 | def backward(ctx, grad_in): 29 | B = ctx.B 30 | H = ctx.H 31 | W = ctx.W 32 | C = ctx.C 33 | shift_size = ctx.shift_size 34 | window_size = ctx.window_size 35 | 36 | grad_out = swin_window_process.roll_and_window_partition_backward(grad_in, B, H, W, C, shift_size, window_size) 37 | return grad_out, None, None, None, None, None, None, None 38 | 39 | 40 | class WindowProcessReverse(torch.autograd.Function): 41 | @staticmethod 42 | def forward(ctx, input, B, H, W, C, shift_size, window_size): 43 | output = swin_window_process.window_merge_and_roll_forward(input, B, H, W, C, shift_size, window_size) 44 | 45 | ctx.B = B 46 | ctx.H = H 47 | ctx.W = W 48 | ctx.C = C 49 | ctx.shift_size = shift_size 50 | ctx.window_size = window_size 51 | 52 | return output 53 | 54 | @staticmethod 55 | def backward(ctx, grad_in): 56 | B = ctx.B 57 | H = ctx.H 58 | W = ctx.W 59 | C = ctx.C 60 | shift_size = ctx.shift_size 61 | window_size = ctx.window_size 62 | 63 | grad_out = swin_window_process.window_merge_and_roll_backward(grad_in, B, H, W, C, shift_size, window_size) 64 | return grad_out, None, None, None, None, None, None, None 65 | 66 | 67 | def window_partition(x, window_size): 68 | """ 69 | Args: 70 | x: (B, H, W, C) 71 | window_size (int): window size 72 | Returns: 73 | windows: (num_windows*B, window_size, window_size, C) 74 | """ 75 | B, H, W, C = x.shape 76 | x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) 77 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 78 | return windows 79 | 80 | def window_reverse(windows, window_size, H, W): 81 | """ 82 | Args: 83 | windows: (num_windows*B, window_size, window_size, C) 84 | window_size (int): Window size 85 | H (int): Height of image 86 | W (int): Width of image 87 | Returns: 88 | x: (B, H, W, C) 89 | """ 90 | B = int(windows.shape[0] / (H * W / window_size / window_size)) 91 | x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) 92 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) 93 | return x 94 | 95 | 96 | def pyt_forward(x, shift_size, window_size): 97 | # x in shape(B, H, W, C) 98 | # cyclic shift 99 | if shift_size > 0: 100 | shifted_x = torch.roll(x, shifts=(-shift_size, -shift_size), dims=(1, 2)) 101 | else: 102 | shifted_x = x 103 | # partition windows 104 | x_windows = window_partition(shifted_x, window_size) 105 | return x_windows 106 | 107 | 108 | def reverse_pyt_forward(attn_windows, shift_size, window_size, H, W): 109 | # x in shape(B*nH*nW, window_size, window_size, C) 110 | shifted_x = window_reverse(attn_windows, window_size, H, W) 111 | if shift_size > 0: 112 | x = torch.roll(shifted_x, shifts=(shift_size, shift_size), dims=(1, 2)) 113 | else: 114 | x = shifted_x 115 | return x 116 | 117 | 118 | def copy_one_tensor(input, requires_grad=True): 119 | input1 = input.clone().detach().requires_grad_(requires_grad).cuda() 120 | return input1 121 | 122 | class Test_WindowProcess(unittest.TestCase): 123 | def setUp(self): 124 | self.B = 192 125 | self.H = 56 126 | self.W = 56 127 | self.C = 96 128 | self.shift_size = 2 129 | self.window_size = 7 130 | self.nH = self.H // self.window_size 131 | self.nW = self.W // self.window_size 132 | 133 | def test_roll_and_window_partition_forward(self, dtype=torch.float32): 134 | input = torch.randn((self.B, self.H, self.W, self.C), dtype=dtype, requires_grad=True).cuda() 135 | 136 | input1 = copy_one_tensor(input, True) 137 | input2 = copy_one_tensor(input, True) 138 | 139 | with torch.no_grad(): 140 | # ori 141 | expected = pyt_forward(input1, self.shift_size, self.window_size) 142 | # fused kernel 143 | fused_output = WindowProcess.apply(input2, self.B, self.H, self.W, self.C, -self.shift_size, self.window_size) 144 | 145 | self.assertTrue(torch.equal(expected, fused_output)) 146 | #self.assertTrue(torch.allclose(expected, fused_output, rtol=1e-05, atol=1e-08)) 147 | 148 | def test_roll_and_window_partition_backward(self, dtype=torch.float32): 149 | input = torch.randn((self.B, self.H, self.W, self.C), dtype=dtype, requires_grad=True).cuda() 150 | d_loss_tensor = torch.randn((self.B*self.nW*self.nH, self.window_size, self.window_size, self.C), dtype=dtype).cuda() 151 | 152 | input1 = copy_one_tensor(input, True) 153 | input2 = copy_one_tensor(input, True) 154 | 155 | # ori 156 | expected = pyt_forward(input1, self.shift_size, self.window_size) 157 | expected.backward(d_loss_tensor) 158 | # fused kernel 159 | fused_output = WindowProcess.apply(input2, self.B, self.H, self.W, self.C, -self.shift_size, self.window_size) 160 | fused_output.backward(d_loss_tensor) 161 | 162 | self.assertTrue(torch.equal(expected, fused_output)) 163 | #self.assertTrue(torch.allclose(expected, fused_output, rtol=1e-05, atol=1e-08)) 164 | 165 | def test_window_merge_and_roll_forward(self, dtype=torch.float32): 166 | input = torch.randn((self.B*self.nH*self.nW, self.window_size, self.window_size, self.C), dtype=dtype, requires_grad=True).cuda() 167 | 168 | input1 = copy_one_tensor(input, True) 169 | input2 = copy_one_tensor(input, True) 170 | 171 | with torch.no_grad(): 172 | # ori 173 | expected = reverse_pyt_forward(input1, self.shift_size, self.window_size, self.H, self.W) 174 | # fused kernel 175 | fused_output = WindowProcessReverse.apply(input2, self.B, self.H, self.W, self.C, self.shift_size, self.window_size) 176 | 177 | self.assertTrue(torch.equal(expected, fused_output)) 178 | #self.assertTrue(torch.allclose(expected, fused_output, rtol=1e-05, atol=1e-08)) 179 | 180 | 181 | def test_window_merge_and_roll_backward(self, dtype=torch.float32): 182 | input = torch.randn((self.B*self.nH*self.nW, self.window_size, self.window_size, self.C), dtype=dtype, requires_grad=True).cuda() 183 | d_loss_tensor = torch.randn((self.B, self.H, self.W, self.C), dtype=dtype, requires_grad=True).cuda() 184 | 185 | input1 = copy_one_tensor(input, True) 186 | input2 = copy_one_tensor(input, True) 187 | 188 | # ori 189 | expected = reverse_pyt_forward(input1, self.shift_size, self.window_size, self.H, self.W) 190 | expected.backward(d_loss_tensor) 191 | # fused kernel 192 | fused_output = WindowProcessReverse.apply(input2, self.B, self.H, self.W, self.C, self.shift_size, self.window_size) 193 | fused_output.backward(d_loss_tensor) 194 | 195 | self.assertTrue(torch.equal(expected, fused_output)) 196 | #self.assertTrue(torch.allclose(expected, fused_output, rtol=1e-05, atol=1e-08)) 197 | 198 | def test_forward_backward_speed(self, dtype=torch.float32, times=1000): 199 | input = torch.randn((self.B*self.nH*self.nW, self.window_size, self.window_size, self.C), dtype=dtype, requires_grad=True).cuda() 200 | d_loss_tensor = torch.randn((self.B, self.H, self.W, self.C), dtype=dtype, requires_grad=True).cuda() 201 | 202 | input1 = copy_one_tensor(input, True) 203 | input2 = copy_one_tensor(input, True) 204 | 205 | # SwinTransformer official 206 | def run_pyt(t=1000): 207 | for _ in range(t): 208 | expected = reverse_pyt_forward(input1, self.shift_size, self.window_size, self.H, self.W) 209 | expected.backward(d_loss_tensor) 210 | 211 | # my op 212 | def run_fusedop(t=1000): 213 | for _ in range(t): 214 | fused_output = WindowProcessReverse.apply(input2, self.B, self.H, self.W, self.C, self.shift_size, self.window_size) 215 | fused_output.backward(d_loss_tensor) 216 | 217 | torch.cuda.synchronize() 218 | t1 = time.time() 219 | run_pyt(t=times) 220 | torch.cuda.synchronize() 221 | t2 = time.time() 222 | run_fusedop(t=times) 223 | torch.cuda.synchronize() 224 | t3 = time.time() 225 | self.assertTrue((t3 - t2) < (t2 - t1)) 226 | 227 | print('Run {} times'.format(times)) 228 | print('Original time cost: {}'.format(t2 - t1)) 229 | print('Fused op time cost: {}'.format(t3 - t2)) 230 | 231 | def test_roll_and_window_partition_forward_fp16(self, dtype=torch.float16): 232 | self.test_roll_and_window_partition_forward(dtype=dtype) 233 | 234 | def test_roll_and_window_partition_backward_fp16(self, dtype=torch.float16): 235 | self.test_roll_and_window_partition_backward(dtype=dtype) 236 | 237 | def test_window_merge_and_roll_forward_fp16(self, dtype=torch.float16): 238 | self.test_window_merge_and_roll_forward(dtype=dtype) 239 | 240 | def test_window_merge_and_roll_backward_fp16(self, dtype=torch.float16): 241 | self.test_window_merge_and_roll_backward(dtype=dtype) 242 | 243 | def test_forward_backward_speed_fp16(self, dtype=torch.float16, times=1000): 244 | self.test_forward_backward_speed(dtype=dtype, times=times) 245 | 246 | 247 | if __name__ == '__main__': 248 | print('Pass only two tensors are exactly the same (using torch.equal).\n') 249 | torch.manual_seed(0) 250 | unittest.main(verbosity=2) 251 | -------------------------------------------------------------------------------- /kernels/window_process/window_process.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Fused kernel for window process for SwinTransformer 3 | # Copyright (c) 2022 Nvidia 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # -------------------------------------------------------- 6 | 7 | import torch 8 | import swin_window_process 9 | 10 | 11 | class WindowProcess(torch.autograd.Function): 12 | @staticmethod 13 | def forward(ctx, input, B, H, W, C, shift_size, window_size): 14 | output = swin_window_process.roll_and_window_partition_forward(input, B, H, W, C, shift_size, window_size) 15 | 16 | ctx.B = B 17 | ctx.H = H 18 | ctx.W = W 19 | ctx.C = C 20 | ctx.shift_size = shift_size 21 | ctx.window_size = window_size 22 | return output 23 | 24 | @staticmethod 25 | def backward(ctx, grad_in): 26 | B = ctx.B 27 | H = ctx.H 28 | W = ctx.W 29 | C = ctx.C 30 | shift_size = ctx.shift_size 31 | window_size = ctx.window_size 32 | 33 | grad_out = swin_window_process.roll_and_window_partition_backward(grad_in, B, H, W, C, shift_size, window_size) 34 | return grad_out, None, None, None, None, None, None, None 35 | 36 | 37 | class WindowProcessReverse(torch.autograd.Function): 38 | @staticmethod 39 | def forward(ctx, input, B, H, W, C, shift_size, window_size): 40 | output = swin_window_process.window_merge_and_roll_forward(input, B, H, W, C, shift_size, window_size) 41 | 42 | ctx.B = B 43 | ctx.H = H 44 | ctx.W = W 45 | ctx.C = C 46 | ctx.shift_size = shift_size 47 | ctx.window_size = window_size 48 | 49 | return output 50 | 51 | @staticmethod 52 | def backward(ctx, grad_in): 53 | B = ctx.B 54 | H = ctx.H 55 | W = ctx.W 56 | C = ctx.C 57 | shift_size = ctx.shift_size 58 | window_size = ctx.window_size 59 | 60 | #grad_out = ctx.saved_tensors[0] 61 | #grad_out = torch.zeros((B, H, W, C), dtype=dtype).cuda() 62 | grad_out = swin_window_process.window_merge_and_roll_backward(grad_in, B, H, W, C, shift_size, window_size) 63 | return grad_out, None, None, None, None, None, None, None 64 | -------------------------------------------------------------------------------- /logger.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | 8 | import os 9 | import sys 10 | import logging 11 | import functools 12 | from termcolor import colored 13 | 14 | 15 | @functools.lru_cache() 16 | def create_logger(output_dir, dist_rank=0, name=''): 17 | # create logger 18 | logger = logging.getLogger(name) 19 | logger.setLevel(logging.INFO) 20 | logger.propagate = False 21 | 22 | # create formatter 23 | fmt = '[%(asctime)s %(name)s] (%(filename)s %(lineno)d): %(levelname)s %(message)s' 24 | color_fmt = colored('[%(asctime)s %(name)s]', 'green') + \ 25 | colored('(%(filename)s %(lineno)d)', 'yellow') + \ 26 | ': %(levelname)s %(message)s' 27 | 28 | # create console handlers for master process 29 | if dist_rank == 0: 30 | console_handler = logging.StreamHandler(sys.stdout) 31 | console_handler.setLevel(logging.DEBUG) 32 | console_handler.setFormatter( 33 | logging.Formatter(fmt=color_fmt, datefmt='%Y-%m-%d %H:%M:%S')) 34 | logger.addHandler(console_handler) 35 | 36 | # create file handlers 37 | file_handler = logging.FileHandler(os.path.join( 38 | output_dir, f'log_rank{dist_rank}.txt'), mode='a') 39 | file_handler.setLevel(logging.DEBUG) 40 | file_handler.setFormatter(logging.Formatter( 41 | fmt=fmt, datefmt='%Y-%m-%d %H:%M:%S')) 42 | logger.addHandler(file_handler) 43 | 44 | return logger 45 | -------------------------------------------------------------------------------- /lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # MTLoRA 3 | # GitHub: https://github.com/scale-lab/MTLoRA 4 | # Built upon Swin Transformer (https://github.com/microsoft/Swin-Transformer) 5 | # 6 | # Original file: 7 | # Copyright (c) 2021 Microsoft 8 | # Licensed under the MIT License 9 | # Written by Ze Liu 10 | # 11 | # Modifications: 12 | # Copyright (c) 2024 SCALE Lab, Brown University 13 | # Licensed under the MIT License (see LICENSE for details) 14 | # -------------------------------------------------------- 15 | 16 | 17 | import bisect 18 | 19 | import torch 20 | from timm.scheduler.cosine_lr import CosineLRScheduler 21 | from timm.scheduler.step_lr import StepLRScheduler 22 | from timm.scheduler.scheduler import Scheduler 23 | 24 | 25 | def build_scheduler(config, optimizer, n_iter_per_epoch): 26 | num_steps = int(config.TRAIN.EPOCHS * n_iter_per_epoch) 27 | warmup_steps = int(config.TRAIN.WARMUP_EPOCHS * n_iter_per_epoch) 28 | decay_steps = int( 29 | config.TRAIN.LR_SCHEDULER.DECAY_EPOCHS * n_iter_per_epoch) 30 | multi_steps = [ 31 | i * n_iter_per_epoch for i in config.TRAIN.LR_SCHEDULER.MULTISTEPS] 32 | lr_scheduler = None 33 | if config.TRAIN.LR_SCHEDULER.NAME == 'cosine': 34 | lr_scheduler = CosineLRScheduler( 35 | optimizer, 36 | t_initial=( 37 | num_steps - warmup_steps) if config.TRAIN.LR_SCHEDULER.WARMUP_PREFIX else num_steps, 38 | # t_mul=1., 39 | lr_min=config.TRAIN.MIN_LR, 40 | warmup_lr_init=config.TRAIN.WARMUP_LR, 41 | warmup_t=warmup_steps, 42 | cycle_limit=1, 43 | t_in_epochs=False, 44 | warmup_prefix=config.TRAIN.LR_SCHEDULER.WARMUP_PREFIX, 45 | ) 46 | elif config.TRAIN.LR_SCHEDULER.NAME == 'linear': 47 | lr_scheduler = LinearLRScheduler( 48 | optimizer, 49 | t_initial=num_steps, 50 | lr_min_rate=0.01, 51 | warmup_lr_init=config.TRAIN.WARMUP_LR, 52 | warmup_t=warmup_steps, 53 | t_in_epochs=False, 54 | ) 55 | elif config.TRAIN.LR_SCHEDULER.NAME == 'step': 56 | lr_scheduler = StepLRScheduler( 57 | optimizer, 58 | decay_t=decay_steps, 59 | decay_rate=config.TRAIN.LR_SCHEDULER.DECAY_RATE, 60 | warmup_lr_init=config.TRAIN.WARMUP_LR, 61 | warmup_t=warmup_steps, 62 | t_in_epochs=False, 63 | ) 64 | elif config.TRAIN.LR_SCHEDULER.NAME == 'multistep': 65 | lr_scheduler = MultiStepLRScheduler( 66 | optimizer, 67 | milestones=multi_steps, 68 | gamma=config.TRAIN.LR_SCHEDULER.GAMMA, 69 | warmup_lr_init=config.TRAIN.WARMUP_LR, 70 | warmup_t=warmup_steps, 71 | t_in_epochs=False, 72 | ) 73 | 74 | return lr_scheduler 75 | 76 | 77 | class LinearLRScheduler(Scheduler): 78 | def __init__(self, 79 | optimizer: torch.optim.Optimizer, 80 | t_initial: int, 81 | lr_min_rate: float, 82 | warmup_t=0, 83 | warmup_lr_init=0., 84 | t_in_epochs=True, 85 | noise_range_t=None, 86 | noise_pct=0.67, 87 | noise_std=1.0, 88 | noise_seed=42, 89 | initialize=True, 90 | ) -> None: 91 | super().__init__( 92 | optimizer, param_group_field="lr", 93 | noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed, 94 | initialize=initialize) 95 | 96 | self.t_initial = t_initial 97 | self.lr_min_rate = lr_min_rate 98 | self.warmup_t = warmup_t 99 | self.warmup_lr_init = warmup_lr_init 100 | self.t_in_epochs = t_in_epochs 101 | if self.warmup_t: 102 | self.warmup_steps = [(v - warmup_lr_init) / 103 | self.warmup_t for v in self.base_values] 104 | super().update_groups(self.warmup_lr_init) 105 | else: 106 | self.warmup_steps = [1 for _ in self.base_values] 107 | 108 | def _get_lr(self, t): 109 | if t < self.warmup_t: 110 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] 111 | else: 112 | t = t - self.warmup_t 113 | total_t = self.t_initial - self.warmup_t 114 | lrs = [v - ((v - v * self.lr_min_rate) * (t / total_t)) 115 | for v in self.base_values] 116 | return lrs 117 | 118 | def get_epoch_values(self, epoch: int): 119 | if self.t_in_epochs: 120 | return self._get_lr(epoch) 121 | else: 122 | return None 123 | 124 | def get_update_values(self, num_updates: int): 125 | if not self.t_in_epochs: 126 | return self._get_lr(num_updates) 127 | else: 128 | return None 129 | 130 | 131 | class MultiStepLRScheduler(Scheduler): 132 | def __init__(self, optimizer: torch.optim.Optimizer, milestones, gamma=0.1, warmup_t=0, warmup_lr_init=0, t_in_epochs=True) -> None: 133 | super().__init__(optimizer, param_group_field="lr") 134 | 135 | self.milestones = milestones 136 | self.gamma = gamma 137 | self.warmup_t = warmup_t 138 | self.warmup_lr_init = warmup_lr_init 139 | self.t_in_epochs = t_in_epochs 140 | if self.warmup_t: 141 | self.warmup_steps = [(v - warmup_lr_init) / 142 | self.warmup_t for v in self.base_values] 143 | super().update_groups(self.warmup_lr_init) 144 | else: 145 | self.warmup_steps = [1 for _ in self.base_values] 146 | 147 | assert self.warmup_t <= min(self.milestones) 148 | 149 | def _get_lr(self, t): 150 | if t < self.warmup_t: 151 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] 152 | else: 153 | lrs = [v * (self.gamma ** bisect.bisect_right(self.milestones, t)) 154 | for v in self.base_values] 155 | return lrs 156 | 157 | def get_epoch_values(self, epoch: int): 158 | if self.t_in_epochs: 159 | return self._get_lr(epoch) 160 | else: 161 | return None 162 | 163 | def get_update_values(self, num_updates: int): 164 | if not self.t_in_epochs: 165 | return self._get_lr(num_updates) 166 | else: 167 | return None 168 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import build_model, build_mtl_model -------------------------------------------------------------------------------- /models/aspp.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # MTLoRA 3 | # GitHub: https://github.com/scale-lab/MTLoRA 4 | # 5 | # Original file: 6 | # Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/) 7 | # Written by Simon Vandenhende 8 | # 9 | # Modifications: 10 | # Copyright (c) 2024 SCALE Lab, Brown University 11 | # Licensed under the MIT License (see LICENSE for details) 12 | # -------------------------------------------------------- 13 | 14 | import torch 15 | import torch.nn as nn 16 | import torch.nn.functional as F 17 | 18 | 19 | class DeepLabHead(nn.Sequential): 20 | def __init__(self, in_channels, num_classes): 21 | super(DeepLabHead, self).__init__( 22 | ASPP(in_channels, [12, 24, 36]), 23 | nn.Conv2d(256, 256, 3, padding=1, bias=False), 24 | nn.BatchNorm2d(256), 25 | nn.ReLU(), 26 | nn.Conv2d(256, num_classes, 1) 27 | ) 28 | 29 | 30 | class ASPPConv(nn.Sequential): 31 | def __init__(self, in_channels, out_channels, dilation): 32 | modules = [ 33 | nn.Conv2d(in_channels, out_channels, 3, padding=dilation, 34 | dilation=dilation, bias=False), 35 | nn.BatchNorm2d(out_channels), 36 | nn.ReLU() 37 | ] 38 | super(ASPPConv, self).__init__(*modules) 39 | 40 | 41 | class ASPPPooling(nn.Sequential): 42 | def __init__(self, in_channels, out_channels): 43 | super(ASPPPooling, self).__init__( 44 | nn.AdaptiveAvgPool2d(1), 45 | nn.Conv2d(in_channels, out_channels, 1, bias=False), 46 | nn.BatchNorm2d(out_channels), 47 | nn.ReLU()) 48 | 49 | def forward(self, x): 50 | size = x.shape[-2:] 51 | x = super(ASPPPooling, self).forward(x) 52 | return F.interpolate(x, size=size, mode='bilinear', align_corners=False) 53 | 54 | 55 | class ASPP(nn.Module): 56 | def __init__(self, in_channels, atrous_rates): 57 | super(ASPP, self).__init__() 58 | 59 | if isinstance(in_channels, (list, tuple)): 60 | in_channels = sum(in_channels) 61 | 62 | out_channels = 256 63 | modules = [] 64 | modules.append(nn.Sequential( 65 | nn.Conv2d(in_channels, out_channels, 1, bias=False), 66 | nn.BatchNorm2d(out_channels), 67 | nn.ReLU())) 68 | 69 | rate1, rate2, rate3 = tuple(atrous_rates) 70 | modules.append(ASPPConv(in_channels, out_channels, rate1)) 71 | modules.append(ASPPConv(in_channels, out_channels, rate2)) 72 | modules.append(ASPPConv(in_channels, out_channels, rate3)) 73 | modules.append(ASPPPooling(in_channels, out_channels)) 74 | 75 | self.convs = nn.ModuleList(modules) 76 | 77 | self.project = nn.Sequential( 78 | nn.Conv2d(5 * out_channels, out_channels, 1, bias=False), 79 | nn.BatchNorm2d(out_channels), 80 | nn.ReLU(), 81 | nn.Dropout(0.5)) 82 | 83 | def forward(self, x): 84 | x0_h, x0_w = x[0].size(2), x[0].size(3) 85 | x1 = F.interpolate(x[1], (x0_h, x0_w), mode='bilinear') 86 | x2 = F.interpolate(x[2], (x0_h, x0_w), mode='bilinear') 87 | x3 = F.interpolate(x[3], (x0_h, x0_w), mode='bilinear') 88 | 89 | x = torch.cat([x[0], x1, x2, x3], 1) 90 | res = [] 91 | for conv in self.convs: 92 | res.append(conv(x)) 93 | res = torch.cat(res, dim=1) 94 | return self.project(res) 95 | -------------------------------------------------------------------------------- /models/aspp_single.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # MTLoRA 3 | # GitHub: https://github.com/scale-lab/MTLoRA 4 | # 5 | # Original file: 6 | # Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/) 7 | # Written by Simon Vandenhende 8 | # 9 | # Modifications: 10 | # Copyright (c) 2024 SCALE Lab, Brown University 11 | # Licensed under the MIT License (see LICENSE for details) 12 | # -------------------------------------------------------- 13 | 14 | import torch 15 | import torch.nn as nn 16 | import torch.nn.functional as F 17 | 18 | 19 | class DeepLabHead(nn.Sequential): 20 | def __init__(self, in_channels, num_classes): 21 | super(DeepLabHead, self).__init__( 22 | ASPP(in_channels, [12, 24, 36]), 23 | nn.Conv2d(256, 256, 3, padding=1, bias=False), 24 | nn.BatchNorm2d(256), 25 | nn.ReLU(), 26 | nn.Conv2d(256, num_classes, 1) 27 | ) 28 | 29 | 30 | class ASPPConv(nn.Sequential): 31 | def __init__(self, in_channels, out_channels, dilation): 32 | modules = [ 33 | nn.Conv2d(in_channels, out_channels, 3, padding=dilation, 34 | dilation=dilation, bias=False), 35 | nn.BatchNorm2d(out_channels), 36 | nn.ReLU() 37 | ] 38 | super(ASPPConv, self).__init__(*modules) 39 | 40 | 41 | class ASPPPooling(nn.Sequential): 42 | def __init__(self, in_channels, out_channels): 43 | super(ASPPPooling, self).__init__( 44 | nn.AdaptiveAvgPool2d(1), 45 | nn.Conv2d(in_channels, out_channels, 1, bias=False), 46 | nn.BatchNorm2d(out_channels), 47 | nn.ReLU()) 48 | 49 | def forward(self, x): 50 | size = x.shape[-2:] 51 | x = super(ASPPPooling, self).forward(x) 52 | return F.interpolate(x, size=size, mode='bilinear', align_corners=False) 53 | 54 | 55 | class ASPP(nn.Module): 56 | def __init__(self, in_channels, atrous_rates): 57 | super(ASPP, self).__init__() 58 | 59 | if isinstance(in_channels, (list, tuple)): 60 | in_channels = sum(in_channels) 61 | 62 | out_channels = 256 63 | modules = [] 64 | modules.append(nn.Sequential( 65 | nn.Conv2d(in_channels, out_channels, 1, bias=False), 66 | nn.BatchNorm2d(out_channels), 67 | nn.ReLU())) 68 | 69 | rate1, rate2, rate3 = tuple(atrous_rates) 70 | modules.append(ASPPConv(in_channels, out_channels, rate1)) 71 | modules.append(ASPPConv(in_channels, out_channels, rate2)) 72 | modules.append(ASPPConv(in_channels, out_channels, rate3)) 73 | modules.append(ASPPPooling(in_channels, out_channels)) 74 | 75 | self.convs = nn.ModuleList(modules) 76 | 77 | self.project = nn.Sequential( 78 | nn.Conv2d(5 * out_channels, out_channels, 1, bias=False), 79 | nn.BatchNorm2d(out_channels), 80 | nn.ReLU(), 81 | nn.Dropout(0.5)) 82 | 83 | def forward(self, x): 84 | res = [] 85 | for conv in self.convs: 86 | res.append(conv(x)) 87 | res = torch.cat(res, dim=1) 88 | return self.project(res) 89 | -------------------------------------------------------------------------------- /models/base_decode_head.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | 3 | import torch 4 | import torch.nn as nn 5 | from mmengine.model import normal_init 6 | import torch.nn.functional as F 7 | 8 | 9 | def resize(input, 10 | size=None, 11 | scale_factor=None, 12 | mode='nearest', 13 | align_corners=None, 14 | warning=True): 15 | if warning: 16 | if size is not None and align_corners: 17 | input_h, input_w = tuple(int(x) for x in input.shape[2:]) 18 | output_h, output_w = tuple(int(x) for x in size) 19 | if output_h > input_h or output_w > output_h: 20 | if ((output_h > 1 and output_w > 1 and input_h > 1 21 | and input_w > 1) and (output_h - 1) % (input_h - 1) 22 | and (output_w - 1) % (input_w - 1)): 23 | warnings.warn( 24 | f'When align_corners={align_corners}, ' 25 | 'the output would more aligned if ' 26 | f'input size {(input_h, input_w)} is `x+1` and ' 27 | f'out size {(output_h, output_w)} is `nx+1`') 28 | if isinstance(size, torch.Size): 29 | size = tuple(int(x) for x in size) 30 | return F.interpolate(input, size, scale_factor, mode, align_corners) 31 | 32 | 33 | class BaseDecodeHead(nn.Module, metaclass=ABCMeta): 34 | """Base class for BaseDecodeHead. 35 | 36 | Args: 37 | in_channels (int|Sequence[int]): Input channels. 38 | channels (int): Channels after modules, before conv_seg. 39 | num_classes (int): Number of classes. 40 | dropout_ratio (float): Ratio of dropout layer. Default: 0.1. 41 | conv_cfg (dict|None): Config of conv layers. Default: None. 42 | norm_cfg (dict|None): Config of norm layers. Default: None. 43 | act_cfg (dict): Config of activation layers. 44 | Default: dict(type='ReLU') 45 | in_index (int|Sequence[int]): Input feature index. Default: -1 46 | input_transform (str|None): Transformation type of input features. 47 | Options: 'resize_concat', 'multiple_select', None. 48 | 'resize_concat': Multiple feature maps will be resize to the 49 | same size as first one and than concat together. 50 | Usually used in FCN head of HRNet. 51 | 'multiple_select': Multiple feature maps will be bundle into 52 | a list and passed into decode head. 53 | None: Only one select feature map is allowed. 54 | Default: None. 55 | loss_decode (dict): Config of decode loss. 56 | Default: dict(type='CrossEntropyLoss'). 57 | ignore_index (int | None): The label index to be ignored. When using 58 | masked BCE loss, ignore_index should be set to None. Default: 255 59 | sampler (dict|None): The config of segmentation map sampler. 60 | Default: None. 61 | align_corners (bool): align_corners argument of F.interpolate. 62 | Default: False. 63 | """ 64 | 65 | def __init__(self, 66 | in_channels, 67 | channels, 68 | *, 69 | num_classes, 70 | dropout_ratio=0.1, 71 | conv_cfg=None, 72 | norm_cfg=None, 73 | act_cfg=dict(type='ReLU'), 74 | in_index=-1, 75 | input_transform=None, 76 | # loss_decode=dict( 77 | # type='CrossEntropyLoss', 78 | # use_sigmoid=False, 79 | # loss_weight=1.0), 80 | decoder_params=None, 81 | ignore_index=255, 82 | sampler=None, 83 | align_corners=False): 84 | super(BaseDecodeHead, self).__init__() 85 | self._init_inputs(in_channels, in_index, input_transform) 86 | self.channels = channels 87 | self.num_classes = num_classes 88 | self.dropout_ratio = dropout_ratio 89 | self.conv_cfg = conv_cfg 90 | self.norm_cfg = norm_cfg 91 | self.act_cfg = act_cfg 92 | self.in_index = in_index 93 | # self.loss_decode = build_loss(loss_decode) 94 | self.ignore_index = ignore_index 95 | self.align_corners = align_corners 96 | 97 | if sampler is not None: 98 | self.sampler = build_pixel_sampler(sampler, context=self) 99 | else: 100 | self.sampler = None 101 | 102 | self.conv_seg = nn.Conv2d(channels, num_classes, kernel_size=1) 103 | if dropout_ratio > 0: 104 | self.dropout = nn.Dropout2d(dropout_ratio) 105 | else: 106 | self.dropout = None 107 | self.fp16_enabled = False 108 | 109 | def extra_repr(self): 110 | """Extra repr.""" 111 | s = f'input_transform={self.input_transform}, ' \ 112 | f'ignore_index={self.ignore_index}, ' \ 113 | f'align_corners={self.align_corners}' 114 | return s 115 | 116 | def _init_inputs(self, in_channels, in_index, input_transform): 117 | """Check and initialize input transforms. 118 | 119 | The in_channels, in_index and input_transform must match. 120 | Specifically, when input_transform is None, only single feature map 121 | will be selected. So in_channels and in_index must be of type int. 122 | When input_transform 123 | 124 | Args: 125 | in_channels (int|Sequence[int]): Input channels. 126 | in_index (int|Sequence[int]): Input feature index. 127 | input_transform (str|None): Transformation type of input features. 128 | Options: 'resize_concat', 'multiple_select', None. 129 | 'resize_concat': Multiple feature maps will be resize to the 130 | same size as first one and than concat together. 131 | Usually used in FCN head of HRNet. 132 | 'multiple_select': Multiple feature maps will be bundle into 133 | a list and passed into decode head. 134 | None: Only one select feature map is allowed. 135 | """ 136 | if input_transform is not None: 137 | assert input_transform in ['resize_concat', 'multiple_select'] 138 | self.input_transform = input_transform 139 | self.in_index = in_index 140 | if input_transform is not None: 141 | assert isinstance(in_channels, (list, tuple)) 142 | assert isinstance(in_index, (list, tuple)) 143 | assert len(in_channels) == len(in_index) 144 | if input_transform == 'resize_concat': 145 | self.in_channels = sum(in_channels) 146 | else: 147 | self.in_channels = in_channels 148 | else: 149 | assert isinstance(in_channels, int) 150 | assert isinstance(in_index, int) 151 | self.in_channels = in_channels 152 | 153 | def init_weights(self): 154 | """Initialize weights of classification layer.""" 155 | normal_init(self.conv_seg, mean=0, std=0.01) 156 | 157 | def _transform_inputs(self, inputs): 158 | """Transform inputs for decoder. 159 | 160 | Args: 161 | inputs (list[Tensor]): List of multi-level img features. 162 | 163 | Returns: 164 | Tensor: The transformed inputs 165 | """ 166 | 167 | if self.input_transform == 'resize_concat': 168 | inputs = [inputs[i] for i in self.in_index] 169 | upsampled_inputs = [ 170 | resize( 171 | input=x, 172 | size=inputs[0].shape[2:], 173 | mode='bilinear', 174 | align_corners=self.align_corners) for x in inputs 175 | ] 176 | inputs = torch.cat(upsampled_inputs, dim=1) 177 | elif self.input_transform == 'multiple_select': 178 | inputs = [inputs[i] for i in self.in_index] 179 | else: 180 | inputs = inputs[self.in_index] 181 | 182 | return inputs 183 | 184 | @abstractmethod 185 | def forward(self, inputs): 186 | """Placeholder of forward function.""" 187 | pass 188 | 189 | def forward_train(self, inputs, img_metas, gt_semantic_seg, train_cfg): 190 | """Forward function for training. 191 | Args: 192 | inputs (list[Tensor]): List of multi-level img features. 193 | img_metas (list[dict]): List of image info dict where each dict 194 | has: 'img_shape', 'scale_factor', 'flip', and may also contain 195 | 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. 196 | For details on the values of these keys see 197 | `mmseg/datasets/pipelines/formatting.py:Collect`. 198 | gt_semantic_seg (Tensor): Semantic segmentation masks 199 | used if the architecture supports semantic segmentation task. 200 | train_cfg (dict): The training config. 201 | 202 | Returns: 203 | dict[str, Tensor]: a dictionary of loss components 204 | """ 205 | seg_logits = self.forward(inputs) 206 | losses = self.losses(seg_logits, gt_semantic_seg) 207 | return losses 208 | 209 | def forward_test(self, inputs, img_metas, test_cfg): 210 | """Forward function for testing. 211 | 212 | Args: 213 | inputs (list[Tensor]): List of multi-level img features. 214 | img_metas (list[dict]): List of image info dict where each dict 215 | has: 'img_shape', 'scale_factor', 'flip', and may also contain 216 | 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. 217 | For details on the values of these keys see 218 | `mmseg/datasets/pipelines/formatting.py:Collect`. 219 | test_cfg (dict): The testing config. 220 | 221 | Returns: 222 | Tensor: Output segmentation map. 223 | """ 224 | return self.forward(inputs) 225 | 226 | def cls_seg(self, feat): 227 | """Classify each pixel.""" 228 | if self.dropout is not None: 229 | feat = self.dropout(feat) 230 | output = self.conv_seg(feat) 231 | return output 232 | 233 | ''' 234 | @force_fp32(apply_to=('seg_logit', )) 235 | def losses(self, seg_logit, seg_label): 236 | """Compute segmentation loss.""" 237 | loss = dict() 238 | seg_logit = resize( 239 | input=seg_logit, 240 | size=seg_label.shape[2:], 241 | mode='bilinear', 242 | align_corners=self.align_corners) 243 | if self.sampler is not None: 244 | seg_weight = self.sampler.sample(seg_logit, seg_label) 245 | else: 246 | seg_weight = None 247 | seg_label = seg_label.squeeze(1) 248 | loss['loss_seg'] = self.loss_decode( 249 | seg_logit, 250 | seg_label, 251 | weight=seg_weight, 252 | ignore_index=self.ignore_index) 253 | loss['acc_seg'] = accuracy(seg_logit, seg_label) 254 | return loss 255 | ''' 256 | -------------------------------------------------------------------------------- /models/build.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # MTLoRA 3 | # GitHub: https://github.com/scale-lab/MTLoRA 4 | # Built upon Swin Transformer (https://github.com/microsoft/Swin-Transformer) 5 | # 6 | # Original file: 7 | # Copyright (c) 2021 Microsoft 8 | # Licensed under the MIT License 9 | # Written by Ze Liu 10 | # 11 | # Modifications: 12 | # Copyright (c) 2024 SCALE Lab, Brown University 13 | # Licensed under the MIT License (see LICENSE for details) 14 | # -------------------------------------------------------- 15 | 16 | 17 | from .swin_transformer_mtlora import SwinTransformerMTLoRA 18 | from .swin_transformer import SwinTransformer 19 | from .swin_mtl import MultiTaskSwin 20 | 21 | 22 | def build_model(config, is_pretrain=False): 23 | model_type = config.MODEL.TYPE 24 | 25 | # accelerate layernorm 26 | if config.FUSED_LAYERNORM: 27 | try: 28 | import apex as amp 29 | layernorm = amp.normalization.FusedLayerNorm 30 | except: 31 | layernorm = None 32 | print("To use FusedLayerNorm, please install apex.") 33 | else: 34 | import torch.nn as nn 35 | layernorm = nn.LayerNorm 36 | 37 | if model_type == 'swin': 38 | if config.MODEL.MTLORA.ENABLED: 39 | model = SwinTransformerMTLoRA(img_size=config.DATA.IMG_SIZE, 40 | patch_size=config.MODEL.SWIN.PATCH_SIZE, 41 | in_chans=config.MODEL.SWIN.IN_CHANS, 42 | num_classes=config.MODEL.NUM_CLASSES, 43 | embed_dim=config.MODEL.SWIN.EMBED_DIM, 44 | depths=config.MODEL.SWIN.DEPTHS, 45 | num_heads=config.MODEL.SWIN.NUM_HEADS, 46 | window_size=config.MODEL.SWIN.WINDOW_SIZE, 47 | mlp_ratio=config.MODEL.SWIN.MLP_RATIO, 48 | qkv_bias=config.MODEL.SWIN.QKV_BIAS, 49 | qk_scale=config.MODEL.SWIN.QK_SCALE, 50 | drop_rate=config.MODEL.DROP_RATE, 51 | drop_path_rate=config.MODEL.DROP_PATH_RATE, 52 | ape=config.MODEL.SWIN.APE, 53 | norm_layer=layernorm, 54 | patch_norm=config.MODEL.SWIN.PATCH_NORM, 55 | use_checkpoint=config.TRAIN.USE_CHECKPOINT, 56 | fused_window_process=config.FUSED_WINDOW_PROCESS, 57 | tasks=config.TASKS, 58 | mtlora=config.MODEL.MTLORA) 59 | else: 60 | model = SwinTransformer(img_size=config.DATA.IMG_SIZE, 61 | patch_size=config.MODEL.SWIN.PATCH_SIZE, 62 | in_chans=config.MODEL.SWIN.IN_CHANS, 63 | num_classes=config.MODEL.NUM_CLASSES, 64 | embed_dim=config.MODEL.SWIN.EMBED_DIM, 65 | depths=config.MODEL.SWIN.DEPTHS, 66 | num_heads=config.MODEL.SWIN.NUM_HEADS, 67 | window_size=config.MODEL.SWIN.WINDOW_SIZE, 68 | mlp_ratio=config.MODEL.SWIN.MLP_RATIO, 69 | qkv_bias=config.MODEL.SWIN.QKV_BIAS, 70 | qk_scale=config.MODEL.SWIN.QK_SCALE, 71 | drop_rate=config.MODEL.DROP_RATE, 72 | drop_path_rate=config.MODEL.DROP_PATH_RATE, 73 | ape=config.MODEL.SWIN.APE, 74 | norm_layer=layernorm, 75 | patch_norm=config.MODEL.SWIN.PATCH_NORM, 76 | use_checkpoint=config.TRAIN.USE_CHECKPOINT, 77 | fused_window_process=config.FUSED_WINDOW_PROCESS) 78 | else: 79 | raise NotImplementedError(f"Unkown model: {model_type}") 80 | 81 | return model 82 | 83 | 84 | def build_mtl_model(backbone, config): 85 | model = MultiTaskSwin(backbone, config) 86 | return model 87 | -------------------------------------------------------------------------------- /models/segformer.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2021, NVIDIA Corporation. All rights reserved. 3 | # 4 | # This work is licensed under the NVIDIA Source Code License 5 | # --------------------------------------------------------------- 6 | import numpy as np 7 | import torch.nn as nn 8 | import torch 9 | from mmcv.cnn import ConvModule 10 | 11 | from .base_decode_head import BaseDecodeHead 12 | import torch.nn.functional as F 13 | 14 | 15 | def resize(input, 16 | size=None, 17 | scale_factor=None, 18 | mode='nearest', 19 | align_corners=None, 20 | warning=False): 21 | if warning: 22 | if size is not None and align_corners: 23 | input_h, input_w = tuple(int(x) for x in input.shape[2:]) 24 | output_h, output_w = tuple(int(x) for x in size) 25 | if output_h > input_h or output_w > output_h: 26 | if ((output_h > 1 and output_w > 1 and input_h > 1 27 | and input_w > 1) and (output_h - 1) % (input_h - 1) 28 | and (output_w - 1) % (input_w - 1)): 29 | warnings.warn( 30 | f'When align_corners={align_corners}, ' 31 | 'the output would more aligned if ' 32 | f'input size {(input_h, input_w)} is `x+1` and ' 33 | f'out size {(output_h, output_w)} is `nx+1`') 34 | if isinstance(size, torch.Size): 35 | size = tuple(int(x) for x in size) 36 | return F.interpolate(input, size, scale_factor, mode, align_corners) 37 | 38 | 39 | class MLP(nn.Module): 40 | """ 41 | Linear Embedding 42 | """ 43 | 44 | def __init__(self, input_dim=2048, embed_dim=768): 45 | super().__init__() 46 | self.proj = nn.Linear(input_dim, embed_dim) 47 | 48 | def forward(self, x): 49 | x = x.flatten(2).transpose(1, 2) 50 | x = self.proj(x) 51 | return x 52 | 53 | 54 | class SegFormerHead(BaseDecodeHead): 55 | """ 56 | SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers 57 | """ 58 | 59 | def __init__(self, **kwargs): 60 | super(SegFormerHead, self).__init__( 61 | input_transform='multiple_select', in_index=[0, 1, 2, 3], **kwargs) 62 | c1_in_channels, c2_in_channels, c3_in_channels, c4_in_channels = self.in_channels 63 | 64 | embedding_dim = kwargs['channels'] 65 | 66 | self.linear_c4 = MLP(input_dim=c4_in_channels, embed_dim=embedding_dim) 67 | self.linear_c3 = MLP(input_dim=c3_in_channels, embed_dim=embedding_dim) 68 | self.linear_c2 = MLP(input_dim=c2_in_channels, embed_dim=embedding_dim) 69 | self.linear_c1 = MLP(input_dim=c1_in_channels, embed_dim=embedding_dim) 70 | 71 | self.linear_fuse = ConvModule( 72 | in_channels=embedding_dim*4, 73 | out_channels=embedding_dim, 74 | kernel_size=1, 75 | norm_cfg=dict(type='SyncBN', requires_grad=True) 76 | ) 77 | 78 | self.linear_pred = nn.Conv2d( 79 | embedding_dim, self.num_classes, kernel_size=1) 80 | 81 | def forward(self, inputs): 82 | x = self._transform_inputs(inputs) # len=4, 1/4,1/8,1/16,1/32 83 | c1, c2, c3, c4 = x 84 | 85 | ############## MLP decoder on C1-C4 ########### 86 | n, _, h, w = c4.shape 87 | 88 | _c4 = self.linear_c4(c4).permute(0, 2, 1).reshape( 89 | n, -1, c4.shape[2], c4.shape[3]) 90 | _c4 = resize(_c4, size=c1.size()[2:], 91 | mode='bilinear', align_corners=False) 92 | 93 | _c3 = self.linear_c3(c3).permute(0, 2, 1).reshape( 94 | n, -1, c3.shape[2], c3.shape[3]) 95 | _c3 = resize(_c3, size=c1.size()[2:], 96 | mode='bilinear', align_corners=False) 97 | 98 | _c2 = self.linear_c2(c2).permute(0, 2, 1).reshape( 99 | n, -1, c2.shape[2], c2.shape[3]) 100 | _c2 = resize(_c2, size=c1.size()[2:], 101 | mode='bilinear', align_corners=False) 102 | 103 | _c1 = self.linear_c1(c1).permute(0, 2, 1).reshape( 104 | n, -1, c1.shape[2], c1.shape[3]) 105 | 106 | _c = self.linear_fuse(torch.cat([_c4, _c3, _c2, _c1], dim=1)) 107 | 108 | x = self.dropout(_c) 109 | x = self.linear_pred(x) 110 | 111 | return x 112 | -------------------------------------------------------------------------------- /models/swin_mtl.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # MTLoRA 3 | # GitHub: https://github.com/scale-lab/MTLoRA 4 | # Copyright (c) 2024 SCALE Lab, Brown University 5 | # Licensed under the MIT License (see LICENSE for details). 6 | # -------------------------------------------------------- 7 | 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | import types 13 | 14 | 15 | def get_head(task, backbone_channels, num_outputs, config=None, multiscale=True): 16 | """Return the decoder head""" 17 | head_type = config.MODEL.DECODER_HEAD.get(task, "hrnet") 18 | 19 | if head_type == "hrnet": 20 | print( 21 | f"Using hrnet for task {task} with backbone channels {backbone_channels}") 22 | from models.seg_hrnet import HighResolutionHead 23 | 24 | return HighResolutionHead(backbone_channels, num_outputs) 25 | elif head_type == "updecoder": 26 | print(f"Using updecoder for task {task}") 27 | from models.updecoder import Decoder 28 | 29 | return Decoder( 30 | backbone_channels, 31 | num_outputs, 32 | args=types.SimpleNamespace( 33 | **{ 34 | "num_deconv": 3, 35 | "num_filters": [32, 32, 32], 36 | "deconv_kernels": [2, 2, 2], 37 | } 38 | ), 39 | ) 40 | elif head_type == "segformer": 41 | print( 42 | f"Using segformer for task {task} with {config.MODEL.SEGFORMER_CHANNELS} channels" 43 | ) 44 | from models.segformer import SegFormerHead 45 | 46 | return SegFormerHead( 47 | in_channels=backbone_channels, 48 | channels=config.MODEL.SEGFORMER_CHANNELS, 49 | num_classes=num_outputs, 50 | ) 51 | else: 52 | if not multiscale: 53 | from models.aspp_single import DeepLabHead 54 | else: 55 | from models.aspp import DeepLabHead 56 | print(f"Using ASPP for task {task}") 57 | return DeepLabHead(backbone_channels, num_outputs) 58 | 59 | 60 | class DecoderGroup(nn.Module): 61 | def __init__(self, tasks, num_outputs, channels, out_size, config, multiscale=True): 62 | super(DecoderGroup, self).__init__() 63 | self.tasks = tasks 64 | self.num_outputs = num_outputs 65 | self.channels = channels 66 | self.decoders = nn.ModuleDict() 67 | self.out_size = out_size 68 | self.multiscale = multiscale 69 | for task in self.tasks: 70 | self.decoders[task] = get_head( 71 | task, 72 | self.channels, 73 | self.num_outputs[task], 74 | config=config, 75 | multiscale=self.multiscale, 76 | ) 77 | 78 | def forward(self, x): 79 | result = { 80 | task: F.interpolate( 81 | self.decoders[task](x[task]), self.out_size, mode="bilinear" 82 | ) 83 | for task in self.tasks 84 | } 85 | return result 86 | 87 | 88 | class Downsampler(nn.Module): 89 | def __init__(self, dims, channels, input_res, bias=False, enabled=True): 90 | super(Downsampler, self).__init__() 91 | self.dims = dims 92 | self.input_res = input_res 93 | self.enabled = enabled 94 | if self.enabled: 95 | self.downsample_0 = torch.nn.Conv2d( 96 | dims[0], channels[0], 1, bias=bias) 97 | self.downsample_1 = torch.nn.Conv2d( 98 | dims[1], channels[1], 1, bias=bias) 99 | self.downsample_2 = torch.nn.Conv2d( 100 | dims[2], channels[2], 1, bias=bias) 101 | self.downsample_3 = torch.nn.Conv2d( 102 | dims[3], channels[3], 1, bias=bias) 103 | 104 | def forward(self, x): 105 | s_3 = ( 106 | x[3] 107 | .view(-1, self.input_res[3], self.input_res[3], self.dims[3]) 108 | .permute(0, 3, 1, 2) 109 | ) 110 | 111 | s_2 = ( 112 | x[2] 113 | .view(-1, self.input_res[2], self.input_res[2], self.dims[2]) 114 | .permute(0, 3, 1, 2) 115 | ) 116 | s_1 = ( 117 | x[1] 118 | .view(-1, self.input_res[1], self.input_res[1], self.dims[1]) 119 | .permute(0, 3, 1, 2) 120 | ) 121 | s_0 = ( 122 | x[0] 123 | .view(-1, self.input_res[0], self.input_res[0], self.dims[0]) 124 | .permute(0, 3, 1, 2) 125 | ) 126 | 127 | if self.enabled: 128 | return [ 129 | self.downsample_0(s_0), 130 | self.downsample_1(s_1), 131 | self.downsample_2(s_2), 132 | self.downsample_3(s_3), 133 | ] 134 | else: 135 | return [s_0, s_1, s_2, s_3] 136 | 137 | 138 | class MultiTaskSwin(nn.Module): 139 | def __init__(self, encoder, config): 140 | super(MultiTaskSwin, self).__init__() 141 | 142 | self.backbone = encoder 143 | self.num_outputs = config.TASKS_CONFIG.ALL_TASKS.NUM_OUTPUT 144 | self.tasks = config.TASKS 145 | if hasattr(self.backbone, "patch_embed"): 146 | patches_resolution = self.backbone.patch_embed.patches_resolution 147 | self.embed_dim = self.backbone.embed_dim 148 | num_layers = self.backbone.num_layers 149 | self.dims = [ 150 | int((self.embed_dim * 2 ** ((i + 1) if i < num_layers - 1 else i))) 151 | for i in range(num_layers) 152 | ] 153 | self.input_res = [ 154 | patches_resolution[0] // (2 ** 155 | ((i + 1) if i < num_layers - 1 else i)) 156 | for i in range(num_layers) 157 | ] 158 | self.window_size = self.backbone.layers[0].blocks[0].window_size 159 | self.img_size = self.backbone.patch_embed.img_size 160 | else: 161 | self.input_res = [28, 14, 7, 7] 162 | 163 | self.dims = [192, 384, 768, 768] 164 | self.window_size = config.MODEL.SWIN.WINDOW_SIZE 165 | self.img_size = config.DATA.IMG_SIZE 166 | 167 | self.channels = ( 168 | config.MODEL.DECODER_CHANNELS 169 | if config.MODEL.DECODER_DOWNSAMPLER 170 | else self.dims 171 | ) 172 | self.mtlora = config.MODEL.MTLORA 173 | if self.mtlora.ENABLED: 174 | self.downsampler = nn.ModuleDict( 175 | { 176 | task: Downsampler( 177 | dims=self.dims, 178 | channels=self.channels, 179 | input_res=self.input_res, 180 | bias=False, 181 | ) 182 | for task in self.tasks 183 | } 184 | ) 185 | else: 186 | self.downsampler = Downsampler( 187 | dims=self.dims, 188 | channels=self.channels, 189 | input_res=self.input_res, 190 | bias=False, 191 | ) 192 | 193 | self.per_task_downsampler = config.MODEL.PER_TASK_DOWNSAMPLER 194 | if self.per_task_downsampler: 195 | self.downsampler = nn.ModuleDict( 196 | { 197 | task: Downsampler( 198 | dims=self.dims, 199 | channels=self.channels, 200 | input_res=self.input_res, 201 | bias=False, 202 | enabled=config.MODEL.DECODER_DOWNSAMPLER, 203 | ) 204 | for task in self.tasks 205 | } 206 | ) 207 | else: 208 | self.downsampler = Downsampler( 209 | dims=self.dims, 210 | channels=self.channels, 211 | input_res=self.input_res, 212 | bias=False, 213 | ) 214 | self.decoders = DecoderGroup( 215 | self.tasks, 216 | self.num_outputs, 217 | channels=self.channels, 218 | out_size=self.img_size, 219 | config=config, 220 | multiscale=True, 221 | ) 222 | 223 | def forward(self, x): 224 | shared_representation = self.backbone(x, return_stages=True) 225 | 226 | if self.mtlora.ENABLED: 227 | shared_ft = {task: [] for task in self.tasks} 228 | 229 | for _, tasks_shared_rep in shared_representation: 230 | for task, shared_rep in tasks_shared_rep.items(): 231 | shared_ft[task].append(shared_rep) 232 | for task in self.tasks: 233 | shared_ft[task] = self.downsampler[task](shared_ft[task]) 234 | else: 235 | if self.per_task_downsampler: 236 | shared_ft = { 237 | task: self.downsampler[task](shared_representation) 238 | for task in self.tasks 239 | } 240 | else: 241 | shared_representation = self.downsampler(shared_representation) 242 | shared_ft = { 243 | task: shared_representation for task in self.tasks} 244 | 245 | result = self.decoders(shared_ft) 246 | return result 247 | 248 | def freeze_all(self): 249 | for param in self.parameters(): 250 | param.requires_grad = False 251 | 252 | def unfreeze_all(self): 253 | for param in self.parameters(): 254 | param.requires_grad = True 255 | 256 | def freeze_task(self, task): 257 | for param in self.decoders[task].parameters(): 258 | param.requires_grad = False 259 | 260 | def unfreeze_task(self, task): 261 | for param in self.decoders[task].parameters(): 262 | param.requires_grad = True 263 | 264 | def freeze_backbone(self): 265 | for param in self.backbone.parameters(): 266 | param.requires_grad = False 267 | 268 | def unfreeze_backbone(self): 269 | for param in self.backbone.parameters(): 270 | param.requires_grad = True 271 | -------------------------------------------------------------------------------- /models/transformer_head.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .swin_transformer import SwinTransformer, PatchMerging, PatchEmbed 5 | 6 | 7 | class UpSample(nn.Module): 8 | def __init__(self, embed_dim): 9 | super().__init__() 10 | self.upsample = nn.Upsample(scale_factor=2, mode='nearest') 11 | self.proj = nn.Conv2d(embed_dim, embed_dim, 1) 12 | 13 | def forward(self, x): 14 | B, H, W, C = x.shape 15 | x = x.permute(0, 3, 1, 2) 16 | x = self.upsample(x) 17 | x = self.proj(x) 18 | _, _, H, W = x.shape 19 | x = x.permute(0, 2, 3, 1) 20 | return x 21 | 22 | 23 | class SwinDecoderHead(SwinTransformer): 24 | def __init__(self, *args, **kwargs): 25 | super().__init__(*args, **kwargs) 26 | # Override downsample layers with upsampling 27 | self.downsample = nn.ModuleList([ 28 | UpSample(embed_dim=kwargs['embed_dim'] * 2**i) for i in range(self.num_layers) 29 | ]) 30 | 31 | def forward_features(self, x, return_stages=False, flatten_ft=True): 32 | return_stages = False 33 | flatten_ft = True 34 | 35 | x0_h, x0_w = x[0].size(2), x[0].size(3) 36 | x1 = F.interpolate(x[1], (x0_h, x0_w), mode='bilinear') 37 | x2 = F.interpolate(x[2], (x0_h, x0_w), mode='bilinear') 38 | x3 = F.interpolate(x[3], (x0_h, x0_w), mode='bilinear') 39 | 40 | x = torch.cat([x[0], x1, x2, x3], 1) 41 | 42 | x = self.patch_embed(x) 43 | if self.ape: 44 | x = x + self.absolute_pos_embed 45 | x = self.pos_drop(x) 46 | if return_stages: 47 | out = [] 48 | for layer in self.layers: 49 | x = layer(x) 50 | if return_stages: 51 | out.append(x) 52 | return x 53 | -------------------------------------------------------------------------------- /models/updecoder.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # MTLoRA 3 | # GitHub: https://github.com/scale-lab/MTLoRA 4 | # Copyright (c) 2024 SCALE Lab, Brown University 5 | # Licensed under the MIT License (see LICENSE for details). 6 | # -------------------------------------------------------- 7 | 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | 14 | from mmcv.cnn import build_conv_layer, build_norm_layer, build_upsample_layer 15 | from mmengine.model import constant_init, normal_init 16 | 17 | 18 | class Decoder(nn.Module): 19 | def __init__(self, in_channels, out_channels, args): 20 | super().__init__() 21 | self.deconv = args.num_deconv 22 | self.in_channels = sum(in_channels) 23 | 24 | # import pdb; pdb.set_trace() 25 | self.deconv_layers = self._make_deconv_layer( 26 | args.num_deconv, 27 | args.num_filters, 28 | args.deconv_kernels, 29 | ) 30 | 31 | conv_layers = [] 32 | conv_layers.append( 33 | build_conv_layer( 34 | dict(type='Conv2d'), 35 | in_channels=args.num_filters[-1], 36 | out_channels=out_channels, 37 | kernel_size=3, 38 | stride=1, 39 | padding=1)) 40 | conv_layers.append( 41 | build_norm_layer(dict(type='BN'), out_channels)[1]) 42 | conv_layers.append(nn.ReLU(inplace=True)) 43 | self.conv_layers = nn.Sequential(*conv_layers) 44 | 45 | self.up = nn.Upsample( 46 | scale_factor=2, mode='bilinear', align_corners=False) 47 | 48 | def forward(self, x): 49 | x0_h, x0_w = x[0].size(2), x[0].size(3) 50 | x1 = F.interpolate(x[1], (x0_h, x0_w), mode='bilinear') 51 | x2 = F.interpolate(x[2], (x0_h, x0_w), mode='bilinear') 52 | x3 = F.interpolate(x[3], (x0_h, x0_w), mode='bilinear') 53 | 54 | conv_feats = torch.cat([x[0], x1, x2, x3], 1) 55 | out = self.deconv_layers(conv_feats) 56 | out = self.conv_layers(out) 57 | 58 | out = self.up(out) 59 | out = self.up(out) 60 | 61 | return out 62 | 63 | def _make_deconv_layer(self, num_layers, num_filters, num_kernels): 64 | layers = [] 65 | in_planes = self.in_channels 66 | for i in range(num_layers): 67 | kernel, padding, output_padding = \ 68 | self._get_deconv_cfg(num_kernels[i]) 69 | 70 | planes = num_filters[i] 71 | layers.append( 72 | build_upsample_layer( 73 | dict(type='deconv'), 74 | in_channels=in_planes, 75 | out_channels=planes, 76 | kernel_size=kernel, 77 | stride=2, 78 | padding=padding, 79 | output_padding=output_padding, 80 | bias=False)) 81 | layers.append(nn.BatchNorm2d(planes)) 82 | layers.append(nn.ReLU(inplace=True)) 83 | in_planes = planes 84 | 85 | return nn.Sequential(*layers) 86 | 87 | def _get_deconv_cfg(self, deconv_kernel): 88 | """Get configurations for deconv layers.""" 89 | if deconv_kernel == 4: 90 | padding = 1 91 | output_padding = 0 92 | elif deconv_kernel == 3: 93 | padding = 1 94 | output_padding = 1 95 | elif deconv_kernel == 2: 96 | padding = 0 97 | output_padding = 0 98 | else: 99 | raise ValueError(f'Not supported num_kernels ({deconv_kernel}).') 100 | 101 | return deconv_kernel, padding, output_padding 102 | 103 | def init_weights(self): 104 | """Initialize model weights.""" 105 | for m in self.modules(): 106 | if isinstance(m, nn.Conv2d): 107 | normal_init(m, std=0.001, bias=0) 108 | elif isinstance(m, nn.BatchNorm2d): 109 | constant_init(m, 1) 110 | elif isinstance(m, nn.ConvTranspose2d): 111 | normal_init(m, std=0.001) 112 | -------------------------------------------------------------------------------- /mtl_loss_schemes.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # MTLoRA 3 | # GitHub: https://github.com/scale-lab/MTLoRA 4 | # 5 | # Original file: 6 | # Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/) 7 | # Written by Simon Vandenhende 8 | # 9 | # Modifications: 10 | # Copyright (c) 2024 SCALE Lab, Brown University 11 | # Licensed under the MIT License (see LICENSE for details) 12 | # -------------------------------------------------------- 13 | 14 | 15 | import torch 16 | import torch.nn as nn 17 | import torch.nn.functional as F 18 | from torch.nn.modules.module import Module 19 | import numpy as np 20 | 21 | 22 | class SoftMaxwithLoss(Module): 23 | """ 24 | This function returns cross entropy loss for semantic segmentation 25 | """ 26 | 27 | def __init__(self, ignore_index=255): 28 | super(SoftMaxwithLoss, self).__init__() 29 | self.softmax = nn.LogSoftmax(dim=1) 30 | self.criterion = nn.NLLLoss(ignore_index=ignore_index) 31 | 32 | def forward(self, out, label): 33 | assert not label.requires_grad 34 | # out shape batch_size x channels x h x w 35 | # label shape batch_size x 1 x h x w 36 | label = label[:, 0, :, :].long() 37 | loss = self.criterion(self.softmax(out), label) 38 | 39 | return loss 40 | 41 | 42 | class BalancedCrossEntropyLoss(Module): 43 | """ 44 | Balanced Cross Entropy Loss with optional ignore regions 45 | """ 46 | 47 | def __init__(self, size_average=True, batch_average=True, pos_weight=None): 48 | super(BalancedCrossEntropyLoss, self).__init__() 49 | self.size_average = size_average 50 | self.batch_average = batch_average 51 | self.pos_weight = pos_weight 52 | 53 | def forward(self, output, label, void_pixels=None): 54 | assert (output.size() == label.size()) 55 | labels = torch.ge(label, 0.5).float() 56 | 57 | # Weighting of the loss, default is HED-style 58 | if self.pos_weight is None: 59 | num_labels_pos = torch.sum(labels) 60 | num_labels_neg = torch.sum(1.0 - labels) 61 | num_total = num_labels_pos + num_labels_neg 62 | w = num_labels_neg / num_total 63 | else: 64 | w = self.pos_weight 65 | 66 | output_gt_zero = torch.ge(output, 0).float() 67 | loss_val = torch.mul(output, (labels - output_gt_zero)) - torch.log( 68 | 1 + torch.exp(output - 2 * torch.mul(output, output_gt_zero))) 69 | 70 | loss_pos_pix = -torch.mul(labels, loss_val) 71 | loss_neg_pix = -torch.mul(1.0 - labels, loss_val) 72 | 73 | if void_pixels is not None and not self.pos_weight: 74 | w_void = torch.le(void_pixels, 0.5).float() 75 | loss_pos_pix = torch.mul(w_void, loss_pos_pix) 76 | loss_neg_pix = torch.mul(w_void, loss_neg_pix) 77 | num_total = num_total - torch.ge(void_pixels, 0.5).float().sum() 78 | w = num_labels_neg / num_total 79 | 80 | loss_pos = torch.sum(loss_pos_pix) 81 | loss_neg = torch.sum(loss_neg_pix) 82 | 83 | final_loss = w * loss_pos + (1 - w) * loss_neg 84 | 85 | if self.size_average: 86 | final_loss /= float(np.prod(label.size())) 87 | elif self.batch_average: 88 | final_loss /= label.size()[0] 89 | 90 | return final_loss 91 | 92 | 93 | class BinaryCrossEntropyLoss(Module): 94 | """ 95 | Binary Cross Entropy with ignore regions, not balanced. 96 | """ 97 | 98 | def __init__(self, size_average=True, batch_average=True): 99 | super(BinaryCrossEntropyLoss, self).__init__() 100 | self.size_average = size_average 101 | self.batch_average = batch_average 102 | 103 | def forward(self, output, label, void_pixels=None): 104 | assert (output.size() == label.size()) 105 | 106 | labels = torch.ge(label, 0.5).float() 107 | 108 | output_gt_zero = torch.ge(output, 0).float() 109 | loss_val = torch.mul(output, (labels - output_gt_zero)) - torch.log( 110 | 1 + torch.exp(output - 2 * torch.mul(output, output_gt_zero))) 111 | 112 | loss_pos_pix = -torch.mul(labels, loss_val) 113 | loss_neg_pix = -torch.mul(1.0 - labels, loss_val) 114 | 115 | if void_pixels is not None: 116 | w_void = torch.le(void_pixels, 0.5).float() 117 | loss_pos_pix = torch.mul(w_void, loss_pos_pix) 118 | loss_neg_pix = torch.mul(w_void, loss_neg_pix) 119 | 120 | loss_pos = torch.sum(loss_pos_pix) 121 | loss_neg = torch.sum(loss_neg_pix) 122 | final_loss = loss_pos + loss_neg 123 | 124 | if self.size_average: 125 | final_loss /= float(np.prod(label.size())) 126 | elif self.batch_average: 127 | final_loss /= label.size()[0] 128 | 129 | return final_loss 130 | 131 | 132 | class DepthLoss(nn.Module): 133 | """ 134 | Loss for depth prediction. By default L1 loss is used. 135 | """ 136 | 137 | def __init__(self, loss='l1'): 138 | super(DepthLoss, self).__init__() 139 | if loss == 'l1': 140 | self.loss = nn.L1Loss() 141 | 142 | else: 143 | raise NotImplementedError( 144 | 'Loss {} currently not supported in DepthLoss'.format(loss)) 145 | 146 | def forward(self, out, label): 147 | mask = (label != 255) 148 | return self.loss(torch.masked_select(out, mask), torch.masked_select(label, mask)) 149 | 150 | 151 | class Normalize(nn.Module): 152 | def __init__(self): 153 | super(Normalize, self).__init__() 154 | 155 | def forward(self, bottom): 156 | qn = torch.norm(bottom, p=2, dim=1).unsqueeze(dim=1) + 1e-12 157 | top = bottom.div(qn) 158 | 159 | return top 160 | 161 | 162 | class NormalsLoss(Module): 163 | """ 164 | L1 loss with ignore labels 165 | normalize: normalization for surface normals 166 | """ 167 | 168 | def __init__(self, size_average=True, normalize=False, norm=1): 169 | super(NormalsLoss, self).__init__() 170 | 171 | self.size_average = size_average 172 | 173 | if normalize: 174 | self.normalize = Normalize() 175 | else: 176 | self.normalize = None 177 | 178 | if norm == 1: 179 | # print('Using L1 loss for surface normals') 180 | self.loss_func = F.l1_loss 181 | elif norm == 2: 182 | # print('Using L2 loss for surface normals') 183 | self.loss_func = F.mse_loss 184 | else: 185 | raise NotImplementedError 186 | 187 | def forward(self, out, label, ignore_label=255): 188 | assert not label.requires_grad 189 | mask = (label != ignore_label) 190 | n_valid = torch.sum(mask).item() 191 | 192 | if self.normalize is not None: 193 | out_norm = self.normalize(out) 194 | loss = self.loss_func(torch.masked_select( 195 | out_norm, mask), torch.masked_select(label, mask), reduction='sum') 196 | else: 197 | loss = self.loss_func(torch.masked_select( 198 | out, mask), torch.masked_select(label, mask), reduction='sum') 199 | 200 | if self.size_average: 201 | if ignore_label: 202 | ret_loss = torch.div(loss, max(n_valid, 1e-6)) 203 | return ret_loss 204 | else: 205 | ret_loss = torch.div(loss, float(np.prod(label.size()))) 206 | return ret_loss 207 | 208 | return loss 209 | 210 | 211 | class SingleTaskLoss(nn.Module): 212 | def __init__(self, loss_ft, task): 213 | super(SingleTaskLoss, self).__init__() 214 | self.loss_ft = loss_ft 215 | self.task = task 216 | 217 | def forward(self, pred, gt): 218 | out = {self.task: self.loss_ft(pred[self.task], gt[self.task])} 219 | out['total'] = out[self.task] 220 | return out 221 | 222 | 223 | class MultiTaskLoss(nn.Module): 224 | def __init__(self, tasks: list, loss_ft: nn.ModuleDict, loss_weights: dict): 225 | super(MultiTaskLoss, self).__init__() 226 | assert (set(tasks) == set(loss_ft.keys())) 227 | assert (set(tasks) == set(loss_weights.keys())) 228 | self.tasks = tasks 229 | self.loss_ft = loss_ft 230 | self.loss_weights = loss_weights 231 | 232 | def forward(self, pred, gt): 233 | out = { 234 | task: self.loss_ft[task](pred[task], gt[task]) for task in self.tasks 235 | } 236 | out['total'] = torch.sum(torch.stack( 237 | [self.loss_weights[t] * out[t] for t in self.tasks])) 238 | return out['total'], out 239 | 240 | 241 | def get_loss(task_cfg, task=None, config={"DATA": {}}): 242 | """ Return loss function for a specific task """ 243 | if task == 'edge': 244 | criterion = BalancedCrossEntropyLoss( 245 | size_average=True, pos_weight=task_cfg.get('edge_w', 0.95)) 246 | 247 | elif task == 'semseg' or task == 'human_parts': 248 | criterion = SoftMaxwithLoss(ignore_index=255) 249 | 250 | elif task == 'normals': 251 | criterion = NormalsLoss(normalize=True, size_average=True, norm=1) 252 | 253 | elif task == 'sal': 254 | criterion = BalancedCrossEntropyLoss(size_average=True) 255 | 256 | elif task == 'depth': 257 | criterion = DepthLoss('l1') 258 | 259 | else: 260 | raise NotImplementedError('Undefined Loss: Choose a task among ' 261 | 'edge, semseg, human_parts, sal, depth, or normals') 262 | 263 | return criterion 264 | -------------------------------------------------------------------------------- /optimizer.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # MTLoRA 3 | # GitHub: https://github.com/scale-lab/MTLoRA 4 | # Built upon Swin Transformer (https://github.com/microsoft/Swin-Transformer) 5 | # 6 | # Original file: 7 | # Copyright (c) 2021 Microsoft 8 | # Licensed under the MIT License 9 | # Written by Ze Liu 10 | # 11 | # Modifications: 12 | # Copyright (c) 2024 SCALE Lab, Brown University 13 | # Licensed under the MIT License (see LICENSE for details) 14 | # -------------------------------------------------------- 15 | 16 | 17 | from functools import partial 18 | from torch import optim as optim 19 | 20 | try: 21 | from apex.optimizers import FusedAdam, FusedLAMB 22 | except: 23 | FusedAdam = None 24 | FusedLAMB = None 25 | print("To use FusedLAMB or FusedAdam, please install apex.") 26 | 27 | 28 | def build_optimizer(config, model, simmim=False, is_pretrain=False): 29 | """ 30 | Build optimizer, set weight decay of normalization to 0 by default. 31 | """ 32 | skip = {} 33 | skip_keywords = {} 34 | if hasattr(model, 'no_weight_decay'): 35 | skip = model.no_weight_decay() 36 | if hasattr(model, 'no_weight_decay_keywords'): 37 | skip_keywords = model.no_weight_decay_keywords() 38 | if simmim: 39 | if is_pretrain: 40 | parameters = get_pretrain_param_groups(model, skip, skip_keywords) 41 | else: 42 | depths = config.MODEL.SWIN.DEPTHS if config.MODEL.TYPE == 'swin' else config.MODEL.SWINV2.DEPTHS 43 | num_layers = sum(depths) 44 | get_layer_func = partial( 45 | get_swin_layer, num_layers=num_layers + 2, depths=depths) 46 | scales = list(config.TRAIN.LAYER_DECAY ** 47 | i for i in reversed(range(num_layers + 2))) 48 | parameters = get_finetune_param_groups( 49 | model, config.TRAIN.BASE_LR, config.TRAIN.WEIGHT_DECAY, get_layer_func, scales, skip, skip_keywords) 50 | else: 51 | parameters = set_weight_decay(model, skip, skip_keywords) 52 | 53 | opt_lower = config.TRAIN.OPTIMIZER.NAME.lower() 54 | optimizer = None 55 | if opt_lower == 'sgd': 56 | optimizer = optim.SGD(parameters, momentum=config.TRAIN.OPTIMIZER.MOMENTUM, nesterov=True, 57 | lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY) 58 | elif opt_lower == 'adamw': 59 | optimizer = optim.AdamW(parameters, eps=config.TRAIN.OPTIMIZER.EPS, betas=config.TRAIN.OPTIMIZER.BETAS, 60 | lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY) 61 | elif opt_lower == 'fused_adam': 62 | optimizer = FusedAdam(parameters, eps=config.TRAIN.OPTIMIZER.EPS, betas=config.TRAIN.OPTIMIZER.BETAS, 63 | lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY) 64 | elif opt_lower == 'fused_lamb': 65 | optimizer = FusedLAMB(parameters, eps=config.TRAIN.OPTIMIZER.EPS, betas=config.TRAIN.OPTIMIZER.BETAS, 66 | lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY) 67 | 68 | return optimizer 69 | 70 | 71 | def set_weight_decay(model, skip_list=(), skip_keywords=()): 72 | has_decay = [] 73 | no_decay = [] 74 | 75 | for name, param in model.named_parameters(): 76 | if not param.requires_grad: 77 | continue # frozen weights 78 | if len(param.shape) == 1 or name.endswith(".bias") or (name in skip_list) or \ 79 | check_keywords_in_name(name, skip_keywords): 80 | no_decay.append(param) 81 | # print(f"{name} has no weight decay") 82 | else: 83 | has_decay.append(param) 84 | return [{'params': has_decay}, 85 | {'params': no_decay, 'weight_decay': 0.}] 86 | 87 | 88 | def check_keywords_in_name(name, keywords=()): 89 | isin = False 90 | for keyword in keywords: 91 | if keyword in name: 92 | isin = True 93 | return isin 94 | 95 | 96 | def get_pretrain_param_groups(model, skip_list=(), skip_keywords=()): 97 | has_decay = [] 98 | no_decay = [] 99 | has_decay_name = [] 100 | no_decay_name = [] 101 | 102 | for name, param in model.named_parameters(): 103 | if not param.requires_grad: 104 | continue 105 | if len(param.shape) == 1 or name.endswith(".bias") or (name in skip_list) or \ 106 | check_keywords_in_name(name, skip_keywords): 107 | no_decay.append(param) 108 | no_decay_name.append(name) 109 | else: 110 | has_decay.append(param) 111 | has_decay_name.append(name) 112 | return [{'params': has_decay}, 113 | {'params': no_decay, 'weight_decay': 0.}] 114 | 115 | 116 | def get_swin_layer(name, num_layers, depths): 117 | if name in ("mask_token"): 118 | return 0 119 | elif name.startswith("patch_embed"): 120 | return 0 121 | elif name.startswith("layers"): 122 | layer_id = int(name.split('.')[1]) 123 | block_id = name.split('.')[3] 124 | if block_id == 'reduction' or block_id == 'norm': 125 | return sum(depths[:layer_id + 1]) 126 | layer_id = sum(depths[:layer_id]) + int(block_id) 127 | return layer_id + 1 128 | else: 129 | return num_layers - 1 130 | 131 | 132 | def get_finetune_param_groups(model, lr, weight_decay, get_layer_func, scales, skip_list=(), skip_keywords=()): 133 | parameter_group_names = {} 134 | parameter_group_vars = {} 135 | 136 | for name, param in model.named_parameters(): 137 | if not param.requires_grad: 138 | continue 139 | if len(param.shape) == 1 or name.endswith(".bias") or (name in skip_list) or \ 140 | check_keywords_in_name(name, skip_keywords): 141 | group_name = "no_decay" 142 | this_weight_decay = 0. 143 | else: 144 | group_name = "decay" 145 | this_weight_decay = weight_decay 146 | if get_layer_func is not None: 147 | layer_id = get_layer_func(name) 148 | group_name = "layer_%d_%s" % (layer_id, group_name) 149 | else: 150 | layer_id = None 151 | 152 | if group_name not in parameter_group_names: 153 | if scales is not None: 154 | scale = scales[layer_id] 155 | else: 156 | scale = 1. 157 | 158 | parameter_group_names[group_name] = { 159 | "group_name": group_name, 160 | "weight_decay": this_weight_decay, 161 | "params": [], 162 | "lr": lr * scale, 163 | "lr_scale": scale, 164 | } 165 | parameter_group_vars[group_name] = { 166 | "group_name": group_name, 167 | "weight_decay": this_weight_decay, 168 | "params": [], 169 | "lr": lr * scale, 170 | "lr_scale": scale 171 | } 172 | 173 | parameter_group_vars[group_name]["params"].append(param) 174 | parameter_group_names[group_name]["params"].append(name) 175 | return list(parameter_group_vars.values()) 176 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | easydict==1.10 2 | filelock==3.12.0 3 | fsspec==2023.5.0 4 | huggingface-hub==0.15.1 5 | imageio==2.31.0 6 | lazy-loader==0.2 7 | networkx==3.1 8 | opencv-python==4.7.0.72 9 | ptflops==0.7 10 | pywavelets==1.4.1 11 | pyyaml==6.0 12 | safetensors==0.3.1 13 | scikit-image==0.21.0 14 | six==1.16.0 15 | termcolor==2.3.0 16 | tifffile==2023.4.12 17 | timm==0.9.2 18 | tqdm==4.65.0 19 | yacs==0.1.8 20 | mmcls==0.25.0 21 | mmcv==2.0.0 22 | mmengine==0.7.4 23 | mmsegmentation==1.0.0 --------------------------------------------------------------------------------