├── script ├── ignore.txt ├── model │ ├── __init__.py │ ├── conditional_modules.py │ └── swin_transformer_parallel.py ├── data │ ├── taskonomy │ │ ├── __init__.py │ │ ├── metadata │ │ │ ├── train_val_test_debug.csv │ │ │ └── train_val_test_tiny.csv │ │ ├── splits.py │ │ ├── task_configs.py │ │ ├── transforms.py │ │ └── taskonomy_dataset_s3.py │ ├── nyuv2.py │ └── nyuv2_same_batch.py ├── requirements.txt ├── loss │ ├── __pycache__ │ │ ├── losses.cpython-38.pyc │ │ └── metrics.cpython-38.pyc │ ├── losses.py │ └── metrics.py ├── evaluate.py ├── train_nyu.py ├── train_nyu_single_task.py └── train_taskonomy.py ├── avtar.gif ├── docs ├── code_icon.png ├── troa-new.png ├── youtube.png ├── nyu-qr-new.pdf ├── nyu-qr-new.png ├── TAA-finalised.png ├── film-finalised.pdf ├── film-finalised.png ├── teaser-iccv1.png ├── pdf_icon_32x32.jpeg ├── uda-results-new.pdf ├── uda-results-new.png ├── ICCV-presentation.pdf ├── bibtex_icon_36x36.png ├── presentation_icon.png ├── vision-adapter-new.pdf ├── vision-adapter-new.png ├── qualitative-taskonomy.png ├── overall-vision-adapter-architecture.png ├── offcanvas.css └── index.html ├── README.md ├── LICENSE └── .gitignore /script/ignore.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /script/model/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /avtar.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVRL/VTAGML/HEAD/avtar.gif -------------------------------------------------------------------------------- /docs/code_icon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVRL/VTAGML/HEAD/docs/code_icon.png -------------------------------------------------------------------------------- /docs/troa-new.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVRL/VTAGML/HEAD/docs/troa-new.png -------------------------------------------------------------------------------- /docs/youtube.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVRL/VTAGML/HEAD/docs/youtube.png -------------------------------------------------------------------------------- /docs/nyu-qr-new.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVRL/VTAGML/HEAD/docs/nyu-qr-new.pdf -------------------------------------------------------------------------------- /docs/nyu-qr-new.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVRL/VTAGML/HEAD/docs/nyu-qr-new.png -------------------------------------------------------------------------------- /script/data/taskonomy/__init__.py: -------------------------------------------------------------------------------- 1 | from .taskonomy_dataset_s3 import TaskonomyDatasetS3 -------------------------------------------------------------------------------- /docs/TAA-finalised.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVRL/VTAGML/HEAD/docs/TAA-finalised.png -------------------------------------------------------------------------------- /docs/film-finalised.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVRL/VTAGML/HEAD/docs/film-finalised.pdf -------------------------------------------------------------------------------- /docs/film-finalised.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVRL/VTAGML/HEAD/docs/film-finalised.png -------------------------------------------------------------------------------- /docs/teaser-iccv1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVRL/VTAGML/HEAD/docs/teaser-iccv1.png -------------------------------------------------------------------------------- /docs/pdf_icon_32x32.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVRL/VTAGML/HEAD/docs/pdf_icon_32x32.jpeg -------------------------------------------------------------------------------- /docs/uda-results-new.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVRL/VTAGML/HEAD/docs/uda-results-new.pdf -------------------------------------------------------------------------------- /docs/uda-results-new.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVRL/VTAGML/HEAD/docs/uda-results-new.png -------------------------------------------------------------------------------- /script/data/taskonomy/metadata/train_val_test_debug.csv: -------------------------------------------------------------------------------- 1 | id,train,val,test 2 | allensville,1,1,1 -------------------------------------------------------------------------------- /docs/ICCV-presentation.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVRL/VTAGML/HEAD/docs/ICCV-presentation.pdf -------------------------------------------------------------------------------- /docs/bibtex_icon_36x36.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVRL/VTAGML/HEAD/docs/bibtex_icon_36x36.png -------------------------------------------------------------------------------- /docs/presentation_icon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVRL/VTAGML/HEAD/docs/presentation_icon.png -------------------------------------------------------------------------------- /docs/vision-adapter-new.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVRL/VTAGML/HEAD/docs/vision-adapter-new.pdf -------------------------------------------------------------------------------- /docs/vision-adapter-new.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVRL/VTAGML/HEAD/docs/vision-adapter-new.png -------------------------------------------------------------------------------- /script/requirements.txt: -------------------------------------------------------------------------------- 1 | timm 2 | einops 3 | torchmetrics 4 | tensorboard 5 | transformers 6 | boto3 -------------------------------------------------------------------------------- /docs/qualitative-taskonomy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVRL/VTAGML/HEAD/docs/qualitative-taskonomy.png -------------------------------------------------------------------------------- /docs/overall-vision-adapter-architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVRL/VTAGML/HEAD/docs/overall-vision-adapter-architecture.png -------------------------------------------------------------------------------- /script/loss/__pycache__/losses.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVRL/VTAGML/HEAD/script/loss/__pycache__/losses.cpython-38.pyc -------------------------------------------------------------------------------- /script/loss/__pycache__/metrics.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVRL/VTAGML/HEAD/script/loss/__pycache__/metrics.cpython-38.pyc -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | ## Vision Transformer Adapters for Generalizable Multitask Learning 3 | Deblina Bhattacharjee, Sabine Süsstrunk, and Mathieu Salzmann 4 | [![DOI](https://zenodo.org/badge/680203400.svg)](https://zenodo.org/doi/10.5281/zenodo.11067070) 5 | 6 | ICCV 2023 Paper: https://arxiv.org/abs/2308.12372 7 | 8 | https://ivrl.github.io/VTAGML/ 9 | ![Figure Abstract](avtar.gif) 10 | 11 | 12 | -------------------------------------------------------------------------------- /script/data/taskonomy/metadata/train_val_test_tiny.csv: -------------------------------------------------------------------------------- 1 | id,train,val,test hanson,1,0,0 merom,1,0,0 klickitat,1,0,0 onaga,1,0,0 leonardo,1,0,0 marstons,1,0,0 newfields,1,0,0 pinesdale,1,0,0 lakeville,1,0,0 cosmos,1,0,0 benevolence,1,0,0 pomaria,1,0,0 tolstoy,1,0,0 shelbyville,1,0,0 allensville,1,0,0 wainscott,1,0,0 beechwood,1,0,0 coffeen,1,0,0 stockman,1,0,0 hiteman,1,0,0 woodbine,1,0,0 lindenwood,1,0,0 forkland,1,0,0 mifflinburg,1,0,0 ranchester,1,0,0 wiconisco,0,1,0 corozal,0,1,0 collierville,0,1,0 markleeville,0,1,0 darden,0,1,0 ihlen,0,0,1 muleshoe,0,0,1 uvalda,0,0,1 noxapater,0,0,1 mcdade,0,0,1 -------------------------------------------------------------------------------- /script/data/taskonomy/splits.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import os 3 | 4 | 5 | def get_splits(split_path, forbidden_buildings=[]): 6 | with open(split_path) as csvfile: 7 | readCSV = csv.reader(csvfile, delimiter=',') 8 | 9 | train_list = [] 10 | val_list = [] 11 | test_list = [] 12 | 13 | for row in readCSV: 14 | name, is_train, is_val, is_test = row 15 | if name in forbidden_buildings: 16 | continue 17 | if is_train == '1': 18 | train_list.append(name) 19 | if is_val == '1': 20 | val_list.append(name) 21 | if is_test == '1': 22 | test_list.append(name) 23 | return { 24 | 'train': sorted(train_list), 25 | 'val': sorted(val_list), 26 | 'test': sorted(test_list) 27 | } 28 | 29 | 30 | 31 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Images and Visual Representation Laboratory (IVRL) at EPFL 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /script/loss/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import torch.cuda.amp as amp 6 | 7 | class berHuLoss(nn.Module): 8 | def __init__(self): 9 | """ 10 | https://github.com/lhoyer/improving_segmentation_with_selfsupervised_depth/ 11 | """ 12 | super(berHuLoss, self).__init__() 13 | 14 | 15 | def make_valid_mask(self, tens, mask_val, conf=1e-7): 16 | 17 | valid_mask = (tens > (mask_val+conf) ) | (tens < (mask_val-conf)) 18 | 19 | return valid_mask 20 | 21 | 22 | def forward(self, inp, target, apply_log=False, threshold=.2, mask_val=None): 23 | if apply_log: 24 | inp, target = torch.log(1 + inp), torch.log(1 + target) 25 | 26 | if mask_val is None: 27 | valid_mask = (target > 0).detach() 28 | else: 29 | valid_mask = self.make_valid_mask(target, mask_val) 30 | 31 | absdiff = torch.abs(target - inp) * valid_mask #* mask 32 | C = threshold * torch.max(absdiff).item() 33 | loss = torch.mean(torch.where(absdiff <= C, 34 | absdiff, 35 | (absdiff * absdiff + C * C) / (2 * C))) 36 | return loss -------------------------------------------------------------------------------- /docs/offcanvas.css: -------------------------------------------------------------------------------- 1 | /* 2 | * Style tweaks 3 | * -------------------------------------------------- 4 | */ 5 | html, 6 | body { 7 | overflow-x: hidden; /*Prevent scroll on narrow devices */ 8 | padding-top: 30px; 9 | text-align: justify; 10 | } 11 | footer { 12 | padding: 10px 0; 13 | } 14 | .authors { 15 | font-size: 20px; 16 | } 17 | /*.container { 18 | max-width: 768px; 19 | }*/ 20 | .container { 21 | max-width: 1000px; 22 | } 23 | p { 24 | font-size: 16px; 25 | /*padding-bottom: 20px;*/ 26 | } 27 | 28 | li { 29 | font-size: 16px; 30 | } 31 | 32 | h2 { 33 | text-align: center; 34 | align: center; 35 | } 36 | 37 | .jumbotron{ 38 | text-align: center; 39 | } 40 | 41 | .btn { 42 | font-size: 18px; 43 | } 44 | 45 | .btn-disabled { 46 | /*background-color: #f4f4f4;*/ 47 | } 48 | 49 | .jumbotron h2 { 50 | font-size: 36px; 51 | } 52 | 53 | .section { 54 | padding-top: 30px; 55 | } 56 | 57 | .center{ 58 | display: block; 59 | margin-left: auto; 60 | margin-right: auto; 61 | } 62 | 63 | .vcontainer { 64 | position: relative; 65 | width: 100%; 66 | height: 0; 67 | padding-bottom: 56.25%; 68 | } 69 | .video { 70 | position: absolute; 71 | top: 0; 72 | left: 0; 73 | width: 100%; 74 | height: 100%; 75 | } 76 | 77 | .gif { 78 | padding:10px; 79 | display: block; 80 | margin-left: auto; 81 | margin-right: auto; 82 | text-align: center; 83 | } 84 | 85 | .caption { 86 | width:75%; 87 | font-size:14px 88 | } 89 | 90 | .bibtexsection { 91 | font-family: "Courier",monospace; 92 | font-size:16px; 93 | white-space:pre; 94 | background-color: #f4f4f4; 95 | text-align:left; 96 | } 97 | 98 | .canvas-row canvas { 99 | max-width:100%; 100 | } 101 | 102 | .padding-0{ 103 | padding-right:0; 104 | padding-left:0; 105 | } 106 | 107 | .vspace-top { 108 | margin-top: 30px; 109 | } 110 | -------------------------------------------------------------------------------- /script/loss/metrics.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | import torch.nn.functional as F 4 | import torch 5 | 6 | 7 | def iou_pytorch(outputs: torch.Tensor, labels: torch.Tensor): 8 | """ 9 | Taken from: 10 | https://www.kaggle.com/iezepov/fast-iou-scoring-metric-in-pytorch-and-numpy/comments 11 | """ 12 | 13 | SMOOTH = 1e-6 14 | # You can comment out this line if you are passing tensors of equal shape 15 | # But if you are passing output from UNet or something it will most probably 16 | # be with the BATCH x 1 x H x W shape 17 | outputs = outputs.squeeze(1) # BATCH x 1 x H x W => BATCH x H x W 18 | labels = labels.squeeze(1) 19 | 20 | intersection = (outputs & labels).float().sum( 21 | (1, 2)) # Will be zero if Truth=0 or Prediction=0 22 | 23 | union = (outputs | labels).float().sum( 24 | (1, 2)) # Will be zzero if both are 0 25 | 26 | # We smooth our devision to avoid 0/0 27 | iou = (intersection + SMOOTH) / (union + SMOOTH) 28 | 29 | # This is equal to comparing with thresolds 30 | thresholded = torch.clamp(20 * (iou - 0.5), 0, 10).ceil() / 10 31 | 32 | # Or thresholded.mean() if you are interested in average across the batch 33 | return thresholded.mean() 34 | 35 | 36 | 37 | def eval_depth(pred, target): 38 | 39 | """ 40 | Taken from: 41 | https://github.com/wl-zhao/VPD/blob/main/depth/utils_depth/metrics.py 42 | """ 43 | 44 | rmse_temp = 0 45 | d1_temp = 0 46 | 47 | for current_target, current_pred in zip(target, pred): 48 | ##assert current_gt_sparse.shape == current_pred.shape 49 | 50 | thresh = torch.max((current_target / current_pred), (current_pred / current_target)) 51 | 52 | d1 = (thresh < 1.25).float().mean()#torch.sum(thresh < 1.25).float().mean()# / len(thresh) 53 | #d2 = torch.sum(thresh < 1.25 ** 2).float() / len(thresh) 54 | #d3 = torch.sum(thresh < 1.25 ** 3).float() / len(thresh) 55 | 56 | diff = current_pred - current_target 57 | diff_log = torch.log(current_pred) - torch.log(current_target) 58 | 59 | #abs_rel = torch.mean(torch.abs(diff) / target) 60 | #sq_rel = torch.mean(torch.pow(diff, 2) / target) 61 | 62 | rmse = torch.sqrt(torch.mean(torch.pow(diff, 2))) 63 | rmse_log = torch.sqrt(torch.mean(torch.pow(diff_log , 2))) 64 | 65 | #log10 = torch.mean(torch.abs(torch.log10(pred) - torch.log10(target))) 66 | #silog = torch.sqrt(torch.pow(diff_log, 2).mean() - 0.5 * torch.pow(diff_log.mean(), 2)) 67 | 68 | #return {'d1': d1.item(), 'd2': d2.item(), 'd3': d3.item(), 'abs_rel': abs_rel.item(), 69 | # 'sq_rel': sq_rel.item(), 'rmse': rmse.item(), 'rmse_log': rmse_log.item(), 70 | # 'log10':log10.item(), 'silog':silog.item()} 71 | 72 | rmse_temp += rmse 73 | d1_temp += d1 74 | 75 | return {'d1': d1_temp.item()/len(pred),'rmse': rmse_temp.item()/len(pred)} 76 | -------------------------------------------------------------------------------- /script/data/taskonomy/task_configs.py: -------------------------------------------------------------------------------- 1 | #################### 2 | # Tasks 3 | #################### 4 | import torch 5 | 6 | 7 | task_parameters = { 8 | 'class_object':{ 9 | 'num_classes': 1000, 10 | 'ext': 'npy', 11 | 'domain_id': 'class_object', 12 | }, 13 | 'class_scene':{ 14 | 'num_classes': 365, 15 | 'ext': 'npy', 16 | 'domain_id': 'class_scene', 17 | }, 18 | 'depth_zbuffer':{ 19 | 'num_channels': 1, 20 | 'mask_val': 1.0, 21 | 'clamp_to': (0.0, 8000.0 / (2**16 - 1)), # Same as consistency 22 | 'ext': 'png', 23 | 'domain_id': 'depth_zbuffer', 24 | }, 25 | 'depth_euclidean':{ 26 | 'num_channels': 1, 27 | 'clamp_to': (0.0, 8000.0 / (2**16 - 1)), # Same as consistency 28 | # 'mask_val': 1.0, 29 | 'ext': 'png', 30 | 'domain_id': 'depth_euclidean', 31 | }, 32 | 'edge_texture': { 33 | 'num_channels': 1, 34 | 'clamp_to': (0.0, 0.25), 35 | 'ext': 'png', 36 | 'domain_id': 'edge_texture', 37 | }, 38 | 'edge_occlusion': { 39 | 'num_channels': 1, 40 | 'ext': 'png', 41 | 'domain_id': 'edge_occlusion', 42 | }, 43 | 'keypoints3d': { 44 | 'num_channels': 1, 45 | 'ext': 'png', 46 | 'domain_id': 'keypoints3d', 47 | }, 48 | 'keypoints2d':{ 49 | 'num_channels': 1, 50 | 'ext': 'png', 51 | 'domain_id': 'keypoints2d', 52 | }, 53 | 'principal_curvature':{ 54 | 'num_channels': 3, 55 | 'mask_val': 0.0, 56 | 'ext': 'png', 57 | 'domain_id': 'principal_curvature', 58 | }, 59 | 'reshading':{ 60 | 'num_channels': 1, 61 | 'ext': 'png', 62 | 'domain_id': 'reshading', 63 | }, 64 | 'normal':{ 65 | 'num_channels': 3, 66 | 'mask_val': 0.502, 67 | 'ext': 'png', 68 | 'domain_id': 'normal', 69 | }, 70 | 'mask_valid':{ 71 | 'num_channels': 1, 72 | 'mask_val': 0.0, 73 | 'ext': 'png', 74 | 'domain_id': 'depth_zbuffer', 75 | }, 76 | 'rgb':{ 77 | 'num_channels': 3, 78 | 'ext': 'png', 79 | 'domain_id': 'rgb', 80 | }, 81 | 'segment_semantic': { 82 | 'num_channels': 18, 83 | 'ext': 'png', 84 | 'domain_id': 'segmentsemantic', 85 | }, 86 | 'segment_unsup2d':{ 87 | 'num_channels': 64, 88 | 'ext': 'png', 89 | 'domain_id': 'segment_unsup2d', 90 | }, 91 | 'segment_unsup25d':{ 92 | 'num_channels': 64, 93 | 'ext': 'png', 94 | 'domain_id': 'segment_unsup25d', 95 | }, 96 | } 97 | 98 | 99 | PIX_TO_PIX_TASKS = ['colorization', 'edge_texture', 'edge_occlusion', 'keypoints3d', 'keypoints2d', 'reshading', 'depth_zbuffer', 'depth_euclidean', 'curvature', 'autoencoding', 'denoising', 'normal', 'inpainting', 'segment_unsup2d', 'segment_unsup25d', 'segment_semantic', ] 100 | FEED_FORWARD_TASKS = ['class_object', 'class_scene', 'room_layout', 'vanishing_point'] 101 | SINGLE_IMAGE_TASKS = PIX_TO_PIX_TASKS + FEED_FORWARD_TASKS 102 | SIAMESE_TASKS = ['fix_pose', 'jigsaw', 'ego_motion', 'point_match', 'non_fixated_pose'] 103 | 104 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /script/data/taskonomy/transforms.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import numpy as np 3 | import torch 4 | import torchvision 5 | import torchvision.transforms as transforms 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from typing import Optional 9 | 10 | from .task_configs import task_parameters 11 | 12 | 13 | MAKE_RESCALE_0_1_NEG1_POS1 = lambda n_chan: transforms.Normalize([0.5]*n_chan, [0.5]*n_chan) 14 | RESCALE_0_1_NEG1_POS1 = transforms.Normalize([0.5], [0.5]) # This needs to be different depending on num out chans 15 | MAKE_RESCALE_0_MAX_NEG1_POS1 = lambda maxx: transforms.Normalize([maxx / 2.], [maxx * 1.0]) 16 | RESCALE_0_255_NEG1_POS1 = transforms.Normalize([127.5,127.5,127.5], [255, 255, 255]) 17 | MAKE_RESCALE_0_MAX_0_POS1 = lambda maxx: transforms.Normalize([0.0], [maxx * 1.0]) 18 | 19 | # For semantic segmentation 20 | transform_dense_labels = lambda img: torch.Tensor(np.array(img)).long() # avoids normalizing 21 | 22 | # Transforms to a 3-channel tensor and then changes [0,1] -> [0, 1] 23 | transform_8bit = transforms.Compose([ 24 | transforms.ToTensor(), 25 | ]) 26 | 27 | # Transforms to a n-channel tensor and then changes [0,1] -> [0, 1]. Keeps only the first n-channels 28 | def transform_8bit_n_channel(n_channel=1, crop_channels=True): 29 | if crop_channels: 30 | crop_channels_fn = lambda x: x[:n_channel] if x.shape[0] > n_channel else x 31 | else: 32 | crop_channels_fn = lambda x: x 33 | return transforms.Compose([ 34 | transforms.ToTensor(), 35 | crop_channels_fn, 36 | ]) 37 | 38 | # Transforms to a 1-channel tensor and then changes [0,1] -> [0, 1]. 39 | def transform_16bit_single_channel(im): 40 | im = transforms.ToTensor()(np.array(im)) 41 | im = im.float() / (2 ** 16 - 1.0) 42 | return im 43 | 44 | def make_valid_mask(mask_float, max_pool_size=4): 45 | ''' 46 | Creates a mask indicating the valid parts of the image(s). 47 | Enlargens masked area using a max pooling operation. 48 | 49 | Args: 50 | mask_float: A (b x c x h x w) mask as loaded from the Taskonomy loader. 51 | max_pool_size: Parameter to choose how much to enlarge masked area. 52 | ''' 53 | squeeze = False 54 | if len(mask_float.shape) == 3: 55 | mask_float = mask_float.unsqueeze(0) 56 | squeeze = True 57 | _, _, h, w = mask_float.shape 58 | mask_float = 1 - mask_float 59 | mask_float = F.max_pool2d(mask_float, kernel_size=max_pool_size) 60 | mask_float = F.interpolate(mask_float, (h, w), mode='nearest') 61 | mask_valid = mask_float == 0 62 | mask_valid = mask_valid[0] if squeeze else mask_valid 63 | return mask_valid 64 | 65 | 66 | def task_transform(file, task: str, image_size=Optional[int]): 67 | transform = None 68 | 69 | if task in ['rgb', 'normal']: 70 | transform = transform_8bit 71 | elif task in ['mask_valid']: 72 | transform = transforms.Compose([ 73 | transforms.ToTensor(), 74 | make_valid_mask 75 | ]) 76 | elif task in ['keypoints2d', 'keypoints3d', 'depth_euclidean', 'depth_zbuffer', 'edge_texture', 'edge_occlusion']: 77 | #transform = transform_16bit_single_channel 78 | transform = transforms.Compose([ 79 | transforms.ToTensor() 80 | ]) 81 | elif task in ['principal_curvature', 'curvature']: 82 | transform = transform_8bit_n_channel(2) 83 | elif task in ['reshading']: 84 | transform = transform_8bit_n_channel(1) 85 | elif task in ['segment_semantic', 'segment_instance', 'segment_panoptic', 'fragments', 'segment_unsup2d', 'segment_unsup25d']: # this is stored as 1 channel image (H,W) where each pixel value is a different class 86 | transform = transform_dense_labels 87 | elif task in ['class_object', 'class_scene']: 88 | transform = torch.Tensor 89 | image_size = None 90 | else: 91 | transform = lambda x: x 92 | 93 | """if 'clamp_to' in task_parameters[task]: 94 | minn, maxx = task_parameters[task]['clamp_to'] 95 | if minn > 0: 96 | raise NotImplementedError("Rescaling (min1, max1) -> (min2, max2) not implemented for min1, min2 != 0 (task {})".format(task)) 97 | transform = transforms.Compose([ 98 | transform, 99 | MAKE_RESCALE_0_MAX_0_POS1(maxx) 100 | ])""" 101 | 102 | 103 | if image_size is not None: 104 | if task == 'fragments': 105 | resize_frag = lambda frag: F.interpolate(frag.permute(2,0,1).unsqueeze(0).float(), image_size, mode='nearest').long()[0].permute(1,2,0) 106 | transform = transforms.Compose([ 107 | transform, 108 | resize_frag 109 | ]) 110 | else: 111 | resize_method = Image.BILINEAR if task in ['rgb'] else Image.NEAREST 112 | transform = transforms.Compose([ 113 | transforms.Resize(image_size, resize_method), 114 | transform 115 | ]) 116 | 117 | 118 | if transform is not None: 119 | file = transform(file) 120 | 121 | return file 122 | -------------------------------------------------------------------------------- /script/data/taskonomy/taskonomy_dataset_s3.py: -------------------------------------------------------------------------------- 1 | import os 2 | import io 3 | import boto3 4 | import json 5 | import numpy as np 6 | import pandas as pd 7 | import torch 8 | from torch.utils.data import Dataset 9 | from PIL import Image 10 | #from PIL import ImageFile 11 | #ImageFile.LOAD_TRUNCATED_IMAGES = True # TODO: Fix these images and then remove this 12 | 13 | 14 | from .task_configs import task_parameters 15 | from .transforms import task_transform 16 | from .splits import get_splits 17 | 18 | 19 | filter_amount_dict = {'low': 10000, 'medium': 50000, 'high': 100000} 20 | forbidden_buildings = [ 21 | 'mosquito', 'tansboro', 'tomkins', 'darnestown', 'brinnon', # We do not have the rgb data for tomkins, darnestown, brinnon 22 | 'rough', 'grace', 'wiconisco' # Contain some wrong viewpoints 23 | ] 24 | 25 | 26 | class TaskonomyDatasetS3(Dataset): 27 | def __init__(self, 28 | tasks, 29 | split='train', 30 | variant='fullplus', 31 | rm_incomplete=True, 32 | image_size=256, 33 | max_images=None, 34 | seed=0, 35 | filter_amount='medium'): 36 | ''' 37 | Taskonomy EPFL-S3 dataloader. 38 | Make sure the environment variables S3_ENDPOINT, S3_TASKONOMY_ACCESS, 39 | S3_TASKONOMY_KEY, and S3_TASKONOMY_BUCKET are set. 40 | 41 | Args: 42 | tasks: List of tasks 43 | split: One of {'train', 'val', 'test', 'all'} 44 | variant: One of {'debug', 'tiny', 'medium', 'full', 'fullplus'} 45 | rm_incomplete: Set to True to only keep samples that have every task 46 | image_size: Target image size 47 | max_images: Optional subset selection 48 | seed: Random seed for deterministic shuffling order 49 | filter_amount: How many "bad" images to remove. One of {'low', 'medium', 'high'}. 50 | ''' 51 | super(TaskonomyDatasetS3, self).__init__() 52 | self.tasks = tasks 53 | self.split = split 54 | self.variant = variant 55 | self.rm_incomplete = rm_incomplete 56 | self.image_size=image_size 57 | self.max_images = max_images 58 | self.seed = seed 59 | self.filter_amount = filter_amount 60 | 61 | # S3 bucket setup 62 | self.session = boto3.session.Session() 63 | self.s3_client = self.session.client( 64 | service_name='s3', 65 | aws_access_key_id=os.environ.get('S3_TASKONOMY_ACCESS'), 66 | aws_secret_access_key=os.environ.get('S3_TASKONOMY_KEY'), 67 | endpoint_url=os.environ.get('S3_ENDPOINT') 68 | ) 69 | self.bucket_name = os.environ.get('S3_TASKONOMY_BUCKET') 70 | 71 | # DataFrame containing information whether or not any file for any task exists 72 | self.df_meta = pd.read_pickle(os.path.join(os.path.dirname(__file__), 'metadata', 'taskonomy_files.pkl.gz')) 73 | 74 | # Select splits based on selected size/variant 75 | splits = get_splits( 76 | os.path.join(os.path.dirname(__file__), 'metadata', f'train_val_test_{variant}.csv'), 77 | forbidden_buildings=forbidden_buildings 78 | ) 79 | if split == 'all': 80 | self.buildings = list(set(splits['train']) | set(splits['val']) | set(splits['test'])) 81 | else: 82 | self.buildings = splits[split] 83 | self.buildings = sorted(self.buildings) 84 | self.df_meta = self.df_meta.loc[self.buildings] 85 | 86 | # Filter bad images 87 | df_filter = pd.read_pickle(os.path.join(os.path.dirname(__file__), 'metadata', 'taskonomy_filter_scores.pkl.gz')) 88 | df_filter = df_filter[:filter_amount_dict[filter_amount]] 89 | filtered_indices = self.df_meta.index.difference(df_filter.index) 90 | self.df_meta = self.df_meta.loc[filtered_indices] 91 | 92 | self.df_meta = self.df_meta[tasks] # Select tasks of interest 93 | if rm_incomplete: 94 | # Only select rows where we have all the tasks 95 | self.df_meta = self.df_meta[self.df_meta.all(axis=1)] 96 | self.df_meta = self.df_meta.sample(frac=1, random_state=seed) # Random shuffle 97 | self.df_meta = self.df_meta[:max_images] if max_images is not None else self.df_meta # Select subset if so desired 98 | 99 | print(f'Using {len(self.df_meta)} images from variant {self.variant} in split {self.split}.') 100 | 101 | 102 | def __len__(self): 103 | return len(self.df_meta) 104 | 105 | def __getitem__(self, index): 106 | 107 | # building / point / view are encoded in dataframe index 108 | building, point, view = building, point, view = self.df_meta.iloc[index].name 109 | # TODO: Remove this try/except after we made sure there are no bad/missing images! 110 | # Very slow if it fails. 111 | try: 112 | 113 | result = {} 114 | for task in self.tasks: 115 | # Load from S3 bucket 116 | ext = task_parameters[task]['ext'] 117 | domain_id = task_parameters[task]['domain_id'] 118 | key = f'taskonomy_imgs/{task}/{building}/point_{point}_view_{view}_domain_{domain_id}.{ext}' 119 | obj = self.s3_client.get_object(Bucket=self.bucket_name, Key=key)['Body'].read() 120 | 121 | # Convert bytes to image / json / array / etc... 122 | if ext == 'png': 123 | file = Image.open(io.BytesIO(obj)) 124 | elif ext == 'json': 125 | file = json.load(io.BytesIO(obj)) 126 | if task == 'point_info': 127 | file['building'] = building 128 | file.pop('nonfixated_points_in_view') 129 | elif ext == 'npy': 130 | file = np.frombuffer(obj) 131 | else: 132 | raise NotImplementedError(f'Loading extension {ext} not yet implemented') 133 | 134 | # Perform transformations 135 | file = task_transform(file, task=task, image_size=self.image_size) 136 | 137 | result[task] = file 138 | 139 | return torch.stack([result[self.tasks[0]],result[self.tasks[0]]]), torch.stack([result[t].view(-1,self.image_size,self.image_size) for i,t in enumerate(self.tasks) if i!=0] ),torch.LongTensor([i for i in range(len(self.tasks)-1)]) 140 | 141 | 142 | except Exception as e : 143 | # In case image was faulty or not uploaded yet, try with random other image 144 | 145 | return self[np.random.randint(len(self))] 146 | -------------------------------------------------------------------------------- /script/model/conditional_modules.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import numbers 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class FiLM(nn.Module): 9 | """ Feature-wise Linear Modulation (FiLM) layer""" 10 | def __init__(self, input_size, output_size, num_film_layers=1, layer_norm=False): 11 | """ 12 | :param input_size: feature size of x_cond 13 | :param output_size: feature size of x_to_film 14 | :param layer_norm: true or false 15 | """ 16 | super(FiLM, self).__init__() 17 | self.input_size = input_size 18 | self.output_size = output_size 19 | self.num_film_layers = num_film_layers 20 | self.layer_norm = nn.LayerNorm(output_size) if layer_norm else None 21 | film_output_size = self.output_size * num_film_layers * 2 22 | self.gb_weights = nn.Linear(self.input_size, film_output_size) 23 | self.gb_weights.bias.data.fill_(0) 24 | 25 | def forward(self, x_cond, x_to_film): 26 | gb = self.gb_weights(x_cond).unsqueeze(1) 27 | gamma, beta = torch.chunk(gb, 2, dim=-1) 28 | out = (1 + gamma) * x_to_film + beta 29 | if self.layer_norm is not None: 30 | out = self.layer_norm(out) 31 | return out 32 | 33 | 34 | class TAA(nn.Module): 35 | """ Task Adapted Attention layer""" 36 | def __init__(self, input_size, output_size, blocks=1, num_film_layers=1, layer_norm=False): 37 | """ 38 | :param input_size: feature size of x_cond 39 | :param output_size: feature size of x_to_film 40 | :param layer_norm: true or false 41 | """ 42 | super(TAA, self).__init__() 43 | self.input_size = input_size 44 | self.output_size = output_size 45 | self.num_film_layers = num_film_layers 46 | self.layer_norm = nn.LayerNorm(output_size) if layer_norm else None 47 | self.blocks = blocks 48 | film_output_size = self.output_size * num_film_layers * 2 49 | self.gb_weights = nn.Linear(self.input_size, film_output_size) 50 | self.gb_weights.bias.data.fill_(0) 51 | 52 | def forward(self, x_cond, x_to_film): 53 | """gb = self.gb_weights(x_cond).unsqueeze(1) 54 | gamma, beta = torch.chunk(gb, 2, dim=-1) 55 | out = (1 + gamma) * x_to_film + beta 56 | """ 57 | 58 | gb = self.gb_weights(x_cond).unsqueeze(1) 59 | 60 | gamma, beta = torch.chunk(gb, 2, dim=-1) 61 | 62 | out = (1 + gamma) * x_to_film + beta 63 | 64 | 65 | if self.layer_norm is not None: 66 | out = self.layer_norm(out) 67 | out = [torch.block_diag(*list(out_b.chunk(self.blocks, 0))) for out_b in out] 68 | out = torch.stack(out) 69 | return out[:, :, :out.size(1)] 70 | 71 | 72 | class TaskScaledNorm(nn.Module): 73 | r"""Applies Task Scaled Normalization over a mini-batch of inputs. 74 | 75 | .. math:: 76 | y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma(z) + \beta(z) 77 | 78 | The mean and standard-deviation are calculated separately over the last 79 | certain number dimensions which have to be of the shape specified by 80 | :attr:`normalized_shape`. 81 | :math:`\gamma` and :math:`\beta` are learnable affine transform parameters of 82 | :attr:`normalized_shape` if :attr:`elementwise_affine` is ``True``. 83 | 84 | .. note:: 85 | Unlike Batch Normalization and Instance Normalization, which applies 86 | scalar scale and bias for each entire channel/plane with the 87 | :attr:`affine`, Layer Normalization applies per-element scale and 88 | bias with :attr:`elementwise_affine`. 89 | 90 | This layer uses statistics computed from input data in both training and 91 | evaluation modes. The affine transformation is modulated by a task scaled tensor. 92 | In our case, we use task embeddings. 93 | 94 | Args: 95 | normalized_shape (int or list or torch.Size): input shape from an expected input 96 | of size 97 | 98 | .. math:: 99 | [* \times \text{normalized\_shape}[0] \times \text{normalized\_shape}[1] 100 | \times \ldots \times \text{normalized\_shape}[-1]] 101 | 102 | If a single integer is used, it is treated as a singleton list, and this module will 103 | normalize over the last dimension which is expected to be of that specific size. 104 | eps: a value added to the denominator for numerical stability. Default: 1e-5 105 | elementwise_affine: a boolean value that when set to ``True``, this module 106 | has learnable per-element affine parameters initialized to ones (for weights) 107 | and zeros (for biases). Default: ``True``. 108 | 109 | Shape: 110 | - Input: :math:`(N, *)` 111 | - Output: :math:`(N, *)` (same shape as input) 112 | 113 | Examples:: 114 | 115 | >>> input_ = torch.randn(20, 5, 10, 10) 116 | >>> condition = torch.randn(20, 10) 117 | >>> # With Learnable Parameters 118 | >>> m = TaskScaledNorm([10, 10]) 119 | >>> # Normalize over last dimension of size 10 120 | >>> m = nn.LayerNorm(10) 121 | >>> # Activating the module 122 | >>> output = m(input_, condition) 123 | 124 | """ 125 | __constants__ = ['normalized_shape', 'condition_size', 'weight', 'bias', 'eps'] 126 | 127 | def __init__(self, normalized_shape, condition_size, eps=1e-5): 128 | super(TaskScaledNorm, self).__init__() 129 | if isinstance(normalized_shape, numbers.Integral): 130 | normalized_shape = (normalized_shape,) 131 | self.normalized_shape = tuple(normalized_shape) 132 | 133 | self.condition_size = condition_size 134 | self.eps = eps 135 | 136 | self.weight = nn.Parameter(torch.Tensor(*normalized_shape)) 137 | self.ln_weight_modulation = FiLM(condition_size, sum(normalized_shape)) 138 | self.bias = nn.Parameter(torch.Tensor(*normalized_shape)) 139 | self.reset_parameters() 140 | 141 | def reset_parameters(self): 142 | nn.init.ones_(self.weight) 143 | nn.init.zeros_(self.bias) 144 | 145 | def forward(self, input_, condition, task_id): 146 | unique_task_ids = torch.unique(task_id) 147 | cln_output = torch.zeros_like(input_) 148 | for unique_task_id in unique_task_ids: 149 | task_id_filter = task_id == unique_task_id 150 | task_emb = condition[task_id_filter][0].unsqueeze(0) 151 | weight = self.ln_weight_modulation(task_emb, self.weight).view(-1) 152 | cln_output[task_id_filter] = F.layer_norm(input_[task_id_filter], self.normalized_shape, weight, self.bias, self.eps) 153 | return cln_output 154 | 155 | def extra_repr(self): 156 | return '{normalized_shape}, {condition_size}, eps={eps}'.format(**self.__dict__) 157 | 158 | 159 | class ConditionalBottleNeck(nn.Module): 160 | """Down projection and up projection with FiLM layers within Transformer layer.""" 161 | def __init__(self, hidden_size, output_size): 162 | super(ConditionalBottleNeck, self).__init__() 163 | self.emb_transf = nn.Linear(hidden_size, hidden_size) 164 | self.hidden_modulation = FiLM(hidden_size, output_size) 165 | self.down_proj_layer = nn.Linear(output_size, output_size//3) 166 | self.up_proj_layer = nn.Linear(output_size//3, output_size) 167 | 168 | def forward(self, x_cond, hidden_states): 169 | x_cond = self.emb_transf(x_cond) 170 | hidden_states = self.hidden_modulation(x_cond=x_cond, x_to_film=hidden_states) 171 | hidden_states = self.down_proj_layer(hidden_states) 172 | hidden_states = self.up_proj_layer(hidden_states) 173 | return hidden_states 174 | -------------------------------------------------------------------------------- /script/evaluate.py: -------------------------------------------------------------------------------- 1 | from data.taskonomy.taskonomy_dataset_s3 import TaskonomyDatasetS3 2 | from matplotlib import pyplot as plt 3 | import torch 4 | import torchvision 5 | import torchvision.transforms as transforms 6 | from torch.utils.data import DataLoader 7 | import torch.optim as optim 8 | import torch.nn.functional as F 9 | from torch.utils.tensorboard import SummaryWriter 10 | import transformers 11 | from tqdm import tqdm 12 | import numpy as np 13 | import os 14 | import pickle 15 | import cv2 16 | import json 17 | import argparse 18 | 19 | 20 | from torchvision.utils import make_grid, save_image 21 | 22 | 23 | from data.nyuv2_same_batch import NYUv2SameBatchDataset 24 | from model.swin_transformer import SwinTransformer 25 | from loss.losses import berHuLoss 26 | from loss.metrics import iou_pytorch, eval_depth 27 | from data.nyuv2 import NYUv2Dataset 28 | 29 | 30 | def get_config(): 31 | parser = argparse.ArgumentParser(description='Train the network') 32 | parser.add_argument('--config', help='train config file path') 33 | 34 | args = parser.parse_args() 35 | 36 | with open(args.config, "r") as jsonfile: 37 | config = json.load(jsonfile) 38 | 39 | return config 40 | 41 | 42 | def get_dataloaders(tasks, batch_size, setting="nyu", task=None): 43 | 44 | if setting == "taskonomy": 45 | 46 | test_dataset = TaskonomyDatasetS3( 47 | tasks=["rgb", "segment_semantic", "depth_euclidean"], split="val", variant="tiny", image_size=224) 48 | 49 | g = torch.Generator() 50 | g.manual_seed(61) 51 | 52 | k_samples = 16*100 53 | perm = torch.randperm(len(test_dataset), generator=g) 54 | idx = perm[:k_samples].tolist() 55 | 56 | subset_dataset_test = torch.utils.data.Subset(test_dataset, idx) 57 | 58 | dataloader = DataLoader(subset_dataset_test, 59 | batch_size=batch_size, shuffle=False) 60 | 61 | return dataloader 62 | 63 | if setting == "nyu": 64 | 65 | IMAGE_SIZE = (480, 640) 66 | 67 | test_t = torch.nn.Sequential( 68 | transforms.CenterCrop(480), transforms.Resize(224)) 69 | train_t_input_image = torch.nn.Sequential(transforms.ColorJitter(brightness=( 70 | 0.8, 1.2), contrast=(0.8, 1.2), saturation=(0.8, 1.2), hue=(-0.1, 0.1))) 71 | 72 | test_dataset = NYUv2SameBatchDataset(root="./data/nyuv2", tasks=tasks, download=False, train=False, 73 | rgb_transform=test_t, seg_transform=test_t, sn_transform=test_t, depth_transform=test_t) 74 | 75 | dataloader = DataLoader( 76 | test_dataset, batch_size=batch_size, shuffle=False) 77 | 78 | return dataloader 79 | 80 | if setting == "nyu_single_task": 81 | 82 | IMAGE_SIZE = (480, 640) 83 | 84 | test_t = torch.nn.Sequential( 85 | transforms.CenterCrop(480), transforms.Resize(224)) 86 | train_t_input_image = torch.nn.Sequential(transforms.ColorJitter(brightness=( 87 | 0.8, 1.2), contrast=(0.8, 1.2), saturation=(0.8, 1.2), hue=(-0.1, 0.1))) 88 | 89 | test_dataset = NYUv2Dataset(root="./data/nyuv2", tasks=tasks, download=False, train=False, 90 | rgb_transform=test_t, seg_transform=test_t, sn_transform=test_t, depth_transform=test_t) 91 | 92 | if task == "segmentation": 93 | test_dataset = torch.utils.data.Subset( 94 | test_dataset, range(len(test_dataset)//2)) 95 | 96 | if task == "depth": 97 | test_dataset = torch.utils.data.Subset( 98 | test_dataset, range(len(test_dataset)//2, len(test_dataset))) 99 | 100 | dataloader = DataLoader( 101 | test_dataset, batch_size=batch_size, shuffle=False) 102 | 103 | return dataloader 104 | 105 | 106 | def calc_seg_metrics(logit_task, label_task): 107 | 108 | max_labels = torch.argmax(logit_task, dim=1, keepdim=True) 109 | iou = iou_pytorch(max_labels, label_task) 110 | 111 | return max_labels, iou 112 | 113 | 114 | def disp2meters(d): 115 | return (65536.0 / d - 1) / 1e4 116 | 117 | 118 | def load_model(model, PATH, device): 119 | checkpoint = torch.load(PATH, map_location=device) 120 | model.load_state_dict(checkpoint['model_state_dict']) 121 | model = model.to(device) 122 | return model 123 | 124 | 125 | def evaluate(model, dataloader, device, task=None): 126 | test_loss = 0 127 | epoch_ious = [] 128 | epoch_eval_depths_d1 = [] 129 | 130 | epoch_loss_seg_test = [] 131 | epoch_loss_depth_test = [] 132 | 133 | model.eval() 134 | for i, (img, label, task_id) in enumerate(dataloader, 0): 135 | 136 | img = img.view((-1, 3, 224, 224)).to(device) 137 | label = label.view((-1, 1, 224, 224)).to(device) 138 | task_id = task_id.view(-1).to(device) 139 | 140 | if task is not None: 141 | task_id = torch.zeros_like(task_id) 142 | 143 | logits, unique_task_ids_list = model(img, task_id) 144 | 145 | loss = 0 146 | 147 | for j, unique_task_id in enumerate(unique_task_ids_list): 148 | 149 | task_id_filter = task_id == unique_task_id 150 | 151 | logit_task = logits[j] 152 | label_task = label[task_id_filter] 153 | B = logit_task.shape[0] 154 | 155 | if unique_task_id == 0 and task != "depth": 156 | 157 | label_task = label_task.long() 158 | 159 | max_labels, iou = calc_seg_metrics(logit_task, label_task) 160 | 161 | epoch_ious.append(iou.cpu().numpy()) 162 | 163 | else: 164 | 165 | logit_task = torch.nn.functional.sigmoid(logit_task)*65535 + 1 166 | label_task = 65536.0 / (label_task + 1) 167 | 168 | evaluation = eval_depth(disp2meters( 169 | logit_task), disp2meters(label_task)) 170 | epoch_eval_depths_d1.append(evaluation["d1"]) 171 | 172 | print("Mean IOU: ", np.mean(epoch_ious)) 173 | print("D1: ", np.mean(epoch_eval_depths_d1)) 174 | 175 | 176 | def save_images(model, dataloader, device, num_images=1, task=None, setting=None): 177 | 178 | img_count = 0 179 | for i, (img, label, task_id) in enumerate(dataloader, 0): 180 | 181 | img = img.view((-1, 3, 224, 224)).to(device) 182 | label = label.view((-1, 1, 224, 224)).to(device) 183 | task_id = task_id.view(-1).to(device) 184 | 185 | if task is not None: 186 | task_id = torch.zeros_like(task_id) 187 | 188 | logits, unique_task_ids_list = model(img, task_id) 189 | 190 | for j in range(len(img)): 191 | if len(logits) == 1: 192 | fig, axs = plt.subplots(1, 3, figsize=(12, 5)) 193 | 194 | axs[0].imshow(torch.permute(img[j].cpu(), (1, 2, 0))) 195 | axs[0].set_xlabel('RGB Image') 196 | 197 | if task == "segmentation": 198 | k = torch.argmax(logits[0][j], dim=0, keepdim=True) 199 | k[0][-1][-1] = torch.max(label[j]) 200 | label[j][0][-1][-1] = torch.max(k) 201 | else: 202 | k = disp2meters(torch.nn.functional.sigmoid( 203 | logits[0][j])*65535 + 1) 204 | 205 | axs[1].imshow(torch.permute(label[j].cpu(), (1, 2, 0))) 206 | axs[1].set_xlabel(f'{task.capitalize()} Label') 207 | 208 | axs[2].imshow(k.detach().view(224, 224, 1).cpu()) 209 | axs[2].set_xlabel(f'{task.capitalize()} Prediction') 210 | 211 | plt.savefig(f'./images/{img_count}.png') 212 | img_count += 1 213 | 214 | else: 215 | if j % 2 == 1: 216 | continue 217 | 218 | c = j//2 219 | fig, axs = plt.subplots(1, 5, figsize=(20, 5)) 220 | axs[0].imshow(torch.permute(img[j].cpu(), (1, 2, 0))) 221 | axs[0].set_xlabel('RGB Image') 222 | 223 | k = torch.argmax(logits[0][c], dim=0, keepdim=True) 224 | 225 | k[0][-1][-1] = 18 if setting == "taskonomy" else 13 226 | label[j][0][-1][-1] = 18 if setting == "taskonomy" else 13 227 | 228 | axs[1].imshow(torch.permute(label[j].cpu(), (1, 2, 0))) 229 | axs[1].set_xlabel('Segmentation Label') 230 | 231 | axs[2].imshow(k.detach().view(224, 224, 1).cpu()) 232 | axs[2].set_xlabel('Segmentation Prediction') 233 | 234 | label[j+1][label[j+1] == 65535] = 0 235 | axs[3].imshow(torch.permute(label[j+1].cpu(), (1, 2, 0))) 236 | axs[3].set_xlabel('Depth Label') 237 | k2 = disp2meters(torch.nn.functional.sigmoid( 238 | logits[1][c])*65535 + 1) 239 | axs[4].imshow(k2.detach().view(224, 224, 1).cpu()) 240 | axs[4].set_xlabel('Depth Prediction') 241 | 242 | plt.savefig(f'./images/{img_count}.png') 243 | img_count += 1 244 | 245 | if img_count == num_images: 246 | return 247 | 248 | 249 | def main(): 250 | 251 | config = get_config() 252 | 253 | if config["setting"] != "nyu_single_task" and "task" in config.keys(): 254 | print("Do not put task parameter on multitask networks!") 255 | return 256 | 257 | torch.manual_seed(61) 258 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 259 | 260 | tasks = {0: "segmentation", 1: "depth"} 261 | print("Creating dataset...") 262 | dataloader = get_dataloaders( 263 | tasks, 16, config["setting"], config["task"] if "task" in config.keys() else None) 264 | 265 | print("Loading model...") 266 | 267 | tasks = ["segmentation", "depth"] 268 | task_classes = [14, 1] if config["setting"] != "taskonomy" else [18, 1] 269 | if config["setting"] == "nyu_single_task": 270 | tasks = [config["task"]] 271 | task_classes = [14 if config["task"] == "segmentation" else 1] 272 | 273 | model = SwinTransformer(img_size=224, 274 | patch_size=4, 275 | in_chans=3, 276 | num_classes=21841, 277 | embed_dim=96, 278 | depths=[2, 2, 18, 2], 279 | depths_decoder=[2, 2, 2, 2], 280 | num_heads=[3, 6, 12, 24], 281 | window_size=7, 282 | mlp_ratio=4., 283 | qkv_bias=True, 284 | qk_scale=True, 285 | drop_rate=0, 286 | drop_rate_decoder=0.6, 287 | drop_path_rate=0.2, 288 | ape=False, 289 | patch_norm=True, 290 | use_checkpoint=False, 291 | tasks=tasks, 292 | task_classes=task_classes, 293 | conditioned_blocks=config["conditioned_blocks"] if config["setting"] != "nyu_single_task" else [ 294 | [], [], [], []], 295 | adapter=config["adapter"] if config["setting"] != "nyu_single_task" else False, 296 | use_conditional_layer=config["use_conditional_layer_norm"] if config["setting"] == "nyu" else False) 297 | 298 | model = load_model(model, config["model_path"], device) 299 | 300 | print("Evaluating...") 301 | evaluate(model, dataloader, device, 302 | task=config["task"] if "task" in config.keys() else None) 303 | 304 | print("Saving Images...") 305 | save_images(model, dataloader, device, num_images=config["num_generated_images"], 306 | task=config["task"] if "task" in config.keys() else None, setting=config["setting"]) 307 | 308 | 309 | if __name__ == '__main__': 310 | main() 311 | -------------------------------------------------------------------------------- /docs/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 10 | 11 | 12 | Vision Transformer Adapters for Generalizable Multitask Learning 13 | 14 | 15 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 |
25 |
26 |

Vision Transformer Adapters for Generalizable Multitask Learning

27 |

ICCV 2023

28 |

29 |
30 |

31 | Deblina Bhattacharjee, 32 | Sabine Süsstrunk, 33 | Mathieu Salzmann 34 |

35 |
36 | Paper 37 | Code 38 | Poster 39 | Slides 40 |
41 |
42 | 43 |
44 |
45 | 46 |
47 |
48 | 49 |
50 |
51 |
52 | 54 |
55 |

56 | We introduce the first multitasking vision transformer adapters that learn generalizable task affinities which can be applied to novel tasks and domains. Integrated into an off-the-shelf vision transformer backbone, our adapters can simultaneously solve multiple dense vision tasks in a parameter-efficient manner, unlike existing multitasking transformers that are parametrically expensive. In contrast to concurrent methods, we do not require retraining or fine-tuning whenever a new task or domain is added. We introduce a task-adapted attention mechanism within our adapter framework that combines gradient-based task similarities with attention-based ones. The learned task affinities generalize to the following settings: zero-shot task transfer, unsupervised domain adaptation, and generalization without fine-tuning to novel domains. We demonstrate that our approach outperforms not only the existing convolutional neural network-based multitasking methods but also the vision transformer-based ones. 57 |

58 | 59 |
60 |

Method Architecture

61 |
62 |
63 |
64 | 65 |
66 |
67 |

68 | Detailed overview of our architecture. The frozen transformer encoder module (in orange) extracts a shared representation of the input image, which is then utilized to learn the task affinities in our novel vision transformer adapters (in purple). Each adapter layer uses gradient task similarity (TROA) (in yellow) and Task-Adapted Attention (TAA) to learn the task affinities, which are communicated with skip connections (in blue) between consecutive adapter layers. The task embeddings are then decoded by the fully-supervised transformer decoders (in green) for the respective tasks. Note that the transformer decoders are shared but have different task heads (in grey). For clarity, only three tasks are depicted here and TAA is explained in a separate figure below. 69 |

70 |
71 |
72 | 73 |

Vision Transformer Adapter Module

74 | 75 |
76 |
77 | 78 |
79 |
80 |

81 | Overview of our vision transformer adapter module. Our vision adapters learn transferable and generalizable task affinities in a parameter-efficient way. We show two blocks to depict the skip connectivity between them. The main modules (TROA) and (TAA) of our vision transformer adapters are depicted below. 82 |

83 |
84 | 85 |

Task Representation Optimization Algorithm (TROA)

86 | 87 |
88 |
89 | 90 |
91 |
92 |

93 | We show the task affinities from TROA when four tasks comprising semantic segmentation (SemSeg), depth, surface normal, and edges are jointly learned. We show that TROA learns a strong task affinity between the same task gradients, for example, segmentation with segmentation. This is a self-explanatory observation. Consequently, TROA also learns task affinities between proximate tasks such as segmentation and depth, and task affinities between non-proximate tasks such as segmentation and normal. Note that task dependence is asymmetric, i.e. segmentation does not affect normal as normal effects segmentation. These task affinities are used by our novel task-adapted attention module as described in what follows. 94 |

95 |
96 | 97 |

Matching the Feature Dimensions using FiLM

98 | 99 |
100 |
101 | 102 |
103 |
104 |

105 | Detailed overview of Feature Wise Linear Modulation (FiLM)} which linearly shifts and scales tasks representations to match dimensions of the feature maps. The orange rectangular area is FiLM. 106 |

107 |
108 | 109 |

Task-Adapted Attention

110 | 111 |
112 |
113 | 114 |
115 |
116 |

117 | Overview of our Task-Adapted Attention (TAA) mechanism that combines task affinities with image attention. Note that the process, in the foreground, is for a single attention head which is repeated for 'M' heads to give us the task-adapted multi-head attention. 118 |

119 |
120 | 121 | 122 |
123 |

Multitasking Results

124 | 125 |
126 |
127 |
128 | 129 |
130 |
131 |

132 | Multitask Learning comparison on the NYUDv2 benchmark in the'S-D-N-E' setting. Our model outperforms all the multitask baselines, i.e. ST-MTL, InvPT, Taskprompter, and MulT, respectively. For instance, our model correctly segments and predicts the surface normal of the elements within the yellow-circled region, unlike the baseline. All the methods are based on the same Swin-B V2 backbone. Best seen on screen and zoomed in. For more details and quantitative results, please refer to our paper. 133 |

134 |
135 | 136 |
137 |
138 | 139 |
140 |
141 |

142 | Multitask Learning comparison on the Taskonomy benchmark in the'S-D-N-E' setting. Our model outperforms all the multitask baselines, respectively. For instance, our model correctly segments and predicts the surface normal of the elements within the yellow-circled region, unlike the baseline. All the methods are based on the same Swin-B V2 backbone. Best seen on screen and zoomed in. For more details and quantitative results, please refer to our paper. 143 |

144 |
145 |
146 | 147 |
148 |

Unsupervised Domain Adaptation (UDA)

149 |
150 | 151 |
152 |
153 | 154 |
155 |
156 |

157 | Unsupervised Domain Adaptation (UDA) results on Synthia->Cityscapes. Our model outperforms the CNN-based baseline (XTAM-UDA) and the Swin-B V2-based baselines (1-task Swin-UDA, MulT-UDA), respectively. For instance, our method can predict the depth of the car tail light, unlike the baselines. Best seen on screen and zoomed within the yellow circled region. 158 |

159 |
160 |
161 |
162 |

Bibtex

163 |
164 |
165 |
166 |
167 | @misc{bhattacharjee2023vision, 168 | title={Vision Transformer Adapters for Generalizable Multitask Learning}, 169 | author={Deblina Bhattacharjee and Sabine Süsstrunk and Mathieu Salzmann}, 170 | year={2023}, 171 | eprint={2308.12372}, 172 | archivePrefix={arXiv}, 173 | primaryClass={cs.CV} 174 | } 175 |
176 |
177 |
178 |
179 |

Acknowledgement

180 |
181 |
182 |

This work was supported in part by the Swiss National Science Foundation via the Sinergia grant CRSII5$-$180359.

183 | 184 |
185 |
186 |
187 |
188 | 189 |
190 | 191 | 192 | 195 | 198 | 201 | 202 | 203 | -------------------------------------------------------------------------------- /script/data/nyuv2.py: -------------------------------------------------------------------------------- 1 | """ 2 | author: Mihai Suteu 3 | date: 15/05/19 4 | https://github.com/xapharius/pytorch-nyuv2 5 | """ 6 | 7 | 8 | import os 9 | import sys 10 | import h5py 11 | import torch 12 | import shutil 13 | import random 14 | import tarfile 15 | import zipfile 16 | import requests 17 | import numpy as np 18 | from typing import Dict 19 | 20 | from PIL import Image 21 | from torch.utils.data import Dataset 22 | from torchvision.datasets.utils import download_url 23 | 24 | SEG = 0 25 | DEP = 1 26 | SN = 2 27 | 28 | 29 | class NYUv2Dataset(Dataset): 30 | """ 31 | PyTorch wrapper for the NYUv2 dataset focused on multi-task learning. 32 | Data sources available: RGB, Semantic Segmentation, Surface Normals, Depth Images. 33 | If no transformation is provided, the image type will not be returned. 34 | 35 | ### Output 36 | All images are of size: 640 x 480 37 | 38 | 1. RGB: 3 channel input image 39 | 40 | 2. Semantic Segmentation: 1 channel representing one of the 14 (0 - 41 | background) classes. Conversion to int will happen automatically if 42 | transformation ends in a tensor. Task name: "segmentation" 43 | 44 | 3. Surface Normals: 3 channels, with values in [0, 1]. Task name: "surface_normals" 45 | 46 | 4. Depth Images: 1 channel with floats representing the distance in meters. 47 | Conversion will happen automatically if transformation ends in a tensor. Task name: "depth" 48 | """ 49 | 50 | def __init__( 51 | self, 52 | root: str, 53 | tasks: Dict[int, str], 54 | train: bool = True, 55 | download: bool = False, 56 | rgb_transform=None, 57 | seg_transform=None, 58 | sn_transform=None, 59 | depth_transform=None, 60 | rgb_transform2=None, 61 | ): 62 | """ 63 | Will return tuples based on what data source has been enabled (rgb, seg etc). 64 | 65 | :param root: path to root folder (eg /data/NYUv2) 66 | :param train: whether to load the train or test set 67 | :param download: whether to download and process data if missing 68 | :param rgb_transform: the transformation pipeline for rbg images 69 | :param seg_transform: the transformation pipeline for segmentation images. If 70 | the transformation ends in a tensor, the result will be automatically 71 | converted to int in [0, 14) 72 | :param sn_transform: the transformation pipeline for surface normal images 73 | :param depth_transform: the transformation pipeline for depth images. If the 74 | transformation ends in a tensor, the result will be automatically converted 75 | to meters 76 | """ 77 | super().__init__() 78 | self.root = root 79 | 80 | self.rgb_transform = rgb_transform 81 | self.rgb_transform2 = rgb_transform2 82 | self.seg_transform = seg_transform 83 | self.depth_transform = depth_transform 84 | self.sn_transform = sn_transform 85 | 86 | self.train = train 87 | self._split = "train" if train else "test" 88 | 89 | if download: 90 | self.download() 91 | 92 | # rgb folder as ground truth 93 | self._files = sorted(os.listdir(os.path.join(root, f"{self._split}_rgb_pt"))) 94 | self.num_img = len(self._files) 95 | 96 | self.num_tasks = len(tasks) 97 | self.tasks = tasks 98 | 99 | self.task_dict = self._get_task_dict() 100 | 101 | self.folder = lambda name: os.path.join(self.root, f"{self._split}_{name}_pt") 102 | 103 | self.seg_images = torch.load(f"{root}/combined/{self._split}_seg13_pt.pt") 104 | self.depth_images = torch.load(f"{root}/combined/{self._split}_depth_pt.pt") 105 | 106 | 107 | 108 | 109 | def __getitem__(self, index: int): 110 | 111 | task = index // self.num_img 112 | rgb_image = index % self.num_img 113 | seed = random.randrange(sys.maxsize) 114 | rgb = None 115 | state = None 116 | 117 | if self.rgb_transform is not None: 118 | random.seed(seed) 119 | img = torch.load(os.path.join(self.folder("rgb"), self._files[rgb_image])) # self.rgb_images[rgb_image, :,:,:]# 120 | ### https://github.com/pytorch/vision/issues/9#issuecomment-789308878 121 | state = torch.get_rng_state() 122 | rgb = self.rgb_transform(img) 123 | if self.rgb_transform2 is not None: 124 | rgb = self.rgb_transform2(rgb) 125 | 126 | label = self._get_task_label(task, rgb_image, state) 127 | 128 | return rgb, label, task 129 | 130 | def _get_task_dict(self): 131 | 132 | task_dict = dict() 133 | 134 | for i in self.tasks.keys(): 135 | 136 | task_type = self.tasks[i] 137 | if task_type == "segmentation": 138 | task_dict[i] = SEG 139 | elif task_type == "surface_normals": 140 | task_dict[i] = SN 141 | elif task_type == "depth": 142 | task_dict[i] = DEP 143 | 144 | return task_dict 145 | 146 | 147 | def _get_task_label(self, task, rgb_image, state): 148 | seed = random.randrange(sys.maxsize) 149 | 150 | task_type = self.task_dict[task] 151 | if task_type == SEG: 152 | if self.seg_transform is not None: 153 | random.seed(seed) 154 | img = self.seg_images[rgb_image, :,:,:] 155 | torch.set_rng_state(state) 156 | img = self.seg_transform(img) 157 | if isinstance(img, torch.Tensor): 158 | # ToTensor scales to [0, 1] by default 159 | img = (img * 255).long() 160 | return img 161 | 162 | if task_type == SN: # kontrol et 163 | if self.sn_transform is not None: 164 | random.seed(seed) 165 | img = self.rgb_images[rgb_image, :,:,:] 166 | torch.set_rng_state(state) 167 | img = self.sn_transform(img) 168 | return img 169 | 170 | if task_type == DEP: 171 | if self.depth_transform is not None: 172 | random.seed(seed) 173 | img = self.depth_images[rgb_image, :,:,:] 174 | torch.set_rng_state(state) 175 | img = self.depth_transform(img) 176 | if isinstance(img, torch.Tensor): 177 | # depth png is uint16 178 | img = img.float() 179 | return img 180 | 181 | 182 | 183 | 184 | 185 | def __len__(self): 186 | return len(self._files) * self.num_tasks 187 | 188 | def __repr__(self): 189 | fmt_str = f"Dataset {self.__class__.__name__}\n" 190 | fmt_str += f" Number of data points: {self.__len__()}\n" 191 | fmt_str += f" Split: {self._split}\n" 192 | fmt_str += f" Root Location: {self.root}\n" 193 | tmp = " RGB Transforms: " 194 | fmt_str += "{0}{1}\n".format( 195 | tmp, self.rgb_transform.__repr__().replace("\n", "\n" + " " * len(tmp)) 196 | ) 197 | tmp = " Seg Transforms: " 198 | fmt_str += "{0}{1}\n".format( 199 | tmp, self.seg_transform.__repr__().replace("\n", "\n" + " " * len(tmp)) 200 | ) 201 | tmp = " SN Transforms: " 202 | fmt_str += "{0}{1}\n".format( 203 | tmp, self.sn_transform.__repr__().replace("\n", "\n" + " " * len(tmp)) 204 | ) 205 | tmp = " Depth Transforms: " 206 | fmt_str += "{0}{1}\n".format( 207 | tmp, self.depth_transform.__repr__().replace("\n", "\n" + " " * len(tmp)) 208 | ) 209 | return fmt_str 210 | 211 | def _check_exists(self) -> bool: 212 | """ 213 | Only checking for folder existence 214 | """ 215 | try: 216 | for split in ["train", "test"]: 217 | for part, transform in zip( 218 | ["rgb", "seg13", "depth"],#"sn", 219 | [ 220 | self.rgb_transform, 221 | self.seg_transform, 222 | self.sn_transform, 223 | self.depth_transform, 224 | ], 225 | ): 226 | if transform is None: 227 | continue 228 | path = os.path.join(self.root, f"{split}_{part}_pt") 229 | if not os.path.exists(path): 230 | raise FileNotFoundError("Missing Folder") 231 | except FileNotFoundError as e: 232 | return False 233 | return True 234 | 235 | def download(self): 236 | if self._check_exists(): 237 | return 238 | if self.rgb_transform is not None: 239 | download_rgb(self.root) 240 | if self.seg_transform is not None: 241 | download_seg(self.root) 242 | if self.sn_transform is not None: 243 | download_sn(self.root) 244 | if self.depth_transform is not None: 245 | download_depth(self.root) 246 | print("Done!") 247 | 248 | 249 | def download_rgb(root: str): 250 | train_url = "http://www.doc.ic.ac.uk/~ahanda/nyu_train_rgb.tgz" 251 | test_url = "http://www.doc.ic.ac.uk/~ahanda/nyu_test_rgb.tgz" 252 | 253 | def _proc(url: str, dst: str): 254 | if not os.path.exists(dst): 255 | tar = os.path.join(root, url.split("/")[-1]) 256 | if not os.path.exists(tar): 257 | download_url(url, root) 258 | if os.path.exists(tar): 259 | _unpack(tar) 260 | _replace_folder(tar.rstrip(".tgz"), dst) 261 | _rename_files(dst, lambda x: x.split("_")[2]) 262 | 263 | _proc(train_url, os.path.join(root, "train_rgb")) 264 | _proc(test_url, os.path.join(root, "test_rgb")) 265 | 266 | 267 | def download_seg(root: str): 268 | train_url = "https://github.com/ankurhanda/nyuv2-meta-data/raw/master/train_labels_13/nyuv2_train_class13.tgz" 269 | test_url = "https://github.com/ankurhanda/nyuv2-meta-data/raw/master/test_labels_13/nyuv2_test_class13.tgz" 270 | 271 | def _proc(url: str, dst: str): 272 | if not os.path.exists(dst): 273 | tar = os.path.join(root, url.split("/")[-1]) 274 | if not os.path.exists(tar): 275 | download_url(url, root) 276 | if os.path.exists(tar): 277 | _unpack(tar) 278 | _replace_folder(tar.rstrip(".tgz"), dst) 279 | _rename_files(dst, lambda x: x.split("_")[3]) 280 | 281 | _proc(train_url, os.path.join(root, "train_seg13")) 282 | _proc(test_url, os.path.join(root, "test_seg13")) 283 | 284 | 285 | def download_sn(root: str): 286 | url = "https://www.dropbox.com/s/dn5sxhlgml78l03/nyu_normals_gt.zip" 287 | train_dst = os.path.join(root, "train_sn") 288 | test_dst = os.path.join(root, "test_sn") 289 | 290 | if not os.path.exists(train_dst) or not os.path.exists(test_dst): 291 | tar = os.path.join(root, url.split("/")[-1]) 292 | if not os.path.exists(tar): 293 | req = requests.get(url + "?dl=1") # dropbox 294 | with open(tar, 'wb') as f: 295 | f.write(req.content) 296 | if os.path.exists(tar): 297 | _unpack(tar) 298 | if not os.path.exists(train_dst): 299 | _replace_folder( 300 | os.path.join(root, "nyu_normals_gt", "train"), train_dst 301 | ) 302 | _rename_files(train_dst, lambda x: x[1:]) 303 | if not os.path.exists(test_dst): 304 | _replace_folder(os.path.join(root, "nyu_normals_gt", "test"), test_dst) 305 | _rename_files(test_dst, lambda x: x[1:]) 306 | shutil.rmtree(os.path.join(root, "nyu_normals_gt")) 307 | 308 | 309 | def download_depth(root: str): 310 | url = ( 311 | "http://horatio.cs.nyu.edu/mit/silberman/nyu_depth_v2/nyu_depth_v2_labeled.mat" 312 | ) 313 | train_dst = os.path.join(root, "train_depth") 314 | test_dst = os.path.join(root, "test_depth") 315 | 316 | if not os.path.exists(train_dst) or not os.path.exists(test_dst): 317 | tar = os.path.join(root, url.split("/")[-1]) 318 | if not os.path.exists(tar): 319 | download_url(url, root) 320 | if os.path.exists(tar): 321 | train_ids = [ 322 | f.split(".")[0] for f in os.listdir(os.path.join(root, "train_rgb")) 323 | ] 324 | _create_depth_files(tar, root, train_ids) 325 | 326 | 327 | def _unpack(file: str): 328 | """ 329 | Unpacks tar and zip, does nothing for any other type 330 | :param file: path of file 331 | """ 332 | path = file.rsplit(".", 1)[0] 333 | 334 | if file.endswith(".tgz"): 335 | tar = tarfile.open(file, "r:gz") 336 | tar.extractall(path) 337 | tar.close() 338 | elif file.endswith(".zip"): 339 | zip = zipfile.ZipFile(file, "r") 340 | zip.extractall(path) 341 | zip.close() 342 | 343 | 344 | def _rename_files(folder: str, rename_func: callable): 345 | """ 346 | Renames all files inside a folder based on the passed rename function 347 | :param folder: path to folder that contains files 348 | :param rename_func: function renaming filename (not including path) str -> str 349 | """ 350 | imgs_old = os.listdir(folder) 351 | imgs_new = [rename_func(file) for file in imgs_old] 352 | for img_old, img_new in zip(imgs_old, imgs_new): 353 | shutil.move(os.path.join(folder, img_old), os.path.join(folder, img_new)) 354 | 355 | 356 | def _replace_folder(src: str, dst: str): 357 | """ 358 | Rename src into dst, replacing/overwriting dst if it exists. 359 | """ 360 | if os.path.exists(dst): 361 | shutil.rmtree(dst) 362 | shutil.move(src, dst) 363 | 364 | 365 | def _create_depth_files(mat_file: str, root: str, train_ids: list): 366 | """ 367 | Extract the depth arrays from the mat file into images 368 | :param mat_file: path to the official labelled dataset .mat file 369 | :param root: The root directory of the dataset 370 | :param train_ids: the IDs of the training images as string (for splitting) 371 | """ 372 | os.mkdir(os.path.join(root, "train_depth")) 373 | os.mkdir(os.path.join(root, "test_depth")) 374 | train_ids = set(train_ids) 375 | 376 | depths = h5py.File(mat_file, "r")["depths"] 377 | for i in range(len(depths)): 378 | img = (depths[i] * 1e4).astype(np.uint16).T 379 | id_ = str(i + 1).zfill(4) 380 | folder = "train" if id_ in train_ids else "test" 381 | save_path = os.path.join(root, f"{folder}_depth", id_ + ".png") 382 | Image.fromarray(img).save(save_path) 383 | -------------------------------------------------------------------------------- /script/data/nyuv2_same_batch.py: -------------------------------------------------------------------------------- 1 | """ 2 | author: Mihai Suteu 3 | date: 15/05/19 4 | https://github.com/xapharius/pytorch-nyuv2 5 | """ 6 | 7 | 8 | import os 9 | import sys 10 | import h5py 11 | import torch 12 | import shutil 13 | import random 14 | import tarfile 15 | import zipfile 16 | import requests 17 | import numpy as np 18 | from typing import Dict 19 | 20 | from PIL import Image 21 | from torch.utils.data import Dataset 22 | from torchvision.datasets.utils import download_url 23 | 24 | SEG = 0 25 | DEP = 1 26 | SN = 2 27 | 28 | 29 | class NYUv2SameBatchDataset(Dataset): 30 | """ 31 | PyTorch wrapper for the NYUv2 dataset focused on multi-task learning. 32 | Data sources available: RGB, Semantic Segmentation, Surface Normals, Depth Images. 33 | If no transformation is provided, the image type will not be returned. 34 | 35 | ### Output 36 | All images are of size: 640 x 480 37 | 38 | 1. RGB: 3 channel input image 39 | 40 | 2. Semantic Segmentation: 1 channel representing one of the 14 (0 - 41 | background) classes. Conversion to int will happen automatically if 42 | transformation ends in a tensor. Task name: "segmentation" 43 | 44 | 3. Surface Normals: 3 channels, with values in [0, 1]. Task name: "surface_normals" 45 | 46 | 4. Depth Images: 1 channel with floats representing the distance in meters. 47 | Conversion will happen automatically if transformation ends in a tensor. Task name: "depth" 48 | """ 49 | 50 | def __init__( 51 | self, 52 | root: str, 53 | tasks: Dict[int, str], 54 | train: bool = True, 55 | download: bool = False, 56 | rgb_transform=None, 57 | seg_transform=None, 58 | sn_transform=None, 59 | depth_transform=None, 60 | rgb_transform2=None, 61 | ): 62 | """ 63 | Will return tuples based on what data source has been enabled (rgb, seg etc). 64 | 65 | :param root: path to root folder (eg /data/NYUv2) 66 | :param train: whether to load the train or test set 67 | :param download: whether to download and process data if missing 68 | :param rgb_transform: the transformation pipeline for rbg images 69 | :param seg_transform: the transformation pipeline for segmentation images. If 70 | the transformation ends in a tensor, the result will be automatically 71 | converted to int in [0, 14) 72 | :param sn_transform: the transformation pipeline for surface normal images 73 | :param depth_transform: the transformation pipeline for depth images. If the 74 | transformation ends in a tensor, the result will be automatically converted 75 | to meters 76 | """ 77 | super().__init__() 78 | self.root = root 79 | 80 | self.rgb_transform = rgb_transform 81 | self.rgb_transform2 = rgb_transform2 82 | self.seg_transform = seg_transform 83 | self.depth_transform = depth_transform 84 | self.sn_transform = sn_transform 85 | 86 | self.train = train 87 | self._split = "train" if train else "test" 88 | 89 | if download: 90 | self.download() 91 | 92 | 93 | # rgb folder as ground truth 94 | self._files = sorted(os.listdir(os.path.join(root, f"{self._split}_rgb_pt"))) 95 | self.num_img = len(self._files) 96 | 97 | self.num_tasks = len(tasks) 98 | self.tasks = tasks 99 | 100 | self.task_dict = self._get_task_dict() 101 | 102 | self.folder = lambda name: os.path.join(self.root, f"{self._split}_{name}_pt") 103 | 104 | self.seg_images = torch.load(f"{root}/combined/{self._split}_seg13_pt.pt") 105 | self.depth_images = torch.load(f"{root}/combined/{self._split}_depth_pt.pt") 106 | 107 | 108 | def __getitem__(self, index: int): 109 | 110 | rgb_image = index 111 | seed = random.randrange(sys.maxsize) 112 | rgb = None 113 | state = None 114 | 115 | if self.rgb_transform is not None: 116 | random.seed(seed) 117 | img = torch.load(os.path.join(self.folder("rgb"), self._files[rgb_image])) 118 | ### https://github.com/pytorch/vision/issues/9#issuecomment-789308878 119 | state = torch.get_rng_state() 120 | rgb = self.rgb_transform(img) 121 | if self.rgb_transform2 is not None: 122 | rgb = self.rgb_transform2(rgb) 123 | 124 | label_seg = self._get_task_label(0, rgb_image, state) 125 | label_depth = self._get_task_label(1, rgb_image, state) 126 | 127 | return torch.stack([rgb,rgb]), torch.stack([label_seg, label_depth]), torch.LongTensor([0,1]) 128 | 129 | def _get_task_dict(self): 130 | 131 | task_dict = dict() 132 | 133 | for i in self.tasks.keys(): 134 | 135 | task_type = self.tasks[i] 136 | if task_type == "segmentation": 137 | task_dict[i] = SEG 138 | elif task_type == "surface_normals": 139 | task_dict[i] = SN 140 | elif task_type == "depth": 141 | task_dict[i] = DEP 142 | 143 | return task_dict 144 | 145 | 146 | def _get_task_label(self, task, rgb_image, state): 147 | seed = random.randrange(sys.maxsize) 148 | 149 | task_type = self.task_dict[task] 150 | if task_type == SEG: 151 | if self.seg_transform is not None: 152 | random.seed(seed) 153 | img = self.seg_images[rgb_image, :,:,:] 154 | torch.set_rng_state(state) 155 | img = self.seg_transform(img) 156 | if isinstance(img, torch.Tensor): 157 | # ToTensor scales to [0, 1] by default 158 | img = (img * 255).long() 159 | return img 160 | 161 | if task_type == SN: # kontrol et 162 | if self.sn_transform is not None: 163 | random.seed(seed) 164 | img = self.rgb_images[rgb_image, :,:,:] 165 | torch.set_rng_state(state) 166 | img = self.sn_transform(img) 167 | return img 168 | 169 | if task_type == DEP: 170 | if self.depth_transform is not None: 171 | random.seed(seed) 172 | img = self.depth_images[rgb_image, :,:,:] 173 | torch.set_rng_state(state) 174 | img = self.depth_transform(img) 175 | if isinstance(img, torch.Tensor): 176 | # depth png is uint16 177 | img = img.float() # / 1e4 178 | return img 179 | 180 | 181 | 182 | 183 | 184 | def __len__(self): 185 | return len(self._files) 186 | 187 | def __repr__(self): 188 | fmt_str = f"Dataset {self.__class__.__name__}\n" 189 | fmt_str += f" Number of data points: {self.__len__()}\n" 190 | fmt_str += f" Split: {self._split}\n" 191 | fmt_str += f" Root Location: {self.root}\n" 192 | tmp = " RGB Transforms: " 193 | fmt_str += "{0}{1}\n".format( 194 | tmp, self.rgb_transform.__repr__().replace("\n", "\n" + " " * len(tmp)) 195 | ) 196 | tmp = " Seg Transforms: " 197 | fmt_str += "{0}{1}\n".format( 198 | tmp, self.seg_transform.__repr__().replace("\n", "\n" + " " * len(tmp)) 199 | ) 200 | tmp = " SN Transforms: " 201 | fmt_str += "{0}{1}\n".format( 202 | tmp, self.sn_transform.__repr__().replace("\n", "\n" + " " * len(tmp)) 203 | ) 204 | tmp = " Depth Transforms: " 205 | fmt_str += "{0}{1}\n".format( 206 | tmp, self.depth_transform.__repr__().replace("\n", "\n" + " " * len(tmp)) 207 | ) 208 | return fmt_str 209 | 210 | def _check_exists(self) -> bool: 211 | """ 212 | Only checking for folder existence 213 | """ 214 | try: 215 | for split in ["train", "test"]: 216 | for part, transform in zip( 217 | ["rgb", "seg13", "depth"],#"sn", 218 | [ 219 | self.rgb_transform, 220 | self.seg_transform, 221 | self.sn_transform, 222 | self.depth_transform, 223 | ], 224 | ): 225 | if transform is None: 226 | continue 227 | path = os.path.join(self.root, f"{split}_{part}_pt") 228 | if not os.path.exists(path): 229 | raise FileNotFoundError("Missing Folder") 230 | except FileNotFoundError as e: 231 | return False 232 | return True 233 | 234 | def download(self): 235 | if self._check_exists(): 236 | return 237 | if self.rgb_transform is not None: 238 | download_rgb(self.root) 239 | if self.seg_transform is not None: 240 | download_seg(self.root) 241 | if self.sn_transform is not None: 242 | download_sn(self.root) 243 | if self.depth_transform is not None: 244 | download_depth(self.root) 245 | print("Done!") 246 | 247 | 248 | def download_rgb(root: str): 249 | train_url = "http://www.doc.ic.ac.uk/~ahanda/nyu_train_rgb.tgz" 250 | test_url = "http://www.doc.ic.ac.uk/~ahanda/nyu_test_rgb.tgz" 251 | 252 | def _proc(url: str, dst: str): 253 | if not os.path.exists(dst): 254 | tar = os.path.join(root, url.split("/")[-1]) 255 | if not os.path.exists(tar): 256 | download_url(url, root) 257 | if os.path.exists(tar): 258 | _unpack(tar) 259 | _replace_folder(tar.rstrip(".tgz"), dst) 260 | _rename_files(dst, lambda x: x.split("_")[2]) 261 | 262 | _proc(train_url, os.path.join(root, "train_rgb")) 263 | _proc(test_url, os.path.join(root, "test_rgb")) 264 | 265 | 266 | def download_seg(root: str): 267 | train_url = "https://github.com/ankurhanda/nyuv2-meta-data/raw/master/train_labels_13/nyuv2_train_class13.tgz" 268 | test_url = "https://github.com/ankurhanda/nyuv2-meta-data/raw/master/test_labels_13/nyuv2_test_class13.tgz" 269 | 270 | def _proc(url: str, dst: str): 271 | if not os.path.exists(dst): 272 | tar = os.path.join(root, url.split("/")[-1]) 273 | if not os.path.exists(tar): 274 | download_url(url, root) 275 | if os.path.exists(tar): 276 | _unpack(tar) 277 | _replace_folder(tar.rstrip(".tgz"), dst) 278 | _rename_files(dst, lambda x: x.split("_")[3]) 279 | 280 | _proc(train_url, os.path.join(root, "train_seg13")) 281 | _proc(test_url, os.path.join(root, "test_seg13")) 282 | 283 | 284 | def download_sn(root: str): 285 | url = "https://www.dropbox.com/s/dn5sxhlgml78l03/nyu_normals_gt.zip" 286 | train_dst = os.path.join(root, "train_sn") 287 | test_dst = os.path.join(root, "test_sn") 288 | 289 | if not os.path.exists(train_dst) or not os.path.exists(test_dst): 290 | tar = os.path.join(root, url.split("/")[-1]) 291 | if not os.path.exists(tar): 292 | req = requests.get(url + "?dl=1") # dropbox 293 | with open(tar, 'wb') as f: 294 | f.write(req.content) 295 | if os.path.exists(tar): 296 | _unpack(tar) 297 | if not os.path.exists(train_dst): 298 | _replace_folder( 299 | os.path.join(root, "nyu_normals_gt", "train"), train_dst 300 | ) 301 | _rename_files(train_dst, lambda x: x[1:]) 302 | if not os.path.exists(test_dst): 303 | _replace_folder(os.path.join(root, "nyu_normals_gt", "test"), test_dst) 304 | _rename_files(test_dst, lambda x: x[1:]) 305 | shutil.rmtree(os.path.join(root, "nyu_normals_gt")) 306 | 307 | 308 | def download_depth(root: str): 309 | url = ( 310 | "http://horatio.cs.nyu.edu/mit/silberman/nyu_depth_v2/nyu_depth_v2_labeled.mat" 311 | ) 312 | train_dst = os.path.join(root, "train_depth") 313 | test_dst = os.path.join(root, "test_depth") 314 | 315 | if not os.path.exists(train_dst) or not os.path.exists(test_dst): 316 | tar = os.path.join(root, url.split("/")[-1]) 317 | if not os.path.exists(tar): 318 | download_url(url, root) 319 | if os.path.exists(tar): 320 | train_ids = [ 321 | f.split(".")[0] for f in os.listdir(os.path.join(root, "train_rgb")) 322 | ] 323 | _create_depth_files(tar, root, train_ids) 324 | 325 | 326 | def _unpack(file: str): 327 | """ 328 | Unpacks tar and zip, does nothing for any other type 329 | :param file: path of file 330 | """ 331 | path = file.rsplit(".", 1)[0] 332 | 333 | if file.endswith(".tgz"): 334 | tar = tarfile.open(file, "r:gz") 335 | tar.extractall(path) 336 | tar.close() 337 | elif file.endswith(".zip"): 338 | zip = zipfile.ZipFile(file, "r") 339 | zip.extractall(path) 340 | zip.close() 341 | 342 | 343 | def _rename_files(folder: str, rename_func: callable): 344 | """ 345 | Renames all files inside a folder based on the passed rename function 346 | :param folder: path to folder that contains files 347 | :param rename_func: function renaming filename (not including path) str -> str 348 | """ 349 | imgs_old = os.listdir(folder) 350 | imgs_new = [rename_func(file) for file in imgs_old] 351 | for img_old, img_new in zip(imgs_old, imgs_new): 352 | shutil.move(os.path.join(folder, img_old), os.path.join(folder, img_new)) 353 | 354 | 355 | def _replace_folder(src: str, dst: str): 356 | """ 357 | Rename src into dst, replacing/overwriting dst if it exists. 358 | """ 359 | if os.path.exists(dst): 360 | shutil.rmtree(dst) 361 | shutil.move(src, dst) 362 | 363 | 364 | def _create_depth_files(mat_file: str, root: str, train_ids: list): 365 | """ 366 | Extract the depth arrays from the mat file into images 367 | :param mat_file: path to the official labelled dataset .mat file 368 | :param root: The root directory of the dataset 369 | :param train_ids: the IDs of the training images as string (for splitting) 370 | """ 371 | os.mkdir(os.path.join(root, "train_depth")) 372 | os.mkdir(os.path.join(root, "test_depth")) 373 | train_ids = set(train_ids) 374 | 375 | depths = h5py.File(mat_file, "r")["depths"] 376 | for i in range(len(depths)): 377 | img = (depths[i] * 1e4).astype(np.uint16).T 378 | id_ = str(i + 1).zfill(4) 379 | folder = "train" if id_ in train_ids else "test" 380 | save_path = os.path.join(root, f"{folder}_depth", id_ + ".png") 381 | Image.fromarray(img).save(save_path) 382 | -------------------------------------------------------------------------------- /script/train_nyu.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import torchvision.transforms as transforms 4 | from torch.utils.data import DataLoader 5 | import torch.optim as optim 6 | import torch.nn.functional as F 7 | from torch.utils.tensorboard import SummaryWriter 8 | import transformers 9 | from tqdm import tqdm 10 | import numpy as np 11 | import os 12 | import pickle 13 | import cv2 14 | import json 15 | import argparse 16 | 17 | from data.nyuv2_same_batch import NYUv2SameBatchDataset 18 | from model.swin_transformer import SwinTransformer 19 | from loss.losses import berHuLoss 20 | from loss.metrics import iou_pytorch, eval_depth 21 | 22 | 23 | def get_config(): 24 | parser = argparse.ArgumentParser(description='Train the network') 25 | parser.add_argument('--config', help='train config file path') 26 | 27 | args = parser.parse_args() 28 | 29 | with open(args.config, "r") as jsonfile: 30 | config = json.load(jsonfile) 31 | 32 | return config 33 | 34 | def freeze_encoder_layers( 35 | model, 36 | conditioned_blocks= [[], [], [*range(12, 18)], []], 37 | unfrozen_modules=[ 38 | "random_weight_matrix", 39 | "film.gb_weights", 40 | "ln_weight_modulation.gb_weights", 41 | "adapter", 42 | "task_type_embeddings", 43 | "patch_embed", 44 | "decoder", 45 | "bottleneck" 46 | ], 47 | frozen_encoder = False 48 | ): 49 | for name, param in model.named_parameters(): 50 | param.requires_grad = not frozen_encoder 51 | 52 | for module in unfrozen_modules: 53 | if module in name: 54 | param.requires_grad = True 55 | 56 | if name.startswith("layers"): 57 | splitted = name.split(".") 58 | 59 | if len(conditioned_blocks[int(splitted[1])]) > 0 and splitted[2]=="blocks" and (int(splitted[3]) in conditioned_blocks[int(splitted[1])]): 60 | param.requires_grad = True 61 | elif name.startswith("norm"): 62 | param.requires_grad = True 63 | 64 | def disp2meters(d): 65 | return (65536.0 / d - 1 ) / 1e4 66 | 67 | def calc_seg_metrics(logit_task, label_task): 68 | 69 | max_labels = torch.argmax(logit_task, dim = 1, keepdim=True) 70 | iou = iou_pytorch(max_labels, label_task) 71 | 72 | return max_labels, iou 73 | 74 | def train(model, train_loader, test_loader, optimizer, scheduler, criterion, epochs, tensorboard_name, seg_weight, depth_weight, start_epoch = 0, device = "cuda", tb_writer=None): 75 | 76 | # Training loop 77 | model.train() 78 | 79 | iters = len(train_loader) 80 | 81 | for e in tqdm(range(epochs)): 82 | 83 | 84 | epoch = e + start_epoch + 1 85 | 86 | epoch_loss = 0.0 87 | epoch_loss_seg = [] 88 | epoch_loss_depth = [] 89 | 90 | train_ious = [] 91 | train_depths_rmse = [] 92 | train_depths_d1 = [] 93 | 94 | for i, (img, label, task_id) in enumerate(train_loader, 0): 95 | model.train() 96 | 97 | img = img.view((-1, 3, 224, 224)).to(device) 98 | label = label.view((-1, 1, 224, 224)).to(device) 99 | task_id = task_id.view(-1).to(device) 100 | 101 | logits, unique_task_ids_list = model(img, task_id) 102 | 103 | loss = 0 104 | 105 | for j, unique_task_id in enumerate(unique_task_ids_list): 106 | 107 | task_id_filter = task_id == unique_task_id 108 | 109 | logit_task = logits[j] 110 | label_task = label[task_id_filter] 111 | 112 | B = logit_task.shape[0] 113 | 114 | 115 | # Task is segmentation 116 | if unique_task_id == 0: 117 | label_task = label_task.long() 118 | 119 | a = criterion[unique_task_id](logit_task.view(B,14,-1), label_task.view(B,-1)) 120 | epoch_loss_seg.append(a.item()) 121 | 122 | loss += a * seg_weight 123 | 124 | # compute metrics every 10 epochs 125 | if epoch%10==0: 126 | max_labels, iou = calc_seg_metrics(logit_task, label_task) 127 | train_ious.append(iou.cpu().numpy()) 128 | 129 | else: 130 | label_task = 65536.0 / (label_task + 1) 131 | logit_task = torch.nn.functional.sigmoid(logit_task)*65535 + 1 132 | 133 | a = criterion[unique_task_id](logit_task, label_task) 134 | epoch_loss_depth.append(a.item()) 135 | 136 | loss += a* depth_weight 137 | 138 | 139 | 140 | # compute metrics every 10 epochs 141 | if epoch%10==0: 142 | evaluation = eval_depth(disp2meters(logit_task), disp2meters(label_task)) 143 | 144 | train_depths_rmse.append(evaluation["rmse"]) 145 | train_depths_d1.append(evaluation["d1"]) 146 | 147 | 148 | optimizer.zero_grad() 149 | loss.backward() 150 | optimizer.step() 151 | 152 | 153 | epoch_loss += loss.item() 154 | 155 | scheduler.step() 156 | 157 | 158 | # Compute validation metrics every 5 epochs 159 | if epoch % 5==0: 160 | 161 | test_loss = 0 162 | epoch_ious = [] 163 | epoch_eval_depths_rmse = [] 164 | epoch_eval_depths_d1 = [] 165 | 166 | epoch_loss_seg_test = [] 167 | epoch_loss_depth_test = [] 168 | 169 | model.eval() 170 | for i, (img, label, task_id) in enumerate(test_loader, 0): 171 | 172 | img = img.view((-1, 3, 224, 224)).to(device) 173 | label = label.view((-1, 1, 224, 224)).to(device) 174 | task_id = task_id.view(-1).to(device) 175 | 176 | logits, unique_task_ids_list = model(img, task_id) 177 | 178 | loss = 0 179 | 180 | for j, unique_task_id in enumerate(unique_task_ids_list): 181 | 182 | 183 | task_id_filter = task_id == unique_task_id 184 | 185 | logit_task = logits[j] 186 | label_task = label[task_id_filter] 187 | B = logit_task.shape[0] 188 | 189 | if unique_task_id == 0: 190 | 191 | label_task = label_task.long() 192 | 193 | a = criterion[unique_task_id](logit_task.view(B,14,-1), label_task.long().view(B,-1)) 194 | epoch_loss_seg_test.append(a.item()) 195 | 196 | loss += a * seg_weight 197 | 198 | max_labels, iou = calc_seg_metrics(logit_task, label_task) 199 | 200 | epoch_ious.append(iou.cpu().numpy()) 201 | 202 | else: 203 | 204 | 205 | logit_task = torch.nn.functional.sigmoid(logit_task)*65535 + 1 206 | label_task = 65536.0 / (label_task + 1) 207 | 208 | a = criterion[unique_task_id](logit_task, label_task)#* len(logit_task) 209 | epoch_loss_depth_test.append(a.item()) 210 | 211 | loss += a* depth_weight 212 | 213 | evaluation = eval_depth(disp2meters(logit_task), disp2meters(label_task)) 214 | epoch_eval_depths_rmse.append(evaluation["rmse"]) 215 | epoch_eval_depths_d1.append(evaluation["d1"]) 216 | 217 | test_loss += loss.item() 218 | 219 | tb_writer.add_scalar(f"{tensorboard_name}/learning_rate", scheduler.get_last_lr()[0] , epoch) 220 | tb_writer.add_scalar(f"{tensorboard_name}/train_loss", epoch_loss/ len(train_loader) , epoch) 221 | tb_writer.add_scalar(f"{tensorboard_name}/test_loss", test_loss/len(test_loader) , epoch) 222 | tb_writer.add_scalar(f"{tensorboard_name}/mean_iou", np.mean(epoch_ious) , epoch) 223 | tb_writer.add_scalar(f"{tensorboard_name}/depth_rmse", np.mean(epoch_eval_depths_rmse) , epoch) 224 | tb_writer.add_scalar(f"{tensorboard_name}/depth_d1", np.mean(epoch_eval_depths_d1) , epoch) 225 | tb_writer.add_scalar(f"{tensorboard_name}/seg_loss", np.mean(epoch_loss_seg) , epoch) 226 | tb_writer.add_scalar(f"{tensorboard_name}/depth_loss", np.mean(epoch_loss_depth) , epoch) 227 | tb_writer.add_scalar(f"{tensorboard_name}/seg_loss_test", np.mean(epoch_loss_seg_test) , epoch) 228 | tb_writer.add_scalar(f"{tensorboard_name}/depth_loss_test", np.mean(epoch_loss_depth_test) , epoch) 229 | 230 | # Save training metrics every 10 epochs 231 | if epoch%10 == 0: 232 | tb_writer.add_scalar(f"{tensorboard_name}/train_mean_iou", np.mean(train_ious) , epoch) 233 | tb_writer.add_scalar(f"{tensorboard_name}/train_depth_rmse", np.mean(train_depths_rmse) , epoch) 234 | tb_writer.add_scalar(f"{tensorboard_name}/train_depth_d1", np.mean(train_depths_d1) , epoch) 235 | 236 | 237 | 238 | # save the model every 500 epochs 239 | if epoch % 500 == 0 or epoch == (epochs-1): 240 | torch.save({ 241 | 'epoch': epoch, 242 | 'model_state_dict': model.state_dict(), 243 | 'optimizer_state_dict': optimizer.state_dict(), 244 | 'scheduler_state_dict': scheduler.state_dict(), 245 | }, f"{tensorboard_name}.pt") 246 | 247 | 248 | def load_model(model, optimizer, scheduler, PATH): 249 | checkpoint = torch.load(PATH, map_location=device) 250 | model.load_state_dict(checkpoint['model_state_dict']) 251 | model = model.to(device) 252 | 253 | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 254 | scheduler.load_state_dict(checkpoint['scheduler_state_dict']) 255 | epoch = checkpoint['epoch'] 256 | return model, optimizer, scheduler, epoch 257 | 258 | def get_dataloaders(tasks, batch_size): 259 | 260 | IMAGE_SIZE = (480, 640) 261 | 262 | train_t = torch.nn.Sequential(transforms.RandomResizedCrop(224, scale=(0.5, 1.0)), transforms.RandomHorizontalFlip()) 263 | test_t = torch.nn.Sequential(transforms.CenterCrop(480), transforms.Resize(224)) 264 | train_t_input_image = torch.nn.Sequential(transforms.ColorJitter(brightness=(0.8, 1.2),contrast =(0.8, 1.2), saturation=(0.8, 1.2), hue=(-0.1,0.1))) 265 | 266 | train_dataset = NYUv2SameBatchDataset(root="./data/nyuv2", tasks=tasks, download=False, train=True, 267 | rgb_transform=train_t, rgb_transform2=train_t_input_image, seg_transform=train_t, sn_transform=train_t, depth_transform=train_t) 268 | 269 | test_dataset = NYUv2SameBatchDataset(root="./data/nyuv2", tasks=tasks, download=False, train=False, 270 | rgb_transform=test_t, seg_transform=test_t, sn_transform=test_t, depth_transform=test_t) 271 | 272 | print("Train dataset size:", len(train_dataset)) 273 | print("Test dataset size:", len(test_dataset)) 274 | 275 | train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) 276 | test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False) 277 | 278 | return train_dataloader, test_dataloader 279 | 280 | 281 | 282 | def main(): 283 | 284 | config = get_config() 285 | 286 | tb_writer = SummaryWriter(f'runs/{config["experiment_name"]}') 287 | 288 | torch.manual_seed(61) 289 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 290 | 291 | tasks = {0:"segmentation", 1:"depth"} 292 | 293 | batch_size = config["batch_size"] 294 | 295 | print("Creating datasets...") 296 | train_dataloader, test_dataloader = get_dataloaders(tasks, batch_size) 297 | 298 | print("Loading model...") 299 | 300 | model = SwinTransformer(img_size=224, 301 | patch_size=4, 302 | in_chans=3, 303 | num_classes=21841, 304 | embed_dim=96, 305 | depths=[2, 2, 18, 2 ], 306 | depths_decoder =[2, 2, 2, 2 ], 307 | num_heads=[ 3, 6, 12, 24 ], 308 | window_size=7, 309 | mlp_ratio=4., 310 | qkv_bias=True, 311 | qk_scale=True, 312 | drop_rate=0, 313 | drop_rate_decoder=0.6, 314 | drop_path_rate=0.2, 315 | ape=False, 316 | patch_norm=True, 317 | use_checkpoint=False, 318 | tasks = ["segmentation", "depth"], 319 | task_classes = [14, 1], 320 | conditioned_blocks = config["conditioned_blocks"], 321 | adapter = config["adapter"], 322 | use_conditional_layer = config["use_conditional_layer_norm"]) 323 | 324 | epochs = config["epochs"] 325 | optimizer = optim.AdamW(model.parameters(), lr=2e-5, betas=(0.9, 0.98)) 326 | 327 | 328 | scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr = 4e-5, epochs = epochs, steps_per_epoch = len(train_dataloader), pct_start = 0.1) 329 | 330 | 331 | if config["continue_training"]: 332 | model, optimizer, scheduler, start_epoch = load_model(model, optimizer, scheduler, config["experiment_name"]+".pt") 333 | print("Continue model loaded") 334 | 335 | else: 336 | start_epoch = -1 337 | model.load_state_dict(torch.load('./pretrained/swin_small_patch4_window7_224_22k.pth')['model'],strict=False) 338 | 339 | model = model.to(device) 340 | print("Pretrained model loaded") 341 | 342 | 343 | freeze_encoder_layers(model, conditioned_blocks = config["conditioned_blocks"], frozen_encoder = config["frozen_encoder"]) 344 | model = model.to(device) 345 | 346 | criterion = [] 347 | segmentation_criteon = torch.nn.CrossEntropyLoss() 348 | criterion.append(segmentation_criteon) 349 | 350 | depth_criterion = berHuLoss() 351 | criterion.append(depth_criterion) 352 | 353 | print("Training",config["experiment_name"],"...") 354 | 355 | train(model, train_dataloader, test_dataloader, optimizer, scheduler, criterion, epochs, config["experiment_name"], config["seg_weight"], config["depth_weight"], start_epoch, device, tb_writer) 356 | 357 | 358 | if __name__ == '__main__': 359 | main() -------------------------------------------------------------------------------- /script/train_nyu_single_task.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import torchvision.transforms as transforms 4 | from torch.utils.data import DataLoader 5 | import torch.optim as optim 6 | import torch.nn.functional as F 7 | from torch.utils.tensorboard import SummaryWriter 8 | import transformers 9 | from tqdm import tqdm 10 | import numpy as np 11 | import os 12 | import pickle 13 | import cv2 14 | import json 15 | import argparse 16 | 17 | from data.nyuv2 import NYUv2Dataset 18 | from model.swin_transformer import SwinTransformer 19 | from loss.losses import berHuLoss 20 | from loss.metrics import iou_pytorch, eval_depth 21 | 22 | 23 | def get_config(): 24 | parser = argparse.ArgumentParser(description='Train the network') 25 | parser.add_argument('--config', help='train config file path') 26 | 27 | args = parser.parse_args() 28 | 29 | with open(args.config, "r") as jsonfile: 30 | config = json.load(jsonfile) 31 | 32 | return config 33 | 34 | def freeze_encoder_layers( 35 | model, 36 | conditioned_blocks= [[], [], [*range(12, 18)], []], 37 | unfrozen_modules=[ 38 | "random_weight_matrix", 39 | "film.gb_weights", 40 | "ln_weight_modulation.gb_weights", 41 | "adapter", 42 | "task_type_embeddings", 43 | "patch_embed", 44 | "decoder", 45 | "bottleneck" 46 | ], 47 | frozen_encoder = False 48 | ): 49 | for name, param in model.named_parameters(): 50 | param.requires_grad = not frozen_encoder 51 | 52 | for module in unfrozen_modules: 53 | if module in name: 54 | param.requires_grad = True 55 | 56 | if name.startswith("layers"): 57 | splitted = name.split(".") 58 | 59 | if len(conditioned_blocks[int(splitted[1])]) > 0 and splitted[2]=="blocks" and (int(splitted[3]) in conditioned_blocks[int(splitted[1])]): 60 | param.requires_grad = True 61 | elif name.startswith("norm"): 62 | param.requires_grad = True 63 | 64 | def disp2meters(d): 65 | return (65536.0 / d - 1 ) / 1e4 66 | 67 | def calc_seg_metrics(logit_task, label_task): 68 | 69 | max_labels = torch.argmax(logit_task, dim = 1, keepdim=True) 70 | iou = iou_pytorch(max_labels, label_task) 71 | 72 | return max_labels, iou 73 | 74 | def train(model, train_loader, test_loader, optimizer, scheduler, criterion, epochs, tensorboard_name, start_epoch = 0, device = "cuda", tb_writer=None, task="segmentation"): 75 | 76 | # Training loop 77 | model.train() 78 | 79 | iters = len(train_loader) 80 | 81 | for e in tqdm(range(epochs)): 82 | 83 | 84 | epoch = e + start_epoch + 1 85 | 86 | epoch_loss = 0.0 87 | epoch_loss_seg = [] 88 | epoch_loss_depth = [] 89 | 90 | train_ious = [] 91 | train_depths_rmse = [] 92 | train_depths_d1 = [] 93 | 94 | for i, (img, label, task_id) in enumerate(train_loader, 0): 95 | model.train() 96 | 97 | img = img.view((-1, 3, 224, 224)).to(device) 98 | label = label.view((-1, 1, 224, 224)).to(device) 99 | task_id = torch.zeros_like(task_id.view(-1).to(device)) 100 | 101 | logits, unique_task_ids_list = model(img, task_id) 102 | 103 | loss = 0 104 | 105 | for j, unique_task_id in enumerate(unique_task_ids_list): 106 | 107 | task_id_filter = task_id == unique_task_id 108 | 109 | logit_task = logits[j] 110 | label_task = label[task_id_filter] 111 | 112 | B = logit_task.shape[0] 113 | 114 | 115 | # Task is segmentation 116 | if task == "segmentation": 117 | label_task = label_task.long() 118 | 119 | a = criterion[0](logit_task.view(B,14,-1), label_task.view(B,-1)) 120 | epoch_loss_seg.append(a.item()) 121 | 122 | loss += a 123 | 124 | # compute metrics every 10 epochs 125 | if epoch%10==0: 126 | max_labels, iou = calc_seg_metrics(logit_task, label_task) 127 | train_ious.append(iou.cpu().numpy()) 128 | 129 | else: 130 | label_task = 65536.0 / (label_task + 1) 131 | logit_task = torch.nn.functional.sigmoid(logit_task)*65535 + 1 132 | 133 | a = criterion[1](logit_task, label_task) 134 | epoch_loss_depth.append(a.item()) 135 | 136 | loss += a 137 | 138 | 139 | 140 | # compute metrics every 10 epochs 141 | if epoch%10==0: 142 | evaluation = eval_depth(disp2meters(logit_task), disp2meters(label_task)) 143 | 144 | train_depths_rmse.append(evaluation["rmse"]) 145 | train_depths_d1.append(evaluation["d1"]) 146 | 147 | 148 | optimizer.zero_grad() 149 | loss.backward() 150 | optimizer.step() 151 | 152 | 153 | epoch_loss += loss.item() 154 | 155 | scheduler.step() 156 | 157 | 158 | # Compute validation metrics every 5 epochs 159 | if epoch % 5==0: 160 | 161 | test_loss = 0 162 | epoch_ious = [] 163 | epoch_eval_depths_rmse = [] 164 | epoch_eval_depths_d1 = [] 165 | 166 | epoch_loss_seg_test = [] 167 | epoch_loss_depth_test = [] 168 | 169 | model.eval() 170 | for i, (img, label, task_id) in enumerate(test_loader, 0): 171 | 172 | img = img.view((-1, 3, 224, 224)).to(device) 173 | label = label.view((-1, 1, 224, 224)).to(device) 174 | task_id = torch.zeros_like(task_id.view(-1).to(device)) 175 | 176 | logits, unique_task_ids_list = model(img, task_id) 177 | 178 | loss = 0 179 | 180 | for j, unique_task_id in enumerate(unique_task_ids_list): 181 | 182 | 183 | task_id_filter = task_id == unique_task_id 184 | 185 | logit_task = logits[j] 186 | label_task = label[task_id_filter] 187 | B = logit_task.shape[0] 188 | 189 | if task == "segmentation": 190 | 191 | label_task = label_task.long() 192 | 193 | a = criterion[0](logit_task.view(B,14,-1), label_task.long().view(B,-1)) 194 | epoch_loss_seg_test.append(a.item()) 195 | 196 | loss += a 197 | 198 | max_labels, iou = calc_seg_metrics(logit_task, label_task) 199 | 200 | epoch_ious.append(iou.cpu().numpy()) 201 | 202 | else: 203 | 204 | 205 | logit_task = torch.nn.functional.sigmoid(logit_task)*65535 + 1 206 | label_task = 65536.0 / (label_task + 1) 207 | 208 | a = criterion[1](logit_task, label_task) 209 | epoch_loss_depth_test.append(a.item()) 210 | 211 | loss += a 212 | 213 | evaluation = eval_depth(disp2meters(logit_task), disp2meters(label_task)) 214 | epoch_eval_depths_rmse.append(evaluation["rmse"]) 215 | epoch_eval_depths_d1.append(evaluation["d1"]) 216 | 217 | test_loss += loss.item() 218 | 219 | tb_writer.add_scalar(f"{tensorboard_name}/learning_rate", scheduler.get_last_lr()[0] , epoch) 220 | tb_writer.add_scalar(f"{tensorboard_name}/train_loss", epoch_loss/ len(train_loader) , epoch) 221 | tb_writer.add_scalar(f"{tensorboard_name}/test_loss", test_loss/len(test_loader) , epoch) 222 | tb_writer.add_scalar(f"{tensorboard_name}/mean_iou", np.mean(epoch_ious) , epoch) 223 | tb_writer.add_scalar(f"{tensorboard_name}/depth_rmse", np.mean(epoch_eval_depths_rmse) , epoch) 224 | tb_writer.add_scalar(f"{tensorboard_name}/depth_d1", np.mean(epoch_eval_depths_d1) , epoch) 225 | tb_writer.add_scalar(f"{tensorboard_name}/seg_loss", np.mean(epoch_loss_seg) , epoch) 226 | tb_writer.add_scalar(f"{tensorboard_name}/depth_loss", np.mean(epoch_loss_depth) , epoch) 227 | tb_writer.add_scalar(f"{tensorboard_name}/seg_loss_test", np.mean(epoch_loss_seg_test) , epoch) 228 | tb_writer.add_scalar(f"{tensorboard_name}/depth_loss_test", np.mean(epoch_loss_depth_test) , epoch) 229 | 230 | # Save training metrics every 10 epochs 231 | if epoch%10 == 0: 232 | tb_writer.add_scalar(f"{tensorboard_name}/train_mean_iou", np.mean(train_ious) , epoch) 233 | tb_writer.add_scalar(f"{tensorboard_name}/train_depth_rmse", np.mean(train_depths_rmse) , epoch) 234 | tb_writer.add_scalar(f"{tensorboard_name}/train_depth_d1", np.mean(train_depths_d1) , epoch) 235 | 236 | 237 | 238 | # save the model every 500 epochs 239 | if epoch % 500 == 0 or epoch == ((epochs)-1): 240 | torch.save({ 241 | 'epoch': epoch, 242 | 'model_state_dict': model.state_dict(), 243 | 'optimizer_state_dict': optimizer.state_dict(), 244 | 'scheduler_state_dict': scheduler.state_dict(), 245 | }, f"{tensorboard_name}.pt") 246 | 247 | 248 | def load_model(model, optimizer, scheduler, PATH): 249 | checkpoint = torch.load(PATH, map_location=device) 250 | model.load_state_dict(checkpoint['model_state_dict']) 251 | model = model.to(device) 252 | 253 | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 254 | scheduler.load_state_dict(checkpoint['scheduler_state_dict']) 255 | epoch = checkpoint['epoch'] 256 | return model, optimizer, scheduler, epoch 257 | 258 | def get_dataloaders(tasks, task, batch_size): 259 | 260 | IMAGE_SIZE = (480, 640) 261 | 262 | train_t = torch.nn.Sequential(transforms.RandomResizedCrop(224, scale=(0.5, 1.0)), transforms.RandomHorizontalFlip()) 263 | test_t = torch.nn.Sequential(transforms.CenterCrop(480), transforms.Resize(224)) 264 | train_t_input_image = torch.nn.Sequential(transforms.ColorJitter(brightness=(0.8, 1.2),contrast =(0.8, 1.2), saturation=(0.8, 1.2), hue=(-0.1,0.1))) 265 | 266 | train_dataset = NYUv2Dataset(root="./data/nyuv2", tasks=tasks, download=False, train=True, 267 | rgb_transform=train_t, rgb_transform2=train_t_input_image, seg_transform=train_t, sn_transform=train_t, depth_transform=train_t) 268 | 269 | test_dataset = NYUv2Dataset(root="./data/nyuv2", tasks=tasks, download=False, train=False, 270 | rgb_transform=test_t, seg_transform=test_t, sn_transform=test_t, depth_transform=test_t) 271 | 272 | if task == "segmentation": 273 | train_dataset = torch.utils.data.Subset(train_dataset, range(len(train_dataset)//2)) 274 | test_dataset = torch.utils.data.Subset(test_dataset, range(len(test_dataset)//2)) 275 | 276 | if task == "depth": 277 | 278 | train_dataset = torch.utils.data.Subset(train_dataset, range(len(train_dataset)//2, len(train_dataset))) 279 | test_dataset = torch.utils.data.Subset(test_dataset, range(len(test_dataset)//2, len(test_dataset))) 280 | 281 | 282 | print("Train dataset size:", len(train_dataset)) 283 | print("Test dataset size:", len(test_dataset)) 284 | 285 | train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) 286 | test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False) 287 | 288 | return train_dataloader, test_dataloader 289 | 290 | 291 | 292 | def main(): 293 | # default `log_dir` is "runs" - we'll be more specific here 294 | 295 | config = get_config() 296 | 297 | tb_writer = SummaryWriter(f'runs/{config["experiment_name"]}') 298 | 299 | torch.manual_seed(61) 300 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 301 | 302 | tasks = {0:"segmentation", 1:"depth"} 303 | batch_size = config["batch_size"] 304 | 305 | print("Creating datasets...") 306 | train_dataloader, test_dataloader = get_dataloaders(tasks, config["task"], batch_size) 307 | 308 | print("Loading model...") 309 | 310 | model = SwinTransformer(img_size=224, 311 | patch_size=4, 312 | in_chans=3, 313 | num_classes=21841, 314 | embed_dim=96, 315 | depths=[2, 2, 18, 2 ], 316 | depths_decoder =[2, 2, 2, 2 ], 317 | num_heads=[ 3, 6, 12, 24 ], 318 | window_size=7, 319 | mlp_ratio=4., 320 | qkv_bias=True, 321 | qk_scale=True, 322 | drop_rate=0, 323 | drop_rate_decoder=0.6, 324 | drop_path_rate=0.2, 325 | ape=False, 326 | patch_norm=True, 327 | use_checkpoint=False, 328 | tasks = [config["task"]], 329 | task_classes = [14 if config["task"]=="segmentation" else 1], 330 | conditioned_blocks = [[],[],[],[]], 331 | adapter = False, 332 | use_conditional_layer = False) 333 | 334 | epochs = config["epochs"] 335 | optimizer = optim.AdamW(model.parameters(), lr=2e-5, betas=(0.9, 0.98)) 336 | 337 | 338 | scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr = 4e-5, epochs = epochs, steps_per_epoch = len(train_dataloader), pct_start = 0.1) 339 | 340 | scheduler_batch_step = True 341 | use_scheduler = True 342 | 343 | if config["continue_training"]: 344 | model, optimizer, scheduler, start_epoch = load_model(model, optimizer, scheduler, config["experiment_name"]+".pt") 345 | print("Continue model loaded") 346 | 347 | else: 348 | start_epoch = -1 349 | model.load_state_dict(torch.load('./pretrained/swin_small_patch4_window7_224_22k.pth')['model'],strict=False) 350 | 351 | model = model.to(device) 352 | print("Pretrained model loaded") 353 | 354 | 355 | freeze_encoder_layers(model, conditioned_blocks = [[],[],[],[]], frozen_encoder = config["frozen_encoder"]) 356 | model = model.to(device) 357 | 358 | criterion = [] 359 | segmentation_criteon = torch.nn.CrossEntropyLoss() 360 | criterion.append(segmentation_criteon) 361 | 362 | depth_criterion = berHuLoss() 363 | criterion.append(depth_criterion) 364 | 365 | print("Training",config["experiment_name"],"...") 366 | 367 | 368 | train(model, train_dataloader, test_dataloader, optimizer, scheduler, criterion, epochs, config["experiment_name"], start_epoch, device, tb_writer, config["task"]) 369 | 370 | 371 | if __name__ == '__main__': 372 | main() -------------------------------------------------------------------------------- /script/train_taskonomy.py: -------------------------------------------------------------------------------- 1 | from data.taskonomy.taskonomy_dataset_s3 import TaskonomyDatasetS3 2 | import torch 3 | import torchvision 4 | import torchvision.transforms as transforms 5 | from torch.utils.data import DataLoader 6 | import torch.optim as optim 7 | import torch.nn.functional as F 8 | from torch.utils.tensorboard import SummaryWriter 9 | import transformers 10 | from tqdm import tqdm 11 | import numpy as np 12 | import os 13 | import pickle 14 | import cv2 15 | import json 16 | import argparse 17 | 18 | # Small adjustments are made to the Swin Transformer model for parallel run, not related to architecture 19 | from model.swin_transformer_parallel import SwinTransformer 20 | from loss.losses import berHuLoss 21 | from loss.metrics import iou_pytorch, eval_depth 22 | import transformers 23 | 24 | 25 | unique_task_ids_list = [0,1] 26 | 27 | def get_config(): 28 | parser = argparse.ArgumentParser(description='Train the network') 29 | parser.add_argument('--config', help='train config file path') 30 | 31 | args = parser.parse_args() 32 | 33 | with open(args.config, "r") as jsonfile: 34 | config = json.load(jsonfile) 35 | 36 | return config 37 | 38 | def freeze_encoder_layers( 39 | model, 40 | conditioned_blocks= [[], [], [*range(12, 18)], []], 41 | unfrozen_modules=[ 42 | "random_weight_matrix", 43 | "film.gb_weights", 44 | "ln_weight_modulation.gb_weights", 45 | "adapter", 46 | "task_type_embeddings", 47 | "patch_embed", 48 | "decoder", 49 | "bottleneck" 50 | ], 51 | frozen_encoder = False 52 | ): 53 | for name, param in model.named_parameters(): 54 | param.requires_grad = not frozen_encoder # remove 'not' for a frozen encoder 55 | 56 | for module in unfrozen_modules: 57 | if module in name: 58 | param.requires_grad = True 59 | 60 | if name.startswith("layers"): 61 | splitted = name.split(".") 62 | 63 | if len(conditioned_blocks[int(splitted[1])]) > 0 and splitted[2]=="blocks" and (int(splitted[3]) in conditioned_blocks[int(splitted[1])]): 64 | param.requires_grad = True 65 | elif name.startswith("norm"): 66 | param.requires_grad = True 67 | 68 | 69 | def disp2meters(d): 70 | return (65536.0 / d - 1 ) / 1e4 71 | 72 | def calc_seg_metrics(logit_task, label_task): 73 | 74 | max_labels = torch.argmax(logit_task, dim = 1, keepdim=True) 75 | iou = iou_pytorch(max_labels, label_task) 76 | 77 | return max_labels, iou 78 | 79 | def check_val(model, dataloader, criterion, index, tensorboard_name, seg_weight, depth_weight, device, tb_writer): 80 | test_loss = 0 81 | epoch_ious = [] 82 | epoch_eval_depths_rmse = [] 83 | epoch_eval_depths_d1 = [] 84 | 85 | epoch_loss_seg_test = [] 86 | epoch_loss_depth_test = [] 87 | 88 | model.eval() 89 | 90 | for i, (img, label, task_id) in enumerate(dataloader, 0): 91 | 92 | img = img.view((-1, 3, 224, 224)).to(device) 93 | label = label.view((-1, 1, 224, 224)).to(device) 94 | task_id = task_id.view(-1).to(device) 95 | 96 | logits = model(img, task_id) 97 | 98 | loss = 0 99 | 100 | for j, unique_task_id in enumerate(unique_task_ids_list): 101 | 102 | 103 | task_id_filter = task_id == unique_task_id 104 | 105 | logit_task = logits[j] 106 | if logit_task is None: 107 | continue 108 | 109 | label_task = label[task_id_filter] 110 | B = logit_task.shape[0] 111 | 112 | 113 | if unique_task_id == 0: 114 | 115 | label_task = label_task.long() 116 | 117 | a = criterion[unique_task_id](logit_task.view(B,18,-1), label_task.long().view(B,-1)) 118 | loss += a * seg_weight 119 | 120 | epoch_loss_seg_test.append(a.item() ) 121 | 122 | max_labels = torch.argmax(logit_task, dim = 1, keepdim=True) 123 | 124 | iou = iou_pytorch(max_labels, label_task) 125 | 126 | epoch_ious.append(iou.cpu().numpy()) 127 | 128 | else: 129 | 130 | logit_task = torch.nn.functional.sigmoid(logit_task)*65535 + 1 131 | label_task = 65536.0 / (label_task + 1) 132 | 133 | a = criterion[unique_task_id](logit_task, label_task, mask_val = 1.0)#* len(logit_task) 134 | loss += a* depth_weight 135 | epoch_loss_depth_test.append(a.item() ) 136 | 137 | 138 | evaluation = eval_depth(disp2meters(logit_task), disp2meters(label_task)) 139 | epoch_eval_depths_rmse.append(evaluation["rmse"]) 140 | epoch_eval_depths_d1.append(evaluation["d1"]) 141 | 142 | test_loss += loss.item() 143 | 144 | tb_writer.add_scalar(f"{tensorboard_name}/mid_test_loss", test_loss/len(dataloader) , index) 145 | tb_writer.add_scalar(f"{tensorboard_name}/mid_mean_iou", np.mean(epoch_ious) , index) 146 | tb_writer.add_scalar(f"{tensorboard_name}/mid_depth_rmse", np.mean(epoch_eval_depths_rmse) , index) 147 | tb_writer.add_scalar(f"{tensorboard_name}/mid_depth_d1", np.mean(epoch_eval_depths_d1) , index) 148 | tb_writer.add_scalar(f"{tensorboard_name}/mid_seg_loss_test", np.mean(epoch_loss_seg_test) , index) 149 | tb_writer.add_scalar(f"{tensorboard_name}/mid_depth_loss_test", np.mean(epoch_loss_depth_test) , index) 150 | 151 | 152 | 153 | 154 | 155 | def train(model, train_loader, test_loader,mid_test_dataloader, optimizer, scheduler, criterion, epochs, tensorboard_name, seg_weight, depth_weight, start_epoch = 0, device = "cuda", tb_writer=None): 156 | 157 | model.train() 158 | 159 | train_losses = [] 160 | test_losses = [] 161 | ious = [] 162 | eval_depths = [] 163 | 164 | iters = len(train_loader) 165 | 166 | 167 | mid_iter_count = 0 168 | mid_iter_count = 30 169 | 170 | 171 | for e in tqdm(range(epochs), desc="epoch", position=0): 172 | 173 | epoch = e + start_epoch + 1 174 | 175 | epoch_loss = 0.0 176 | epoch_loss_seg = [] 177 | epoch_loss_depth = [] 178 | 179 | 180 | 181 | #model.train() 182 | train_ious = [] 183 | train_depths_rmse = [] 184 | train_depths_d1 = [] 185 | 186 | 187 | for i, (img, label, task_id) in tqdm(enumerate(train_loader, 0), desc="iter", position=1, leave=False): 188 | model.train() 189 | 190 | img = img.view((-1, 3, 224, 224)).to(device) 191 | label = label.view((-1, 1, 224, 224)).to(device) 192 | task_id = task_id.view(-1).to(device) 193 | 194 | 195 | logits = model(img, task_id) 196 | 197 | loss = 0 198 | 199 | 200 | for j, unique_task_id in enumerate(unique_task_ids_list): 201 | 202 | task_id_filter = task_id == unique_task_id 203 | 204 | 205 | logit_task = logits[j] 206 | if logit_task is None: 207 | continue 208 | label_task = label[task_id_filter] 209 | 210 | B = logit_task.shape[0] 211 | 212 | if unique_task_id == 0: 213 | label_task = label_task.long() 214 | 215 | a = criterion[unique_task_id](logit_task.view(B,18,-1), label_task.view(B,-1)) #* len(logit_task) 216 | loss += a * seg_weight 217 | epoch_loss_seg.append(a.item() ) 218 | 219 | if epoch%1==0 : 220 | max_labels = torch.argmax(logit_task, dim = 1, keepdim=True) 221 | iou = iou_pytorch(max_labels, label_task) 222 | train_ious.append(iou.cpu().numpy()) 223 | 224 | 225 | else: 226 | logit_task = torch.nn.functional.sigmoid(logit_task)*65535 + 1 227 | label_task = 65536.0 / (label_task + 1) 228 | 229 | a = criterion[unique_task_id](logit_task, label_task, mask_val = 1.0) #* len(logit_task) 230 | loss += a* depth_weight 231 | epoch_loss_depth.append(a.item()) 232 | 233 | if epoch%1==0: 234 | evaluation = eval_depth(disp2meters(logit_task), disp2meters(label_task)) 235 | train_depths_rmse.append(evaluation["rmse"]) 236 | train_depths_d1.append(evaluation["d1"]) 237 | 238 | 239 | optimizer.zero_grad() 240 | loss.backward() 241 | optimizer.step() 242 | 243 | 244 | epoch_loss += loss.item() 245 | 246 | scheduler.step() 247 | 248 | 249 | if i % 1000 == 0: 250 | check_val(model, mid_test_dataloader, criterion, mid_iter_count*1000, tensorboard_name, seg_weight, depth_weight, device, tb_writer) 251 | mid_iter_count += 1 252 | 253 | tb_writer.add_scalar(f"{tensorboard_name}/mid_train_loss", epoch_loss/ (i+1) , mid_iter_count*1000) 254 | 255 | torch.save({ 256 | 'epoch': epoch, 257 | 'iter': i, 258 | 'model_state_dict': model.module.state_dict(), 259 | 'optimizer_state_dict': optimizer.state_dict(), 260 | 'scheduler_state_dict': scheduler.state_dict(), 261 | }, f"{tensorboard_name}.pt") 262 | 263 | 264 | 265 | if epoch % 1==0: 266 | test_loss = 0 267 | epoch_ious = [] 268 | epoch_eval_depths_rmse = [] 269 | epoch_eval_depths_d1 = [] 270 | 271 | epoch_loss_seg_test = [] 272 | epoch_loss_depth_test = [] 273 | 274 | model.eval() 275 | for i, (img, label, task_id) in enumerate(test_loader, 0): 276 | 277 | img = img.view((-1, 3, 224, 224)).to(device) 278 | label = label.view((-1, 1, 224, 224)).to(device) 279 | task_id = task_id.view(-1).to(device) 280 | 281 | 282 | logits = model(img, task_id) 283 | 284 | 285 | loss = 0 286 | 287 | for j, unique_task_id in enumerate(unique_task_ids_list): 288 | 289 | 290 | task_id_filter = task_id == unique_task_id 291 | 292 | logit_task = logits[j] 293 | if logit_task is None: 294 | continue 295 | label_task = label[task_id_filter] 296 | B = logit_task.shape[0] 297 | 298 | 299 | if unique_task_id == 0: 300 | 301 | label_task = label_task.long() 302 | 303 | a = criterion[unique_task_id](logit_task.view(B,18,-1), label_task.long().view(B,-1)) 304 | loss += a * seg_weight 305 | 306 | epoch_loss_seg_test.append(a.item() ) 307 | 308 | if epoch%1==0: 309 | 310 | max_labels = torch.argmax(logit_task, dim = 1, keepdim=True) 311 | 312 | iou = iou_pytorch(max_labels, label_task) 313 | 314 | epoch_ious.append(iou.cpu().numpy()) 315 | 316 | else: 317 | 318 | 319 | logit_task = torch.nn.functional.sigmoid(logit_task)*65535 + 1 320 | label_task = 65536.0 / (label_task + 1) 321 | 322 | a = criterion[unique_task_id](logit_task, label_task, mask_val = 1.0)#* len(logit_task) 323 | loss += a* depth_weight 324 | 325 | epoch_loss_depth_test.append(a.item() ) 326 | 327 | if epoch%1==0: 328 | evaluation = eval_depth(disp2meters(logit_task), disp2meters(label_task)) 329 | epoch_eval_depths_rmse.append(evaluation["rmse"]) 330 | epoch_eval_depths_d1.append(evaluation["d1"]) 331 | 332 | 333 | 334 | test_loss += loss.item() 335 | 336 | 337 | 338 | if epoch % 1==0: 339 | 340 | 341 | tb_writer.add_scalar(f"{tensorboard_name}/learning_rate", scheduler.get_last_lr()[0] , epoch) 342 | tb_writer.add_scalar(f"{tensorboard_name}/train_loss", epoch_loss/ len(train_loader) , epoch) 343 | tb_writer.add_scalar(f"{tensorboard_name}/test_loss", test_loss/len(test_loader) , epoch) 344 | tb_writer.add_scalar(f"{tensorboard_name}/mean_iou", np.mean(epoch_ious) , epoch) 345 | tb_writer.add_scalar(f"{tensorboard_name}/depth_rmse", np.mean(epoch_eval_depths_rmse) , epoch) 346 | tb_writer.add_scalar(f"{tensorboard_name}/depth_d1", np.mean(epoch_eval_depths_d1) , epoch) 347 | 348 | tb_writer.add_scalar(f"{tensorboard_name}/seg_loss", np.mean(epoch_loss_seg) , epoch) 349 | tb_writer.add_scalar(f"{tensorboard_name}/depth_loss", np.mean(epoch_loss_depth) , epoch) 350 | 351 | tb_writer.add_scalar(f"{tensorboard_name}/seg_loss_test", np.mean(epoch_loss_seg_test) , epoch) 352 | tb_writer.add_scalar(f"{tensorboard_name}/depth_loss_test", np.mean(epoch_loss_depth_test) , epoch) 353 | 354 | if epoch%1 == 0: 355 | tb_writer.add_scalar(f"{tensorboard_name}/train_mean_iou", np.mean(train_ious) , epoch) 356 | tb_writer.add_scalar(f"{tensorboard_name}/train_depth_rmse", np.mean(train_depths_rmse) , epoch) 357 | tb_writer.add_scalar(f"{tensorboard_name}/train_depth_d1", np.mean(train_depths_d1) , epoch) 358 | 359 | 360 | 361 | if epoch % 1 == 0: 362 | torch.save({ 363 | 'epoch': epoch, 364 | 'model_state_dict': model.state_dict(), 365 | 'optimizer_state_dict': optimizer.state_dict(), 366 | 'scheduler_state_dict': scheduler.state_dict(), 367 | }, f"{tensorboard_name}.pt") 368 | 369 | return train_losses, test_losses, ious, eval_depths 370 | 371 | 372 | def load_model(model, optimizer, scheduler, PATH): 373 | checkpoint = torch.load(PATH, map_location=device) 374 | model.load_state_dict(checkpoint['model_state_dict']) 375 | model = model.to(device) 376 | 377 | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 378 | scheduler.load_state_dict(checkpoint['scheduler_state_dict']) 379 | epoch = checkpoint['epoch'] 380 | return model, optimizer, scheduler, epoch 381 | 382 | 383 | def get_dataloaders(tasks, batch_size): 384 | 385 | train_dataset = TaskonomyDatasetS3(tasks=["rgb", "segment_semantic","depth_euclidean"], split="train", variant="tiny", image_size=224) 386 | test_dataset = TaskonomyDatasetS3(tasks=["rgb", "segment_semantic","depth_euclidean"], split="val", variant="tiny", image_size=224) 387 | 388 | print("Train dataset size:", len(train_dataset)) 389 | print("Test dataset size:", len(test_dataset)) 390 | 391 | g = torch.Generator() 392 | g.manual_seed(61) 393 | 394 | k_samples = 16*100 395 | perm = torch.randperm(len(test_dataset), generator=g) 396 | idx = perm[:k_samples].tolist() 397 | 398 | subset_dataset_test = torch.utils.data.Subset(test_dataset, idx) 399 | 400 | train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) 401 | mid_test_dataloader = DataLoader(subset_dataset_test, batch_size=batch_size, shuffle=False) 402 | test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False) 403 | 404 | return train_dataloader, mid_test_dataloader, test_dataloader 405 | 406 | class _CustomDataParallel(torch.nn.DataParallel): 407 | def __init__(self, model): 408 | super(_CustomDataParallel, self).__init__(model) 409 | 410 | def __getattr__(self, name): 411 | try: 412 | return super(_CustomDataParallel, self).__getattr__(name) 413 | except AttributeError: 414 | return getattr(self.module, name) 415 | 416 | def main(): 417 | 418 | config = get_config() 419 | 420 | tb_writer = SummaryWriter(f'runs/{config["experiment_name"]}') 421 | 422 | torch.manual_seed(61) 423 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 424 | 425 | tasks = {0:"segmentation", 1:"depth"} # add 2:"normals" and 3:"edges" to replicate the above code 426 | 427 | batch_size = config["batch_size"] 428 | print("Creating datasets...") 429 | train_dataloader, mid_test_dataloader, test_dataloader = get_dataloaders(tasks, batch_size) 430 | 431 | 432 | print("Loading model...") 433 | model = SwinTransformer(img_size=224, 434 | patch_size=4, 435 | in_chans=3, 436 | num_classes=21841, 437 | embed_dim=96, 438 | depths=[2, 2, 18, 2 ], 439 | depths_decoder =[2, 2, 2, 2 ], 440 | num_heads=[ 3, 6, 12, 24 ], 441 | window_size=7, 442 | mlp_ratio=4., 443 | qkv_bias=True, 444 | qk_scale=True, 445 | drop_rate=0, 446 | drop_rate_decoder=0.6, 447 | drop_path_rate=0.2, 448 | ape=False, 449 | patch_norm=True, 450 | use_checkpoint=False, 451 | tasks = ["segmentation", "depth"], 452 | task_classes = [18, 1], 453 | conditioned_blocks = config["conditioned_blocks"], 454 | adapter=config["adapter"]) 455 | 456 | 457 | epochs = config["epochs"] 458 | 459 | 460 | optimizer = optim.AdamW(model.parameters(), lr=2e-5, betas=(0.9, 0.98))#, weight_decay=0.001) 461 | 462 | scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr = 4e-5, epochs = epochs, steps_per_epoch = len(train_dataloader), pct_start = 0.1) 463 | 464 | 465 | if config["continue_training"]: 466 | 467 | model, optimizer, scheduler, start_epoch = load_model(model, optimizer, scheduler, config["experiment_name"]+".pt") 468 | print("Continue model loaded") 469 | 470 | else: 471 | start_epoch = -1 472 | model.load_state_dict(torch.load('./pretrained/swin_small_patch4_window7_224_22k.pth')['model'],strict=False) 473 | model = model.to(device) 474 | print("Model loaded") 475 | 476 | 477 | 478 | freeze_encoder_layers(model, conditioned_blocks = config["conditioned_blocks"], frozen_encoder = config["frozen_encoder"]) 479 | model = model.to(device) 480 | 481 | model = torch.nn.DataParallel(model, device_ids=[0,1,2,3]) 482 | print("Model on cuda:",next(model.parameters()).is_cuda) 483 | 484 | criterion = [] 485 | segmentation_criteon = torch.nn.CrossEntropyLoss(ignore_index = 0) 486 | criterion.append(segmentation_criteon) 487 | 488 | depth_criterion = berHuLoss() 489 | criterion.append(depth_criterion) 490 | 491 | print("Training",config["experiment_name"],"...") 492 | 493 | 494 | train(model, train_dataloader, test_dataloader, mid_test_dataloader, optimizer, scheduler, criterion, epochs, config["experiment_name"], config["seg_weight"], config["depth_weight"], start_epoch, device, tb_writer) 495 | 496 | 497 | if __name__ == '__main__': 498 | main() 499 | -------------------------------------------------------------------------------- /script/model/swin_transformer_parallel.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.utils.checkpoint as checkpoint 11 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 12 | import math 13 | from model.conditional_modules import TAA, ConditionalBottleNeck, TaskScaledNorm 14 | from einops import rearrange 15 | 16 | import re 17 | import logging 18 | logger = logging.getLogger(__name__) 19 | 20 | 21 | try: 22 | import os 23 | import sys 24 | 25 | kernel_path = os.path.abspath(os.path.join('..')) 26 | sys.path.append(kernel_path) 27 | from kernels.window_process.window_process import WindowProcess, WindowProcessReverse 28 | 29 | except: 30 | WindowProcess = None 31 | WindowProcessReverse = None 32 | print("[Warning] Fused window process have not been installed. Please refer to get_started.md for installation.") 33 | 34 | 35 | class Mlp(nn.Module): 36 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 37 | super().__init__() 38 | out_features = out_features or in_features 39 | hidden_features = hidden_features or in_features 40 | self.fc1 = nn.Linear(in_features, hidden_features) 41 | self.act = act_layer() 42 | self.fc2 = nn.Linear(hidden_features, out_features) 43 | self.drop = nn.Dropout(drop) 44 | 45 | def forward(self, x): 46 | x = self.fc1(x) 47 | x = self.act(x) 48 | x = self.drop(x) 49 | x = self.fc2(x) 50 | x = self.drop(x) 51 | return x 52 | 53 | 54 | def window_partition(x, window_size): 55 | """ 56 | Args: 57 | x: (B, H, W, C) 58 | window_size (int): window size 59 | 60 | Returns: 61 | windows: (num_windows*B, window_size, window_size, C) 62 | """ 63 | B, H, W, C = x.shape 64 | x = x.view(B, H // window_size, window_size, 65 | W // window_size, window_size, C) 66 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous( 67 | ).view(-1, window_size, window_size, C) 68 | return windows 69 | 70 | 71 | def window_reverse(windows, window_size, H, W): 72 | """ 73 | Args: 74 | windows: (num_windows*B, window_size, window_size, C) 75 | window_size (int): Window size 76 | H (int): Height of image 77 | W (int): Width of image 78 | 79 | Returns: 80 | x: (B, H, W, C) 81 | """ 82 | B = int(windows.shape[0] / (H * W / window_size / window_size)) 83 | x = windows.view(B, H // window_size, W // window_size, 84 | window_size, window_size, -1) 85 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) 86 | return x 87 | 88 | 89 | class WindowAttention(nn.Module): 90 | r""" Window based multi-head self attention (W-MSA) module with relative position bias. 91 | It supports both of shifted and non-shifted window. 92 | 93 | Args: 94 | dim (int): Number of input channels. 95 | window_size (tuple[int]): The height and width of the window. 96 | num_heads (int): Number of attention heads. 97 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 98 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set 99 | attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 100 | proj_drop (float, optional): Dropout ratio of output. Default: 0.0 101 | task_configs (dict): Configuration for the tasks 102 | """ 103 | 104 | def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., task_configs=None): 105 | 106 | super().__init__() 107 | self.dim = dim 108 | self.window_size = window_size # Wh, Ww 109 | self.num_heads = num_heads 110 | head_dim = dim // num_heads 111 | self.scale = qk_scale or head_dim ** -0.5 112 | 113 | # define a parameter table of relative position bias 114 | self.relative_position_bias_table = nn.Parameter( 115 | torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH 116 | 117 | # get pair-wise relative position index for each token inside the window 118 | coords_h = torch.arange(self.window_size[0]) 119 | coords_w = torch.arange(self.window_size[1]) 120 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww 121 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww 122 | relative_coords = coords_flatten[:, :, None] - \ 123 | coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww 124 | relative_coords = relative_coords.permute( 125 | 1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 126 | relative_coords[:, :, 0] += self.window_size[0] - \ 127 | 1 # shift to start from 0 128 | relative_coords[:, :, 1] += self.window_size[1] - 1 129 | relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 130 | relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww 131 | self.register_buffer("relative_position_index", 132 | relative_position_index) 133 | 134 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 135 | self.attn_drop = nn.Dropout(attn_drop) 136 | self.proj = nn.Linear(dim, dim) 137 | self.proj_drop = nn.Dropout(proj_drop) 138 | 139 | trunc_normal_(self.relative_position_bias_table, std=.02) 140 | self.softmax = nn.Softmax(dim=-1) 141 | 142 | self.task_configs = task_configs 143 | if task_configs is not None: 144 | 145 | self.max_seq_length = task_configs["max_seq_length"] 146 | self.hidden_size = task_configs["hidden_size"] 147 | 148 | self.num_blocks = self.hidden_size // self.max_seq_length 149 | self.taa_attn = TAA( 150 | self.hidden_size, math.ceil( 151 | self.max_seq_length / self.num_blocks), self.num_blocks 152 | ) 153 | 154 | self.random_weight_matrix = nn.Parameter( 155 | torch.zeros( 156 | [self.max_seq_length, math.ceil( 157 | self.max_seq_length / self.num_blocks)] 158 | ), 159 | requires_grad=True, 160 | ) 161 | 162 | def forward(self, x, task_embedding=None, mask=None): 163 | """ 164 | Args: 165 | x: input features with shape of (num_windows*B, N, C) 166 | mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None 167 | """ 168 | B_, N, C = x.shape 169 | qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // 170 | self.num_heads).permute(2, 0, 3, 1, 4) 171 | q, k, v = qkv[0], qkv[1], qkv[2] 172 | 173 | q = q * self.scale 174 | attn = (q @ k.transpose(-2, -1)) 175 | 176 | relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( 177 | self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH 178 | relative_position_bias = relative_position_bias.permute( 179 | 2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww 180 | attn = attn + relative_position_bias.unsqueeze(0) 181 | 182 | 183 | 184 | if self.task_configs is not None: 185 | 186 | attn2 = self.taa_attn( 187 | x_cond=task_embedding, 188 | x_to_film=self.random_weight_matrix, 189 | ) 190 | 191 | 192 | 193 | attn = attn.view(len(task_embedding), -1, *(attn.shape[1:])) 194 | 195 | 196 | 197 | attn = attn + attn2.unsqueeze(1).unsqueeze(1) 198 | 199 | attn = attn.view(-1, *(attn.shape[2:])) 200 | 201 | 202 | 203 | if mask is not None: 204 | nW = mask.shape[0] 205 | attn = attn.view(B_ // nW, nW, self.num_heads, N, 206 | N) + mask.unsqueeze(1).unsqueeze(0) 207 | attn = attn.view(-1, self.num_heads, N, N) 208 | attn = self.softmax(attn) 209 | else: 210 | attn = self.softmax(attn) 211 | 212 | attn = self.attn_drop(attn) 213 | 214 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C) 215 | x = self.proj(x) 216 | x = self.proj_drop(x) 217 | return x 218 | 219 | def extra_repr(self) -> str: 220 | return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' 221 | 222 | def flops(self, N): 223 | # calculate flops for 1 window with token length of N 224 | flops = 0 225 | # qkv = self.qkv(x) 226 | flops += N * self.dim * 3 * self.dim 227 | # attn = (q @ k.transpose(-2, -1)) 228 | flops += self.num_heads * N * (self.dim // self.num_heads) * N 229 | # x = (attn @ v) 230 | flops += self.num_heads * N * N * (self.dim // self.num_heads) 231 | # x = self.proj(x) 232 | flops += N * self.dim * self.dim 233 | return flops 234 | 235 | 236 | class SwinTransformerBlock(nn.Module): 237 | r""" Swin Transformer Block. 238 | 239 | Args: 240 | dim (int): Number of input channels. 241 | input_resolution (tuple[int]): Input resulotion. 242 | num_heads (int): Number of attention heads. 243 | window_size (int): Window size. 244 | shift_size (int): Shift size for SW-MSA. 245 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 246 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 247 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. 248 | drop (float, optional): Dropout rate. Default: 0.0 249 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 250 | drop_path (float, optional): Stochastic depth rate. Default: 0.0 251 | act_layer (nn.Module, optional): Activation layer. Default: nn.GELU 252 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 253 | fused_window_process (bool, optional): If True, use one kernel to fused window shift & window partition for acceleration, similar for the reversed part. Default: False 254 | task_configs (dict): Configuration for the tasks 255 | use_tsn_layer (bool, optional): Whether to use Task Scaled Normalization or regular layer normalization 256 | """ 257 | 258 | def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, 259 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., 260 | act_layer=nn.GELU, norm_layer=nn.LayerNorm, 261 | fused_window_process=False, task_configs=None, hidden_size=343): 262 | super().__init__() 263 | self.dim = dim 264 | self.input_resolution = input_resolution 265 | self.num_heads = num_heads 266 | self.window_size = window_size 267 | self.shift_size = shift_size 268 | self.mlp_ratio = mlp_ratio 269 | self.task_configs = task_configs 270 | if min(self.input_resolution) <= self.window_size: 271 | # if window size is larger than input resolution, we don't partition windows 272 | self.shift_size = 0 273 | self.window_size = min(self.input_resolution) 274 | assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" 275 | 276 | 277 | self.norm1 = norm_layer(dim) 278 | 279 | 280 | self.attn = WindowAttention( 281 | dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, 282 | qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, task_configs=task_configs) 283 | 284 | self.drop_path = DropPath( 285 | drop_path) if drop_path > 0. else nn.Identity() 286 | 287 | self.norm2 = norm_layer(dim) 288 | 289 | 290 | mlp_hidden_dim = int(dim * mlp_ratio) 291 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, 292 | act_layer=act_layer, drop=drop) 293 | 294 | if self.shift_size > 0: 295 | # calculate attention mask for SW-MSA 296 | H, W = self.input_resolution 297 | img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 298 | h_slices = (slice(0, -self.window_size), 299 | slice(-self.window_size, -self.shift_size), 300 | slice(-self.shift_size, None)) 301 | w_slices = (slice(0, -self.window_size), 302 | slice(-self.window_size, -self.shift_size), 303 | slice(-self.shift_size, None)) 304 | cnt = 0 305 | for h in h_slices: 306 | for w in w_slices: 307 | img_mask[:, h, w, :] = cnt 308 | cnt += 1 309 | 310 | # nW, window_size, window_size, 1 311 | mask_windows = window_partition(img_mask, self.window_size) 312 | mask_windows = mask_windows.view(-1, 313 | self.window_size * self.window_size) 314 | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) 315 | attn_mask = attn_mask.masked_fill( 316 | attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) 317 | else: 318 | attn_mask = None 319 | 320 | self.register_buffer("attn_mask", attn_mask) 321 | self.fused_window_process = fused_window_process 322 | 323 | def forward(self, x, task_embedding=None, task_id = None): 324 | H, W = self.input_resolution 325 | B, L, C = x.shape 326 | assert L == H * W, "input feature has wrong size" 327 | 328 | skipconnect = x 329 | 330 | x = self.norm1(x) 331 | 332 | x = x.view(B, H, W, C) 333 | 334 | # cyclic shift 335 | if self.shift_size > 0: 336 | if not self.fused_window_process: 337 | shifted_x = torch.roll( 338 | x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) 339 | # partition windows 340 | # nW*B, window_size, window_size, C 341 | x_windows = window_partition(shifted_x, self.window_size) 342 | else: 343 | x_windows = WindowProcess.apply( 344 | x, B, H, W, C, -self.shift_size, self.window_size) 345 | else: 346 | shifted_x = x 347 | # partition windows 348 | # nW*B, window_size, window_size, C 349 | x_windows = window_partition(shifted_x, self.window_size) 350 | 351 | # nW*B, window_size*window_size, C 352 | x_windows = x_windows.view(-1, self.window_size * self.window_size, C) 353 | 354 | # W-MSA/SW-MSA 355 | # nW*B, window_size*window_size, C 356 | attn_windows = self.attn( 357 | x_windows, task_embedding=task_embedding, mask=self.attn_mask) 358 | 359 | # merge windows 360 | attn_windows = attn_windows.view(-1, 361 | self.window_size, self.window_size, C) 362 | 363 | 364 | # reverse cyclic shift 365 | if self.shift_size > 0: 366 | if not self.fused_window_process: 367 | shifted_x = window_reverse( 368 | attn_windows, self.window_size, H, W) # B H' W' C 369 | x = torch.roll(shifted_x, shifts=( 370 | self.shift_size, self.shift_size), dims=(1, 2)) 371 | else: 372 | x = WindowProcessReverse.apply( 373 | attn_windows, B, H, W, C, self.shift_size, self.window_size) 374 | else: 375 | shifted_x = window_reverse( 376 | attn_windows, self.window_size, H, W) # B H' W' C 377 | x = shifted_x 378 | 379 | x = x.view(B, H * W, C) 380 | x = skipconnect + self.drop_path(x) 381 | 382 | ''' Feed Forward Network''' 383 | x = x + self.drop_path(self.mlp(self.norm2(x))) 384 | 385 | 386 | return x 387 | 388 | def extra_repr(self) -> str: 389 | return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ 390 | f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" 391 | 392 | def flops(self): 393 | flops = 0 394 | H, W = self.input_resolution 395 | # norm1 396 | flops += self.dim * H * W 397 | # W-MSA/SW-MSA 398 | nW = H * W / self.window_size / self.window_size 399 | flops += nW * self.attn.flops(self.window_size * self.window_size) 400 | # mlp 401 | flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio 402 | # norm2 403 | flops += self.dim * H * W 404 | return flops 405 | 406 | 407 | class PatchMerging(nn.Module): 408 | r""" Patch Merging Layer. 409 | 410 | Args: 411 | input_resolution (tuple[int]): Resolution of input feature. 412 | dim (int): Number of input channels. 413 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 414 | """ 415 | 416 | def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): 417 | super().__init__() 418 | self.input_resolution = input_resolution 419 | self.dim = dim 420 | self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) 421 | self.norm = norm_layer(4 * dim) 422 | 423 | def forward(self, x): 424 | """ 425 | x: B, H*W, C 426 | """ 427 | H, W = self.input_resolution 428 | B, L, C = x.shape 429 | assert L == H * W, "input feature has wrong size" 430 | assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." 431 | 432 | x = x.view(B, H, W, C) 433 | 434 | x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C 435 | x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C 436 | x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C 437 | x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C 438 | x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C 439 | x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C 440 | 441 | x = self.norm(x) 442 | x = self.reduction(x) 443 | 444 | return x 445 | 446 | def extra_repr(self) -> str: 447 | return f"input_resolution={self.input_resolution}, dim={self.dim}" 448 | 449 | def flops(self): 450 | H, W = self.input_resolution 451 | flops = H * W * self.dim 452 | flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim 453 | return flops 454 | 455 | 456 | class BasicLayer(nn.Module): 457 | """ A basic Swin Transformer layer for one stage. 458 | 459 | Args: 460 | dim (int): Number of input channels. 461 | input_resolution (tuple[int]): Input resolution. 462 | depth (int): Number of blocks. 463 | num_heads (int): Number of attention heads. 464 | window_size (int): Local window size. 465 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 466 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 467 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. 468 | drop (float, optional): Dropout rate. Default: 0.0 469 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 470 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 471 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 472 | downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None 473 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. 474 | fused_window_process (bool, optional): If True, use one kernel to fused window shift & window partition for acceleration, similar for the reversed part. Default: False 475 | task_configs (dict): Configuration for the tasks 476 | conditioned_blocks (list): List of transformer blocks to adapt 477 | adapter (boolean): Whether to use adapters or not 478 | use_tsn_layer (boolean): Whether to use regular or task scaled normalization 479 | """ 480 | 481 | def __init__(self, dim, input_resolution, depth, num_heads, window_size, 482 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., 483 | drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, 484 | fused_window_process=False, task_configs=None, conditioned_blocks=[0], adapter=False, hidden_size=343): 485 | 486 | super().__init__() 487 | self.dim = dim 488 | self.input_resolution = input_resolution 489 | self.depth = depth 490 | self.use_checkpoint = use_checkpoint 491 | self.task_configs = task_configs 492 | self.adapter = adapter 493 | 494 | # build blocks 495 | self.blocks = nn.ModuleList([ 496 | SwinTransformerBlock(dim=dim, input_resolution=input_resolution, 497 | num_heads=num_heads, window_size=window_size, 498 | shift_size=0 if ( 499 | i % 2 == 0) else window_size // 2, 500 | mlp_ratio=mlp_ratio, 501 | qkv_bias=qkv_bias, qk_scale=qk_scale, 502 | drop=drop, attn_drop=attn_drop, 503 | drop_path=drop_path[i] if isinstance( 504 | drop_path, list) else drop_path, 505 | norm_layer=norm_layer, 506 | fused_window_process=fused_window_process, 507 | task_configs=task_configs if i in conditioned_blocks else None, 508 | hidden_size = hidden_size) 509 | for i in range(depth)]) 510 | 511 | if self.adapter: 512 | self.adapter_layer = nn.ModuleList([ 513 | ConditionalBottleNeck(task_configs["hidden_size"], self.dim) 514 | for i in range(depth)]) 515 | else: 516 | self.adapter_layer = [None for i in range(depth)] 517 | 518 | # patch merging layer 519 | if downsample is not None: 520 | self.downsample = downsample( 521 | input_resolution, dim=dim, norm_layer=norm_layer) 522 | 523 | self.downsample_bottleneck = downsample( 524 | input_resolution, dim=dim, norm_layer=norm_layer) 525 | else: 526 | self.downsample = None 527 | self.downsample_bottleneck = None 528 | 529 | def forward(self, x, hidden_film=None, task_embedding=None, task_id = None): 530 | 531 | if hidden_film is None: 532 | hidden_film = torch.zeros_like(x) 533 | 534 | for i, (blk, adapter_module) in enumerate(zip(self.blocks, self.adapter_layer)): 535 | if self.use_checkpoint: 536 | x = checkpoint.checkpoint(blk, x, task_embedding, task_id) 537 | else: 538 | x = blk(x, task_embedding=task_embedding, task_id = task_id) 539 | 540 | 541 | if self.adapter: 542 | hidden_film = adapter_module( 543 | x_cond=task_embedding, hidden_states=x + hidden_film 544 | ) 545 | 546 | else: 547 | hidden_film = None 548 | 549 | 550 | 551 | 552 | if self.downsample is not None: 553 | x = self.downsample(x) 554 | if self.adapter: 555 | hidden_film = self.downsample_bottleneck(hidden_film) 556 | return x, hidden_film 557 | 558 | def extra_repr(self) -> str: 559 | return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" 560 | 561 | def flops(self): 562 | flops = 0 563 | for blk in self.blocks: 564 | flops += blk.flops() 565 | if self.downsample is not None: 566 | flops += self.downsample.flops() 567 | return flops 568 | 569 | 570 | class BasicLayer_up(nn.Module): 571 | """ A basic Swin Transformer layer for one stage. 572 | 573 | Args: 574 | dim (int): Number of input channels. 575 | input_resolution (tuple[int]): Input resolution. 576 | depth (int): Number of blocks. 577 | num_heads (int): Number of attention heads. 578 | window_size (int): Local window size. 579 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 580 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 581 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. 582 | drop (float, optional): Dropout rate. Default: 0.0 583 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 584 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 585 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 586 | downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None 587 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. 588 | use_tsn_layer (boolean): Whether to use regular or task scaled normalization 589 | """ 590 | 591 | def __init__(self, dim, input_resolution, depth, num_heads, window_size, 592 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., 593 | drop_path=0., norm_layer=nn.LayerNorm, upsample=None, use_checkpoint=False): 594 | 595 | super().__init__() 596 | self.dim = dim 597 | self.input_resolution = input_resolution 598 | self.depth = depth 599 | self.use_checkpoint = use_checkpoint 600 | 601 | # build blocks 602 | self.blocks = nn.ModuleList([ 603 | SwinTransformerBlock(dim=dim, input_resolution=input_resolution, 604 | num_heads=num_heads, window_size=window_size, 605 | shift_size=0 if ( 606 | i % 2 == 0) else window_size // 2, 607 | mlp_ratio=mlp_ratio, 608 | qkv_bias=qkv_bias, qk_scale=qk_scale, 609 | drop=drop, attn_drop=attn_drop, 610 | drop_path=drop_path[i] if isinstance( 611 | drop_path, list) else drop_path, 612 | norm_layer=norm_layer) 613 | for i in range(depth)]) 614 | 615 | # patch merging layer 616 | if upsample is not None: 617 | self.upsample = PatchExpand( 618 | input_resolution, dim=dim, dim_scale=2, norm_layer=norm_layer) 619 | else: 620 | self.upsample = None 621 | 622 | def forward(self, x): 623 | for blk in self.blocks: 624 | if self.use_checkpoint: 625 | x = checkpoint.checkpoint(blk, x) 626 | else: 627 | x = blk(x) 628 | if self.upsample is not None: 629 | x = self.upsample(x) 630 | return x 631 | 632 | 633 | class PatchEmbed(nn.Module): 634 | r""" Image to Patch Embedding 635 | 636 | Args: 637 | img_size (int): Image size. Default: 224. 638 | patch_size (int): Patch token size. Default: 4. 639 | in_chans (int): Number of input image channels. Default: 3. 640 | embed_dim (int): Number of linear projection output channels. Default: 96. 641 | norm_layer (nn.Module, optional): Normalization layer. Default: None 642 | """ 643 | 644 | def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): 645 | super().__init__() 646 | img_size = to_2tuple(img_size) 647 | patch_size = to_2tuple(patch_size) 648 | patches_resolution = [img_size[0] // 649 | patch_size[0], img_size[1] // patch_size[1]] 650 | self.img_size = img_size 651 | self.patch_size = patch_size 652 | self.patches_resolution = patches_resolution 653 | self.num_patches = patches_resolution[0] * patches_resolution[1] 654 | 655 | self.in_chans = in_chans 656 | self.embed_dim = embed_dim 657 | 658 | self.proj = nn.Conv2d(in_chans, embed_dim, 659 | kernel_size=patch_size, stride=patch_size) 660 | if norm_layer is not None: 661 | self.norm = norm_layer(embed_dim) 662 | else: 663 | self.norm = None 664 | 665 | def forward(self, x): 666 | B, C, H, W = x.shape 667 | # FIXME look at relaxing size constraints 668 | assert H == self.img_size[0] and W == self.img_size[1], \ 669 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 670 | x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C 671 | if self.norm is not None: 672 | x = self.norm(x) 673 | return x 674 | 675 | def flops(self): 676 | Ho, Wo = self.patches_resolution 677 | flops = Ho * Wo * self.embed_dim * self.in_chans * \ 678 | (self.patch_size[0] * self.patch_size[1]) 679 | if self.norm is not None: 680 | flops += Ho * Wo * self.embed_dim 681 | return flops 682 | 683 | 684 | class PatchExpand(nn.Module): 685 | def __init__(self, input_resolution, dim, dim_scale=2, norm_layer=nn.LayerNorm): 686 | super().__init__() 687 | self.input_resolution = input_resolution 688 | self.dim = dim 689 | self.expand = nn.Linear( 690 | dim, 2*dim, bias=False) if dim_scale == 2 else nn.Identity() 691 | 692 | self.norm = norm_layer(dim*2) 693 | 694 | def forward(self, x): 695 | """ 696 | x: B, H*W, C 697 | """ 698 | H, W = self.input_resolution 699 | x = self.expand(x) 700 | B, L, C = x.shape 701 | assert L == H * W, "Input feature has wrong size" 702 | 703 | x = self.norm(x) ###### bu 704 | 705 | x = x.view(B, H, W, C) 706 | x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=2, p2=2, c=C//4) 707 | x = x.view(B, -1, C//4) 708 | return x 709 | 710 | 711 | class FinalPatchExpand_X4(nn.Module): 712 | def __init__(self, input_resolution, dim, dim_scale=4, norm_layer=nn.LayerNorm): 713 | super().__init__() 714 | self.input_resolution = input_resolution 715 | self.dim = dim 716 | self.dim_scale = dim_scale 717 | self.expand = nn.Linear(dim, 16*dim, bias=False) ###ikinci 16*dim 718 | self.output_dim = dim 719 | self.norm = norm_layer(16*dim)#norm_layer(self.output_dim) 720 | 721 | def forward(self, x): 722 | """ 723 | x: B, H*W, C 724 | """ 725 | H, W = self.input_resolution 726 | x = self.expand(x) 727 | B, L, C = x.shape 728 | assert L == H * W, "Input feature has wrong size" 729 | 730 | x = self.norm(x) 731 | 732 | x = x.view(B, H, W, C) 733 | x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=self.dim_scale, p2=self.dim_scale, c=C//(self.dim_scale**2)) 734 | x = x.view(B, -1, self.output_dim) 735 | 736 | #x = self.norm(x) 737 | return x 738 | 739 | class SwinTransformer(nn.Module): 740 | r""" Swin Transformer 741 | A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - 742 | https://arxiv.org/pdf/2103.14030 743 | 744 | Args: 745 | img_size (int | tuple(int)): Input image size. Default 224 746 | patch_size (int | tuple(int)): Patch size. Default: 4 747 | in_chans (int): Number of input image channels. Default: 3 748 | num_classes (int): Number of classes for classification head. Default: 1000 749 | embed_dim (int): Patch embedding dimension. Default: 96 750 | depths (tuple(int)): Depth of each Swin Transformer layer. 751 | num_heads (tuple(int)): Number of attention heads in different layers. 752 | window_size (int): Window size. Default: 7 753 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 754 | qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True 755 | qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None 756 | drop_rate (float): Dropout rate. Default: 0 757 | attn_drop_rate (float): Attention dropout rate. Default: 0 758 | drop_path_rate (float): Stochastic depth rate. Default: 0.1 759 | norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. 760 | ape (bool): If True, add absolute position embedding to the patch embedding. Default: False 761 | patch_norm (bool): If True, add normalization after patch embedding. Default: True 762 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False 763 | fused_window_process (bool, optional): If True, use one kernel to fused window shift & window partition for acceleration, similar for the reversed part. Default: False 764 | tasks (list): List of tasks 765 | final_upsample (strin): Setting to expand the last layer 766 | task_classes (list): List of number of prediction classes for each task 767 | conditioned_blocks (list): List of transformer blocks to adapt 768 | adapter (boolean): Whether to use adapters or not 769 | use_tsn_layer (boolean): Whether to use regular or task scaled normalization 770 | """ 771 | 772 | def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, 773 | embed_dim=96, depths=[2, 2, 6, 2], depths_decoder=[2,2,2,2], num_heads=[3, 6, 12, 24], 774 | window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, 775 | drop_rate=0.,drop_rate_decoder=0., attn_drop_rate=0., drop_path_rate=0.1, 776 | norm_layer=nn.LayerNorm, ape=False, patch_norm=True, 777 | use_checkpoint=False, fused_window_process=False, 778 | hidden_size=None, tasks=["segmentation"], final_upsample="expand_first", task_classes = [100], 779 | conditioned_blocks = [[],[],[12],[]], adapter = False, 780 | **kwargs): 781 | super().__init__() 782 | 783 | self.num_classes = num_classes 784 | self.num_layers = len(depths) 785 | self.embed_dim = embed_dim 786 | self.ape = ape 787 | self.patch_norm = patch_norm 788 | self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) 789 | self.mlp_ratio = mlp_ratio 790 | self.final_upsample = final_upsample 791 | self.task_classes = task_classes 792 | self.adapter = adapter 793 | 794 | assert len(task_classes) == len(tasks), "number of tasks must match the number of classes" 795 | 796 | assert len(conditioned_blocks) == self.num_layers, "give conditioned block index for each layer" 797 | 798 | # split image into non-overlapping patches 799 | self.patch_embed = PatchEmbed( 800 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, 801 | norm_layer=norm_layer if self.patch_norm else None) 802 | num_patches = self.patch_embed.num_patches 803 | patches_resolution = self.patch_embed.patches_resolution 804 | self.patches_resolution = patches_resolution 805 | 806 | # absolute position embedding 807 | if self.ape: 808 | self.absolute_pos_embed = nn.Parameter( 809 | torch.zeros(1, num_patches, embed_dim)) 810 | trunc_normal_(self.absolute_pos_embed, std=.02) 811 | 812 | self.pos_drop = nn.Dropout(p=drop_rate) 813 | 814 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, 815 | sum(depths))] 816 | self.max_seq_length = window_size * window_size 817 | 818 | if hidden_size is None: 819 | self.hidden_size = self.max_seq_length * \ 820 | window_size 821 | else: 822 | self.hidden_size = hidden_size 823 | 824 | self.task_id_2_task_idx = {i: i for i, t in enumerate(tasks)} 825 | 826 | self.task_type_embeddings = nn.Embedding( 827 | len(tasks), self.hidden_size) 828 | 829 | self.task_configs = {"hidden_size": self.hidden_size, 830 | "max_seq_length": self.max_seq_length} 831 | 832 | # build layers 833 | self.layers = nn.ModuleList() 834 | for i_layer in range(self.num_layers): 835 | 836 | layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer), 837 | input_resolution=(patches_resolution[0] // (2 ** i_layer), 838 | patches_resolution[1] // (2 ** i_layer)), 839 | depth=depths[i_layer], 840 | num_heads=num_heads[i_layer], 841 | window_size=window_size, 842 | mlp_ratio=self.mlp_ratio, 843 | qkv_bias=qkv_bias, qk_scale=qk_scale, 844 | drop=drop_rate, attn_drop=attn_drop_rate, 845 | drop_path=dpr[sum(depths[:i_layer]):sum( 846 | depths[:i_layer + 1])], 847 | norm_layer=norm_layer, 848 | downsample=PatchMerging if ( 849 | i_layer < self.num_layers - 1) else None, 850 | use_checkpoint=use_checkpoint, 851 | fused_window_process=fused_window_process, 852 | task_configs=self.task_configs, 853 | conditioned_blocks=conditioned_blocks[i_layer], 854 | adapter = adapter, 855 | hidden_size = self.hidden_size) 856 | self.layers.append(layer) 857 | 858 | self.norm = norm_layer(self.num_features) 859 | 860 | # Decoder Module 861 | self.decoder_layers_layers_up = nn.ModuleList() 862 | self.decoder_layers_concat_back_dim = nn.ModuleList() 863 | self.decoder_layers_norm_up = nn.ModuleList() 864 | self.decoder_layers_up = nn.ModuleList() 865 | self.decoder_layers_output = nn.ModuleList() 866 | 867 | for i, task in enumerate(tasks): 868 | task_modules = dict() 869 | layers_up = nn.ModuleList() 870 | concat_back_dim = nn.ModuleList() 871 | for i_layer in range(self.num_layers): 872 | concat_linear = nn.Linear(2*int(embed_dim*2**(self.num_layers-1-i_layer)), 873 | int(embed_dim*2**(self.num_layers-1-i_layer))) if i_layer > 0 else nn.Identity() 874 | if i_layer == 0: 875 | layer_up = PatchExpand(input_resolution=(patches_resolution[0] // (2 ** (self.num_layers-1-i_layer)), 876 | patches_resolution[1] // (2 ** (self.num_layers-1-i_layer))), dim=int(embed_dim * 2 ** (self.num_layers-1-i_layer)), dim_scale=2, norm_layer=norm_layer) 877 | else: 878 | layer_up = BasicLayer_up(dim=int(embed_dim * 2 ** (self.num_layers-1-i_layer)), 879 | input_resolution=(patches_resolution[0] // (2 ** (self.num_layers-1-i_layer)), 880 | patches_resolution[1] // (2 ** (self.num_layers-1-i_layer))), 881 | depth=depths_decoder[( 882 | self.num_layers-1-i_layer)], 883 | num_heads=num_heads[( 884 | self.num_layers-1-i_layer)], 885 | window_size=window_size, 886 | mlp_ratio=self.mlp_ratio, 887 | qkv_bias=qkv_bias, qk_scale=qk_scale, 888 | drop=drop_rate_decoder, attn_drop=attn_drop_rate, 889 | drop_path=dpr[sum(depths[:( 890 | self.num_layers-1-i_layer)]):sum(depths[:(self.num_layers-1-i_layer) + 1])], 891 | norm_layer=norm_layer, 892 | upsample=PatchExpand if ( 893 | i_layer < self.num_layers - 1) else None, 894 | use_checkpoint=use_checkpoint) 895 | layers_up.append(layer_up) 896 | concat_back_dim.append(concat_linear) 897 | 898 | self.decoder_layers_layers_up.append(layers_up) 899 | self.decoder_layers_concat_back_dim.append(concat_back_dim) 900 | self.decoder_layers_norm_up.append(norm_layer(self.embed_dim)) 901 | 902 | if self.final_upsample == "expand_first": 903 | up = FinalPatchExpand_X4(input_resolution=( 904 | img_size//patch_size, img_size//patch_size), dim_scale=4, dim=embed_dim) 905 | 906 | self.decoder_layers_up.append(up) 907 | 908 | dec_output = nn.Conv2d( 909 | in_channels=embed_dim, out_channels=self.task_classes[i], kernel_size=1, bias=False) 910 | 911 | self.decoder_layers_output.append(dec_output) 912 | 913 | self.apply(self._init_weights) 914 | 915 | def _init_weights(self, m): 916 | if isinstance(m, nn.Linear): 917 | trunc_normal_(m.weight, std=.02) 918 | if isinstance(m, nn.Linear) and m.bias is not None: 919 | nn.init.constant_(m.bias, 0) 920 | elif isinstance(m, nn.LayerNorm): 921 | nn.init.constant_(m.bias, 0) 922 | nn.init.constant_(m.weight, 1.0) 923 | 924 | @torch.jit.ignore 925 | def no_weight_decay(self): 926 | return {'absolute_pos_embed'} 927 | 928 | @torch.jit.ignore 929 | def no_weight_decay_keywords(self): 930 | return {'relative_position_bias_table'} 931 | 932 | def forward_features_old(self, x, task_embedding): 933 | x = self.patch_embed(x) 934 | if self.ape: 935 | x = x + self.absolute_pos_embed 936 | x = self.pos_drop(x) 937 | 938 | for layer in self.layers: 939 | x, hidden_film = layer(x, task_embedding=task_embedding) 940 | 941 | x = self.norm(x) 942 | 943 | 944 | return x 945 | 946 | def forward_features(self, x, task_embedding, task_id): 947 | x = self.patch_embed(x) 948 | if self.ape: 949 | x = x + self.absolute_pos_embed 950 | x = self.pos_drop(x) 951 | x_downsample = [] 952 | 953 | for i,layer in enumerate(self.layers): 954 | x_downsample.append(x) 955 | if i == 0: 956 | x, hidden = layer(x, task_embedding=task_embedding, task_id = task_id) 957 | else: 958 | x, hidden = layer(x, hidden, task_embedding=task_embedding, task_id = task_id) 959 | 960 | if self.adapter: 961 | x = hidden 962 | 963 | x = self.norm(x) 964 | 965 | return x, x_downsample 966 | 967 | #Skip connection 968 | def forward_up_features(self, x, x_downsample, layers_up, concat_back_dim, norm_up, print = False): 969 | for inx, layer_up in enumerate(layers_up): 970 | if inx == 0: 971 | x = layer_up(x) 972 | else: 973 | x = torch.cat([x, x_downsample[3-inx]], -1) 974 | x = concat_back_dim[inx](x) 975 | x = layer_up(x) 976 | x = norm_up(x) 977 | return x 978 | 979 | def up_x4(self, x, up, output): 980 | H, W = self.patches_resolution 981 | B, L, C = x.shape 982 | assert L == H*W, "Input features have wrong size" 983 | 984 | if self.final_upsample == "expand_first": 985 | x = up(x) 986 | x = x.view(B, 4*H, 4*W, -1) 987 | x = x.permute(0, 3, 1, 2) 988 | x = output(x) 989 | return x 990 | 991 | def forward_old(self, x, task_id): 992 | task_type = self._create_task_type(task_id) 993 | task_embedding = self.task_type_embeddings(task_type) 994 | x = self.forward_features(x, task_embedding) 995 | return x 996 | 997 | def forward(self, x, task_id): 998 | task_type, unique_task_ids_list = self._create_task_type(task_id) 999 | task_embedding = self.task_type_embeddings(task_type) 1000 | 1001 | x, x_downsample = self.forward_features(x, task_embedding, task_id) 1002 | 1003 | logits = [None]*len(self.task_classes) 1004 | 1005 | for unique_task_id in unique_task_ids_list: 1006 | task_id_filter = task_id == unique_task_id 1007 | layers_up = self.decoder_layers_layers_up[unique_task_id] 1008 | concat_back_dim = self.decoder_layers_concat_back_dim[unique_task_id] 1009 | norm_up = self.decoder_layers_norm_up[unique_task_id] 1010 | up = self.decoder_layers_up[unique_task_id] 1011 | dec_output = self.decoder_layers_output[unique_task_id] 1012 | 1013 | x_downsample_up = [] 1014 | 1015 | for x_it in x_downsample: 1016 | x_downsample_up.append(x_it[task_id_filter]) 1017 | 1018 | 1019 | if unique_task_id == 1: 1020 | x_up = self.forward_up_features( 1021 | x[task_id_filter], x_downsample_up, layers_up, concat_back_dim, norm_up) 1022 | else: 1023 | x_up = self.forward_up_features( 1024 | x[task_id_filter], x_downsample_up, layers_up, concat_back_dim, norm_up) 1025 | x_up = self.up_x4(x_up, up, dec_output) 1026 | logits[unique_task_id] = x_up 1027 | 1028 | return tuple(logits) 1029 | 1030 | 1031 | def flops(self): 1032 | flops = 0 1033 | flops += self.patch_embed.flops() 1034 | for i, layer in enumerate(self.layers): 1035 | flops += layer.flops() 1036 | flops += self.num_features * \ 1037 | self.patches_resolution[0] * \ 1038 | self.patches_resolution[1] // (2 ** self.num_layers) 1039 | flops += self.num_features * self.num_classes 1040 | return flops 1041 | 1042 | def _create_task_type(self, task_id): 1043 | task_type = task_id.clone() 1044 | unique_task_ids = torch.unique(task_type) 1045 | unique_task_ids_list = ( 1046 | unique_task_ids.cpu().numpy() 1047 | if unique_task_ids.is_cuda 1048 | else unique_task_ids.numpy() 1049 | ) 1050 | for unique_task_id in unique_task_ids_list: 1051 | task_type[task_type == unique_task_id] = self.task_id_2_task_idx[ 1052 | unique_task_id 1053 | ] 1054 | return task_type, unique_task_ids_list 1055 | --------------------------------------------------------------------------------