├── utils
├── __init__.py
├── cmap.npy
├── optimizer.py
├── helpers.py
├── meter.py
├── datasets.py
├── transforms.py
└── augmentations_mm.py
├── models
├── __init__.py
├── modules.py
├── segformer.py
├── mix_transformer.py
└── swin_transformer.py
├── figs
├── framework.png
├── homogeneous.png
├── heterogeneous.png
└── geminifusion_framework.png
├── mmcv_custom
├── __init__.py
├── runner
│ ├── __init__.py
│ ├── checkpoint.py
│ └── epoch_based_runner.py
└── checkpoint.py
├── LICENSE
├── .gitignore
├── README-TokenFusion.md
├── README.md
├── main.py
└── data
└── nyudv2
└── val.txt
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .helpers import *
2 | from .meter import *
3 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .mix_transformer import *
2 | from .segformer import WeTr
--------------------------------------------------------------------------------
/utils/cmap.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiaDingCN/GeminiFusion/HEAD/utils/cmap.npy
--------------------------------------------------------------------------------
/figs/framework.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiaDingCN/GeminiFusion/HEAD/figs/framework.png
--------------------------------------------------------------------------------
/figs/homogeneous.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiaDingCN/GeminiFusion/HEAD/figs/homogeneous.png
--------------------------------------------------------------------------------
/figs/heterogeneous.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiaDingCN/GeminiFusion/HEAD/figs/heterogeneous.png
--------------------------------------------------------------------------------
/figs/geminifusion_framework.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiaDingCN/GeminiFusion/HEAD/figs/geminifusion_framework.png
--------------------------------------------------------------------------------
/mmcv_custom/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | from .checkpoint import load_checkpoint
4 |
5 | __all__ = ["load_checkpoint"]
6 |
7 |
--------------------------------------------------------------------------------
/mmcv_custom/runner/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Open-MMLab. All rights reserved.
2 | from .checkpoint import save_checkpoint
3 | from .epoch_based_runner import EpochBasedRunnerAmp
4 |
5 |
6 | __all__ = ["EpochBasedRunnerAmp", "save_checkpoint"]
7 |
8 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 jiading
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/utils/optimizer.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | class PolyWarmupAdamW(torch.optim.AdamW):
4 |
5 | def __init__(self, params, lr, weight_decay, betas, warmup_iter=None, max_iter=None, warmup_ratio=None, power=None):
6 | super().__init__(params, lr=lr, betas=betas,weight_decay=weight_decay, eps=1e-8)
7 |
8 | self.global_step = 0
9 | self.warmup_iter = warmup_iter
10 | self.warmup_ratio = warmup_ratio
11 | self.max_iter = max_iter
12 | self.power = power
13 |
14 | self.__init_lr = [group['lr'] for group in self.param_groups]
15 |
16 | def step(self, closure=None):
17 | ## adjust lr
18 | if self.global_step < self.warmup_iter:
19 |
20 | lr_mult = 1 - (1 - self.global_step / self.warmup_iter) * (1 - self.warmup_ratio)
21 | for i in range(len(self.param_groups)):
22 | self.param_groups[i]['lr'] = self.__init_lr[i] * lr_mult
23 |
24 | elif self.global_step < self.max_iter:
25 |
26 | lr_mult = (1 - self.global_step / self.max_iter) ** self.power
27 | for i in range(len(self.param_groups)):
28 | self.param_groups[i]['lr'] = self.__init_lr[i] * lr_mult
29 |
30 | # step
31 | super().step(closure)
32 |
33 | self.global_step += 1
--------------------------------------------------------------------------------
/models/modules.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | num_parallel = 2
4 |
5 |
6 |
7 |
8 |
9 | class ModuleParallel(nn.Module):
10 | def __init__(self, module):
11 | super(ModuleParallel, self).__init__()
12 | self.module = module
13 |
14 | def forward(self, x_parallel):
15 | return [self.module(x) for x in x_parallel]
16 |
17 |
18 | class Additional_One_ModuleParallel(nn.Module):
19 | def __init__(self, module):
20 | super(Additional_One_ModuleParallel, self).__init__()
21 | self.module = module
22 |
23 | def forward(self, x_parallel, x_arg):
24 | if x_arg == None:
25 | return [self.module(x, None) for x in x_parallel]
26 | elif isinstance(x_arg, list):
27 | return [
28 | self.module(x_parallel[i], x_arg[i]) for i in range(len(x_parallel))
29 | ]
30 | else:
31 | return [self.module(x_parallel[i], x_arg) for i in range(len(x_parallel))]
32 |
33 |
34 | class Additional_Two_ModuleParallel(nn.Module):
35 | def __init__(self, module):
36 | super(Additional_Two_ModuleParallel, self).__init__()
37 | self.module = module
38 |
39 | def forward(self, x_parallel, x_arg1, x_arg2):
40 | return [
41 | self.module(x_parallel[i], x_arg1, x_arg2) for i in range(len(x_parallel))
42 | ]
43 |
44 |
45 | class LayerNormParallel(nn.Module):
46 | def __init__(self, num_features):
47 | super(LayerNormParallel, self).__init__()
48 | for i in range(num_parallel):
49 | setattr(self, "ln_" + str(i), nn.LayerNorm(num_features, eps=1e-6))
50 |
51 | def forward(self, x_parallel):
52 | return [getattr(self, "ln_" + str(i))(x) for i, x in enumerate(x_parallel)]
53 |
--------------------------------------------------------------------------------
/utils/helpers.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import matplotlib as mpl
4 | import matplotlib.cm as cm
5 | import PIL.Image as pil
6 | import cv2
7 | import os
8 |
9 | IMG_SCALE = 1./255
10 | IMG_MEAN = np.array([0.485, 0.456, 0.406]).reshape((1, 1, 3))
11 | IMG_STD = np.array([0.229, 0.224, 0.225]).reshape((1, 1, 3))
12 | logger = None
13 |
14 |
15 | def print_log(message):
16 | print(message, flush=True)
17 | if logger:
18 | logger.write(str(message) + '\n')
19 |
20 |
21 | def maybe_download(model_name, model_url, model_dir=None, map_location=None):
22 | import os, sys
23 | from six.moves import urllib
24 | if model_dir is None:
25 | torch_home = os.path.expanduser(os.getenv('TORCH_HOME', '~/.torch'))
26 | model_dir = os.getenv('TORCH_MODEL_ZOO', os.path.join(torch_home, 'models'))
27 | if not os.path.exists(model_dir):
28 | os.makedirs(model_dir)
29 | filename = '{}.pth.tar'.format(model_name)
30 | cached_file = os.path.join(model_dir, filename)
31 | if not os.path.exists(cached_file):
32 | url = model_url
33 | sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
34 | urllib.request.urlretrieve(url, cached_file)
35 | return torch.load(cached_file, map_location=map_location)
36 |
37 |
38 | def prepare_img(img):
39 | return (img * IMG_SCALE - IMG_MEAN) / IMG_STD
40 |
41 |
42 | def make_validation_img(img_, depth_, lab, pre):
43 | cmap = np.load('./utils/cmap.npy')
44 |
45 | img = np.array([i * IMG_STD.reshape((3, 1, 1)) + IMG_MEAN.reshape((3, 1, 1)) for i in img_])
46 | img *= 255
47 | img = img.astype(np.uint8)
48 | img = np.concatenate(img, axis=1)
49 |
50 | depth_ = depth_[0].transpose(1, 2, 0) / max(depth_.max(), 10)
51 | vmax = np.percentile(depth_, 95)
52 | normalizer = mpl.colors.Normalize(vmin=depth_.min(), vmax=vmax)
53 | mapper = cm.ScalarMappable(norm=normalizer, cmap='magma')
54 | depth = (mapper.to_rgba(depth_)[:,:,:3] * 255).astype(np.uint8)
55 | lab = np.concatenate(lab)
56 | lab = np.array([cmap[i.astype(np.uint8) + 1] for i in lab])
57 |
58 | pre = np.concatenate(pre)
59 | pre = np.array([cmap[i.astype(np.uint8) + 1] for i in pre])
60 | img = img.transpose(1, 2, 0)
61 |
62 | return np.concatenate([img, depth, lab, pre], 1)
63 |
--------------------------------------------------------------------------------
/mmcv_custom/runner/checkpoint.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Open-MMLab. All rights reserved.
2 | import os.path as osp
3 | import time
4 | from tempfile import TemporaryDirectory
5 |
6 | import torch
7 | from torch.optim import Optimizer
8 |
9 | import mmcv
10 | from mmcv.parallel import is_module_wrapper
11 | from mmcv.runner.checkpoint import weights_to_cpu, get_state_dict
12 |
13 | try:
14 | import apex
15 | except:
16 | print("apex is not installed")
17 |
18 |
19 | def save_checkpoint(model, filename, optimizer=None, meta=None):
20 | """Save checkpoint to file.
21 |
22 | The checkpoint will have 4 fields: ``meta``, ``state_dict`` and
23 | ``optimizer``, ``amp``. By default ``meta`` will contain version
24 | and time info.
25 |
26 | Args:
27 | model (Module): Module whose params are to be saved.
28 | filename (str): Checkpoint filename.
29 | optimizer (:obj:`Optimizer`, optional): Optimizer to be saved.
30 | meta (dict, optional): Metadata to be saved in checkpoint.
31 | """
32 | if meta is None:
33 | meta = {}
34 | elif not isinstance(meta, dict):
35 | raise TypeError(f"meta must be a dict or None, but got {type(meta)}")
36 | meta.update(mmcv_version=mmcv.__version__, time=time.asctime())
37 |
38 | if is_module_wrapper(model):
39 | model = model.module
40 |
41 | if hasattr(model, "CLASSES") and model.CLASSES is not None:
42 | # save class name to the meta
43 | meta.update(CLASSES=model.CLASSES)
44 |
45 | checkpoint = {"meta": meta, "state_dict": weights_to_cpu(get_state_dict(model))}
46 | # save optimizer state dict in the checkpoint
47 | if isinstance(optimizer, Optimizer):
48 | checkpoint["optimizer"] = optimizer.state_dict()
49 | elif isinstance(optimizer, dict):
50 | checkpoint["optimizer"] = {}
51 | for name, optim in optimizer.items():
52 | checkpoint["optimizer"][name] = optim.state_dict()
53 |
54 | # save amp state dict in the checkpoint
55 | checkpoint["amp"] = apex.amp.state_dict()
56 |
57 | if filename.startswith("pavi://"):
58 | try:
59 | from pavi import modelcloud
60 | from pavi.exception import NodeNotFoundError
61 | except ImportError:
62 | raise ImportError("Please install pavi to load checkpoint from modelcloud.")
63 | model_path = filename[7:]
64 | root = modelcloud.Folder()
65 | model_dir, model_name = osp.split(model_path)
66 | try:
67 | model = modelcloud.get(model_dir)
68 | except NodeNotFoundError:
69 | model = root.create_training_model(model_dir)
70 | with TemporaryDirectory() as tmp_dir:
71 | checkpoint_file = osp.join(tmp_dir, model_name)
72 | with open(checkpoint_file, "wb") as f:
73 | torch.save(checkpoint, f)
74 | f.flush()
75 | model.create_file(checkpoint_file, name=model_name)
76 | else:
77 | mmcv.mkdir_or_exist(osp.dirname(filename))
78 | # immediately flush buffer
79 | with open(filename, "wb") as f:
80 | torch.save(checkpoint, f)
81 | f.flush()
82 |
--------------------------------------------------------------------------------
/utils/meter.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import torch
4 | import numpy as np
5 |
6 |
7 | def confusion_matrix(x, y, n, ignore_label=None, mask=None):
8 | if mask is None:
9 | mask = np.ones_like(x) == 1
10 | k = (x >= 0) & (y < n) & (x != ignore_label) & (mask.astype(np.bool))
11 | return np.bincount(n * x[k].astype(int) + y[k], minlength=n ** 2).reshape(n, n)
12 |
13 |
14 | def getScores(conf_matrix):
15 | if conf_matrix.sum() == 0:
16 | return 0, 0, 0
17 | with np.errstate(divide='ignore',invalid='ignore'):
18 | overall = np.diag(conf_matrix).sum() / np.float(conf_matrix.sum())
19 | perclass = np.diag(conf_matrix) / conf_matrix.sum(1).astype(np.float)
20 | IU = np.diag(conf_matrix) / (conf_matrix.sum(1) + conf_matrix.sum(0) \
21 | - np.diag(conf_matrix)).astype(np.float)
22 | return overall * 100., np.nanmean(perclass) * 100., np.nanmean(IU) * 100.
23 |
24 |
25 | def compute_params(model):
26 | """Compute number of parameters"""
27 | n_total_params = 0
28 | for name, m in model.named_parameters():
29 | n_elem = m.numel()
30 | n_total_params += n_elem
31 | return n_total_params
32 |
33 |
34 | # Adopted from https://raw.githubusercontent.com/pytorch/examples/master/imagenet/main.py
35 | class AverageMeter(object):
36 | """Computes and stores the average and current value"""
37 | def __init__(self):
38 | self.reset()
39 |
40 | def reset(self):
41 | self.val = 0
42 | self.avg = 0
43 | self.sum = 0
44 | self.count = 0
45 |
46 | def update(self, val, n=1):
47 | self.val = val
48 | self.sum += val * n
49 | self.count += n
50 | self.avg = self.sum / self.count
51 |
52 |
53 | class Saver():
54 | """Saver class for managing parameters"""
55 | def __init__(self, args, ckpt_dir, best_val=0, condition=lambda x, y: x > y):
56 | """
57 | Args:
58 | args (dict): dictionary with arguments.
59 | ckpt_dir (str): path to directory in which to store the checkpoint.
60 | best_val (float): initial best value.
61 | condition (function): how to decide whether to save the new checkpoint
62 | by comparing best value and new value (x,y).
63 |
64 | """
65 | if not os.path.exists(ckpt_dir):
66 | os.makedirs(ckpt_dir)
67 | with open('{}/args.json'.format(ckpt_dir), 'w') as f:
68 | json.dump({k: v for k, v in args.items() if isinstance(v, (int, float, str))}, f,
69 | sort_keys = True, indent = 4, ensure_ascii = False)
70 | self.ckpt_dir = ckpt_dir
71 | self.best_val = best_val
72 | self.condition = condition
73 | self._counter = 0
74 |
75 | def _do_save(self, new_val):
76 | """Check whether need to save"""
77 | return self.condition(new_val, self.best_val)
78 |
79 | def save(self, new_val, dict_to_save):
80 | """Save new checkpoint"""
81 | self._counter += 1
82 | if self._do_save(new_val):
83 | # print(' New best value {:.4f}, was {:.4f}'.format(new_val, self.best_val), flush=True)
84 | self.best_val = new_val
85 | dict_to_save['best_val'] = new_val
86 | torch.save(dict_to_save, '{}/model-best.pth.tar'.format(self.ckpt_dir))
87 | else:
88 | dict_to_save['best_val'] = new_val
89 | torch.save(dict_to_save, '{}/checkpoint.pth.tar'.format(self.ckpt_dir))
90 |
91 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 |
--------------------------------------------------------------------------------
/mmcv_custom/runner/epoch_based_runner.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Open-MMLab. All rights reserved.
2 | import os.path as osp
3 | import platform
4 | import shutil
5 |
6 | import torch
7 | from torch.optim import Optimizer
8 |
9 | import mmcv
10 | from mmcv.runner import RUNNERS, EpochBasedRunner
11 | from .checkpoint import save_checkpoint
12 |
13 | try:
14 | import apex
15 | except:
16 | print("apex is not installed")
17 |
18 |
19 | @RUNNERS.register_module()
20 | class EpochBasedRunnerAmp(EpochBasedRunner):
21 | """Epoch-based Runner with AMP support.
22 |
23 | This runner train models epoch by epoch.
24 | """
25 |
26 | def save_checkpoint(
27 | self,
28 | out_dir,
29 | filename_tmpl="epoch_{}.pth",
30 | save_optimizer=True,
31 | meta=None,
32 | create_symlink=True,
33 | ):
34 | """Save the checkpoint.
35 |
36 | Args:
37 | out_dir (str): The directory that checkpoints are saved.
38 | filename_tmpl (str, optional): The checkpoint filename template,
39 | which contains a placeholder for the epoch number.
40 | Defaults to 'epoch_{}.pth'.
41 | save_optimizer (bool, optional): Whether to save the optimizer to
42 | the checkpoint. Defaults to True.
43 | meta (dict, optional): The meta information to be saved in the
44 | checkpoint. Defaults to None.
45 | create_symlink (bool, optional): Whether to create a symlink
46 | "latest.pth" to point to the latest checkpoint.
47 | Defaults to True.
48 | """
49 | if meta is None:
50 | meta = dict(epoch=self.epoch + 1, iter=self.iter)
51 | elif isinstance(meta, dict):
52 | meta.update(epoch=self.epoch + 1, iter=self.iter)
53 | else:
54 | raise TypeError(f"meta should be a dict or None, but got {type(meta)}")
55 | if self.meta is not None:
56 | meta.update(self.meta)
57 |
58 | filename = filename_tmpl.format(self.epoch + 1)
59 | filepath = osp.join(out_dir, filename)
60 | optimizer = self.optimizer if save_optimizer else None
61 | save_checkpoint(self.model, filepath, optimizer=optimizer, meta=meta)
62 | # in some environments, `os.symlink` is not supported, you may need to
63 | # set `create_symlink` to False
64 | if create_symlink:
65 | dst_file = osp.join(out_dir, "latest.pth")
66 | if platform.system() != "Windows":
67 | mmcv.symlink(filename, dst_file)
68 | else:
69 | shutil.copy(filepath, dst_file)
70 |
71 | def resume(self, checkpoint, resume_optimizer=True, map_location="default"):
72 | if map_location == "default":
73 | if torch.cuda.is_available():
74 | device_id = torch.cuda.current_device()
75 | checkpoint = self.load_checkpoint(
76 | checkpoint,
77 | map_location=lambda storage, loc: storage.cuda(device_id),
78 | )
79 | else:
80 | checkpoint = self.load_checkpoint(checkpoint)
81 | else:
82 | checkpoint = self.load_checkpoint(checkpoint, map_location=map_location)
83 |
84 | self._epoch = checkpoint["meta"]["epoch"]
85 | self._iter = checkpoint["meta"]["iter"]
86 | if "optimizer" in checkpoint and resume_optimizer:
87 | if isinstance(self.optimizer, Optimizer):
88 | self.optimizer.load_state_dict(checkpoint["optimizer"])
89 | elif isinstance(self.optimizer, dict):
90 | for k in self.optimizer.keys():
91 | self.optimizer[k].load_state_dict(checkpoint["optimizer"][k])
92 | else:
93 | raise TypeError(
94 | "Optimizer should be dict or torch.optim.Optimizer "
95 | f"but got {type(self.optimizer)}"
96 | )
97 |
98 | if "amp" in checkpoint:
99 | apex.amp.load_state_dict(checkpoint["amp"])
100 | self.logger.info("load amp state dict")
101 |
102 | self.logger.info("resumed epoch %d, iter %d", self.epoch, self.iter)
103 |
104 |
--------------------------------------------------------------------------------
/utils/datasets.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import cv2
4 | from PIL import Image
5 | from torch.utils.data import Dataset
6 | from torchvision import io
7 |
8 |
9 | def line_to_paths_fn(x, input_names):
10 | return x.decode("utf-8").strip("\n").split("\t")
11 |
12 |
13 | class SegDataset(Dataset):
14 | """Multi-Modality Segmentation dataset.
15 |
16 | Works with any datasets that contain image
17 | and any number of 2D-annotations.
18 |
19 | Args:
20 | data_file (string): Path to the data file with annotations.
21 | data_dir (string): Directory with all the images.
22 | line_to_paths_fn (callable): function to convert a line of data_file
23 | into paths (img_relpath, msk_relpath, ...).
24 | masks_names (list of strings): keys for each annotation mask
25 | (e.g., 'segm', 'depth').
26 | transform_trn (callable, optional): Optional transform
27 | to be applied on a sample during the training stage.
28 | transform_val (callable, optional): Optional transform
29 | to be applied on a sample during the validation stage.
30 | stage (str): initial stage of dataset - either 'train' or 'val'.
31 |
32 | """
33 |
34 | def __init__(
35 | self,
36 | dataset,
37 | data_file,
38 | data_dir,
39 | input_names,
40 | input_mask_idxs,
41 | transform_trn=None,
42 | transform_val=None,
43 | stage="train",
44 | ignore_label=None,
45 | ):
46 | with open(data_file, "rb") as f:
47 | datalist = f.readlines()
48 | self.dataset = dataset
49 | self.datalist = [line_to_paths_fn(l, input_names) for l in datalist]
50 | self.root_dir = data_dir
51 | self.transform_trn = transform_trn
52 | self.transform_val = transform_val
53 | self.stage = stage
54 | self.input_names = input_names
55 | self.input_mask_idxs = input_mask_idxs
56 | self.ignore_label = ignore_label
57 |
58 | def set_stage(self, stage):
59 | """Define which set of transformation to use.
60 |
61 | Args:
62 | stage (str): either 'train' or 'val'
63 |
64 | """
65 | self.stage = stage
66 |
67 | def __len__(self):
68 | return len(self.datalist)
69 |
70 | def __getitem__(self, idx):
71 | idxs = self.input_mask_idxs
72 | names = [os.path.join(self.root_dir, rpath) for rpath in self.datalist[idx]]
73 | sample = {}
74 | for i, key in enumerate(self.input_names):
75 | sample[key] = self.read_image(names[idxs[i]], key)
76 | try:
77 | if self.dataset == "nyudv2":
78 | mask = np.array(Image.open(names[idxs[-1]]))
79 | elif self.dataset == "sunrgbd":
80 | mask = self._open_image(
81 | names[idxs[-1]], cv2.IMREAD_GRAYSCALE, dtype=np.uint8
82 | )
83 | except FileNotFoundError: # for sunrgbd
84 | path = names[idxs[-1]]
85 | num_idx = int(path[-10:-4]) + 5050
86 | path = path[:-10] + "%06d" % num_idx + path[-4:]
87 | mask = np.array(Image.open(path))
88 |
89 | if self.dataset == "sunrgbd":
90 | mask -= 1
91 |
92 | assert len(mask.shape) == 2, "Masks must be encoded without colourmap"
93 | sample["inputs"] = self.input_names
94 | sample["mask"] = mask
95 |
96 | del sample["inputs"]
97 | if self.stage == "train":
98 | if self.transform_trn:
99 | sample = self.transform_trn(sample)
100 | elif self.stage == "val":
101 | if self.transform_val:
102 | sample = self.transform_val(sample)
103 |
104 | return sample
105 |
106 | @staticmethod
107 | def _open_image(filepath, mode=cv2.IMREAD_COLOR, dtype=None):
108 | img = np.array(cv2.imread(filepath, mode), dtype=dtype)
109 | return img
110 |
111 | @staticmethod
112 | def read_image(x, key):
113 | """Simple image reader
114 |
115 | Args:
116 | x (str): path to image.
117 |
118 | Returns image as `np.array`.
119 |
120 | """
121 | img_arr = np.array(Image.open(x))
122 | if len(img_arr.shape) == 2: # grayscale
123 | img_arr = np.tile(img_arr, [3, 1, 1]).transpose(1, 2, 0)
124 | return img_arr
125 |
--------------------------------------------------------------------------------
/README-TokenFusion.md:
--------------------------------------------------------------------------------
1 | # Multimodal Token Fusion for Vision Transformers
2 |
3 | By Yikai Wang, Xinghao Chen, Lele Cao, Wenbing Huang, Fuchun Sun, Yunhe Wang.
4 |
5 | [**[Paper]**](https://arxiv.org/pdf/2204.08721.pdf)
6 |
7 | This repository is a PyTorch implementation of "Multimodal Token Fusion for Vision Transformers", in CVPR 2022.
8 |
9 |
10 |

11 |
12 |
13 | Homogeneous predictions,
14 |
15 |

16 |
17 |
18 | Heterogeneous predictions,
19 |
20 |

21 |
22 |
23 |
24 | ## Datasets
25 |
26 | For semantic segmentation task on NYUDv2 ([official dataset](https://cs.nyu.edu/~silberman/datasets/nyu_depth_v2.html)), we provide a link to download the dataset [here](https://drive.google.com/drive/folders/1mXmOXVsd5l9-gYHk92Wpn6AcKAbE0m3X?usp=sharing). The provided dataset is originally preprocessed in this [repository](https://github.com/DrSleep/light-weight-refinenet), and we add depth data in it.
27 |
28 | For image-to-image translation task, we use the sample dataset of [Taskonomy](http://taskonomy.stanford.edu/), where a link to download the sample dataset is [here](https://github.com/alexsax/taskonomy-sample-model-1.git).
29 |
30 | Please modify the data paths in the codes, where we add comments 'Modify data path'.
31 |
32 |
33 | ## Dependencies
34 | ```
35 | python==3.6
36 | pytorch==1.7.1
37 | torchvision==0.8.2
38 | numpy==1.19.2
39 | ```
40 |
41 |
42 | ## Semantic Segmentation
43 |
44 |
45 | First,
46 | ```
47 | cd semantic_segmentation
48 | ```
49 |
50 | Download the [segformer](https://github.com/NVlabs/SegFormer) pretrained model (pretrained on ImageNet) from [weights](https://drive.google.com/drive/folders/1b7bwrInTW4VLEm27YawHOAMSMikga2Ia), e.g., mit_b3.pth. Move this pretrained model to folder 'pretrained'.
51 |
52 | Training script for segmentation with RGB and Depth input,
53 | ```
54 | python main.py --backbone mit_b3 -c exp_name --lamda 1e-6 --gpu 0 1 2
55 | ```
56 |
57 | Evaluation script,
58 | ```
59 | python main.py --gpu 0 --resume path_to_pth --evaluate # optionally use --save-img to visualize results
60 | ```
61 |
62 | Checkpoint models, training logs, mask ratios and the **single-scale** performance on NYUDv2 are provided as follows:
63 |
64 | | Method | Backbone | Pixel Acc. (%) | Mean Acc. (%) | Mean IoU (%) | Download |
65 | |:-----------:|:-----------:|:-----------:|:-----------:|:-----------:|:-----------:|
66 | |[CEN](https://github.com/yikaiw/CEN)| ResNet101 | 76.2 | 62.8 | 51.1 | [Google Drive](https://drive.google.com/drive/folders/1wim_cBG-HW0bdipwA1UbnGeDwjldPIwV?usp=sharing)|
67 | |[CEN](https://github.com/yikaiw/CEN)| ResNet152 | 77.0 | 64.4 | 51.6 | [Google Drive](https://drive.google.com/drive/folders/1DGF6vHLDgBgLrdUNJOLYdoXCuEKbIuRs?usp=sharing)|
68 | |Ours| SegFormer-B3 | 78.7 | 67.5 | 54.8 | [Google Drive](https://drive.google.com/drive/folders/14fi8aABFYqGF7LYKHkiJazHA58OBW1AW?usp=sharing)|
69 |
70 |
71 | Mindspore implementation is available at: https://gitee.com/mindspore/models/tree/master/research/cv/TokenFusion
72 |
73 | ## Image-to-Image Translation
74 |
75 | First,
76 | ```
77 | cd image2image_translation
78 | ```
79 | Training script, from Shade and Texture to RGB,
80 | ```
81 | python main.py --gpu 0 -c exp_name
82 | ```
83 | This script will auto-evaluate on the validation dataset every 5 training epochs.
84 |
85 | Predicted images will be automatically saved during training, in the following folder structure:
86 |
87 | ```
88 | code_root/ckpt/exp_name/results
89 | ├── input0 # 1st modality input
90 | ├── input1 # 2nd modality input
91 | ├── fake0 # 1st branch output
92 | ├── fake1 # 2nd branch output
93 | ├── fake2 # ensemble output
94 | ├── best # current best output
95 | │ ├── fake0
96 | │ ├── fake1
97 | │ └── fake2
98 | └── real # ground truth output
99 | ```
100 |
101 | Checkpoint models:
102 |
103 | | Method | Task | FID | KID | Download |
104 | |:-----------:|:-----------:|:-----------:|:-----------:|:-----------:|
105 | | [CEN](https://github.com/yikaiw/CEN) |Texture+Shade->RGB | 62.6 | 1.65 | - |
106 | | Ours | Texture+Shade->RGB | 45.5 | 1.00 | [Google Drive](https://drive.google.com/drive/folders/1vkcDv5bHKXZKxCg4dC7R56ts6nLLt6lh?usp=sharing)|
107 |
108 | ## 3D Object Detection (under construction)
109 |
110 | Data preparation, environments, and training scripts follow [Group-Free](https://github.com/zeliu98/Group-Free-3D) and [ImVoteNet](https://github.com/facebookresearch/imvotenet).
111 |
112 | E.g.,
113 | ```
114 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --master_port 2229 --nproc_per_node 4 train_dist.py --max_epoch 600 --val_freq 25 --save_freq 25 --lr_decay_epochs 420 480 540 --num_point 20000 --num_decoder_layers 6 --size_cls_agnostic --size_delta 0.0625 --heading_delta 0.04 --center_delta 0.1111111111111 --weight_decay 0.00000001 --query_points_generator_loss_coef 0.2 --obj_loss_coef 0.4 --dataset sunrgbd --data_root . --use_img --log_dir log/exp_name
115 | ```
116 |
117 | ## Citation
118 |
119 | If you find our work useful for your research, please consider citing the following paper.
120 | ```
121 | @inproceedings{wang2022tokenfusion,
122 | title={Multimodal Token Fusion for Vision Transformers},
123 | author={Wang, Yikai and Chen, Xinghao and Cao, Lele and Huang, Wenbing and Sun, Fuchun and Wang, Yunhe},
124 | booktitle={IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
125 | year={2022}
126 | }
127 | ```
128 |
129 |
130 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## GeminiFusion for Multimodal Segementation on NYUDv2 & SUN RGBD Dataset (ICML 2024)
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 | [](https://paperswithcode.com/sota/semantic-segmentation-on-deliver-1?p=geminifusion-efficient-pixel-wise-multimodal)
21 | [](https://paperswithcode.com/sota/semantic-segmentation-on-nyu-depth-v2?p=geminifusion-efficient-pixel-wise-multimodal)
22 | [](https://paperswithcode.com/sota/semantic-segmentation-on-sun-rgbd?p=geminifusion-efficient-pixel-wise-multimodal)
23 |
24 |
25 | This is the official implementation of our paper "[GeminiFusion: Efficient Pixel-wise Multimodal Fusion for Vision Transformer](https://arxiv.org/pdf/2406.01210)".
26 |
27 | Authors: Ding Jia, Jianyuan Guo, Kai Han, Han Wu, Chao Zhang, Chang Xu, Xinghao Chen
28 |
29 |
30 |
31 | ## Code List
32 |
33 | We have applied our GeminiFusion to different tasks and datasets:
34 |
35 | * GeminiFusion for Multimodal Semantic Segmentation
36 | * (This branch)[NYUDv2 & SUN RGBD datasets](https://github.com/JiaDingCN/GeminiFusion/tree/main)
37 | * [DeLiVER dataset](https://github.com/JiaDingCN/GeminiFusion/tree/DeLiVER)
38 | * GeminiFusion for Multimodal 3D Object Detection
39 | * [KITTI dataset](https://github.com/JiaDingCN/GeminiFusion/tree/3d_object_detection_kitti)
40 |
41 |
42 | ## Introduction
43 |
44 | We propose GeminiFusion, a pixel-wise fusion approach that capitalizes on aligned cross-modal representations. GeminiFusion elegantly combines intra-modal and inter-modal attentions, dynamically integrating complementary information across modalities. We employ a layer-adaptive noise to adaptively control their interplay on a per-layer basis, thereby achieving a harmonized fusion process. Notably, GeminiFusion maintains linear complexity with respect to the number of input tokens, ensuring this multimodal framework operates with efficiency comparable to unimodal networks. Comprehensive evaluations demonstrate the superior performance of our GeminiFusion against leading-edge techniques.
45 |
46 |
47 |
48 | ## Framework
49 | 
50 |
51 |
52 |
53 | ## Model Zoo
54 |
55 | ### NYUDv2 dataset
56 |
57 | | Model | backbone| mIoU | Download |
58 | |:-------:|:--------:|:-------:|:-------------------:|
59 | | GeminiFusion | MiT-B3| 56.8 | [model](https://github.com/JiaDingCN/GeminiFusion/releases/download/NYUDv2_V2/mit-b3.pth.tar) |
60 | | GeminiFusion | MiT-B5| 57.7 | [model](https://github.com/JiaDingCN/GeminiFusion/releases/download/NYUDv2_V2/mit_b5.pth.tar) |
61 | | GeminiFusion | swin_tiny| 52.2 | [model](https://github.com/JiaDingCN/GeminiFusion/releases/download/NYUDv2_V2/swin_tiny.pth.tar) |
62 | | GeminiFusion | swin-small| 55.0 | [model](https://github.com/JiaDingCN/GeminiFusion/releases/download/NYUDv2_V2/swin_small.pth.tar) |
63 | | GeminiFusion | swin-large-224| 58.8 | [model](https://github.com/JiaDingCN/GeminiFusion/releases/download/NYUDv2_V2/swin_large.pth.tar) |
64 | | GeminiFusion | swin-large-384| 60.2 | [model](https://github.com/JiaDingCN/GeminiFusion/releases/download/NYUDv2_V2/swin_large_384.pth.tar) |
65 | | GeminiFusion | swin-large-384 +FineTune from SUN 300eps| 60.9 | [model](https://github.com/JiaDingCN/GeminiFusion/releases/download/NYUDv2_V2/finetune-swin-large-384.pth.tar) |
66 |
67 | ### SUN RGBD dataset
68 |
69 | | Model | backbone| mIoU | Download |
70 | |:-------:|:--------:|:-------:|:-------------------:|
71 | | GeminiFusion | MiT-B3| 52.7 | [model](https://github.com/JiaDingCN/GeminiFusion/releases/download/SUN_v2/mit-b3.pth.tar) |
72 | | GeminiFusion | MiT-B5| 53.3 | [model](https://github.com/JiaDingCN/GeminiFusion/releases/download/SUN_v2/mit_b5.pth.tar) |
73 | | GeminiFusion | swin_tiny| 50.2 | [model](https://github.com/JiaDingCN/GeminiFusion/releases/download/SUN_v2/swin_tiny.pth.tar) |
74 | | GeminiFusion | swin-large-384| 54.8 | [model](https://github.com/JiaDingCN/GeminiFusion/releases/download/SUN_v2/swin-large-384.pth.tar) |
75 |
76 |
77 |
78 | ## Installation
79 |
80 | We build our GeminiFusion on the TokenFusion codebase, which requires no additional installation steps. If any problem about the framework, you may refer to [the offical TokenFusion readme](./README-TokenFusion.md).
81 |
82 | Most of the `GeminiFusion`-related code locate in the following files:
83 | * [models/mix_transformer](models/mix_transformer.py): implement the GeminiFusion module for MiT backbones.
84 | * [models/swin_transformer](models/swin_transformer.py):implement the GeminiFusion module for Swin backbones.
85 | * [mmcv_custom](mmcv_custom): load checkpoints for Swin backbones.
86 | * [main](main.py): enable SUN RGBD dataset.
87 | * [utils/datasets](utils/datasets.py): enable SUN RGBD dataset.
88 |
89 | We also delete the config.py in the TokenFusion codebase since it is not used here.
90 |
91 |
92 |
93 | ## Data
94 |
95 | **NYUDv2 Dataset Prapare**
96 |
97 | Please follow [the data preparation instructions for NYUDv2 in TokenFusion readme](./README-TokenFusion.md#datasets). In default the data path is `/cache/datasets/nyudv2`, you may change it by `--train-dir `.
98 |
99 | **SUN RGBD Dataset Prapare**
100 |
101 | Please download the SUN RGBD dataset follow the link in [DFormer](https://github.com/VCIP-RGBD/DFormer?tab=readme-ov-file#2--get-start).In default the data path is `/cache/datasets/sunrgbd_Dformer/SUNRGBD`, you may change it by `--train-dir `.
102 |
103 |
104 |
105 | ## Train
106 |
107 | **NYUDv2 Training**
108 |
109 | On the NYUDv2 dataset, we follow the TokenFusion's setting, using 3 GPUs to train the GeminiFusion.
110 |
111 | ```shell
112 | # mit-b3
113 | CUDA_VISIBLE_DEVICES=0,1,2 python -m torch.distributed.launch --nproc_per_node=3 --use_env main.py --backbone mit_b3 --dataset nyudv2 -c nyudv2_mit_b3
114 |
115 | # mit-b5
116 | CUDA_VISIBLE_DEVICES=0,1,2 python -m torch.distributed.launch --nproc_per_node=3 --use_env main.py --backbone mit_b5 --dataset nyudv2 -c nyudv2_mit_b5 --dpr 0.35
117 |
118 | # swin_tiny
119 | CUDA_VISIBLE_DEVICES=0,1,2 python -m torch.distributed.launch --nproc_per_node=3 --use_env main.py --backbone swin_tiny --dataset nyudv2 -c nyudv2_swin_tiny --dpr 0.2
120 |
121 | # swin_small
122 | CUDA_VISIBLE_DEVICES=0,1,2 python -m torch.distributed.launch --nproc_per_node=3 --use_env main.py --backbone swin_small --dataset nyudv2 -c nyudv2_swin_small
123 |
124 | # swin_large
125 | CUDA_VISIBLE_DEVICES=0,1,2 python -m torch.distributed.launch --nproc_per_node=3 --use_env main.py --backbone swin_large --dataset nyudv2 -c nyudv2_swin_large
126 |
127 | # swin_large_window12
128 | CUDA_VISIBLE_DEVICES=0,1,2 python -m torch.distributed.launch --nproc_per_node=3 --use_env main.py --backbone swin_large_window12 --dataset nyudv2 -c nyudv2_swin_large_window12 --dpr 0.2
129 |
130 | # swin-large-384+FineTune from SUN 300eps
131 | # swin-large-384.pth.tar should be downloaded by our link or trained by yourself
132 | CUDA_VISIBLE_DEVICES=0,1,2 python -m torch.distributed.launch --nproc_per_node=3 --use_env main.py --backbone swin_large_window12 --dataset nyudv2 -c swin_large_window12_finetune_dpr0.15_100+200+100 \
133 | --dpr 0.15 --num-epoch 100 200 100 --is_pretrain_finetune --resume ./swin-large-384.pth.tar
134 | ```
135 |
136 | **SUN RGBD Training**
137 |
138 | On the SUN RGBD dataset, we use 4 GPUs to train the GeminiFusion.
139 | ```shell
140 | # mit-b3
141 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --use_env main.py --backbone mit_b3 --dataset sunrgbd --train-dir /cache/datasets/sunrgbd_Dformer/SUNRGBD -c sunrgbd_mit_b3
142 |
143 | # mit-b5
144 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --use_env main.py --backbone mit_b5 --dataset sunrgbd --train-dir /cache/datasets/sunrgbd_Dformer/SUNRGBD -c sunrgbd_mit_b5 --weight_decay 0.05
145 |
146 | # swin_tiny
147 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --use_env main.py --backbone swin_tiny --dataset sunrgbd --train-dir /cache/datasets/sunrgbd_Dformer/SUNRGBD -c sunrgbd_swin_tiny
148 |
149 | # swin_large_window12
150 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --use_env main.py --backbone swin_large_window12 --dataset sunrgbd --train-dir /cache/datasets/sunrgbd_Dformer/SUNRGBD -c sunrgbd_swin_large_window12
151 | ```
152 |
153 |
154 |
155 | ## Test
156 |
157 | To evaluate checkpoints, you need to add `--eval --resume ` after the training script.
158 |
159 | For example, on the NYUDv2 dataset, the training script for GeminiFusion with mit-b3 backbone is:
160 | ```shell
161 | CUDA_VISIBLE_DEVICES=0,1,2 python -m torch.distributed.launch --nproc_per_node=3 --use_env main.py --backbone mit_b3 --dataset nyudv2 -c nyudv2_mit_b3
162 | ```
163 |
164 | To evaluate the trained or downloaded checkpoint, the eval script is:
165 | ```shell
166 | CUDA_VISIBLE_DEVICES=0,1,2 python -m torch.distributed.launch --nproc_per_node=3 --use_env main.py --backbone mit_b3 --dataset nyudv2 -c nyudv2_mit_b3 --eval --resume mit-b3.pth.tar
167 | ```
168 |
169 |
170 |
171 | ## Citation
172 |
173 | If you find this work useful for your research, please cite our paper:
174 |
175 | ```
176 | @misc{jia2024geminifusion,
177 | title={GeminiFusion: Efficient Pixel-wise Multimodal Fusion for Vision Transformer},
178 | author={Ding Jia and Jianyuan Guo and Kai Han and Han Wu and Chao Zhang and Chang Xu and Xinghao Chen},
179 | year={2024},
180 | eprint={2406.01210},
181 | archivePrefix={arXiv},
182 | primaryClass={cs.CV}
183 | }
184 | ```
185 |
186 |
187 |
188 | ## Acknowledgement
189 | Part of our code is based on the open-source project [TokenFusion](https://github.com/yikaiw/TokenFusion).
190 |
--------------------------------------------------------------------------------
/utils/transforms.py:
--------------------------------------------------------------------------------
1 | """RefineNet-LightWeight
2 |
3 | RefineNet-LigthWeight PyTorch for non-commercial purposes
4 |
5 | Copyright (c) 2018, Vladimir Nekrasov (vladimir.nekrasov@adelaide.edu.au)
6 | All rights reserved.
7 |
8 | Redistribution and use in source and binary forms, with or without
9 | modification, are permitted provided that the following conditions are met:
10 |
11 | * Redistributions of source code must retain the above copyright notice, this
12 | list of conditions and the following disclaimer.
13 |
14 | * Redistributions in binary form must reproduce the above copyright notice,
15 | this list of conditions and the following disclaimer in the documentation
16 | and/or other materials provided with the distribution.
17 |
18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 | """
29 |
30 |
31 | import cv2
32 | import numpy as np
33 | import torch
34 |
35 | # Usual dtypes for common modalities
36 | KEYS_TO_DTYPES = {
37 | "rgb": torch.float,
38 | "depth": torch.float,
39 | "normals": torch.float,
40 | "mask": torch.long,
41 | }
42 |
43 |
44 | class Pad(object):
45 | """Pad image and mask to the desired size.
46 |
47 | Args:
48 | size (int) : minimum length/width.
49 | img_val (array) : image padding value.
50 | msk_val (int) : mask padding value.
51 |
52 | """
53 |
54 | def __init__(self, size, img_val, msk_val):
55 | assert isinstance(size, int)
56 | self.size = size
57 | self.img_val = img_val
58 | self.msk_val = msk_val
59 |
60 | def __call__(self, sample):
61 | image = sample["rgb"]
62 | h, w = image.shape[:2]
63 | h_pad = int(np.clip(((self.size - h) + 1) // 2, 0, 1e6))
64 | w_pad = int(np.clip(((self.size - w) + 1) // 2, 0, 1e6))
65 | pad = ((h_pad, h_pad), (w_pad, w_pad))
66 | for key in sample["inputs"]:
67 | sample[key] = self.transform_input(sample[key], pad)
68 | sample["mask"] = np.pad(
69 | sample["mask"], pad, mode="constant", constant_values=self.msk_val
70 | )
71 | return sample
72 |
73 | def transform_input(self, input, pad):
74 | input = np.stack(
75 | [
76 | np.pad(
77 | input[:, :, c],
78 | pad,
79 | mode="constant",
80 | constant_values=self.img_val[c],
81 | )
82 | for c in range(3)
83 | ],
84 | axis=2,
85 | )
86 | return input
87 |
88 |
89 | class RandomCrop(object):
90 | """Crop randomly the image in a sample.
91 |
92 | Args:
93 | crop_size (int): Desired output size.
94 |
95 | """
96 |
97 | def __init__(self, crop_size):
98 | assert isinstance(crop_size, int)
99 | self.crop_size = crop_size
100 | if self.crop_size % 2 != 0:
101 | self.crop_size -= 1
102 |
103 | def __call__(self, sample):
104 | image = sample["rgb"]
105 | h, w = image.shape[:2]
106 | new_h = min(h, self.crop_size)
107 | new_w = min(w, self.crop_size)
108 | top = np.random.randint(0, h - new_h + 1)
109 | left = np.random.randint(0, w - new_w + 1)
110 | for key in sample["inputs"]:
111 | sample[key] = self.transform_input(sample[key], top, new_h, left, new_w)
112 | sample["mask"] = sample["mask"][top : top + new_h, left : left + new_w]
113 | return sample
114 |
115 | def transform_input(self, input, top, new_h, left, new_w):
116 | input = input[top : top + new_h, left : left + new_w]
117 | return input
118 |
119 |
120 | class ResizeAndScale(object):
121 | """Resize shorter/longer side to a given value and randomly scale.
122 |
123 | Args:
124 | side (int) : shorter / longer side value.
125 | low_scale (float) : lower scaling bound.
126 | high_scale (float) : upper scaling bound.
127 | shorter (bool) : whether to resize shorter / longer side.
128 |
129 | """
130 |
131 | def __init__(self, side, low_scale, high_scale, shorter=True):
132 | assert isinstance(side, int)
133 | assert isinstance(low_scale, float)
134 | assert isinstance(high_scale, float)
135 | self.side = side
136 | self.low_scale = low_scale
137 | self.high_scale = high_scale
138 | self.shorter = shorter
139 |
140 | def __call__(self, sample):
141 | image = sample["rgb"]
142 | scale = np.random.uniform(self.low_scale, self.high_scale)
143 | if self.shorter:
144 | min_side = min(image.shape[:2])
145 | if min_side * scale < self.side:
146 | scale = self.side * 1.0 / min_side
147 | else:
148 | max_side = max(image.shape[:2])
149 | if max_side * scale > self.side:
150 | scale = self.side * 1.0 / max_side
151 | inters = {"rgb": cv2.INTER_CUBIC, "depth": cv2.INTER_NEAREST}
152 | for key in sample["inputs"]:
153 | inter = inters[key] if key in inters else cv2.INTER_CUBIC
154 | sample[key] = self.transform_input(sample[key], scale, inter)
155 | sample["mask"] = cv2.resize(
156 | sample["mask"], None, fx=scale, fy=scale, interpolation=cv2.INTER_NEAREST
157 | )
158 | return sample
159 |
160 | def transform_input(self, input, scale, inter):
161 | input = cv2.resize(input, None, fx=scale, fy=scale, interpolation=inter)
162 | return input
163 |
164 |
165 | class CropAlignToMask(object):
166 | """Crop inputs to the size of the mask."""
167 |
168 | def __call__(self, sample):
169 | mask_h, mask_w = sample["mask"].shape[:2]
170 | for key in sample["inputs"]:
171 | sample[key] = self.transform_input(sample[key], mask_h, mask_w)
172 | return sample
173 |
174 | def transform_input(self, input, mask_h, mask_w):
175 | input_h, input_w = input.shape[:2]
176 | if (input_h, input_w) == (mask_h, mask_w):
177 | return input
178 | h, w = (input_h - mask_h) // 2, (input_w - mask_w) // 2
179 | del_h, del_w = (input_h - mask_h) % 2, (input_w - mask_w) % 2
180 | input = input[h : input_h - h - del_h, w : input_w - w - del_w]
181 | assert input.shape[:2] == (mask_h, mask_w)
182 | return input
183 |
184 |
185 | class ResizeAlignToMask(object):
186 | """Resize inputs to the size of the mask."""
187 |
188 | def __call__(self, sample):
189 | mask_h, mask_w = sample["mask"].shape[:2]
190 | assert mask_h == mask_w
191 | inters = {"rgb": cv2.INTER_CUBIC, "depth": cv2.INTER_NEAREST}
192 | for key in sample["inputs"]:
193 | inter = inters[key] if key in inters else cv2.INTER_CUBIC
194 | sample[key] = self.transform_input(sample[key], mask_h, inter)
195 | return sample
196 |
197 | def transform_input(self, input, mask_h, inter):
198 | input_h, input_w = input.shape[:2]
199 | assert input_h == input_w
200 | scale = mask_h / input_h
201 | input = cv2.resize(input, None, fx=scale, fy=scale, interpolation=inter)
202 | return input
203 |
204 |
205 | class ResizeInputs(object):
206 | def __init__(self, size):
207 | self.size = size
208 |
209 | def __call__(self, sample):
210 | # sample['rgb'] = sample['rgb'].numpy()
211 | if self.size is None:
212 | return sample
213 | size = sample["rgb"].shape[0]
214 | scale = self.size / size
215 | # print(sample['rgb'].shape, type(sample['rgb']))
216 | inters = {"rgb": cv2.INTER_CUBIC, "depth": cv2.INTER_NEAREST}
217 | for key in sample["inputs"]:
218 | inter = inters[key] if key in inters else cv2.INTER_CUBIC
219 | sample[key] = self.transform_input(sample[key], scale, inter)
220 | return sample
221 |
222 | def transform_input(self, input, scale, inter):
223 | input = cv2.resize(input, None, fx=scale, fy=scale, interpolation=inter)
224 | return input
225 |
226 |
227 | class ResizeInputsScale(object):
228 | def __init__(self, scale):
229 | self.scale = scale
230 |
231 | def __call__(self, sample):
232 | if self.scale is None:
233 | return sample
234 | inters = {"rgb": cv2.INTER_CUBIC, "depth": cv2.INTER_NEAREST}
235 | for key in sample["inputs"]:
236 | inter = inters[key] if key in inters else cv2.INTER_CUBIC
237 | sample[key] = self.transform_input(sample[key], self.scale, inter)
238 | return sample
239 |
240 | def transform_input(self, input, scale, inter):
241 | input = cv2.resize(input, None, fx=scale, fy=scale, interpolation=inter)
242 | return input
243 |
244 |
245 | class RandomMirror(object):
246 | """Randomly flip the image and the mask"""
247 |
248 | def __call__(self, sample):
249 | do_mirror = np.random.randint(2)
250 | if do_mirror:
251 | for key in sample["inputs"]:
252 | sample[key] = cv2.flip(sample[key], 1)
253 | sample["mask"] = cv2.flip(sample["mask"], 1)
254 | return sample
255 |
256 |
257 | class Normalise(object):
258 | """Normalise a tensor image with mean and standard deviation.
259 | Given mean: (R, G, B) and std: (R, G, B),
260 | will normalise each channel of the torch.*Tensor, i.e.
261 | channel = (scale * channel - mean) / std
262 |
263 | Args:
264 | scale (float): Scaling constant.
265 | mean (sequence): Sequence of means for R,G,B channels respecitvely.
266 | std (sequence): Sequence of standard deviations for R,G,B channels
267 | respecitvely.
268 | depth_scale (float): Depth divisor for depth annotations.
269 |
270 | """
271 |
272 | def __init__(self, scale, mean, std, depth_scale=1.0):
273 | self.scale = scale
274 | self.mean = mean
275 | self.std = std
276 | self.depth_scale = depth_scale
277 |
278 | def __call__(self, sample):
279 | for key in sample["inputs"]:
280 | if key == "depth":
281 | continue
282 | sample[key] = (self.scale * sample[key] - self.mean) / self.std
283 | if "depth" in sample:
284 | # sample['depth'] = self.scale * sample['depth']
285 | # sample['depth'] = (self.scale * sample['depth'] - self.mean) / self.std
286 | if self.depth_scale > 0:
287 | sample["depth"] = self.depth_scale * sample["depth"]
288 | elif self.depth_scale == -1: # taskonomy
289 | # sample['depth'] = np.log(1 + sample['depth']) / np.log(2.** 16.0)
290 | sample["depth"] = np.log(1 + sample["depth"])
291 | elif self.depth_scale == -2: # sunrgbd
292 | depth = sample["depth"]
293 | sample["depth"] = (
294 | (depth - depth.min()) * 255.0 / (depth.max() - depth.min())
295 | )
296 | return sample
297 |
298 |
299 | class ToTensor(object):
300 | """Convert ndarrays in sample to Tensors."""
301 |
302 | def __call__(self, sample):
303 | # swap color axis because
304 | # numpy image: H x W x C
305 | # torch image: C X H X W
306 | for key in ["rgb", "depth"]:
307 | sample[key] = torch.from_numpy(sample[key].transpose((2, 0, 1))).to(
308 | KEYS_TO_DTYPES[key] if key in KEYS_TO_DTYPES else KEYS_TO_DTYPES["rgb"]
309 | )
310 | sample["mask"] = torch.from_numpy(sample["mask"]).to(KEYS_TO_DTYPES["mask"])
311 | return sample
312 |
313 |
314 | def make_list(x):
315 | """Returns the given input as a list."""
316 | if isinstance(x, list):
317 | return x
318 | elif isinstance(x, tuple):
319 | return list(x)
320 | else:
321 | return [x]
322 |
--------------------------------------------------------------------------------
/models/segformer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from . import mix_transformer
5 | from mmcv.cnn import ConvModule
6 | from .modules import num_parallel
7 | from .swin_transformer import SwinTransformer
8 |
9 |
10 | class MLP(nn.Module):
11 | """
12 | Linear Embedding
13 | """
14 |
15 | def __init__(self, input_dim=2048, embed_dim=768):
16 | super().__init__()
17 | self.proj = nn.Linear(input_dim, embed_dim)
18 |
19 | def forward(self, x):
20 | x = x.flatten(2).transpose(1, 2)
21 | x = self.proj(x)
22 | return x
23 |
24 |
25 | class SegFormerHead(nn.Module):
26 | """
27 | SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers
28 | """
29 |
30 | def __init__(
31 | self,
32 | feature_strides=None,
33 | in_channels=128,
34 | embedding_dim=256,
35 | num_classes=20,
36 | **kwargs
37 | ):
38 | super(SegFormerHead, self).__init__()
39 | self.in_channels = in_channels
40 | self.num_classes = num_classes
41 | assert len(feature_strides) == len(self.in_channels)
42 | assert min(feature_strides) == feature_strides[0]
43 | self.feature_strides = feature_strides
44 |
45 | (
46 | c1_in_channels,
47 | c2_in_channels,
48 | c3_in_channels,
49 | c4_in_channels,
50 | ) = self.in_channels
51 |
52 | # decoder_params = kwargs['decoder_params']
53 | # embedding_dim = decoder_params['embed_dim']
54 |
55 | self.linear_c4 = MLP(input_dim=c4_in_channels, embed_dim=embedding_dim)
56 | self.linear_c3 = MLP(input_dim=c3_in_channels, embed_dim=embedding_dim)
57 | self.linear_c2 = MLP(input_dim=c2_in_channels, embed_dim=embedding_dim)
58 | self.linear_c1 = MLP(input_dim=c1_in_channels, embed_dim=embedding_dim)
59 | self.dropout = nn.Dropout2d(0.1)
60 |
61 | self.linear_fuse = ConvModule(
62 | in_channels=embedding_dim * 4,
63 | out_channels=embedding_dim,
64 | kernel_size=1,
65 | norm_cfg=dict(type="BN", requires_grad=True),
66 | )
67 |
68 | self.linear_pred = nn.Conv2d(embedding_dim, self.num_classes, kernel_size=1)
69 |
70 | def forward(self, x):
71 | c1, c2, c3, c4 = x
72 |
73 | ############## MLP decoder on C1-C4 ###########
74 | n, _, h, w = c4.shape
75 |
76 | _c4 = (
77 | self.linear_c4(c4).permute(0, 2, 1).reshape(n, -1, c4.shape[2], c4.shape[3])
78 | )
79 | _c4 = F.interpolate(
80 | _c4, size=c1.size()[2:], mode="bilinear", align_corners=False
81 | )
82 |
83 | _c3 = (
84 | self.linear_c3(c3).permute(0, 2, 1).reshape(n, -1, c3.shape[2], c3.shape[3])
85 | )
86 | _c3 = F.interpolate(
87 | _c3, size=c1.size()[2:], mode="bilinear", align_corners=False
88 | )
89 |
90 | _c2 = (
91 | self.linear_c2(c2).permute(0, 2, 1).reshape(n, -1, c2.shape[2], c2.shape[3])
92 | )
93 | _c2 = F.interpolate(
94 | _c2, size=c1.size()[2:], mode="bilinear", align_corners=False
95 | )
96 |
97 | _c1 = (
98 | self.linear_c1(c1).permute(0, 2, 1).reshape(n, -1, c1.shape[2], c1.shape[3])
99 | )
100 |
101 | _c = self.linear_fuse(torch.cat([_c4, _c3, _c2, _c1], dim=1))
102 |
103 | x = self.dropout(_c)
104 | x = self.linear_pred(x)
105 |
106 | return x
107 |
108 |
109 | class TransformerBackbone(nn.Module):
110 | def __init__(
111 | self,
112 | backbone: str,
113 | train_backbone: bool,
114 | return_interm_layers: bool,
115 | drop_path_rate,
116 | pretrained_backbone_path,
117 | ):
118 | super().__init__()
119 | out_indices = (0, 1, 2, 3)
120 | if backbone == "swin_tiny":
121 | backbone = SwinTransformer(
122 | embed_dim=96,
123 | depths=[2, 2, 6, 2],
124 | num_heads=[3, 6, 12, 24],
125 | window_size=7,
126 | ape=False,
127 | drop_path_rate=drop_path_rate,
128 | patch_norm=True,
129 | use_checkpoint=False,
130 | out_indices=out_indices,
131 | )
132 | embed_dim = 96
133 | backbone.init_weights(pretrained_backbone_path)
134 | elif backbone == "swin_small":
135 | backbone = SwinTransformer(
136 | embed_dim=96,
137 | depths=[2, 2, 18, 2],
138 | num_heads=[3, 6, 12, 24],
139 | window_size=7,
140 | ape=False,
141 | drop_path_rate=drop_path_rate,
142 | patch_norm=True,
143 | use_checkpoint=False,
144 | out_indices=out_indices,
145 | )
146 | embed_dim = 96
147 | backbone.init_weights(pretrained_backbone_path)
148 | elif backbone == "swin_large":
149 | backbone = SwinTransformer(
150 | embed_dim=192,
151 | depths=[2, 2, 18, 2],
152 | num_heads=[6, 12, 24, 48],
153 | window_size=7,
154 | ape=False,
155 | drop_path_rate=drop_path_rate,
156 | patch_norm=True,
157 | use_checkpoint=False,
158 | out_indices=out_indices,
159 | )
160 | embed_dim = 192
161 | backbone.init_weights(pretrained_backbone_path)
162 | elif backbone == "swin_large_window12":
163 | backbone = SwinTransformer(
164 | pretrain_img_size=384,
165 | embed_dim=192,
166 | depths=[2, 2, 18, 2],
167 | num_heads=[6, 12, 24, 48],
168 | window_size=12,
169 | ape=False,
170 | drop_path_rate=drop_path_rate,
171 | patch_norm=True,
172 | use_checkpoint=False,
173 | out_indices=out_indices,
174 | )
175 | embed_dim = 192
176 | backbone.init_weights(pretrained_backbone_path)
177 | elif backbone == "swin_large_window12_to_1k":
178 | backbone = SwinTransformer(
179 | pretrain_img_size=384,
180 | embed_dim=192,
181 | depths=[2, 2, 18, 2],
182 | num_heads=[6, 12, 24, 48],
183 | window_size=12,
184 | ape=False,
185 | drop_path_rate=drop_path_rate,
186 | patch_norm=True,
187 | use_checkpoint=False,
188 | out_indices=out_indices,
189 | )
190 | embed_dim = 192
191 | backbone.init_weights(pretrained_backbone_path)
192 | else:
193 | raise NotImplementedError
194 |
195 | for name, parameter in backbone.named_parameters():
196 | # TODO: freeze some layers?
197 | if not train_backbone:
198 | parameter.requires_grad_(False)
199 |
200 | if return_interm_layers:
201 |
202 | self.strides = [8, 16, 32]
203 | self.num_channels = [
204 | embed_dim * 2,
205 | embed_dim * 4,
206 | embed_dim * 8,
207 | ]
208 | else:
209 | self.strides = [32]
210 | self.num_channels = [embed_dim * 8]
211 |
212 | self.body = backbone
213 |
214 | def forward(self, input):
215 | xs = self.body(input)
216 |
217 | return xs
218 |
219 |
220 | class WeTr(nn.Module):
221 | def __init__(
222 | self,
223 | backbone,
224 | num_classes=20,
225 | n_heads=8,
226 | dpr=0.1,
227 | drop_rate=0.0,
228 | ):
229 | super().__init__()
230 | self.num_classes = num_classes
231 | self.embedding_dim = 256
232 | self.feature_strides = [4, 8, 16, 32]
233 | self.num_parallel = num_parallel
234 | self.backbone = backbone
235 |
236 | print("-----------------Model Params--------------------------------------")
237 | print("backbone:", backbone)
238 | print("dpr:", dpr)
239 | print("--------------------------------------------------------------")
240 |
241 | if "swin" in backbone:
242 | if backbone == "swin_tiny":
243 | pretrained_backbone_path = "pretrained/swin_tiny_patch4_window7_224.pth"
244 | self.in_channels = [96, 192, 384, 768]
245 | elif backbone == "swin_small":
246 | pretrained_backbone_path = (
247 | "pretrained/swin_small_patch4_window7_224.pth"
248 | )
249 | self.in_channels = [96, 192, 384, 768]
250 | elif backbone == "swin_large_window12":
251 | pretrained_backbone_path = (
252 | "pretrained/swin_large_patch4_window12_384_22k.pth"
253 | )
254 | self.in_channels = [192, 384, 768, 1536]
255 | elif backbone == "swin_large_window12_to_1k":
256 | pretrained_backbone_path = (
257 | "pretrained/swin_large_patch4_window12_384_22kto1k.pth"
258 | )
259 | self.in_channels = [192, 384, 768, 1536]
260 | else:
261 | assert backbone == "swin_large"
262 | pretrained_backbone_path = (
263 | "pretrained/swin_large_patch4_window7_224_22k.pth"
264 | )
265 | self.in_channels = [192, 384, 768, 1536]
266 | self.encoder = TransformerBackbone(
267 | backbone, True, True, dpr, pretrained_backbone_path
268 | )
269 | else:
270 | self.encoder = getattr(mix_transformer, backbone)(n_heads, dpr, drop_rate)
271 | self.in_channels = self.encoder.embed_dims
272 | ## initilize encoder
273 | state_dict = torch.load("pretrained/" + backbone + ".pth")
274 | state_dict.pop("head.weight")
275 | state_dict.pop("head.bias")
276 | state_dict = expand_state_dict(
277 | self.encoder.state_dict(), state_dict, self.num_parallel
278 | )
279 | self.encoder.load_state_dict(state_dict, strict=True)
280 |
281 | self.decoder = SegFormerHead(
282 | feature_strides=self.feature_strides,
283 | in_channels=self.in_channels,
284 | embedding_dim=self.embedding_dim,
285 | num_classes=self.num_classes,
286 | )
287 |
288 | self.alpha = nn.Parameter(torch.ones(self.num_parallel, requires_grad=True))
289 | self.register_parameter("alpha", self.alpha)
290 |
291 | def get_param_groups(self):
292 | param_groups = [[], [], []]
293 | for name, param in list(self.encoder.named_parameters()):
294 | if "norm" in name:
295 | param_groups[1].append(param)
296 | else:
297 | param_groups[0].append(param)
298 | for param in list(self.decoder.parameters()):
299 | param_groups[2].append(param)
300 | return param_groups
301 |
302 | def forward(self, x):
303 |
304 | x = self.encoder(x)
305 |
306 | x = [self.decoder(x[0]), self.decoder(x[1])]
307 | ens = 0
308 | alpha_soft = F.softmax(self.alpha)
309 | for l in range(self.num_parallel):
310 | ens += alpha_soft[l] * x[l].detach()
311 | x.append(ens)
312 | return x, None
313 |
314 |
315 | def expand_state_dict(model_dict, state_dict, num_parallel):
316 | model_dict_keys = model_dict.keys()
317 | state_dict_keys = state_dict.keys()
318 | for model_dict_key in model_dict_keys:
319 | model_dict_key_re = model_dict_key.replace("module.", "")
320 | if model_dict_key_re in state_dict_keys:
321 | model_dict[model_dict_key] = state_dict[model_dict_key_re]
322 | for i in range(num_parallel):
323 | ln = ".ln_%d" % i
324 | replace = True if ln in model_dict_key_re else False
325 | model_dict_key_re = model_dict_key_re.replace(ln, "")
326 | if replace and model_dict_key_re in state_dict_keys:
327 | model_dict[model_dict_key] = state_dict[model_dict_key_re]
328 | return model_dict
329 |
330 |
331 | if __name__ == "__main__":
332 | pretrained_weights = torch.load("pretrained/mit_b1.pth")
333 | wetr = WeTr("mit_b1", num_classes=20, embedding_dim=256, pretrained=True).cuda()
334 | wetr.get_param_groupsv()
335 | dummy_input = torch.rand(2, 3, 512, 512).cuda()
336 | wetr(dummy_input)
337 |
--------------------------------------------------------------------------------
/utils/augmentations_mm.py:
--------------------------------------------------------------------------------
1 | import torchvision.transforms.functional as TF
2 | import random
3 | import math
4 | import torch
5 | from torch import Tensor
6 | from typing import Tuple, List, Union, Tuple, Optional
7 |
8 |
9 | class Compose:
10 | def __init__(self, transforms: list) -> None:
11 | self.transforms = transforms
12 |
13 | def __call__(self, sample: list) -> list:
14 | img, mask = sample["rgb"], sample["mask"]
15 | if mask.ndim == 2:
16 | assert img.shape[1:] == mask.shape
17 | else:
18 | assert img.shape[1:] == mask.shape[1:]
19 |
20 | for transform in self.transforms:
21 | sample = transform(sample)
22 |
23 | return sample
24 |
25 |
26 | class Normalize:
27 | def __init__(
28 | self, mean: list = (0.485, 0.456, 0.406), std: list = (0.229, 0.224, 0.225)
29 | ):
30 | self.mean = mean
31 | self.std = std
32 |
33 | def __call__(self, sample: list) -> list:
34 | for k, v in sample.items():
35 | if k == "mask":
36 | continue
37 | elif k == "rgb":
38 | sample[k] = sample[k].float()
39 | sample[k] /= 255
40 | sample[k] = TF.normalize(sample[k], self.mean, self.std)
41 | else:
42 | sample[k] = sample[k].float()
43 | sample[k] /= 255
44 |
45 | return sample
46 |
47 |
48 | class RandomColorJitter:
49 | def __init__(self, p=0.5) -> None:
50 | self.p = p
51 |
52 | def __call__(self, sample: list) -> list:
53 | if random.random() < self.p:
54 | self.brightness = random.uniform(0.5, 1.5)
55 | sample["rgb"] = TF.adjust_brightness(sample["rgb"], self.brightness)
56 | self.contrast = random.uniform(0.5, 1.5)
57 | sample["rgb"] = TF.adjust_contrast(sample["rgb"], self.contrast)
58 | self.saturation = random.uniform(0.5, 1.5)
59 | sample["rgb"] = TF.adjust_saturation(sample["rgb"], self.saturation)
60 | return sample
61 |
62 |
63 | class AdjustGamma:
64 | def __init__(self, gamma: float, gain: float = 1) -> None:
65 | """
66 | Args:
67 | gamma: Non-negative real number. gamma larger than 1 make the shadows darker, while gamma smaller than 1 make dark regions lighter.
68 | gain: constant multiplier
69 | """
70 | self.gamma = gamma
71 | self.gain = gain
72 |
73 | def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]:
74 | return TF.adjust_gamma(img, self.gamma, self.gain), mask
75 |
76 |
77 | class RandomAdjustSharpness:
78 | def __init__(self, sharpness_factor: float, p: float = 0.5) -> None:
79 | self.sharpness = sharpness_factor
80 | self.p = p
81 |
82 | def __call__(self, sample: list) -> list:
83 | if random.random() < self.p:
84 | sample["rgb"] = TF.adjust_sharpness(sample["rgb"], self.sharpness)
85 | return sample
86 |
87 |
88 | class RandomAutoContrast:
89 | def __init__(self, p: float = 0.5) -> None:
90 | self.p = p
91 |
92 | def __call__(self, sample: list) -> list:
93 | if random.random() < self.p:
94 | sample["rgb"] = TF.autocontrast(sample["rgb"])
95 | return sample
96 |
97 |
98 | class RandomGaussianBlur:
99 | def __init__(self, kernel_size: int = 3, p: float = 0.5) -> None:
100 | self.kernel_size = kernel_size
101 | self.p = p
102 |
103 | def __call__(self, sample: list) -> list:
104 | if random.random() < self.p:
105 | sample["rgb"] = TF.gaussian_blur(sample["rgb"], self.kernel_size)
106 | # img = TF.gaussian_blur(img, self.kernel_size)
107 | return sample
108 |
109 |
110 | class RandomHorizontalFlip:
111 | def __init__(self, p: float = 0.5) -> None:
112 | self.p = p
113 |
114 | def __call__(self, sample: list) -> list:
115 | if random.random() < self.p:
116 | for k, v in sample.items():
117 | sample[k] = TF.hflip(v)
118 | return sample
119 | return sample
120 |
121 |
122 | class RandomVerticalFlip:
123 | def __init__(self, p: float = 0.5) -> None:
124 | self.p = p
125 |
126 | def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]:
127 | if random.random() < self.p:
128 | return TF.vflip(img), TF.vflip(mask)
129 | return img, mask
130 |
131 |
132 | class RandomGrayscale:
133 | def __init__(self, p: float = 0.5) -> None:
134 | self.p = p
135 |
136 | def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]:
137 | if random.random() < self.p:
138 | img = TF.rgb_to_grayscale(img, 3)
139 | return img, mask
140 |
141 |
142 | class Equalize:
143 | def __call__(self, image, label):
144 | return TF.equalize(image), label
145 |
146 |
147 | class Posterize:
148 | def __init__(self, bits=2):
149 | self.bits = bits # 0-8
150 |
151 | def __call__(self, image, label):
152 | return TF.posterize(image, self.bits), label
153 |
154 |
155 | class Affine:
156 | def __init__(self, angle=0, translate=[0, 0], scale=1.0, shear=[0, 0], seg_fill=0):
157 | self.angle = angle
158 | self.translate = translate
159 | self.scale = scale
160 | self.shear = shear
161 | self.seg_fill = seg_fill
162 |
163 | def __call__(self, img, label):
164 | return TF.affine(
165 | img,
166 | self.angle,
167 | self.translate,
168 | self.scale,
169 | self.shear,
170 | TF.InterpolationMode.BILINEAR,
171 | 0,
172 | ), TF.affine(
173 | label,
174 | self.angle,
175 | self.translate,
176 | self.scale,
177 | self.shear,
178 | TF.InterpolationMode.NEAREST,
179 | self.seg_fill,
180 | )
181 |
182 |
183 | class RandomRotation:
184 | def __init__(
185 | self,
186 | degrees: float = 10.0,
187 | p: float = 0.2,
188 | seg_fill: int = 0,
189 | expand: bool = False,
190 | ) -> None:
191 | """Rotate the image by a random angle between -angle and angle with probability p
192 |
193 | Args:
194 | p: probability
195 | angle: rotation angle value in degrees, counter-clockwise.
196 | expand: Optional expansion flag.
197 | If true, expands the output image to make it large enough to hold the entire rotated image.
198 | If false or omitted, make the output image the same size as the input image.
199 | Note that the expand flag assumes rotation around the center and no translation.
200 | """
201 | self.p = p
202 | self.angle = degrees
203 | self.expand = expand
204 | self.seg_fill = seg_fill
205 |
206 | def __call__(self, sample: list) -> list:
207 | random_angle = random.random() * 2 * self.angle - self.angle
208 | if random.random() < self.p:
209 | for k, v in sample.items():
210 | if k == "mask":
211 | sample[k] = TF.rotate(
212 | v,
213 | random_angle,
214 | TF.InterpolationMode.NEAREST,
215 | self.expand,
216 | fill=self.seg_fill,
217 | )
218 | else:
219 | sample[k] = TF.rotate(
220 | v,
221 | random_angle,
222 | TF.InterpolationMode.BILINEAR,
223 | self.expand,
224 | fill=0,
225 | )
226 | # img = TF.rotate(img, random_angle, TF.InterpolationMode.BILINEAR, self.expand, fill=0)
227 | # mask = TF.rotate(mask, random_angle, TF.InterpolationMode.NEAREST, self.expand, fill=self.seg_fill)
228 | return sample
229 |
230 |
231 | class CenterCrop:
232 | def __init__(self, size: Union[int, List[int], Tuple[int]]) -> None:
233 | """Crops the image at the center
234 |
235 | Args:
236 | output_size: height and width of the crop box. If int, this size is used for both directions.
237 | """
238 | self.size = (size, size) if isinstance(size, int) else size
239 |
240 | def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]:
241 | return TF.center_crop(img, self.size), TF.center_crop(mask, self.size)
242 |
243 |
244 | class RandomCrop:
245 | def __init__(self, size: Union[int, List[int], Tuple[int]], p: float = 0.5) -> None:
246 | """Randomly Crops the image.
247 |
248 | Args:
249 | output_size: height and width of the crop box. If int, this size is used for both directions.
250 | """
251 | self.size = (size, size) if isinstance(size, int) else size
252 | self.p = p
253 |
254 | def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]:
255 | H, W = img.shape[1:]
256 | tH, tW = self.size
257 |
258 | if random.random() < self.p:
259 | margin_h = max(H - tH, 0)
260 | margin_w = max(W - tW, 0)
261 | y1 = random.randint(0, margin_h + 1)
262 | x1 = random.randint(0, margin_w + 1)
263 | y2 = y1 + tH
264 | x2 = x1 + tW
265 | img = img[:, y1:y2, x1:x2]
266 | mask = mask[:, y1:y2, x1:x2]
267 | return img, mask
268 |
269 |
270 | class Pad:
271 | def __init__(
272 | self, size: Union[List[int], Tuple[int], int], seg_fill: int = 0
273 | ) -> None:
274 | """Pad the given image on all sides with the given "pad" value.
275 | Args:
276 | size: expected output image size (h, w)
277 | fill: Pixel fill value for constant fill. Default is 0. This value is only used when the padding mode is constant.
278 | """
279 | self.size = size
280 | self.seg_fill = seg_fill
281 |
282 | def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]:
283 | padding = (0, 0, self.size[1] - img.shape[2], self.size[0] - img.shape[1])
284 | return TF.pad(img, padding), TF.pad(mask, padding, self.seg_fill)
285 |
286 |
287 | class ResizePad:
288 | def __init__(
289 | self, size: Union[int, Tuple[int], List[int]], seg_fill: int = 0
290 | ) -> None:
291 | """Resize the input image to the given size.
292 | Args:
293 | size: Desired output size.
294 | If size is a sequence, the output size will be matched to this.
295 | If size is an int, the smaller edge of the image will be matched to this number maintaining the aspect ratio.
296 | """
297 | self.size = size
298 | self.seg_fill = seg_fill
299 |
300 | def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]:
301 | H, W = img.shape[1:]
302 | tH, tW = self.size
303 |
304 | # scale the image
305 | scale_factor = min(tH / H, tW / W) if W > H else max(tH / H, tW / W)
306 | # nH, nW = int(H * scale_factor + 0.5), int(W * scale_factor + 0.5)
307 | nH, nW = round(H * scale_factor), round(W * scale_factor)
308 | img = TF.resize(img, (nH, nW), TF.InterpolationMode.BILINEAR)
309 | mask = TF.resize(mask, (nH, nW), TF.InterpolationMode.NEAREST)
310 |
311 | # pad the image
312 | padding = [0, 0, tW - nW, tH - nH]
313 | img = TF.pad(img, padding, fill=0)
314 | mask = TF.pad(mask, padding, fill=self.seg_fill)
315 | return img, mask
316 |
317 |
318 | class Resize:
319 | def __init__(self, size: Union[int, Tuple[int], List[int]]) -> None:
320 | """Resize the input image to the given size.
321 | Args:
322 | size: Desired output size.
323 | If size is a sequence, the output size will be matched to this.
324 | If size is an int, the smaller edge of the image will be matched to this number maintaining the aspect ratio.
325 | """
326 | self.size = size
327 |
328 | def __call__(self, sample: list) -> list:
329 | H, W = sample["rgb"].shape[1:]
330 |
331 | # scale the image
332 | scale_factor = self.size[0] / min(H, W)
333 | nH, nW = round(H * scale_factor), round(W * scale_factor)
334 | for k, v in sample.items():
335 | if k == "mask":
336 | sample[k] = TF.resize(
337 | v.unsqueeze(0), (nH, nW), TF.InterpolationMode.NEAREST
338 | ).squeeze(0)
339 | else:
340 | sample[k] = TF.resize(v, (nH, nW), TF.InterpolationMode.BILINEAR)
341 | # img = TF.resize(img, (nH, nW), TF.InterpolationMode.BILINEAR)
342 | # mask = TF.resize(mask, (nH, nW), TF.InterpolationMode.NEAREST)
343 |
344 | # make the image divisible by stride
345 | alignH, alignW = int(math.ceil(nH / 32)) * 32, int(math.ceil(nW / 32)) * 32
346 |
347 | for k, v in sample.items():
348 | if k == "mask":
349 | sample[k] = TF.resize(
350 | v.unsqueeze(0), (alignH, alignW), TF.InterpolationMode.NEAREST
351 | ).squeeze(0)
352 | else:
353 | sample[k] = TF.resize(
354 | v, (alignH, alignW), TF.InterpolationMode.BILINEAR
355 | )
356 | # img = TF.resize(img, (alignH, alignW), TF.InterpolationMode.BILINEAR)
357 | # mask = TF.resize(mask, (alignH, alignW), TF.InterpolationMode.NEAREST)
358 | return sample
359 |
360 |
361 | class RandomResizedCrop:
362 | def __init__(
363 | self,
364 | size: Union[int, Tuple[int], List[int]],
365 | scale: Tuple[float, float] = (0.5, 2.0),
366 | seg_fill: int = 0,
367 | ) -> None:
368 | """Resize the input image to the given size."""
369 | self.size = size
370 | self.scale = scale
371 | self.seg_fill = seg_fill
372 |
373 | def __call__(self, sample: list) -> list:
374 | # img, mask = sample['rgb'], sample['mask']
375 | H, W = sample["rgb"].shape[1:]
376 | tH, tW = self.size
377 |
378 | # get the scale
379 | ratio = random.random() * (self.scale[1] - self.scale[0]) + self.scale[0]
380 | # ratio = random.uniform(min(self.scale), max(self.scale))
381 | scale = int(tH * ratio), int(tW * 4 * ratio)
382 | # scale the image
383 | scale_factor = min(max(scale) / max(H, W), min(scale) / min(H, W))
384 | nH, nW = int(H * scale_factor + 0.5), int(W * scale_factor + 0.5)
385 | # nH, nW = int(math.ceil(nH / 32)) * 32, int(math.ceil(nW / 32)) * 32
386 | for k, v in sample.items():
387 | if k == "mask":
388 | sample[k] = TF.resize(
389 | v.unsqueeze(0),
390 | (nH, nW),
391 | TF.InterpolationMode.NEAREST,
392 | ).squeeze(0)
393 | else:
394 | sample[k] = TF.resize(v, (nH, nW), TF.InterpolationMode.BILINEAR)
395 |
396 | # random crop
397 | margin_h = max(sample["rgb"].shape[1] - tH, 0)
398 | margin_w = max(sample["rgb"].shape[2] - tW, 0)
399 | y1 = random.randint(0, margin_h + 1)
400 | x1 = random.randint(0, margin_w + 1)
401 | y2 = y1 + tH
402 | x2 = x1 + tW
403 | for k, v in sample.items():
404 | # print("before_1:", k, sample[k].shape)
405 | if len(v.shape) == 3:
406 | sample[k] = v[:, y1:y2, x1:x2]
407 | else:
408 | sample[k] = v[y1:y2, x1:x2]
409 | # print("after_1:", k, sample[k].shape)
410 |
411 | # pad the image
412 | if sample["rgb"].shape[1:] != self.size:
413 | padding = [
414 | 0,
415 | 0,
416 | tW - sample["rgb"].shape[2],
417 | tH - sample["rgb"].shape[1],
418 | ]
419 | for k, v in sample.items():
420 | if k == "mask":
421 | sample[k] = TF.pad(v, padding, fill=self.seg_fill)
422 | else:
423 | sample[k] = TF.pad(v, padding, fill=0)
424 |
425 | return sample
426 |
427 |
428 | def get_train_augmentation(size: Union[int, Tuple[int], List[int]], seg_fill: int = 0):
429 | return Compose(
430 | [
431 | RandomColorJitter(p=0.2), #
432 | RandomHorizontalFlip(p=0.5), #
433 | RandomGaussianBlur((3, 3), p=0.2), #
434 | RandomResizedCrop(size, scale=(0.5, 2.0), seg_fill=seg_fill), #
435 | Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
436 | ]
437 | )
438 |
439 |
440 | def get_val_augmentation(size: Union[int, Tuple[int], List[int]]):
441 | return Compose(
442 | [Resize(size), Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]
443 | )
444 |
445 |
446 | if __name__ == "__main__":
447 | h = 230
448 | w = 420
449 | sample = {}
450 | sample["rgb"] = torch.randn(3, h, w)
451 | sample["depth"] = torch.randn(3, h, w)
452 | sample["lidar"] = torch.randn(3, h, w)
453 | sample["event"] = torch.randn(3, h, w)
454 | sample["mask"] = torch.randn(1, h, w)
455 | aug = Compose(
456 | [
457 | RandomHorizontalFlip(p=0.5),
458 | RandomResizedCrop((512, 512)),
459 | Resize((224, 224)),
460 | Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
461 | ]
462 | )
463 | sample = aug(sample)
464 | for k, v in sample.items():
465 | print(k, v.shape)
466 |
--------------------------------------------------------------------------------
/mmcv_custom/checkpoint.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Open-MMLab. All rights reserved.
2 | import io
3 | import os
4 | import os.path as osp
5 | import pkgutil
6 | import time
7 | import warnings
8 | from collections import OrderedDict
9 | from importlib import import_module
10 | from tempfile import TemporaryDirectory
11 |
12 | import torch
13 | import torchvision
14 | from torch.optim import Optimizer
15 | from torch.utils import model_zoo
16 | from torch.nn import functional as F
17 |
18 | import mmcv
19 | from mmcv.fileio import FileClient
20 | from mmcv.fileio import load as load_file
21 | from mmcv.parallel import is_module_wrapper
22 | from mmcv.utils import mkdir_or_exist
23 | from mmcv.runner import get_dist_info
24 |
25 | ENV_MMCV_HOME = "MMCV_HOME"
26 | ENV_XDG_CACHE_HOME = "XDG_CACHE_HOME"
27 | DEFAULT_CACHE_DIR = "~/.cache"
28 |
29 |
30 | def _get_mmcv_home():
31 | mmcv_home = os.path.expanduser(
32 | os.getenv(
33 | ENV_MMCV_HOME,
34 | os.path.join(os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), "mmcv"),
35 | )
36 | )
37 |
38 | mkdir_or_exist(mmcv_home)
39 | return mmcv_home
40 |
41 |
42 | def load_state_dict(module, state_dict, strict=False, logger=None):
43 | """Load state_dict to a module.
44 |
45 | This method is modified from :meth:`torch.nn.Module.load_state_dict`.
46 | Default value for ``strict`` is set to ``False`` and the message for
47 | param mismatch will be shown even if strict is False.
48 |
49 | Args:
50 | module (Module): Module that receives the state_dict.
51 | state_dict (OrderedDict): Weights.
52 | strict (bool): whether to strictly enforce that the keys
53 | in :attr:`state_dict` match the keys returned by this module's
54 | :meth:`~torch.nn.Module.state_dict` function. Default: ``False``.
55 | logger (:obj:`logging.Logger`, optional): Logger to log the error
56 | message. If not specified, print function will be used.
57 | """
58 | unexpected_keys = []
59 | all_missing_keys = []
60 | err_msg = []
61 |
62 | metadata = getattr(state_dict, "_metadata", None)
63 | state_dict = state_dict.copy()
64 | if metadata is not None:
65 | state_dict._metadata = metadata
66 |
67 | # use _load_from_state_dict to enable checkpoint version control
68 | def load(module, prefix=""):
69 | # recursively check parallel module in case that the model has a
70 | # complicated structure, e.g., nn.Module(nn.Module(DDP))
71 | if is_module_wrapper(module):
72 | module = module.module
73 | local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
74 | module._load_from_state_dict(
75 | state_dict,
76 | prefix,
77 | local_metadata,
78 | True,
79 | all_missing_keys,
80 | unexpected_keys,
81 | err_msg,
82 | )
83 | for name, child in module._modules.items():
84 | if child is not None:
85 | load(child, prefix + name + ".")
86 |
87 | load(module)
88 | load = None # break load->load reference cycle
89 |
90 | # ignore "num_batches_tracked" of BN layers
91 | missing_keys = [key for key in all_missing_keys if "num_batches_tracked" not in key]
92 |
93 | if unexpected_keys:
94 | err_msg.append(
95 | "unexpected key in source " f'state_dict: {", ".join(unexpected_keys)}\n'
96 | )
97 | if missing_keys:
98 | err_msg.append(
99 | f'missing keys in source state_dict: {", ".join(missing_keys)}\n'
100 | )
101 |
102 | rank, _ = get_dist_info()
103 | if len(err_msg) > 0 and rank == 0:
104 | err_msg.insert(0, "The model and loaded state dict do not match exactly\n")
105 | err_msg = "\n".join(err_msg)
106 | if strict:
107 | raise RuntimeError(err_msg)
108 | elif logger is not None:
109 | logger.warning(err_msg)
110 | else:
111 | print(err_msg)
112 |
113 |
114 | def load_url_dist(url, model_dir=None):
115 | """In distributed setting, this function only download checkpoint at local
116 | rank 0."""
117 | rank, world_size = get_dist_info()
118 | rank = int(os.environ.get("LOCAL_RANK", rank))
119 | if rank == 0:
120 | checkpoint = model_zoo.load_url(url, model_dir=model_dir)
121 | if world_size > 1:
122 | torch.distributed.barrier()
123 | if rank > 0:
124 | checkpoint = model_zoo.load_url(url, model_dir=model_dir)
125 | return checkpoint
126 |
127 |
128 | def load_pavimodel_dist(model_path, map_location=None):
129 | """In distributed setting, this function only download checkpoint at local
130 | rank 0."""
131 | try:
132 | from pavi import modelcloud
133 | except ImportError:
134 | raise ImportError("Please install pavi to load checkpoint from modelcloud.")
135 | rank, world_size = get_dist_info()
136 | rank = int(os.environ.get("LOCAL_RANK", rank))
137 | if rank == 0:
138 | model = modelcloud.get(model_path)
139 | with TemporaryDirectory() as tmp_dir:
140 | downloaded_file = osp.join(tmp_dir, model.name)
141 | model.download(downloaded_file)
142 | checkpoint = torch.load(downloaded_file, map_location=map_location)
143 | if world_size > 1:
144 | torch.distributed.barrier()
145 | if rank > 0:
146 | model = modelcloud.get(model_path)
147 | with TemporaryDirectory() as tmp_dir:
148 | downloaded_file = osp.join(tmp_dir, model.name)
149 | model.download(downloaded_file)
150 | checkpoint = torch.load(downloaded_file, map_location=map_location)
151 | return checkpoint
152 |
153 |
154 | def load_fileclient_dist(filename, backend, map_location):
155 | """In distributed setting, this function only download checkpoint at local
156 | rank 0."""
157 | rank, world_size = get_dist_info()
158 | rank = int(os.environ.get("LOCAL_RANK", rank))
159 | allowed_backends = ["ceph"]
160 | if backend not in allowed_backends:
161 | raise ValueError(f"Load from Backend {backend} is not supported.")
162 | if rank == 0:
163 | fileclient = FileClient(backend=backend)
164 | buffer = io.BytesIO(fileclient.get(filename))
165 | checkpoint = torch.load(buffer, map_location=map_location)
166 | if world_size > 1:
167 | torch.distributed.barrier()
168 | if rank > 0:
169 | fileclient = FileClient(backend=backend)
170 | buffer = io.BytesIO(fileclient.get(filename))
171 | checkpoint = torch.load(buffer, map_location=map_location)
172 | return checkpoint
173 |
174 |
175 | def get_torchvision_models():
176 | model_urls = dict()
177 | for _, name, ispkg in pkgutil.walk_packages(torchvision.models.__path__):
178 | if ispkg:
179 | continue
180 | _zoo = import_module(f"torchvision.models.{name}")
181 | if hasattr(_zoo, "model_urls"):
182 | _urls = getattr(_zoo, "model_urls")
183 | model_urls.update(_urls)
184 | return model_urls
185 |
186 |
187 | def get_external_models():
188 | mmcv_home = _get_mmcv_home()
189 | default_json_path = osp.join(mmcv.__path__[0], "model_zoo/open_mmlab.json")
190 | default_urls = load_file(default_json_path)
191 | assert isinstance(default_urls, dict)
192 | external_json_path = osp.join(mmcv_home, "open_mmlab.json")
193 | if osp.exists(external_json_path):
194 | external_urls = load_file(external_json_path)
195 | assert isinstance(external_urls, dict)
196 | default_urls.update(external_urls)
197 |
198 | return default_urls
199 |
200 |
201 | def get_mmcls_models():
202 | mmcls_json_path = osp.join(mmcv.__path__[0], "model_zoo/mmcls.json")
203 | mmcls_urls = load_file(mmcls_json_path)
204 |
205 | return mmcls_urls
206 |
207 |
208 | def get_deprecated_model_names():
209 | deprecate_json_path = osp.join(mmcv.__path__[0], "model_zoo/deprecated.json")
210 | deprecate_urls = load_file(deprecate_json_path)
211 | assert isinstance(deprecate_urls, dict)
212 |
213 | return deprecate_urls
214 |
215 |
216 | def _process_mmcls_checkpoint(checkpoint):
217 | state_dict = checkpoint["state_dict"]
218 | new_state_dict = OrderedDict()
219 | for k, v in state_dict.items():
220 | if k.startswith("backbone."):
221 | new_state_dict[k[9:]] = v
222 | new_checkpoint = dict(state_dict=new_state_dict)
223 |
224 | return new_checkpoint
225 |
226 |
227 | def _load_checkpoint(filename, map_location=None):
228 | """Load checkpoint from somewhere (modelzoo, file, url).
229 |
230 | Args:
231 | filename (str): Accept local filepath, URL, ``torchvision://xxx``,
232 | ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
233 | details.
234 | map_location (str | None): Same as :func:`torch.load`. Default: None.
235 |
236 | Returns:
237 | dict | OrderedDict: The loaded checkpoint. It can be either an
238 | OrderedDict storing model weights or a dict containing other
239 | information, which depends on the checkpoint.
240 | """
241 | if filename.startswith("modelzoo://"):
242 | warnings.warn(
243 | 'The URL scheme of "modelzoo://" is deprecated, please '
244 | 'use "torchvision://" instead'
245 | )
246 | model_urls = get_torchvision_models()
247 | model_name = filename[11:]
248 | checkpoint = load_url_dist(model_urls[model_name])
249 | elif filename.startswith("torchvision://"):
250 | model_urls = get_torchvision_models()
251 | model_name = filename[14:]
252 | checkpoint = load_url_dist(model_urls[model_name])
253 | elif filename.startswith("open-mmlab://"):
254 | model_urls = get_external_models()
255 | model_name = filename[13:]
256 | deprecated_urls = get_deprecated_model_names()
257 | if model_name in deprecated_urls:
258 | warnings.warn(
259 | f"open-mmlab://{model_name} is deprecated in favor "
260 | f"of open-mmlab://{deprecated_urls[model_name]}"
261 | )
262 | model_name = deprecated_urls[model_name]
263 | model_url = model_urls[model_name]
264 | # check if is url
265 | if model_url.startswith(("http://", "https://")):
266 | checkpoint = load_url_dist(model_url)
267 | else:
268 | filename = osp.join(_get_mmcv_home(), model_url)
269 | if not osp.isfile(filename):
270 | raise IOError(f"{filename} is not a checkpoint file")
271 | checkpoint = torch.load(filename, map_location=map_location)
272 | elif filename.startswith("mmcls://"):
273 | model_urls = get_mmcls_models()
274 | model_name = filename[8:]
275 | checkpoint = load_url_dist(model_urls[model_name])
276 | checkpoint = _process_mmcls_checkpoint(checkpoint)
277 | elif filename.startswith(("http://", "https://")):
278 | checkpoint = load_url_dist(filename)
279 | elif filename.startswith("pavi://"):
280 | model_path = filename[7:]
281 | checkpoint = load_pavimodel_dist(model_path, map_location=map_location)
282 | elif filename.startswith("s3://"):
283 | checkpoint = load_fileclient_dist(
284 | filename, backend="ceph", map_location=map_location
285 | )
286 | else:
287 | if not osp.isfile(filename):
288 | raise IOError(f"{filename} is not a checkpoint file")
289 | checkpoint = torch.load(filename, map_location=map_location)
290 | return checkpoint
291 |
292 |
293 | def load_checkpoint(model, filename, map_location="cpu", strict=False, logger=None):
294 | """Load checkpoint from a file or URI.
295 |
296 | Args:
297 | model (Module): Module to load checkpoint.
298 | filename (str): Accept local filepath, URL, ``torchvision://xxx``,
299 | ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
300 | details.
301 | map_location (str): Same as :func:`torch.load`.
302 | strict (bool): Whether to allow different params for the model and
303 | checkpoint.
304 | logger (:mod:`logging.Logger` or None): The logger for error message.
305 |
306 | Returns:
307 | dict or OrderedDict: The loaded checkpoint.
308 | """
309 | checkpoint = _load_checkpoint(filename, map_location)
310 | # OrderedDict is a subclass of dict
311 | if not isinstance(checkpoint, dict):
312 | raise RuntimeError(f"No state_dict found in checkpoint file {filename}")
313 | # get state_dict from checkpoint
314 | if "state_dict" in checkpoint:
315 | state_dict = checkpoint["state_dict"]
316 | elif "model" in checkpoint:
317 | state_dict = checkpoint["model"]
318 | else:
319 | state_dict = checkpoint
320 | # strip prefix of state_dict
321 | if list(state_dict.keys())[0].startswith("module."):
322 | state_dict = {k[7:]: v for k, v in state_dict.items()}
323 |
324 | # for MoBY, load model of online branch
325 | if sorted(list(state_dict.keys()))[0].startswith("encoder"):
326 | state_dict = {
327 | k.replace("encoder.", ""): v
328 | for k, v in state_dict.items()
329 | if k.startswith("encoder.")
330 | }
331 |
332 | # reshape absolute position embedding
333 | if state_dict.get("absolute_pos_embed") is not None:
334 | absolute_pos_embed = state_dict["absolute_pos_embed"]
335 | N1, L, C1 = absolute_pos_embed.size()
336 | N2, C2, H, W = model.absolute_pos_embed.size()
337 | if N1 != N2 or C1 != C2 or L != H * W:
338 | logger.warning("Error in loading absolute_pos_embed, pass")
339 | else:
340 | state_dict["absolute_pos_embed"] = absolute_pos_embed.view(
341 | N2, H, W, C2
342 | ).permute(0, 3, 1, 2)
343 |
344 | # interpolate position bias table if needed
345 | relative_position_bias_table_keys = [
346 | k for k in state_dict.keys() if "relative_position_bias_table" in k
347 | ]
348 | for table_key in relative_position_bias_table_keys:
349 | table_pretrained = state_dict[table_key]
350 |
351 | new_table_key = ".".join(
352 | table_key.split(".")[:-1] + ["module"] + [table_key.split(".")[-1]]
353 | )
354 |
355 | table_current = model.state_dict()[new_table_key]
356 | L1, nH1 = table_pretrained.size()
357 | L2, nH2 = table_current.size()
358 | if nH1 != nH2:
359 | logger.warning(f"Error in loading {table_key}, pass")
360 | else:
361 | if L1 != L2:
362 | S1 = int(L1**0.5)
363 | S2 = int(L2**0.5)
364 | table_pretrained_resized = F.interpolate(
365 | table_pretrained.permute(1, 0).view(1, nH1, S1, S1),
366 | size=(S2, S2),
367 | mode="bicubic",
368 | )
369 | state_dict[table_key] = table_pretrained_resized.view(nH2, L2).permute(
370 | 1, 0
371 | )
372 |
373 | new_state_dict = dict()
374 | for key, value in state_dict.items():
375 | new_key = ".".join(key.split(".")[:-1] + ["module"] + [key.split(".")[-1]])
376 | new_key_2 = ".".join(key.split(".")[:-2] + ["module"] + key.split(".")[-2:])
377 | new_state_dict[new_key] = value
378 | new_state_dict[new_key_2] = value
379 | new_state_dict[key] = value
380 | # load state_dict
381 | load_state_dict(model, new_state_dict, strict, logger)
382 | return checkpoint
383 |
384 |
385 | def weights_to_cpu(state_dict):
386 | """Copy a model state_dict to cpu.
387 |
388 | Args:
389 | state_dict (OrderedDict): Model weights on GPU.
390 |
391 | Returns:
392 | OrderedDict: Model weights on GPU.
393 | """
394 | state_dict_cpu = OrderedDict()
395 | for key, val in state_dict.items():
396 | state_dict_cpu[key] = val.cpu()
397 | return state_dict_cpu
398 |
399 |
400 | def _save_to_state_dict(module, destination, prefix, keep_vars):
401 | """Saves module state to `destination` dictionary.
402 |
403 | This method is modified from :meth:`torch.nn.Module._save_to_state_dict`.
404 |
405 | Args:
406 | module (nn.Module): The module to generate state_dict.
407 | destination (dict): A dict where state will be stored.
408 | prefix (str): The prefix for parameters and buffers used in this
409 | module.
410 | """
411 | for name, param in module._parameters.items():
412 | if param is not None:
413 | destination[prefix + name] = param if keep_vars else param.detach()
414 | for name, buf in module._buffers.items():
415 | # remove check of _non_persistent_buffers_set to allow nn.BatchNorm2d
416 | if buf is not None:
417 | destination[prefix + name] = buf if keep_vars else buf.detach()
418 |
419 |
420 | def get_state_dict(module, destination=None, prefix="", keep_vars=False):
421 | """Returns a dictionary containing a whole state of the module.
422 |
423 | Both parameters and persistent buffers (e.g. running averages) are
424 | included. Keys are corresponding parameter and buffer names.
425 |
426 | This method is modified from :meth:`torch.nn.Module.state_dict` to
427 | recursively check parallel module in case that the model has a complicated
428 | structure, e.g., nn.Module(nn.Module(DDP)).
429 |
430 | Args:
431 | module (nn.Module): The module to generate state_dict.
432 | destination (OrderedDict): Returned dict for the state of the
433 | module.
434 | prefix (str): Prefix of the key.
435 | keep_vars (bool): Whether to keep the variable property of the
436 | parameters. Default: False.
437 |
438 | Returns:
439 | dict: A dictionary containing a whole state of the module.
440 | """
441 | # recursively check parallel module in case that the model has a
442 | # complicated structure, e.g., nn.Module(nn.Module(DDP))
443 | if is_module_wrapper(module):
444 | module = module.module
445 |
446 | # below is the same as torch.nn.Module.state_dict()
447 | if destination is None:
448 | destination = OrderedDict()
449 | destination._metadata = OrderedDict()
450 | destination._metadata[prefix[:-1]] = local_metadata = dict(version=module._version)
451 | _save_to_state_dict(module, destination, prefix, keep_vars)
452 | for name, child in module._modules.items():
453 | if child is not None:
454 | get_state_dict(child, destination, prefix + name + ".", keep_vars=keep_vars)
455 | for hook in module._state_dict_hooks.values():
456 | hook_result = hook(module, destination, prefix, local_metadata)
457 | if hook_result is not None:
458 | destination = hook_result
459 | return destination
460 |
461 |
462 | def save_checkpoint(model, filename, optimizer=None, meta=None):
463 | """Save checkpoint to file.
464 |
465 | The checkpoint will have 3 fields: ``meta``, ``state_dict`` and
466 | ``optimizer``. By default ``meta`` will contain version and time info.
467 |
468 | Args:
469 | model (Module): Module whose params are to be saved.
470 | filename (str): Checkpoint filename.
471 | optimizer (:obj:`Optimizer`, optional): Optimizer to be saved.
472 | meta (dict, optional): Metadata to be saved in checkpoint.
473 | """
474 | if meta is None:
475 | meta = {}
476 | elif not isinstance(meta, dict):
477 | raise TypeError(f"meta must be a dict or None, but got {type(meta)}")
478 | meta.update(mmcv_version=mmcv.__version__, time=time.asctime())
479 |
480 | if is_module_wrapper(model):
481 | model = model.module
482 |
483 | if hasattr(model, "CLASSES") and model.CLASSES is not None:
484 | # save class name to the meta
485 | meta.update(CLASSES=model.CLASSES)
486 |
487 | checkpoint = {"meta": meta, "state_dict": weights_to_cpu(get_state_dict(model))}
488 | # save optimizer state dict in the checkpoint
489 | if isinstance(optimizer, Optimizer):
490 | checkpoint["optimizer"] = optimizer.state_dict()
491 | elif isinstance(optimizer, dict):
492 | checkpoint["optimizer"] = {}
493 | for name, optim in optimizer.items():
494 | checkpoint["optimizer"][name] = optim.state_dict()
495 |
496 | if filename.startswith("pavi://"):
497 | try:
498 | from pavi import modelcloud
499 | from pavi.exception import NodeNotFoundError
500 | except ImportError:
501 | raise ImportError("Please install pavi to load checkpoint from modelcloud.")
502 | model_path = filename[7:]
503 | root = modelcloud.Folder()
504 | model_dir, model_name = osp.split(model_path)
505 | try:
506 | model = modelcloud.get(model_dir)
507 | except NodeNotFoundError:
508 | model = root.create_training_model(model_dir)
509 | with TemporaryDirectory() as tmp_dir:
510 | checkpoint_file = osp.join(tmp_dir, model_name)
511 | with open(checkpoint_file, "wb") as f:
512 | torch.save(checkpoint, f)
513 | f.flush()
514 | model.create_file(checkpoint_file, name=model_name)
515 | else:
516 | mmcv.mkdir_or_exist(osp.dirname(filename))
517 | # immediately flush buffer
518 | with open(filename, "wb") as f:
519 | torch.save(checkpoint, f)
520 | f.flush()
521 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | # general libs
2 | import os, sys, argparse
3 | import random, time
4 | import warnings
5 |
6 | warnings.filterwarnings("ignore")
7 | import cv2
8 | import numpy as np
9 | import torch
10 | import torch.nn as nn
11 | import datetime
12 |
13 |
14 | from utils import *
15 | import utils.helpers as helpers
16 | from utils.optimizer import PolyWarmupAdamW
17 | from models.segformer import WeTr
18 | from torch import distributed as dist
19 | from torch.utils.data.distributed import DistributedSampler
20 | from tqdm import tqdm
21 | from utils.augmentations_mm import *
22 | from torch.nn.parallel import DistributedDataParallel as DDP
23 |
24 |
25 | def setup_ddp():
26 | # print(os.environ.keys())
27 | if "SLURM_PROCID" in os.environ and not "RANK" in os.environ:
28 | # --- multi nodes
29 | world_size = int(os.environ["WORLD_SIZE"])
30 | rank = int(os.environ["SLURM_PROCID"])
31 | gpus_per_node = int(os.environ["SLURM_GPUS_ON_NODE"])
32 | gpu = rank - gpus_per_node * (rank // gpus_per_node)
33 | torch.cuda.set_device(gpu)
34 | dist.init_process_group(
35 | backend="nccl",
36 | world_size=world_size,
37 | rank=rank,
38 | timeout=datetime.timedelta(seconds=7200),
39 | )
40 | elif "RANK" in os.environ and "WORLD_SIZE" in os.environ:
41 | rank = int(os.environ["RANK"])
42 | world_size = int(os.environ["WORLD_SIZE"])
43 | # ---
44 | gpu = int(os.environ["LOCAL_RANK"])
45 | torch.cuda.set_device(gpu)
46 | dist.init_process_group(
47 | "nccl",
48 | init_method="env://",
49 | world_size=world_size,
50 | rank=rank,
51 | timeout=datetime.timedelta(seconds=7200),
52 | )
53 | dist.barrier()
54 | else:
55 | gpu = 0
56 | return gpu
57 |
58 |
59 | def cleanup_ddp():
60 | if dist.is_initialized():
61 | dist.destroy_process_group()
62 |
63 |
64 | def get_arguments():
65 | """Parse all the arguments provided from the CLI.
66 |
67 | Returns:
68 | A list of parsed arguments.
69 | """
70 | parser = argparse.ArgumentParser(description="Full Pipeline Training")
71 |
72 | # Dataset
73 | parser.add_argument(
74 | "--dataset",
75 | type=str,
76 | default="nyudv2",
77 | help="Name of the dataset.",
78 | )
79 | parser.add_argument(
80 | "--train-dir",
81 | type=str,
82 | default="/cache/datasets/nyudv2",
83 | help="Path to the training set directory.",
84 | )
85 | parser.add_argument(
86 | "--batch-size",
87 | type=int,
88 | default=2,
89 | help="Batch size to train the segmenter model.",
90 | )
91 | parser.add_argument(
92 | "--num-workers",
93 | type=int,
94 | default=16,
95 | help="Number of workers for pytorch's dataloader.",
96 | )
97 | parser.add_argument(
98 | "--ignore-label",
99 | type=int,
100 | default=255,
101 | help="Label to ignore during training",
102 | )
103 |
104 | # General
105 | parser.add_argument("--name", default="", type=str, help="model name")
106 | parser.add_argument(
107 | "--evaluate",
108 | action="store_true",
109 | default=False,
110 | help="If true, only validate segmentation.",
111 | )
112 | parser.add_argument(
113 | "--freeze-bn",
114 | type=bool,
115 | nargs="+",
116 | default=True,
117 | help="Whether to keep batch norm statistics intact.",
118 | )
119 | parser.add_argument(
120 | "--num-epoch",
121 | type=int,
122 | nargs="+",
123 | default=[100] * 3,
124 | help="Number of epochs to train for segmentation network.",
125 | )
126 | parser.add_argument(
127 | "--random-seed",
128 | type=int,
129 | default=42,
130 | help="Seed to provide (near-)reproducibility.",
131 | )
132 | parser.add_argument(
133 | "-c",
134 | "--ckpt",
135 | default="model",
136 | type=str,
137 | metavar="PATH",
138 | help="path to save checkpoint (default: model)",
139 | )
140 | parser.add_argument(
141 | "--resume",
142 | default="",
143 | type=str,
144 | metavar="PATH",
145 | help="path to latest checkpoint (default: none)",
146 | )
147 | parser.add_argument(
148 | "--val-every",
149 | type=int,
150 | default=5,
151 | help="How often to validate current architecture.",
152 | )
153 | parser.add_argument(
154 | "--print-network",
155 | action="store_true",
156 | default=False,
157 | help="Whether print newtork paramemters.",
158 | )
159 | parser.add_argument(
160 | "--print-loss",
161 | action="store_true",
162 | default=False,
163 | help="Whether print losses during training.",
164 | )
165 | parser.add_argument(
166 | "--save-image",
167 | type=int,
168 | default=100,
169 | help="Number to save images during evaluating, -1 to save all.",
170 | )
171 | parser.add_argument(
172 | "-i",
173 | "--input",
174 | default=["rgb", "depth"],
175 | type=str,
176 | nargs="+",
177 | help="input type (image, depth)",
178 | )
179 |
180 | # Optimisers
181 | parser.add_argument("--backbone", default="mit_b3", type=str)
182 | parser.add_argument("--n_heads", default=8, type=int)
183 | parser.add_argument("--drop_rate", default=0.0, type=float)
184 | parser.add_argument("--dpr", default=0.4, type=float)
185 |
186 | parser.add_argument("--weight_decay", default=0.01, type=float)
187 | parser.add_argument("--lr_0", default=6e-5, type=float)
188 | parser.add_argument("--lr_1", default=3e-5, type=float)
189 | parser.add_argument("--lr_2", default=1.5e-5, type=float)
190 | parser.add_argument("--is_pretrain_finetune", action="store_true")
191 |
192 | return parser.parse_args()
193 |
194 |
195 | def create_segmenter(num_classes, gpu, backbone, n_heads, dpr, drop_rate):
196 | segmenter = WeTr(backbone, num_classes, n_heads, dpr, drop_rate)
197 | param_groups = segmenter.get_param_groups()
198 | assert torch.cuda.is_available()
199 | segmenter.to("cuda:" + str(gpu))
200 | return segmenter, param_groups
201 |
202 |
203 | def create_loaders(
204 | dataset,
205 | train_dir,
206 | val_dir,
207 | train_list,
208 | val_list,
209 | batch_size,
210 | num_workers,
211 | ignore_label,
212 | ):
213 | """
214 | Args:
215 | train_dir (str) : path to the root directory of the training set.
216 | val_dir (str) : path to the root directory of the validation set.
217 | train_list (str) : path to the training list.
218 | val_list (str) : path to the validation list.
219 | batch_size (int) : training batch size.
220 | num_workers (int) : number of workers to parallelise data loading operations.
221 | ignore_label (int) : label to pad segmentation masks with
222 |
223 | Returns:
224 | train_loader, val loader
225 |
226 | """
227 | # Torch libraries
228 | from torchvision import transforms
229 | from torch.utils.data import DataLoader
230 |
231 | # Custom libraries
232 | from utils.datasets import SegDataset as Dataset
233 | from utils.transforms import ToTensor
234 |
235 | input_names, input_mask_idxs = ["rgb", "depth"], [0, 2, 1]
236 |
237 | if dataset == "nyudv2":
238 | input_scale = [480, 640]
239 | elif dataset == "sunrgbd":
240 | input_scale = [480, 480]
241 |
242 | composed_trn = transforms.Compose(
243 | [
244 | ToTensor(),
245 | RandomColorJitter(p=0.2), #
246 | RandomHorizontalFlip(p=0.5), #
247 | RandomGaussianBlur((3, 3), p=0.2), #
248 | RandomResizedCrop(input_scale, scale=(0.5, 2.0), seg_fill=255), #
249 | Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
250 | ]
251 | )
252 |
253 | composed_val = transforms.Compose(
254 | [
255 | ToTensor(),
256 | Resize(input_scale),
257 | Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
258 | ]
259 | )
260 | # Training and validation sets
261 | trainset = Dataset(
262 | dataset=dataset,
263 | data_file=train_list,
264 | data_dir=train_dir,
265 | input_names=input_names,
266 | input_mask_idxs=input_mask_idxs,
267 | transform_trn=composed_trn,
268 | transform_val=composed_val,
269 | stage="train",
270 | ignore_label=ignore_label,
271 | )
272 |
273 | validset = Dataset(
274 | dataset=dataset,
275 | data_file=val_list,
276 | data_dir=val_dir,
277 | input_names=input_names,
278 | input_mask_idxs=input_mask_idxs,
279 | transform_trn=None,
280 | transform_val=composed_val,
281 | stage="val",
282 | ignore_label=ignore_label,
283 | )
284 | print_log(
285 | "Created train set {} examples, val set {} examples".format(
286 | len(trainset), len(validset)
287 | )
288 | )
289 | train_sampler = DistributedSampler(
290 | trainset, dist.get_world_size(), dist.get_rank(), shuffle=True
291 | )
292 |
293 | # Training and validation loaders
294 | train_loader = DataLoader(
295 | trainset,
296 | batch_size=batch_size,
297 | num_workers=num_workers,
298 | pin_memory=True,
299 | drop_last=True,
300 | sampler=train_sampler,
301 | )
302 | val_loader = DataLoader(
303 | validset, batch_size=1, shuffle=False, num_workers=num_workers, pin_memory=True
304 | )
305 |
306 | return train_loader, val_loader, train_sampler
307 |
308 |
309 | def load_ckpt(ckpt_path, ckpt_dict, is_pretrain_finetune=False):
310 | print("----------------")
311 | ckpt = torch.load(ckpt_path, map_location="cpu")
312 | new_segmenter_ckpt = dict()
313 | if is_pretrain_finetune:
314 | for ckpt_k, ckpt_v in ckpt["segmenter"].items():
315 | if "linear_pred" in ckpt_k:
316 | print(ckpt_k, " is Excluded!")
317 | else:
318 | if "module." in ckpt_k:
319 | new_segmenter_ckpt[ckpt_k[7:]] = ckpt_v
320 | else:
321 | for ckpt_k, ckpt_v in ckpt["segmenter"].items():
322 | new_segmenter_ckpt[ckpt_k] = ckpt_v
323 | if "module." in ckpt_k:
324 | new_segmenter_ckpt[ckpt_k[7:]] = ckpt_v
325 | ckpt["segmenter"] = new_segmenter_ckpt
326 |
327 | for k, v in ckpt_dict.items():
328 | if k in ckpt:
329 | v.load_state_dict(ckpt[k], strict=False)
330 | else:
331 | print(v, " is missed!")
332 | best_val = ckpt.get("best_val", 0)
333 | epoch_start = ckpt.get("epoch_start", 0)
334 | if is_pretrain_finetune:
335 | print_log(
336 | "Found [Pretrain] checkpoint at {} with best_val {:.4f} at epoch {}".format(
337 | ckpt_path, best_val, epoch_start
338 | )
339 | )
340 | return 0, 0
341 | else:
342 |
343 | print_log(
344 | "Found checkpoint at {} with best_val {:.4f} at epoch {}".format(
345 | ckpt_path, best_val, epoch_start
346 | )
347 | )
348 | return best_val, epoch_start
349 |
350 |
351 | def train(
352 | segmenter,
353 | input_types,
354 | train_loader,
355 | optimizer,
356 | epoch,
357 | segm_crit,
358 | freeze_bn,
359 | print_loss=False,
360 | ):
361 | """Training segmenter
362 |
363 | Args:
364 | segmenter (nn.Module) : segmentation network
365 | train_loader (DataLoader) : training data iterator
366 | optim_enc (optim) : optimiser for encoder
367 | optim_dec (optim) : optimiser for decoder
368 | epoch (int) : current epoch
369 | segm_crit (nn.Loss) : segmentation criterion
370 | freeze_bn (bool) : whether to keep BN params intact
371 |
372 | """
373 | train_loader.dataset.set_stage("train")
374 | segmenter.train()
375 | if freeze_bn:
376 | for module in segmenter.modules():
377 | if isinstance(module, nn.BatchNorm2d):
378 | module.eval()
379 | batch_time = AverageMeter()
380 | losses = AverageMeter()
381 |
382 | for i, sample in tqdm(enumerate(train_loader), total=len(train_loader)):
383 | start = time.time()
384 | inputs = [sample[key].cuda().float() for key in input_types]
385 | target = sample["mask"].cuda().long()
386 | # Compute outputs
387 | outputs, masks = segmenter(inputs)
388 | loss = 0
389 | for output in outputs:
390 | output = nn.functional.interpolate(
391 | output, size=target.size()[1:], mode="bilinear", align_corners=False
392 | )
393 | soft_output = nn.LogSoftmax()(output)
394 | # Compute loss and backpropagate
395 | loss += segm_crit(soft_output, target)
396 |
397 | optimizer.zero_grad()
398 | loss.backward()
399 | if print_loss:
400 | print("step: %-3d: loss=%.2f" % (i, loss), flush=True)
401 | optimizer.step()
402 | losses.update(loss.item())
403 | batch_time.update(time.time() - start)
404 |
405 |
406 | def validate(
407 | segmenter, input_types, val_loader, epoch, save_dir, num_classes=-1, save_image=0
408 | ):
409 | """Validate segmenter
410 |
411 | Args:
412 | segmenter (nn.Module) : segmentation network
413 | val_loader (DataLoader) : training data iterator
414 | epoch (int) : current epoch
415 | num_classes (int) : number of classes to consider
416 |
417 | Returns:
418 | Mean IoU (float)
419 | """
420 | global best_iou
421 | val_loader.dataset.set_stage("val")
422 | segmenter.eval()
423 | conf_mat = []
424 | for _ in range(len(input_types) + 1):
425 | conf_mat.append(np.zeros((num_classes, num_classes), dtype=int))
426 | with torch.no_grad():
427 | all_times = 0
428 | count = 0
429 | for i, sample in enumerate(val_loader):
430 | inputs = [sample[key].float().cuda() for key in input_types]
431 | target = sample["mask"]
432 | gt = target[0].data.cpu().numpy().astype(np.uint8)
433 | gt_idx = (
434 | gt < num_classes
435 | ) # Ignore every class index larger than the number of classes
436 |
437 | """from fvcore.nn import FlopCountAnalysis, parameter_count_table
438 |
439 | flops = FlopCountAnalysis(segmenter, inputs)
440 | print("FLOPs: ", flops.total())
441 | print(parameter_count_table(segmenter))
442 | exit()"""
443 |
444 | start_time = time.time()
445 |
446 | outputs, _ = segmenter(inputs)
447 |
448 | end_time = time.time()
449 | all_times += end_time - start_time
450 |
451 | for idx, output in enumerate(outputs):
452 | output = (
453 | cv2.resize(
454 | output[0, :num_classes].data.cpu().numpy().transpose(1, 2, 0),
455 | target.size()[1:][::-1],
456 | interpolation=cv2.INTER_CUBIC,
457 | )
458 | .argmax(axis=2)
459 | .astype(np.uint8)
460 | )
461 | # Compute IoU
462 | conf_mat[idx] += confusion_matrix(
463 | gt[gt_idx], output[gt_idx], num_classes
464 | )
465 | if i < save_image or save_image == -1:
466 | img = make_validation_img(
467 | inputs[0].data.cpu().numpy(),
468 | inputs[1].data.cpu().numpy(),
469 | sample["mask"].data.cpu().numpy(),
470 | output[np.newaxis, :],
471 | )
472 | imgs_folder = os.path.join(save_dir, "imgs")
473 | os.makedirs(imgs_folder, exist_ok=True)
474 | cv2.imwrite(
475 | os.path.join(imgs_folder, "validate_" + str(i) + ".png"),
476 | img[:, :, ::-1],
477 | )
478 | print("imwrite at imgs/validate_%d.png" % i)
479 | count += 1
480 | latency = all_times / count
481 | print("all_times:", all_times, " count:", count, " latency:", latency)
482 |
483 | for idx, input_type in enumerate(input_types + ["ens"]):
484 | glob, mean, iou = getScores(conf_mat[idx])
485 | best_iou_note = ""
486 | if iou > best_iou:
487 | best_iou = iou
488 | best_iou_note = " (best)"
489 | alpha = " "
490 |
491 | input_type_str = "(%s)" % input_type
492 | print_log(
493 | "Epoch %-4d %-7s glob_acc=%-5.2f mean_acc=%-5.2f IoU=%-5.2f%s%s"
494 | % (epoch, input_type_str, glob, mean, iou, alpha, best_iou_note)
495 | )
496 | print_log("")
497 | return iou
498 |
499 |
500 | def main():
501 | global args, best_iou
502 | best_iou = 0
503 | args = get_arguments()
504 | args.val_dir = args.train_dir
505 |
506 | if args.dataset == "nyudv2":
507 | args.train_list = "data/nyudv2/train.txt"
508 | args.val_list = "data/nyudv2/val.txt"
509 | args.num_classes = 40
510 | elif args.dataset == "sunrgbd":
511 | args.train_list = "data/sun/train.txt"
512 | args.val_list = "data/sun/test.txt"
513 | args.num_classes = 37
514 |
515 | args.num_stages = 3
516 | gpu = setup_ddp()
517 | ckpt_dir = os.path.join("ckpt", args.ckpt)
518 | os.makedirs(ckpt_dir, exist_ok=True)
519 | os.system("cp -r *py models utils data %s" % ckpt_dir)
520 | helpers.logger = open(os.path.join(ckpt_dir, "log.txt"), "w+")
521 | print_log(" ".join(sys.argv))
522 |
523 | # Set random seeds
524 | torch.backends.cudnn.deterministic = True
525 | torch.manual_seed(args.random_seed)
526 | if torch.cuda.is_available():
527 | torch.cuda.manual_seed_all(args.random_seed)
528 | np.random.seed(args.random_seed)
529 | random.seed(args.random_seed)
530 | # Generate Segmenter
531 | segmenter, param_groups = create_segmenter(
532 | args.num_classes,
533 | gpu,
534 | args.backbone,
535 | args.n_heads,
536 | args.dpr,
537 | args.drop_rate,
538 | )
539 |
540 | print_log(
541 | "Loaded Segmenter {}, #PARAMS={:3.2f}M".format(
542 | args.backbone, compute_params(segmenter) / 1e6
543 | )
544 | )
545 | # Restore if any
546 | best_val, epoch_start = 0, 0
547 | if args.resume:
548 | if os.path.isfile(args.resume):
549 | best_val, epoch_start = load_ckpt(
550 | args.resume,
551 | {"segmenter": segmenter},
552 | is_pretrain_finetune=args.is_pretrain_finetune,
553 | )
554 | else:
555 | print_log("=> no checkpoint found at '{}'".format(args.resume))
556 | return
557 | no_ddp_segmenter = segmenter
558 | segmenter = DDP(
559 | segmenter, device_ids=[gpu], output_device=0, find_unused_parameters=False
560 | )
561 |
562 | epoch_current = epoch_start
563 | # Criterion
564 | segm_crit = nn.NLLLoss(ignore_index=args.ignore_label).cuda()
565 | # Saver
566 | saver = Saver(
567 | args=vars(args),
568 | ckpt_dir=ckpt_dir,
569 | best_val=best_val,
570 | condition=lambda x, y: x > y,
571 | ) # keep checkpoint with the best validation score
572 |
573 | lrs = [args.lr_0, args.lr_1, args.lr_2]
574 |
575 | print("-------------------------Optimizer Params--------------------")
576 | print("weight_decay:", args.weight_decay)
577 | print("lrs:", lrs)
578 | print("----------------------------------------------------------------")
579 |
580 | for task_idx in range(args.num_stages):
581 | optimizer = PolyWarmupAdamW(
582 | # encoder,encoder-norm,decoder
583 | params=[
584 | {
585 | "params": param_groups[0],
586 | "lr": lrs[task_idx],
587 | "weight_decay": args.weight_decay,
588 | },
589 | {
590 | "params": param_groups[1],
591 | "lr": lrs[task_idx],
592 | "weight_decay": 0.0,
593 | },
594 | {
595 | "params": param_groups[2],
596 | "lr": lrs[task_idx] * 10,
597 | "weight_decay": args.weight_decay,
598 | },
599 | ],
600 | lr=lrs[task_idx],
601 | weight_decay=args.weight_decay,
602 | betas=[0.9, 0.999],
603 | warmup_iter=1500,
604 | max_iter=40000,
605 | warmup_ratio=1e-6,
606 | power=1.0,
607 | )
608 | total_epoch = sum([args.num_epoch[idx] for idx in range(task_idx + 1)])
609 | if epoch_start >= total_epoch:
610 | continue
611 | start = time.time()
612 | torch.cuda.empty_cache()
613 | # Create dataloaders
614 | train_loader, val_loader, train_sampler = create_loaders(
615 | args.dataset,
616 | args.train_dir,
617 | args.val_dir,
618 | args.train_list,
619 | args.val_list,
620 | args.batch_size,
621 | args.num_workers,
622 | args.ignore_label,
623 | )
624 | if args.evaluate:
625 | return validate(
626 | no_ddp_segmenter,
627 | args.input,
628 | val_loader,
629 | 0,
630 | ckpt_dir,
631 | num_classes=args.num_classes,
632 | save_image=args.save_image,
633 | )
634 |
635 | # Optimisers
636 | print_log("Training Stage {}".format(str(task_idx)))
637 |
638 | for epoch in range(min(args.num_epoch[task_idx], total_epoch - epoch_start)):
639 | train_sampler.set_epoch(epoch)
640 | train(
641 | segmenter,
642 | args.input,
643 | train_loader,
644 | optimizer,
645 | epoch_current,
646 | segm_crit,
647 | args.freeze_bn,
648 | args.print_loss,
649 | )
650 | if (epoch + 1) % (args.val_every) == 0:
651 | miou = validate(
652 | no_ddp_segmenter,
653 | args.input,
654 | val_loader,
655 | epoch_current,
656 | ckpt_dir,
657 | args.num_classes,
658 | )
659 | saver.save(
660 | miou,
661 | {"segmenter": segmenter.state_dict(), "epoch_start": epoch_current},
662 | )
663 | epoch_current += 1
664 |
665 | print_log(
666 | "Stage {} finished, time spent {:.3f}min\n".format(
667 | task_idx, (time.time() - start) / 60.0
668 | )
669 | )
670 |
671 | print_log("All stages are now finished. Best Val is {:.3f}".format(saver.best_val))
672 | helpers.logger.close()
673 | cleanup_ddp()
674 |
675 |
676 | if __name__ == "__main__":
677 | main()
678 |
--------------------------------------------------------------------------------
/models/mix_transformer.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------------------------
2 | # Copyright (c) 2021, NVIDIA Corporation. All rights reserved.
3 | #
4 | # This work is licensed under the NVIDIA Source Code License
5 | # ---------------------------------------------------------------
6 | import math
7 | import torch
8 | import torch.nn as nn
9 |
10 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_
11 | from .modules import ModuleParallel, LayerNormParallel
12 |
13 |
14 | class Mlp(nn.Module):
15 | def __init__(
16 | self,
17 | in_features,
18 | hidden_features=None,
19 | out_features=None,
20 | act_layer=nn.GELU,
21 | drop=0.0,
22 | ):
23 | super().__init__()
24 | out_features = out_features or in_features
25 | hidden_features = hidden_features or in_features
26 | self.fc1 = ModuleParallel(nn.Linear(in_features, hidden_features))
27 | self.dwconv = DWConv(hidden_features)
28 | self.act = ModuleParallel(act_layer())
29 | self.fc2 = ModuleParallel(nn.Linear(hidden_features, out_features))
30 | self.drop = ModuleParallel(nn.Dropout(drop))
31 |
32 | self.apply(self._init_weights)
33 |
34 | def _init_weights(self, m):
35 | if isinstance(m, nn.Linear):
36 | trunc_normal_(m.weight, std=0.02)
37 | if isinstance(m, nn.Linear) and m.bias is not None:
38 | nn.init.constant_(m.bias, 0)
39 | elif isinstance(m, nn.LayerNorm):
40 | nn.init.constant_(m.bias, 0)
41 | nn.init.constant_(m.weight, 1.0)
42 | elif isinstance(m, nn.Conv2d):
43 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
44 | fan_out //= m.groups
45 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
46 | if m.bias is not None:
47 | m.bias.data.zero_()
48 |
49 | def forward(self, x, H, W):
50 | x = self.fc1(x)
51 | x = [self.dwconv(x[0], H, W), self.dwconv(x[1], H, W)]
52 | x = self.act(x)
53 | x = self.drop(x)
54 | x = self.fc2(x)
55 | x = self.drop(x)
56 | return x
57 |
58 |
59 | class Mlp_2(nn.Module):
60 | """Multilayer perceptron."""
61 |
62 | def __init__(
63 | self,
64 | in_features,
65 | hidden_features=None,
66 | out_features=None,
67 | act_layer=nn.GELU,
68 | drop=0.0,
69 | ):
70 | super().__init__()
71 | out_features = out_features or in_features
72 | hidden_features = hidden_features or in_features
73 | self.fc1 = nn.Linear(in_features, hidden_features)
74 | self.act = act_layer()
75 | self.fc2 = nn.Linear(hidden_features, out_features)
76 | self.drop = nn.Dropout(drop)
77 |
78 | def forward(self, x):
79 | x = self.fc1(x)
80 | x = self.act(x)
81 | x = self.drop(x)
82 | x = self.fc2(x)
83 | x = self.drop(x)
84 | return x
85 |
86 |
87 | class Attention(nn.Module):
88 | def __init__(
89 | self,
90 | dim,
91 | num_heads=8,
92 | qkv_bias=False,
93 | qk_scale=None,
94 | attn_drop=0.0,
95 | proj_drop=0.0,
96 | sr_ratio=1,
97 | n_heads=8,
98 | ):
99 | super().__init__()
100 | assert (
101 | dim % num_heads == 0
102 | ), f"dim {dim} should be divided by num_heads {num_heads}."
103 |
104 | self.dim = dim
105 | self.num_heads = num_heads
106 | head_dim = dim // num_heads
107 | self.scale = qk_scale or head_dim**-0.5
108 |
109 | self.q = ModuleParallel(nn.Linear(dim, dim, bias=qkv_bias))
110 | self.kv = ModuleParallel(nn.Linear(dim, dim * 2, bias=qkv_bias))
111 | self.attn_drop = ModuleParallel(nn.Dropout(attn_drop))
112 | self.proj = ModuleParallel(nn.Linear(dim, dim))
113 | self.proj_drop = ModuleParallel(nn.Dropout(proj_drop))
114 |
115 | self.sr_ratio = sr_ratio
116 | if sr_ratio > 1:
117 | self.sr = ModuleParallel(
118 | nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
119 | )
120 | self.norm = LayerNormParallel(dim)
121 |
122 | self.cross_heads = n_heads
123 | self.cross_attn_0_to_1 = nn.MultiheadAttention(
124 | dim, self.cross_heads, dropout=0.0, batch_first=False
125 | )
126 | self.cross_attn_1_to_0 = nn.MultiheadAttention(
127 | dim, self.cross_heads, dropout=0.0, batch_first=False
128 | )
129 |
130 | self.relation_judger = nn.Sequential(
131 | Mlp_2(dim * 2, dim, dim), torch.nn.Softmax(dim=-1)
132 | )
133 |
134 | self.k_noise = nn.Embedding(2, dim)
135 | self.v_noise = nn.Embedding(2, dim)
136 |
137 | self.apply(self._init_weights)
138 |
139 | def _init_weights(self, m):
140 | if isinstance(m, nn.Linear):
141 | trunc_normal_(m.weight, std=0.02)
142 | if isinstance(m, nn.Linear) and m.bias is not None:
143 | nn.init.constant_(m.bias, 0)
144 | elif isinstance(m, nn.LayerNorm):
145 | nn.init.constant_(m.bias, 0)
146 | nn.init.constant_(m.weight, 1.0)
147 | elif isinstance(m, nn.Conv2d):
148 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
149 | fan_out //= m.groups
150 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
151 | if m.bias is not None:
152 | m.bias.data.zero_()
153 |
154 | def forward(
155 | self,
156 | x,
157 | H,
158 | W,
159 | ):
160 | B, N, C = x[0].shape
161 | q = self.q(x)
162 | q = [
163 | q_.reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
164 | for q_ in q
165 | ]
166 |
167 | if self.sr_ratio > 1:
168 | x = [x_.permute(0, 2, 1).reshape(B, C, H, W) for x_ in x]
169 | x = self.sr(x)
170 | x = [x_.reshape(B, C, -1).permute(0, 2, 1) for x_ in x]
171 | x = self.norm(x)
172 | kv = self.kv(x)
173 | kv = [
174 | kv_.reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(
175 | 2, 0, 3, 1, 4
176 | )
177 | for kv_ in kv
178 | ]
179 | else:
180 | kv = self.kv(x)
181 | kv = [
182 | kv_.reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(
183 | 2, 0, 3, 1, 4
184 | )
185 | for kv_ in kv
186 | ]
187 | k, v = [kv[0][0], kv[1][0]], [kv[0][1], kv[1][1]]
188 |
189 | attn = [(q_ @ k_.transpose(-2, -1)) * self.scale for (q_, k_) in zip(q, k)]
190 | attn = [attn_.softmax(dim=-1) for attn_ in attn]
191 | attn = self.attn_drop(attn)
192 |
193 | x = [
194 | (attn_ @ v_).transpose(1, 2).reshape(B, N, C)
195 | for (attn_, v_) in zip(attn, v)
196 | ]
197 |
198 | # cross-attn per batch
199 | new_x0 = []
200 | new_x1 = []
201 | for bs in range(B):
202 | ## 1. 0_to_1 cross attn and skip connect
203 | q = x[0][bs].unsqueeze(0)
204 |
205 | judger_input = torch.cat(
206 | [x[0][bs].unsqueeze(0), x[1][bs].unsqueeze(0)], dim=-1
207 | )
208 |
209 | relation_score = self.relation_judger(judger_input)
210 |
211 | noise_k = self.k_noise.weight[0] + q
212 | noise_v = self.v_noise.weight[0] + q
213 |
214 | k = torch.cat([noise_k, torch.mul(q, relation_score)], dim=0)
215 | v = torch.cat([noise_v, x[1][bs].unsqueeze(0)], dim=0)
216 |
217 | new_x0.append(x[0][bs] + self.cross_attn_0_to_1(q, k, v)[0].squeeze(0))
218 |
219 | ## 2. 1_to_0 cross attn and skip connect
220 | q = x[1][bs].unsqueeze(0)
221 |
222 | judger_input = torch.cat(
223 | [x[1][bs].unsqueeze(0), x[0][bs].unsqueeze(0)], dim=-1
224 | )
225 |
226 | relation_score = self.relation_judger(judger_input)
227 |
228 | noise_k = self.k_noise.weight[1] + q
229 | noise_v = self.v_noise.weight[1] + q
230 |
231 | k = torch.cat([noise_k, torch.mul(q, relation_score)], dim=0)
232 | v = torch.cat([noise_v, x[0][bs].unsqueeze(0)], dim=0)
233 |
234 | new_x1.append(x[1][bs] + self.cross_attn_1_to_0(q, k, v)[0].squeeze(0))
235 |
236 | new_x0 = torch.stack(new_x0)
237 | new_x1 = torch.stack(new_x1)
238 | x[0] = new_x0
239 | x[1] = new_x1
240 |
241 | x = self.proj(x)
242 | x = self.proj_drop(x)
243 |
244 | return x
245 |
246 |
247 | class Block(nn.Module):
248 | def __init__(
249 | self,
250 | dim,
251 | num_heads,
252 | mlp_ratio=4.0,
253 | qkv_bias=False,
254 | qk_scale=None,
255 | drop=0.0,
256 | attn_drop=0.0,
257 | drop_path=0.0,
258 | act_layer=nn.GELU,
259 | norm_layer=LayerNormParallel,
260 | sr_ratio=1,
261 | n_heads=8,
262 | ):
263 | super().__init__()
264 | self.norm1 = norm_layer(dim)
265 |
266 | self.attn = Attention(
267 | dim,
268 | num_heads=num_heads,
269 | qkv_bias=qkv_bias,
270 | qk_scale=qk_scale,
271 | attn_drop=attn_drop,
272 | proj_drop=drop,
273 | sr_ratio=sr_ratio,
274 | n_heads=n_heads,
275 | )
276 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
277 | self.drop_path = (
278 | ModuleParallel(DropPath(drop_path))
279 | if drop_path > 0.0
280 | else ModuleParallel(nn.Identity())
281 | )
282 | self.norm2 = norm_layer(dim)
283 | mlp_hidden_dim = int(dim * mlp_ratio)
284 | self.mlp = Mlp(
285 | in_features=dim,
286 | hidden_features=mlp_hidden_dim,
287 | act_layer=act_layer,
288 | drop=drop,
289 | )
290 |
291 | self.apply(self._init_weights)
292 |
293 | def _init_weights(self, m):
294 | if isinstance(m, nn.Linear):
295 | trunc_normal_(m.weight, std=0.02)
296 | if isinstance(m, nn.Linear) and m.bias is not None:
297 | nn.init.constant_(m.bias, 0)
298 | elif isinstance(m, nn.LayerNorm):
299 | nn.init.constant_(m.bias, 0)
300 | nn.init.constant_(m.weight, 1.0)
301 | elif isinstance(m, nn.Conv2d):
302 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
303 | fan_out //= m.groups
304 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
305 | if m.bias is not None:
306 | m.bias.data.zero_()
307 |
308 | def forward(self, x, H, W):
309 | B = x[0].shape[0]
310 |
311 | f = self.drop_path(
312 | self.attn(
313 | self.norm1(x),
314 | H,
315 | W,
316 | )
317 | )
318 | x = [x_ + f_ for (x_, f_) in zip(x, f)]
319 | f = self.drop_path(self.mlp(self.norm2(x), H, W))
320 | x = [x_ + f_ for (x_, f_) in zip(x, f)]
321 |
322 | return x
323 |
324 |
325 | class OverlapPatchEmbed(nn.Module):
326 | """Image to Patch Embedding"""
327 |
328 | def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768):
329 | super().__init__()
330 | img_size = to_2tuple(img_size)
331 | patch_size = to_2tuple(patch_size)
332 |
333 | self.img_size = img_size
334 | self.patch_size = patch_size
335 | self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1]
336 | self.num_patches = self.H * self.W
337 | self.proj = ModuleParallel(
338 | nn.Conv2d(
339 | in_chans,
340 | embed_dim,
341 | kernel_size=patch_size,
342 | stride=stride,
343 | padding=(patch_size[0] // 2, patch_size[1] // 2),
344 | )
345 | )
346 | self.norm = LayerNormParallel(embed_dim)
347 |
348 | self.apply(self._init_weights)
349 |
350 | def _init_weights(self, m):
351 | if isinstance(m, nn.Linear):
352 | trunc_normal_(m.weight, std=0.02)
353 | if isinstance(m, nn.Linear) and m.bias is not None:
354 | nn.init.constant_(m.bias, 0)
355 | elif isinstance(m, nn.LayerNorm):
356 | nn.init.constant_(m.bias, 0)
357 | nn.init.constant_(m.weight, 1.0)
358 | elif isinstance(m, nn.Conv2d):
359 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
360 | fan_out //= m.groups
361 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
362 | if m.bias is not None:
363 | m.bias.data.zero_()
364 |
365 | def forward(self, x):
366 | x = self.proj(x)
367 | _, _, H, W = x[0].shape
368 | x = [x_.flatten(2).transpose(1, 2) for x_ in x]
369 | x = self.norm(x)
370 | return x, H, W
371 |
372 |
373 | class MixVisionTransformer(nn.Module):
374 | def __init__(
375 | self,
376 | img_size=224,
377 | patch_size=16,
378 | in_chans=3,
379 | num_classes=1000,
380 | embed_dims=[64, 128, 256, 512],
381 | num_heads=[1, 2, 4, 8],
382 | mlp_ratios=[4, 4, 4, 4],
383 | qkv_bias=False,
384 | qk_scale=None,
385 | drop_rate=0.0,
386 | attn_drop_rate=0.0,
387 | drop_path_rate=0.0,
388 | norm_layer=LayerNormParallel,
389 | depths=[3, 4, 6, 3],
390 | sr_ratios=[8, 4, 2, 1],
391 | n_heads=8,
392 | ):
393 | super().__init__()
394 |
395 | self.num_classes = num_classes
396 | self.depths = depths
397 | self.embed_dims = embed_dims
398 |
399 | # patch_embed
400 | self.patch_embed1 = OverlapPatchEmbed(
401 | img_size=img_size,
402 | patch_size=7,
403 | stride=4,
404 | in_chans=in_chans,
405 | embed_dim=embed_dims[0],
406 | )
407 | self.patch_embed2 = OverlapPatchEmbed(
408 | img_size=img_size // 4,
409 | patch_size=3,
410 | stride=2,
411 | in_chans=embed_dims[0],
412 | embed_dim=embed_dims[1],
413 | )
414 | self.patch_embed3 = OverlapPatchEmbed(
415 | img_size=img_size // 8,
416 | patch_size=3,
417 | stride=2,
418 | in_chans=embed_dims[1],
419 | embed_dim=embed_dims[2],
420 | )
421 | self.patch_embed4 = OverlapPatchEmbed(
422 | img_size=img_size // 16,
423 | patch_size=3,
424 | stride=2,
425 | in_chans=embed_dims[2],
426 | embed_dim=embed_dims[3],
427 | )
428 |
429 | # transformer encoder
430 | dpr = [
431 | x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
432 | ] # stochastic depth decay rule
433 | cur = 0
434 | self.block1 = nn.ModuleList(
435 | [
436 | Block(
437 | dim=embed_dims[0],
438 | num_heads=num_heads[0],
439 | mlp_ratio=mlp_ratios[0],
440 | qkv_bias=qkv_bias,
441 | qk_scale=qk_scale,
442 | drop=drop_rate,
443 | attn_drop=attn_drop_rate,
444 | drop_path=dpr[cur + i],
445 | norm_layer=norm_layer,
446 | sr_ratio=sr_ratios[0],
447 | n_heads=n_heads,
448 | )
449 | for i in range(depths[0])
450 | ]
451 | )
452 | self.norm1 = norm_layer(embed_dims[0])
453 |
454 | cur += depths[0]
455 | self.block2 = nn.ModuleList(
456 | [
457 | Block(
458 | dim=embed_dims[1],
459 | num_heads=num_heads[1],
460 | mlp_ratio=mlp_ratios[1],
461 | qkv_bias=qkv_bias,
462 | qk_scale=qk_scale,
463 | drop=drop_rate,
464 | attn_drop=attn_drop_rate,
465 | drop_path=dpr[cur + i],
466 | norm_layer=norm_layer,
467 | sr_ratio=sr_ratios[1],
468 | n_heads=n_heads,
469 | )
470 | for i in range(depths[1])
471 | ]
472 | )
473 | self.norm2 = norm_layer(embed_dims[1])
474 |
475 | cur += depths[1]
476 | self.block3 = nn.ModuleList(
477 | [
478 | Block(
479 | dim=embed_dims[2],
480 | num_heads=num_heads[2],
481 | mlp_ratio=mlp_ratios[2],
482 | qkv_bias=qkv_bias,
483 | qk_scale=qk_scale,
484 | drop=drop_rate,
485 | attn_drop=attn_drop_rate,
486 | drop_path=dpr[cur + i],
487 | norm_layer=norm_layer,
488 | sr_ratio=sr_ratios[2],
489 | n_heads=n_heads,
490 | )
491 | for i in range(depths[2])
492 | ]
493 | )
494 | self.norm3 = norm_layer(embed_dims[2])
495 |
496 | cur += depths[2]
497 | self.block4 = nn.ModuleList(
498 | [
499 | Block(
500 | dim=embed_dims[3],
501 | num_heads=num_heads[3],
502 | mlp_ratio=mlp_ratios[3],
503 | qkv_bias=qkv_bias,
504 | qk_scale=qk_scale,
505 | drop=drop_rate,
506 | attn_drop=attn_drop_rate,
507 | drop_path=dpr[cur + i],
508 | norm_layer=norm_layer,
509 | sr_ratio=sr_ratios[3],
510 | n_heads=n_heads,
511 | )
512 | for i in range(depths[3])
513 | ]
514 | )
515 | self.norm4 = norm_layer(embed_dims[3])
516 |
517 | self.apply(self._init_weights)
518 |
519 | def _init_weights(self, m):
520 | if isinstance(m, nn.Linear):
521 | trunc_normal_(m.weight, std=0.02)
522 | if isinstance(m, nn.Linear) and m.bias is not None:
523 | nn.init.constant_(m.bias, 0)
524 | elif isinstance(m, nn.LayerNorm):
525 | nn.init.constant_(m.bias, 0)
526 | nn.init.constant_(m.weight, 1.0)
527 | elif isinstance(m, nn.Conv2d):
528 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
529 | fan_out //= m.groups
530 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
531 | if m.bias is not None:
532 | m.bias.data.zero_()
533 |
534 | def reset_drop_path(self, drop_path_rate):
535 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))]
536 | cur = 0
537 | for i in range(self.depths[0]):
538 | self.block1[i].drop_path.drop_prob = dpr[cur + i]
539 |
540 | cur += self.depths[0]
541 | for i in range(self.depths[1]):
542 | self.block2[i].drop_path.drop_prob = dpr[cur + i]
543 |
544 | cur += self.depths[1]
545 | for i in range(self.depths[2]):
546 | self.block3[i].drop_path.drop_prob = dpr[cur + i]
547 |
548 | cur += self.depths[2]
549 | for i in range(self.depths[3]):
550 | self.block4[i].drop_path.drop_prob = dpr[cur + i]
551 |
552 | def freeze_patch_emb(self):
553 | self.patch_embed1.requires_grad = False
554 |
555 | @torch.jit.ignore
556 | def no_weight_decay(self):
557 | return {
558 | "pos_embed1",
559 | "pos_embed2",
560 | "pos_embed3",
561 | "pos_embed4",
562 | "cls_token",
563 | } # has pos_embed may be better
564 |
565 | def get_classifier(self):
566 | return self.head
567 |
568 | def reset_classifier(self, num_classes, global_pool=""):
569 | self.num_classes = num_classes
570 | self.head = (
571 | nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
572 | )
573 |
574 | def forward_features(self, x):
575 | B = x[0].shape[0]
576 | outs0, outs1 = [], []
577 |
578 | # stage 1
579 | x, H, W = self.patch_embed1(x)
580 | for i, blk in enumerate(self.block1):
581 |
582 | x = blk(
583 | x,
584 | H,
585 | W,
586 | )
587 | x = self.norm1(x)
588 | x = [x_.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for x_ in x]
589 | outs0.append(x[0])
590 | outs1.append(x[1])
591 |
592 | # stage 2
593 | x, H, W = self.patch_embed2(x)
594 | for i, blk in enumerate(self.block2):
595 |
596 | x = blk(
597 | x,
598 | H,
599 | W,
600 | )
601 | x = self.norm2(x)
602 | x = [x_.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for x_ in x]
603 | outs0.append(x[0])
604 | outs1.append(x[1])
605 |
606 | # stage 3
607 | x, H, W = self.patch_embed3(x)
608 | for i, blk in enumerate(self.block3):
609 |
610 | x = blk(
611 | x,
612 | H,
613 | W,
614 | )
615 | x = self.norm3(x)
616 | x = [x_.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for x_ in x]
617 | outs0.append(x[0])
618 | outs1.append(x[1])
619 |
620 | # stage 4
621 | x, H, W = self.patch_embed4(x)
622 | for i, blk in enumerate(self.block4):
623 |
624 | x = blk(
625 | x,
626 | H,
627 | W,
628 | )
629 | x = self.norm4(x)
630 | x = [x_.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for x_ in x]
631 | outs0.append(x[0])
632 | outs1.append(x[1])
633 |
634 | return [outs0, outs1]
635 |
636 | def forward(self, x):
637 | x = self.forward_features(x)
638 | return x
639 |
640 |
641 | class DWConv(nn.Module):
642 | def __init__(self, dim=768):
643 | super(DWConv, self).__init__()
644 | self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
645 |
646 | def forward(self, x, H, W):
647 | B, N, C = x.shape
648 | x = x.transpose(1, 2).view(B, C, H, W)
649 | x = self.dwconv(x)
650 | x = x.flatten(2).transpose(1, 2)
651 |
652 | return x
653 |
654 |
655 | class mit_b0(MixVisionTransformer):
656 | def __init__(self, n_heads, dpr, drop_rate, **kwargs):
657 | super(mit_b0, self).__init__(
658 | patch_size=4,
659 | embed_dims=[32, 64, 160, 256],
660 | num_heads=[1, 2, 5, 8],
661 | mlp_ratios=[4, 4, 4, 4],
662 | qkv_bias=True,
663 | norm_layer=LayerNormParallel,
664 | depths=[2, 2, 2, 2],
665 | sr_ratios=[8, 4, 2, 1],
666 | drop_rate=drop_rate,
667 | drop_path_rate=dpr,
668 | n_heads=n_heads,
669 | )
670 |
671 |
672 | class mit_b1(MixVisionTransformer):
673 | def __init__(self, n_heads, dpr, drop_rate, **kwargs):
674 | super(mit_b1, self).__init__(
675 | patch_size=4,
676 | embed_dims=[64, 128, 320, 512],
677 | num_heads=[1, 2, 5, 8],
678 | mlp_ratios=[4, 4, 4, 4],
679 | qkv_bias=True,
680 | norm_layer=LayerNormParallel,
681 | depths=[2, 2, 2, 2],
682 | sr_ratios=[8, 4, 2, 1],
683 | drop_rate=drop_rate,
684 | drop_path_rate=dpr,
685 | n_heads=n_heads,
686 | )
687 |
688 |
689 | class mit_b2(MixVisionTransformer):
690 | def __init__(self, n_heads, dpr, drop_rate, **kwargs):
691 | super(mit_b2, self).__init__(
692 | patch_size=4,
693 | embed_dims=[64, 128, 320, 512],
694 | num_heads=[1, 2, 5, 8],
695 | mlp_ratios=[4, 4, 4, 4],
696 | qkv_bias=True,
697 | norm_layer=LayerNormParallel,
698 | depths=[3, 4, 6, 3],
699 | sr_ratios=[8, 4, 2, 1],
700 | drop_rate=drop_rate,
701 | drop_path_rate=dpr,
702 | n_heads=n_heads,
703 | )
704 |
705 |
706 | class mit_b3(MixVisionTransformer):
707 | def __init__(self, n_heads, dpr, drop_rate, **kwargs):
708 | super(mit_b3, self).__init__(
709 | patch_size=4,
710 | embed_dims=[64, 128, 320, 512],
711 | num_heads=[1, 2, 5, 8],
712 | mlp_ratios=[4, 4, 4, 4],
713 | qkv_bias=True,
714 | norm_layer=LayerNormParallel,
715 | depths=[3, 4, 18, 3],
716 | sr_ratios=[8, 4, 2, 1],
717 | drop_rate=drop_rate,
718 | drop_path_rate=dpr,
719 | n_heads=n_heads,
720 | )
721 |
722 |
723 | class mit_b4(MixVisionTransformer):
724 | def __init__(self, n_heads, dpr, drop_rate, **kwargs):
725 | super(mit_b4, self).__init__(
726 | patch_size=4,
727 | embed_dims=[64, 128, 320, 512],
728 | num_heads=[1, 2, 5, 8],
729 | mlp_ratios=[4, 4, 4, 4],
730 | qkv_bias=True,
731 | norm_layer=LayerNormParallel,
732 | depths=[3, 8, 27, 3],
733 | sr_ratios=[8, 4, 2, 1],
734 | drop_rate=drop_rate,
735 | drop_path_rate=dpr,
736 | n_heads=n_heads,
737 | )
738 |
739 |
740 | class mit_b5(MixVisionTransformer):
741 | def __init__(self, n_heads, dpr, drop_rate, **kwargs):
742 | super(mit_b5, self).__init__(
743 | patch_size=4,
744 | embed_dims=[64, 128, 320, 512],
745 | num_heads=[1, 2, 5, 8],
746 | mlp_ratios=[4, 4, 4, 4],
747 | qkv_bias=True,
748 | norm_layer=LayerNormParallel,
749 | depths=[3, 6, 40, 3],
750 | sr_ratios=[8, 4, 2, 1],
751 | drop_rate=drop_rate,
752 | drop_path_rate=dpr,
753 | n_heads=n_heads,
754 | )
755 |
--------------------------------------------------------------------------------
/models/swin_transformer.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Swin Transformer
3 | # Copyright (c) 2021 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Ze Liu, Yutong Lin, Yixuan Wei
6 | # --------------------------------------------------------
7 |
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 | import torch.utils.checkpoint as checkpoint
12 | import numpy as np
13 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_
14 | import math
15 | from mmcv_custom import load_checkpoint
16 | from mmdet.utils import get_root_logger
17 | from .modules import (
18 | ModuleParallel,
19 | Additional_One_ModuleParallel,
20 | LayerNormParallel,
21 | Additional_Two_ModuleParallel,
22 | )
23 |
24 |
25 | class Mlp(nn.Module):
26 | """Multilayer perceptron."""
27 |
28 | def __init__(
29 | self,
30 | in_features,
31 | hidden_features=None,
32 | out_features=None,
33 | act_layer=ModuleParallel(nn.GELU()),
34 | drop=0.0,
35 | ):
36 | super().__init__()
37 | out_features = out_features or in_features
38 | hidden_features = hidden_features or in_features
39 | self.fc1 = ModuleParallel(nn.Linear(in_features, hidden_features))
40 | self.act = act_layer
41 | self.fc2 = ModuleParallel(nn.Linear(hidden_features, out_features))
42 | self.drop = ModuleParallel(nn.Dropout(drop))
43 |
44 | def forward(self, x):
45 | x = self.fc1(x)
46 | x = self.act(x)
47 | x = self.drop(x)
48 | x = self.fc2(x)
49 | x = self.drop(x)
50 | return x
51 |
52 |
53 | def window_partition(x, window_size):
54 | """
55 | Args:
56 | x: (B, H, W, C)
57 | window_size (int): window size
58 |
59 | Returns:
60 | windows: (num_windows*B, window_size, window_size, C)
61 | """
62 | B, H, W, C = x.shape
63 | x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
64 | windows = (
65 | x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
66 | )
67 | return windows
68 |
69 |
70 | def window_reverse(windows, window_size, H, W):
71 | """
72 | Args:
73 | windows: (num_windows*B, window_size, window_size, C)
74 | window_size (int): Window size
75 | H (int): Height of image
76 | W (int): Width of image
77 |
78 | Returns:
79 | x: (B, H, W, C)
80 | """
81 | B = int(windows.shape[0] / (H * W / window_size / window_size))
82 | x = windows.view(
83 | B, H // window_size, W // window_size, window_size, window_size, -1
84 | )
85 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
86 | return x
87 |
88 |
89 | class WindowAttention(nn.Module):
90 | """Window based multi-head self attention (W-MSA) module with relative position bias.
91 | It supports both of shifted and non-shifted window.
92 |
93 | Args:
94 | dim (int): Number of input channels.
95 | window_size (tuple[int]): The height and width of the window.
96 | num_heads (int): Number of attention heads.
97 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
98 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
99 | attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
100 | proj_drop (float, optional): Dropout ratio of output. Default: 0.0
101 | """
102 |
103 | def __init__(
104 | self,
105 | dim,
106 | window_size,
107 | num_heads,
108 | qkv_bias=True,
109 | qk_scale=None,
110 | attn_drop=0.0,
111 | proj_drop=0.0,
112 | ):
113 |
114 | super().__init__()
115 | self.dim = dim
116 | self.window_size = window_size # Wh, Ww
117 | self.num_heads = num_heads
118 | head_dim = dim // num_heads
119 | self.scale = qk_scale or head_dim**-0.5
120 |
121 | # define a parameter table of relative position bias
122 | self.relative_position_bias_table = nn.Parameter(
123 | torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
124 | ) # 2*Wh-1 * 2*Ww-1, nH
125 |
126 | # get pair-wise relative position index for each token inside the window
127 | coords_h = torch.arange(self.window_size[0])
128 | coords_w = torch.arange(self.window_size[1])
129 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
130 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
131 | relative_coords = (
132 | coords_flatten[:, :, None] - coords_flatten[:, None, :]
133 | ) # 2, Wh*Ww, Wh*Ww
134 | relative_coords = relative_coords.permute(
135 | 1, 2, 0
136 | ).contiguous() # Wh*Ww, Wh*Ww, 2
137 | relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
138 | relative_coords[:, :, 1] += self.window_size[1] - 1
139 | relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
140 | relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
141 | self.register_buffer("relative_position_index", relative_position_index)
142 |
143 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
144 | self.attn_drop = nn.Dropout(attn_drop)
145 | self.proj = nn.Linear(dim, dim)
146 | self.proj_drop = nn.Dropout(proj_drop)
147 |
148 | trunc_normal_(self.relative_position_bias_table, std=0.02)
149 | self.softmax = nn.Softmax(dim=-1)
150 |
151 | def forward(self, x, mask=None):
152 | """Forward function.
153 |
154 | Args:
155 | x: input features with shape of (num_windows*B, N, C)
156 | mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
157 | """
158 | B_, N, C = x.shape
159 | qkv = (
160 | self.qkv(x)
161 | .reshape(B_, N, 3, self.num_heads, C // self.num_heads)
162 | .permute(2, 0, 3, 1, 4)
163 | )
164 | q, k, v = (
165 | qkv[0],
166 | qkv[1],
167 | qkv[2],
168 | ) # make torchscript happy (cannot use tensor as tuple)
169 |
170 | q = q * self.scale
171 | attn = q @ k.transpose(-2, -1)
172 |
173 | relative_position_bias = self.relative_position_bias_table[
174 | self.relative_position_index.view(-1)
175 | ].view(
176 | self.window_size[0] * self.window_size[1],
177 | self.window_size[0] * self.window_size[1],
178 | -1,
179 | ) # Wh*Ww,Wh*Ww,nH
180 | relative_position_bias = relative_position_bias.permute(
181 | 2, 0, 1
182 | ).contiguous() # nH, Wh*Ww, Wh*Ww
183 | attn = attn + relative_position_bias.unsqueeze(0)
184 |
185 | if mask is not None:
186 |
187 | nW = mask.shape[0]
188 | attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(
189 | 1
190 | ).unsqueeze(0)
191 |
192 | attn = attn.view(-1, self.num_heads, N, N)
193 | attn = self.softmax(attn)
194 | else:
195 | attn = self.softmax(attn)
196 |
197 | attn = self.attn_drop(attn)
198 |
199 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
200 | x = self.proj(x)
201 | x = self.proj_drop(x)
202 | return x
203 |
204 |
205 | class Mlp_2(nn.Module):
206 | """Multilayer perceptron."""
207 |
208 | def __init__(
209 | self,
210 | in_features,
211 | hidden_features=None,
212 | out_features=None,
213 | act_layer=nn.GELU,
214 | drop=0.0,
215 | ):
216 | super().__init__()
217 | out_features = out_features or in_features
218 | hidden_features = hidden_features or in_features
219 | self.fc1 = nn.Linear(in_features, hidden_features)
220 | self.act = act_layer()
221 | self.fc2 = nn.Linear(hidden_features, out_features)
222 | self.drop = nn.Dropout(drop)
223 |
224 | def forward(self, x):
225 | x = self.fc1(x)
226 | x = self.act(x)
227 | x = self.drop(x)
228 | x = self.fc2(x)
229 | x = self.drop(x)
230 | return x
231 |
232 |
233 | class SwinTransformerBlock(nn.Module):
234 | """Swin Transformer Block.
235 |
236 | Args:
237 | dim (int): Number of input channels.
238 | num_heads (int): Number of attention heads.
239 | window_size (int): Window size.
240 | shift_size (int): Shift size for SW-MSA.
241 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
242 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
243 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
244 | drop (float, optional): Dropout rate. Default: 0.0
245 | attn_drop (float, optional): Attention dropout rate. Default: 0.0
246 | drop_path (float, optional): Stochastic depth rate. Default: 0.0
247 | act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
248 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
249 | """
250 |
251 | def __init__(
252 | self,
253 | dim,
254 | num_heads,
255 | window_size=7,
256 | shift_size=0,
257 | mlp_ratio=4.0,
258 | qkv_bias=True,
259 | qk_scale=None,
260 | drop=0.0,
261 | attn_drop=0.0,
262 | drop_path=0.0,
263 | act_layer=ModuleParallel(nn.GELU()),
264 | norm_layer=LayerNormParallel,
265 | ):
266 | super().__init__()
267 | self.dim = dim
268 | self.num_heads = num_heads
269 | self.window_size = window_size
270 | self.shift_size = shift_size
271 | self.mlp_ratio = mlp_ratio
272 | assert (
273 | 0 <= self.shift_size < self.window_size
274 | ), "shift_size must in 0-window_size"
275 |
276 | self.norm1 = ModuleParallel(nn.LayerNorm(dim))
277 | self.attn = Additional_One_ModuleParallel(
278 | WindowAttention(
279 | dim,
280 | window_size=to_2tuple(self.window_size),
281 | num_heads=num_heads,
282 | qkv_bias=qkv_bias,
283 | qk_scale=qk_scale,
284 | attn_drop=attn_drop,
285 | proj_drop=drop,
286 | )
287 | )
288 |
289 | self.drop_path = (
290 | ModuleParallel(DropPath(drop_path)) if drop_path > 0.0 else nn.Identity()
291 | )
292 | self.norm2 = ModuleParallel(nn.LayerNorm(dim))
293 | mlp_hidden_dim = int(dim * mlp_ratio)
294 | self.mlp = Mlp(
295 | in_features=dim,
296 | hidden_features=mlp_hidden_dim,
297 | act_layer=act_layer,
298 | drop=drop,
299 | )
300 |
301 | self.H = None
302 | self.W = None
303 |
304 | self.cross_heads = 8
305 | self.cross_attn_0_to_1 = nn.MultiheadAttention(
306 | dim, self.cross_heads, dropout=0.0, batch_first=False
307 | )
308 | self.cross_attn_1_to_0 = nn.MultiheadAttention(
309 | dim, self.cross_heads, dropout=0.0, batch_first=False
310 | )
311 | self.relation_judger = nn.Sequential(
312 | Mlp_2(dim * 2, dim, dim), torch.nn.Softmax(dim=-1)
313 | )
314 | self.k_noise = nn.Embedding(2, dim)
315 | self.v_noise = nn.Embedding(2, dim)
316 |
317 | self.apply(self._init_weights)
318 |
319 | def _init_weights(self, m):
320 | if isinstance(m, nn.Linear):
321 | trunc_normal_(m.weight, std=0.02)
322 | if isinstance(m, nn.Linear) and m.bias is not None:
323 | nn.init.constant_(m.bias, 0)
324 | elif isinstance(m, nn.LayerNorm):
325 | nn.init.constant_(m.bias, 0)
326 | nn.init.constant_(m.weight, 1.0)
327 | elif isinstance(m, nn.Conv2d):
328 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
329 | fan_out //= m.groups
330 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
331 | if m.bias is not None:
332 | m.bias.data.zero_()
333 |
334 | def forward(self, x, mask_matrix):
335 | """Forward function.
336 |
337 | Args:
338 | x: Input feature, tensor size (B, H*W, C).
339 | H, W: Spatial resolution of the input feature.
340 | mask_matrix: Attention mask for cyclic shift.
341 | """
342 | B, L, C = x[0].shape
343 | H, W = self.H, self.W
344 | assert L == H * W, "input feature has wrong size"
345 |
346 | shortcut = x
347 | x = self.norm1(x)
348 | for i in range(len(x)):
349 | x[i] = x[i].view(B, H, W, C)
350 |
351 | # pad feature maps to multiples of window size
352 | pad_l = pad_t = 0
353 | pad_r = (self.window_size - W % self.window_size) % self.window_size
354 | pad_b = (self.window_size - H % self.window_size) % self.window_size
355 | for i in range(len(x)):
356 | x[i] = F.pad(x[i], (0, 0, pad_l, pad_r, pad_t, pad_b))
357 | _, Hp, Wp, _ = x[0].shape
358 |
359 | # cyclic shift
360 | if self.shift_size > 0:
361 | shifted_x = [
362 | torch.roll(
363 | x[i], shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)
364 | )
365 | for i in range(len(x))
366 | ]
367 | attn_mask = mask_matrix
368 | else:
369 | shifted_x = x
370 | attn_mask = None
371 |
372 | # partition windows
373 | x_windows = [
374 | window_partition(shifted_x[i], self.window_size)
375 | for i in range(len(shifted_x))
376 | ] # nW*B, window_size, window_size, C
377 |
378 | for i in range(len(x_windows)):
379 | x_windows[i] = x_windows[i].view(
380 | -1, self.window_size * self.window_size, C
381 | ) # nW*B, window_size*window_size, C
382 |
383 | # W-MSA/SW-MSA
384 | attn_windows = self.attn(
385 | x_windows, attn_mask
386 | ) # nW*B, window_size*window_size, C
387 |
388 | # merge windows
389 | for i in range(len(attn_windows)):
390 | attn_windows[i] = attn_windows[i].view(
391 | -1, self.window_size, self.window_size, C
392 | )
393 | shifted_x = [
394 | window_reverse(attn_windows[i], self.window_size, Hp, Wp)
395 | for i in range(len(attn_windows))
396 | ] # B H' W' C
397 |
398 | # reverse cyclic shift
399 | if self.shift_size > 0:
400 | x = [
401 | torch.roll(
402 | shifted_x[i], shifts=(self.shift_size, self.shift_size), dims=(1, 2)
403 | )
404 | for i in range(len(shifted_x))
405 | ]
406 | else:
407 | x = shifted_x
408 |
409 | if pad_r > 0 or pad_b > 0:
410 | for i in range(len(x)):
411 | x[i] = x[i][:, :H, :W, :].contiguous()
412 | for i in range(len(x)):
413 | x[i] = x[i].view(B, H * W, C)
414 |
415 | # cross-attn per batch
416 | new_x0 = []
417 | new_x1 = []
418 | for bs in range(B):
419 | ## 1. 0_to_1 cross attn and skip connect
420 | q = x[0][bs].unsqueeze(0)
421 | # k = v = x[1][bs].unsqueeze(0)
422 | judger_input = torch.cat(
423 | [x[0][bs].unsqueeze(0), x[1][bs].unsqueeze(0)], dim=-1
424 | )
425 |
426 | relation_score = self.relation_judger(judger_input)
427 |
428 | noise_k = self.k_noise.weight[0] + q
429 | noise_v = self.v_noise.weight[0] + q
430 |
431 | k = torch.cat([noise_k, torch.mul(q, relation_score)], dim=0)
432 | v = torch.cat([noise_v, x[1][bs].unsqueeze(0)], dim=0)
433 |
434 | new_x0.append(x[0][bs] + self.cross_attn_0_to_1(q, k, v)[0].squeeze(0))
435 |
436 | ## 2. 1_to_0 cross attn and skip connect
437 | q = x[1][bs].unsqueeze(0)
438 | # k = v = x[0][bs].unsqueeze(0)
439 | judger_input = torch.cat(
440 | [x[1][bs].unsqueeze(0), x[0][bs].unsqueeze(0)], dim=-1
441 | )
442 |
443 | relation_score = self.relation_judger(judger_input)
444 |
445 | noise_k = self.k_noise.weight[1] + q
446 | noise_v = self.v_noise.weight[1] + q
447 |
448 | k = torch.cat([noise_k, torch.mul(q, relation_score)], dim=0)
449 | v = torch.cat([noise_v, x[0][bs].unsqueeze(0)], dim=0)
450 |
451 | new_x1.append(x[1][bs] + self.cross_attn_1_to_0(q, k, v)[0].squeeze(0))
452 |
453 | new_x0 = torch.stack(new_x0)
454 | new_x1 = torch.stack(new_x1)
455 | x[0] = new_x0
456 | x[1] = new_x1
457 |
458 | # FFN
459 | x_dp1 = self.drop_path(x)
460 | for i in range(len(x)):
461 | x[i] = shortcut[i] + x_dp1[i]
462 | x_dp2 = self.drop_path(self.mlp(self.norm2(x)))
463 | for i in range(len(x)):
464 | x[i] = x[i] + x_dp2[i]
465 |
466 | return x
467 |
468 |
469 | class PatchMerging(nn.Module):
470 | """Patch Merging Layer
471 |
472 | Args:
473 | dim (int): Number of input channels.
474 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
475 | """
476 |
477 | def __init__(self, dim, norm_layer=nn.LayerNorm):
478 | super().__init__()
479 | self.dim = dim
480 | self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
481 | self.norm = norm_layer(4 * dim)
482 |
483 | def forward(self, x, H, W):
484 | """Forward function.
485 |
486 | Args:
487 | x: Input feature, tensor size (B, H*W, C).
488 | H, W: Spatial resolution of the input feature.
489 | """
490 |
491 | B, L, C = x.shape
492 | assert L == H * W, "input feature has wrong size"
493 |
494 | x = x.view(B, H, W, C)
495 |
496 | # padding
497 | pad_input = (H % 2 == 1) or (W % 2 == 1)
498 | if pad_input:
499 | x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
500 |
501 | x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
502 | x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
503 | x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
504 | x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
505 | x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
506 | x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
507 |
508 | x = self.norm(x)
509 | x = self.reduction(x)
510 |
511 | return x
512 |
513 |
514 | class BasicLayer(nn.Module):
515 | """A basic Swin Transformer layer for one stage.
516 |
517 | Args:
518 | dim (int): Number of feature channels
519 | depth (int): Depths of this stage.
520 | num_heads (int): Number of attention head.
521 | window_size (int): Local window size. Default: 7.
522 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
523 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
524 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
525 | drop (float, optional): Dropout rate. Default: 0.0
526 | attn_drop (float, optional): Attention dropout rate. Default: 0.0
527 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
528 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
529 | downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
530 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
531 | """
532 |
533 | def __init__(
534 | self,
535 | dim,
536 | depth,
537 | num_heads,
538 | window_size=7,
539 | mlp_ratio=4.0,
540 | qkv_bias=True,
541 | qk_scale=None,
542 | drop=0.0,
543 | attn_drop=0.0,
544 | drop_path=0.0,
545 | norm_layer=LayerNormParallel,
546 | downsample=None,
547 | use_checkpoint=False,
548 | ):
549 | super().__init__()
550 | self.window_size = window_size
551 | self.shift_size = window_size // 2
552 | self.depth = depth
553 | self.use_checkpoint = use_checkpoint
554 |
555 | # build blocks
556 | self.blocks = nn.ModuleList(
557 | [
558 | SwinTransformerBlock(
559 | dim=dim,
560 | num_heads=num_heads,
561 | window_size=window_size,
562 | shift_size=0 if (i % 2 == 0) else window_size // 2,
563 | mlp_ratio=mlp_ratio,
564 | qkv_bias=qkv_bias,
565 | qk_scale=qk_scale,
566 | drop=drop,
567 | attn_drop=attn_drop,
568 | drop_path=(
569 | drop_path[i] if isinstance(drop_path, list) else drop_path
570 | ),
571 | norm_layer=norm_layer,
572 | )
573 | for i in range(depth)
574 | ]
575 | )
576 |
577 | # patch merging layer
578 | if downsample is not None:
579 | self.downsample = Additional_Two_ModuleParallel(
580 | downsample(dim=dim, norm_layer=nn.LayerNorm)
581 | )
582 | else:
583 | self.downsample = None
584 |
585 | def forward(self, x, H, W):
586 | """Forward function.
587 |
588 | Args:
589 | x: Input feature, tensor size (B, H*W, C).
590 | H, W: Spatial resolution of the input feature.
591 | """
592 |
593 | # calculate attention mask for SW-MSA
594 | Hp = int(np.ceil(H / self.window_size)) * self.window_size
595 | Wp = int(np.ceil(W / self.window_size)) * self.window_size
596 | img_mask = torch.zeros((1, Hp, Wp, 1), device=x[0].device) # 1 Hp Wp 1
597 | h_slices = (
598 | slice(0, -self.window_size),
599 | slice(-self.window_size, -self.shift_size),
600 | slice(-self.shift_size, None),
601 | )
602 | w_slices = (
603 | slice(0, -self.window_size),
604 | slice(-self.window_size, -self.shift_size),
605 | slice(-self.shift_size, None),
606 | )
607 | cnt = 0
608 | for h in h_slices:
609 | for w in w_slices:
610 | img_mask[:, h, w, :] = cnt
611 | cnt += 1
612 |
613 | mask_windows = window_partition(
614 | img_mask, self.window_size
615 | ) # nW, window_size, window_size, 1
616 | mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
617 | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
618 | attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(
619 | attn_mask == 0, float(0.0)
620 | )
621 |
622 | for blk in self.blocks:
623 | blk.H, blk.W = H, W
624 | if self.use_checkpoint:
625 | x = checkpoint.checkpoint(blk, x, attn_mask)
626 | else:
627 | x = blk(x, attn_mask)
628 | if self.downsample is not None:
629 | x_down = self.downsample(x, H, W)
630 | Wh, Ww = (H + 1) // 2, (W + 1) // 2
631 | return x, H, W, x_down, Wh, Ww
632 | else:
633 | return x, H, W, x, H, W
634 |
635 |
636 | class PatchEmbed(nn.Module):
637 | """Image to Patch Embedding
638 |
639 | Args:
640 | patch_size (int): Patch token size. Default: 4.
641 | in_chans (int): Number of input image channels. Default: 3.
642 | embed_dim (int): Number of linear projection output channels. Default: 96.
643 | norm_layer (nn.Module, optional): Normalization layer. Default: None
644 | """
645 |
646 | def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
647 | super().__init__()
648 | patch_size = to_2tuple(patch_size)
649 | self.patch_size = patch_size
650 |
651 | self.in_chans = in_chans
652 | self.embed_dim = embed_dim
653 |
654 | self.proj = ModuleParallel(
655 | nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
656 | )
657 | if norm_layer is not None:
658 | self.norm = norm_layer(embed_dim)
659 | else:
660 | self.norm = None
661 |
662 | def forward(self, x):
663 | """Forward function."""
664 | # padding
665 | _, _, H, W = x[0].size()
666 | if W % self.patch_size[1] != 0:
667 | x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
668 | if H % self.patch_size[0] != 0:
669 | x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
670 |
671 | x = self.proj(x) # B C Wh Ww
672 | if self.norm is not None:
673 | Wh, Ww = x[0].size(2), x[0].size(3)
674 | for i in range(len(x)):
675 | x[i] = x[i].flatten(2).transpose(1, 2)
676 | x = self.norm(x)
677 | for i in range(len(x)):
678 | x[i] = x[i].transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
679 |
680 | return x
681 |
682 |
683 | class SwinTransformer(nn.Module):
684 | """Swin Transformer backbone.
685 | A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
686 | https://arxiv.org/pdf/2103.14030
687 |
688 | Args:
689 | pretrain_img_size (int): Input image size for training the pretrained model,
690 | used in absolute postion embedding. Default 224.
691 | patch_size (int | tuple(int)): Patch size. Default: 4.
692 | in_chans (int): Number of input image channels. Default: 3.
693 | embed_dim (int): Number of linear projection output channels. Default: 96.
694 | depths (tuple[int]): Depths of each Swin Transformer stage.
695 | num_heads (tuple[int]): Number of attention head of each stage.
696 | window_size (int): Window size. Default: 7.
697 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
698 | qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
699 | qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
700 | drop_rate (float): Dropout rate.
701 | attn_drop_rate (float): Attention dropout rate. Default: 0.
702 | drop_path_rate (float): Stochastic depth rate. Default: 0.2.
703 | norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
704 | ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
705 | patch_norm (bool): If True, add normalization after patch embedding. Default: True.
706 | out_indices (Sequence[int]): Output from which stages.
707 | frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
708 | -1 means not freezing any parameters.
709 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
710 | """
711 |
712 | def __init__(
713 | self,
714 | pretrain_img_size=224,
715 | patch_size=4,
716 | in_chans=3,
717 | embed_dim=96,
718 | depths=[2, 2, 6, 2],
719 | num_heads=[3, 6, 12, 24],
720 | window_size=7,
721 | mlp_ratio=4.0,
722 | qkv_bias=True,
723 | qk_scale=None,
724 | drop_rate=0.0,
725 | attn_drop_rate=0.0,
726 | drop_path_rate=0.2,
727 | norm_layer=LayerNormParallel,
728 | ape=False,
729 | patch_norm=True,
730 | out_indices=(0, 1, 2, 3),
731 | frozen_stages=-1,
732 | use_checkpoint=False,
733 | ):
734 | super().__init__()
735 | self.drop_path_rate = drop_path_rate
736 | self.pretrain_img_size = pretrain_img_size
737 | self.num_layers = len(depths)
738 | self.embed_dim = embed_dim
739 | self.ape = ape
740 | self.patch_norm = patch_norm
741 | self.out_indices = out_indices
742 | self.frozen_stages = frozen_stages
743 |
744 | # split image into non-overlapping patches
745 | self.patch_embed = PatchEmbed(
746 | patch_size=patch_size,
747 | in_chans=in_chans,
748 | embed_dim=embed_dim,
749 | norm_layer=norm_layer if self.patch_norm else None,
750 | )
751 |
752 | # absolute position embedding
753 | if self.ape:
754 | pretrain_img_size = to_2tuple(pretrain_img_size)
755 | patch_size = to_2tuple(patch_size)
756 | patches_resolution = [
757 | pretrain_img_size[0] // patch_size[0],
758 | pretrain_img_size[1] // patch_size[1],
759 | ]
760 |
761 | self.absolute_pos_embed = nn.Parameter(
762 | torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1])
763 | )
764 | trunc_normal_(self.absolute_pos_embed, std=0.02)
765 |
766 | self.pos_drop = ModuleParallel(nn.Dropout(p=drop_rate))
767 |
768 | # stochastic depth
769 | dpr = [
770 | x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
771 | ] # stochastic depth decay rule
772 |
773 | # build layers
774 | self.layers = nn.ModuleList()
775 | for i_layer in range(self.num_layers):
776 | layer = BasicLayer(
777 | dim=int(embed_dim * 2**i_layer),
778 | depth=depths[i_layer],
779 | num_heads=num_heads[i_layer],
780 | window_size=window_size,
781 | mlp_ratio=mlp_ratio,
782 | qkv_bias=qkv_bias,
783 | qk_scale=qk_scale,
784 | drop=drop_rate,
785 | attn_drop=attn_drop_rate,
786 | drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])],
787 | norm_layer=norm_layer,
788 | downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
789 | use_checkpoint=use_checkpoint,
790 | )
791 | self.layers.append(layer)
792 |
793 | num_features = [int(embed_dim * 2**i) for i in range(self.num_layers)]
794 | self.num_features = num_features
795 |
796 | # add a norm layer for each output
797 | for i_layer in out_indices:
798 | layer = norm_layer(num_features[i_layer])
799 | layer_name = f"norm{i_layer}"
800 | self.add_module(layer_name, layer)
801 |
802 | self._freeze_stages()
803 |
804 | def _freeze_stages(self):
805 | if self.frozen_stages >= 0:
806 | self.patch_embed.eval()
807 | for param in self.patch_embed.parameters():
808 | param.requires_grad = False
809 |
810 | if self.frozen_stages >= 1 and self.ape:
811 | self.absolute_pos_embed.requires_grad = False
812 |
813 | if self.frozen_stages >= 2:
814 | self.pos_drop.eval()
815 | for i in range(0, self.frozen_stages - 1):
816 | m = self.layers[i]
817 | m.eval()
818 | for param in m.parameters():
819 | param.requires_grad = False
820 |
821 | def init_weights(self, pretrained=None):
822 | """Initialize the weights in backbone.
823 |
824 | Args:
825 | pretrained (str, optional): Path to pre-trained weights.
826 | Defaults to None.
827 | """
828 |
829 | def _init_weights(m):
830 | pass
831 |
832 | if isinstance(pretrained, str):
833 | self.apply(_init_weights)
834 | logger = get_root_logger()
835 | load_checkpoint(self, pretrained, strict=False, logger=logger)
836 | elif pretrained is None:
837 | self.apply(_init_weights)
838 | else:
839 | raise TypeError("pretrained must be a str or None")
840 |
841 | def forward(self, x):
842 | """Forward function."""
843 | x = self.patch_embed(x)
844 |
845 | Wh, Ww = x[0].size(2), x[0].size(3)
846 | if self.ape:
847 | # interpolate the position embedding to the corresponding size
848 | absolute_pos_embed = F.interpolate(
849 | self.absolute_pos_embed, size=(Wh, Ww), mode="bicubic"
850 | )
851 | x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C
852 | else:
853 | for i in range(len(x)):
854 | x[i] = x[i].flatten(2).transpose(1, 2)
855 | x = self.pos_drop(x)
856 |
857 | outs = {}
858 | for i in range(self.num_layers):
859 | layer = self.layers[i]
860 | x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
861 |
862 | if i in self.out_indices:
863 | norm_layer = getattr(self, f"norm{i}")
864 | x_out = norm_layer(x_out)
865 |
866 | out = [
867 | x_out[j]
868 | .view(-1, H, W, self.num_features[i])
869 | .permute(0, 3, 1, 2)
870 | .contiguous()
871 | for j in range(len(x_out))
872 | ]
873 | outs[i] = out
874 |
875 | new_x0 = []
876 | new_x1 = []
877 | for i in range(4):
878 | new_x0.append(outs[i][0])
879 | new_x1.append(outs[i][1])
880 | x = [new_x0, new_x1]
881 |
882 | return x
883 |
884 | def train(self, mode=True):
885 | """Convert the model into training mode while keep layers freezed."""
886 | super(SwinTransformer, self).train(mode)
887 | self._freeze_stages()
888 |
--------------------------------------------------------------------------------
/data/nyudv2/val.txt:
--------------------------------------------------------------------------------
1 | rgb/000001.png masks/000001.png depth/000001.png
2 | rgb/000002.png masks/000002.png depth/000002.png
3 | rgb/000009.png masks/000009.png depth/000009.png
4 | rgb/000014.png masks/000014.png depth/000014.png
5 | rgb/000015.png masks/000015.png depth/000015.png
6 | rgb/000016.png masks/000016.png depth/000016.png
7 | rgb/000017.png masks/000017.png depth/000017.png
8 | rgb/000018.png masks/000018.png depth/000018.png
9 | rgb/000021.png masks/000021.png depth/000021.png
10 | rgb/000028.png masks/000028.png depth/000028.png
11 | rgb/000029.png masks/000029.png depth/000029.png
12 | rgb/000030.png masks/000030.png depth/000030.png
13 | rgb/000031.png masks/000031.png depth/000031.png
14 | rgb/000032.png masks/000032.png depth/000032.png
15 | rgb/000033.png masks/000033.png depth/000033.png
16 | rgb/000034.png masks/000034.png depth/000034.png
17 | rgb/000035.png masks/000035.png depth/000035.png
18 | rgb/000036.png masks/000036.png depth/000036.png
19 | rgb/000037.png masks/000037.png depth/000037.png
20 | rgb/000038.png masks/000038.png depth/000038.png
21 | rgb/000039.png masks/000039.png depth/000039.png
22 | rgb/000040.png masks/000040.png depth/000040.png
23 | rgb/000041.png masks/000041.png depth/000041.png
24 | rgb/000042.png masks/000042.png depth/000042.png
25 | rgb/000043.png masks/000043.png depth/000043.png
26 | rgb/000046.png masks/000046.png depth/000046.png
27 | rgb/000047.png masks/000047.png depth/000047.png
28 | rgb/000056.png masks/000056.png depth/000056.png
29 | rgb/000057.png masks/000057.png depth/000057.png
30 | rgb/000059.png masks/000059.png depth/000059.png
31 | rgb/000060.png masks/000060.png depth/000060.png
32 | rgb/000061.png masks/000061.png depth/000061.png
33 | rgb/000062.png masks/000062.png depth/000062.png
34 | rgb/000063.png masks/000063.png depth/000063.png
35 | rgb/000076.png masks/000076.png depth/000076.png
36 | rgb/000077.png masks/000077.png depth/000077.png
37 | rgb/000078.png masks/000078.png depth/000078.png
38 | rgb/000079.png masks/000079.png depth/000079.png
39 | rgb/000084.png masks/000084.png depth/000084.png
40 | rgb/000085.png masks/000085.png depth/000085.png
41 | rgb/000086.png masks/000086.png depth/000086.png
42 | rgb/000087.png masks/000087.png depth/000087.png
43 | rgb/000088.png masks/000088.png depth/000088.png
44 | rgb/000089.png masks/000089.png depth/000089.png
45 | rgb/000090.png masks/000090.png depth/000090.png
46 | rgb/000091.png masks/000091.png depth/000091.png
47 | rgb/000117.png masks/000117.png depth/000117.png
48 | rgb/000118.png masks/000118.png depth/000118.png
49 | rgb/000119.png masks/000119.png depth/000119.png
50 | rgb/000125.png masks/000125.png depth/000125.png
51 | rgb/000126.png masks/000126.png depth/000126.png
52 | rgb/000127.png masks/000127.png depth/000127.png
53 | rgb/000128.png masks/000128.png depth/000128.png
54 | rgb/000129.png masks/000129.png depth/000129.png
55 | rgb/000131.png masks/000131.png depth/000131.png
56 | rgb/000132.png masks/000132.png depth/000132.png
57 | rgb/000133.png masks/000133.png depth/000133.png
58 | rgb/000134.png masks/000134.png depth/000134.png
59 | rgb/000137.png masks/000137.png depth/000137.png
60 | rgb/000153.png masks/000153.png depth/000153.png
61 | rgb/000154.png masks/000154.png depth/000154.png
62 | rgb/000155.png masks/000155.png depth/000155.png
63 | rgb/000167.png masks/000167.png depth/000167.png
64 | rgb/000168.png masks/000168.png depth/000168.png
65 | rgb/000169.png masks/000169.png depth/000169.png
66 | rgb/000171.png masks/000171.png depth/000171.png
67 | rgb/000172.png masks/000172.png depth/000172.png
68 | rgb/000173.png masks/000173.png depth/000173.png
69 | rgb/000174.png masks/000174.png depth/000174.png
70 | rgb/000175.png masks/000175.png depth/000175.png
71 | rgb/000176.png masks/000176.png depth/000176.png
72 | rgb/000180.png masks/000180.png depth/000180.png
73 | rgb/000181.png masks/000181.png depth/000181.png
74 | rgb/000182.png masks/000182.png depth/000182.png
75 | rgb/000183.png masks/000183.png depth/000183.png
76 | rgb/000184.png masks/000184.png depth/000184.png
77 | rgb/000185.png masks/000185.png depth/000185.png
78 | rgb/000186.png masks/000186.png depth/000186.png
79 | rgb/000187.png masks/000187.png depth/000187.png
80 | rgb/000188.png masks/000188.png depth/000188.png
81 | rgb/000189.png masks/000189.png depth/000189.png
82 | rgb/000190.png masks/000190.png depth/000190.png
83 | rgb/000191.png masks/000191.png depth/000191.png
84 | rgb/000192.png masks/000192.png depth/000192.png
85 | rgb/000193.png masks/000193.png depth/000193.png
86 | rgb/000194.png masks/000194.png depth/000194.png
87 | rgb/000195.png masks/000195.png depth/000195.png
88 | rgb/000196.png masks/000196.png depth/000196.png
89 | rgb/000197.png masks/000197.png depth/000197.png
90 | rgb/000198.png masks/000198.png depth/000198.png
91 | rgb/000199.png masks/000199.png depth/000199.png
92 | rgb/000200.png masks/000200.png depth/000200.png
93 | rgb/000201.png masks/000201.png depth/000201.png
94 | rgb/000202.png masks/000202.png depth/000202.png
95 | rgb/000207.png masks/000207.png depth/000207.png
96 | rgb/000208.png masks/000208.png depth/000208.png
97 | rgb/000209.png masks/000209.png depth/000209.png
98 | rgb/000210.png masks/000210.png depth/000210.png
99 | rgb/000211.png masks/000211.png depth/000211.png
100 | rgb/000212.png masks/000212.png depth/000212.png
101 | rgb/000220.png masks/000220.png depth/000220.png
102 | rgb/000221.png masks/000221.png depth/000221.png
103 | rgb/000222.png masks/000222.png depth/000222.png
104 | rgb/000250.png masks/000250.png depth/000250.png
105 | rgb/000264.png masks/000264.png depth/000264.png
106 | rgb/000271.png masks/000271.png depth/000271.png
107 | rgb/000272.png masks/000272.png depth/000272.png
108 | rgb/000273.png masks/000273.png depth/000273.png
109 | rgb/000279.png masks/000279.png depth/000279.png
110 | rgb/000280.png masks/000280.png depth/000280.png
111 | rgb/000281.png masks/000281.png depth/000281.png
112 | rgb/000282.png masks/000282.png depth/000282.png
113 | rgb/000283.png masks/000283.png depth/000283.png
114 | rgb/000284.png masks/000284.png depth/000284.png
115 | rgb/000285.png masks/000285.png depth/000285.png
116 | rgb/000296.png masks/000296.png depth/000296.png
117 | rgb/000297.png masks/000297.png depth/000297.png
118 | rgb/000298.png masks/000298.png depth/000298.png
119 | rgb/000299.png masks/000299.png depth/000299.png
120 | rgb/000300.png masks/000300.png depth/000300.png
121 | rgb/000301.png masks/000301.png depth/000301.png
122 | rgb/000302.png masks/000302.png depth/000302.png
123 | rgb/000310.png masks/000310.png depth/000310.png
124 | rgb/000311.png masks/000311.png depth/000311.png
125 | rgb/000312.png masks/000312.png depth/000312.png
126 | rgb/000315.png masks/000315.png depth/000315.png
127 | rgb/000316.png masks/000316.png depth/000316.png
128 | rgb/000317.png masks/000317.png depth/000317.png
129 | rgb/000325.png masks/000325.png depth/000325.png
130 | rgb/000326.png masks/000326.png depth/000326.png
131 | rgb/000327.png masks/000327.png depth/000327.png
132 | rgb/000328.png masks/000328.png depth/000328.png
133 | rgb/000329.png masks/000329.png depth/000329.png
134 | rgb/000330.png masks/000330.png depth/000330.png
135 | rgb/000331.png masks/000331.png depth/000331.png
136 | rgb/000332.png masks/000332.png depth/000332.png
137 | rgb/000333.png masks/000333.png depth/000333.png
138 | rgb/000334.png masks/000334.png depth/000334.png
139 | rgb/000335.png masks/000335.png depth/000335.png
140 | rgb/000351.png masks/000351.png depth/000351.png
141 | rgb/000352.png masks/000352.png depth/000352.png
142 | rgb/000355.png masks/000355.png depth/000355.png
143 | rgb/000356.png masks/000356.png depth/000356.png
144 | rgb/000357.png masks/000357.png depth/000357.png
145 | rgb/000358.png masks/000358.png depth/000358.png
146 | rgb/000359.png masks/000359.png depth/000359.png
147 | rgb/000360.png masks/000360.png depth/000360.png
148 | rgb/000361.png masks/000361.png depth/000361.png
149 | rgb/000362.png masks/000362.png depth/000362.png
150 | rgb/000363.png masks/000363.png depth/000363.png
151 | rgb/000364.png masks/000364.png depth/000364.png
152 | rgb/000384.png masks/000384.png depth/000384.png
153 | rgb/000385.png masks/000385.png depth/000385.png
154 | rgb/000386.png masks/000386.png depth/000386.png
155 | rgb/000387.png masks/000387.png depth/000387.png
156 | rgb/000388.png masks/000388.png depth/000388.png
157 | rgb/000389.png masks/000389.png depth/000389.png
158 | rgb/000390.png masks/000390.png depth/000390.png
159 | rgb/000395.png masks/000395.png depth/000395.png
160 | rgb/000396.png masks/000396.png depth/000396.png
161 | rgb/000397.png masks/000397.png depth/000397.png
162 | rgb/000411.png masks/000411.png depth/000411.png
163 | rgb/000412.png masks/000412.png depth/000412.png
164 | rgb/000413.png masks/000413.png depth/000413.png
165 | rgb/000414.png masks/000414.png depth/000414.png
166 | rgb/000430.png masks/000430.png depth/000430.png
167 | rgb/000431.png masks/000431.png depth/000431.png
168 | rgb/000432.png masks/000432.png depth/000432.png
169 | rgb/000433.png masks/000433.png depth/000433.png
170 | rgb/000434.png masks/000434.png depth/000434.png
171 | rgb/000435.png masks/000435.png depth/000435.png
172 | rgb/000441.png masks/000441.png depth/000441.png
173 | rgb/000442.png masks/000442.png depth/000442.png
174 | rgb/000443.png masks/000443.png depth/000443.png
175 | rgb/000444.png masks/000444.png depth/000444.png
176 | rgb/000445.png masks/000445.png depth/000445.png
177 | rgb/000446.png masks/000446.png depth/000446.png
178 | rgb/000447.png masks/000447.png depth/000447.png
179 | rgb/000448.png masks/000448.png depth/000448.png
180 | rgb/000462.png masks/000462.png depth/000462.png
181 | rgb/000463.png masks/000463.png depth/000463.png
182 | rgb/000464.png masks/000464.png depth/000464.png
183 | rgb/000465.png masks/000465.png depth/000465.png
184 | rgb/000466.png masks/000466.png depth/000466.png
185 | rgb/000469.png masks/000469.png depth/000469.png
186 | rgb/000470.png masks/000470.png depth/000470.png
187 | rgb/000471.png masks/000471.png depth/000471.png
188 | rgb/000472.png masks/000472.png depth/000472.png
189 | rgb/000473.png masks/000473.png depth/000473.png
190 | rgb/000474.png masks/000474.png depth/000474.png
191 | rgb/000475.png masks/000475.png depth/000475.png
192 | rgb/000476.png masks/000476.png depth/000476.png
193 | rgb/000477.png masks/000477.png depth/000477.png
194 | rgb/000508.png masks/000508.png depth/000508.png
195 | rgb/000509.png masks/000509.png depth/000509.png
196 | rgb/000510.png masks/000510.png depth/000510.png
197 | rgb/000511.png masks/000511.png depth/000511.png
198 | rgb/000512.png masks/000512.png depth/000512.png
199 | rgb/000513.png masks/000513.png depth/000513.png
200 | rgb/000515.png masks/000515.png depth/000515.png
201 | rgb/000516.png masks/000516.png depth/000516.png
202 | rgb/000517.png masks/000517.png depth/000517.png
203 | rgb/000518.png masks/000518.png depth/000518.png
204 | rgb/000519.png masks/000519.png depth/000519.png
205 | rgb/000520.png masks/000520.png depth/000520.png
206 | rgb/000521.png masks/000521.png depth/000521.png
207 | rgb/000522.png masks/000522.png depth/000522.png
208 | rgb/000523.png masks/000523.png depth/000523.png
209 | rgb/000524.png masks/000524.png depth/000524.png
210 | rgb/000525.png masks/000525.png depth/000525.png
211 | rgb/000526.png masks/000526.png depth/000526.png
212 | rgb/000531.png masks/000531.png depth/000531.png
213 | rgb/000532.png masks/000532.png depth/000532.png
214 | rgb/000533.png masks/000533.png depth/000533.png
215 | rgb/000537.png masks/000537.png depth/000537.png
216 | rgb/000538.png masks/000538.png depth/000538.png
217 | rgb/000539.png masks/000539.png depth/000539.png
218 | rgb/000549.png masks/000549.png depth/000549.png
219 | rgb/000550.png masks/000550.png depth/000550.png
220 | rgb/000551.png masks/000551.png depth/000551.png
221 | rgb/000555.png masks/000555.png depth/000555.png
222 | rgb/000556.png masks/000556.png depth/000556.png
223 | rgb/000557.png masks/000557.png depth/000557.png
224 | rgb/000558.png masks/000558.png depth/000558.png
225 | rgb/000559.png masks/000559.png depth/000559.png
226 | rgb/000560.png masks/000560.png depth/000560.png
227 | rgb/000561.png masks/000561.png depth/000561.png
228 | rgb/000562.png masks/000562.png depth/000562.png
229 | rgb/000563.png masks/000563.png depth/000563.png
230 | rgb/000564.png masks/000564.png depth/000564.png
231 | rgb/000565.png masks/000565.png depth/000565.png
232 | rgb/000566.png masks/000566.png depth/000566.png
233 | rgb/000567.png masks/000567.png depth/000567.png
234 | rgb/000568.png masks/000568.png depth/000568.png
235 | rgb/000569.png masks/000569.png depth/000569.png
236 | rgb/000570.png masks/000570.png depth/000570.png
237 | rgb/000571.png masks/000571.png depth/000571.png
238 | rgb/000579.png masks/000579.png depth/000579.png
239 | rgb/000580.png masks/000580.png depth/000580.png
240 | rgb/000581.png masks/000581.png depth/000581.png
241 | rgb/000582.png masks/000582.png depth/000582.png
242 | rgb/000583.png masks/000583.png depth/000583.png
243 | rgb/000591.png masks/000591.png depth/000591.png
244 | rgb/000592.png masks/000592.png depth/000592.png
245 | rgb/000593.png masks/000593.png depth/000593.png
246 | rgb/000594.png masks/000594.png depth/000594.png
247 | rgb/000603.png masks/000603.png depth/000603.png
248 | rgb/000604.png masks/000604.png depth/000604.png
249 | rgb/000605.png masks/000605.png depth/000605.png
250 | rgb/000606.png masks/000606.png depth/000606.png
251 | rgb/000607.png masks/000607.png depth/000607.png
252 | rgb/000612.png masks/000612.png depth/000612.png
253 | rgb/000613.png masks/000613.png depth/000613.png
254 | rgb/000617.png masks/000617.png depth/000617.png
255 | rgb/000618.png masks/000618.png depth/000618.png
256 | rgb/000619.png masks/000619.png depth/000619.png
257 | rgb/000620.png masks/000620.png depth/000620.png
258 | rgb/000621.png masks/000621.png depth/000621.png
259 | rgb/000633.png masks/000633.png depth/000633.png
260 | rgb/000634.png masks/000634.png depth/000634.png
261 | rgb/000635.png masks/000635.png depth/000635.png
262 | rgb/000636.png masks/000636.png depth/000636.png
263 | rgb/000637.png masks/000637.png depth/000637.png
264 | rgb/000638.png masks/000638.png depth/000638.png
265 | rgb/000644.png masks/000644.png depth/000644.png
266 | rgb/000645.png masks/000645.png depth/000645.png
267 | rgb/000650.png masks/000650.png depth/000650.png
268 | rgb/000651.png masks/000651.png depth/000651.png
269 | rgb/000656.png masks/000656.png depth/000656.png
270 | rgb/000657.png masks/000657.png depth/000657.png
271 | rgb/000658.png masks/000658.png depth/000658.png
272 | rgb/000663.png masks/000663.png depth/000663.png
273 | rgb/000664.png masks/000664.png depth/000664.png
274 | rgb/000668.png masks/000668.png depth/000668.png
275 | rgb/000669.png masks/000669.png depth/000669.png
276 | rgb/000670.png masks/000670.png depth/000670.png
277 | rgb/000671.png masks/000671.png depth/000671.png
278 | rgb/000672.png masks/000672.png depth/000672.png
279 | rgb/000673.png masks/000673.png depth/000673.png
280 | rgb/000676.png masks/000676.png depth/000676.png
281 | rgb/000677.png masks/000677.png depth/000677.png
282 | rgb/000678.png masks/000678.png depth/000678.png
283 | rgb/000679.png masks/000679.png depth/000679.png
284 | rgb/000680.png masks/000680.png depth/000680.png
285 | rgb/000681.png masks/000681.png depth/000681.png
286 | rgb/000686.png masks/000686.png depth/000686.png
287 | rgb/000687.png masks/000687.png depth/000687.png
288 | rgb/000688.png masks/000688.png depth/000688.png
289 | rgb/000689.png masks/000689.png depth/000689.png
290 | rgb/000690.png masks/000690.png depth/000690.png
291 | rgb/000693.png masks/000693.png depth/000693.png
292 | rgb/000694.png masks/000694.png depth/000694.png
293 | rgb/000697.png masks/000697.png depth/000697.png
294 | rgb/000698.png masks/000698.png depth/000698.png
295 | rgb/000699.png masks/000699.png depth/000699.png
296 | rgb/000706.png masks/000706.png depth/000706.png
297 | rgb/000707.png masks/000707.png depth/000707.png
298 | rgb/000708.png masks/000708.png depth/000708.png
299 | rgb/000709.png masks/000709.png depth/000709.png
300 | rgb/000710.png masks/000710.png depth/000710.png
301 | rgb/000711.png masks/000711.png depth/000711.png
302 | rgb/000712.png masks/000712.png depth/000712.png
303 | rgb/000713.png masks/000713.png depth/000713.png
304 | rgb/000717.png masks/000717.png depth/000717.png
305 | rgb/000718.png masks/000718.png depth/000718.png
306 | rgb/000724.png masks/000724.png depth/000724.png
307 | rgb/000725.png masks/000725.png depth/000725.png
308 | rgb/000726.png masks/000726.png depth/000726.png
309 | rgb/000727.png masks/000727.png depth/000727.png
310 | rgb/000728.png masks/000728.png depth/000728.png
311 | rgb/000731.png masks/000731.png depth/000731.png
312 | rgb/000732.png masks/000732.png depth/000732.png
313 | rgb/000733.png masks/000733.png depth/000733.png
314 | rgb/000734.png masks/000734.png depth/000734.png
315 | rgb/000743.png masks/000743.png depth/000743.png
316 | rgb/000744.png masks/000744.png depth/000744.png
317 | rgb/000759.png masks/000759.png depth/000759.png
318 | rgb/000760.png masks/000760.png depth/000760.png
319 | rgb/000761.png masks/000761.png depth/000761.png
320 | rgb/000762.png masks/000762.png depth/000762.png
321 | rgb/000763.png masks/000763.png depth/000763.png
322 | rgb/000764.png masks/000764.png depth/000764.png
323 | rgb/000765.png masks/000765.png depth/000765.png
324 | rgb/000766.png masks/000766.png depth/000766.png
325 | rgb/000767.png masks/000767.png depth/000767.png
326 | rgb/000768.png masks/000768.png depth/000768.png
327 | rgb/000769.png masks/000769.png depth/000769.png
328 | rgb/000770.png masks/000770.png depth/000770.png
329 | rgb/000771.png masks/000771.png depth/000771.png
330 | rgb/000772.png masks/000772.png depth/000772.png
331 | rgb/000773.png masks/000773.png depth/000773.png
332 | rgb/000774.png masks/000774.png depth/000774.png
333 | rgb/000775.png masks/000775.png depth/000775.png
334 | rgb/000776.png masks/000776.png depth/000776.png
335 | rgb/000777.png masks/000777.png depth/000777.png
336 | rgb/000778.png masks/000778.png depth/000778.png
337 | rgb/000779.png masks/000779.png depth/000779.png
338 | rgb/000780.png masks/000780.png depth/000780.png
339 | rgb/000781.png masks/000781.png depth/000781.png
340 | rgb/000782.png masks/000782.png depth/000782.png
341 | rgb/000783.png masks/000783.png depth/000783.png
342 | rgb/000784.png masks/000784.png depth/000784.png
343 | rgb/000785.png masks/000785.png depth/000785.png
344 | rgb/000786.png masks/000786.png depth/000786.png
345 | rgb/000787.png masks/000787.png depth/000787.png
346 | rgb/000800.png masks/000800.png depth/000800.png
347 | rgb/000801.png masks/000801.png depth/000801.png
348 | rgb/000802.png masks/000802.png depth/000802.png
349 | rgb/000803.png masks/000803.png depth/000803.png
350 | rgb/000804.png masks/000804.png depth/000804.png
351 | rgb/000810.png masks/000810.png depth/000810.png
352 | rgb/000811.png masks/000811.png depth/000811.png
353 | rgb/000812.png masks/000812.png depth/000812.png
354 | rgb/000813.png masks/000813.png depth/000813.png
355 | rgb/000814.png masks/000814.png depth/000814.png
356 | rgb/000821.png masks/000821.png depth/000821.png
357 | rgb/000822.png masks/000822.png depth/000822.png
358 | rgb/000823.png masks/000823.png depth/000823.png
359 | rgb/000833.png masks/000833.png depth/000833.png
360 | rgb/000834.png masks/000834.png depth/000834.png
361 | rgb/000835.png masks/000835.png depth/000835.png
362 | rgb/000836.png masks/000836.png depth/000836.png
363 | rgb/000837.png masks/000837.png depth/000837.png
364 | rgb/000838.png masks/000838.png depth/000838.png
365 | rgb/000839.png masks/000839.png depth/000839.png
366 | rgb/000840.png masks/000840.png depth/000840.png
367 | rgb/000841.png masks/000841.png depth/000841.png
368 | rgb/000842.png masks/000842.png depth/000842.png
369 | rgb/000843.png masks/000843.png depth/000843.png
370 | rgb/000844.png masks/000844.png depth/000844.png
371 | rgb/000845.png masks/000845.png depth/000845.png
372 | rgb/000846.png masks/000846.png depth/000846.png
373 | rgb/000850.png masks/000850.png depth/000850.png
374 | rgb/000851.png masks/000851.png depth/000851.png
375 | rgb/000852.png masks/000852.png depth/000852.png
376 | rgb/000857.png masks/000857.png depth/000857.png
377 | rgb/000858.png masks/000858.png depth/000858.png
378 | rgb/000859.png masks/000859.png depth/000859.png
379 | rgb/000860.png masks/000860.png depth/000860.png
380 | rgb/000861.png masks/000861.png depth/000861.png
381 | rgb/000862.png masks/000862.png depth/000862.png
382 | rgb/000869.png masks/000869.png depth/000869.png
383 | rgb/000870.png masks/000870.png depth/000870.png
384 | rgb/000871.png masks/000871.png depth/000871.png
385 | rgb/000906.png masks/000906.png depth/000906.png
386 | rgb/000907.png masks/000907.png depth/000907.png
387 | rgb/000908.png masks/000908.png depth/000908.png
388 | rgb/000917.png masks/000917.png depth/000917.png
389 | rgb/000918.png masks/000918.png depth/000918.png
390 | rgb/000919.png masks/000919.png depth/000919.png
391 | rgb/000926.png masks/000926.png depth/000926.png
392 | rgb/000927.png masks/000927.png depth/000927.png
393 | rgb/000928.png masks/000928.png depth/000928.png
394 | rgb/000932.png masks/000932.png depth/000932.png
395 | rgb/000933.png masks/000933.png depth/000933.png
396 | rgb/000934.png masks/000934.png depth/000934.png
397 | rgb/000935.png masks/000935.png depth/000935.png
398 | rgb/000945.png masks/000945.png depth/000945.png
399 | rgb/000946.png masks/000946.png depth/000946.png
400 | rgb/000947.png masks/000947.png depth/000947.png
401 | rgb/000959.png masks/000959.png depth/000959.png
402 | rgb/000960.png masks/000960.png depth/000960.png
403 | rgb/000961.png masks/000961.png depth/000961.png
404 | rgb/000962.png masks/000962.png depth/000962.png
405 | rgb/000965.png masks/000965.png depth/000965.png
406 | rgb/000966.png masks/000966.png depth/000966.png
407 | rgb/000967.png masks/000967.png depth/000967.png
408 | rgb/000970.png masks/000970.png depth/000970.png
409 | rgb/000971.png masks/000971.png depth/000971.png
410 | rgb/000972.png masks/000972.png depth/000972.png
411 | rgb/000973.png masks/000973.png depth/000973.png
412 | rgb/000974.png masks/000974.png depth/000974.png
413 | rgb/000975.png masks/000975.png depth/000975.png
414 | rgb/000976.png masks/000976.png depth/000976.png
415 | rgb/000977.png masks/000977.png depth/000977.png
416 | rgb/000991.png masks/000991.png depth/000991.png
417 | rgb/000992.png masks/000992.png depth/000992.png
418 | rgb/000993.png masks/000993.png depth/000993.png
419 | rgb/000994.png masks/000994.png depth/000994.png
420 | rgb/000995.png masks/000995.png depth/000995.png
421 | rgb/001001.png masks/001001.png depth/001001.png
422 | rgb/001002.png masks/001002.png depth/001002.png
423 | rgb/001003.png masks/001003.png depth/001003.png
424 | rgb/001004.png masks/001004.png depth/001004.png
425 | rgb/001010.png masks/001010.png depth/001010.png
426 | rgb/001011.png masks/001011.png depth/001011.png
427 | rgb/001012.png masks/001012.png depth/001012.png
428 | rgb/001021.png masks/001021.png depth/001021.png
429 | rgb/001022.png masks/001022.png depth/001022.png
430 | rgb/001023.png masks/001023.png depth/001023.png
431 | rgb/001032.png masks/001032.png depth/001032.png
432 | rgb/001033.png masks/001033.png depth/001033.png
433 | rgb/001034.png masks/001034.png depth/001034.png
434 | rgb/001038.png masks/001038.png depth/001038.png
435 | rgb/001039.png masks/001039.png depth/001039.png
436 | rgb/001048.png masks/001048.png depth/001048.png
437 | rgb/001049.png masks/001049.png depth/001049.png
438 | rgb/001052.png masks/001052.png depth/001052.png
439 | rgb/001053.png masks/001053.png depth/001053.png
440 | rgb/001057.png masks/001057.png depth/001057.png
441 | rgb/001058.png masks/001058.png depth/001058.png
442 | rgb/001075.png masks/001075.png depth/001075.png
443 | rgb/001076.png masks/001076.png depth/001076.png
444 | rgb/001077.png masks/001077.png depth/001077.png
445 | rgb/001078.png masks/001078.png depth/001078.png
446 | rgb/001079.png masks/001079.png depth/001079.png
447 | rgb/001080.png masks/001080.png depth/001080.png
448 | rgb/001081.png masks/001081.png depth/001081.png
449 | rgb/001082.png masks/001082.png depth/001082.png
450 | rgb/001083.png masks/001083.png depth/001083.png
451 | rgb/001084.png masks/001084.png depth/001084.png
452 | rgb/001088.png masks/001088.png depth/001088.png
453 | rgb/001089.png masks/001089.png depth/001089.png
454 | rgb/001090.png masks/001090.png depth/001090.png
455 | rgb/001091.png masks/001091.png depth/001091.png
456 | rgb/001092.png masks/001092.png depth/001092.png
457 | rgb/001093.png masks/001093.png depth/001093.png
458 | rgb/001094.png masks/001094.png depth/001094.png
459 | rgb/001095.png masks/001095.png depth/001095.png
460 | rgb/001096.png masks/001096.png depth/001096.png
461 | rgb/001098.png masks/001098.png depth/001098.png
462 | rgb/001099.png masks/001099.png depth/001099.png
463 | rgb/001100.png masks/001100.png depth/001100.png
464 | rgb/001101.png masks/001101.png depth/001101.png
465 | rgb/001102.png masks/001102.png depth/001102.png
466 | rgb/001103.png masks/001103.png depth/001103.png
467 | rgb/001104.png masks/001104.png depth/001104.png
468 | rgb/001106.png masks/001106.png depth/001106.png
469 | rgb/001107.png masks/001107.png depth/001107.png
470 | rgb/001108.png masks/001108.png depth/001108.png
471 | rgb/001109.png masks/001109.png depth/001109.png
472 | rgb/001117.png masks/001117.png depth/001117.png
473 | rgb/001118.png masks/001118.png depth/001118.png
474 | rgb/001119.png masks/001119.png depth/001119.png
475 | rgb/001123.png masks/001123.png depth/001123.png
476 | rgb/001124.png masks/001124.png depth/001124.png
477 | rgb/001125.png masks/001125.png depth/001125.png
478 | rgb/001126.png masks/001126.png depth/001126.png
479 | rgb/001127.png masks/001127.png depth/001127.png
480 | rgb/001128.png masks/001128.png depth/001128.png
481 | rgb/001129.png masks/001129.png depth/001129.png
482 | rgb/001130.png masks/001130.png depth/001130.png
483 | rgb/001131.png masks/001131.png depth/001131.png
484 | rgb/001135.png masks/001135.png depth/001135.png
485 | rgb/001136.png masks/001136.png depth/001136.png
486 | rgb/001144.png masks/001144.png depth/001144.png
487 | rgb/001145.png masks/001145.png depth/001145.png
488 | rgb/001146.png masks/001146.png depth/001146.png
489 | rgb/001147.png masks/001147.png depth/001147.png
490 | rgb/001148.png masks/001148.png depth/001148.png
491 | rgb/001149.png masks/001149.png depth/001149.png
492 | rgb/001150.png masks/001150.png depth/001150.png
493 | rgb/001151.png masks/001151.png depth/001151.png
494 | rgb/001152.png masks/001152.png depth/001152.png
495 | rgb/001153.png masks/001153.png depth/001153.png
496 | rgb/001154.png masks/001154.png depth/001154.png
497 | rgb/001155.png masks/001155.png depth/001155.png
498 | rgb/001156.png masks/001156.png depth/001156.png
499 | rgb/001157.png masks/001157.png depth/001157.png
500 | rgb/001158.png masks/001158.png depth/001158.png
501 | rgb/001162.png masks/001162.png depth/001162.png
502 | rgb/001163.png masks/001163.png depth/001163.png
503 | rgb/001164.png masks/001164.png depth/001164.png
504 | rgb/001165.png masks/001165.png depth/001165.png
505 | rgb/001166.png masks/001166.png depth/001166.png
506 | rgb/001167.png masks/001167.png depth/001167.png
507 | rgb/001170.png masks/001170.png depth/001170.png
508 | rgb/001171.png masks/001171.png depth/001171.png
509 | rgb/001174.png masks/001174.png depth/001174.png
510 | rgb/001175.png masks/001175.png depth/001175.png
511 | rgb/001176.png masks/001176.png depth/001176.png
512 | rgb/001179.png masks/001179.png depth/001179.png
513 | rgb/001180.png masks/001180.png depth/001180.png
514 | rgb/001181.png masks/001181.png depth/001181.png
515 | rgb/001182.png masks/001182.png depth/001182.png
516 | rgb/001183.png masks/001183.png depth/001183.png
517 | rgb/001184.png masks/001184.png depth/001184.png
518 | rgb/001192.png masks/001192.png depth/001192.png
519 | rgb/001193.png masks/001193.png depth/001193.png
520 | rgb/001194.png masks/001194.png depth/001194.png
521 | rgb/001195.png masks/001195.png depth/001195.png
522 | rgb/001196.png masks/001196.png depth/001196.png
523 | rgb/001201.png masks/001201.png depth/001201.png
524 | rgb/001202.png masks/001202.png depth/001202.png
525 | rgb/001203.png masks/001203.png depth/001203.png
526 | rgb/001204.png masks/001204.png depth/001204.png
527 | rgb/001205.png masks/001205.png depth/001205.png
528 | rgb/001206.png masks/001206.png depth/001206.png
529 | rgb/001207.png masks/001207.png depth/001207.png
530 | rgb/001208.png masks/001208.png depth/001208.png
531 | rgb/001209.png masks/001209.png depth/001209.png
532 | rgb/001210.png masks/001210.png depth/001210.png
533 | rgb/001211.png masks/001211.png depth/001211.png
534 | rgb/001212.png masks/001212.png depth/001212.png
535 | rgb/001216.png masks/001216.png depth/001216.png
536 | rgb/001217.png masks/001217.png depth/001217.png
537 | rgb/001218.png masks/001218.png depth/001218.png
538 | rgb/001219.png masks/001219.png depth/001219.png
539 | rgb/001220.png masks/001220.png depth/001220.png
540 | rgb/001226.png masks/001226.png depth/001226.png
541 | rgb/001227.png masks/001227.png depth/001227.png
542 | rgb/001228.png masks/001228.png depth/001228.png
543 | rgb/001229.png masks/001229.png depth/001229.png
544 | rgb/001230.png masks/001230.png depth/001230.png
545 | rgb/001233.png masks/001233.png depth/001233.png
546 | rgb/001234.png masks/001234.png depth/001234.png
547 | rgb/001235.png masks/001235.png depth/001235.png
548 | rgb/001247.png masks/001247.png depth/001247.png
549 | rgb/001248.png masks/001248.png depth/001248.png
550 | rgb/001249.png masks/001249.png depth/001249.png
551 | rgb/001250.png masks/001250.png depth/001250.png
552 | rgb/001254.png masks/001254.png depth/001254.png
553 | rgb/001255.png masks/001255.png depth/001255.png
554 | rgb/001256.png masks/001256.png depth/001256.png
555 | rgb/001257.png masks/001257.png depth/001257.png
556 | rgb/001258.png masks/001258.png depth/001258.png
557 | rgb/001259.png masks/001259.png depth/001259.png
558 | rgb/001260.png masks/001260.png depth/001260.png
559 | rgb/001261.png masks/001261.png depth/001261.png
560 | rgb/001262.png masks/001262.png depth/001262.png
561 | rgb/001263.png masks/001263.png depth/001263.png
562 | rgb/001264.png masks/001264.png depth/001264.png
563 | rgb/001265.png masks/001265.png depth/001265.png
564 | rgb/001275.png masks/001275.png depth/001275.png
565 | rgb/001276.png masks/001276.png depth/001276.png
566 | rgb/001277.png masks/001277.png depth/001277.png
567 | rgb/001278.png masks/001278.png depth/001278.png
568 | rgb/001279.png masks/001279.png depth/001279.png
569 | rgb/001280.png masks/001280.png depth/001280.png
570 | rgb/001285.png masks/001285.png depth/001285.png
571 | rgb/001286.png masks/001286.png depth/001286.png
572 | rgb/001287.png masks/001287.png depth/001287.png
573 | rgb/001288.png masks/001288.png depth/001288.png
574 | rgb/001289.png masks/001289.png depth/001289.png
575 | rgb/001290.png masks/001290.png depth/001290.png
576 | rgb/001291.png masks/001291.png depth/001291.png
577 | rgb/001292.png masks/001292.png depth/001292.png
578 | rgb/001293.png masks/001293.png depth/001293.png
579 | rgb/001294.png masks/001294.png depth/001294.png
580 | rgb/001295.png masks/001295.png depth/001295.png
581 | rgb/001297.png masks/001297.png depth/001297.png
582 | rgb/001298.png masks/001298.png depth/001298.png
583 | rgb/001299.png masks/001299.png depth/001299.png
584 | rgb/001302.png masks/001302.png depth/001302.png
585 | rgb/001303.png masks/001303.png depth/001303.png
586 | rgb/001304.png masks/001304.png depth/001304.png
587 | rgb/001305.png masks/001305.png depth/001305.png
588 | rgb/001306.png masks/001306.png depth/001306.png
589 | rgb/001307.png masks/001307.png depth/001307.png
590 | rgb/001308.png masks/001308.png depth/001308.png
591 | rgb/001314.png masks/001314.png depth/001314.png
592 | rgb/001315.png masks/001315.png depth/001315.png
593 | rgb/001329.png masks/001329.png depth/001329.png
594 | rgb/001330.png masks/001330.png depth/001330.png
595 | rgb/001331.png masks/001331.png depth/001331.png
596 | rgb/001332.png masks/001332.png depth/001332.png
597 | rgb/001335.png masks/001335.png depth/001335.png
598 | rgb/001336.png masks/001336.png depth/001336.png
599 | rgb/001337.png masks/001337.png depth/001337.png
600 | rgb/001338.png masks/001338.png depth/001338.png
601 | rgb/001339.png masks/001339.png depth/001339.png
602 | rgb/001340.png masks/001340.png depth/001340.png
603 | rgb/001347.png masks/001347.png depth/001347.png
604 | rgb/001348.png masks/001348.png depth/001348.png
605 | rgb/001349.png masks/001349.png depth/001349.png
606 | rgb/001353.png masks/001353.png depth/001353.png
607 | rgb/001354.png masks/001354.png depth/001354.png
608 | rgb/001355.png masks/001355.png depth/001355.png
609 | rgb/001356.png masks/001356.png depth/001356.png
610 | rgb/001364.png masks/001364.png depth/001364.png
611 | rgb/001365.png masks/001365.png depth/001365.png
612 | rgb/001368.png masks/001368.png depth/001368.png
613 | rgb/001369.png masks/001369.png depth/001369.png
614 | rgb/001384.png masks/001384.png depth/001384.png
615 | rgb/001385.png masks/001385.png depth/001385.png
616 | rgb/001386.png masks/001386.png depth/001386.png
617 | rgb/001387.png masks/001387.png depth/001387.png
618 | rgb/001388.png masks/001388.png depth/001388.png
619 | rgb/001389.png masks/001389.png depth/001389.png
620 | rgb/001390.png masks/001390.png depth/001390.png
621 | rgb/001391.png masks/001391.png depth/001391.png
622 | rgb/001394.png masks/001394.png depth/001394.png
623 | rgb/001395.png masks/001395.png depth/001395.png
624 | rgb/001396.png masks/001396.png depth/001396.png
625 | rgb/001397.png masks/001397.png depth/001397.png
626 | rgb/001398.png masks/001398.png depth/001398.png
627 | rgb/001399.png masks/001399.png depth/001399.png
628 | rgb/001400.png masks/001400.png depth/001400.png
629 | rgb/001401.png masks/001401.png depth/001401.png
630 | rgb/001407.png masks/001407.png depth/001407.png
631 | rgb/001408.png masks/001408.png depth/001408.png
632 | rgb/001409.png masks/001409.png depth/001409.png
633 | rgb/001410.png masks/001410.png depth/001410.png
634 | rgb/001411.png masks/001411.png depth/001411.png
635 | rgb/001412.png masks/001412.png depth/001412.png
636 | rgb/001413.png masks/001413.png depth/001413.png
637 | rgb/001414.png masks/001414.png depth/001414.png
638 | rgb/001421.png masks/001421.png depth/001421.png
639 | rgb/001422.png masks/001422.png depth/001422.png
640 | rgb/001423.png masks/001423.png depth/001423.png
641 | rgb/001424.png masks/001424.png depth/001424.png
642 | rgb/001430.png masks/001430.png depth/001430.png
643 | rgb/001431.png masks/001431.png depth/001431.png
644 | rgb/001432.png masks/001432.png depth/001432.png
645 | rgb/001433.png masks/001433.png depth/001433.png
646 | rgb/001441.png masks/001441.png depth/001441.png
647 | rgb/001442.png masks/001442.png depth/001442.png
648 | rgb/001443.png masks/001443.png depth/001443.png
649 | rgb/001444.png masks/001444.png depth/001444.png
650 | rgb/001445.png masks/001445.png depth/001445.png
651 | rgb/001446.png masks/001446.png depth/001446.png
652 | rgb/001447.png masks/001447.png depth/001447.png
653 | rgb/001448.png masks/001448.png depth/001448.png
654 | rgb/001449.png masks/001449.png depth/001449.png
--------------------------------------------------------------------------------