├── script
├── ignore.txt
├── model
│ ├── __init__.py
│ ├── conditional_modules.py
│ └── swin_transformer_parallel.py
├── data
│ ├── taskonomy
│ │ ├── __init__.py
│ │ ├── metadata
│ │ │ ├── train_val_test_debug.csv
│ │ │ └── train_val_test_tiny.csv
│ │ ├── splits.py
│ │ ├── task_configs.py
│ │ ├── transforms.py
│ │ └── taskonomy_dataset_s3.py
│ ├── nyuv2.py
│ └── nyuv2_same_batch.py
├── requirements.txt
├── loss
│ ├── __pycache__
│ │ ├── losses.cpython-38.pyc
│ │ └── metrics.cpython-38.pyc
│ ├── losses.py
│ └── metrics.py
├── evaluate.py
├── train_nyu.py
├── train_nyu_single_task.py
└── train_taskonomy.py
├── avtar.gif
├── docs
├── code_icon.png
├── troa-new.png
├── youtube.png
├── nyu-qr-new.pdf
├── nyu-qr-new.png
├── TAA-finalised.png
├── film-finalised.pdf
├── film-finalised.png
├── teaser-iccv1.png
├── pdf_icon_32x32.jpeg
├── uda-results-new.pdf
├── uda-results-new.png
├── ICCV-presentation.pdf
├── bibtex_icon_36x36.png
├── presentation_icon.png
├── vision-adapter-new.pdf
├── vision-adapter-new.png
├── qualitative-taskonomy.png
├── overall-vision-adapter-architecture.png
├── offcanvas.css
└── index.html
├── README.md
├── LICENSE
└── .gitignore
/script/ignore.txt:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/script/model/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/avtar.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVRL/VTAGML/HEAD/avtar.gif
--------------------------------------------------------------------------------
/docs/code_icon.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVRL/VTAGML/HEAD/docs/code_icon.png
--------------------------------------------------------------------------------
/docs/troa-new.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVRL/VTAGML/HEAD/docs/troa-new.png
--------------------------------------------------------------------------------
/docs/youtube.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVRL/VTAGML/HEAD/docs/youtube.png
--------------------------------------------------------------------------------
/docs/nyu-qr-new.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVRL/VTAGML/HEAD/docs/nyu-qr-new.pdf
--------------------------------------------------------------------------------
/docs/nyu-qr-new.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVRL/VTAGML/HEAD/docs/nyu-qr-new.png
--------------------------------------------------------------------------------
/script/data/taskonomy/__init__.py:
--------------------------------------------------------------------------------
1 | from .taskonomy_dataset_s3 import TaskonomyDatasetS3
--------------------------------------------------------------------------------
/docs/TAA-finalised.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVRL/VTAGML/HEAD/docs/TAA-finalised.png
--------------------------------------------------------------------------------
/docs/film-finalised.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVRL/VTAGML/HEAD/docs/film-finalised.pdf
--------------------------------------------------------------------------------
/docs/film-finalised.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVRL/VTAGML/HEAD/docs/film-finalised.png
--------------------------------------------------------------------------------
/docs/teaser-iccv1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVRL/VTAGML/HEAD/docs/teaser-iccv1.png
--------------------------------------------------------------------------------
/docs/pdf_icon_32x32.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVRL/VTAGML/HEAD/docs/pdf_icon_32x32.jpeg
--------------------------------------------------------------------------------
/docs/uda-results-new.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVRL/VTAGML/HEAD/docs/uda-results-new.pdf
--------------------------------------------------------------------------------
/docs/uda-results-new.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVRL/VTAGML/HEAD/docs/uda-results-new.png
--------------------------------------------------------------------------------
/script/data/taskonomy/metadata/train_val_test_debug.csv:
--------------------------------------------------------------------------------
1 | id,train,val,test
2 | allensville,1,1,1
--------------------------------------------------------------------------------
/docs/ICCV-presentation.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVRL/VTAGML/HEAD/docs/ICCV-presentation.pdf
--------------------------------------------------------------------------------
/docs/bibtex_icon_36x36.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVRL/VTAGML/HEAD/docs/bibtex_icon_36x36.png
--------------------------------------------------------------------------------
/docs/presentation_icon.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVRL/VTAGML/HEAD/docs/presentation_icon.png
--------------------------------------------------------------------------------
/docs/vision-adapter-new.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVRL/VTAGML/HEAD/docs/vision-adapter-new.pdf
--------------------------------------------------------------------------------
/docs/vision-adapter-new.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVRL/VTAGML/HEAD/docs/vision-adapter-new.png
--------------------------------------------------------------------------------
/script/requirements.txt:
--------------------------------------------------------------------------------
1 | timm
2 | einops
3 | torchmetrics
4 | tensorboard
5 | transformers
6 | boto3
--------------------------------------------------------------------------------
/docs/qualitative-taskonomy.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVRL/VTAGML/HEAD/docs/qualitative-taskonomy.png
--------------------------------------------------------------------------------
/docs/overall-vision-adapter-architecture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVRL/VTAGML/HEAD/docs/overall-vision-adapter-architecture.png
--------------------------------------------------------------------------------
/script/loss/__pycache__/losses.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVRL/VTAGML/HEAD/script/loss/__pycache__/losses.cpython-38.pyc
--------------------------------------------------------------------------------
/script/loss/__pycache__/metrics.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVRL/VTAGML/HEAD/script/loss/__pycache__/metrics.cpython-38.pyc
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | ## Vision Transformer Adapters for Generalizable Multitask Learning
3 | Deblina Bhattacharjee, Sabine Süsstrunk, and Mathieu Salzmann
4 | [](https://zenodo.org/doi/10.5281/zenodo.11067070)
5 |
6 | ICCV 2023 Paper: https://arxiv.org/abs/2308.12372
7 |
8 | https://ivrl.github.io/VTAGML/
9 | 
10 |
11 |
12 |
--------------------------------------------------------------------------------
/script/data/taskonomy/metadata/train_val_test_tiny.csv:
--------------------------------------------------------------------------------
1 | id,train,val,test
hanson,1,0,0
merom,1,0,0
klickitat,1,0,0
onaga,1,0,0
leonardo,1,0,0
marstons,1,0,0
newfields,1,0,0
pinesdale,1,0,0
lakeville,1,0,0
cosmos,1,0,0
benevolence,1,0,0
pomaria,1,0,0
tolstoy,1,0,0
shelbyville,1,0,0
allensville,1,0,0
wainscott,1,0,0
beechwood,1,0,0
coffeen,1,0,0
stockman,1,0,0
hiteman,1,0,0
woodbine,1,0,0
lindenwood,1,0,0
forkland,1,0,0
mifflinburg,1,0,0
ranchester,1,0,0
wiconisco,0,1,0
corozal,0,1,0
collierville,0,1,0
markleeville,0,1,0
darden,0,1,0
ihlen,0,0,1
muleshoe,0,0,1
uvalda,0,0,1
noxapater,0,0,1
mcdade,0,0,1
--------------------------------------------------------------------------------
/script/data/taskonomy/splits.py:
--------------------------------------------------------------------------------
1 | import csv
2 | import os
3 |
4 |
5 | def get_splits(split_path, forbidden_buildings=[]):
6 | with open(split_path) as csvfile:
7 | readCSV = csv.reader(csvfile, delimiter=',')
8 |
9 | train_list = []
10 | val_list = []
11 | test_list = []
12 |
13 | for row in readCSV:
14 | name, is_train, is_val, is_test = row
15 | if name in forbidden_buildings:
16 | continue
17 | if is_train == '1':
18 | train_list.append(name)
19 | if is_val == '1':
20 | val_list.append(name)
21 | if is_test == '1':
22 | test_list.append(name)
23 | return {
24 | 'train': sorted(train_list),
25 | 'val': sorted(val_list),
26 | 'test': sorted(test_list)
27 | }
28 |
29 |
30 |
31 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Images and Visual Representation Laboratory (IVRL) at EPFL
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/script/loss/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | import torch.cuda.amp as amp
6 |
7 | class berHuLoss(nn.Module):
8 | def __init__(self):
9 | """
10 | https://github.com/lhoyer/improving_segmentation_with_selfsupervised_depth/
11 | """
12 | super(berHuLoss, self).__init__()
13 |
14 |
15 | def make_valid_mask(self, tens, mask_val, conf=1e-7):
16 |
17 | valid_mask = (tens > (mask_val+conf) ) | (tens < (mask_val-conf))
18 |
19 | return valid_mask
20 |
21 |
22 | def forward(self, inp, target, apply_log=False, threshold=.2, mask_val=None):
23 | if apply_log:
24 | inp, target = torch.log(1 + inp), torch.log(1 + target)
25 |
26 | if mask_val is None:
27 | valid_mask = (target > 0).detach()
28 | else:
29 | valid_mask = self.make_valid_mask(target, mask_val)
30 |
31 | absdiff = torch.abs(target - inp) * valid_mask #* mask
32 | C = threshold * torch.max(absdiff).item()
33 | loss = torch.mean(torch.where(absdiff <= C,
34 | absdiff,
35 | (absdiff * absdiff + C * C) / (2 * C)))
36 | return loss
--------------------------------------------------------------------------------
/docs/offcanvas.css:
--------------------------------------------------------------------------------
1 | /*
2 | * Style tweaks
3 | * --------------------------------------------------
4 | */
5 | html,
6 | body {
7 | overflow-x: hidden; /*Prevent scroll on narrow devices */
8 | padding-top: 30px;
9 | text-align: justify;
10 | }
11 | footer {
12 | padding: 10px 0;
13 | }
14 | .authors {
15 | font-size: 20px;
16 | }
17 | /*.container {
18 | max-width: 768px;
19 | }*/
20 | .container {
21 | max-width: 1000px;
22 | }
23 | p {
24 | font-size: 16px;
25 | /*padding-bottom: 20px;*/
26 | }
27 |
28 | li {
29 | font-size: 16px;
30 | }
31 |
32 | h2 {
33 | text-align: center;
34 | align: center;
35 | }
36 |
37 | .jumbotron{
38 | text-align: center;
39 | }
40 |
41 | .btn {
42 | font-size: 18px;
43 | }
44 |
45 | .btn-disabled {
46 | /*background-color: #f4f4f4;*/
47 | }
48 |
49 | .jumbotron h2 {
50 | font-size: 36px;
51 | }
52 |
53 | .section {
54 | padding-top: 30px;
55 | }
56 |
57 | .center{
58 | display: block;
59 | margin-left: auto;
60 | margin-right: auto;
61 | }
62 |
63 | .vcontainer {
64 | position: relative;
65 | width: 100%;
66 | height: 0;
67 | padding-bottom: 56.25%;
68 | }
69 | .video {
70 | position: absolute;
71 | top: 0;
72 | left: 0;
73 | width: 100%;
74 | height: 100%;
75 | }
76 |
77 | .gif {
78 | padding:10px;
79 | display: block;
80 | margin-left: auto;
81 | margin-right: auto;
82 | text-align: center;
83 | }
84 |
85 | .caption {
86 | width:75%;
87 | font-size:14px
88 | }
89 |
90 | .bibtexsection {
91 | font-family: "Courier",monospace;
92 | font-size:16px;
93 | white-space:pre;
94 | background-color: #f4f4f4;
95 | text-align:left;
96 | }
97 |
98 | .canvas-row canvas {
99 | max-width:100%;
100 | }
101 |
102 | .padding-0{
103 | padding-right:0;
104 | padding-left:0;
105 | }
106 |
107 | .vspace-top {
108 | margin-top: 30px;
109 | }
110 |
--------------------------------------------------------------------------------
/script/loss/metrics.py:
--------------------------------------------------------------------------------
1 | import functools
2 |
3 | import torch.nn.functional as F
4 | import torch
5 |
6 |
7 | def iou_pytorch(outputs: torch.Tensor, labels: torch.Tensor):
8 | """
9 | Taken from:
10 | https://www.kaggle.com/iezepov/fast-iou-scoring-metric-in-pytorch-and-numpy/comments
11 | """
12 |
13 | SMOOTH = 1e-6
14 | # You can comment out this line if you are passing tensors of equal shape
15 | # But if you are passing output from UNet or something it will most probably
16 | # be with the BATCH x 1 x H x W shape
17 | outputs = outputs.squeeze(1) # BATCH x 1 x H x W => BATCH x H x W
18 | labels = labels.squeeze(1)
19 |
20 | intersection = (outputs & labels).float().sum(
21 | (1, 2)) # Will be zero if Truth=0 or Prediction=0
22 |
23 | union = (outputs | labels).float().sum(
24 | (1, 2)) # Will be zzero if both are 0
25 |
26 | # We smooth our devision to avoid 0/0
27 | iou = (intersection + SMOOTH) / (union + SMOOTH)
28 |
29 | # This is equal to comparing with thresolds
30 | thresholded = torch.clamp(20 * (iou - 0.5), 0, 10).ceil() / 10
31 |
32 | # Or thresholded.mean() if you are interested in average across the batch
33 | return thresholded.mean()
34 |
35 |
36 |
37 | def eval_depth(pred, target):
38 |
39 | """
40 | Taken from:
41 | https://github.com/wl-zhao/VPD/blob/main/depth/utils_depth/metrics.py
42 | """
43 |
44 | rmse_temp = 0
45 | d1_temp = 0
46 |
47 | for current_target, current_pred in zip(target, pred):
48 | ##assert current_gt_sparse.shape == current_pred.shape
49 |
50 | thresh = torch.max((current_target / current_pred), (current_pred / current_target))
51 |
52 | d1 = (thresh < 1.25).float().mean()#torch.sum(thresh < 1.25).float().mean()# / len(thresh)
53 | #d2 = torch.sum(thresh < 1.25 ** 2).float() / len(thresh)
54 | #d3 = torch.sum(thresh < 1.25 ** 3).float() / len(thresh)
55 |
56 | diff = current_pred - current_target
57 | diff_log = torch.log(current_pred) - torch.log(current_target)
58 |
59 | #abs_rel = torch.mean(torch.abs(diff) / target)
60 | #sq_rel = torch.mean(torch.pow(diff, 2) / target)
61 |
62 | rmse = torch.sqrt(torch.mean(torch.pow(diff, 2)))
63 | rmse_log = torch.sqrt(torch.mean(torch.pow(diff_log , 2)))
64 |
65 | #log10 = torch.mean(torch.abs(torch.log10(pred) - torch.log10(target)))
66 | #silog = torch.sqrt(torch.pow(diff_log, 2).mean() - 0.5 * torch.pow(diff_log.mean(), 2))
67 |
68 | #return {'d1': d1.item(), 'd2': d2.item(), 'd3': d3.item(), 'abs_rel': abs_rel.item(),
69 | # 'sq_rel': sq_rel.item(), 'rmse': rmse.item(), 'rmse_log': rmse_log.item(),
70 | # 'log10':log10.item(), 'silog':silog.item()}
71 |
72 | rmse_temp += rmse
73 | d1_temp += d1
74 |
75 | return {'d1': d1_temp.item()/len(pred),'rmse': rmse_temp.item()/len(pred)}
76 |
--------------------------------------------------------------------------------
/script/data/taskonomy/task_configs.py:
--------------------------------------------------------------------------------
1 | ####################
2 | # Tasks
3 | ####################
4 | import torch
5 |
6 |
7 | task_parameters = {
8 | 'class_object':{
9 | 'num_classes': 1000,
10 | 'ext': 'npy',
11 | 'domain_id': 'class_object',
12 | },
13 | 'class_scene':{
14 | 'num_classes': 365,
15 | 'ext': 'npy',
16 | 'domain_id': 'class_scene',
17 | },
18 | 'depth_zbuffer':{
19 | 'num_channels': 1,
20 | 'mask_val': 1.0,
21 | 'clamp_to': (0.0, 8000.0 / (2**16 - 1)), # Same as consistency
22 | 'ext': 'png',
23 | 'domain_id': 'depth_zbuffer',
24 | },
25 | 'depth_euclidean':{
26 | 'num_channels': 1,
27 | 'clamp_to': (0.0, 8000.0 / (2**16 - 1)), # Same as consistency
28 | # 'mask_val': 1.0,
29 | 'ext': 'png',
30 | 'domain_id': 'depth_euclidean',
31 | },
32 | 'edge_texture': {
33 | 'num_channels': 1,
34 | 'clamp_to': (0.0, 0.25),
35 | 'ext': 'png',
36 | 'domain_id': 'edge_texture',
37 | },
38 | 'edge_occlusion': {
39 | 'num_channels': 1,
40 | 'ext': 'png',
41 | 'domain_id': 'edge_occlusion',
42 | },
43 | 'keypoints3d': {
44 | 'num_channels': 1,
45 | 'ext': 'png',
46 | 'domain_id': 'keypoints3d',
47 | },
48 | 'keypoints2d':{
49 | 'num_channels': 1,
50 | 'ext': 'png',
51 | 'domain_id': 'keypoints2d',
52 | },
53 | 'principal_curvature':{
54 | 'num_channels': 3,
55 | 'mask_val': 0.0,
56 | 'ext': 'png',
57 | 'domain_id': 'principal_curvature',
58 | },
59 | 'reshading':{
60 | 'num_channels': 1,
61 | 'ext': 'png',
62 | 'domain_id': 'reshading',
63 | },
64 | 'normal':{
65 | 'num_channels': 3,
66 | 'mask_val': 0.502,
67 | 'ext': 'png',
68 | 'domain_id': 'normal',
69 | },
70 | 'mask_valid':{
71 | 'num_channels': 1,
72 | 'mask_val': 0.0,
73 | 'ext': 'png',
74 | 'domain_id': 'depth_zbuffer',
75 | },
76 | 'rgb':{
77 | 'num_channels': 3,
78 | 'ext': 'png',
79 | 'domain_id': 'rgb',
80 | },
81 | 'segment_semantic': {
82 | 'num_channels': 18,
83 | 'ext': 'png',
84 | 'domain_id': 'segmentsemantic',
85 | },
86 | 'segment_unsup2d':{
87 | 'num_channels': 64,
88 | 'ext': 'png',
89 | 'domain_id': 'segment_unsup2d',
90 | },
91 | 'segment_unsup25d':{
92 | 'num_channels': 64,
93 | 'ext': 'png',
94 | 'domain_id': 'segment_unsup25d',
95 | },
96 | }
97 |
98 |
99 | PIX_TO_PIX_TASKS = ['colorization', 'edge_texture', 'edge_occlusion', 'keypoints3d', 'keypoints2d', 'reshading', 'depth_zbuffer', 'depth_euclidean', 'curvature', 'autoencoding', 'denoising', 'normal', 'inpainting', 'segment_unsup2d', 'segment_unsup25d', 'segment_semantic', ]
100 | FEED_FORWARD_TASKS = ['class_object', 'class_scene', 'room_layout', 'vanishing_point']
101 | SINGLE_IMAGE_TASKS = PIX_TO_PIX_TASKS + FEED_FORWARD_TASKS
102 | SIAMESE_TASKS = ['fix_pose', 'jigsaw', 'ego_motion', 'point_match', 'non_fixated_pose']
103 |
104 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 |
--------------------------------------------------------------------------------
/script/data/taskonomy/transforms.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | import numpy as np
3 | import torch
4 | import torchvision
5 | import torchvision.transforms as transforms
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | from typing import Optional
9 |
10 | from .task_configs import task_parameters
11 |
12 |
13 | MAKE_RESCALE_0_1_NEG1_POS1 = lambda n_chan: transforms.Normalize([0.5]*n_chan, [0.5]*n_chan)
14 | RESCALE_0_1_NEG1_POS1 = transforms.Normalize([0.5], [0.5]) # This needs to be different depending on num out chans
15 | MAKE_RESCALE_0_MAX_NEG1_POS1 = lambda maxx: transforms.Normalize([maxx / 2.], [maxx * 1.0])
16 | RESCALE_0_255_NEG1_POS1 = transforms.Normalize([127.5,127.5,127.5], [255, 255, 255])
17 | MAKE_RESCALE_0_MAX_0_POS1 = lambda maxx: transforms.Normalize([0.0], [maxx * 1.0])
18 |
19 | # For semantic segmentation
20 | transform_dense_labels = lambda img: torch.Tensor(np.array(img)).long() # avoids normalizing
21 |
22 | # Transforms to a 3-channel tensor and then changes [0,1] -> [0, 1]
23 | transform_8bit = transforms.Compose([
24 | transforms.ToTensor(),
25 | ])
26 |
27 | # Transforms to a n-channel tensor and then changes [0,1] -> [0, 1]. Keeps only the first n-channels
28 | def transform_8bit_n_channel(n_channel=1, crop_channels=True):
29 | if crop_channels:
30 | crop_channels_fn = lambda x: x[:n_channel] if x.shape[0] > n_channel else x
31 | else:
32 | crop_channels_fn = lambda x: x
33 | return transforms.Compose([
34 | transforms.ToTensor(),
35 | crop_channels_fn,
36 | ])
37 |
38 | # Transforms to a 1-channel tensor and then changes [0,1] -> [0, 1].
39 | def transform_16bit_single_channel(im):
40 | im = transforms.ToTensor()(np.array(im))
41 | im = im.float() / (2 ** 16 - 1.0)
42 | return im
43 |
44 | def make_valid_mask(mask_float, max_pool_size=4):
45 | '''
46 | Creates a mask indicating the valid parts of the image(s).
47 | Enlargens masked area using a max pooling operation.
48 |
49 | Args:
50 | mask_float: A (b x c x h x w) mask as loaded from the Taskonomy loader.
51 | max_pool_size: Parameter to choose how much to enlarge masked area.
52 | '''
53 | squeeze = False
54 | if len(mask_float.shape) == 3:
55 | mask_float = mask_float.unsqueeze(0)
56 | squeeze = True
57 | _, _, h, w = mask_float.shape
58 | mask_float = 1 - mask_float
59 | mask_float = F.max_pool2d(mask_float, kernel_size=max_pool_size)
60 | mask_float = F.interpolate(mask_float, (h, w), mode='nearest')
61 | mask_valid = mask_float == 0
62 | mask_valid = mask_valid[0] if squeeze else mask_valid
63 | return mask_valid
64 |
65 |
66 | def task_transform(file, task: str, image_size=Optional[int]):
67 | transform = None
68 |
69 | if task in ['rgb', 'normal']:
70 | transform = transform_8bit
71 | elif task in ['mask_valid']:
72 | transform = transforms.Compose([
73 | transforms.ToTensor(),
74 | make_valid_mask
75 | ])
76 | elif task in ['keypoints2d', 'keypoints3d', 'depth_euclidean', 'depth_zbuffer', 'edge_texture', 'edge_occlusion']:
77 | #transform = transform_16bit_single_channel
78 | transform = transforms.Compose([
79 | transforms.ToTensor()
80 | ])
81 | elif task in ['principal_curvature', 'curvature']:
82 | transform = transform_8bit_n_channel(2)
83 | elif task in ['reshading']:
84 | transform = transform_8bit_n_channel(1)
85 | elif task in ['segment_semantic', 'segment_instance', 'segment_panoptic', 'fragments', 'segment_unsup2d', 'segment_unsup25d']: # this is stored as 1 channel image (H,W) where each pixel value is a different class
86 | transform = transform_dense_labels
87 | elif task in ['class_object', 'class_scene']:
88 | transform = torch.Tensor
89 | image_size = None
90 | else:
91 | transform = lambda x: x
92 |
93 | """if 'clamp_to' in task_parameters[task]:
94 | minn, maxx = task_parameters[task]['clamp_to']
95 | if minn > 0:
96 | raise NotImplementedError("Rescaling (min1, max1) -> (min2, max2) not implemented for min1, min2 != 0 (task {})".format(task))
97 | transform = transforms.Compose([
98 | transform,
99 | MAKE_RESCALE_0_MAX_0_POS1(maxx)
100 | ])"""
101 |
102 |
103 | if image_size is not None:
104 | if task == 'fragments':
105 | resize_frag = lambda frag: F.interpolate(frag.permute(2,0,1).unsqueeze(0).float(), image_size, mode='nearest').long()[0].permute(1,2,0)
106 | transform = transforms.Compose([
107 | transform,
108 | resize_frag
109 | ])
110 | else:
111 | resize_method = Image.BILINEAR if task in ['rgb'] else Image.NEAREST
112 | transform = transforms.Compose([
113 | transforms.Resize(image_size, resize_method),
114 | transform
115 | ])
116 |
117 |
118 | if transform is not None:
119 | file = transform(file)
120 |
121 | return file
122 |
--------------------------------------------------------------------------------
/script/data/taskonomy/taskonomy_dataset_s3.py:
--------------------------------------------------------------------------------
1 | import os
2 | import io
3 | import boto3
4 | import json
5 | import numpy as np
6 | import pandas as pd
7 | import torch
8 | from torch.utils.data import Dataset
9 | from PIL import Image
10 | #from PIL import ImageFile
11 | #ImageFile.LOAD_TRUNCATED_IMAGES = True # TODO: Fix these images and then remove this
12 |
13 |
14 | from .task_configs import task_parameters
15 | from .transforms import task_transform
16 | from .splits import get_splits
17 |
18 |
19 | filter_amount_dict = {'low': 10000, 'medium': 50000, 'high': 100000}
20 | forbidden_buildings = [
21 | 'mosquito', 'tansboro', 'tomkins', 'darnestown', 'brinnon', # We do not have the rgb data for tomkins, darnestown, brinnon
22 | 'rough', 'grace', 'wiconisco' # Contain some wrong viewpoints
23 | ]
24 |
25 |
26 | class TaskonomyDatasetS3(Dataset):
27 | def __init__(self,
28 | tasks,
29 | split='train',
30 | variant='fullplus',
31 | rm_incomplete=True,
32 | image_size=256,
33 | max_images=None,
34 | seed=0,
35 | filter_amount='medium'):
36 | '''
37 | Taskonomy EPFL-S3 dataloader.
38 | Make sure the environment variables S3_ENDPOINT, S3_TASKONOMY_ACCESS,
39 | S3_TASKONOMY_KEY, and S3_TASKONOMY_BUCKET are set.
40 |
41 | Args:
42 | tasks: List of tasks
43 | split: One of {'train', 'val', 'test', 'all'}
44 | variant: One of {'debug', 'tiny', 'medium', 'full', 'fullplus'}
45 | rm_incomplete: Set to True to only keep samples that have every task
46 | image_size: Target image size
47 | max_images: Optional subset selection
48 | seed: Random seed for deterministic shuffling order
49 | filter_amount: How many "bad" images to remove. One of {'low', 'medium', 'high'}.
50 | '''
51 | super(TaskonomyDatasetS3, self).__init__()
52 | self.tasks = tasks
53 | self.split = split
54 | self.variant = variant
55 | self.rm_incomplete = rm_incomplete
56 | self.image_size=image_size
57 | self.max_images = max_images
58 | self.seed = seed
59 | self.filter_amount = filter_amount
60 |
61 | # S3 bucket setup
62 | self.session = boto3.session.Session()
63 | self.s3_client = self.session.client(
64 | service_name='s3',
65 | aws_access_key_id=os.environ.get('S3_TASKONOMY_ACCESS'),
66 | aws_secret_access_key=os.environ.get('S3_TASKONOMY_KEY'),
67 | endpoint_url=os.environ.get('S3_ENDPOINT')
68 | )
69 | self.bucket_name = os.environ.get('S3_TASKONOMY_BUCKET')
70 |
71 | # DataFrame containing information whether or not any file for any task exists
72 | self.df_meta = pd.read_pickle(os.path.join(os.path.dirname(__file__), 'metadata', 'taskonomy_files.pkl.gz'))
73 |
74 | # Select splits based on selected size/variant
75 | splits = get_splits(
76 | os.path.join(os.path.dirname(__file__), 'metadata', f'train_val_test_{variant}.csv'),
77 | forbidden_buildings=forbidden_buildings
78 | )
79 | if split == 'all':
80 | self.buildings = list(set(splits['train']) | set(splits['val']) | set(splits['test']))
81 | else:
82 | self.buildings = splits[split]
83 | self.buildings = sorted(self.buildings)
84 | self.df_meta = self.df_meta.loc[self.buildings]
85 |
86 | # Filter bad images
87 | df_filter = pd.read_pickle(os.path.join(os.path.dirname(__file__), 'metadata', 'taskonomy_filter_scores.pkl.gz'))
88 | df_filter = df_filter[:filter_amount_dict[filter_amount]]
89 | filtered_indices = self.df_meta.index.difference(df_filter.index)
90 | self.df_meta = self.df_meta.loc[filtered_indices]
91 |
92 | self.df_meta = self.df_meta[tasks] # Select tasks of interest
93 | if rm_incomplete:
94 | # Only select rows where we have all the tasks
95 | self.df_meta = self.df_meta[self.df_meta.all(axis=1)]
96 | self.df_meta = self.df_meta.sample(frac=1, random_state=seed) # Random shuffle
97 | self.df_meta = self.df_meta[:max_images] if max_images is not None else self.df_meta # Select subset if so desired
98 |
99 | print(f'Using {len(self.df_meta)} images from variant {self.variant} in split {self.split}.')
100 |
101 |
102 | def __len__(self):
103 | return len(self.df_meta)
104 |
105 | def __getitem__(self, index):
106 |
107 | # building / point / view are encoded in dataframe index
108 | building, point, view = building, point, view = self.df_meta.iloc[index].name
109 | # TODO: Remove this try/except after we made sure there are no bad/missing images!
110 | # Very slow if it fails.
111 | try:
112 |
113 | result = {}
114 | for task in self.tasks:
115 | # Load from S3 bucket
116 | ext = task_parameters[task]['ext']
117 | domain_id = task_parameters[task]['domain_id']
118 | key = f'taskonomy_imgs/{task}/{building}/point_{point}_view_{view}_domain_{domain_id}.{ext}'
119 | obj = self.s3_client.get_object(Bucket=self.bucket_name, Key=key)['Body'].read()
120 |
121 | # Convert bytes to image / json / array / etc...
122 | if ext == 'png':
123 | file = Image.open(io.BytesIO(obj))
124 | elif ext == 'json':
125 | file = json.load(io.BytesIO(obj))
126 | if task == 'point_info':
127 | file['building'] = building
128 | file.pop('nonfixated_points_in_view')
129 | elif ext == 'npy':
130 | file = np.frombuffer(obj)
131 | else:
132 | raise NotImplementedError(f'Loading extension {ext} not yet implemented')
133 |
134 | # Perform transformations
135 | file = task_transform(file, task=task, image_size=self.image_size)
136 |
137 | result[task] = file
138 |
139 | return torch.stack([result[self.tasks[0]],result[self.tasks[0]]]), torch.stack([result[t].view(-1,self.image_size,self.image_size) for i,t in enumerate(self.tasks) if i!=0] ),torch.LongTensor([i for i in range(len(self.tasks)-1)])
140 |
141 |
142 | except Exception as e :
143 | # In case image was faulty or not uploaded yet, try with random other image
144 |
145 | return self[np.random.randint(len(self))]
146 |
--------------------------------------------------------------------------------
/script/model/conditional_modules.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import numbers
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 |
8 | class FiLM(nn.Module):
9 | """ Feature-wise Linear Modulation (FiLM) layer"""
10 | def __init__(self, input_size, output_size, num_film_layers=1, layer_norm=False):
11 | """
12 | :param input_size: feature size of x_cond
13 | :param output_size: feature size of x_to_film
14 | :param layer_norm: true or false
15 | """
16 | super(FiLM, self).__init__()
17 | self.input_size = input_size
18 | self.output_size = output_size
19 | self.num_film_layers = num_film_layers
20 | self.layer_norm = nn.LayerNorm(output_size) if layer_norm else None
21 | film_output_size = self.output_size * num_film_layers * 2
22 | self.gb_weights = nn.Linear(self.input_size, film_output_size)
23 | self.gb_weights.bias.data.fill_(0)
24 |
25 | def forward(self, x_cond, x_to_film):
26 | gb = self.gb_weights(x_cond).unsqueeze(1)
27 | gamma, beta = torch.chunk(gb, 2, dim=-1)
28 | out = (1 + gamma) * x_to_film + beta
29 | if self.layer_norm is not None:
30 | out = self.layer_norm(out)
31 | return out
32 |
33 |
34 | class TAA(nn.Module):
35 | """ Task Adapted Attention layer"""
36 | def __init__(self, input_size, output_size, blocks=1, num_film_layers=1, layer_norm=False):
37 | """
38 | :param input_size: feature size of x_cond
39 | :param output_size: feature size of x_to_film
40 | :param layer_norm: true or false
41 | """
42 | super(TAA, self).__init__()
43 | self.input_size = input_size
44 | self.output_size = output_size
45 | self.num_film_layers = num_film_layers
46 | self.layer_norm = nn.LayerNorm(output_size) if layer_norm else None
47 | self.blocks = blocks
48 | film_output_size = self.output_size * num_film_layers * 2
49 | self.gb_weights = nn.Linear(self.input_size, film_output_size)
50 | self.gb_weights.bias.data.fill_(0)
51 |
52 | def forward(self, x_cond, x_to_film):
53 | """gb = self.gb_weights(x_cond).unsqueeze(1)
54 | gamma, beta = torch.chunk(gb, 2, dim=-1)
55 | out = (1 + gamma) * x_to_film + beta
56 | """
57 |
58 | gb = self.gb_weights(x_cond).unsqueeze(1)
59 |
60 | gamma, beta = torch.chunk(gb, 2, dim=-1)
61 |
62 | out = (1 + gamma) * x_to_film + beta
63 |
64 |
65 | if self.layer_norm is not None:
66 | out = self.layer_norm(out)
67 | out = [torch.block_diag(*list(out_b.chunk(self.blocks, 0))) for out_b in out]
68 | out = torch.stack(out)
69 | return out[:, :, :out.size(1)]
70 |
71 |
72 | class TaskScaledNorm(nn.Module):
73 | r"""Applies Task Scaled Normalization over a mini-batch of inputs.
74 |
75 | .. math::
76 | y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma(z) + \beta(z)
77 |
78 | The mean and standard-deviation are calculated separately over the last
79 | certain number dimensions which have to be of the shape specified by
80 | :attr:`normalized_shape`.
81 | :math:`\gamma` and :math:`\beta` are learnable affine transform parameters of
82 | :attr:`normalized_shape` if :attr:`elementwise_affine` is ``True``.
83 |
84 | .. note::
85 | Unlike Batch Normalization and Instance Normalization, which applies
86 | scalar scale and bias for each entire channel/plane with the
87 | :attr:`affine`, Layer Normalization applies per-element scale and
88 | bias with :attr:`elementwise_affine`.
89 |
90 | This layer uses statistics computed from input data in both training and
91 | evaluation modes. The affine transformation is modulated by a task scaled tensor.
92 | In our case, we use task embeddings.
93 |
94 | Args:
95 | normalized_shape (int or list or torch.Size): input shape from an expected input
96 | of size
97 |
98 | .. math::
99 | [* \times \text{normalized\_shape}[0] \times \text{normalized\_shape}[1]
100 | \times \ldots \times \text{normalized\_shape}[-1]]
101 |
102 | If a single integer is used, it is treated as a singleton list, and this module will
103 | normalize over the last dimension which is expected to be of that specific size.
104 | eps: a value added to the denominator for numerical stability. Default: 1e-5
105 | elementwise_affine: a boolean value that when set to ``True``, this module
106 | has learnable per-element affine parameters initialized to ones (for weights)
107 | and zeros (for biases). Default: ``True``.
108 |
109 | Shape:
110 | - Input: :math:`(N, *)`
111 | - Output: :math:`(N, *)` (same shape as input)
112 |
113 | Examples::
114 |
115 | >>> input_ = torch.randn(20, 5, 10, 10)
116 | >>> condition = torch.randn(20, 10)
117 | >>> # With Learnable Parameters
118 | >>> m = TaskScaledNorm([10, 10])
119 | >>> # Normalize over last dimension of size 10
120 | >>> m = nn.LayerNorm(10)
121 | >>> # Activating the module
122 | >>> output = m(input_, condition)
123 |
124 | """
125 | __constants__ = ['normalized_shape', 'condition_size', 'weight', 'bias', 'eps']
126 |
127 | def __init__(self, normalized_shape, condition_size, eps=1e-5):
128 | super(TaskScaledNorm, self).__init__()
129 | if isinstance(normalized_shape, numbers.Integral):
130 | normalized_shape = (normalized_shape,)
131 | self.normalized_shape = tuple(normalized_shape)
132 |
133 | self.condition_size = condition_size
134 | self.eps = eps
135 |
136 | self.weight = nn.Parameter(torch.Tensor(*normalized_shape))
137 | self.ln_weight_modulation = FiLM(condition_size, sum(normalized_shape))
138 | self.bias = nn.Parameter(torch.Tensor(*normalized_shape))
139 | self.reset_parameters()
140 |
141 | def reset_parameters(self):
142 | nn.init.ones_(self.weight)
143 | nn.init.zeros_(self.bias)
144 |
145 | def forward(self, input_, condition, task_id):
146 | unique_task_ids = torch.unique(task_id)
147 | cln_output = torch.zeros_like(input_)
148 | for unique_task_id in unique_task_ids:
149 | task_id_filter = task_id == unique_task_id
150 | task_emb = condition[task_id_filter][0].unsqueeze(0)
151 | weight = self.ln_weight_modulation(task_emb, self.weight).view(-1)
152 | cln_output[task_id_filter] = F.layer_norm(input_[task_id_filter], self.normalized_shape, weight, self.bias, self.eps)
153 | return cln_output
154 |
155 | def extra_repr(self):
156 | return '{normalized_shape}, {condition_size}, eps={eps}'.format(**self.__dict__)
157 |
158 |
159 | class ConditionalBottleNeck(nn.Module):
160 | """Down projection and up projection with FiLM layers within Transformer layer."""
161 | def __init__(self, hidden_size, output_size):
162 | super(ConditionalBottleNeck, self).__init__()
163 | self.emb_transf = nn.Linear(hidden_size, hidden_size)
164 | self.hidden_modulation = FiLM(hidden_size, output_size)
165 | self.down_proj_layer = nn.Linear(output_size, output_size//3)
166 | self.up_proj_layer = nn.Linear(output_size//3, output_size)
167 |
168 | def forward(self, x_cond, hidden_states):
169 | x_cond = self.emb_transf(x_cond)
170 | hidden_states = self.hidden_modulation(x_cond=x_cond, x_to_film=hidden_states)
171 | hidden_states = self.down_proj_layer(hidden_states)
172 | hidden_states = self.up_proj_layer(hidden_states)
173 | return hidden_states
174 |
--------------------------------------------------------------------------------
/script/evaluate.py:
--------------------------------------------------------------------------------
1 | from data.taskonomy.taskonomy_dataset_s3 import TaskonomyDatasetS3
2 | from matplotlib import pyplot as plt
3 | import torch
4 | import torchvision
5 | import torchvision.transforms as transforms
6 | from torch.utils.data import DataLoader
7 | import torch.optim as optim
8 | import torch.nn.functional as F
9 | from torch.utils.tensorboard import SummaryWriter
10 | import transformers
11 | from tqdm import tqdm
12 | import numpy as np
13 | import os
14 | import pickle
15 | import cv2
16 | import json
17 | import argparse
18 |
19 |
20 | from torchvision.utils import make_grid, save_image
21 |
22 |
23 | from data.nyuv2_same_batch import NYUv2SameBatchDataset
24 | from model.swin_transformer import SwinTransformer
25 | from loss.losses import berHuLoss
26 | from loss.metrics import iou_pytorch, eval_depth
27 | from data.nyuv2 import NYUv2Dataset
28 |
29 |
30 | def get_config():
31 | parser = argparse.ArgumentParser(description='Train the network')
32 | parser.add_argument('--config', help='train config file path')
33 |
34 | args = parser.parse_args()
35 |
36 | with open(args.config, "r") as jsonfile:
37 | config = json.load(jsonfile)
38 |
39 | return config
40 |
41 |
42 | def get_dataloaders(tasks, batch_size, setting="nyu", task=None):
43 |
44 | if setting == "taskonomy":
45 |
46 | test_dataset = TaskonomyDatasetS3(
47 | tasks=["rgb", "segment_semantic", "depth_euclidean"], split="val", variant="tiny", image_size=224)
48 |
49 | g = torch.Generator()
50 | g.manual_seed(61)
51 |
52 | k_samples = 16*100
53 | perm = torch.randperm(len(test_dataset), generator=g)
54 | idx = perm[:k_samples].tolist()
55 |
56 | subset_dataset_test = torch.utils.data.Subset(test_dataset, idx)
57 |
58 | dataloader = DataLoader(subset_dataset_test,
59 | batch_size=batch_size, shuffle=False)
60 |
61 | return dataloader
62 |
63 | if setting == "nyu":
64 |
65 | IMAGE_SIZE = (480, 640)
66 |
67 | test_t = torch.nn.Sequential(
68 | transforms.CenterCrop(480), transforms.Resize(224))
69 | train_t_input_image = torch.nn.Sequential(transforms.ColorJitter(brightness=(
70 | 0.8, 1.2), contrast=(0.8, 1.2), saturation=(0.8, 1.2), hue=(-0.1, 0.1)))
71 |
72 | test_dataset = NYUv2SameBatchDataset(root="./data/nyuv2", tasks=tasks, download=False, train=False,
73 | rgb_transform=test_t, seg_transform=test_t, sn_transform=test_t, depth_transform=test_t)
74 |
75 | dataloader = DataLoader(
76 | test_dataset, batch_size=batch_size, shuffle=False)
77 |
78 | return dataloader
79 |
80 | if setting == "nyu_single_task":
81 |
82 | IMAGE_SIZE = (480, 640)
83 |
84 | test_t = torch.nn.Sequential(
85 | transforms.CenterCrop(480), transforms.Resize(224))
86 | train_t_input_image = torch.nn.Sequential(transforms.ColorJitter(brightness=(
87 | 0.8, 1.2), contrast=(0.8, 1.2), saturation=(0.8, 1.2), hue=(-0.1, 0.1)))
88 |
89 | test_dataset = NYUv2Dataset(root="./data/nyuv2", tasks=tasks, download=False, train=False,
90 | rgb_transform=test_t, seg_transform=test_t, sn_transform=test_t, depth_transform=test_t)
91 |
92 | if task == "segmentation":
93 | test_dataset = torch.utils.data.Subset(
94 | test_dataset, range(len(test_dataset)//2))
95 |
96 | if task == "depth":
97 | test_dataset = torch.utils.data.Subset(
98 | test_dataset, range(len(test_dataset)//2, len(test_dataset)))
99 |
100 | dataloader = DataLoader(
101 | test_dataset, batch_size=batch_size, shuffle=False)
102 |
103 | return dataloader
104 |
105 |
106 | def calc_seg_metrics(logit_task, label_task):
107 |
108 | max_labels = torch.argmax(logit_task, dim=1, keepdim=True)
109 | iou = iou_pytorch(max_labels, label_task)
110 |
111 | return max_labels, iou
112 |
113 |
114 | def disp2meters(d):
115 | return (65536.0 / d - 1) / 1e4
116 |
117 |
118 | def load_model(model, PATH, device):
119 | checkpoint = torch.load(PATH, map_location=device)
120 | model.load_state_dict(checkpoint['model_state_dict'])
121 | model = model.to(device)
122 | return model
123 |
124 |
125 | def evaluate(model, dataloader, device, task=None):
126 | test_loss = 0
127 | epoch_ious = []
128 | epoch_eval_depths_d1 = []
129 |
130 | epoch_loss_seg_test = []
131 | epoch_loss_depth_test = []
132 |
133 | model.eval()
134 | for i, (img, label, task_id) in enumerate(dataloader, 0):
135 |
136 | img = img.view((-1, 3, 224, 224)).to(device)
137 | label = label.view((-1, 1, 224, 224)).to(device)
138 | task_id = task_id.view(-1).to(device)
139 |
140 | if task is not None:
141 | task_id = torch.zeros_like(task_id)
142 |
143 | logits, unique_task_ids_list = model(img, task_id)
144 |
145 | loss = 0
146 |
147 | for j, unique_task_id in enumerate(unique_task_ids_list):
148 |
149 | task_id_filter = task_id == unique_task_id
150 |
151 | logit_task = logits[j]
152 | label_task = label[task_id_filter]
153 | B = logit_task.shape[0]
154 |
155 | if unique_task_id == 0 and task != "depth":
156 |
157 | label_task = label_task.long()
158 |
159 | max_labels, iou = calc_seg_metrics(logit_task, label_task)
160 |
161 | epoch_ious.append(iou.cpu().numpy())
162 |
163 | else:
164 |
165 | logit_task = torch.nn.functional.sigmoid(logit_task)*65535 + 1
166 | label_task = 65536.0 / (label_task + 1)
167 |
168 | evaluation = eval_depth(disp2meters(
169 | logit_task), disp2meters(label_task))
170 | epoch_eval_depths_d1.append(evaluation["d1"])
171 |
172 | print("Mean IOU: ", np.mean(epoch_ious))
173 | print("D1: ", np.mean(epoch_eval_depths_d1))
174 |
175 |
176 | def save_images(model, dataloader, device, num_images=1, task=None, setting=None):
177 |
178 | img_count = 0
179 | for i, (img, label, task_id) in enumerate(dataloader, 0):
180 |
181 | img = img.view((-1, 3, 224, 224)).to(device)
182 | label = label.view((-1, 1, 224, 224)).to(device)
183 | task_id = task_id.view(-1).to(device)
184 |
185 | if task is not None:
186 | task_id = torch.zeros_like(task_id)
187 |
188 | logits, unique_task_ids_list = model(img, task_id)
189 |
190 | for j in range(len(img)):
191 | if len(logits) == 1:
192 | fig, axs = plt.subplots(1, 3, figsize=(12, 5))
193 |
194 | axs[0].imshow(torch.permute(img[j].cpu(), (1, 2, 0)))
195 | axs[0].set_xlabel('RGB Image')
196 |
197 | if task == "segmentation":
198 | k = torch.argmax(logits[0][j], dim=0, keepdim=True)
199 | k[0][-1][-1] = torch.max(label[j])
200 | label[j][0][-1][-1] = torch.max(k)
201 | else:
202 | k = disp2meters(torch.nn.functional.sigmoid(
203 | logits[0][j])*65535 + 1)
204 |
205 | axs[1].imshow(torch.permute(label[j].cpu(), (1, 2, 0)))
206 | axs[1].set_xlabel(f'{task.capitalize()} Label')
207 |
208 | axs[2].imshow(k.detach().view(224, 224, 1).cpu())
209 | axs[2].set_xlabel(f'{task.capitalize()} Prediction')
210 |
211 | plt.savefig(f'./images/{img_count}.png')
212 | img_count += 1
213 |
214 | else:
215 | if j % 2 == 1:
216 | continue
217 |
218 | c = j//2
219 | fig, axs = plt.subplots(1, 5, figsize=(20, 5))
220 | axs[0].imshow(torch.permute(img[j].cpu(), (1, 2, 0)))
221 | axs[0].set_xlabel('RGB Image')
222 |
223 | k = torch.argmax(logits[0][c], dim=0, keepdim=True)
224 |
225 | k[0][-1][-1] = 18 if setting == "taskonomy" else 13
226 | label[j][0][-1][-1] = 18 if setting == "taskonomy" else 13
227 |
228 | axs[1].imshow(torch.permute(label[j].cpu(), (1, 2, 0)))
229 | axs[1].set_xlabel('Segmentation Label')
230 |
231 | axs[2].imshow(k.detach().view(224, 224, 1).cpu())
232 | axs[2].set_xlabel('Segmentation Prediction')
233 |
234 | label[j+1][label[j+1] == 65535] = 0
235 | axs[3].imshow(torch.permute(label[j+1].cpu(), (1, 2, 0)))
236 | axs[3].set_xlabel('Depth Label')
237 | k2 = disp2meters(torch.nn.functional.sigmoid(
238 | logits[1][c])*65535 + 1)
239 | axs[4].imshow(k2.detach().view(224, 224, 1).cpu())
240 | axs[4].set_xlabel('Depth Prediction')
241 |
242 | plt.savefig(f'./images/{img_count}.png')
243 | img_count += 1
244 |
245 | if img_count == num_images:
246 | return
247 |
248 |
249 | def main():
250 |
251 | config = get_config()
252 |
253 | if config["setting"] != "nyu_single_task" and "task" in config.keys():
254 | print("Do not put task parameter on multitask networks!")
255 | return
256 |
257 | torch.manual_seed(61)
258 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
259 |
260 | tasks = {0: "segmentation", 1: "depth"}
261 | print("Creating dataset...")
262 | dataloader = get_dataloaders(
263 | tasks, 16, config["setting"], config["task"] if "task" in config.keys() else None)
264 |
265 | print("Loading model...")
266 |
267 | tasks = ["segmentation", "depth"]
268 | task_classes = [14, 1] if config["setting"] != "taskonomy" else [18, 1]
269 | if config["setting"] == "nyu_single_task":
270 | tasks = [config["task"]]
271 | task_classes = [14 if config["task"] == "segmentation" else 1]
272 |
273 | model = SwinTransformer(img_size=224,
274 | patch_size=4,
275 | in_chans=3,
276 | num_classes=21841,
277 | embed_dim=96,
278 | depths=[2, 2, 18, 2],
279 | depths_decoder=[2, 2, 2, 2],
280 | num_heads=[3, 6, 12, 24],
281 | window_size=7,
282 | mlp_ratio=4.,
283 | qkv_bias=True,
284 | qk_scale=True,
285 | drop_rate=0,
286 | drop_rate_decoder=0.6,
287 | drop_path_rate=0.2,
288 | ape=False,
289 | patch_norm=True,
290 | use_checkpoint=False,
291 | tasks=tasks,
292 | task_classes=task_classes,
293 | conditioned_blocks=config["conditioned_blocks"] if config["setting"] != "nyu_single_task" else [
294 | [], [], [], []],
295 | adapter=config["adapter"] if config["setting"] != "nyu_single_task" else False,
296 | use_conditional_layer=config["use_conditional_layer_norm"] if config["setting"] == "nyu" else False)
297 |
298 | model = load_model(model, config["model_path"], device)
299 |
300 | print("Evaluating...")
301 | evaluate(model, dataloader, device,
302 | task=config["task"] if "task" in config.keys() else None)
303 |
304 | print("Saving Images...")
305 | save_images(model, dataloader, device, num_images=config["num_generated_images"],
306 | task=config["task"] if "task" in config.keys() else None, setting=config["setting"])
307 |
308 |
309 | if __name__ == '__main__':
310 | main()
311 |
--------------------------------------------------------------------------------
/docs/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
10 |
11 |
12 | Vision Transformer Adapters for Generalizable Multitask Learning
13 |
14 |
15 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 | VIDEO
54 |
55 |
56 | We introduce the first multitasking vision transformer adapters that learn generalizable task affinities which can be applied to novel tasks and domains. Integrated into an off-the-shelf vision transformer backbone, our adapters can simultaneously solve multiple dense vision tasks in a parameter-efficient manner, unlike existing multitasking transformers that are parametrically expensive. In contrast to concurrent methods, we do not require retraining or fine-tuning whenever a new task or domain is added. We introduce a task-adapted attention mechanism within our adapter framework that combines gradient-based task similarities with attention-based ones. The learned task affinities generalize to the following settings: zero-shot task transfer, unsupervised domain adaptation, and generalization without fine-tuning to novel domains. We demonstrate that our approach outperforms not only the existing convolutional neural network-based multitasking methods but also the vision transformer-based ones.
57 |
58 |
59 |
60 |
Method Architecture
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 | Detailed overview of our architecture. The frozen transformer encoder module (in orange ) extracts a shared representation of the input image, which is then utilized to learn the task affinities in our novel vision transformer adapters (in purple ). Each adapter layer uses gradient task similarity (TROA) (in yellow ) and Task-Adapted Attention (TAA) to learn the task affinities, which are communicated with skip connections (in blue ) between consecutive adapter layers. The task embeddings are then decoded by the fully-supervised transformer decoders (in green ) for the respective tasks. Note that the transformer decoders are shared but have different task heads (in grey ). For clarity, only three tasks are depicted here and TAA is explained in a separate figure below.
69 |
70 |
71 |
72 |
73 |
Vision Transformer Adapter Module
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 | Overview of our vision transformer adapter module. Our vision adapters learn transferable and generalizable task affinities in a parameter-efficient way. We show two blocks to depict the skip connectivity between them. The main modules (TROA) and (TAA) of our vision transformer adapters are depicted below.
82 |
83 |
84 |
85 |
Task Representation Optimization Algorithm (TROA)
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 | We show the task affinities from TROA when four tasks comprising semantic segmentation (SemSeg), depth, surface normal, and edges are jointly learned. We show that TROA learns a strong task affinity between the same task gradients, for example, segmentation with segmentation. This is a self-explanatory observation. Consequently, TROA also learns task affinities between proximate tasks such as segmentation and depth, and task affinities between non-proximate tasks such as segmentation and normal. Note that task dependence is asymmetric, i.e. segmentation does not affect normal as normal effects segmentation. These task affinities are used by our novel task-adapted attention module as described in what follows.
94 |
95 |
96 |
97 |
Matching the Feature Dimensions using FiLM
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 | Detailed overview of Feature Wise Linear Modulation (FiLM)} which linearly shifts and scales tasks representations to match dimensions of the feature maps. The orange rectangular area is FiLM.
106 |
107 |
108 |
109 |
Task-Adapted Attention
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 | Overview of our Task-Adapted Attention (TAA) mechanism that combines task affinities with image attention. Note that the process, in the foreground, is for a single attention head which is repeated for 'M' heads to give us the task-adapted multi-head attention.
118 |
119 |
120 |
121 |
122 |
123 |
Multitasking Results
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 | Multitask Learning comparison on the NYUDv2 benchmark in the'S-D-N-E' setting. Our model outperforms all the multitask baselines, i.e. ST-MTL, InvPT, Taskprompter, and MulT, respectively. For instance, our model correctly segments and predicts the surface normal of the elements within the yellow-circled region, unlike the baseline. All the methods are based on the same Swin-B V2 backbone. Best seen on screen and zoomed in. For more details and quantitative results, please refer to our paper.
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 | Multitask Learning comparison on the Taskonomy benchmark in the'S-D-N-E' setting. Our model outperforms all the multitask baselines, respectively. For instance, our model correctly segments and predicts the surface normal of the elements within the yellow-circled region, unlike the baseline. All the methods are based on the same Swin-B V2 backbone. Best seen on screen and zoomed in. For more details and quantitative results, please refer to our paper.
143 |
144 |
145 |
146 |
147 |
148 |
Unsupervised Domain Adaptation (UDA)
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 |
157 | Unsupervised Domain Adaptation (UDA) results on Synthia->Cityscapes. Our model outperforms the CNN-based baseline (XTAM-UDA) and the Swin-B V2-based baselines (1-task Swin-UDA, MulT-UDA), respectively. For instance, our method can predict the depth of the car tail light, unlike the baselines. Best seen on screen and zoomed within the yellow circled region.
158 |
159 |
160 |
161 |
162 |
Bibtex
163 |
164 |
165 |
166 |
167 | @misc{bhattacharjee2023vision,
168 | title={Vision Transformer Adapters for Generalizable Multitask Learning},
169 | author={Deblina Bhattacharjee and Sabine Süsstrunk and Mathieu Salzmann},
170 | year={2023},
171 | eprint={2308.12372},
172 | archivePrefix={arXiv},
173 | primaryClass={cs.CV}
174 | }
175 |
176 |
177 |
178 |
179 |
Acknowledgement
180 |
181 |
182 |
This work was supported in part by the Swiss National Science Foundation via the Sinergia grant CRSII5$-$180359.
183 |
184 |
185 |
186 |
187 |
188 |
189 |
190 |
191 |
192 |
195 |
198 |
201 |
202 |
203 |
--------------------------------------------------------------------------------
/script/data/nyuv2.py:
--------------------------------------------------------------------------------
1 | """
2 | author: Mihai Suteu
3 | date: 15/05/19
4 | https://github.com/xapharius/pytorch-nyuv2
5 | """
6 |
7 |
8 | import os
9 | import sys
10 | import h5py
11 | import torch
12 | import shutil
13 | import random
14 | import tarfile
15 | import zipfile
16 | import requests
17 | import numpy as np
18 | from typing import Dict
19 |
20 | from PIL import Image
21 | from torch.utils.data import Dataset
22 | from torchvision.datasets.utils import download_url
23 |
24 | SEG = 0
25 | DEP = 1
26 | SN = 2
27 |
28 |
29 | class NYUv2Dataset(Dataset):
30 | """
31 | PyTorch wrapper for the NYUv2 dataset focused on multi-task learning.
32 | Data sources available: RGB, Semantic Segmentation, Surface Normals, Depth Images.
33 | If no transformation is provided, the image type will not be returned.
34 |
35 | ### Output
36 | All images are of size: 640 x 480
37 |
38 | 1. RGB: 3 channel input image
39 |
40 | 2. Semantic Segmentation: 1 channel representing one of the 14 (0 -
41 | background) classes. Conversion to int will happen automatically if
42 | transformation ends in a tensor. Task name: "segmentation"
43 |
44 | 3. Surface Normals: 3 channels, with values in [0, 1]. Task name: "surface_normals"
45 |
46 | 4. Depth Images: 1 channel with floats representing the distance in meters.
47 | Conversion will happen automatically if transformation ends in a tensor. Task name: "depth"
48 | """
49 |
50 | def __init__(
51 | self,
52 | root: str,
53 | tasks: Dict[int, str],
54 | train: bool = True,
55 | download: bool = False,
56 | rgb_transform=None,
57 | seg_transform=None,
58 | sn_transform=None,
59 | depth_transform=None,
60 | rgb_transform2=None,
61 | ):
62 | """
63 | Will return tuples based on what data source has been enabled (rgb, seg etc).
64 |
65 | :param root: path to root folder (eg /data/NYUv2)
66 | :param train: whether to load the train or test set
67 | :param download: whether to download and process data if missing
68 | :param rgb_transform: the transformation pipeline for rbg images
69 | :param seg_transform: the transformation pipeline for segmentation images. If
70 | the transformation ends in a tensor, the result will be automatically
71 | converted to int in [0, 14)
72 | :param sn_transform: the transformation pipeline for surface normal images
73 | :param depth_transform: the transformation pipeline for depth images. If the
74 | transformation ends in a tensor, the result will be automatically converted
75 | to meters
76 | """
77 | super().__init__()
78 | self.root = root
79 |
80 | self.rgb_transform = rgb_transform
81 | self.rgb_transform2 = rgb_transform2
82 | self.seg_transform = seg_transform
83 | self.depth_transform = depth_transform
84 | self.sn_transform = sn_transform
85 |
86 | self.train = train
87 | self._split = "train" if train else "test"
88 |
89 | if download:
90 | self.download()
91 |
92 | # rgb folder as ground truth
93 | self._files = sorted(os.listdir(os.path.join(root, f"{self._split}_rgb_pt")))
94 | self.num_img = len(self._files)
95 |
96 | self.num_tasks = len(tasks)
97 | self.tasks = tasks
98 |
99 | self.task_dict = self._get_task_dict()
100 |
101 | self.folder = lambda name: os.path.join(self.root, f"{self._split}_{name}_pt")
102 |
103 | self.seg_images = torch.load(f"{root}/combined/{self._split}_seg13_pt.pt")
104 | self.depth_images = torch.load(f"{root}/combined/{self._split}_depth_pt.pt")
105 |
106 |
107 |
108 |
109 | def __getitem__(self, index: int):
110 |
111 | task = index // self.num_img
112 | rgb_image = index % self.num_img
113 | seed = random.randrange(sys.maxsize)
114 | rgb = None
115 | state = None
116 |
117 | if self.rgb_transform is not None:
118 | random.seed(seed)
119 | img = torch.load(os.path.join(self.folder("rgb"), self._files[rgb_image])) # self.rgb_images[rgb_image, :,:,:]#
120 | ### https://github.com/pytorch/vision/issues/9#issuecomment-789308878
121 | state = torch.get_rng_state()
122 | rgb = self.rgb_transform(img)
123 | if self.rgb_transform2 is not None:
124 | rgb = self.rgb_transform2(rgb)
125 |
126 | label = self._get_task_label(task, rgb_image, state)
127 |
128 | return rgb, label, task
129 |
130 | def _get_task_dict(self):
131 |
132 | task_dict = dict()
133 |
134 | for i in self.tasks.keys():
135 |
136 | task_type = self.tasks[i]
137 | if task_type == "segmentation":
138 | task_dict[i] = SEG
139 | elif task_type == "surface_normals":
140 | task_dict[i] = SN
141 | elif task_type == "depth":
142 | task_dict[i] = DEP
143 |
144 | return task_dict
145 |
146 |
147 | def _get_task_label(self, task, rgb_image, state):
148 | seed = random.randrange(sys.maxsize)
149 |
150 | task_type = self.task_dict[task]
151 | if task_type == SEG:
152 | if self.seg_transform is not None:
153 | random.seed(seed)
154 | img = self.seg_images[rgb_image, :,:,:]
155 | torch.set_rng_state(state)
156 | img = self.seg_transform(img)
157 | if isinstance(img, torch.Tensor):
158 | # ToTensor scales to [0, 1] by default
159 | img = (img * 255).long()
160 | return img
161 |
162 | if task_type == SN: # kontrol et
163 | if self.sn_transform is not None:
164 | random.seed(seed)
165 | img = self.rgb_images[rgb_image, :,:,:]
166 | torch.set_rng_state(state)
167 | img = self.sn_transform(img)
168 | return img
169 |
170 | if task_type == DEP:
171 | if self.depth_transform is not None:
172 | random.seed(seed)
173 | img = self.depth_images[rgb_image, :,:,:]
174 | torch.set_rng_state(state)
175 | img = self.depth_transform(img)
176 | if isinstance(img, torch.Tensor):
177 | # depth png is uint16
178 | img = img.float()
179 | return img
180 |
181 |
182 |
183 |
184 |
185 | def __len__(self):
186 | return len(self._files) * self.num_tasks
187 |
188 | def __repr__(self):
189 | fmt_str = f"Dataset {self.__class__.__name__}\n"
190 | fmt_str += f" Number of data points: {self.__len__()}\n"
191 | fmt_str += f" Split: {self._split}\n"
192 | fmt_str += f" Root Location: {self.root}\n"
193 | tmp = " RGB Transforms: "
194 | fmt_str += "{0}{1}\n".format(
195 | tmp, self.rgb_transform.__repr__().replace("\n", "\n" + " " * len(tmp))
196 | )
197 | tmp = " Seg Transforms: "
198 | fmt_str += "{0}{1}\n".format(
199 | tmp, self.seg_transform.__repr__().replace("\n", "\n" + " " * len(tmp))
200 | )
201 | tmp = " SN Transforms: "
202 | fmt_str += "{0}{1}\n".format(
203 | tmp, self.sn_transform.__repr__().replace("\n", "\n" + " " * len(tmp))
204 | )
205 | tmp = " Depth Transforms: "
206 | fmt_str += "{0}{1}\n".format(
207 | tmp, self.depth_transform.__repr__().replace("\n", "\n" + " " * len(tmp))
208 | )
209 | return fmt_str
210 |
211 | def _check_exists(self) -> bool:
212 | """
213 | Only checking for folder existence
214 | """
215 | try:
216 | for split in ["train", "test"]:
217 | for part, transform in zip(
218 | ["rgb", "seg13", "depth"],#"sn",
219 | [
220 | self.rgb_transform,
221 | self.seg_transform,
222 | self.sn_transform,
223 | self.depth_transform,
224 | ],
225 | ):
226 | if transform is None:
227 | continue
228 | path = os.path.join(self.root, f"{split}_{part}_pt")
229 | if not os.path.exists(path):
230 | raise FileNotFoundError("Missing Folder")
231 | except FileNotFoundError as e:
232 | return False
233 | return True
234 |
235 | def download(self):
236 | if self._check_exists():
237 | return
238 | if self.rgb_transform is not None:
239 | download_rgb(self.root)
240 | if self.seg_transform is not None:
241 | download_seg(self.root)
242 | if self.sn_transform is not None:
243 | download_sn(self.root)
244 | if self.depth_transform is not None:
245 | download_depth(self.root)
246 | print("Done!")
247 |
248 |
249 | def download_rgb(root: str):
250 | train_url = "http://www.doc.ic.ac.uk/~ahanda/nyu_train_rgb.tgz"
251 | test_url = "http://www.doc.ic.ac.uk/~ahanda/nyu_test_rgb.tgz"
252 |
253 | def _proc(url: str, dst: str):
254 | if not os.path.exists(dst):
255 | tar = os.path.join(root, url.split("/")[-1])
256 | if not os.path.exists(tar):
257 | download_url(url, root)
258 | if os.path.exists(tar):
259 | _unpack(tar)
260 | _replace_folder(tar.rstrip(".tgz"), dst)
261 | _rename_files(dst, lambda x: x.split("_")[2])
262 |
263 | _proc(train_url, os.path.join(root, "train_rgb"))
264 | _proc(test_url, os.path.join(root, "test_rgb"))
265 |
266 |
267 | def download_seg(root: str):
268 | train_url = "https://github.com/ankurhanda/nyuv2-meta-data/raw/master/train_labels_13/nyuv2_train_class13.tgz"
269 | test_url = "https://github.com/ankurhanda/nyuv2-meta-data/raw/master/test_labels_13/nyuv2_test_class13.tgz"
270 |
271 | def _proc(url: str, dst: str):
272 | if not os.path.exists(dst):
273 | tar = os.path.join(root, url.split("/")[-1])
274 | if not os.path.exists(tar):
275 | download_url(url, root)
276 | if os.path.exists(tar):
277 | _unpack(tar)
278 | _replace_folder(tar.rstrip(".tgz"), dst)
279 | _rename_files(dst, lambda x: x.split("_")[3])
280 |
281 | _proc(train_url, os.path.join(root, "train_seg13"))
282 | _proc(test_url, os.path.join(root, "test_seg13"))
283 |
284 |
285 | def download_sn(root: str):
286 | url = "https://www.dropbox.com/s/dn5sxhlgml78l03/nyu_normals_gt.zip"
287 | train_dst = os.path.join(root, "train_sn")
288 | test_dst = os.path.join(root, "test_sn")
289 |
290 | if not os.path.exists(train_dst) or not os.path.exists(test_dst):
291 | tar = os.path.join(root, url.split("/")[-1])
292 | if not os.path.exists(tar):
293 | req = requests.get(url + "?dl=1") # dropbox
294 | with open(tar, 'wb') as f:
295 | f.write(req.content)
296 | if os.path.exists(tar):
297 | _unpack(tar)
298 | if not os.path.exists(train_dst):
299 | _replace_folder(
300 | os.path.join(root, "nyu_normals_gt", "train"), train_dst
301 | )
302 | _rename_files(train_dst, lambda x: x[1:])
303 | if not os.path.exists(test_dst):
304 | _replace_folder(os.path.join(root, "nyu_normals_gt", "test"), test_dst)
305 | _rename_files(test_dst, lambda x: x[1:])
306 | shutil.rmtree(os.path.join(root, "nyu_normals_gt"))
307 |
308 |
309 | def download_depth(root: str):
310 | url = (
311 | "http://horatio.cs.nyu.edu/mit/silberman/nyu_depth_v2/nyu_depth_v2_labeled.mat"
312 | )
313 | train_dst = os.path.join(root, "train_depth")
314 | test_dst = os.path.join(root, "test_depth")
315 |
316 | if not os.path.exists(train_dst) or not os.path.exists(test_dst):
317 | tar = os.path.join(root, url.split("/")[-1])
318 | if not os.path.exists(tar):
319 | download_url(url, root)
320 | if os.path.exists(tar):
321 | train_ids = [
322 | f.split(".")[0] for f in os.listdir(os.path.join(root, "train_rgb"))
323 | ]
324 | _create_depth_files(tar, root, train_ids)
325 |
326 |
327 | def _unpack(file: str):
328 | """
329 | Unpacks tar and zip, does nothing for any other type
330 | :param file: path of file
331 | """
332 | path = file.rsplit(".", 1)[0]
333 |
334 | if file.endswith(".tgz"):
335 | tar = tarfile.open(file, "r:gz")
336 | tar.extractall(path)
337 | tar.close()
338 | elif file.endswith(".zip"):
339 | zip = zipfile.ZipFile(file, "r")
340 | zip.extractall(path)
341 | zip.close()
342 |
343 |
344 | def _rename_files(folder: str, rename_func: callable):
345 | """
346 | Renames all files inside a folder based on the passed rename function
347 | :param folder: path to folder that contains files
348 | :param rename_func: function renaming filename (not including path) str -> str
349 | """
350 | imgs_old = os.listdir(folder)
351 | imgs_new = [rename_func(file) for file in imgs_old]
352 | for img_old, img_new in zip(imgs_old, imgs_new):
353 | shutil.move(os.path.join(folder, img_old), os.path.join(folder, img_new))
354 |
355 |
356 | def _replace_folder(src: str, dst: str):
357 | """
358 | Rename src into dst, replacing/overwriting dst if it exists.
359 | """
360 | if os.path.exists(dst):
361 | shutil.rmtree(dst)
362 | shutil.move(src, dst)
363 |
364 |
365 | def _create_depth_files(mat_file: str, root: str, train_ids: list):
366 | """
367 | Extract the depth arrays from the mat file into images
368 | :param mat_file: path to the official labelled dataset .mat file
369 | :param root: The root directory of the dataset
370 | :param train_ids: the IDs of the training images as string (for splitting)
371 | """
372 | os.mkdir(os.path.join(root, "train_depth"))
373 | os.mkdir(os.path.join(root, "test_depth"))
374 | train_ids = set(train_ids)
375 |
376 | depths = h5py.File(mat_file, "r")["depths"]
377 | for i in range(len(depths)):
378 | img = (depths[i] * 1e4).astype(np.uint16).T
379 | id_ = str(i + 1).zfill(4)
380 | folder = "train" if id_ in train_ids else "test"
381 | save_path = os.path.join(root, f"{folder}_depth", id_ + ".png")
382 | Image.fromarray(img).save(save_path)
383 |
--------------------------------------------------------------------------------
/script/data/nyuv2_same_batch.py:
--------------------------------------------------------------------------------
1 | """
2 | author: Mihai Suteu
3 | date: 15/05/19
4 | https://github.com/xapharius/pytorch-nyuv2
5 | """
6 |
7 |
8 | import os
9 | import sys
10 | import h5py
11 | import torch
12 | import shutil
13 | import random
14 | import tarfile
15 | import zipfile
16 | import requests
17 | import numpy as np
18 | from typing import Dict
19 |
20 | from PIL import Image
21 | from torch.utils.data import Dataset
22 | from torchvision.datasets.utils import download_url
23 |
24 | SEG = 0
25 | DEP = 1
26 | SN = 2
27 |
28 |
29 | class NYUv2SameBatchDataset(Dataset):
30 | """
31 | PyTorch wrapper for the NYUv2 dataset focused on multi-task learning.
32 | Data sources available: RGB, Semantic Segmentation, Surface Normals, Depth Images.
33 | If no transformation is provided, the image type will not be returned.
34 |
35 | ### Output
36 | All images are of size: 640 x 480
37 |
38 | 1. RGB: 3 channel input image
39 |
40 | 2. Semantic Segmentation: 1 channel representing one of the 14 (0 -
41 | background) classes. Conversion to int will happen automatically if
42 | transformation ends in a tensor. Task name: "segmentation"
43 |
44 | 3. Surface Normals: 3 channels, with values in [0, 1]. Task name: "surface_normals"
45 |
46 | 4. Depth Images: 1 channel with floats representing the distance in meters.
47 | Conversion will happen automatically if transformation ends in a tensor. Task name: "depth"
48 | """
49 |
50 | def __init__(
51 | self,
52 | root: str,
53 | tasks: Dict[int, str],
54 | train: bool = True,
55 | download: bool = False,
56 | rgb_transform=None,
57 | seg_transform=None,
58 | sn_transform=None,
59 | depth_transform=None,
60 | rgb_transform2=None,
61 | ):
62 | """
63 | Will return tuples based on what data source has been enabled (rgb, seg etc).
64 |
65 | :param root: path to root folder (eg /data/NYUv2)
66 | :param train: whether to load the train or test set
67 | :param download: whether to download and process data if missing
68 | :param rgb_transform: the transformation pipeline for rbg images
69 | :param seg_transform: the transformation pipeline for segmentation images. If
70 | the transformation ends in a tensor, the result will be automatically
71 | converted to int in [0, 14)
72 | :param sn_transform: the transformation pipeline for surface normal images
73 | :param depth_transform: the transformation pipeline for depth images. If the
74 | transformation ends in a tensor, the result will be automatically converted
75 | to meters
76 | """
77 | super().__init__()
78 | self.root = root
79 |
80 | self.rgb_transform = rgb_transform
81 | self.rgb_transform2 = rgb_transform2
82 | self.seg_transform = seg_transform
83 | self.depth_transform = depth_transform
84 | self.sn_transform = sn_transform
85 |
86 | self.train = train
87 | self._split = "train" if train else "test"
88 |
89 | if download:
90 | self.download()
91 |
92 |
93 | # rgb folder as ground truth
94 | self._files = sorted(os.listdir(os.path.join(root, f"{self._split}_rgb_pt")))
95 | self.num_img = len(self._files)
96 |
97 | self.num_tasks = len(tasks)
98 | self.tasks = tasks
99 |
100 | self.task_dict = self._get_task_dict()
101 |
102 | self.folder = lambda name: os.path.join(self.root, f"{self._split}_{name}_pt")
103 |
104 | self.seg_images = torch.load(f"{root}/combined/{self._split}_seg13_pt.pt")
105 | self.depth_images = torch.load(f"{root}/combined/{self._split}_depth_pt.pt")
106 |
107 |
108 | def __getitem__(self, index: int):
109 |
110 | rgb_image = index
111 | seed = random.randrange(sys.maxsize)
112 | rgb = None
113 | state = None
114 |
115 | if self.rgb_transform is not None:
116 | random.seed(seed)
117 | img = torch.load(os.path.join(self.folder("rgb"), self._files[rgb_image]))
118 | ### https://github.com/pytorch/vision/issues/9#issuecomment-789308878
119 | state = torch.get_rng_state()
120 | rgb = self.rgb_transform(img)
121 | if self.rgb_transform2 is not None:
122 | rgb = self.rgb_transform2(rgb)
123 |
124 | label_seg = self._get_task_label(0, rgb_image, state)
125 | label_depth = self._get_task_label(1, rgb_image, state)
126 |
127 | return torch.stack([rgb,rgb]), torch.stack([label_seg, label_depth]), torch.LongTensor([0,1])
128 |
129 | def _get_task_dict(self):
130 |
131 | task_dict = dict()
132 |
133 | for i in self.tasks.keys():
134 |
135 | task_type = self.tasks[i]
136 | if task_type == "segmentation":
137 | task_dict[i] = SEG
138 | elif task_type == "surface_normals":
139 | task_dict[i] = SN
140 | elif task_type == "depth":
141 | task_dict[i] = DEP
142 |
143 | return task_dict
144 |
145 |
146 | def _get_task_label(self, task, rgb_image, state):
147 | seed = random.randrange(sys.maxsize)
148 |
149 | task_type = self.task_dict[task]
150 | if task_type == SEG:
151 | if self.seg_transform is not None:
152 | random.seed(seed)
153 | img = self.seg_images[rgb_image, :,:,:]
154 | torch.set_rng_state(state)
155 | img = self.seg_transform(img)
156 | if isinstance(img, torch.Tensor):
157 | # ToTensor scales to [0, 1] by default
158 | img = (img * 255).long()
159 | return img
160 |
161 | if task_type == SN: # kontrol et
162 | if self.sn_transform is not None:
163 | random.seed(seed)
164 | img = self.rgb_images[rgb_image, :,:,:]
165 | torch.set_rng_state(state)
166 | img = self.sn_transform(img)
167 | return img
168 |
169 | if task_type == DEP:
170 | if self.depth_transform is not None:
171 | random.seed(seed)
172 | img = self.depth_images[rgb_image, :,:,:]
173 | torch.set_rng_state(state)
174 | img = self.depth_transform(img)
175 | if isinstance(img, torch.Tensor):
176 | # depth png is uint16
177 | img = img.float() # / 1e4
178 | return img
179 |
180 |
181 |
182 |
183 |
184 | def __len__(self):
185 | return len(self._files)
186 |
187 | def __repr__(self):
188 | fmt_str = f"Dataset {self.__class__.__name__}\n"
189 | fmt_str += f" Number of data points: {self.__len__()}\n"
190 | fmt_str += f" Split: {self._split}\n"
191 | fmt_str += f" Root Location: {self.root}\n"
192 | tmp = " RGB Transforms: "
193 | fmt_str += "{0}{1}\n".format(
194 | tmp, self.rgb_transform.__repr__().replace("\n", "\n" + " " * len(tmp))
195 | )
196 | tmp = " Seg Transforms: "
197 | fmt_str += "{0}{1}\n".format(
198 | tmp, self.seg_transform.__repr__().replace("\n", "\n" + " " * len(tmp))
199 | )
200 | tmp = " SN Transforms: "
201 | fmt_str += "{0}{1}\n".format(
202 | tmp, self.sn_transform.__repr__().replace("\n", "\n" + " " * len(tmp))
203 | )
204 | tmp = " Depth Transforms: "
205 | fmt_str += "{0}{1}\n".format(
206 | tmp, self.depth_transform.__repr__().replace("\n", "\n" + " " * len(tmp))
207 | )
208 | return fmt_str
209 |
210 | def _check_exists(self) -> bool:
211 | """
212 | Only checking for folder existence
213 | """
214 | try:
215 | for split in ["train", "test"]:
216 | for part, transform in zip(
217 | ["rgb", "seg13", "depth"],#"sn",
218 | [
219 | self.rgb_transform,
220 | self.seg_transform,
221 | self.sn_transform,
222 | self.depth_transform,
223 | ],
224 | ):
225 | if transform is None:
226 | continue
227 | path = os.path.join(self.root, f"{split}_{part}_pt")
228 | if not os.path.exists(path):
229 | raise FileNotFoundError("Missing Folder")
230 | except FileNotFoundError as e:
231 | return False
232 | return True
233 |
234 | def download(self):
235 | if self._check_exists():
236 | return
237 | if self.rgb_transform is not None:
238 | download_rgb(self.root)
239 | if self.seg_transform is not None:
240 | download_seg(self.root)
241 | if self.sn_transform is not None:
242 | download_sn(self.root)
243 | if self.depth_transform is not None:
244 | download_depth(self.root)
245 | print("Done!")
246 |
247 |
248 | def download_rgb(root: str):
249 | train_url = "http://www.doc.ic.ac.uk/~ahanda/nyu_train_rgb.tgz"
250 | test_url = "http://www.doc.ic.ac.uk/~ahanda/nyu_test_rgb.tgz"
251 |
252 | def _proc(url: str, dst: str):
253 | if not os.path.exists(dst):
254 | tar = os.path.join(root, url.split("/")[-1])
255 | if not os.path.exists(tar):
256 | download_url(url, root)
257 | if os.path.exists(tar):
258 | _unpack(tar)
259 | _replace_folder(tar.rstrip(".tgz"), dst)
260 | _rename_files(dst, lambda x: x.split("_")[2])
261 |
262 | _proc(train_url, os.path.join(root, "train_rgb"))
263 | _proc(test_url, os.path.join(root, "test_rgb"))
264 |
265 |
266 | def download_seg(root: str):
267 | train_url = "https://github.com/ankurhanda/nyuv2-meta-data/raw/master/train_labels_13/nyuv2_train_class13.tgz"
268 | test_url = "https://github.com/ankurhanda/nyuv2-meta-data/raw/master/test_labels_13/nyuv2_test_class13.tgz"
269 |
270 | def _proc(url: str, dst: str):
271 | if not os.path.exists(dst):
272 | tar = os.path.join(root, url.split("/")[-1])
273 | if not os.path.exists(tar):
274 | download_url(url, root)
275 | if os.path.exists(tar):
276 | _unpack(tar)
277 | _replace_folder(tar.rstrip(".tgz"), dst)
278 | _rename_files(dst, lambda x: x.split("_")[3])
279 |
280 | _proc(train_url, os.path.join(root, "train_seg13"))
281 | _proc(test_url, os.path.join(root, "test_seg13"))
282 |
283 |
284 | def download_sn(root: str):
285 | url = "https://www.dropbox.com/s/dn5sxhlgml78l03/nyu_normals_gt.zip"
286 | train_dst = os.path.join(root, "train_sn")
287 | test_dst = os.path.join(root, "test_sn")
288 |
289 | if not os.path.exists(train_dst) or not os.path.exists(test_dst):
290 | tar = os.path.join(root, url.split("/")[-1])
291 | if not os.path.exists(tar):
292 | req = requests.get(url + "?dl=1") # dropbox
293 | with open(tar, 'wb') as f:
294 | f.write(req.content)
295 | if os.path.exists(tar):
296 | _unpack(tar)
297 | if not os.path.exists(train_dst):
298 | _replace_folder(
299 | os.path.join(root, "nyu_normals_gt", "train"), train_dst
300 | )
301 | _rename_files(train_dst, lambda x: x[1:])
302 | if not os.path.exists(test_dst):
303 | _replace_folder(os.path.join(root, "nyu_normals_gt", "test"), test_dst)
304 | _rename_files(test_dst, lambda x: x[1:])
305 | shutil.rmtree(os.path.join(root, "nyu_normals_gt"))
306 |
307 |
308 | def download_depth(root: str):
309 | url = (
310 | "http://horatio.cs.nyu.edu/mit/silberman/nyu_depth_v2/nyu_depth_v2_labeled.mat"
311 | )
312 | train_dst = os.path.join(root, "train_depth")
313 | test_dst = os.path.join(root, "test_depth")
314 |
315 | if not os.path.exists(train_dst) or not os.path.exists(test_dst):
316 | tar = os.path.join(root, url.split("/")[-1])
317 | if not os.path.exists(tar):
318 | download_url(url, root)
319 | if os.path.exists(tar):
320 | train_ids = [
321 | f.split(".")[0] for f in os.listdir(os.path.join(root, "train_rgb"))
322 | ]
323 | _create_depth_files(tar, root, train_ids)
324 |
325 |
326 | def _unpack(file: str):
327 | """
328 | Unpacks tar and zip, does nothing for any other type
329 | :param file: path of file
330 | """
331 | path = file.rsplit(".", 1)[0]
332 |
333 | if file.endswith(".tgz"):
334 | tar = tarfile.open(file, "r:gz")
335 | tar.extractall(path)
336 | tar.close()
337 | elif file.endswith(".zip"):
338 | zip = zipfile.ZipFile(file, "r")
339 | zip.extractall(path)
340 | zip.close()
341 |
342 |
343 | def _rename_files(folder: str, rename_func: callable):
344 | """
345 | Renames all files inside a folder based on the passed rename function
346 | :param folder: path to folder that contains files
347 | :param rename_func: function renaming filename (not including path) str -> str
348 | """
349 | imgs_old = os.listdir(folder)
350 | imgs_new = [rename_func(file) for file in imgs_old]
351 | for img_old, img_new in zip(imgs_old, imgs_new):
352 | shutil.move(os.path.join(folder, img_old), os.path.join(folder, img_new))
353 |
354 |
355 | def _replace_folder(src: str, dst: str):
356 | """
357 | Rename src into dst, replacing/overwriting dst if it exists.
358 | """
359 | if os.path.exists(dst):
360 | shutil.rmtree(dst)
361 | shutil.move(src, dst)
362 |
363 |
364 | def _create_depth_files(mat_file: str, root: str, train_ids: list):
365 | """
366 | Extract the depth arrays from the mat file into images
367 | :param mat_file: path to the official labelled dataset .mat file
368 | :param root: The root directory of the dataset
369 | :param train_ids: the IDs of the training images as string (for splitting)
370 | """
371 | os.mkdir(os.path.join(root, "train_depth"))
372 | os.mkdir(os.path.join(root, "test_depth"))
373 | train_ids = set(train_ids)
374 |
375 | depths = h5py.File(mat_file, "r")["depths"]
376 | for i in range(len(depths)):
377 | img = (depths[i] * 1e4).astype(np.uint16).T
378 | id_ = str(i + 1).zfill(4)
379 | folder = "train" if id_ in train_ids else "test"
380 | save_path = os.path.join(root, f"{folder}_depth", id_ + ".png")
381 | Image.fromarray(img).save(save_path)
382 |
--------------------------------------------------------------------------------
/script/train_nyu.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision
3 | import torchvision.transforms as transforms
4 | from torch.utils.data import DataLoader
5 | import torch.optim as optim
6 | import torch.nn.functional as F
7 | from torch.utils.tensorboard import SummaryWriter
8 | import transformers
9 | from tqdm import tqdm
10 | import numpy as np
11 | import os
12 | import pickle
13 | import cv2
14 | import json
15 | import argparse
16 |
17 | from data.nyuv2_same_batch import NYUv2SameBatchDataset
18 | from model.swin_transformer import SwinTransformer
19 | from loss.losses import berHuLoss
20 | from loss.metrics import iou_pytorch, eval_depth
21 |
22 |
23 | def get_config():
24 | parser = argparse.ArgumentParser(description='Train the network')
25 | parser.add_argument('--config', help='train config file path')
26 |
27 | args = parser.parse_args()
28 |
29 | with open(args.config, "r") as jsonfile:
30 | config = json.load(jsonfile)
31 |
32 | return config
33 |
34 | def freeze_encoder_layers(
35 | model,
36 | conditioned_blocks= [[], [], [*range(12, 18)], []],
37 | unfrozen_modules=[
38 | "random_weight_matrix",
39 | "film.gb_weights",
40 | "ln_weight_modulation.gb_weights",
41 | "adapter",
42 | "task_type_embeddings",
43 | "patch_embed",
44 | "decoder",
45 | "bottleneck"
46 | ],
47 | frozen_encoder = False
48 | ):
49 | for name, param in model.named_parameters():
50 | param.requires_grad = not frozen_encoder
51 |
52 | for module in unfrozen_modules:
53 | if module in name:
54 | param.requires_grad = True
55 |
56 | if name.startswith("layers"):
57 | splitted = name.split(".")
58 |
59 | if len(conditioned_blocks[int(splitted[1])]) > 0 and splitted[2]=="blocks" and (int(splitted[3]) in conditioned_blocks[int(splitted[1])]):
60 | param.requires_grad = True
61 | elif name.startswith("norm"):
62 | param.requires_grad = True
63 |
64 | def disp2meters(d):
65 | return (65536.0 / d - 1 ) / 1e4
66 |
67 | def calc_seg_metrics(logit_task, label_task):
68 |
69 | max_labels = torch.argmax(logit_task, dim = 1, keepdim=True)
70 | iou = iou_pytorch(max_labels, label_task)
71 |
72 | return max_labels, iou
73 |
74 | def train(model, train_loader, test_loader, optimizer, scheduler, criterion, epochs, tensorboard_name, seg_weight, depth_weight, start_epoch = 0, device = "cuda", tb_writer=None):
75 |
76 | # Training loop
77 | model.train()
78 |
79 | iters = len(train_loader)
80 |
81 | for e in tqdm(range(epochs)):
82 |
83 |
84 | epoch = e + start_epoch + 1
85 |
86 | epoch_loss = 0.0
87 | epoch_loss_seg = []
88 | epoch_loss_depth = []
89 |
90 | train_ious = []
91 | train_depths_rmse = []
92 | train_depths_d1 = []
93 |
94 | for i, (img, label, task_id) in enumerate(train_loader, 0):
95 | model.train()
96 |
97 | img = img.view((-1, 3, 224, 224)).to(device)
98 | label = label.view((-1, 1, 224, 224)).to(device)
99 | task_id = task_id.view(-1).to(device)
100 |
101 | logits, unique_task_ids_list = model(img, task_id)
102 |
103 | loss = 0
104 |
105 | for j, unique_task_id in enumerate(unique_task_ids_list):
106 |
107 | task_id_filter = task_id == unique_task_id
108 |
109 | logit_task = logits[j]
110 | label_task = label[task_id_filter]
111 |
112 | B = logit_task.shape[0]
113 |
114 |
115 | # Task is segmentation
116 | if unique_task_id == 0:
117 | label_task = label_task.long()
118 |
119 | a = criterion[unique_task_id](logit_task.view(B,14,-1), label_task.view(B,-1))
120 | epoch_loss_seg.append(a.item())
121 |
122 | loss += a * seg_weight
123 |
124 | # compute metrics every 10 epochs
125 | if epoch%10==0:
126 | max_labels, iou = calc_seg_metrics(logit_task, label_task)
127 | train_ious.append(iou.cpu().numpy())
128 |
129 | else:
130 | label_task = 65536.0 / (label_task + 1)
131 | logit_task = torch.nn.functional.sigmoid(logit_task)*65535 + 1
132 |
133 | a = criterion[unique_task_id](logit_task, label_task)
134 | epoch_loss_depth.append(a.item())
135 |
136 | loss += a* depth_weight
137 |
138 |
139 |
140 | # compute metrics every 10 epochs
141 | if epoch%10==0:
142 | evaluation = eval_depth(disp2meters(logit_task), disp2meters(label_task))
143 |
144 | train_depths_rmse.append(evaluation["rmse"])
145 | train_depths_d1.append(evaluation["d1"])
146 |
147 |
148 | optimizer.zero_grad()
149 | loss.backward()
150 | optimizer.step()
151 |
152 |
153 | epoch_loss += loss.item()
154 |
155 | scheduler.step()
156 |
157 |
158 | # Compute validation metrics every 5 epochs
159 | if epoch % 5==0:
160 |
161 | test_loss = 0
162 | epoch_ious = []
163 | epoch_eval_depths_rmse = []
164 | epoch_eval_depths_d1 = []
165 |
166 | epoch_loss_seg_test = []
167 | epoch_loss_depth_test = []
168 |
169 | model.eval()
170 | for i, (img, label, task_id) in enumerate(test_loader, 0):
171 |
172 | img = img.view((-1, 3, 224, 224)).to(device)
173 | label = label.view((-1, 1, 224, 224)).to(device)
174 | task_id = task_id.view(-1).to(device)
175 |
176 | logits, unique_task_ids_list = model(img, task_id)
177 |
178 | loss = 0
179 |
180 | for j, unique_task_id in enumerate(unique_task_ids_list):
181 |
182 |
183 | task_id_filter = task_id == unique_task_id
184 |
185 | logit_task = logits[j]
186 | label_task = label[task_id_filter]
187 | B = logit_task.shape[0]
188 |
189 | if unique_task_id == 0:
190 |
191 | label_task = label_task.long()
192 |
193 | a = criterion[unique_task_id](logit_task.view(B,14,-1), label_task.long().view(B,-1))
194 | epoch_loss_seg_test.append(a.item())
195 |
196 | loss += a * seg_weight
197 |
198 | max_labels, iou = calc_seg_metrics(logit_task, label_task)
199 |
200 | epoch_ious.append(iou.cpu().numpy())
201 |
202 | else:
203 |
204 |
205 | logit_task = torch.nn.functional.sigmoid(logit_task)*65535 + 1
206 | label_task = 65536.0 / (label_task + 1)
207 |
208 | a = criterion[unique_task_id](logit_task, label_task)#* len(logit_task)
209 | epoch_loss_depth_test.append(a.item())
210 |
211 | loss += a* depth_weight
212 |
213 | evaluation = eval_depth(disp2meters(logit_task), disp2meters(label_task))
214 | epoch_eval_depths_rmse.append(evaluation["rmse"])
215 | epoch_eval_depths_d1.append(evaluation["d1"])
216 |
217 | test_loss += loss.item()
218 |
219 | tb_writer.add_scalar(f"{tensorboard_name}/learning_rate", scheduler.get_last_lr()[0] , epoch)
220 | tb_writer.add_scalar(f"{tensorboard_name}/train_loss", epoch_loss/ len(train_loader) , epoch)
221 | tb_writer.add_scalar(f"{tensorboard_name}/test_loss", test_loss/len(test_loader) , epoch)
222 | tb_writer.add_scalar(f"{tensorboard_name}/mean_iou", np.mean(epoch_ious) , epoch)
223 | tb_writer.add_scalar(f"{tensorboard_name}/depth_rmse", np.mean(epoch_eval_depths_rmse) , epoch)
224 | tb_writer.add_scalar(f"{tensorboard_name}/depth_d1", np.mean(epoch_eval_depths_d1) , epoch)
225 | tb_writer.add_scalar(f"{tensorboard_name}/seg_loss", np.mean(epoch_loss_seg) , epoch)
226 | tb_writer.add_scalar(f"{tensorboard_name}/depth_loss", np.mean(epoch_loss_depth) , epoch)
227 | tb_writer.add_scalar(f"{tensorboard_name}/seg_loss_test", np.mean(epoch_loss_seg_test) , epoch)
228 | tb_writer.add_scalar(f"{tensorboard_name}/depth_loss_test", np.mean(epoch_loss_depth_test) , epoch)
229 |
230 | # Save training metrics every 10 epochs
231 | if epoch%10 == 0:
232 | tb_writer.add_scalar(f"{tensorboard_name}/train_mean_iou", np.mean(train_ious) , epoch)
233 | tb_writer.add_scalar(f"{tensorboard_name}/train_depth_rmse", np.mean(train_depths_rmse) , epoch)
234 | tb_writer.add_scalar(f"{tensorboard_name}/train_depth_d1", np.mean(train_depths_d1) , epoch)
235 |
236 |
237 |
238 | # save the model every 500 epochs
239 | if epoch % 500 == 0 or epoch == (epochs-1):
240 | torch.save({
241 | 'epoch': epoch,
242 | 'model_state_dict': model.state_dict(),
243 | 'optimizer_state_dict': optimizer.state_dict(),
244 | 'scheduler_state_dict': scheduler.state_dict(),
245 | }, f"{tensorboard_name}.pt")
246 |
247 |
248 | def load_model(model, optimizer, scheduler, PATH):
249 | checkpoint = torch.load(PATH, map_location=device)
250 | model.load_state_dict(checkpoint['model_state_dict'])
251 | model = model.to(device)
252 |
253 | optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
254 | scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
255 | epoch = checkpoint['epoch']
256 | return model, optimizer, scheduler, epoch
257 |
258 | def get_dataloaders(tasks, batch_size):
259 |
260 | IMAGE_SIZE = (480, 640)
261 |
262 | train_t = torch.nn.Sequential(transforms.RandomResizedCrop(224, scale=(0.5, 1.0)), transforms.RandomHorizontalFlip())
263 | test_t = torch.nn.Sequential(transforms.CenterCrop(480), transforms.Resize(224))
264 | train_t_input_image = torch.nn.Sequential(transforms.ColorJitter(brightness=(0.8, 1.2),contrast =(0.8, 1.2), saturation=(0.8, 1.2), hue=(-0.1,0.1)))
265 |
266 | train_dataset = NYUv2SameBatchDataset(root="./data/nyuv2", tasks=tasks, download=False, train=True,
267 | rgb_transform=train_t, rgb_transform2=train_t_input_image, seg_transform=train_t, sn_transform=train_t, depth_transform=train_t)
268 |
269 | test_dataset = NYUv2SameBatchDataset(root="./data/nyuv2", tasks=tasks, download=False, train=False,
270 | rgb_transform=test_t, seg_transform=test_t, sn_transform=test_t, depth_transform=test_t)
271 |
272 | print("Train dataset size:", len(train_dataset))
273 | print("Test dataset size:", len(test_dataset))
274 |
275 | train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
276 | test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
277 |
278 | return train_dataloader, test_dataloader
279 |
280 |
281 |
282 | def main():
283 |
284 | config = get_config()
285 |
286 | tb_writer = SummaryWriter(f'runs/{config["experiment_name"]}')
287 |
288 | torch.manual_seed(61)
289 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
290 |
291 | tasks = {0:"segmentation", 1:"depth"}
292 |
293 | batch_size = config["batch_size"]
294 |
295 | print("Creating datasets...")
296 | train_dataloader, test_dataloader = get_dataloaders(tasks, batch_size)
297 |
298 | print("Loading model...")
299 |
300 | model = SwinTransformer(img_size=224,
301 | patch_size=4,
302 | in_chans=3,
303 | num_classes=21841,
304 | embed_dim=96,
305 | depths=[2, 2, 18, 2 ],
306 | depths_decoder =[2, 2, 2, 2 ],
307 | num_heads=[ 3, 6, 12, 24 ],
308 | window_size=7,
309 | mlp_ratio=4.,
310 | qkv_bias=True,
311 | qk_scale=True,
312 | drop_rate=0,
313 | drop_rate_decoder=0.6,
314 | drop_path_rate=0.2,
315 | ape=False,
316 | patch_norm=True,
317 | use_checkpoint=False,
318 | tasks = ["segmentation", "depth"],
319 | task_classes = [14, 1],
320 | conditioned_blocks = config["conditioned_blocks"],
321 | adapter = config["adapter"],
322 | use_conditional_layer = config["use_conditional_layer_norm"])
323 |
324 | epochs = config["epochs"]
325 | optimizer = optim.AdamW(model.parameters(), lr=2e-5, betas=(0.9, 0.98))
326 |
327 |
328 | scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr = 4e-5, epochs = epochs, steps_per_epoch = len(train_dataloader), pct_start = 0.1)
329 |
330 |
331 | if config["continue_training"]:
332 | model, optimizer, scheduler, start_epoch = load_model(model, optimizer, scheduler, config["experiment_name"]+".pt")
333 | print("Continue model loaded")
334 |
335 | else:
336 | start_epoch = -1
337 | model.load_state_dict(torch.load('./pretrained/swin_small_patch4_window7_224_22k.pth')['model'],strict=False)
338 |
339 | model = model.to(device)
340 | print("Pretrained model loaded")
341 |
342 |
343 | freeze_encoder_layers(model, conditioned_blocks = config["conditioned_blocks"], frozen_encoder = config["frozen_encoder"])
344 | model = model.to(device)
345 |
346 | criterion = []
347 | segmentation_criteon = torch.nn.CrossEntropyLoss()
348 | criterion.append(segmentation_criteon)
349 |
350 | depth_criterion = berHuLoss()
351 | criterion.append(depth_criterion)
352 |
353 | print("Training",config["experiment_name"],"...")
354 |
355 | train(model, train_dataloader, test_dataloader, optimizer, scheduler, criterion, epochs, config["experiment_name"], config["seg_weight"], config["depth_weight"], start_epoch, device, tb_writer)
356 |
357 |
358 | if __name__ == '__main__':
359 | main()
--------------------------------------------------------------------------------
/script/train_nyu_single_task.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision
3 | import torchvision.transforms as transforms
4 | from torch.utils.data import DataLoader
5 | import torch.optim as optim
6 | import torch.nn.functional as F
7 | from torch.utils.tensorboard import SummaryWriter
8 | import transformers
9 | from tqdm import tqdm
10 | import numpy as np
11 | import os
12 | import pickle
13 | import cv2
14 | import json
15 | import argparse
16 |
17 | from data.nyuv2 import NYUv2Dataset
18 | from model.swin_transformer import SwinTransformer
19 | from loss.losses import berHuLoss
20 | from loss.metrics import iou_pytorch, eval_depth
21 |
22 |
23 | def get_config():
24 | parser = argparse.ArgumentParser(description='Train the network')
25 | parser.add_argument('--config', help='train config file path')
26 |
27 | args = parser.parse_args()
28 |
29 | with open(args.config, "r") as jsonfile:
30 | config = json.load(jsonfile)
31 |
32 | return config
33 |
34 | def freeze_encoder_layers(
35 | model,
36 | conditioned_blocks= [[], [], [*range(12, 18)], []],
37 | unfrozen_modules=[
38 | "random_weight_matrix",
39 | "film.gb_weights",
40 | "ln_weight_modulation.gb_weights",
41 | "adapter",
42 | "task_type_embeddings",
43 | "patch_embed",
44 | "decoder",
45 | "bottleneck"
46 | ],
47 | frozen_encoder = False
48 | ):
49 | for name, param in model.named_parameters():
50 | param.requires_grad = not frozen_encoder
51 |
52 | for module in unfrozen_modules:
53 | if module in name:
54 | param.requires_grad = True
55 |
56 | if name.startswith("layers"):
57 | splitted = name.split(".")
58 |
59 | if len(conditioned_blocks[int(splitted[1])]) > 0 and splitted[2]=="blocks" and (int(splitted[3]) in conditioned_blocks[int(splitted[1])]):
60 | param.requires_grad = True
61 | elif name.startswith("norm"):
62 | param.requires_grad = True
63 |
64 | def disp2meters(d):
65 | return (65536.0 / d - 1 ) / 1e4
66 |
67 | def calc_seg_metrics(logit_task, label_task):
68 |
69 | max_labels = torch.argmax(logit_task, dim = 1, keepdim=True)
70 | iou = iou_pytorch(max_labels, label_task)
71 |
72 | return max_labels, iou
73 |
74 | def train(model, train_loader, test_loader, optimizer, scheduler, criterion, epochs, tensorboard_name, start_epoch = 0, device = "cuda", tb_writer=None, task="segmentation"):
75 |
76 | # Training loop
77 | model.train()
78 |
79 | iters = len(train_loader)
80 |
81 | for e in tqdm(range(epochs)):
82 |
83 |
84 | epoch = e + start_epoch + 1
85 |
86 | epoch_loss = 0.0
87 | epoch_loss_seg = []
88 | epoch_loss_depth = []
89 |
90 | train_ious = []
91 | train_depths_rmse = []
92 | train_depths_d1 = []
93 |
94 | for i, (img, label, task_id) in enumerate(train_loader, 0):
95 | model.train()
96 |
97 | img = img.view((-1, 3, 224, 224)).to(device)
98 | label = label.view((-1, 1, 224, 224)).to(device)
99 | task_id = torch.zeros_like(task_id.view(-1).to(device))
100 |
101 | logits, unique_task_ids_list = model(img, task_id)
102 |
103 | loss = 0
104 |
105 | for j, unique_task_id in enumerate(unique_task_ids_list):
106 |
107 | task_id_filter = task_id == unique_task_id
108 |
109 | logit_task = logits[j]
110 | label_task = label[task_id_filter]
111 |
112 | B = logit_task.shape[0]
113 |
114 |
115 | # Task is segmentation
116 | if task == "segmentation":
117 | label_task = label_task.long()
118 |
119 | a = criterion[0](logit_task.view(B,14,-1), label_task.view(B,-1))
120 | epoch_loss_seg.append(a.item())
121 |
122 | loss += a
123 |
124 | # compute metrics every 10 epochs
125 | if epoch%10==0:
126 | max_labels, iou = calc_seg_metrics(logit_task, label_task)
127 | train_ious.append(iou.cpu().numpy())
128 |
129 | else:
130 | label_task = 65536.0 / (label_task + 1)
131 | logit_task = torch.nn.functional.sigmoid(logit_task)*65535 + 1
132 |
133 | a = criterion[1](logit_task, label_task)
134 | epoch_loss_depth.append(a.item())
135 |
136 | loss += a
137 |
138 |
139 |
140 | # compute metrics every 10 epochs
141 | if epoch%10==0:
142 | evaluation = eval_depth(disp2meters(logit_task), disp2meters(label_task))
143 |
144 | train_depths_rmse.append(evaluation["rmse"])
145 | train_depths_d1.append(evaluation["d1"])
146 |
147 |
148 | optimizer.zero_grad()
149 | loss.backward()
150 | optimizer.step()
151 |
152 |
153 | epoch_loss += loss.item()
154 |
155 | scheduler.step()
156 |
157 |
158 | # Compute validation metrics every 5 epochs
159 | if epoch % 5==0:
160 |
161 | test_loss = 0
162 | epoch_ious = []
163 | epoch_eval_depths_rmse = []
164 | epoch_eval_depths_d1 = []
165 |
166 | epoch_loss_seg_test = []
167 | epoch_loss_depth_test = []
168 |
169 | model.eval()
170 | for i, (img, label, task_id) in enumerate(test_loader, 0):
171 |
172 | img = img.view((-1, 3, 224, 224)).to(device)
173 | label = label.view((-1, 1, 224, 224)).to(device)
174 | task_id = torch.zeros_like(task_id.view(-1).to(device))
175 |
176 | logits, unique_task_ids_list = model(img, task_id)
177 |
178 | loss = 0
179 |
180 | for j, unique_task_id in enumerate(unique_task_ids_list):
181 |
182 |
183 | task_id_filter = task_id == unique_task_id
184 |
185 | logit_task = logits[j]
186 | label_task = label[task_id_filter]
187 | B = logit_task.shape[0]
188 |
189 | if task == "segmentation":
190 |
191 | label_task = label_task.long()
192 |
193 | a = criterion[0](logit_task.view(B,14,-1), label_task.long().view(B,-1))
194 | epoch_loss_seg_test.append(a.item())
195 |
196 | loss += a
197 |
198 | max_labels, iou = calc_seg_metrics(logit_task, label_task)
199 |
200 | epoch_ious.append(iou.cpu().numpy())
201 |
202 | else:
203 |
204 |
205 | logit_task = torch.nn.functional.sigmoid(logit_task)*65535 + 1
206 | label_task = 65536.0 / (label_task + 1)
207 |
208 | a = criterion[1](logit_task, label_task)
209 | epoch_loss_depth_test.append(a.item())
210 |
211 | loss += a
212 |
213 | evaluation = eval_depth(disp2meters(logit_task), disp2meters(label_task))
214 | epoch_eval_depths_rmse.append(evaluation["rmse"])
215 | epoch_eval_depths_d1.append(evaluation["d1"])
216 |
217 | test_loss += loss.item()
218 |
219 | tb_writer.add_scalar(f"{tensorboard_name}/learning_rate", scheduler.get_last_lr()[0] , epoch)
220 | tb_writer.add_scalar(f"{tensorboard_name}/train_loss", epoch_loss/ len(train_loader) , epoch)
221 | tb_writer.add_scalar(f"{tensorboard_name}/test_loss", test_loss/len(test_loader) , epoch)
222 | tb_writer.add_scalar(f"{tensorboard_name}/mean_iou", np.mean(epoch_ious) , epoch)
223 | tb_writer.add_scalar(f"{tensorboard_name}/depth_rmse", np.mean(epoch_eval_depths_rmse) , epoch)
224 | tb_writer.add_scalar(f"{tensorboard_name}/depth_d1", np.mean(epoch_eval_depths_d1) , epoch)
225 | tb_writer.add_scalar(f"{tensorboard_name}/seg_loss", np.mean(epoch_loss_seg) , epoch)
226 | tb_writer.add_scalar(f"{tensorboard_name}/depth_loss", np.mean(epoch_loss_depth) , epoch)
227 | tb_writer.add_scalar(f"{tensorboard_name}/seg_loss_test", np.mean(epoch_loss_seg_test) , epoch)
228 | tb_writer.add_scalar(f"{tensorboard_name}/depth_loss_test", np.mean(epoch_loss_depth_test) , epoch)
229 |
230 | # Save training metrics every 10 epochs
231 | if epoch%10 == 0:
232 | tb_writer.add_scalar(f"{tensorboard_name}/train_mean_iou", np.mean(train_ious) , epoch)
233 | tb_writer.add_scalar(f"{tensorboard_name}/train_depth_rmse", np.mean(train_depths_rmse) , epoch)
234 | tb_writer.add_scalar(f"{tensorboard_name}/train_depth_d1", np.mean(train_depths_d1) , epoch)
235 |
236 |
237 |
238 | # save the model every 500 epochs
239 | if epoch % 500 == 0 or epoch == ((epochs)-1):
240 | torch.save({
241 | 'epoch': epoch,
242 | 'model_state_dict': model.state_dict(),
243 | 'optimizer_state_dict': optimizer.state_dict(),
244 | 'scheduler_state_dict': scheduler.state_dict(),
245 | }, f"{tensorboard_name}.pt")
246 |
247 |
248 | def load_model(model, optimizer, scheduler, PATH):
249 | checkpoint = torch.load(PATH, map_location=device)
250 | model.load_state_dict(checkpoint['model_state_dict'])
251 | model = model.to(device)
252 |
253 | optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
254 | scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
255 | epoch = checkpoint['epoch']
256 | return model, optimizer, scheduler, epoch
257 |
258 | def get_dataloaders(tasks, task, batch_size):
259 |
260 | IMAGE_SIZE = (480, 640)
261 |
262 | train_t = torch.nn.Sequential(transforms.RandomResizedCrop(224, scale=(0.5, 1.0)), transforms.RandomHorizontalFlip())
263 | test_t = torch.nn.Sequential(transforms.CenterCrop(480), transforms.Resize(224))
264 | train_t_input_image = torch.nn.Sequential(transforms.ColorJitter(brightness=(0.8, 1.2),contrast =(0.8, 1.2), saturation=(0.8, 1.2), hue=(-0.1,0.1)))
265 |
266 | train_dataset = NYUv2Dataset(root="./data/nyuv2", tasks=tasks, download=False, train=True,
267 | rgb_transform=train_t, rgb_transform2=train_t_input_image, seg_transform=train_t, sn_transform=train_t, depth_transform=train_t)
268 |
269 | test_dataset = NYUv2Dataset(root="./data/nyuv2", tasks=tasks, download=False, train=False,
270 | rgb_transform=test_t, seg_transform=test_t, sn_transform=test_t, depth_transform=test_t)
271 |
272 | if task == "segmentation":
273 | train_dataset = torch.utils.data.Subset(train_dataset, range(len(train_dataset)//2))
274 | test_dataset = torch.utils.data.Subset(test_dataset, range(len(test_dataset)//2))
275 |
276 | if task == "depth":
277 |
278 | train_dataset = torch.utils.data.Subset(train_dataset, range(len(train_dataset)//2, len(train_dataset)))
279 | test_dataset = torch.utils.data.Subset(test_dataset, range(len(test_dataset)//2, len(test_dataset)))
280 |
281 |
282 | print("Train dataset size:", len(train_dataset))
283 | print("Test dataset size:", len(test_dataset))
284 |
285 | train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
286 | test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
287 |
288 | return train_dataloader, test_dataloader
289 |
290 |
291 |
292 | def main():
293 | # default `log_dir` is "runs" - we'll be more specific here
294 |
295 | config = get_config()
296 |
297 | tb_writer = SummaryWriter(f'runs/{config["experiment_name"]}')
298 |
299 | torch.manual_seed(61)
300 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
301 |
302 | tasks = {0:"segmentation", 1:"depth"}
303 | batch_size = config["batch_size"]
304 |
305 | print("Creating datasets...")
306 | train_dataloader, test_dataloader = get_dataloaders(tasks, config["task"], batch_size)
307 |
308 | print("Loading model...")
309 |
310 | model = SwinTransformer(img_size=224,
311 | patch_size=4,
312 | in_chans=3,
313 | num_classes=21841,
314 | embed_dim=96,
315 | depths=[2, 2, 18, 2 ],
316 | depths_decoder =[2, 2, 2, 2 ],
317 | num_heads=[ 3, 6, 12, 24 ],
318 | window_size=7,
319 | mlp_ratio=4.,
320 | qkv_bias=True,
321 | qk_scale=True,
322 | drop_rate=0,
323 | drop_rate_decoder=0.6,
324 | drop_path_rate=0.2,
325 | ape=False,
326 | patch_norm=True,
327 | use_checkpoint=False,
328 | tasks = [config["task"]],
329 | task_classes = [14 if config["task"]=="segmentation" else 1],
330 | conditioned_blocks = [[],[],[],[]],
331 | adapter = False,
332 | use_conditional_layer = False)
333 |
334 | epochs = config["epochs"]
335 | optimizer = optim.AdamW(model.parameters(), lr=2e-5, betas=(0.9, 0.98))
336 |
337 |
338 | scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr = 4e-5, epochs = epochs, steps_per_epoch = len(train_dataloader), pct_start = 0.1)
339 |
340 | scheduler_batch_step = True
341 | use_scheduler = True
342 |
343 | if config["continue_training"]:
344 | model, optimizer, scheduler, start_epoch = load_model(model, optimizer, scheduler, config["experiment_name"]+".pt")
345 | print("Continue model loaded")
346 |
347 | else:
348 | start_epoch = -1
349 | model.load_state_dict(torch.load('./pretrained/swin_small_patch4_window7_224_22k.pth')['model'],strict=False)
350 |
351 | model = model.to(device)
352 | print("Pretrained model loaded")
353 |
354 |
355 | freeze_encoder_layers(model, conditioned_blocks = [[],[],[],[]], frozen_encoder = config["frozen_encoder"])
356 | model = model.to(device)
357 |
358 | criterion = []
359 | segmentation_criteon = torch.nn.CrossEntropyLoss()
360 | criterion.append(segmentation_criteon)
361 |
362 | depth_criterion = berHuLoss()
363 | criterion.append(depth_criterion)
364 |
365 | print("Training",config["experiment_name"],"...")
366 |
367 |
368 | train(model, train_dataloader, test_dataloader, optimizer, scheduler, criterion, epochs, config["experiment_name"], start_epoch, device, tb_writer, config["task"])
369 |
370 |
371 | if __name__ == '__main__':
372 | main()
--------------------------------------------------------------------------------
/script/train_taskonomy.py:
--------------------------------------------------------------------------------
1 | from data.taskonomy.taskonomy_dataset_s3 import TaskonomyDatasetS3
2 | import torch
3 | import torchvision
4 | import torchvision.transforms as transforms
5 | from torch.utils.data import DataLoader
6 | import torch.optim as optim
7 | import torch.nn.functional as F
8 | from torch.utils.tensorboard import SummaryWriter
9 | import transformers
10 | from tqdm import tqdm
11 | import numpy as np
12 | import os
13 | import pickle
14 | import cv2
15 | import json
16 | import argparse
17 |
18 | # Small adjustments are made to the Swin Transformer model for parallel run, not related to architecture
19 | from model.swin_transformer_parallel import SwinTransformer
20 | from loss.losses import berHuLoss
21 | from loss.metrics import iou_pytorch, eval_depth
22 | import transformers
23 |
24 |
25 | unique_task_ids_list = [0,1]
26 |
27 | def get_config():
28 | parser = argparse.ArgumentParser(description='Train the network')
29 | parser.add_argument('--config', help='train config file path')
30 |
31 | args = parser.parse_args()
32 |
33 | with open(args.config, "r") as jsonfile:
34 | config = json.load(jsonfile)
35 |
36 | return config
37 |
38 | def freeze_encoder_layers(
39 | model,
40 | conditioned_blocks= [[], [], [*range(12, 18)], []],
41 | unfrozen_modules=[
42 | "random_weight_matrix",
43 | "film.gb_weights",
44 | "ln_weight_modulation.gb_weights",
45 | "adapter",
46 | "task_type_embeddings",
47 | "patch_embed",
48 | "decoder",
49 | "bottleneck"
50 | ],
51 | frozen_encoder = False
52 | ):
53 | for name, param in model.named_parameters():
54 | param.requires_grad = not frozen_encoder # remove 'not' for a frozen encoder
55 |
56 | for module in unfrozen_modules:
57 | if module in name:
58 | param.requires_grad = True
59 |
60 | if name.startswith("layers"):
61 | splitted = name.split(".")
62 |
63 | if len(conditioned_blocks[int(splitted[1])]) > 0 and splitted[2]=="blocks" and (int(splitted[3]) in conditioned_blocks[int(splitted[1])]):
64 | param.requires_grad = True
65 | elif name.startswith("norm"):
66 | param.requires_grad = True
67 |
68 |
69 | def disp2meters(d):
70 | return (65536.0 / d - 1 ) / 1e4
71 |
72 | def calc_seg_metrics(logit_task, label_task):
73 |
74 | max_labels = torch.argmax(logit_task, dim = 1, keepdim=True)
75 | iou = iou_pytorch(max_labels, label_task)
76 |
77 | return max_labels, iou
78 |
79 | def check_val(model, dataloader, criterion, index, tensorboard_name, seg_weight, depth_weight, device, tb_writer):
80 | test_loss = 0
81 | epoch_ious = []
82 | epoch_eval_depths_rmse = []
83 | epoch_eval_depths_d1 = []
84 |
85 | epoch_loss_seg_test = []
86 | epoch_loss_depth_test = []
87 |
88 | model.eval()
89 |
90 | for i, (img, label, task_id) in enumerate(dataloader, 0):
91 |
92 | img = img.view((-1, 3, 224, 224)).to(device)
93 | label = label.view((-1, 1, 224, 224)).to(device)
94 | task_id = task_id.view(-1).to(device)
95 |
96 | logits = model(img, task_id)
97 |
98 | loss = 0
99 |
100 | for j, unique_task_id in enumerate(unique_task_ids_list):
101 |
102 |
103 | task_id_filter = task_id == unique_task_id
104 |
105 | logit_task = logits[j]
106 | if logit_task is None:
107 | continue
108 |
109 | label_task = label[task_id_filter]
110 | B = logit_task.shape[0]
111 |
112 |
113 | if unique_task_id == 0:
114 |
115 | label_task = label_task.long()
116 |
117 | a = criterion[unique_task_id](logit_task.view(B,18,-1), label_task.long().view(B,-1))
118 | loss += a * seg_weight
119 |
120 | epoch_loss_seg_test.append(a.item() )
121 |
122 | max_labels = torch.argmax(logit_task, dim = 1, keepdim=True)
123 |
124 | iou = iou_pytorch(max_labels, label_task)
125 |
126 | epoch_ious.append(iou.cpu().numpy())
127 |
128 | else:
129 |
130 | logit_task = torch.nn.functional.sigmoid(logit_task)*65535 + 1
131 | label_task = 65536.0 / (label_task + 1)
132 |
133 | a = criterion[unique_task_id](logit_task, label_task, mask_val = 1.0)#* len(logit_task)
134 | loss += a* depth_weight
135 | epoch_loss_depth_test.append(a.item() )
136 |
137 |
138 | evaluation = eval_depth(disp2meters(logit_task), disp2meters(label_task))
139 | epoch_eval_depths_rmse.append(evaluation["rmse"])
140 | epoch_eval_depths_d1.append(evaluation["d1"])
141 |
142 | test_loss += loss.item()
143 |
144 | tb_writer.add_scalar(f"{tensorboard_name}/mid_test_loss", test_loss/len(dataloader) , index)
145 | tb_writer.add_scalar(f"{tensorboard_name}/mid_mean_iou", np.mean(epoch_ious) , index)
146 | tb_writer.add_scalar(f"{tensorboard_name}/mid_depth_rmse", np.mean(epoch_eval_depths_rmse) , index)
147 | tb_writer.add_scalar(f"{tensorboard_name}/mid_depth_d1", np.mean(epoch_eval_depths_d1) , index)
148 | tb_writer.add_scalar(f"{tensorboard_name}/mid_seg_loss_test", np.mean(epoch_loss_seg_test) , index)
149 | tb_writer.add_scalar(f"{tensorboard_name}/mid_depth_loss_test", np.mean(epoch_loss_depth_test) , index)
150 |
151 |
152 |
153 |
154 |
155 | def train(model, train_loader, test_loader,mid_test_dataloader, optimizer, scheduler, criterion, epochs, tensorboard_name, seg_weight, depth_weight, start_epoch = 0, device = "cuda", tb_writer=None):
156 |
157 | model.train()
158 |
159 | train_losses = []
160 | test_losses = []
161 | ious = []
162 | eval_depths = []
163 |
164 | iters = len(train_loader)
165 |
166 |
167 | mid_iter_count = 0
168 | mid_iter_count = 30
169 |
170 |
171 | for e in tqdm(range(epochs), desc="epoch", position=0):
172 |
173 | epoch = e + start_epoch + 1
174 |
175 | epoch_loss = 0.0
176 | epoch_loss_seg = []
177 | epoch_loss_depth = []
178 |
179 |
180 |
181 | #model.train()
182 | train_ious = []
183 | train_depths_rmse = []
184 | train_depths_d1 = []
185 |
186 |
187 | for i, (img, label, task_id) in tqdm(enumerate(train_loader, 0), desc="iter", position=1, leave=False):
188 | model.train()
189 |
190 | img = img.view((-1, 3, 224, 224)).to(device)
191 | label = label.view((-1, 1, 224, 224)).to(device)
192 | task_id = task_id.view(-1).to(device)
193 |
194 |
195 | logits = model(img, task_id)
196 |
197 | loss = 0
198 |
199 |
200 | for j, unique_task_id in enumerate(unique_task_ids_list):
201 |
202 | task_id_filter = task_id == unique_task_id
203 |
204 |
205 | logit_task = logits[j]
206 | if logit_task is None:
207 | continue
208 | label_task = label[task_id_filter]
209 |
210 | B = logit_task.shape[0]
211 |
212 | if unique_task_id == 0:
213 | label_task = label_task.long()
214 |
215 | a = criterion[unique_task_id](logit_task.view(B,18,-1), label_task.view(B,-1)) #* len(logit_task)
216 | loss += a * seg_weight
217 | epoch_loss_seg.append(a.item() )
218 |
219 | if epoch%1==0 :
220 | max_labels = torch.argmax(logit_task, dim = 1, keepdim=True)
221 | iou = iou_pytorch(max_labels, label_task)
222 | train_ious.append(iou.cpu().numpy())
223 |
224 |
225 | else:
226 | logit_task = torch.nn.functional.sigmoid(logit_task)*65535 + 1
227 | label_task = 65536.0 / (label_task + 1)
228 |
229 | a = criterion[unique_task_id](logit_task, label_task, mask_val = 1.0) #* len(logit_task)
230 | loss += a* depth_weight
231 | epoch_loss_depth.append(a.item())
232 |
233 | if epoch%1==0:
234 | evaluation = eval_depth(disp2meters(logit_task), disp2meters(label_task))
235 | train_depths_rmse.append(evaluation["rmse"])
236 | train_depths_d1.append(evaluation["d1"])
237 |
238 |
239 | optimizer.zero_grad()
240 | loss.backward()
241 | optimizer.step()
242 |
243 |
244 | epoch_loss += loss.item()
245 |
246 | scheduler.step()
247 |
248 |
249 | if i % 1000 == 0:
250 | check_val(model, mid_test_dataloader, criterion, mid_iter_count*1000, tensorboard_name, seg_weight, depth_weight, device, tb_writer)
251 | mid_iter_count += 1
252 |
253 | tb_writer.add_scalar(f"{tensorboard_name}/mid_train_loss", epoch_loss/ (i+1) , mid_iter_count*1000)
254 |
255 | torch.save({
256 | 'epoch': epoch,
257 | 'iter': i,
258 | 'model_state_dict': model.module.state_dict(),
259 | 'optimizer_state_dict': optimizer.state_dict(),
260 | 'scheduler_state_dict': scheduler.state_dict(),
261 | }, f"{tensorboard_name}.pt")
262 |
263 |
264 |
265 | if epoch % 1==0:
266 | test_loss = 0
267 | epoch_ious = []
268 | epoch_eval_depths_rmse = []
269 | epoch_eval_depths_d1 = []
270 |
271 | epoch_loss_seg_test = []
272 | epoch_loss_depth_test = []
273 |
274 | model.eval()
275 | for i, (img, label, task_id) in enumerate(test_loader, 0):
276 |
277 | img = img.view((-1, 3, 224, 224)).to(device)
278 | label = label.view((-1, 1, 224, 224)).to(device)
279 | task_id = task_id.view(-1).to(device)
280 |
281 |
282 | logits = model(img, task_id)
283 |
284 |
285 | loss = 0
286 |
287 | for j, unique_task_id in enumerate(unique_task_ids_list):
288 |
289 |
290 | task_id_filter = task_id == unique_task_id
291 |
292 | logit_task = logits[j]
293 | if logit_task is None:
294 | continue
295 | label_task = label[task_id_filter]
296 | B = logit_task.shape[0]
297 |
298 |
299 | if unique_task_id == 0:
300 |
301 | label_task = label_task.long()
302 |
303 | a = criterion[unique_task_id](logit_task.view(B,18,-1), label_task.long().view(B,-1))
304 | loss += a * seg_weight
305 |
306 | epoch_loss_seg_test.append(a.item() )
307 |
308 | if epoch%1==0:
309 |
310 | max_labels = torch.argmax(logit_task, dim = 1, keepdim=True)
311 |
312 | iou = iou_pytorch(max_labels, label_task)
313 |
314 | epoch_ious.append(iou.cpu().numpy())
315 |
316 | else:
317 |
318 |
319 | logit_task = torch.nn.functional.sigmoid(logit_task)*65535 + 1
320 | label_task = 65536.0 / (label_task + 1)
321 |
322 | a = criterion[unique_task_id](logit_task, label_task, mask_val = 1.0)#* len(logit_task)
323 | loss += a* depth_weight
324 |
325 | epoch_loss_depth_test.append(a.item() )
326 |
327 | if epoch%1==0:
328 | evaluation = eval_depth(disp2meters(logit_task), disp2meters(label_task))
329 | epoch_eval_depths_rmse.append(evaluation["rmse"])
330 | epoch_eval_depths_d1.append(evaluation["d1"])
331 |
332 |
333 |
334 | test_loss += loss.item()
335 |
336 |
337 |
338 | if epoch % 1==0:
339 |
340 |
341 | tb_writer.add_scalar(f"{tensorboard_name}/learning_rate", scheduler.get_last_lr()[0] , epoch)
342 | tb_writer.add_scalar(f"{tensorboard_name}/train_loss", epoch_loss/ len(train_loader) , epoch)
343 | tb_writer.add_scalar(f"{tensorboard_name}/test_loss", test_loss/len(test_loader) , epoch)
344 | tb_writer.add_scalar(f"{tensorboard_name}/mean_iou", np.mean(epoch_ious) , epoch)
345 | tb_writer.add_scalar(f"{tensorboard_name}/depth_rmse", np.mean(epoch_eval_depths_rmse) , epoch)
346 | tb_writer.add_scalar(f"{tensorboard_name}/depth_d1", np.mean(epoch_eval_depths_d1) , epoch)
347 |
348 | tb_writer.add_scalar(f"{tensorboard_name}/seg_loss", np.mean(epoch_loss_seg) , epoch)
349 | tb_writer.add_scalar(f"{tensorboard_name}/depth_loss", np.mean(epoch_loss_depth) , epoch)
350 |
351 | tb_writer.add_scalar(f"{tensorboard_name}/seg_loss_test", np.mean(epoch_loss_seg_test) , epoch)
352 | tb_writer.add_scalar(f"{tensorboard_name}/depth_loss_test", np.mean(epoch_loss_depth_test) , epoch)
353 |
354 | if epoch%1 == 0:
355 | tb_writer.add_scalar(f"{tensorboard_name}/train_mean_iou", np.mean(train_ious) , epoch)
356 | tb_writer.add_scalar(f"{tensorboard_name}/train_depth_rmse", np.mean(train_depths_rmse) , epoch)
357 | tb_writer.add_scalar(f"{tensorboard_name}/train_depth_d1", np.mean(train_depths_d1) , epoch)
358 |
359 |
360 |
361 | if epoch % 1 == 0:
362 | torch.save({
363 | 'epoch': epoch,
364 | 'model_state_dict': model.state_dict(),
365 | 'optimizer_state_dict': optimizer.state_dict(),
366 | 'scheduler_state_dict': scheduler.state_dict(),
367 | }, f"{tensorboard_name}.pt")
368 |
369 | return train_losses, test_losses, ious, eval_depths
370 |
371 |
372 | def load_model(model, optimizer, scheduler, PATH):
373 | checkpoint = torch.load(PATH, map_location=device)
374 | model.load_state_dict(checkpoint['model_state_dict'])
375 | model = model.to(device)
376 |
377 | optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
378 | scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
379 | epoch = checkpoint['epoch']
380 | return model, optimizer, scheduler, epoch
381 |
382 |
383 | def get_dataloaders(tasks, batch_size):
384 |
385 | train_dataset = TaskonomyDatasetS3(tasks=["rgb", "segment_semantic","depth_euclidean"], split="train", variant="tiny", image_size=224)
386 | test_dataset = TaskonomyDatasetS3(tasks=["rgb", "segment_semantic","depth_euclidean"], split="val", variant="tiny", image_size=224)
387 |
388 | print("Train dataset size:", len(train_dataset))
389 | print("Test dataset size:", len(test_dataset))
390 |
391 | g = torch.Generator()
392 | g.manual_seed(61)
393 |
394 | k_samples = 16*100
395 | perm = torch.randperm(len(test_dataset), generator=g)
396 | idx = perm[:k_samples].tolist()
397 |
398 | subset_dataset_test = torch.utils.data.Subset(test_dataset, idx)
399 |
400 | train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
401 | mid_test_dataloader = DataLoader(subset_dataset_test, batch_size=batch_size, shuffle=False)
402 | test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
403 |
404 | return train_dataloader, mid_test_dataloader, test_dataloader
405 |
406 | class _CustomDataParallel(torch.nn.DataParallel):
407 | def __init__(self, model):
408 | super(_CustomDataParallel, self).__init__(model)
409 |
410 | def __getattr__(self, name):
411 | try:
412 | return super(_CustomDataParallel, self).__getattr__(name)
413 | except AttributeError:
414 | return getattr(self.module, name)
415 |
416 | def main():
417 |
418 | config = get_config()
419 |
420 | tb_writer = SummaryWriter(f'runs/{config["experiment_name"]}')
421 |
422 | torch.manual_seed(61)
423 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
424 |
425 | tasks = {0:"segmentation", 1:"depth"} # add 2:"normals" and 3:"edges" to replicate the above code
426 |
427 | batch_size = config["batch_size"]
428 | print("Creating datasets...")
429 | train_dataloader, mid_test_dataloader, test_dataloader = get_dataloaders(tasks, batch_size)
430 |
431 |
432 | print("Loading model...")
433 | model = SwinTransformer(img_size=224,
434 | patch_size=4,
435 | in_chans=3,
436 | num_classes=21841,
437 | embed_dim=96,
438 | depths=[2, 2, 18, 2 ],
439 | depths_decoder =[2, 2, 2, 2 ],
440 | num_heads=[ 3, 6, 12, 24 ],
441 | window_size=7,
442 | mlp_ratio=4.,
443 | qkv_bias=True,
444 | qk_scale=True,
445 | drop_rate=0,
446 | drop_rate_decoder=0.6,
447 | drop_path_rate=0.2,
448 | ape=False,
449 | patch_norm=True,
450 | use_checkpoint=False,
451 | tasks = ["segmentation", "depth"],
452 | task_classes = [18, 1],
453 | conditioned_blocks = config["conditioned_blocks"],
454 | adapter=config["adapter"])
455 |
456 |
457 | epochs = config["epochs"]
458 |
459 |
460 | optimizer = optim.AdamW(model.parameters(), lr=2e-5, betas=(0.9, 0.98))#, weight_decay=0.001)
461 |
462 | scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr = 4e-5, epochs = epochs, steps_per_epoch = len(train_dataloader), pct_start = 0.1)
463 |
464 |
465 | if config["continue_training"]:
466 |
467 | model, optimizer, scheduler, start_epoch = load_model(model, optimizer, scheduler, config["experiment_name"]+".pt")
468 | print("Continue model loaded")
469 |
470 | else:
471 | start_epoch = -1
472 | model.load_state_dict(torch.load('./pretrained/swin_small_patch4_window7_224_22k.pth')['model'],strict=False)
473 | model = model.to(device)
474 | print("Model loaded")
475 |
476 |
477 |
478 | freeze_encoder_layers(model, conditioned_blocks = config["conditioned_blocks"], frozen_encoder = config["frozen_encoder"])
479 | model = model.to(device)
480 |
481 | model = torch.nn.DataParallel(model, device_ids=[0,1,2,3])
482 | print("Model on cuda:",next(model.parameters()).is_cuda)
483 |
484 | criterion = []
485 | segmentation_criteon = torch.nn.CrossEntropyLoss(ignore_index = 0)
486 | criterion.append(segmentation_criteon)
487 |
488 | depth_criterion = berHuLoss()
489 | criterion.append(depth_criterion)
490 |
491 | print("Training",config["experiment_name"],"...")
492 |
493 |
494 | train(model, train_dataloader, test_dataloader, mid_test_dataloader, optimizer, scheduler, criterion, epochs, config["experiment_name"], config["seg_weight"], config["depth_weight"], start_epoch, device, tb_writer)
495 |
496 |
497 | if __name__ == '__main__':
498 | main()
499 |
--------------------------------------------------------------------------------
/script/model/swin_transformer_parallel.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Swin Transformer
3 | # Copyright (c) 2021 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Ze Liu
6 | # --------------------------------------------------------
7 |
8 | import torch
9 | import torch.nn as nn
10 | import torch.utils.checkpoint as checkpoint
11 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_
12 | import math
13 | from model.conditional_modules import TAA, ConditionalBottleNeck, TaskScaledNorm
14 | from einops import rearrange
15 |
16 | import re
17 | import logging
18 | logger = logging.getLogger(__name__)
19 |
20 |
21 | try:
22 | import os
23 | import sys
24 |
25 | kernel_path = os.path.abspath(os.path.join('..'))
26 | sys.path.append(kernel_path)
27 | from kernels.window_process.window_process import WindowProcess, WindowProcessReverse
28 |
29 | except:
30 | WindowProcess = None
31 | WindowProcessReverse = None
32 | print("[Warning] Fused window process have not been installed. Please refer to get_started.md for installation.")
33 |
34 |
35 | class Mlp(nn.Module):
36 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
37 | super().__init__()
38 | out_features = out_features or in_features
39 | hidden_features = hidden_features or in_features
40 | self.fc1 = nn.Linear(in_features, hidden_features)
41 | self.act = act_layer()
42 | self.fc2 = nn.Linear(hidden_features, out_features)
43 | self.drop = nn.Dropout(drop)
44 |
45 | def forward(self, x):
46 | x = self.fc1(x)
47 | x = self.act(x)
48 | x = self.drop(x)
49 | x = self.fc2(x)
50 | x = self.drop(x)
51 | return x
52 |
53 |
54 | def window_partition(x, window_size):
55 | """
56 | Args:
57 | x: (B, H, W, C)
58 | window_size (int): window size
59 |
60 | Returns:
61 | windows: (num_windows*B, window_size, window_size, C)
62 | """
63 | B, H, W, C = x.shape
64 | x = x.view(B, H // window_size, window_size,
65 | W // window_size, window_size, C)
66 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous(
67 | ).view(-1, window_size, window_size, C)
68 | return windows
69 |
70 |
71 | def window_reverse(windows, window_size, H, W):
72 | """
73 | Args:
74 | windows: (num_windows*B, window_size, window_size, C)
75 | window_size (int): Window size
76 | H (int): Height of image
77 | W (int): Width of image
78 |
79 | Returns:
80 | x: (B, H, W, C)
81 | """
82 | B = int(windows.shape[0] / (H * W / window_size / window_size))
83 | x = windows.view(B, H // window_size, W // window_size,
84 | window_size, window_size, -1)
85 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
86 | return x
87 |
88 |
89 | class WindowAttention(nn.Module):
90 | r""" Window based multi-head self attention (W-MSA) module with relative position bias.
91 | It supports both of shifted and non-shifted window.
92 |
93 | Args:
94 | dim (int): Number of input channels.
95 | window_size (tuple[int]): The height and width of the window.
96 | num_heads (int): Number of attention heads.
97 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
98 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
99 | attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
100 | proj_drop (float, optional): Dropout ratio of output. Default: 0.0
101 | task_configs (dict): Configuration for the tasks
102 | """
103 |
104 | def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., task_configs=None):
105 |
106 | super().__init__()
107 | self.dim = dim
108 | self.window_size = window_size # Wh, Ww
109 | self.num_heads = num_heads
110 | head_dim = dim // num_heads
111 | self.scale = qk_scale or head_dim ** -0.5
112 |
113 | # define a parameter table of relative position bias
114 | self.relative_position_bias_table = nn.Parameter(
115 | torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
116 |
117 | # get pair-wise relative position index for each token inside the window
118 | coords_h = torch.arange(self.window_size[0])
119 | coords_w = torch.arange(self.window_size[1])
120 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
121 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
122 | relative_coords = coords_flatten[:, :, None] - \
123 | coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
124 | relative_coords = relative_coords.permute(
125 | 1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
126 | relative_coords[:, :, 0] += self.window_size[0] - \
127 | 1 # shift to start from 0
128 | relative_coords[:, :, 1] += self.window_size[1] - 1
129 | relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
130 | relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
131 | self.register_buffer("relative_position_index",
132 | relative_position_index)
133 |
134 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
135 | self.attn_drop = nn.Dropout(attn_drop)
136 | self.proj = nn.Linear(dim, dim)
137 | self.proj_drop = nn.Dropout(proj_drop)
138 |
139 | trunc_normal_(self.relative_position_bias_table, std=.02)
140 | self.softmax = nn.Softmax(dim=-1)
141 |
142 | self.task_configs = task_configs
143 | if task_configs is not None:
144 |
145 | self.max_seq_length = task_configs["max_seq_length"]
146 | self.hidden_size = task_configs["hidden_size"]
147 |
148 | self.num_blocks = self.hidden_size // self.max_seq_length
149 | self.taa_attn = TAA(
150 | self.hidden_size, math.ceil(
151 | self.max_seq_length / self.num_blocks), self.num_blocks
152 | )
153 |
154 | self.random_weight_matrix = nn.Parameter(
155 | torch.zeros(
156 | [self.max_seq_length, math.ceil(
157 | self.max_seq_length / self.num_blocks)]
158 | ),
159 | requires_grad=True,
160 | )
161 |
162 | def forward(self, x, task_embedding=None, mask=None):
163 | """
164 | Args:
165 | x: input features with shape of (num_windows*B, N, C)
166 | mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
167 | """
168 | B_, N, C = x.shape
169 | qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C //
170 | self.num_heads).permute(2, 0, 3, 1, 4)
171 | q, k, v = qkv[0], qkv[1], qkv[2]
172 |
173 | q = q * self.scale
174 | attn = (q @ k.transpose(-2, -1))
175 |
176 | relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
177 | self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
178 | relative_position_bias = relative_position_bias.permute(
179 | 2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
180 | attn = attn + relative_position_bias.unsqueeze(0)
181 |
182 |
183 |
184 | if self.task_configs is not None:
185 |
186 | attn2 = self.taa_attn(
187 | x_cond=task_embedding,
188 | x_to_film=self.random_weight_matrix,
189 | )
190 |
191 |
192 |
193 | attn = attn.view(len(task_embedding), -1, *(attn.shape[1:]))
194 |
195 |
196 |
197 | attn = attn + attn2.unsqueeze(1).unsqueeze(1)
198 |
199 | attn = attn.view(-1, *(attn.shape[2:]))
200 |
201 |
202 |
203 | if mask is not None:
204 | nW = mask.shape[0]
205 | attn = attn.view(B_ // nW, nW, self.num_heads, N,
206 | N) + mask.unsqueeze(1).unsqueeze(0)
207 | attn = attn.view(-1, self.num_heads, N, N)
208 | attn = self.softmax(attn)
209 | else:
210 | attn = self.softmax(attn)
211 |
212 | attn = self.attn_drop(attn)
213 |
214 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
215 | x = self.proj(x)
216 | x = self.proj_drop(x)
217 | return x
218 |
219 | def extra_repr(self) -> str:
220 | return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
221 |
222 | def flops(self, N):
223 | # calculate flops for 1 window with token length of N
224 | flops = 0
225 | # qkv = self.qkv(x)
226 | flops += N * self.dim * 3 * self.dim
227 | # attn = (q @ k.transpose(-2, -1))
228 | flops += self.num_heads * N * (self.dim // self.num_heads) * N
229 | # x = (attn @ v)
230 | flops += self.num_heads * N * N * (self.dim // self.num_heads)
231 | # x = self.proj(x)
232 | flops += N * self.dim * self.dim
233 | return flops
234 |
235 |
236 | class SwinTransformerBlock(nn.Module):
237 | r""" Swin Transformer Block.
238 |
239 | Args:
240 | dim (int): Number of input channels.
241 | input_resolution (tuple[int]): Input resulotion.
242 | num_heads (int): Number of attention heads.
243 | window_size (int): Window size.
244 | shift_size (int): Shift size for SW-MSA.
245 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
246 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
247 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
248 | drop (float, optional): Dropout rate. Default: 0.0
249 | attn_drop (float, optional): Attention dropout rate. Default: 0.0
250 | drop_path (float, optional): Stochastic depth rate. Default: 0.0
251 | act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
252 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
253 | fused_window_process (bool, optional): If True, use one kernel to fused window shift & window partition for acceleration, similar for the reversed part. Default: False
254 | task_configs (dict): Configuration for the tasks
255 | use_tsn_layer (bool, optional): Whether to use Task Scaled Normalization or regular layer normalization
256 | """
257 |
258 | def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
259 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
260 | act_layer=nn.GELU, norm_layer=nn.LayerNorm,
261 | fused_window_process=False, task_configs=None, hidden_size=343):
262 | super().__init__()
263 | self.dim = dim
264 | self.input_resolution = input_resolution
265 | self.num_heads = num_heads
266 | self.window_size = window_size
267 | self.shift_size = shift_size
268 | self.mlp_ratio = mlp_ratio
269 | self.task_configs = task_configs
270 | if min(self.input_resolution) <= self.window_size:
271 | # if window size is larger than input resolution, we don't partition windows
272 | self.shift_size = 0
273 | self.window_size = min(self.input_resolution)
274 | assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
275 |
276 |
277 | self.norm1 = norm_layer(dim)
278 |
279 |
280 | self.attn = WindowAttention(
281 | dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
282 | qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, task_configs=task_configs)
283 |
284 | self.drop_path = DropPath(
285 | drop_path) if drop_path > 0. else nn.Identity()
286 |
287 | self.norm2 = norm_layer(dim)
288 |
289 |
290 | mlp_hidden_dim = int(dim * mlp_ratio)
291 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,
292 | act_layer=act_layer, drop=drop)
293 |
294 | if self.shift_size > 0:
295 | # calculate attention mask for SW-MSA
296 | H, W = self.input_resolution
297 | img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
298 | h_slices = (slice(0, -self.window_size),
299 | slice(-self.window_size, -self.shift_size),
300 | slice(-self.shift_size, None))
301 | w_slices = (slice(0, -self.window_size),
302 | slice(-self.window_size, -self.shift_size),
303 | slice(-self.shift_size, None))
304 | cnt = 0
305 | for h in h_slices:
306 | for w in w_slices:
307 | img_mask[:, h, w, :] = cnt
308 | cnt += 1
309 |
310 | # nW, window_size, window_size, 1
311 | mask_windows = window_partition(img_mask, self.window_size)
312 | mask_windows = mask_windows.view(-1,
313 | self.window_size * self.window_size)
314 | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
315 | attn_mask = attn_mask.masked_fill(
316 | attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
317 | else:
318 | attn_mask = None
319 |
320 | self.register_buffer("attn_mask", attn_mask)
321 | self.fused_window_process = fused_window_process
322 |
323 | def forward(self, x, task_embedding=None, task_id = None):
324 | H, W = self.input_resolution
325 | B, L, C = x.shape
326 | assert L == H * W, "input feature has wrong size"
327 |
328 | skipconnect = x
329 |
330 | x = self.norm1(x)
331 |
332 | x = x.view(B, H, W, C)
333 |
334 | # cyclic shift
335 | if self.shift_size > 0:
336 | if not self.fused_window_process:
337 | shifted_x = torch.roll(
338 | x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
339 | # partition windows
340 | # nW*B, window_size, window_size, C
341 | x_windows = window_partition(shifted_x, self.window_size)
342 | else:
343 | x_windows = WindowProcess.apply(
344 | x, B, H, W, C, -self.shift_size, self.window_size)
345 | else:
346 | shifted_x = x
347 | # partition windows
348 | # nW*B, window_size, window_size, C
349 | x_windows = window_partition(shifted_x, self.window_size)
350 |
351 | # nW*B, window_size*window_size, C
352 | x_windows = x_windows.view(-1, self.window_size * self.window_size, C)
353 |
354 | # W-MSA/SW-MSA
355 | # nW*B, window_size*window_size, C
356 | attn_windows = self.attn(
357 | x_windows, task_embedding=task_embedding, mask=self.attn_mask)
358 |
359 | # merge windows
360 | attn_windows = attn_windows.view(-1,
361 | self.window_size, self.window_size, C)
362 |
363 |
364 | # reverse cyclic shift
365 | if self.shift_size > 0:
366 | if not self.fused_window_process:
367 | shifted_x = window_reverse(
368 | attn_windows, self.window_size, H, W) # B H' W' C
369 | x = torch.roll(shifted_x, shifts=(
370 | self.shift_size, self.shift_size), dims=(1, 2))
371 | else:
372 | x = WindowProcessReverse.apply(
373 | attn_windows, B, H, W, C, self.shift_size, self.window_size)
374 | else:
375 | shifted_x = window_reverse(
376 | attn_windows, self.window_size, H, W) # B H' W' C
377 | x = shifted_x
378 |
379 | x = x.view(B, H * W, C)
380 | x = skipconnect + self.drop_path(x)
381 |
382 | ''' Feed Forward Network'''
383 | x = x + self.drop_path(self.mlp(self.norm2(x)))
384 |
385 |
386 | return x
387 |
388 | def extra_repr(self) -> str:
389 | return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
390 | f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
391 |
392 | def flops(self):
393 | flops = 0
394 | H, W = self.input_resolution
395 | # norm1
396 | flops += self.dim * H * W
397 | # W-MSA/SW-MSA
398 | nW = H * W / self.window_size / self.window_size
399 | flops += nW * self.attn.flops(self.window_size * self.window_size)
400 | # mlp
401 | flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
402 | # norm2
403 | flops += self.dim * H * W
404 | return flops
405 |
406 |
407 | class PatchMerging(nn.Module):
408 | r""" Patch Merging Layer.
409 |
410 | Args:
411 | input_resolution (tuple[int]): Resolution of input feature.
412 | dim (int): Number of input channels.
413 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
414 | """
415 |
416 | def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
417 | super().__init__()
418 | self.input_resolution = input_resolution
419 | self.dim = dim
420 | self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
421 | self.norm = norm_layer(4 * dim)
422 |
423 | def forward(self, x):
424 | """
425 | x: B, H*W, C
426 | """
427 | H, W = self.input_resolution
428 | B, L, C = x.shape
429 | assert L == H * W, "input feature has wrong size"
430 | assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
431 |
432 | x = x.view(B, H, W, C)
433 |
434 | x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
435 | x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
436 | x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
437 | x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
438 | x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
439 | x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
440 |
441 | x = self.norm(x)
442 | x = self.reduction(x)
443 |
444 | return x
445 |
446 | def extra_repr(self) -> str:
447 | return f"input_resolution={self.input_resolution}, dim={self.dim}"
448 |
449 | def flops(self):
450 | H, W = self.input_resolution
451 | flops = H * W * self.dim
452 | flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
453 | return flops
454 |
455 |
456 | class BasicLayer(nn.Module):
457 | """ A basic Swin Transformer layer for one stage.
458 |
459 | Args:
460 | dim (int): Number of input channels.
461 | input_resolution (tuple[int]): Input resolution.
462 | depth (int): Number of blocks.
463 | num_heads (int): Number of attention heads.
464 | window_size (int): Local window size.
465 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
466 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
467 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
468 | drop (float, optional): Dropout rate. Default: 0.0
469 | attn_drop (float, optional): Attention dropout rate. Default: 0.0
470 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
471 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
472 | downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
473 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
474 | fused_window_process (bool, optional): If True, use one kernel to fused window shift & window partition for acceleration, similar for the reversed part. Default: False
475 | task_configs (dict): Configuration for the tasks
476 | conditioned_blocks (list): List of transformer blocks to adapt
477 | adapter (boolean): Whether to use adapters or not
478 | use_tsn_layer (boolean): Whether to use regular or task scaled normalization
479 | """
480 |
481 | def __init__(self, dim, input_resolution, depth, num_heads, window_size,
482 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
483 | drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
484 | fused_window_process=False, task_configs=None, conditioned_blocks=[0], adapter=False, hidden_size=343):
485 |
486 | super().__init__()
487 | self.dim = dim
488 | self.input_resolution = input_resolution
489 | self.depth = depth
490 | self.use_checkpoint = use_checkpoint
491 | self.task_configs = task_configs
492 | self.adapter = adapter
493 |
494 | # build blocks
495 | self.blocks = nn.ModuleList([
496 | SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
497 | num_heads=num_heads, window_size=window_size,
498 | shift_size=0 if (
499 | i % 2 == 0) else window_size // 2,
500 | mlp_ratio=mlp_ratio,
501 | qkv_bias=qkv_bias, qk_scale=qk_scale,
502 | drop=drop, attn_drop=attn_drop,
503 | drop_path=drop_path[i] if isinstance(
504 | drop_path, list) else drop_path,
505 | norm_layer=norm_layer,
506 | fused_window_process=fused_window_process,
507 | task_configs=task_configs if i in conditioned_blocks else None,
508 | hidden_size = hidden_size)
509 | for i in range(depth)])
510 |
511 | if self.adapter:
512 | self.adapter_layer = nn.ModuleList([
513 | ConditionalBottleNeck(task_configs["hidden_size"], self.dim)
514 | for i in range(depth)])
515 | else:
516 | self.adapter_layer = [None for i in range(depth)]
517 |
518 | # patch merging layer
519 | if downsample is not None:
520 | self.downsample = downsample(
521 | input_resolution, dim=dim, norm_layer=norm_layer)
522 |
523 | self.downsample_bottleneck = downsample(
524 | input_resolution, dim=dim, norm_layer=norm_layer)
525 | else:
526 | self.downsample = None
527 | self.downsample_bottleneck = None
528 |
529 | def forward(self, x, hidden_film=None, task_embedding=None, task_id = None):
530 |
531 | if hidden_film is None:
532 | hidden_film = torch.zeros_like(x)
533 |
534 | for i, (blk, adapter_module) in enumerate(zip(self.blocks, self.adapter_layer)):
535 | if self.use_checkpoint:
536 | x = checkpoint.checkpoint(blk, x, task_embedding, task_id)
537 | else:
538 | x = blk(x, task_embedding=task_embedding, task_id = task_id)
539 |
540 |
541 | if self.adapter:
542 | hidden_film = adapter_module(
543 | x_cond=task_embedding, hidden_states=x + hidden_film
544 | )
545 |
546 | else:
547 | hidden_film = None
548 |
549 |
550 |
551 |
552 | if self.downsample is not None:
553 | x = self.downsample(x)
554 | if self.adapter:
555 | hidden_film = self.downsample_bottleneck(hidden_film)
556 | return x, hidden_film
557 |
558 | def extra_repr(self) -> str:
559 | return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
560 |
561 | def flops(self):
562 | flops = 0
563 | for blk in self.blocks:
564 | flops += blk.flops()
565 | if self.downsample is not None:
566 | flops += self.downsample.flops()
567 | return flops
568 |
569 |
570 | class BasicLayer_up(nn.Module):
571 | """ A basic Swin Transformer layer for one stage.
572 |
573 | Args:
574 | dim (int): Number of input channels.
575 | input_resolution (tuple[int]): Input resolution.
576 | depth (int): Number of blocks.
577 | num_heads (int): Number of attention heads.
578 | window_size (int): Local window size.
579 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
580 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
581 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
582 | drop (float, optional): Dropout rate. Default: 0.0
583 | attn_drop (float, optional): Attention dropout rate. Default: 0.0
584 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
585 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
586 | downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
587 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
588 | use_tsn_layer (boolean): Whether to use regular or task scaled normalization
589 | """
590 |
591 | def __init__(self, dim, input_resolution, depth, num_heads, window_size,
592 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
593 | drop_path=0., norm_layer=nn.LayerNorm, upsample=None, use_checkpoint=False):
594 |
595 | super().__init__()
596 | self.dim = dim
597 | self.input_resolution = input_resolution
598 | self.depth = depth
599 | self.use_checkpoint = use_checkpoint
600 |
601 | # build blocks
602 | self.blocks = nn.ModuleList([
603 | SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
604 | num_heads=num_heads, window_size=window_size,
605 | shift_size=0 if (
606 | i % 2 == 0) else window_size // 2,
607 | mlp_ratio=mlp_ratio,
608 | qkv_bias=qkv_bias, qk_scale=qk_scale,
609 | drop=drop, attn_drop=attn_drop,
610 | drop_path=drop_path[i] if isinstance(
611 | drop_path, list) else drop_path,
612 | norm_layer=norm_layer)
613 | for i in range(depth)])
614 |
615 | # patch merging layer
616 | if upsample is not None:
617 | self.upsample = PatchExpand(
618 | input_resolution, dim=dim, dim_scale=2, norm_layer=norm_layer)
619 | else:
620 | self.upsample = None
621 |
622 | def forward(self, x):
623 | for blk in self.blocks:
624 | if self.use_checkpoint:
625 | x = checkpoint.checkpoint(blk, x)
626 | else:
627 | x = blk(x)
628 | if self.upsample is not None:
629 | x = self.upsample(x)
630 | return x
631 |
632 |
633 | class PatchEmbed(nn.Module):
634 | r""" Image to Patch Embedding
635 |
636 | Args:
637 | img_size (int): Image size. Default: 224.
638 | patch_size (int): Patch token size. Default: 4.
639 | in_chans (int): Number of input image channels. Default: 3.
640 | embed_dim (int): Number of linear projection output channels. Default: 96.
641 | norm_layer (nn.Module, optional): Normalization layer. Default: None
642 | """
643 |
644 | def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
645 | super().__init__()
646 | img_size = to_2tuple(img_size)
647 | patch_size = to_2tuple(patch_size)
648 | patches_resolution = [img_size[0] //
649 | patch_size[0], img_size[1] // patch_size[1]]
650 | self.img_size = img_size
651 | self.patch_size = patch_size
652 | self.patches_resolution = patches_resolution
653 | self.num_patches = patches_resolution[0] * patches_resolution[1]
654 |
655 | self.in_chans = in_chans
656 | self.embed_dim = embed_dim
657 |
658 | self.proj = nn.Conv2d(in_chans, embed_dim,
659 | kernel_size=patch_size, stride=patch_size)
660 | if norm_layer is not None:
661 | self.norm = norm_layer(embed_dim)
662 | else:
663 | self.norm = None
664 |
665 | def forward(self, x):
666 | B, C, H, W = x.shape
667 | # FIXME look at relaxing size constraints
668 | assert H == self.img_size[0] and W == self.img_size[1], \
669 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
670 | x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
671 | if self.norm is not None:
672 | x = self.norm(x)
673 | return x
674 |
675 | def flops(self):
676 | Ho, Wo = self.patches_resolution
677 | flops = Ho * Wo * self.embed_dim * self.in_chans * \
678 | (self.patch_size[0] * self.patch_size[1])
679 | if self.norm is not None:
680 | flops += Ho * Wo * self.embed_dim
681 | return flops
682 |
683 |
684 | class PatchExpand(nn.Module):
685 | def __init__(self, input_resolution, dim, dim_scale=2, norm_layer=nn.LayerNorm):
686 | super().__init__()
687 | self.input_resolution = input_resolution
688 | self.dim = dim
689 | self.expand = nn.Linear(
690 | dim, 2*dim, bias=False) if dim_scale == 2 else nn.Identity()
691 |
692 | self.norm = norm_layer(dim*2)
693 |
694 | def forward(self, x):
695 | """
696 | x: B, H*W, C
697 | """
698 | H, W = self.input_resolution
699 | x = self.expand(x)
700 | B, L, C = x.shape
701 | assert L == H * W, "Input feature has wrong size"
702 |
703 | x = self.norm(x) ###### bu
704 |
705 | x = x.view(B, H, W, C)
706 | x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=2, p2=2, c=C//4)
707 | x = x.view(B, -1, C//4)
708 | return x
709 |
710 |
711 | class FinalPatchExpand_X4(nn.Module):
712 | def __init__(self, input_resolution, dim, dim_scale=4, norm_layer=nn.LayerNorm):
713 | super().__init__()
714 | self.input_resolution = input_resolution
715 | self.dim = dim
716 | self.dim_scale = dim_scale
717 | self.expand = nn.Linear(dim, 16*dim, bias=False) ###ikinci 16*dim
718 | self.output_dim = dim
719 | self.norm = norm_layer(16*dim)#norm_layer(self.output_dim)
720 |
721 | def forward(self, x):
722 | """
723 | x: B, H*W, C
724 | """
725 | H, W = self.input_resolution
726 | x = self.expand(x)
727 | B, L, C = x.shape
728 | assert L == H * W, "Input feature has wrong size"
729 |
730 | x = self.norm(x)
731 |
732 | x = x.view(B, H, W, C)
733 | x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=self.dim_scale, p2=self.dim_scale, c=C//(self.dim_scale**2))
734 | x = x.view(B, -1, self.output_dim)
735 |
736 | #x = self.norm(x)
737 | return x
738 |
739 | class SwinTransformer(nn.Module):
740 | r""" Swin Transformer
741 | A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
742 | https://arxiv.org/pdf/2103.14030
743 |
744 | Args:
745 | img_size (int | tuple(int)): Input image size. Default 224
746 | patch_size (int | tuple(int)): Patch size. Default: 4
747 | in_chans (int): Number of input image channels. Default: 3
748 | num_classes (int): Number of classes for classification head. Default: 1000
749 | embed_dim (int): Patch embedding dimension. Default: 96
750 | depths (tuple(int)): Depth of each Swin Transformer layer.
751 | num_heads (tuple(int)): Number of attention heads in different layers.
752 | window_size (int): Window size. Default: 7
753 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
754 | qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
755 | qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
756 | drop_rate (float): Dropout rate. Default: 0
757 | attn_drop_rate (float): Attention dropout rate. Default: 0
758 | drop_path_rate (float): Stochastic depth rate. Default: 0.1
759 | norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
760 | ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
761 | patch_norm (bool): If True, add normalization after patch embedding. Default: True
762 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
763 | fused_window_process (bool, optional): If True, use one kernel to fused window shift & window partition for acceleration, similar for the reversed part. Default: False
764 | tasks (list): List of tasks
765 | final_upsample (strin): Setting to expand the last layer
766 | task_classes (list): List of number of prediction classes for each task
767 | conditioned_blocks (list): List of transformer blocks to adapt
768 | adapter (boolean): Whether to use adapters or not
769 | use_tsn_layer (boolean): Whether to use regular or task scaled normalization
770 | """
771 |
772 | def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000,
773 | embed_dim=96, depths=[2, 2, 6, 2], depths_decoder=[2,2,2,2], num_heads=[3, 6, 12, 24],
774 | window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
775 | drop_rate=0.,drop_rate_decoder=0., attn_drop_rate=0., drop_path_rate=0.1,
776 | norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
777 | use_checkpoint=False, fused_window_process=False,
778 | hidden_size=None, tasks=["segmentation"], final_upsample="expand_first", task_classes = [100],
779 | conditioned_blocks = [[],[],[12],[]], adapter = False,
780 | **kwargs):
781 | super().__init__()
782 |
783 | self.num_classes = num_classes
784 | self.num_layers = len(depths)
785 | self.embed_dim = embed_dim
786 | self.ape = ape
787 | self.patch_norm = patch_norm
788 | self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
789 | self.mlp_ratio = mlp_ratio
790 | self.final_upsample = final_upsample
791 | self.task_classes = task_classes
792 | self.adapter = adapter
793 |
794 | assert len(task_classes) == len(tasks), "number of tasks must match the number of classes"
795 |
796 | assert len(conditioned_blocks) == self.num_layers, "give conditioned block index for each layer"
797 |
798 | # split image into non-overlapping patches
799 | self.patch_embed = PatchEmbed(
800 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
801 | norm_layer=norm_layer if self.patch_norm else None)
802 | num_patches = self.patch_embed.num_patches
803 | patches_resolution = self.patch_embed.patches_resolution
804 | self.patches_resolution = patches_resolution
805 |
806 | # absolute position embedding
807 | if self.ape:
808 | self.absolute_pos_embed = nn.Parameter(
809 | torch.zeros(1, num_patches, embed_dim))
810 | trunc_normal_(self.absolute_pos_embed, std=.02)
811 |
812 | self.pos_drop = nn.Dropout(p=drop_rate)
813 |
814 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate,
815 | sum(depths))]
816 | self.max_seq_length = window_size * window_size
817 |
818 | if hidden_size is None:
819 | self.hidden_size = self.max_seq_length * \
820 | window_size
821 | else:
822 | self.hidden_size = hidden_size
823 |
824 | self.task_id_2_task_idx = {i: i for i, t in enumerate(tasks)}
825 |
826 | self.task_type_embeddings = nn.Embedding(
827 | len(tasks), self.hidden_size)
828 |
829 | self.task_configs = {"hidden_size": self.hidden_size,
830 | "max_seq_length": self.max_seq_length}
831 |
832 | # build layers
833 | self.layers = nn.ModuleList()
834 | for i_layer in range(self.num_layers):
835 |
836 | layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),
837 | input_resolution=(patches_resolution[0] // (2 ** i_layer),
838 | patches_resolution[1] // (2 ** i_layer)),
839 | depth=depths[i_layer],
840 | num_heads=num_heads[i_layer],
841 | window_size=window_size,
842 | mlp_ratio=self.mlp_ratio,
843 | qkv_bias=qkv_bias, qk_scale=qk_scale,
844 | drop=drop_rate, attn_drop=attn_drop_rate,
845 | drop_path=dpr[sum(depths[:i_layer]):sum(
846 | depths[:i_layer + 1])],
847 | norm_layer=norm_layer,
848 | downsample=PatchMerging if (
849 | i_layer < self.num_layers - 1) else None,
850 | use_checkpoint=use_checkpoint,
851 | fused_window_process=fused_window_process,
852 | task_configs=self.task_configs,
853 | conditioned_blocks=conditioned_blocks[i_layer],
854 | adapter = adapter,
855 | hidden_size = self.hidden_size)
856 | self.layers.append(layer)
857 |
858 | self.norm = norm_layer(self.num_features)
859 |
860 | # Decoder Module
861 | self.decoder_layers_layers_up = nn.ModuleList()
862 | self.decoder_layers_concat_back_dim = nn.ModuleList()
863 | self.decoder_layers_norm_up = nn.ModuleList()
864 | self.decoder_layers_up = nn.ModuleList()
865 | self.decoder_layers_output = nn.ModuleList()
866 |
867 | for i, task in enumerate(tasks):
868 | task_modules = dict()
869 | layers_up = nn.ModuleList()
870 | concat_back_dim = nn.ModuleList()
871 | for i_layer in range(self.num_layers):
872 | concat_linear = nn.Linear(2*int(embed_dim*2**(self.num_layers-1-i_layer)),
873 | int(embed_dim*2**(self.num_layers-1-i_layer))) if i_layer > 0 else nn.Identity()
874 | if i_layer == 0:
875 | layer_up = PatchExpand(input_resolution=(patches_resolution[0] // (2 ** (self.num_layers-1-i_layer)),
876 | patches_resolution[1] // (2 ** (self.num_layers-1-i_layer))), dim=int(embed_dim * 2 ** (self.num_layers-1-i_layer)), dim_scale=2, norm_layer=norm_layer)
877 | else:
878 | layer_up = BasicLayer_up(dim=int(embed_dim * 2 ** (self.num_layers-1-i_layer)),
879 | input_resolution=(patches_resolution[0] // (2 ** (self.num_layers-1-i_layer)),
880 | patches_resolution[1] // (2 ** (self.num_layers-1-i_layer))),
881 | depth=depths_decoder[(
882 | self.num_layers-1-i_layer)],
883 | num_heads=num_heads[(
884 | self.num_layers-1-i_layer)],
885 | window_size=window_size,
886 | mlp_ratio=self.mlp_ratio,
887 | qkv_bias=qkv_bias, qk_scale=qk_scale,
888 | drop=drop_rate_decoder, attn_drop=attn_drop_rate,
889 | drop_path=dpr[sum(depths[:(
890 | self.num_layers-1-i_layer)]):sum(depths[:(self.num_layers-1-i_layer) + 1])],
891 | norm_layer=norm_layer,
892 | upsample=PatchExpand if (
893 | i_layer < self.num_layers - 1) else None,
894 | use_checkpoint=use_checkpoint)
895 | layers_up.append(layer_up)
896 | concat_back_dim.append(concat_linear)
897 |
898 | self.decoder_layers_layers_up.append(layers_up)
899 | self.decoder_layers_concat_back_dim.append(concat_back_dim)
900 | self.decoder_layers_norm_up.append(norm_layer(self.embed_dim))
901 |
902 | if self.final_upsample == "expand_first":
903 | up = FinalPatchExpand_X4(input_resolution=(
904 | img_size//patch_size, img_size//patch_size), dim_scale=4, dim=embed_dim)
905 |
906 | self.decoder_layers_up.append(up)
907 |
908 | dec_output = nn.Conv2d(
909 | in_channels=embed_dim, out_channels=self.task_classes[i], kernel_size=1, bias=False)
910 |
911 | self.decoder_layers_output.append(dec_output)
912 |
913 | self.apply(self._init_weights)
914 |
915 | def _init_weights(self, m):
916 | if isinstance(m, nn.Linear):
917 | trunc_normal_(m.weight, std=.02)
918 | if isinstance(m, nn.Linear) and m.bias is not None:
919 | nn.init.constant_(m.bias, 0)
920 | elif isinstance(m, nn.LayerNorm):
921 | nn.init.constant_(m.bias, 0)
922 | nn.init.constant_(m.weight, 1.0)
923 |
924 | @torch.jit.ignore
925 | def no_weight_decay(self):
926 | return {'absolute_pos_embed'}
927 |
928 | @torch.jit.ignore
929 | def no_weight_decay_keywords(self):
930 | return {'relative_position_bias_table'}
931 |
932 | def forward_features_old(self, x, task_embedding):
933 | x = self.patch_embed(x)
934 | if self.ape:
935 | x = x + self.absolute_pos_embed
936 | x = self.pos_drop(x)
937 |
938 | for layer in self.layers:
939 | x, hidden_film = layer(x, task_embedding=task_embedding)
940 |
941 | x = self.norm(x)
942 |
943 |
944 | return x
945 |
946 | def forward_features(self, x, task_embedding, task_id):
947 | x = self.patch_embed(x)
948 | if self.ape:
949 | x = x + self.absolute_pos_embed
950 | x = self.pos_drop(x)
951 | x_downsample = []
952 |
953 | for i,layer in enumerate(self.layers):
954 | x_downsample.append(x)
955 | if i == 0:
956 | x, hidden = layer(x, task_embedding=task_embedding, task_id = task_id)
957 | else:
958 | x, hidden = layer(x, hidden, task_embedding=task_embedding, task_id = task_id)
959 |
960 | if self.adapter:
961 | x = hidden
962 |
963 | x = self.norm(x)
964 |
965 | return x, x_downsample
966 |
967 | #Skip connection
968 | def forward_up_features(self, x, x_downsample, layers_up, concat_back_dim, norm_up, print = False):
969 | for inx, layer_up in enumerate(layers_up):
970 | if inx == 0:
971 | x = layer_up(x)
972 | else:
973 | x = torch.cat([x, x_downsample[3-inx]], -1)
974 | x = concat_back_dim[inx](x)
975 | x = layer_up(x)
976 | x = norm_up(x)
977 | return x
978 |
979 | def up_x4(self, x, up, output):
980 | H, W = self.patches_resolution
981 | B, L, C = x.shape
982 | assert L == H*W, "Input features have wrong size"
983 |
984 | if self.final_upsample == "expand_first":
985 | x = up(x)
986 | x = x.view(B, 4*H, 4*W, -1)
987 | x = x.permute(0, 3, 1, 2)
988 | x = output(x)
989 | return x
990 |
991 | def forward_old(self, x, task_id):
992 | task_type = self._create_task_type(task_id)
993 | task_embedding = self.task_type_embeddings(task_type)
994 | x = self.forward_features(x, task_embedding)
995 | return x
996 |
997 | def forward(self, x, task_id):
998 | task_type, unique_task_ids_list = self._create_task_type(task_id)
999 | task_embedding = self.task_type_embeddings(task_type)
1000 |
1001 | x, x_downsample = self.forward_features(x, task_embedding, task_id)
1002 |
1003 | logits = [None]*len(self.task_classes)
1004 |
1005 | for unique_task_id in unique_task_ids_list:
1006 | task_id_filter = task_id == unique_task_id
1007 | layers_up = self.decoder_layers_layers_up[unique_task_id]
1008 | concat_back_dim = self.decoder_layers_concat_back_dim[unique_task_id]
1009 | norm_up = self.decoder_layers_norm_up[unique_task_id]
1010 | up = self.decoder_layers_up[unique_task_id]
1011 | dec_output = self.decoder_layers_output[unique_task_id]
1012 |
1013 | x_downsample_up = []
1014 |
1015 | for x_it in x_downsample:
1016 | x_downsample_up.append(x_it[task_id_filter])
1017 |
1018 |
1019 | if unique_task_id == 1:
1020 | x_up = self.forward_up_features(
1021 | x[task_id_filter], x_downsample_up, layers_up, concat_back_dim, norm_up)
1022 | else:
1023 | x_up = self.forward_up_features(
1024 | x[task_id_filter], x_downsample_up, layers_up, concat_back_dim, norm_up)
1025 | x_up = self.up_x4(x_up, up, dec_output)
1026 | logits[unique_task_id] = x_up
1027 |
1028 | return tuple(logits)
1029 |
1030 |
1031 | def flops(self):
1032 | flops = 0
1033 | flops += self.patch_embed.flops()
1034 | for i, layer in enumerate(self.layers):
1035 | flops += layer.flops()
1036 | flops += self.num_features * \
1037 | self.patches_resolution[0] * \
1038 | self.patches_resolution[1] // (2 ** self.num_layers)
1039 | flops += self.num_features * self.num_classes
1040 | return flops
1041 |
1042 | def _create_task_type(self, task_id):
1043 | task_type = task_id.clone()
1044 | unique_task_ids = torch.unique(task_type)
1045 | unique_task_ids_list = (
1046 | unique_task_ids.cpu().numpy()
1047 | if unique_task_ids.is_cuda
1048 | else unique_task_ids.numpy()
1049 | )
1050 | for unique_task_id in unique_task_ids_list:
1051 | task_type[task_type == unique_task_id] = self.task_id_2_task_idx[
1052 | unique_task_id
1053 | ]
1054 | return task_type, unique_task_ids_list
1055 |
--------------------------------------------------------------------------------