├── utils ├── __init__.py ├── cmap.npy ├── optimizer.py ├── helpers.py ├── meter.py ├── datasets.py ├── transforms.py └── augmentations_mm.py ├── models ├── __init__.py ├── modules.py ├── segformer.py ├── mix_transformer.py └── swin_transformer.py ├── figs ├── framework.png ├── homogeneous.png ├── heterogeneous.png └── geminifusion_framework.png ├── mmcv_custom ├── __init__.py ├── runner │ ├── __init__.py │ ├── checkpoint.py │ └── epoch_based_runner.py └── checkpoint.py ├── LICENSE ├── .gitignore ├── README-TokenFusion.md ├── README.md ├── main.py └── data └── nyudv2 └── val.txt /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .helpers import * 2 | from .meter import * 3 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .mix_transformer import * 2 | from .segformer import WeTr -------------------------------------------------------------------------------- /utils/cmap.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiaDingCN/GeminiFusion/HEAD/utils/cmap.npy -------------------------------------------------------------------------------- /figs/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiaDingCN/GeminiFusion/HEAD/figs/framework.png -------------------------------------------------------------------------------- /figs/homogeneous.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiaDingCN/GeminiFusion/HEAD/figs/homogeneous.png -------------------------------------------------------------------------------- /figs/heterogeneous.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiaDingCN/GeminiFusion/HEAD/figs/heterogeneous.png -------------------------------------------------------------------------------- /figs/geminifusion_framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiaDingCN/GeminiFusion/HEAD/figs/geminifusion_framework.png -------------------------------------------------------------------------------- /mmcv_custom/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .checkpoint import load_checkpoint 4 | 5 | __all__ = ["load_checkpoint"] 6 | 7 | -------------------------------------------------------------------------------- /mmcv_custom/runner/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Open-MMLab. All rights reserved. 2 | from .checkpoint import save_checkpoint 3 | from .epoch_based_runner import EpochBasedRunnerAmp 4 | 5 | 6 | __all__ = ["EpochBasedRunnerAmp", "save_checkpoint"] 7 | 8 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 jiading 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /utils/optimizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class PolyWarmupAdamW(torch.optim.AdamW): 4 | 5 | def __init__(self, params, lr, weight_decay, betas, warmup_iter=None, max_iter=None, warmup_ratio=None, power=None): 6 | super().__init__(params, lr=lr, betas=betas,weight_decay=weight_decay, eps=1e-8) 7 | 8 | self.global_step = 0 9 | self.warmup_iter = warmup_iter 10 | self.warmup_ratio = warmup_ratio 11 | self.max_iter = max_iter 12 | self.power = power 13 | 14 | self.__init_lr = [group['lr'] for group in self.param_groups] 15 | 16 | def step(self, closure=None): 17 | ## adjust lr 18 | if self.global_step < self.warmup_iter: 19 | 20 | lr_mult = 1 - (1 - self.global_step / self.warmup_iter) * (1 - self.warmup_ratio) 21 | for i in range(len(self.param_groups)): 22 | self.param_groups[i]['lr'] = self.__init_lr[i] * lr_mult 23 | 24 | elif self.global_step < self.max_iter: 25 | 26 | lr_mult = (1 - self.global_step / self.max_iter) ** self.power 27 | for i in range(len(self.param_groups)): 28 | self.param_groups[i]['lr'] = self.__init_lr[i] * lr_mult 29 | 30 | # step 31 | super().step(closure) 32 | 33 | self.global_step += 1 -------------------------------------------------------------------------------- /models/modules.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | num_parallel = 2 4 | 5 | 6 | 7 | 8 | 9 | class ModuleParallel(nn.Module): 10 | def __init__(self, module): 11 | super(ModuleParallel, self).__init__() 12 | self.module = module 13 | 14 | def forward(self, x_parallel): 15 | return [self.module(x) for x in x_parallel] 16 | 17 | 18 | class Additional_One_ModuleParallel(nn.Module): 19 | def __init__(self, module): 20 | super(Additional_One_ModuleParallel, self).__init__() 21 | self.module = module 22 | 23 | def forward(self, x_parallel, x_arg): 24 | if x_arg == None: 25 | return [self.module(x, None) for x in x_parallel] 26 | elif isinstance(x_arg, list): 27 | return [ 28 | self.module(x_parallel[i], x_arg[i]) for i in range(len(x_parallel)) 29 | ] 30 | else: 31 | return [self.module(x_parallel[i], x_arg) for i in range(len(x_parallel))] 32 | 33 | 34 | class Additional_Two_ModuleParallel(nn.Module): 35 | def __init__(self, module): 36 | super(Additional_Two_ModuleParallel, self).__init__() 37 | self.module = module 38 | 39 | def forward(self, x_parallel, x_arg1, x_arg2): 40 | return [ 41 | self.module(x_parallel[i], x_arg1, x_arg2) for i in range(len(x_parallel)) 42 | ] 43 | 44 | 45 | class LayerNormParallel(nn.Module): 46 | def __init__(self, num_features): 47 | super(LayerNormParallel, self).__init__() 48 | for i in range(num_parallel): 49 | setattr(self, "ln_" + str(i), nn.LayerNorm(num_features, eps=1e-6)) 50 | 51 | def forward(self, x_parallel): 52 | return [getattr(self, "ln_" + str(i))(x) for i, x in enumerate(x_parallel)] 53 | -------------------------------------------------------------------------------- /utils/helpers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import matplotlib as mpl 4 | import matplotlib.cm as cm 5 | import PIL.Image as pil 6 | import cv2 7 | import os 8 | 9 | IMG_SCALE = 1./255 10 | IMG_MEAN = np.array([0.485, 0.456, 0.406]).reshape((1, 1, 3)) 11 | IMG_STD = np.array([0.229, 0.224, 0.225]).reshape((1, 1, 3)) 12 | logger = None 13 | 14 | 15 | def print_log(message): 16 | print(message, flush=True) 17 | if logger: 18 | logger.write(str(message) + '\n') 19 | 20 | 21 | def maybe_download(model_name, model_url, model_dir=None, map_location=None): 22 | import os, sys 23 | from six.moves import urllib 24 | if model_dir is None: 25 | torch_home = os.path.expanduser(os.getenv('TORCH_HOME', '~/.torch')) 26 | model_dir = os.getenv('TORCH_MODEL_ZOO', os.path.join(torch_home, 'models')) 27 | if not os.path.exists(model_dir): 28 | os.makedirs(model_dir) 29 | filename = '{}.pth.tar'.format(model_name) 30 | cached_file = os.path.join(model_dir, filename) 31 | if not os.path.exists(cached_file): 32 | url = model_url 33 | sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) 34 | urllib.request.urlretrieve(url, cached_file) 35 | return torch.load(cached_file, map_location=map_location) 36 | 37 | 38 | def prepare_img(img): 39 | return (img * IMG_SCALE - IMG_MEAN) / IMG_STD 40 | 41 | 42 | def make_validation_img(img_, depth_, lab, pre): 43 | cmap = np.load('./utils/cmap.npy') 44 | 45 | img = np.array([i * IMG_STD.reshape((3, 1, 1)) + IMG_MEAN.reshape((3, 1, 1)) for i in img_]) 46 | img *= 255 47 | img = img.astype(np.uint8) 48 | img = np.concatenate(img, axis=1) 49 | 50 | depth_ = depth_[0].transpose(1, 2, 0) / max(depth_.max(), 10) 51 | vmax = np.percentile(depth_, 95) 52 | normalizer = mpl.colors.Normalize(vmin=depth_.min(), vmax=vmax) 53 | mapper = cm.ScalarMappable(norm=normalizer, cmap='magma') 54 | depth = (mapper.to_rgba(depth_)[:,:,:3] * 255).astype(np.uint8) 55 | lab = np.concatenate(lab) 56 | lab = np.array([cmap[i.astype(np.uint8) + 1] for i in lab]) 57 | 58 | pre = np.concatenate(pre) 59 | pre = np.array([cmap[i.astype(np.uint8) + 1] for i in pre]) 60 | img = img.transpose(1, 2, 0) 61 | 62 | return np.concatenate([img, depth, lab, pre], 1) 63 | -------------------------------------------------------------------------------- /mmcv_custom/runner/checkpoint.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Open-MMLab. All rights reserved. 2 | import os.path as osp 3 | import time 4 | from tempfile import TemporaryDirectory 5 | 6 | import torch 7 | from torch.optim import Optimizer 8 | 9 | import mmcv 10 | from mmcv.parallel import is_module_wrapper 11 | from mmcv.runner.checkpoint import weights_to_cpu, get_state_dict 12 | 13 | try: 14 | import apex 15 | except: 16 | print("apex is not installed") 17 | 18 | 19 | def save_checkpoint(model, filename, optimizer=None, meta=None): 20 | """Save checkpoint to file. 21 | 22 | The checkpoint will have 4 fields: ``meta``, ``state_dict`` and 23 | ``optimizer``, ``amp``. By default ``meta`` will contain version 24 | and time info. 25 | 26 | Args: 27 | model (Module): Module whose params are to be saved. 28 | filename (str): Checkpoint filename. 29 | optimizer (:obj:`Optimizer`, optional): Optimizer to be saved. 30 | meta (dict, optional): Metadata to be saved in checkpoint. 31 | """ 32 | if meta is None: 33 | meta = {} 34 | elif not isinstance(meta, dict): 35 | raise TypeError(f"meta must be a dict or None, but got {type(meta)}") 36 | meta.update(mmcv_version=mmcv.__version__, time=time.asctime()) 37 | 38 | if is_module_wrapper(model): 39 | model = model.module 40 | 41 | if hasattr(model, "CLASSES") and model.CLASSES is not None: 42 | # save class name to the meta 43 | meta.update(CLASSES=model.CLASSES) 44 | 45 | checkpoint = {"meta": meta, "state_dict": weights_to_cpu(get_state_dict(model))} 46 | # save optimizer state dict in the checkpoint 47 | if isinstance(optimizer, Optimizer): 48 | checkpoint["optimizer"] = optimizer.state_dict() 49 | elif isinstance(optimizer, dict): 50 | checkpoint["optimizer"] = {} 51 | for name, optim in optimizer.items(): 52 | checkpoint["optimizer"][name] = optim.state_dict() 53 | 54 | # save amp state dict in the checkpoint 55 | checkpoint["amp"] = apex.amp.state_dict() 56 | 57 | if filename.startswith("pavi://"): 58 | try: 59 | from pavi import modelcloud 60 | from pavi.exception import NodeNotFoundError 61 | except ImportError: 62 | raise ImportError("Please install pavi to load checkpoint from modelcloud.") 63 | model_path = filename[7:] 64 | root = modelcloud.Folder() 65 | model_dir, model_name = osp.split(model_path) 66 | try: 67 | model = modelcloud.get(model_dir) 68 | except NodeNotFoundError: 69 | model = root.create_training_model(model_dir) 70 | with TemporaryDirectory() as tmp_dir: 71 | checkpoint_file = osp.join(tmp_dir, model_name) 72 | with open(checkpoint_file, "wb") as f: 73 | torch.save(checkpoint, f) 74 | f.flush() 75 | model.create_file(checkpoint_file, name=model_name) 76 | else: 77 | mmcv.mkdir_or_exist(osp.dirname(filename)) 78 | # immediately flush buffer 79 | with open(filename, "wb") as f: 80 | torch.save(checkpoint, f) 81 | f.flush() 82 | -------------------------------------------------------------------------------- /utils/meter.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import torch 4 | import numpy as np 5 | 6 | 7 | def confusion_matrix(x, y, n, ignore_label=None, mask=None): 8 | if mask is None: 9 | mask = np.ones_like(x) == 1 10 | k = (x >= 0) & (y < n) & (x != ignore_label) & (mask.astype(np.bool)) 11 | return np.bincount(n * x[k].astype(int) + y[k], minlength=n ** 2).reshape(n, n) 12 | 13 | 14 | def getScores(conf_matrix): 15 | if conf_matrix.sum() == 0: 16 | return 0, 0, 0 17 | with np.errstate(divide='ignore',invalid='ignore'): 18 | overall = np.diag(conf_matrix).sum() / np.float(conf_matrix.sum()) 19 | perclass = np.diag(conf_matrix) / conf_matrix.sum(1).astype(np.float) 20 | IU = np.diag(conf_matrix) / (conf_matrix.sum(1) + conf_matrix.sum(0) \ 21 | - np.diag(conf_matrix)).astype(np.float) 22 | return overall * 100., np.nanmean(perclass) * 100., np.nanmean(IU) * 100. 23 | 24 | 25 | def compute_params(model): 26 | """Compute number of parameters""" 27 | n_total_params = 0 28 | for name, m in model.named_parameters(): 29 | n_elem = m.numel() 30 | n_total_params += n_elem 31 | return n_total_params 32 | 33 | 34 | # Adopted from https://raw.githubusercontent.com/pytorch/examples/master/imagenet/main.py 35 | class AverageMeter(object): 36 | """Computes and stores the average and current value""" 37 | def __init__(self): 38 | self.reset() 39 | 40 | def reset(self): 41 | self.val = 0 42 | self.avg = 0 43 | self.sum = 0 44 | self.count = 0 45 | 46 | def update(self, val, n=1): 47 | self.val = val 48 | self.sum += val * n 49 | self.count += n 50 | self.avg = self.sum / self.count 51 | 52 | 53 | class Saver(): 54 | """Saver class for managing parameters""" 55 | def __init__(self, args, ckpt_dir, best_val=0, condition=lambda x, y: x > y): 56 | """ 57 | Args: 58 | args (dict): dictionary with arguments. 59 | ckpt_dir (str): path to directory in which to store the checkpoint. 60 | best_val (float): initial best value. 61 | condition (function): how to decide whether to save the new checkpoint 62 | by comparing best value and new value (x,y). 63 | 64 | """ 65 | if not os.path.exists(ckpt_dir): 66 | os.makedirs(ckpt_dir) 67 | with open('{}/args.json'.format(ckpt_dir), 'w') as f: 68 | json.dump({k: v for k, v in args.items() if isinstance(v, (int, float, str))}, f, 69 | sort_keys = True, indent = 4, ensure_ascii = False) 70 | self.ckpt_dir = ckpt_dir 71 | self.best_val = best_val 72 | self.condition = condition 73 | self._counter = 0 74 | 75 | def _do_save(self, new_val): 76 | """Check whether need to save""" 77 | return self.condition(new_val, self.best_val) 78 | 79 | def save(self, new_val, dict_to_save): 80 | """Save new checkpoint""" 81 | self._counter += 1 82 | if self._do_save(new_val): 83 | # print(' New best value {:.4f}, was {:.4f}'.format(new_val, self.best_val), flush=True) 84 | self.best_val = new_val 85 | dict_to_save['best_val'] = new_val 86 | torch.save(dict_to_save, '{}/model-best.pth.tar'.format(self.ckpt_dir)) 87 | else: 88 | dict_to_save['best_val'] = new_val 89 | torch.save(dict_to_save, '{}/checkpoint.pth.tar'.format(self.ckpt_dir)) 90 | 91 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /mmcv_custom/runner/epoch_based_runner.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Open-MMLab. All rights reserved. 2 | import os.path as osp 3 | import platform 4 | import shutil 5 | 6 | import torch 7 | from torch.optim import Optimizer 8 | 9 | import mmcv 10 | from mmcv.runner import RUNNERS, EpochBasedRunner 11 | from .checkpoint import save_checkpoint 12 | 13 | try: 14 | import apex 15 | except: 16 | print("apex is not installed") 17 | 18 | 19 | @RUNNERS.register_module() 20 | class EpochBasedRunnerAmp(EpochBasedRunner): 21 | """Epoch-based Runner with AMP support. 22 | 23 | This runner train models epoch by epoch. 24 | """ 25 | 26 | def save_checkpoint( 27 | self, 28 | out_dir, 29 | filename_tmpl="epoch_{}.pth", 30 | save_optimizer=True, 31 | meta=None, 32 | create_symlink=True, 33 | ): 34 | """Save the checkpoint. 35 | 36 | Args: 37 | out_dir (str): The directory that checkpoints are saved. 38 | filename_tmpl (str, optional): The checkpoint filename template, 39 | which contains a placeholder for the epoch number. 40 | Defaults to 'epoch_{}.pth'. 41 | save_optimizer (bool, optional): Whether to save the optimizer to 42 | the checkpoint. Defaults to True. 43 | meta (dict, optional): The meta information to be saved in the 44 | checkpoint. Defaults to None. 45 | create_symlink (bool, optional): Whether to create a symlink 46 | "latest.pth" to point to the latest checkpoint. 47 | Defaults to True. 48 | """ 49 | if meta is None: 50 | meta = dict(epoch=self.epoch + 1, iter=self.iter) 51 | elif isinstance(meta, dict): 52 | meta.update(epoch=self.epoch + 1, iter=self.iter) 53 | else: 54 | raise TypeError(f"meta should be a dict or None, but got {type(meta)}") 55 | if self.meta is not None: 56 | meta.update(self.meta) 57 | 58 | filename = filename_tmpl.format(self.epoch + 1) 59 | filepath = osp.join(out_dir, filename) 60 | optimizer = self.optimizer if save_optimizer else None 61 | save_checkpoint(self.model, filepath, optimizer=optimizer, meta=meta) 62 | # in some environments, `os.symlink` is not supported, you may need to 63 | # set `create_symlink` to False 64 | if create_symlink: 65 | dst_file = osp.join(out_dir, "latest.pth") 66 | if platform.system() != "Windows": 67 | mmcv.symlink(filename, dst_file) 68 | else: 69 | shutil.copy(filepath, dst_file) 70 | 71 | def resume(self, checkpoint, resume_optimizer=True, map_location="default"): 72 | if map_location == "default": 73 | if torch.cuda.is_available(): 74 | device_id = torch.cuda.current_device() 75 | checkpoint = self.load_checkpoint( 76 | checkpoint, 77 | map_location=lambda storage, loc: storage.cuda(device_id), 78 | ) 79 | else: 80 | checkpoint = self.load_checkpoint(checkpoint) 81 | else: 82 | checkpoint = self.load_checkpoint(checkpoint, map_location=map_location) 83 | 84 | self._epoch = checkpoint["meta"]["epoch"] 85 | self._iter = checkpoint["meta"]["iter"] 86 | if "optimizer" in checkpoint and resume_optimizer: 87 | if isinstance(self.optimizer, Optimizer): 88 | self.optimizer.load_state_dict(checkpoint["optimizer"]) 89 | elif isinstance(self.optimizer, dict): 90 | for k in self.optimizer.keys(): 91 | self.optimizer[k].load_state_dict(checkpoint["optimizer"][k]) 92 | else: 93 | raise TypeError( 94 | "Optimizer should be dict or torch.optim.Optimizer " 95 | f"but got {type(self.optimizer)}" 96 | ) 97 | 98 | if "amp" in checkpoint: 99 | apex.amp.load_state_dict(checkpoint["amp"]) 100 | self.logger.info("load amp state dict") 101 | 102 | self.logger.info("resumed epoch %d, iter %d", self.epoch, self.iter) 103 | 104 | -------------------------------------------------------------------------------- /utils/datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import cv2 4 | from PIL import Image 5 | from torch.utils.data import Dataset 6 | from torchvision import io 7 | 8 | 9 | def line_to_paths_fn(x, input_names): 10 | return x.decode("utf-8").strip("\n").split("\t") 11 | 12 | 13 | class SegDataset(Dataset): 14 | """Multi-Modality Segmentation dataset. 15 | 16 | Works with any datasets that contain image 17 | and any number of 2D-annotations. 18 | 19 | Args: 20 | data_file (string): Path to the data file with annotations. 21 | data_dir (string): Directory with all the images. 22 | line_to_paths_fn (callable): function to convert a line of data_file 23 | into paths (img_relpath, msk_relpath, ...). 24 | masks_names (list of strings): keys for each annotation mask 25 | (e.g., 'segm', 'depth'). 26 | transform_trn (callable, optional): Optional transform 27 | to be applied on a sample during the training stage. 28 | transform_val (callable, optional): Optional transform 29 | to be applied on a sample during the validation stage. 30 | stage (str): initial stage of dataset - either 'train' or 'val'. 31 | 32 | """ 33 | 34 | def __init__( 35 | self, 36 | dataset, 37 | data_file, 38 | data_dir, 39 | input_names, 40 | input_mask_idxs, 41 | transform_trn=None, 42 | transform_val=None, 43 | stage="train", 44 | ignore_label=None, 45 | ): 46 | with open(data_file, "rb") as f: 47 | datalist = f.readlines() 48 | self.dataset = dataset 49 | self.datalist = [line_to_paths_fn(l, input_names) for l in datalist] 50 | self.root_dir = data_dir 51 | self.transform_trn = transform_trn 52 | self.transform_val = transform_val 53 | self.stage = stage 54 | self.input_names = input_names 55 | self.input_mask_idxs = input_mask_idxs 56 | self.ignore_label = ignore_label 57 | 58 | def set_stage(self, stage): 59 | """Define which set of transformation to use. 60 | 61 | Args: 62 | stage (str): either 'train' or 'val' 63 | 64 | """ 65 | self.stage = stage 66 | 67 | def __len__(self): 68 | return len(self.datalist) 69 | 70 | def __getitem__(self, idx): 71 | idxs = self.input_mask_idxs 72 | names = [os.path.join(self.root_dir, rpath) for rpath in self.datalist[idx]] 73 | sample = {} 74 | for i, key in enumerate(self.input_names): 75 | sample[key] = self.read_image(names[idxs[i]], key) 76 | try: 77 | if self.dataset == "nyudv2": 78 | mask = np.array(Image.open(names[idxs[-1]])) 79 | elif self.dataset == "sunrgbd": 80 | mask = self._open_image( 81 | names[idxs[-1]], cv2.IMREAD_GRAYSCALE, dtype=np.uint8 82 | ) 83 | except FileNotFoundError: # for sunrgbd 84 | path = names[idxs[-1]] 85 | num_idx = int(path[-10:-4]) + 5050 86 | path = path[:-10] + "%06d" % num_idx + path[-4:] 87 | mask = np.array(Image.open(path)) 88 | 89 | if self.dataset == "sunrgbd": 90 | mask -= 1 91 | 92 | assert len(mask.shape) == 2, "Masks must be encoded without colourmap" 93 | sample["inputs"] = self.input_names 94 | sample["mask"] = mask 95 | 96 | del sample["inputs"] 97 | if self.stage == "train": 98 | if self.transform_trn: 99 | sample = self.transform_trn(sample) 100 | elif self.stage == "val": 101 | if self.transform_val: 102 | sample = self.transform_val(sample) 103 | 104 | return sample 105 | 106 | @staticmethod 107 | def _open_image(filepath, mode=cv2.IMREAD_COLOR, dtype=None): 108 | img = np.array(cv2.imread(filepath, mode), dtype=dtype) 109 | return img 110 | 111 | @staticmethod 112 | def read_image(x, key): 113 | """Simple image reader 114 | 115 | Args: 116 | x (str): path to image. 117 | 118 | Returns image as `np.array`. 119 | 120 | """ 121 | img_arr = np.array(Image.open(x)) 122 | if len(img_arr.shape) == 2: # grayscale 123 | img_arr = np.tile(img_arr, [3, 1, 1]).transpose(1, 2, 0) 124 | return img_arr 125 | -------------------------------------------------------------------------------- /README-TokenFusion.md: -------------------------------------------------------------------------------- 1 | # Multimodal Token Fusion for Vision Transformers 2 | 3 | By Yikai Wang, Xinghao Chen, Lele Cao, Wenbing Huang, Fuchun Sun, Yunhe Wang. 4 | 5 | [**[Paper]**](https://arxiv.org/pdf/2204.08721.pdf) 6 | 7 | This repository is a PyTorch implementation of "Multimodal Token Fusion for Vision Transformers", in CVPR 2022. 8 | 9 |
10 | 11 |
12 | 13 | Homogeneous predictions, 14 |
15 | 16 |
17 | 18 | Heterogeneous predictions, 19 |
20 | 21 |
22 | 23 | 24 | ## Datasets 25 | 26 | For semantic segmentation task on NYUDv2 ([official dataset](https://cs.nyu.edu/~silberman/datasets/nyu_depth_v2.html)), we provide a link to download the dataset [here](https://drive.google.com/drive/folders/1mXmOXVsd5l9-gYHk92Wpn6AcKAbE0m3X?usp=sharing). The provided dataset is originally preprocessed in this [repository](https://github.com/DrSleep/light-weight-refinenet), and we add depth data in it. 27 | 28 | For image-to-image translation task, we use the sample dataset of [Taskonomy](http://taskonomy.stanford.edu/), where a link to download the sample dataset is [here](https://github.com/alexsax/taskonomy-sample-model-1.git). 29 | 30 | Please modify the data paths in the codes, where we add comments 'Modify data path'. 31 | 32 | 33 | ## Dependencies 34 | ``` 35 | python==3.6 36 | pytorch==1.7.1 37 | torchvision==0.8.2 38 | numpy==1.19.2 39 | ``` 40 | 41 | 42 | ## Semantic Segmentation 43 | 44 | 45 | First, 46 | ``` 47 | cd semantic_segmentation 48 | ``` 49 | 50 | Download the [segformer](https://github.com/NVlabs/SegFormer) pretrained model (pretrained on ImageNet) from [weights](https://drive.google.com/drive/folders/1b7bwrInTW4VLEm27YawHOAMSMikga2Ia), e.g., mit_b3.pth. Move this pretrained model to folder 'pretrained'. 51 | 52 | Training script for segmentation with RGB and Depth input, 53 | ``` 54 | python main.py --backbone mit_b3 -c exp_name --lamda 1e-6 --gpu 0 1 2 55 | ``` 56 | 57 | Evaluation script, 58 | ``` 59 | python main.py --gpu 0 --resume path_to_pth --evaluate # optionally use --save-img to visualize results 60 | ``` 61 | 62 | Checkpoint models, training logs, mask ratios and the **single-scale** performance on NYUDv2 are provided as follows: 63 | 64 | | Method | Backbone | Pixel Acc. (%) | Mean Acc. (%) | Mean IoU (%) | Download | 65 | |:-----------:|:-----------:|:-----------:|:-----------:|:-----------:|:-----------:| 66 | |[CEN](https://github.com/yikaiw/CEN)| ResNet101 | 76.2 | 62.8 | 51.1 | [Google Drive](https://drive.google.com/drive/folders/1wim_cBG-HW0bdipwA1UbnGeDwjldPIwV?usp=sharing)| 67 | |[CEN](https://github.com/yikaiw/CEN)| ResNet152 | 77.0 | 64.4 | 51.6 | [Google Drive](https://drive.google.com/drive/folders/1DGF6vHLDgBgLrdUNJOLYdoXCuEKbIuRs?usp=sharing)| 68 | |Ours| SegFormer-B3 | 78.7 | 67.5 | 54.8 | [Google Drive](https://drive.google.com/drive/folders/14fi8aABFYqGF7LYKHkiJazHA58OBW1AW?usp=sharing)| 69 | 70 | 71 | Mindspore implementation is available at: https://gitee.com/mindspore/models/tree/master/research/cv/TokenFusion 72 | 73 | ## Image-to-Image Translation 74 | 75 | First, 76 | ``` 77 | cd image2image_translation 78 | ``` 79 | Training script, from Shade and Texture to RGB, 80 | ``` 81 | python main.py --gpu 0 -c exp_name 82 | ``` 83 | This script will auto-evaluate on the validation dataset every 5 training epochs. 84 | 85 | Predicted images will be automatically saved during training, in the following folder structure: 86 | 87 | ``` 88 | code_root/ckpt/exp_name/results 89 | ├── input0 # 1st modality input 90 | ├── input1 # 2nd modality input 91 | ├── fake0 # 1st branch output 92 | ├── fake1 # 2nd branch output 93 | ├── fake2 # ensemble output 94 | ├── best # current best output 95 | │ ├── fake0 96 | │ ├── fake1 97 | │ └── fake2 98 | └── real # ground truth output 99 | ``` 100 | 101 | Checkpoint models: 102 | 103 | | Method | Task | FID | KID | Download | 104 | |:-----------:|:-----------:|:-----------:|:-----------:|:-----------:| 105 | | [CEN](https://github.com/yikaiw/CEN) |Texture+Shade->RGB | 62.6 | 1.65 | - | 106 | | Ours | Texture+Shade->RGB | 45.5 | 1.00 | [Google Drive](https://drive.google.com/drive/folders/1vkcDv5bHKXZKxCg4dC7R56ts6nLLt6lh?usp=sharing)| 107 | 108 | ## 3D Object Detection (under construction) 109 | 110 | Data preparation, environments, and training scripts follow [Group-Free](https://github.com/zeliu98/Group-Free-3D) and [ImVoteNet](https://github.com/facebookresearch/imvotenet). 111 | 112 | E.g., 113 | ``` 114 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --master_port 2229 --nproc_per_node 4 train_dist.py --max_epoch 600 --val_freq 25 --save_freq 25 --lr_decay_epochs 420 480 540 --num_point 20000 --num_decoder_layers 6 --size_cls_agnostic --size_delta 0.0625 --heading_delta 0.04 --center_delta 0.1111111111111 --weight_decay 0.00000001 --query_points_generator_loss_coef 0.2 --obj_loss_coef 0.4 --dataset sunrgbd --data_root . --use_img --log_dir log/exp_name 115 | ``` 116 | 117 | ## Citation 118 | 119 | If you find our work useful for your research, please consider citing the following paper. 120 | ``` 121 | @inproceedings{wang2022tokenfusion, 122 | title={Multimodal Token Fusion for Vision Transformers}, 123 | author={Wang, Yikai and Chen, Xinghao and Cao, Lele and Huang, Wenbing and Sun, Fuchun and Wang, Yunhe}, 124 | booktitle={IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, 125 | year={2022} 126 | } 127 | ``` 128 | 129 | 130 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | ## GeminiFusion for Multimodal Segementation on NYUDv2 & SUN RGBD Dataset (ICML 2024) 4 | 5 |
6 | 7 |

8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 |

17 | 18 | 19 | 20 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/geminifusion-efficient-pixel-wise-multimodal/semantic-segmentation-on-deliver-1)](https://paperswithcode.com/sota/semantic-segmentation-on-deliver-1?p=geminifusion-efficient-pixel-wise-multimodal) 21 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/geminifusion-efficient-pixel-wise-multimodal/semantic-segmentation-on-nyu-depth-v2)](https://paperswithcode.com/sota/semantic-segmentation-on-nyu-depth-v2?p=geminifusion-efficient-pixel-wise-multimodal) 22 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/geminifusion-efficient-pixel-wise-multimodal/semantic-segmentation-on-sun-rgbd)](https://paperswithcode.com/sota/semantic-segmentation-on-sun-rgbd?p=geminifusion-efficient-pixel-wise-multimodal) 23 | 24 | 25 | This is the official implementation of our paper "[GeminiFusion: Efficient Pixel-wise Multimodal Fusion for Vision Transformer](https://arxiv.org/pdf/2406.01210)". 26 | 27 | Authors: Ding Jia, Jianyuan Guo, Kai Han, Han Wu, Chao Zhang, Chang Xu, Xinghao Chen 28 | 29 | 30 | 31 | ## Code List 32 | 33 | We have applied our GeminiFusion to different tasks and datasets: 34 | 35 | * GeminiFusion for Multimodal Semantic Segmentation 36 | * (This branch)[NYUDv2 & SUN RGBD datasets](https://github.com/JiaDingCN/GeminiFusion/tree/main) 37 | * [DeLiVER dataset](https://github.com/JiaDingCN/GeminiFusion/tree/DeLiVER) 38 | * GeminiFusion for Multimodal 3D Object Detection 39 | * [KITTI dataset](https://github.com/JiaDingCN/GeminiFusion/tree/3d_object_detection_kitti) 40 | 41 | 42 | ## Introduction 43 | 44 | We propose GeminiFusion, a pixel-wise fusion approach that capitalizes on aligned cross-modal representations. GeminiFusion elegantly combines intra-modal and inter-modal attentions, dynamically integrating complementary information across modalities. We employ a layer-adaptive noise to adaptively control their interplay on a per-layer basis, thereby achieving a harmonized fusion process. Notably, GeminiFusion maintains linear complexity with respect to the number of input tokens, ensuring this multimodal framework operates with efficiency comparable to unimodal networks. Comprehensive evaluations demonstrate the superior performance of our GeminiFusion against leading-edge techniques. 45 | 46 | 47 | 48 | ## Framework 49 | ![geminifusion_framework](figs/geminifusion_framework.png) 50 | 51 | 52 | 53 | ## Model Zoo 54 | 55 | ### NYUDv2 dataset 56 | 57 | | Model | backbone| mIoU | Download | 58 | |:-------:|:--------:|:-------:|:-------------------:| 59 | | GeminiFusion | MiT-B3| 56.8 | [model](https://github.com/JiaDingCN/GeminiFusion/releases/download/NYUDv2_V2/mit-b3.pth.tar) | 60 | | GeminiFusion | MiT-B5| 57.7 | [model](https://github.com/JiaDingCN/GeminiFusion/releases/download/NYUDv2_V2/mit_b5.pth.tar) | 61 | | GeminiFusion | swin_tiny| 52.2 | [model](https://github.com/JiaDingCN/GeminiFusion/releases/download/NYUDv2_V2/swin_tiny.pth.tar) | 62 | | GeminiFusion | swin-small| 55.0 | [model](https://github.com/JiaDingCN/GeminiFusion/releases/download/NYUDv2_V2/swin_small.pth.tar) | 63 | | GeminiFusion | swin-large-224| 58.8 | [model](https://github.com/JiaDingCN/GeminiFusion/releases/download/NYUDv2_V2/swin_large.pth.tar) | 64 | | GeminiFusion | swin-large-384| 60.2 | [model](https://github.com/JiaDingCN/GeminiFusion/releases/download/NYUDv2_V2/swin_large_384.pth.tar) | 65 | | GeminiFusion | swin-large-384 +FineTune from SUN 300eps| 60.9 | [model](https://github.com/JiaDingCN/GeminiFusion/releases/download/NYUDv2_V2/finetune-swin-large-384.pth.tar) | 66 | 67 | ### SUN RGBD dataset 68 | 69 | | Model | backbone| mIoU | Download | 70 | |:-------:|:--------:|:-------:|:-------------------:| 71 | | GeminiFusion | MiT-B3| 52.7 | [model](https://github.com/JiaDingCN/GeminiFusion/releases/download/SUN_v2/mit-b3.pth.tar) | 72 | | GeminiFusion | MiT-B5| 53.3 | [model](https://github.com/JiaDingCN/GeminiFusion/releases/download/SUN_v2/mit_b5.pth.tar) | 73 | | GeminiFusion | swin_tiny| 50.2 | [model](https://github.com/JiaDingCN/GeminiFusion/releases/download/SUN_v2/swin_tiny.pth.tar) | 74 | | GeminiFusion | swin-large-384| 54.8 | [model](https://github.com/JiaDingCN/GeminiFusion/releases/download/SUN_v2/swin-large-384.pth.tar) | 75 | 76 | 77 | 78 | ## Installation 79 | 80 | We build our GeminiFusion on the TokenFusion codebase, which requires no additional installation steps. If any problem about the framework, you may refer to [the offical TokenFusion readme](./README-TokenFusion.md). 81 | 82 | Most of the `GeminiFusion`-related code locate in the following files: 83 | * [models/mix_transformer](models/mix_transformer.py): implement the GeminiFusion module for MiT backbones. 84 | * [models/swin_transformer](models/swin_transformer.py):implement the GeminiFusion module for Swin backbones. 85 | * [mmcv_custom](mmcv_custom): load checkpoints for Swin backbones. 86 | * [main](main.py): enable SUN RGBD dataset. 87 | * [utils/datasets](utils/datasets.py): enable SUN RGBD dataset. 88 | 89 | We also delete the config.py in the TokenFusion codebase since it is not used here. 90 | 91 | 92 | 93 | ## Data 94 | 95 | **NYUDv2 Dataset Prapare** 96 | 97 | Please follow [the data preparation instructions for NYUDv2 in TokenFusion readme](./README-TokenFusion.md#datasets). In default the data path is `/cache/datasets/nyudv2`, you may change it by `--train-dir `. 98 | 99 | **SUN RGBD Dataset Prapare** 100 | 101 | Please download the SUN RGBD dataset follow the link in [DFormer](https://github.com/VCIP-RGBD/DFormer?tab=readme-ov-file#2--get-start).In default the data path is `/cache/datasets/sunrgbd_Dformer/SUNRGBD`, you may change it by `--train-dir `. 102 | 103 | 104 | 105 | ## Train 106 | 107 | **NYUDv2 Training** 108 | 109 | On the NYUDv2 dataset, we follow the TokenFusion's setting, using 3 GPUs to train the GeminiFusion. 110 | 111 | ```shell 112 | # mit-b3 113 | CUDA_VISIBLE_DEVICES=0,1,2 python -m torch.distributed.launch --nproc_per_node=3 --use_env main.py --backbone mit_b3 --dataset nyudv2 -c nyudv2_mit_b3 114 | 115 | # mit-b5 116 | CUDA_VISIBLE_DEVICES=0,1,2 python -m torch.distributed.launch --nproc_per_node=3 --use_env main.py --backbone mit_b5 --dataset nyudv2 -c nyudv2_mit_b5 --dpr 0.35 117 | 118 | # swin_tiny 119 | CUDA_VISIBLE_DEVICES=0,1,2 python -m torch.distributed.launch --nproc_per_node=3 --use_env main.py --backbone swin_tiny --dataset nyudv2 -c nyudv2_swin_tiny --dpr 0.2 120 | 121 | # swin_small 122 | CUDA_VISIBLE_DEVICES=0,1,2 python -m torch.distributed.launch --nproc_per_node=3 --use_env main.py --backbone swin_small --dataset nyudv2 -c nyudv2_swin_small 123 | 124 | # swin_large 125 | CUDA_VISIBLE_DEVICES=0,1,2 python -m torch.distributed.launch --nproc_per_node=3 --use_env main.py --backbone swin_large --dataset nyudv2 -c nyudv2_swin_large 126 | 127 | # swin_large_window12 128 | CUDA_VISIBLE_DEVICES=0,1,2 python -m torch.distributed.launch --nproc_per_node=3 --use_env main.py --backbone swin_large_window12 --dataset nyudv2 -c nyudv2_swin_large_window12 --dpr 0.2 129 | 130 | # swin-large-384+FineTune from SUN 300eps 131 | # swin-large-384.pth.tar should be downloaded by our link or trained by yourself 132 | CUDA_VISIBLE_DEVICES=0,1,2 python -m torch.distributed.launch --nproc_per_node=3 --use_env main.py --backbone swin_large_window12 --dataset nyudv2 -c swin_large_window12_finetune_dpr0.15_100+200+100 \ 133 | --dpr 0.15 --num-epoch 100 200 100 --is_pretrain_finetune --resume ./swin-large-384.pth.tar 134 | ``` 135 | 136 | **SUN RGBD Training** 137 | 138 | On the SUN RGBD dataset, we use 4 GPUs to train the GeminiFusion. 139 | ```shell 140 | # mit-b3 141 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --use_env main.py --backbone mit_b3 --dataset sunrgbd --train-dir /cache/datasets/sunrgbd_Dformer/SUNRGBD -c sunrgbd_mit_b3 142 | 143 | # mit-b5 144 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --use_env main.py --backbone mit_b5 --dataset sunrgbd --train-dir /cache/datasets/sunrgbd_Dformer/SUNRGBD -c sunrgbd_mit_b5 --weight_decay 0.05 145 | 146 | # swin_tiny 147 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --use_env main.py --backbone swin_tiny --dataset sunrgbd --train-dir /cache/datasets/sunrgbd_Dformer/SUNRGBD -c sunrgbd_swin_tiny 148 | 149 | # swin_large_window12 150 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --use_env main.py --backbone swin_large_window12 --dataset sunrgbd --train-dir /cache/datasets/sunrgbd_Dformer/SUNRGBD -c sunrgbd_swin_large_window12 151 | ``` 152 | 153 | 154 | 155 | ## Test 156 | 157 | To evaluate checkpoints, you need to add `--eval --resume ` after the training script. 158 | 159 | For example, on the NYUDv2 dataset, the training script for GeminiFusion with mit-b3 backbone is: 160 | ```shell 161 | CUDA_VISIBLE_DEVICES=0,1,2 python -m torch.distributed.launch --nproc_per_node=3 --use_env main.py --backbone mit_b3 --dataset nyudv2 -c nyudv2_mit_b3 162 | ``` 163 | 164 | To evaluate the trained or downloaded checkpoint, the eval script is: 165 | ```shell 166 | CUDA_VISIBLE_DEVICES=0,1,2 python -m torch.distributed.launch --nproc_per_node=3 --use_env main.py --backbone mit_b3 --dataset nyudv2 -c nyudv2_mit_b3 --eval --resume mit-b3.pth.tar 167 | ``` 168 | 169 | 170 | 171 | ## Citation 172 | 173 | If you find this work useful for your research, please cite our paper: 174 | 175 | ``` 176 | @misc{jia2024geminifusion, 177 | title={GeminiFusion: Efficient Pixel-wise Multimodal Fusion for Vision Transformer}, 178 | author={Ding Jia and Jianyuan Guo and Kai Han and Han Wu and Chao Zhang and Chang Xu and Xinghao Chen}, 179 | year={2024}, 180 | eprint={2406.01210}, 181 | archivePrefix={arXiv}, 182 | primaryClass={cs.CV} 183 | } 184 | ``` 185 | 186 | 187 | 188 | ## Acknowledgement 189 | Part of our code is based on the open-source project [TokenFusion](https://github.com/yikaiw/TokenFusion). 190 | -------------------------------------------------------------------------------- /utils/transforms.py: -------------------------------------------------------------------------------- 1 | """RefineNet-LightWeight 2 | 3 | RefineNet-LigthWeight PyTorch for non-commercial purposes 4 | 5 | Copyright (c) 2018, Vladimir Nekrasov (vladimir.nekrasov@adelaide.edu.au) 6 | All rights reserved. 7 | 8 | Redistribution and use in source and binary forms, with or without 9 | modification, are permitted provided that the following conditions are met: 10 | 11 | * Redistributions of source code must retain the above copyright notice, this 12 | list of conditions and the following disclaimer. 13 | 14 | * Redistributions in binary form must reproduce the above copyright notice, 15 | this list of conditions and the following disclaimer in the documentation 16 | and/or other materials provided with the distribution. 17 | 18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | """ 29 | 30 | 31 | import cv2 32 | import numpy as np 33 | import torch 34 | 35 | # Usual dtypes for common modalities 36 | KEYS_TO_DTYPES = { 37 | "rgb": torch.float, 38 | "depth": torch.float, 39 | "normals": torch.float, 40 | "mask": torch.long, 41 | } 42 | 43 | 44 | class Pad(object): 45 | """Pad image and mask to the desired size. 46 | 47 | Args: 48 | size (int) : minimum length/width. 49 | img_val (array) : image padding value. 50 | msk_val (int) : mask padding value. 51 | 52 | """ 53 | 54 | def __init__(self, size, img_val, msk_val): 55 | assert isinstance(size, int) 56 | self.size = size 57 | self.img_val = img_val 58 | self.msk_val = msk_val 59 | 60 | def __call__(self, sample): 61 | image = sample["rgb"] 62 | h, w = image.shape[:2] 63 | h_pad = int(np.clip(((self.size - h) + 1) // 2, 0, 1e6)) 64 | w_pad = int(np.clip(((self.size - w) + 1) // 2, 0, 1e6)) 65 | pad = ((h_pad, h_pad), (w_pad, w_pad)) 66 | for key in sample["inputs"]: 67 | sample[key] = self.transform_input(sample[key], pad) 68 | sample["mask"] = np.pad( 69 | sample["mask"], pad, mode="constant", constant_values=self.msk_val 70 | ) 71 | return sample 72 | 73 | def transform_input(self, input, pad): 74 | input = np.stack( 75 | [ 76 | np.pad( 77 | input[:, :, c], 78 | pad, 79 | mode="constant", 80 | constant_values=self.img_val[c], 81 | ) 82 | for c in range(3) 83 | ], 84 | axis=2, 85 | ) 86 | return input 87 | 88 | 89 | class RandomCrop(object): 90 | """Crop randomly the image in a sample. 91 | 92 | Args: 93 | crop_size (int): Desired output size. 94 | 95 | """ 96 | 97 | def __init__(self, crop_size): 98 | assert isinstance(crop_size, int) 99 | self.crop_size = crop_size 100 | if self.crop_size % 2 != 0: 101 | self.crop_size -= 1 102 | 103 | def __call__(self, sample): 104 | image = sample["rgb"] 105 | h, w = image.shape[:2] 106 | new_h = min(h, self.crop_size) 107 | new_w = min(w, self.crop_size) 108 | top = np.random.randint(0, h - new_h + 1) 109 | left = np.random.randint(0, w - new_w + 1) 110 | for key in sample["inputs"]: 111 | sample[key] = self.transform_input(sample[key], top, new_h, left, new_w) 112 | sample["mask"] = sample["mask"][top : top + new_h, left : left + new_w] 113 | return sample 114 | 115 | def transform_input(self, input, top, new_h, left, new_w): 116 | input = input[top : top + new_h, left : left + new_w] 117 | return input 118 | 119 | 120 | class ResizeAndScale(object): 121 | """Resize shorter/longer side to a given value and randomly scale. 122 | 123 | Args: 124 | side (int) : shorter / longer side value. 125 | low_scale (float) : lower scaling bound. 126 | high_scale (float) : upper scaling bound. 127 | shorter (bool) : whether to resize shorter / longer side. 128 | 129 | """ 130 | 131 | def __init__(self, side, low_scale, high_scale, shorter=True): 132 | assert isinstance(side, int) 133 | assert isinstance(low_scale, float) 134 | assert isinstance(high_scale, float) 135 | self.side = side 136 | self.low_scale = low_scale 137 | self.high_scale = high_scale 138 | self.shorter = shorter 139 | 140 | def __call__(self, sample): 141 | image = sample["rgb"] 142 | scale = np.random.uniform(self.low_scale, self.high_scale) 143 | if self.shorter: 144 | min_side = min(image.shape[:2]) 145 | if min_side * scale < self.side: 146 | scale = self.side * 1.0 / min_side 147 | else: 148 | max_side = max(image.shape[:2]) 149 | if max_side * scale > self.side: 150 | scale = self.side * 1.0 / max_side 151 | inters = {"rgb": cv2.INTER_CUBIC, "depth": cv2.INTER_NEAREST} 152 | for key in sample["inputs"]: 153 | inter = inters[key] if key in inters else cv2.INTER_CUBIC 154 | sample[key] = self.transform_input(sample[key], scale, inter) 155 | sample["mask"] = cv2.resize( 156 | sample["mask"], None, fx=scale, fy=scale, interpolation=cv2.INTER_NEAREST 157 | ) 158 | return sample 159 | 160 | def transform_input(self, input, scale, inter): 161 | input = cv2.resize(input, None, fx=scale, fy=scale, interpolation=inter) 162 | return input 163 | 164 | 165 | class CropAlignToMask(object): 166 | """Crop inputs to the size of the mask.""" 167 | 168 | def __call__(self, sample): 169 | mask_h, mask_w = sample["mask"].shape[:2] 170 | for key in sample["inputs"]: 171 | sample[key] = self.transform_input(sample[key], mask_h, mask_w) 172 | return sample 173 | 174 | def transform_input(self, input, mask_h, mask_w): 175 | input_h, input_w = input.shape[:2] 176 | if (input_h, input_w) == (mask_h, mask_w): 177 | return input 178 | h, w = (input_h - mask_h) // 2, (input_w - mask_w) // 2 179 | del_h, del_w = (input_h - mask_h) % 2, (input_w - mask_w) % 2 180 | input = input[h : input_h - h - del_h, w : input_w - w - del_w] 181 | assert input.shape[:2] == (mask_h, mask_w) 182 | return input 183 | 184 | 185 | class ResizeAlignToMask(object): 186 | """Resize inputs to the size of the mask.""" 187 | 188 | def __call__(self, sample): 189 | mask_h, mask_w = sample["mask"].shape[:2] 190 | assert mask_h == mask_w 191 | inters = {"rgb": cv2.INTER_CUBIC, "depth": cv2.INTER_NEAREST} 192 | for key in sample["inputs"]: 193 | inter = inters[key] if key in inters else cv2.INTER_CUBIC 194 | sample[key] = self.transform_input(sample[key], mask_h, inter) 195 | return sample 196 | 197 | def transform_input(self, input, mask_h, inter): 198 | input_h, input_w = input.shape[:2] 199 | assert input_h == input_w 200 | scale = mask_h / input_h 201 | input = cv2.resize(input, None, fx=scale, fy=scale, interpolation=inter) 202 | return input 203 | 204 | 205 | class ResizeInputs(object): 206 | def __init__(self, size): 207 | self.size = size 208 | 209 | def __call__(self, sample): 210 | # sample['rgb'] = sample['rgb'].numpy() 211 | if self.size is None: 212 | return sample 213 | size = sample["rgb"].shape[0] 214 | scale = self.size / size 215 | # print(sample['rgb'].shape, type(sample['rgb'])) 216 | inters = {"rgb": cv2.INTER_CUBIC, "depth": cv2.INTER_NEAREST} 217 | for key in sample["inputs"]: 218 | inter = inters[key] if key in inters else cv2.INTER_CUBIC 219 | sample[key] = self.transform_input(sample[key], scale, inter) 220 | return sample 221 | 222 | def transform_input(self, input, scale, inter): 223 | input = cv2.resize(input, None, fx=scale, fy=scale, interpolation=inter) 224 | return input 225 | 226 | 227 | class ResizeInputsScale(object): 228 | def __init__(self, scale): 229 | self.scale = scale 230 | 231 | def __call__(self, sample): 232 | if self.scale is None: 233 | return sample 234 | inters = {"rgb": cv2.INTER_CUBIC, "depth": cv2.INTER_NEAREST} 235 | for key in sample["inputs"]: 236 | inter = inters[key] if key in inters else cv2.INTER_CUBIC 237 | sample[key] = self.transform_input(sample[key], self.scale, inter) 238 | return sample 239 | 240 | def transform_input(self, input, scale, inter): 241 | input = cv2.resize(input, None, fx=scale, fy=scale, interpolation=inter) 242 | return input 243 | 244 | 245 | class RandomMirror(object): 246 | """Randomly flip the image and the mask""" 247 | 248 | def __call__(self, sample): 249 | do_mirror = np.random.randint(2) 250 | if do_mirror: 251 | for key in sample["inputs"]: 252 | sample[key] = cv2.flip(sample[key], 1) 253 | sample["mask"] = cv2.flip(sample["mask"], 1) 254 | return sample 255 | 256 | 257 | class Normalise(object): 258 | """Normalise a tensor image with mean and standard deviation. 259 | Given mean: (R, G, B) and std: (R, G, B), 260 | will normalise each channel of the torch.*Tensor, i.e. 261 | channel = (scale * channel - mean) / std 262 | 263 | Args: 264 | scale (float): Scaling constant. 265 | mean (sequence): Sequence of means for R,G,B channels respecitvely. 266 | std (sequence): Sequence of standard deviations for R,G,B channels 267 | respecitvely. 268 | depth_scale (float): Depth divisor for depth annotations. 269 | 270 | """ 271 | 272 | def __init__(self, scale, mean, std, depth_scale=1.0): 273 | self.scale = scale 274 | self.mean = mean 275 | self.std = std 276 | self.depth_scale = depth_scale 277 | 278 | def __call__(self, sample): 279 | for key in sample["inputs"]: 280 | if key == "depth": 281 | continue 282 | sample[key] = (self.scale * sample[key] - self.mean) / self.std 283 | if "depth" in sample: 284 | # sample['depth'] = self.scale * sample['depth'] 285 | # sample['depth'] = (self.scale * sample['depth'] - self.mean) / self.std 286 | if self.depth_scale > 0: 287 | sample["depth"] = self.depth_scale * sample["depth"] 288 | elif self.depth_scale == -1: # taskonomy 289 | # sample['depth'] = np.log(1 + sample['depth']) / np.log(2.** 16.0) 290 | sample["depth"] = np.log(1 + sample["depth"]) 291 | elif self.depth_scale == -2: # sunrgbd 292 | depth = sample["depth"] 293 | sample["depth"] = ( 294 | (depth - depth.min()) * 255.0 / (depth.max() - depth.min()) 295 | ) 296 | return sample 297 | 298 | 299 | class ToTensor(object): 300 | """Convert ndarrays in sample to Tensors.""" 301 | 302 | def __call__(self, sample): 303 | # swap color axis because 304 | # numpy image: H x W x C 305 | # torch image: C X H X W 306 | for key in ["rgb", "depth"]: 307 | sample[key] = torch.from_numpy(sample[key].transpose((2, 0, 1))).to( 308 | KEYS_TO_DTYPES[key] if key in KEYS_TO_DTYPES else KEYS_TO_DTYPES["rgb"] 309 | ) 310 | sample["mask"] = torch.from_numpy(sample["mask"]).to(KEYS_TO_DTYPES["mask"]) 311 | return sample 312 | 313 | 314 | def make_list(x): 315 | """Returns the given input as a list.""" 316 | if isinstance(x, list): 317 | return x 318 | elif isinstance(x, tuple): 319 | return list(x) 320 | else: 321 | return [x] 322 | -------------------------------------------------------------------------------- /models/segformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from . import mix_transformer 5 | from mmcv.cnn import ConvModule 6 | from .modules import num_parallel 7 | from .swin_transformer import SwinTransformer 8 | 9 | 10 | class MLP(nn.Module): 11 | """ 12 | Linear Embedding 13 | """ 14 | 15 | def __init__(self, input_dim=2048, embed_dim=768): 16 | super().__init__() 17 | self.proj = nn.Linear(input_dim, embed_dim) 18 | 19 | def forward(self, x): 20 | x = x.flatten(2).transpose(1, 2) 21 | x = self.proj(x) 22 | return x 23 | 24 | 25 | class SegFormerHead(nn.Module): 26 | """ 27 | SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers 28 | """ 29 | 30 | def __init__( 31 | self, 32 | feature_strides=None, 33 | in_channels=128, 34 | embedding_dim=256, 35 | num_classes=20, 36 | **kwargs 37 | ): 38 | super(SegFormerHead, self).__init__() 39 | self.in_channels = in_channels 40 | self.num_classes = num_classes 41 | assert len(feature_strides) == len(self.in_channels) 42 | assert min(feature_strides) == feature_strides[0] 43 | self.feature_strides = feature_strides 44 | 45 | ( 46 | c1_in_channels, 47 | c2_in_channels, 48 | c3_in_channels, 49 | c4_in_channels, 50 | ) = self.in_channels 51 | 52 | # decoder_params = kwargs['decoder_params'] 53 | # embedding_dim = decoder_params['embed_dim'] 54 | 55 | self.linear_c4 = MLP(input_dim=c4_in_channels, embed_dim=embedding_dim) 56 | self.linear_c3 = MLP(input_dim=c3_in_channels, embed_dim=embedding_dim) 57 | self.linear_c2 = MLP(input_dim=c2_in_channels, embed_dim=embedding_dim) 58 | self.linear_c1 = MLP(input_dim=c1_in_channels, embed_dim=embedding_dim) 59 | self.dropout = nn.Dropout2d(0.1) 60 | 61 | self.linear_fuse = ConvModule( 62 | in_channels=embedding_dim * 4, 63 | out_channels=embedding_dim, 64 | kernel_size=1, 65 | norm_cfg=dict(type="BN", requires_grad=True), 66 | ) 67 | 68 | self.linear_pred = nn.Conv2d(embedding_dim, self.num_classes, kernel_size=1) 69 | 70 | def forward(self, x): 71 | c1, c2, c3, c4 = x 72 | 73 | ############## MLP decoder on C1-C4 ########### 74 | n, _, h, w = c4.shape 75 | 76 | _c4 = ( 77 | self.linear_c4(c4).permute(0, 2, 1).reshape(n, -1, c4.shape[2], c4.shape[3]) 78 | ) 79 | _c4 = F.interpolate( 80 | _c4, size=c1.size()[2:], mode="bilinear", align_corners=False 81 | ) 82 | 83 | _c3 = ( 84 | self.linear_c3(c3).permute(0, 2, 1).reshape(n, -1, c3.shape[2], c3.shape[3]) 85 | ) 86 | _c3 = F.interpolate( 87 | _c3, size=c1.size()[2:], mode="bilinear", align_corners=False 88 | ) 89 | 90 | _c2 = ( 91 | self.linear_c2(c2).permute(0, 2, 1).reshape(n, -1, c2.shape[2], c2.shape[3]) 92 | ) 93 | _c2 = F.interpolate( 94 | _c2, size=c1.size()[2:], mode="bilinear", align_corners=False 95 | ) 96 | 97 | _c1 = ( 98 | self.linear_c1(c1).permute(0, 2, 1).reshape(n, -1, c1.shape[2], c1.shape[3]) 99 | ) 100 | 101 | _c = self.linear_fuse(torch.cat([_c4, _c3, _c2, _c1], dim=1)) 102 | 103 | x = self.dropout(_c) 104 | x = self.linear_pred(x) 105 | 106 | return x 107 | 108 | 109 | class TransformerBackbone(nn.Module): 110 | def __init__( 111 | self, 112 | backbone: str, 113 | train_backbone: bool, 114 | return_interm_layers: bool, 115 | drop_path_rate, 116 | pretrained_backbone_path, 117 | ): 118 | super().__init__() 119 | out_indices = (0, 1, 2, 3) 120 | if backbone == "swin_tiny": 121 | backbone = SwinTransformer( 122 | embed_dim=96, 123 | depths=[2, 2, 6, 2], 124 | num_heads=[3, 6, 12, 24], 125 | window_size=7, 126 | ape=False, 127 | drop_path_rate=drop_path_rate, 128 | patch_norm=True, 129 | use_checkpoint=False, 130 | out_indices=out_indices, 131 | ) 132 | embed_dim = 96 133 | backbone.init_weights(pretrained_backbone_path) 134 | elif backbone == "swin_small": 135 | backbone = SwinTransformer( 136 | embed_dim=96, 137 | depths=[2, 2, 18, 2], 138 | num_heads=[3, 6, 12, 24], 139 | window_size=7, 140 | ape=False, 141 | drop_path_rate=drop_path_rate, 142 | patch_norm=True, 143 | use_checkpoint=False, 144 | out_indices=out_indices, 145 | ) 146 | embed_dim = 96 147 | backbone.init_weights(pretrained_backbone_path) 148 | elif backbone == "swin_large": 149 | backbone = SwinTransformer( 150 | embed_dim=192, 151 | depths=[2, 2, 18, 2], 152 | num_heads=[6, 12, 24, 48], 153 | window_size=7, 154 | ape=False, 155 | drop_path_rate=drop_path_rate, 156 | patch_norm=True, 157 | use_checkpoint=False, 158 | out_indices=out_indices, 159 | ) 160 | embed_dim = 192 161 | backbone.init_weights(pretrained_backbone_path) 162 | elif backbone == "swin_large_window12": 163 | backbone = SwinTransformer( 164 | pretrain_img_size=384, 165 | embed_dim=192, 166 | depths=[2, 2, 18, 2], 167 | num_heads=[6, 12, 24, 48], 168 | window_size=12, 169 | ape=False, 170 | drop_path_rate=drop_path_rate, 171 | patch_norm=True, 172 | use_checkpoint=False, 173 | out_indices=out_indices, 174 | ) 175 | embed_dim = 192 176 | backbone.init_weights(pretrained_backbone_path) 177 | elif backbone == "swin_large_window12_to_1k": 178 | backbone = SwinTransformer( 179 | pretrain_img_size=384, 180 | embed_dim=192, 181 | depths=[2, 2, 18, 2], 182 | num_heads=[6, 12, 24, 48], 183 | window_size=12, 184 | ape=False, 185 | drop_path_rate=drop_path_rate, 186 | patch_norm=True, 187 | use_checkpoint=False, 188 | out_indices=out_indices, 189 | ) 190 | embed_dim = 192 191 | backbone.init_weights(pretrained_backbone_path) 192 | else: 193 | raise NotImplementedError 194 | 195 | for name, parameter in backbone.named_parameters(): 196 | # TODO: freeze some layers? 197 | if not train_backbone: 198 | parameter.requires_grad_(False) 199 | 200 | if return_interm_layers: 201 | 202 | self.strides = [8, 16, 32] 203 | self.num_channels = [ 204 | embed_dim * 2, 205 | embed_dim * 4, 206 | embed_dim * 8, 207 | ] 208 | else: 209 | self.strides = [32] 210 | self.num_channels = [embed_dim * 8] 211 | 212 | self.body = backbone 213 | 214 | def forward(self, input): 215 | xs = self.body(input) 216 | 217 | return xs 218 | 219 | 220 | class WeTr(nn.Module): 221 | def __init__( 222 | self, 223 | backbone, 224 | num_classes=20, 225 | n_heads=8, 226 | dpr=0.1, 227 | drop_rate=0.0, 228 | ): 229 | super().__init__() 230 | self.num_classes = num_classes 231 | self.embedding_dim = 256 232 | self.feature_strides = [4, 8, 16, 32] 233 | self.num_parallel = num_parallel 234 | self.backbone = backbone 235 | 236 | print("-----------------Model Params--------------------------------------") 237 | print("backbone:", backbone) 238 | print("dpr:", dpr) 239 | print("--------------------------------------------------------------") 240 | 241 | if "swin" in backbone: 242 | if backbone == "swin_tiny": 243 | pretrained_backbone_path = "pretrained/swin_tiny_patch4_window7_224.pth" 244 | self.in_channels = [96, 192, 384, 768] 245 | elif backbone == "swin_small": 246 | pretrained_backbone_path = ( 247 | "pretrained/swin_small_patch4_window7_224.pth" 248 | ) 249 | self.in_channels = [96, 192, 384, 768] 250 | elif backbone == "swin_large_window12": 251 | pretrained_backbone_path = ( 252 | "pretrained/swin_large_patch4_window12_384_22k.pth" 253 | ) 254 | self.in_channels = [192, 384, 768, 1536] 255 | elif backbone == "swin_large_window12_to_1k": 256 | pretrained_backbone_path = ( 257 | "pretrained/swin_large_patch4_window12_384_22kto1k.pth" 258 | ) 259 | self.in_channels = [192, 384, 768, 1536] 260 | else: 261 | assert backbone == "swin_large" 262 | pretrained_backbone_path = ( 263 | "pretrained/swin_large_patch4_window7_224_22k.pth" 264 | ) 265 | self.in_channels = [192, 384, 768, 1536] 266 | self.encoder = TransformerBackbone( 267 | backbone, True, True, dpr, pretrained_backbone_path 268 | ) 269 | else: 270 | self.encoder = getattr(mix_transformer, backbone)(n_heads, dpr, drop_rate) 271 | self.in_channels = self.encoder.embed_dims 272 | ## initilize encoder 273 | state_dict = torch.load("pretrained/" + backbone + ".pth") 274 | state_dict.pop("head.weight") 275 | state_dict.pop("head.bias") 276 | state_dict = expand_state_dict( 277 | self.encoder.state_dict(), state_dict, self.num_parallel 278 | ) 279 | self.encoder.load_state_dict(state_dict, strict=True) 280 | 281 | self.decoder = SegFormerHead( 282 | feature_strides=self.feature_strides, 283 | in_channels=self.in_channels, 284 | embedding_dim=self.embedding_dim, 285 | num_classes=self.num_classes, 286 | ) 287 | 288 | self.alpha = nn.Parameter(torch.ones(self.num_parallel, requires_grad=True)) 289 | self.register_parameter("alpha", self.alpha) 290 | 291 | def get_param_groups(self): 292 | param_groups = [[], [], []] 293 | for name, param in list(self.encoder.named_parameters()): 294 | if "norm" in name: 295 | param_groups[1].append(param) 296 | else: 297 | param_groups[0].append(param) 298 | for param in list(self.decoder.parameters()): 299 | param_groups[2].append(param) 300 | return param_groups 301 | 302 | def forward(self, x): 303 | 304 | x = self.encoder(x) 305 | 306 | x = [self.decoder(x[0]), self.decoder(x[1])] 307 | ens = 0 308 | alpha_soft = F.softmax(self.alpha) 309 | for l in range(self.num_parallel): 310 | ens += alpha_soft[l] * x[l].detach() 311 | x.append(ens) 312 | return x, None 313 | 314 | 315 | def expand_state_dict(model_dict, state_dict, num_parallel): 316 | model_dict_keys = model_dict.keys() 317 | state_dict_keys = state_dict.keys() 318 | for model_dict_key in model_dict_keys: 319 | model_dict_key_re = model_dict_key.replace("module.", "") 320 | if model_dict_key_re in state_dict_keys: 321 | model_dict[model_dict_key] = state_dict[model_dict_key_re] 322 | for i in range(num_parallel): 323 | ln = ".ln_%d" % i 324 | replace = True if ln in model_dict_key_re else False 325 | model_dict_key_re = model_dict_key_re.replace(ln, "") 326 | if replace and model_dict_key_re in state_dict_keys: 327 | model_dict[model_dict_key] = state_dict[model_dict_key_re] 328 | return model_dict 329 | 330 | 331 | if __name__ == "__main__": 332 | pretrained_weights = torch.load("pretrained/mit_b1.pth") 333 | wetr = WeTr("mit_b1", num_classes=20, embedding_dim=256, pretrained=True).cuda() 334 | wetr.get_param_groupsv() 335 | dummy_input = torch.rand(2, 3, 512, 512).cuda() 336 | wetr(dummy_input) 337 | -------------------------------------------------------------------------------- /utils/augmentations_mm.py: -------------------------------------------------------------------------------- 1 | import torchvision.transforms.functional as TF 2 | import random 3 | import math 4 | import torch 5 | from torch import Tensor 6 | from typing import Tuple, List, Union, Tuple, Optional 7 | 8 | 9 | class Compose: 10 | def __init__(self, transforms: list) -> None: 11 | self.transforms = transforms 12 | 13 | def __call__(self, sample: list) -> list: 14 | img, mask = sample["rgb"], sample["mask"] 15 | if mask.ndim == 2: 16 | assert img.shape[1:] == mask.shape 17 | else: 18 | assert img.shape[1:] == mask.shape[1:] 19 | 20 | for transform in self.transforms: 21 | sample = transform(sample) 22 | 23 | return sample 24 | 25 | 26 | class Normalize: 27 | def __init__( 28 | self, mean: list = (0.485, 0.456, 0.406), std: list = (0.229, 0.224, 0.225) 29 | ): 30 | self.mean = mean 31 | self.std = std 32 | 33 | def __call__(self, sample: list) -> list: 34 | for k, v in sample.items(): 35 | if k == "mask": 36 | continue 37 | elif k == "rgb": 38 | sample[k] = sample[k].float() 39 | sample[k] /= 255 40 | sample[k] = TF.normalize(sample[k], self.mean, self.std) 41 | else: 42 | sample[k] = sample[k].float() 43 | sample[k] /= 255 44 | 45 | return sample 46 | 47 | 48 | class RandomColorJitter: 49 | def __init__(self, p=0.5) -> None: 50 | self.p = p 51 | 52 | def __call__(self, sample: list) -> list: 53 | if random.random() < self.p: 54 | self.brightness = random.uniform(0.5, 1.5) 55 | sample["rgb"] = TF.adjust_brightness(sample["rgb"], self.brightness) 56 | self.contrast = random.uniform(0.5, 1.5) 57 | sample["rgb"] = TF.adjust_contrast(sample["rgb"], self.contrast) 58 | self.saturation = random.uniform(0.5, 1.5) 59 | sample["rgb"] = TF.adjust_saturation(sample["rgb"], self.saturation) 60 | return sample 61 | 62 | 63 | class AdjustGamma: 64 | def __init__(self, gamma: float, gain: float = 1) -> None: 65 | """ 66 | Args: 67 | gamma: Non-negative real number. gamma larger than 1 make the shadows darker, while gamma smaller than 1 make dark regions lighter. 68 | gain: constant multiplier 69 | """ 70 | self.gamma = gamma 71 | self.gain = gain 72 | 73 | def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]: 74 | return TF.adjust_gamma(img, self.gamma, self.gain), mask 75 | 76 | 77 | class RandomAdjustSharpness: 78 | def __init__(self, sharpness_factor: float, p: float = 0.5) -> None: 79 | self.sharpness = sharpness_factor 80 | self.p = p 81 | 82 | def __call__(self, sample: list) -> list: 83 | if random.random() < self.p: 84 | sample["rgb"] = TF.adjust_sharpness(sample["rgb"], self.sharpness) 85 | return sample 86 | 87 | 88 | class RandomAutoContrast: 89 | def __init__(self, p: float = 0.5) -> None: 90 | self.p = p 91 | 92 | def __call__(self, sample: list) -> list: 93 | if random.random() < self.p: 94 | sample["rgb"] = TF.autocontrast(sample["rgb"]) 95 | return sample 96 | 97 | 98 | class RandomGaussianBlur: 99 | def __init__(self, kernel_size: int = 3, p: float = 0.5) -> None: 100 | self.kernel_size = kernel_size 101 | self.p = p 102 | 103 | def __call__(self, sample: list) -> list: 104 | if random.random() < self.p: 105 | sample["rgb"] = TF.gaussian_blur(sample["rgb"], self.kernel_size) 106 | # img = TF.gaussian_blur(img, self.kernel_size) 107 | return sample 108 | 109 | 110 | class RandomHorizontalFlip: 111 | def __init__(self, p: float = 0.5) -> None: 112 | self.p = p 113 | 114 | def __call__(self, sample: list) -> list: 115 | if random.random() < self.p: 116 | for k, v in sample.items(): 117 | sample[k] = TF.hflip(v) 118 | return sample 119 | return sample 120 | 121 | 122 | class RandomVerticalFlip: 123 | def __init__(self, p: float = 0.5) -> None: 124 | self.p = p 125 | 126 | def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]: 127 | if random.random() < self.p: 128 | return TF.vflip(img), TF.vflip(mask) 129 | return img, mask 130 | 131 | 132 | class RandomGrayscale: 133 | def __init__(self, p: float = 0.5) -> None: 134 | self.p = p 135 | 136 | def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]: 137 | if random.random() < self.p: 138 | img = TF.rgb_to_grayscale(img, 3) 139 | return img, mask 140 | 141 | 142 | class Equalize: 143 | def __call__(self, image, label): 144 | return TF.equalize(image), label 145 | 146 | 147 | class Posterize: 148 | def __init__(self, bits=2): 149 | self.bits = bits # 0-8 150 | 151 | def __call__(self, image, label): 152 | return TF.posterize(image, self.bits), label 153 | 154 | 155 | class Affine: 156 | def __init__(self, angle=0, translate=[0, 0], scale=1.0, shear=[0, 0], seg_fill=0): 157 | self.angle = angle 158 | self.translate = translate 159 | self.scale = scale 160 | self.shear = shear 161 | self.seg_fill = seg_fill 162 | 163 | def __call__(self, img, label): 164 | return TF.affine( 165 | img, 166 | self.angle, 167 | self.translate, 168 | self.scale, 169 | self.shear, 170 | TF.InterpolationMode.BILINEAR, 171 | 0, 172 | ), TF.affine( 173 | label, 174 | self.angle, 175 | self.translate, 176 | self.scale, 177 | self.shear, 178 | TF.InterpolationMode.NEAREST, 179 | self.seg_fill, 180 | ) 181 | 182 | 183 | class RandomRotation: 184 | def __init__( 185 | self, 186 | degrees: float = 10.0, 187 | p: float = 0.2, 188 | seg_fill: int = 0, 189 | expand: bool = False, 190 | ) -> None: 191 | """Rotate the image by a random angle between -angle and angle with probability p 192 | 193 | Args: 194 | p: probability 195 | angle: rotation angle value in degrees, counter-clockwise. 196 | expand: Optional expansion flag. 197 | If true, expands the output image to make it large enough to hold the entire rotated image. 198 | If false or omitted, make the output image the same size as the input image. 199 | Note that the expand flag assumes rotation around the center and no translation. 200 | """ 201 | self.p = p 202 | self.angle = degrees 203 | self.expand = expand 204 | self.seg_fill = seg_fill 205 | 206 | def __call__(self, sample: list) -> list: 207 | random_angle = random.random() * 2 * self.angle - self.angle 208 | if random.random() < self.p: 209 | for k, v in sample.items(): 210 | if k == "mask": 211 | sample[k] = TF.rotate( 212 | v, 213 | random_angle, 214 | TF.InterpolationMode.NEAREST, 215 | self.expand, 216 | fill=self.seg_fill, 217 | ) 218 | else: 219 | sample[k] = TF.rotate( 220 | v, 221 | random_angle, 222 | TF.InterpolationMode.BILINEAR, 223 | self.expand, 224 | fill=0, 225 | ) 226 | # img = TF.rotate(img, random_angle, TF.InterpolationMode.BILINEAR, self.expand, fill=0) 227 | # mask = TF.rotate(mask, random_angle, TF.InterpolationMode.NEAREST, self.expand, fill=self.seg_fill) 228 | return sample 229 | 230 | 231 | class CenterCrop: 232 | def __init__(self, size: Union[int, List[int], Tuple[int]]) -> None: 233 | """Crops the image at the center 234 | 235 | Args: 236 | output_size: height and width of the crop box. If int, this size is used for both directions. 237 | """ 238 | self.size = (size, size) if isinstance(size, int) else size 239 | 240 | def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]: 241 | return TF.center_crop(img, self.size), TF.center_crop(mask, self.size) 242 | 243 | 244 | class RandomCrop: 245 | def __init__(self, size: Union[int, List[int], Tuple[int]], p: float = 0.5) -> None: 246 | """Randomly Crops the image. 247 | 248 | Args: 249 | output_size: height and width of the crop box. If int, this size is used for both directions. 250 | """ 251 | self.size = (size, size) if isinstance(size, int) else size 252 | self.p = p 253 | 254 | def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]: 255 | H, W = img.shape[1:] 256 | tH, tW = self.size 257 | 258 | if random.random() < self.p: 259 | margin_h = max(H - tH, 0) 260 | margin_w = max(W - tW, 0) 261 | y1 = random.randint(0, margin_h + 1) 262 | x1 = random.randint(0, margin_w + 1) 263 | y2 = y1 + tH 264 | x2 = x1 + tW 265 | img = img[:, y1:y2, x1:x2] 266 | mask = mask[:, y1:y2, x1:x2] 267 | return img, mask 268 | 269 | 270 | class Pad: 271 | def __init__( 272 | self, size: Union[List[int], Tuple[int], int], seg_fill: int = 0 273 | ) -> None: 274 | """Pad the given image on all sides with the given "pad" value. 275 | Args: 276 | size: expected output image size (h, w) 277 | fill: Pixel fill value for constant fill. Default is 0. This value is only used when the padding mode is constant. 278 | """ 279 | self.size = size 280 | self.seg_fill = seg_fill 281 | 282 | def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]: 283 | padding = (0, 0, self.size[1] - img.shape[2], self.size[0] - img.shape[1]) 284 | return TF.pad(img, padding), TF.pad(mask, padding, self.seg_fill) 285 | 286 | 287 | class ResizePad: 288 | def __init__( 289 | self, size: Union[int, Tuple[int], List[int]], seg_fill: int = 0 290 | ) -> None: 291 | """Resize the input image to the given size. 292 | Args: 293 | size: Desired output size. 294 | If size is a sequence, the output size will be matched to this. 295 | If size is an int, the smaller edge of the image will be matched to this number maintaining the aspect ratio. 296 | """ 297 | self.size = size 298 | self.seg_fill = seg_fill 299 | 300 | def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]: 301 | H, W = img.shape[1:] 302 | tH, tW = self.size 303 | 304 | # scale the image 305 | scale_factor = min(tH / H, tW / W) if W > H else max(tH / H, tW / W) 306 | # nH, nW = int(H * scale_factor + 0.5), int(W * scale_factor + 0.5) 307 | nH, nW = round(H * scale_factor), round(W * scale_factor) 308 | img = TF.resize(img, (nH, nW), TF.InterpolationMode.BILINEAR) 309 | mask = TF.resize(mask, (nH, nW), TF.InterpolationMode.NEAREST) 310 | 311 | # pad the image 312 | padding = [0, 0, tW - nW, tH - nH] 313 | img = TF.pad(img, padding, fill=0) 314 | mask = TF.pad(mask, padding, fill=self.seg_fill) 315 | return img, mask 316 | 317 | 318 | class Resize: 319 | def __init__(self, size: Union[int, Tuple[int], List[int]]) -> None: 320 | """Resize the input image to the given size. 321 | Args: 322 | size: Desired output size. 323 | If size is a sequence, the output size will be matched to this. 324 | If size is an int, the smaller edge of the image will be matched to this number maintaining the aspect ratio. 325 | """ 326 | self.size = size 327 | 328 | def __call__(self, sample: list) -> list: 329 | H, W = sample["rgb"].shape[1:] 330 | 331 | # scale the image 332 | scale_factor = self.size[0] / min(H, W) 333 | nH, nW = round(H * scale_factor), round(W * scale_factor) 334 | for k, v in sample.items(): 335 | if k == "mask": 336 | sample[k] = TF.resize( 337 | v.unsqueeze(0), (nH, nW), TF.InterpolationMode.NEAREST 338 | ).squeeze(0) 339 | else: 340 | sample[k] = TF.resize(v, (nH, nW), TF.InterpolationMode.BILINEAR) 341 | # img = TF.resize(img, (nH, nW), TF.InterpolationMode.BILINEAR) 342 | # mask = TF.resize(mask, (nH, nW), TF.InterpolationMode.NEAREST) 343 | 344 | # make the image divisible by stride 345 | alignH, alignW = int(math.ceil(nH / 32)) * 32, int(math.ceil(nW / 32)) * 32 346 | 347 | for k, v in sample.items(): 348 | if k == "mask": 349 | sample[k] = TF.resize( 350 | v.unsqueeze(0), (alignH, alignW), TF.InterpolationMode.NEAREST 351 | ).squeeze(0) 352 | else: 353 | sample[k] = TF.resize( 354 | v, (alignH, alignW), TF.InterpolationMode.BILINEAR 355 | ) 356 | # img = TF.resize(img, (alignH, alignW), TF.InterpolationMode.BILINEAR) 357 | # mask = TF.resize(mask, (alignH, alignW), TF.InterpolationMode.NEAREST) 358 | return sample 359 | 360 | 361 | class RandomResizedCrop: 362 | def __init__( 363 | self, 364 | size: Union[int, Tuple[int], List[int]], 365 | scale: Tuple[float, float] = (0.5, 2.0), 366 | seg_fill: int = 0, 367 | ) -> None: 368 | """Resize the input image to the given size.""" 369 | self.size = size 370 | self.scale = scale 371 | self.seg_fill = seg_fill 372 | 373 | def __call__(self, sample: list) -> list: 374 | # img, mask = sample['rgb'], sample['mask'] 375 | H, W = sample["rgb"].shape[1:] 376 | tH, tW = self.size 377 | 378 | # get the scale 379 | ratio = random.random() * (self.scale[1] - self.scale[0]) + self.scale[0] 380 | # ratio = random.uniform(min(self.scale), max(self.scale)) 381 | scale = int(tH * ratio), int(tW * 4 * ratio) 382 | # scale the image 383 | scale_factor = min(max(scale) / max(H, W), min(scale) / min(H, W)) 384 | nH, nW = int(H * scale_factor + 0.5), int(W * scale_factor + 0.5) 385 | # nH, nW = int(math.ceil(nH / 32)) * 32, int(math.ceil(nW / 32)) * 32 386 | for k, v in sample.items(): 387 | if k == "mask": 388 | sample[k] = TF.resize( 389 | v.unsqueeze(0), 390 | (nH, nW), 391 | TF.InterpolationMode.NEAREST, 392 | ).squeeze(0) 393 | else: 394 | sample[k] = TF.resize(v, (nH, nW), TF.InterpolationMode.BILINEAR) 395 | 396 | # random crop 397 | margin_h = max(sample["rgb"].shape[1] - tH, 0) 398 | margin_w = max(sample["rgb"].shape[2] - tW, 0) 399 | y1 = random.randint(0, margin_h + 1) 400 | x1 = random.randint(0, margin_w + 1) 401 | y2 = y1 + tH 402 | x2 = x1 + tW 403 | for k, v in sample.items(): 404 | # print("before_1:", k, sample[k].shape) 405 | if len(v.shape) == 3: 406 | sample[k] = v[:, y1:y2, x1:x2] 407 | else: 408 | sample[k] = v[y1:y2, x1:x2] 409 | # print("after_1:", k, sample[k].shape) 410 | 411 | # pad the image 412 | if sample["rgb"].shape[1:] != self.size: 413 | padding = [ 414 | 0, 415 | 0, 416 | tW - sample["rgb"].shape[2], 417 | tH - sample["rgb"].shape[1], 418 | ] 419 | for k, v in sample.items(): 420 | if k == "mask": 421 | sample[k] = TF.pad(v, padding, fill=self.seg_fill) 422 | else: 423 | sample[k] = TF.pad(v, padding, fill=0) 424 | 425 | return sample 426 | 427 | 428 | def get_train_augmentation(size: Union[int, Tuple[int], List[int]], seg_fill: int = 0): 429 | return Compose( 430 | [ 431 | RandomColorJitter(p=0.2), # 432 | RandomHorizontalFlip(p=0.5), # 433 | RandomGaussianBlur((3, 3), p=0.2), # 434 | RandomResizedCrop(size, scale=(0.5, 2.0), seg_fill=seg_fill), # 435 | Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 436 | ] 437 | ) 438 | 439 | 440 | def get_val_augmentation(size: Union[int, Tuple[int], List[int]]): 441 | return Compose( 442 | [Resize(size), Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))] 443 | ) 444 | 445 | 446 | if __name__ == "__main__": 447 | h = 230 448 | w = 420 449 | sample = {} 450 | sample["rgb"] = torch.randn(3, h, w) 451 | sample["depth"] = torch.randn(3, h, w) 452 | sample["lidar"] = torch.randn(3, h, w) 453 | sample["event"] = torch.randn(3, h, w) 454 | sample["mask"] = torch.randn(1, h, w) 455 | aug = Compose( 456 | [ 457 | RandomHorizontalFlip(p=0.5), 458 | RandomResizedCrop((512, 512)), 459 | Resize((224, 224)), 460 | Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 461 | ] 462 | ) 463 | sample = aug(sample) 464 | for k, v in sample.items(): 465 | print(k, v.shape) 466 | -------------------------------------------------------------------------------- /mmcv_custom/checkpoint.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Open-MMLab. All rights reserved. 2 | import io 3 | import os 4 | import os.path as osp 5 | import pkgutil 6 | import time 7 | import warnings 8 | from collections import OrderedDict 9 | from importlib import import_module 10 | from tempfile import TemporaryDirectory 11 | 12 | import torch 13 | import torchvision 14 | from torch.optim import Optimizer 15 | from torch.utils import model_zoo 16 | from torch.nn import functional as F 17 | 18 | import mmcv 19 | from mmcv.fileio import FileClient 20 | from mmcv.fileio import load as load_file 21 | from mmcv.parallel import is_module_wrapper 22 | from mmcv.utils import mkdir_or_exist 23 | from mmcv.runner import get_dist_info 24 | 25 | ENV_MMCV_HOME = "MMCV_HOME" 26 | ENV_XDG_CACHE_HOME = "XDG_CACHE_HOME" 27 | DEFAULT_CACHE_DIR = "~/.cache" 28 | 29 | 30 | def _get_mmcv_home(): 31 | mmcv_home = os.path.expanduser( 32 | os.getenv( 33 | ENV_MMCV_HOME, 34 | os.path.join(os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), "mmcv"), 35 | ) 36 | ) 37 | 38 | mkdir_or_exist(mmcv_home) 39 | return mmcv_home 40 | 41 | 42 | def load_state_dict(module, state_dict, strict=False, logger=None): 43 | """Load state_dict to a module. 44 | 45 | This method is modified from :meth:`torch.nn.Module.load_state_dict`. 46 | Default value for ``strict`` is set to ``False`` and the message for 47 | param mismatch will be shown even if strict is False. 48 | 49 | Args: 50 | module (Module): Module that receives the state_dict. 51 | state_dict (OrderedDict): Weights. 52 | strict (bool): whether to strictly enforce that the keys 53 | in :attr:`state_dict` match the keys returned by this module's 54 | :meth:`~torch.nn.Module.state_dict` function. Default: ``False``. 55 | logger (:obj:`logging.Logger`, optional): Logger to log the error 56 | message. If not specified, print function will be used. 57 | """ 58 | unexpected_keys = [] 59 | all_missing_keys = [] 60 | err_msg = [] 61 | 62 | metadata = getattr(state_dict, "_metadata", None) 63 | state_dict = state_dict.copy() 64 | if metadata is not None: 65 | state_dict._metadata = metadata 66 | 67 | # use _load_from_state_dict to enable checkpoint version control 68 | def load(module, prefix=""): 69 | # recursively check parallel module in case that the model has a 70 | # complicated structure, e.g., nn.Module(nn.Module(DDP)) 71 | if is_module_wrapper(module): 72 | module = module.module 73 | local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) 74 | module._load_from_state_dict( 75 | state_dict, 76 | prefix, 77 | local_metadata, 78 | True, 79 | all_missing_keys, 80 | unexpected_keys, 81 | err_msg, 82 | ) 83 | for name, child in module._modules.items(): 84 | if child is not None: 85 | load(child, prefix + name + ".") 86 | 87 | load(module) 88 | load = None # break load->load reference cycle 89 | 90 | # ignore "num_batches_tracked" of BN layers 91 | missing_keys = [key for key in all_missing_keys if "num_batches_tracked" not in key] 92 | 93 | if unexpected_keys: 94 | err_msg.append( 95 | "unexpected key in source " f'state_dict: {", ".join(unexpected_keys)}\n' 96 | ) 97 | if missing_keys: 98 | err_msg.append( 99 | f'missing keys in source state_dict: {", ".join(missing_keys)}\n' 100 | ) 101 | 102 | rank, _ = get_dist_info() 103 | if len(err_msg) > 0 and rank == 0: 104 | err_msg.insert(0, "The model and loaded state dict do not match exactly\n") 105 | err_msg = "\n".join(err_msg) 106 | if strict: 107 | raise RuntimeError(err_msg) 108 | elif logger is not None: 109 | logger.warning(err_msg) 110 | else: 111 | print(err_msg) 112 | 113 | 114 | def load_url_dist(url, model_dir=None): 115 | """In distributed setting, this function only download checkpoint at local 116 | rank 0.""" 117 | rank, world_size = get_dist_info() 118 | rank = int(os.environ.get("LOCAL_RANK", rank)) 119 | if rank == 0: 120 | checkpoint = model_zoo.load_url(url, model_dir=model_dir) 121 | if world_size > 1: 122 | torch.distributed.barrier() 123 | if rank > 0: 124 | checkpoint = model_zoo.load_url(url, model_dir=model_dir) 125 | return checkpoint 126 | 127 | 128 | def load_pavimodel_dist(model_path, map_location=None): 129 | """In distributed setting, this function only download checkpoint at local 130 | rank 0.""" 131 | try: 132 | from pavi import modelcloud 133 | except ImportError: 134 | raise ImportError("Please install pavi to load checkpoint from modelcloud.") 135 | rank, world_size = get_dist_info() 136 | rank = int(os.environ.get("LOCAL_RANK", rank)) 137 | if rank == 0: 138 | model = modelcloud.get(model_path) 139 | with TemporaryDirectory() as tmp_dir: 140 | downloaded_file = osp.join(tmp_dir, model.name) 141 | model.download(downloaded_file) 142 | checkpoint = torch.load(downloaded_file, map_location=map_location) 143 | if world_size > 1: 144 | torch.distributed.barrier() 145 | if rank > 0: 146 | model = modelcloud.get(model_path) 147 | with TemporaryDirectory() as tmp_dir: 148 | downloaded_file = osp.join(tmp_dir, model.name) 149 | model.download(downloaded_file) 150 | checkpoint = torch.load(downloaded_file, map_location=map_location) 151 | return checkpoint 152 | 153 | 154 | def load_fileclient_dist(filename, backend, map_location): 155 | """In distributed setting, this function only download checkpoint at local 156 | rank 0.""" 157 | rank, world_size = get_dist_info() 158 | rank = int(os.environ.get("LOCAL_RANK", rank)) 159 | allowed_backends = ["ceph"] 160 | if backend not in allowed_backends: 161 | raise ValueError(f"Load from Backend {backend} is not supported.") 162 | if rank == 0: 163 | fileclient = FileClient(backend=backend) 164 | buffer = io.BytesIO(fileclient.get(filename)) 165 | checkpoint = torch.load(buffer, map_location=map_location) 166 | if world_size > 1: 167 | torch.distributed.barrier() 168 | if rank > 0: 169 | fileclient = FileClient(backend=backend) 170 | buffer = io.BytesIO(fileclient.get(filename)) 171 | checkpoint = torch.load(buffer, map_location=map_location) 172 | return checkpoint 173 | 174 | 175 | def get_torchvision_models(): 176 | model_urls = dict() 177 | for _, name, ispkg in pkgutil.walk_packages(torchvision.models.__path__): 178 | if ispkg: 179 | continue 180 | _zoo = import_module(f"torchvision.models.{name}") 181 | if hasattr(_zoo, "model_urls"): 182 | _urls = getattr(_zoo, "model_urls") 183 | model_urls.update(_urls) 184 | return model_urls 185 | 186 | 187 | def get_external_models(): 188 | mmcv_home = _get_mmcv_home() 189 | default_json_path = osp.join(mmcv.__path__[0], "model_zoo/open_mmlab.json") 190 | default_urls = load_file(default_json_path) 191 | assert isinstance(default_urls, dict) 192 | external_json_path = osp.join(mmcv_home, "open_mmlab.json") 193 | if osp.exists(external_json_path): 194 | external_urls = load_file(external_json_path) 195 | assert isinstance(external_urls, dict) 196 | default_urls.update(external_urls) 197 | 198 | return default_urls 199 | 200 | 201 | def get_mmcls_models(): 202 | mmcls_json_path = osp.join(mmcv.__path__[0], "model_zoo/mmcls.json") 203 | mmcls_urls = load_file(mmcls_json_path) 204 | 205 | return mmcls_urls 206 | 207 | 208 | def get_deprecated_model_names(): 209 | deprecate_json_path = osp.join(mmcv.__path__[0], "model_zoo/deprecated.json") 210 | deprecate_urls = load_file(deprecate_json_path) 211 | assert isinstance(deprecate_urls, dict) 212 | 213 | return deprecate_urls 214 | 215 | 216 | def _process_mmcls_checkpoint(checkpoint): 217 | state_dict = checkpoint["state_dict"] 218 | new_state_dict = OrderedDict() 219 | for k, v in state_dict.items(): 220 | if k.startswith("backbone."): 221 | new_state_dict[k[9:]] = v 222 | new_checkpoint = dict(state_dict=new_state_dict) 223 | 224 | return new_checkpoint 225 | 226 | 227 | def _load_checkpoint(filename, map_location=None): 228 | """Load checkpoint from somewhere (modelzoo, file, url). 229 | 230 | Args: 231 | filename (str): Accept local filepath, URL, ``torchvision://xxx``, 232 | ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for 233 | details. 234 | map_location (str | None): Same as :func:`torch.load`. Default: None. 235 | 236 | Returns: 237 | dict | OrderedDict: The loaded checkpoint. It can be either an 238 | OrderedDict storing model weights or a dict containing other 239 | information, which depends on the checkpoint. 240 | """ 241 | if filename.startswith("modelzoo://"): 242 | warnings.warn( 243 | 'The URL scheme of "modelzoo://" is deprecated, please ' 244 | 'use "torchvision://" instead' 245 | ) 246 | model_urls = get_torchvision_models() 247 | model_name = filename[11:] 248 | checkpoint = load_url_dist(model_urls[model_name]) 249 | elif filename.startswith("torchvision://"): 250 | model_urls = get_torchvision_models() 251 | model_name = filename[14:] 252 | checkpoint = load_url_dist(model_urls[model_name]) 253 | elif filename.startswith("open-mmlab://"): 254 | model_urls = get_external_models() 255 | model_name = filename[13:] 256 | deprecated_urls = get_deprecated_model_names() 257 | if model_name in deprecated_urls: 258 | warnings.warn( 259 | f"open-mmlab://{model_name} is deprecated in favor " 260 | f"of open-mmlab://{deprecated_urls[model_name]}" 261 | ) 262 | model_name = deprecated_urls[model_name] 263 | model_url = model_urls[model_name] 264 | # check if is url 265 | if model_url.startswith(("http://", "https://")): 266 | checkpoint = load_url_dist(model_url) 267 | else: 268 | filename = osp.join(_get_mmcv_home(), model_url) 269 | if not osp.isfile(filename): 270 | raise IOError(f"{filename} is not a checkpoint file") 271 | checkpoint = torch.load(filename, map_location=map_location) 272 | elif filename.startswith("mmcls://"): 273 | model_urls = get_mmcls_models() 274 | model_name = filename[8:] 275 | checkpoint = load_url_dist(model_urls[model_name]) 276 | checkpoint = _process_mmcls_checkpoint(checkpoint) 277 | elif filename.startswith(("http://", "https://")): 278 | checkpoint = load_url_dist(filename) 279 | elif filename.startswith("pavi://"): 280 | model_path = filename[7:] 281 | checkpoint = load_pavimodel_dist(model_path, map_location=map_location) 282 | elif filename.startswith("s3://"): 283 | checkpoint = load_fileclient_dist( 284 | filename, backend="ceph", map_location=map_location 285 | ) 286 | else: 287 | if not osp.isfile(filename): 288 | raise IOError(f"{filename} is not a checkpoint file") 289 | checkpoint = torch.load(filename, map_location=map_location) 290 | return checkpoint 291 | 292 | 293 | def load_checkpoint(model, filename, map_location="cpu", strict=False, logger=None): 294 | """Load checkpoint from a file or URI. 295 | 296 | Args: 297 | model (Module): Module to load checkpoint. 298 | filename (str): Accept local filepath, URL, ``torchvision://xxx``, 299 | ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for 300 | details. 301 | map_location (str): Same as :func:`torch.load`. 302 | strict (bool): Whether to allow different params for the model and 303 | checkpoint. 304 | logger (:mod:`logging.Logger` or None): The logger for error message. 305 | 306 | Returns: 307 | dict or OrderedDict: The loaded checkpoint. 308 | """ 309 | checkpoint = _load_checkpoint(filename, map_location) 310 | # OrderedDict is a subclass of dict 311 | if not isinstance(checkpoint, dict): 312 | raise RuntimeError(f"No state_dict found in checkpoint file {filename}") 313 | # get state_dict from checkpoint 314 | if "state_dict" in checkpoint: 315 | state_dict = checkpoint["state_dict"] 316 | elif "model" in checkpoint: 317 | state_dict = checkpoint["model"] 318 | else: 319 | state_dict = checkpoint 320 | # strip prefix of state_dict 321 | if list(state_dict.keys())[0].startswith("module."): 322 | state_dict = {k[7:]: v for k, v in state_dict.items()} 323 | 324 | # for MoBY, load model of online branch 325 | if sorted(list(state_dict.keys()))[0].startswith("encoder"): 326 | state_dict = { 327 | k.replace("encoder.", ""): v 328 | for k, v in state_dict.items() 329 | if k.startswith("encoder.") 330 | } 331 | 332 | # reshape absolute position embedding 333 | if state_dict.get("absolute_pos_embed") is not None: 334 | absolute_pos_embed = state_dict["absolute_pos_embed"] 335 | N1, L, C1 = absolute_pos_embed.size() 336 | N2, C2, H, W = model.absolute_pos_embed.size() 337 | if N1 != N2 or C1 != C2 or L != H * W: 338 | logger.warning("Error in loading absolute_pos_embed, pass") 339 | else: 340 | state_dict["absolute_pos_embed"] = absolute_pos_embed.view( 341 | N2, H, W, C2 342 | ).permute(0, 3, 1, 2) 343 | 344 | # interpolate position bias table if needed 345 | relative_position_bias_table_keys = [ 346 | k for k in state_dict.keys() if "relative_position_bias_table" in k 347 | ] 348 | for table_key in relative_position_bias_table_keys: 349 | table_pretrained = state_dict[table_key] 350 | 351 | new_table_key = ".".join( 352 | table_key.split(".")[:-1] + ["module"] + [table_key.split(".")[-1]] 353 | ) 354 | 355 | table_current = model.state_dict()[new_table_key] 356 | L1, nH1 = table_pretrained.size() 357 | L2, nH2 = table_current.size() 358 | if nH1 != nH2: 359 | logger.warning(f"Error in loading {table_key}, pass") 360 | else: 361 | if L1 != L2: 362 | S1 = int(L1**0.5) 363 | S2 = int(L2**0.5) 364 | table_pretrained_resized = F.interpolate( 365 | table_pretrained.permute(1, 0).view(1, nH1, S1, S1), 366 | size=(S2, S2), 367 | mode="bicubic", 368 | ) 369 | state_dict[table_key] = table_pretrained_resized.view(nH2, L2).permute( 370 | 1, 0 371 | ) 372 | 373 | new_state_dict = dict() 374 | for key, value in state_dict.items(): 375 | new_key = ".".join(key.split(".")[:-1] + ["module"] + [key.split(".")[-1]]) 376 | new_key_2 = ".".join(key.split(".")[:-2] + ["module"] + key.split(".")[-2:]) 377 | new_state_dict[new_key] = value 378 | new_state_dict[new_key_2] = value 379 | new_state_dict[key] = value 380 | # load state_dict 381 | load_state_dict(model, new_state_dict, strict, logger) 382 | return checkpoint 383 | 384 | 385 | def weights_to_cpu(state_dict): 386 | """Copy a model state_dict to cpu. 387 | 388 | Args: 389 | state_dict (OrderedDict): Model weights on GPU. 390 | 391 | Returns: 392 | OrderedDict: Model weights on GPU. 393 | """ 394 | state_dict_cpu = OrderedDict() 395 | for key, val in state_dict.items(): 396 | state_dict_cpu[key] = val.cpu() 397 | return state_dict_cpu 398 | 399 | 400 | def _save_to_state_dict(module, destination, prefix, keep_vars): 401 | """Saves module state to `destination` dictionary. 402 | 403 | This method is modified from :meth:`torch.nn.Module._save_to_state_dict`. 404 | 405 | Args: 406 | module (nn.Module): The module to generate state_dict. 407 | destination (dict): A dict where state will be stored. 408 | prefix (str): The prefix for parameters and buffers used in this 409 | module. 410 | """ 411 | for name, param in module._parameters.items(): 412 | if param is not None: 413 | destination[prefix + name] = param if keep_vars else param.detach() 414 | for name, buf in module._buffers.items(): 415 | # remove check of _non_persistent_buffers_set to allow nn.BatchNorm2d 416 | if buf is not None: 417 | destination[prefix + name] = buf if keep_vars else buf.detach() 418 | 419 | 420 | def get_state_dict(module, destination=None, prefix="", keep_vars=False): 421 | """Returns a dictionary containing a whole state of the module. 422 | 423 | Both parameters and persistent buffers (e.g. running averages) are 424 | included. Keys are corresponding parameter and buffer names. 425 | 426 | This method is modified from :meth:`torch.nn.Module.state_dict` to 427 | recursively check parallel module in case that the model has a complicated 428 | structure, e.g., nn.Module(nn.Module(DDP)). 429 | 430 | Args: 431 | module (nn.Module): The module to generate state_dict. 432 | destination (OrderedDict): Returned dict for the state of the 433 | module. 434 | prefix (str): Prefix of the key. 435 | keep_vars (bool): Whether to keep the variable property of the 436 | parameters. Default: False. 437 | 438 | Returns: 439 | dict: A dictionary containing a whole state of the module. 440 | """ 441 | # recursively check parallel module in case that the model has a 442 | # complicated structure, e.g., nn.Module(nn.Module(DDP)) 443 | if is_module_wrapper(module): 444 | module = module.module 445 | 446 | # below is the same as torch.nn.Module.state_dict() 447 | if destination is None: 448 | destination = OrderedDict() 449 | destination._metadata = OrderedDict() 450 | destination._metadata[prefix[:-1]] = local_metadata = dict(version=module._version) 451 | _save_to_state_dict(module, destination, prefix, keep_vars) 452 | for name, child in module._modules.items(): 453 | if child is not None: 454 | get_state_dict(child, destination, prefix + name + ".", keep_vars=keep_vars) 455 | for hook in module._state_dict_hooks.values(): 456 | hook_result = hook(module, destination, prefix, local_metadata) 457 | if hook_result is not None: 458 | destination = hook_result 459 | return destination 460 | 461 | 462 | def save_checkpoint(model, filename, optimizer=None, meta=None): 463 | """Save checkpoint to file. 464 | 465 | The checkpoint will have 3 fields: ``meta``, ``state_dict`` and 466 | ``optimizer``. By default ``meta`` will contain version and time info. 467 | 468 | Args: 469 | model (Module): Module whose params are to be saved. 470 | filename (str): Checkpoint filename. 471 | optimizer (:obj:`Optimizer`, optional): Optimizer to be saved. 472 | meta (dict, optional): Metadata to be saved in checkpoint. 473 | """ 474 | if meta is None: 475 | meta = {} 476 | elif not isinstance(meta, dict): 477 | raise TypeError(f"meta must be a dict or None, but got {type(meta)}") 478 | meta.update(mmcv_version=mmcv.__version__, time=time.asctime()) 479 | 480 | if is_module_wrapper(model): 481 | model = model.module 482 | 483 | if hasattr(model, "CLASSES") and model.CLASSES is not None: 484 | # save class name to the meta 485 | meta.update(CLASSES=model.CLASSES) 486 | 487 | checkpoint = {"meta": meta, "state_dict": weights_to_cpu(get_state_dict(model))} 488 | # save optimizer state dict in the checkpoint 489 | if isinstance(optimizer, Optimizer): 490 | checkpoint["optimizer"] = optimizer.state_dict() 491 | elif isinstance(optimizer, dict): 492 | checkpoint["optimizer"] = {} 493 | for name, optim in optimizer.items(): 494 | checkpoint["optimizer"][name] = optim.state_dict() 495 | 496 | if filename.startswith("pavi://"): 497 | try: 498 | from pavi import modelcloud 499 | from pavi.exception import NodeNotFoundError 500 | except ImportError: 501 | raise ImportError("Please install pavi to load checkpoint from modelcloud.") 502 | model_path = filename[7:] 503 | root = modelcloud.Folder() 504 | model_dir, model_name = osp.split(model_path) 505 | try: 506 | model = modelcloud.get(model_dir) 507 | except NodeNotFoundError: 508 | model = root.create_training_model(model_dir) 509 | with TemporaryDirectory() as tmp_dir: 510 | checkpoint_file = osp.join(tmp_dir, model_name) 511 | with open(checkpoint_file, "wb") as f: 512 | torch.save(checkpoint, f) 513 | f.flush() 514 | model.create_file(checkpoint_file, name=model_name) 515 | else: 516 | mmcv.mkdir_or_exist(osp.dirname(filename)) 517 | # immediately flush buffer 518 | with open(filename, "wb") as f: 519 | torch.save(checkpoint, f) 520 | f.flush() 521 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # general libs 2 | import os, sys, argparse 3 | import random, time 4 | import warnings 5 | 6 | warnings.filterwarnings("ignore") 7 | import cv2 8 | import numpy as np 9 | import torch 10 | import torch.nn as nn 11 | import datetime 12 | 13 | 14 | from utils import * 15 | import utils.helpers as helpers 16 | from utils.optimizer import PolyWarmupAdamW 17 | from models.segformer import WeTr 18 | from torch import distributed as dist 19 | from torch.utils.data.distributed import DistributedSampler 20 | from tqdm import tqdm 21 | from utils.augmentations_mm import * 22 | from torch.nn.parallel import DistributedDataParallel as DDP 23 | 24 | 25 | def setup_ddp(): 26 | # print(os.environ.keys()) 27 | if "SLURM_PROCID" in os.environ and not "RANK" in os.environ: 28 | # --- multi nodes 29 | world_size = int(os.environ["WORLD_SIZE"]) 30 | rank = int(os.environ["SLURM_PROCID"]) 31 | gpus_per_node = int(os.environ["SLURM_GPUS_ON_NODE"]) 32 | gpu = rank - gpus_per_node * (rank // gpus_per_node) 33 | torch.cuda.set_device(gpu) 34 | dist.init_process_group( 35 | backend="nccl", 36 | world_size=world_size, 37 | rank=rank, 38 | timeout=datetime.timedelta(seconds=7200), 39 | ) 40 | elif "RANK" in os.environ and "WORLD_SIZE" in os.environ: 41 | rank = int(os.environ["RANK"]) 42 | world_size = int(os.environ["WORLD_SIZE"]) 43 | # --- 44 | gpu = int(os.environ["LOCAL_RANK"]) 45 | torch.cuda.set_device(gpu) 46 | dist.init_process_group( 47 | "nccl", 48 | init_method="env://", 49 | world_size=world_size, 50 | rank=rank, 51 | timeout=datetime.timedelta(seconds=7200), 52 | ) 53 | dist.barrier() 54 | else: 55 | gpu = 0 56 | return gpu 57 | 58 | 59 | def cleanup_ddp(): 60 | if dist.is_initialized(): 61 | dist.destroy_process_group() 62 | 63 | 64 | def get_arguments(): 65 | """Parse all the arguments provided from the CLI. 66 | 67 | Returns: 68 | A list of parsed arguments. 69 | """ 70 | parser = argparse.ArgumentParser(description="Full Pipeline Training") 71 | 72 | # Dataset 73 | parser.add_argument( 74 | "--dataset", 75 | type=str, 76 | default="nyudv2", 77 | help="Name of the dataset.", 78 | ) 79 | parser.add_argument( 80 | "--train-dir", 81 | type=str, 82 | default="/cache/datasets/nyudv2", 83 | help="Path to the training set directory.", 84 | ) 85 | parser.add_argument( 86 | "--batch-size", 87 | type=int, 88 | default=2, 89 | help="Batch size to train the segmenter model.", 90 | ) 91 | parser.add_argument( 92 | "--num-workers", 93 | type=int, 94 | default=16, 95 | help="Number of workers for pytorch's dataloader.", 96 | ) 97 | parser.add_argument( 98 | "--ignore-label", 99 | type=int, 100 | default=255, 101 | help="Label to ignore during training", 102 | ) 103 | 104 | # General 105 | parser.add_argument("--name", default="", type=str, help="model name") 106 | parser.add_argument( 107 | "--evaluate", 108 | action="store_true", 109 | default=False, 110 | help="If true, only validate segmentation.", 111 | ) 112 | parser.add_argument( 113 | "--freeze-bn", 114 | type=bool, 115 | nargs="+", 116 | default=True, 117 | help="Whether to keep batch norm statistics intact.", 118 | ) 119 | parser.add_argument( 120 | "--num-epoch", 121 | type=int, 122 | nargs="+", 123 | default=[100] * 3, 124 | help="Number of epochs to train for segmentation network.", 125 | ) 126 | parser.add_argument( 127 | "--random-seed", 128 | type=int, 129 | default=42, 130 | help="Seed to provide (near-)reproducibility.", 131 | ) 132 | parser.add_argument( 133 | "-c", 134 | "--ckpt", 135 | default="model", 136 | type=str, 137 | metavar="PATH", 138 | help="path to save checkpoint (default: model)", 139 | ) 140 | parser.add_argument( 141 | "--resume", 142 | default="", 143 | type=str, 144 | metavar="PATH", 145 | help="path to latest checkpoint (default: none)", 146 | ) 147 | parser.add_argument( 148 | "--val-every", 149 | type=int, 150 | default=5, 151 | help="How often to validate current architecture.", 152 | ) 153 | parser.add_argument( 154 | "--print-network", 155 | action="store_true", 156 | default=False, 157 | help="Whether print newtork paramemters.", 158 | ) 159 | parser.add_argument( 160 | "--print-loss", 161 | action="store_true", 162 | default=False, 163 | help="Whether print losses during training.", 164 | ) 165 | parser.add_argument( 166 | "--save-image", 167 | type=int, 168 | default=100, 169 | help="Number to save images during evaluating, -1 to save all.", 170 | ) 171 | parser.add_argument( 172 | "-i", 173 | "--input", 174 | default=["rgb", "depth"], 175 | type=str, 176 | nargs="+", 177 | help="input type (image, depth)", 178 | ) 179 | 180 | # Optimisers 181 | parser.add_argument("--backbone", default="mit_b3", type=str) 182 | parser.add_argument("--n_heads", default=8, type=int) 183 | parser.add_argument("--drop_rate", default=0.0, type=float) 184 | parser.add_argument("--dpr", default=0.4, type=float) 185 | 186 | parser.add_argument("--weight_decay", default=0.01, type=float) 187 | parser.add_argument("--lr_0", default=6e-5, type=float) 188 | parser.add_argument("--lr_1", default=3e-5, type=float) 189 | parser.add_argument("--lr_2", default=1.5e-5, type=float) 190 | parser.add_argument("--is_pretrain_finetune", action="store_true") 191 | 192 | return parser.parse_args() 193 | 194 | 195 | def create_segmenter(num_classes, gpu, backbone, n_heads, dpr, drop_rate): 196 | segmenter = WeTr(backbone, num_classes, n_heads, dpr, drop_rate) 197 | param_groups = segmenter.get_param_groups() 198 | assert torch.cuda.is_available() 199 | segmenter.to("cuda:" + str(gpu)) 200 | return segmenter, param_groups 201 | 202 | 203 | def create_loaders( 204 | dataset, 205 | train_dir, 206 | val_dir, 207 | train_list, 208 | val_list, 209 | batch_size, 210 | num_workers, 211 | ignore_label, 212 | ): 213 | """ 214 | Args: 215 | train_dir (str) : path to the root directory of the training set. 216 | val_dir (str) : path to the root directory of the validation set. 217 | train_list (str) : path to the training list. 218 | val_list (str) : path to the validation list. 219 | batch_size (int) : training batch size. 220 | num_workers (int) : number of workers to parallelise data loading operations. 221 | ignore_label (int) : label to pad segmentation masks with 222 | 223 | Returns: 224 | train_loader, val loader 225 | 226 | """ 227 | # Torch libraries 228 | from torchvision import transforms 229 | from torch.utils.data import DataLoader 230 | 231 | # Custom libraries 232 | from utils.datasets import SegDataset as Dataset 233 | from utils.transforms import ToTensor 234 | 235 | input_names, input_mask_idxs = ["rgb", "depth"], [0, 2, 1] 236 | 237 | if dataset == "nyudv2": 238 | input_scale = [480, 640] 239 | elif dataset == "sunrgbd": 240 | input_scale = [480, 480] 241 | 242 | composed_trn = transforms.Compose( 243 | [ 244 | ToTensor(), 245 | RandomColorJitter(p=0.2), # 246 | RandomHorizontalFlip(p=0.5), # 247 | RandomGaussianBlur((3, 3), p=0.2), # 248 | RandomResizedCrop(input_scale, scale=(0.5, 2.0), seg_fill=255), # 249 | Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 250 | ] 251 | ) 252 | 253 | composed_val = transforms.Compose( 254 | [ 255 | ToTensor(), 256 | Resize(input_scale), 257 | Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 258 | ] 259 | ) 260 | # Training and validation sets 261 | trainset = Dataset( 262 | dataset=dataset, 263 | data_file=train_list, 264 | data_dir=train_dir, 265 | input_names=input_names, 266 | input_mask_idxs=input_mask_idxs, 267 | transform_trn=composed_trn, 268 | transform_val=composed_val, 269 | stage="train", 270 | ignore_label=ignore_label, 271 | ) 272 | 273 | validset = Dataset( 274 | dataset=dataset, 275 | data_file=val_list, 276 | data_dir=val_dir, 277 | input_names=input_names, 278 | input_mask_idxs=input_mask_idxs, 279 | transform_trn=None, 280 | transform_val=composed_val, 281 | stage="val", 282 | ignore_label=ignore_label, 283 | ) 284 | print_log( 285 | "Created train set {} examples, val set {} examples".format( 286 | len(trainset), len(validset) 287 | ) 288 | ) 289 | train_sampler = DistributedSampler( 290 | trainset, dist.get_world_size(), dist.get_rank(), shuffle=True 291 | ) 292 | 293 | # Training and validation loaders 294 | train_loader = DataLoader( 295 | trainset, 296 | batch_size=batch_size, 297 | num_workers=num_workers, 298 | pin_memory=True, 299 | drop_last=True, 300 | sampler=train_sampler, 301 | ) 302 | val_loader = DataLoader( 303 | validset, batch_size=1, shuffle=False, num_workers=num_workers, pin_memory=True 304 | ) 305 | 306 | return train_loader, val_loader, train_sampler 307 | 308 | 309 | def load_ckpt(ckpt_path, ckpt_dict, is_pretrain_finetune=False): 310 | print("----------------") 311 | ckpt = torch.load(ckpt_path, map_location="cpu") 312 | new_segmenter_ckpt = dict() 313 | if is_pretrain_finetune: 314 | for ckpt_k, ckpt_v in ckpt["segmenter"].items(): 315 | if "linear_pred" in ckpt_k: 316 | print(ckpt_k, " is Excluded!") 317 | else: 318 | if "module." in ckpt_k: 319 | new_segmenter_ckpt[ckpt_k[7:]] = ckpt_v 320 | else: 321 | for ckpt_k, ckpt_v in ckpt["segmenter"].items(): 322 | new_segmenter_ckpt[ckpt_k] = ckpt_v 323 | if "module." in ckpt_k: 324 | new_segmenter_ckpt[ckpt_k[7:]] = ckpt_v 325 | ckpt["segmenter"] = new_segmenter_ckpt 326 | 327 | for k, v in ckpt_dict.items(): 328 | if k in ckpt: 329 | v.load_state_dict(ckpt[k], strict=False) 330 | else: 331 | print(v, " is missed!") 332 | best_val = ckpt.get("best_val", 0) 333 | epoch_start = ckpt.get("epoch_start", 0) 334 | if is_pretrain_finetune: 335 | print_log( 336 | "Found [Pretrain] checkpoint at {} with best_val {:.4f} at epoch {}".format( 337 | ckpt_path, best_val, epoch_start 338 | ) 339 | ) 340 | return 0, 0 341 | else: 342 | 343 | print_log( 344 | "Found checkpoint at {} with best_val {:.4f} at epoch {}".format( 345 | ckpt_path, best_val, epoch_start 346 | ) 347 | ) 348 | return best_val, epoch_start 349 | 350 | 351 | def train( 352 | segmenter, 353 | input_types, 354 | train_loader, 355 | optimizer, 356 | epoch, 357 | segm_crit, 358 | freeze_bn, 359 | print_loss=False, 360 | ): 361 | """Training segmenter 362 | 363 | Args: 364 | segmenter (nn.Module) : segmentation network 365 | train_loader (DataLoader) : training data iterator 366 | optim_enc (optim) : optimiser for encoder 367 | optim_dec (optim) : optimiser for decoder 368 | epoch (int) : current epoch 369 | segm_crit (nn.Loss) : segmentation criterion 370 | freeze_bn (bool) : whether to keep BN params intact 371 | 372 | """ 373 | train_loader.dataset.set_stage("train") 374 | segmenter.train() 375 | if freeze_bn: 376 | for module in segmenter.modules(): 377 | if isinstance(module, nn.BatchNorm2d): 378 | module.eval() 379 | batch_time = AverageMeter() 380 | losses = AverageMeter() 381 | 382 | for i, sample in tqdm(enumerate(train_loader), total=len(train_loader)): 383 | start = time.time() 384 | inputs = [sample[key].cuda().float() for key in input_types] 385 | target = sample["mask"].cuda().long() 386 | # Compute outputs 387 | outputs, masks = segmenter(inputs) 388 | loss = 0 389 | for output in outputs: 390 | output = nn.functional.interpolate( 391 | output, size=target.size()[1:], mode="bilinear", align_corners=False 392 | ) 393 | soft_output = nn.LogSoftmax()(output) 394 | # Compute loss and backpropagate 395 | loss += segm_crit(soft_output, target) 396 | 397 | optimizer.zero_grad() 398 | loss.backward() 399 | if print_loss: 400 | print("step: %-3d: loss=%.2f" % (i, loss), flush=True) 401 | optimizer.step() 402 | losses.update(loss.item()) 403 | batch_time.update(time.time() - start) 404 | 405 | 406 | def validate( 407 | segmenter, input_types, val_loader, epoch, save_dir, num_classes=-1, save_image=0 408 | ): 409 | """Validate segmenter 410 | 411 | Args: 412 | segmenter (nn.Module) : segmentation network 413 | val_loader (DataLoader) : training data iterator 414 | epoch (int) : current epoch 415 | num_classes (int) : number of classes to consider 416 | 417 | Returns: 418 | Mean IoU (float) 419 | """ 420 | global best_iou 421 | val_loader.dataset.set_stage("val") 422 | segmenter.eval() 423 | conf_mat = [] 424 | for _ in range(len(input_types) + 1): 425 | conf_mat.append(np.zeros((num_classes, num_classes), dtype=int)) 426 | with torch.no_grad(): 427 | all_times = 0 428 | count = 0 429 | for i, sample in enumerate(val_loader): 430 | inputs = [sample[key].float().cuda() for key in input_types] 431 | target = sample["mask"] 432 | gt = target[0].data.cpu().numpy().astype(np.uint8) 433 | gt_idx = ( 434 | gt < num_classes 435 | ) # Ignore every class index larger than the number of classes 436 | 437 | """from fvcore.nn import FlopCountAnalysis, parameter_count_table 438 | 439 | flops = FlopCountAnalysis(segmenter, inputs) 440 | print("FLOPs: ", flops.total()) 441 | print(parameter_count_table(segmenter)) 442 | exit()""" 443 | 444 | start_time = time.time() 445 | 446 | outputs, _ = segmenter(inputs) 447 | 448 | end_time = time.time() 449 | all_times += end_time - start_time 450 | 451 | for idx, output in enumerate(outputs): 452 | output = ( 453 | cv2.resize( 454 | output[0, :num_classes].data.cpu().numpy().transpose(1, 2, 0), 455 | target.size()[1:][::-1], 456 | interpolation=cv2.INTER_CUBIC, 457 | ) 458 | .argmax(axis=2) 459 | .astype(np.uint8) 460 | ) 461 | # Compute IoU 462 | conf_mat[idx] += confusion_matrix( 463 | gt[gt_idx], output[gt_idx], num_classes 464 | ) 465 | if i < save_image or save_image == -1: 466 | img = make_validation_img( 467 | inputs[0].data.cpu().numpy(), 468 | inputs[1].data.cpu().numpy(), 469 | sample["mask"].data.cpu().numpy(), 470 | output[np.newaxis, :], 471 | ) 472 | imgs_folder = os.path.join(save_dir, "imgs") 473 | os.makedirs(imgs_folder, exist_ok=True) 474 | cv2.imwrite( 475 | os.path.join(imgs_folder, "validate_" + str(i) + ".png"), 476 | img[:, :, ::-1], 477 | ) 478 | print("imwrite at imgs/validate_%d.png" % i) 479 | count += 1 480 | latency = all_times / count 481 | print("all_times:", all_times, " count:", count, " latency:", latency) 482 | 483 | for idx, input_type in enumerate(input_types + ["ens"]): 484 | glob, mean, iou = getScores(conf_mat[idx]) 485 | best_iou_note = "" 486 | if iou > best_iou: 487 | best_iou = iou 488 | best_iou_note = " (best)" 489 | alpha = " " 490 | 491 | input_type_str = "(%s)" % input_type 492 | print_log( 493 | "Epoch %-4d %-7s glob_acc=%-5.2f mean_acc=%-5.2f IoU=%-5.2f%s%s" 494 | % (epoch, input_type_str, glob, mean, iou, alpha, best_iou_note) 495 | ) 496 | print_log("") 497 | return iou 498 | 499 | 500 | def main(): 501 | global args, best_iou 502 | best_iou = 0 503 | args = get_arguments() 504 | args.val_dir = args.train_dir 505 | 506 | if args.dataset == "nyudv2": 507 | args.train_list = "data/nyudv2/train.txt" 508 | args.val_list = "data/nyudv2/val.txt" 509 | args.num_classes = 40 510 | elif args.dataset == "sunrgbd": 511 | args.train_list = "data/sun/train.txt" 512 | args.val_list = "data/sun/test.txt" 513 | args.num_classes = 37 514 | 515 | args.num_stages = 3 516 | gpu = setup_ddp() 517 | ckpt_dir = os.path.join("ckpt", args.ckpt) 518 | os.makedirs(ckpt_dir, exist_ok=True) 519 | os.system("cp -r *py models utils data %s" % ckpt_dir) 520 | helpers.logger = open(os.path.join(ckpt_dir, "log.txt"), "w+") 521 | print_log(" ".join(sys.argv)) 522 | 523 | # Set random seeds 524 | torch.backends.cudnn.deterministic = True 525 | torch.manual_seed(args.random_seed) 526 | if torch.cuda.is_available(): 527 | torch.cuda.manual_seed_all(args.random_seed) 528 | np.random.seed(args.random_seed) 529 | random.seed(args.random_seed) 530 | # Generate Segmenter 531 | segmenter, param_groups = create_segmenter( 532 | args.num_classes, 533 | gpu, 534 | args.backbone, 535 | args.n_heads, 536 | args.dpr, 537 | args.drop_rate, 538 | ) 539 | 540 | print_log( 541 | "Loaded Segmenter {}, #PARAMS={:3.2f}M".format( 542 | args.backbone, compute_params(segmenter) / 1e6 543 | ) 544 | ) 545 | # Restore if any 546 | best_val, epoch_start = 0, 0 547 | if args.resume: 548 | if os.path.isfile(args.resume): 549 | best_val, epoch_start = load_ckpt( 550 | args.resume, 551 | {"segmenter": segmenter}, 552 | is_pretrain_finetune=args.is_pretrain_finetune, 553 | ) 554 | else: 555 | print_log("=> no checkpoint found at '{}'".format(args.resume)) 556 | return 557 | no_ddp_segmenter = segmenter 558 | segmenter = DDP( 559 | segmenter, device_ids=[gpu], output_device=0, find_unused_parameters=False 560 | ) 561 | 562 | epoch_current = epoch_start 563 | # Criterion 564 | segm_crit = nn.NLLLoss(ignore_index=args.ignore_label).cuda() 565 | # Saver 566 | saver = Saver( 567 | args=vars(args), 568 | ckpt_dir=ckpt_dir, 569 | best_val=best_val, 570 | condition=lambda x, y: x > y, 571 | ) # keep checkpoint with the best validation score 572 | 573 | lrs = [args.lr_0, args.lr_1, args.lr_2] 574 | 575 | print("-------------------------Optimizer Params--------------------") 576 | print("weight_decay:", args.weight_decay) 577 | print("lrs:", lrs) 578 | print("----------------------------------------------------------------") 579 | 580 | for task_idx in range(args.num_stages): 581 | optimizer = PolyWarmupAdamW( 582 | # encoder,encoder-norm,decoder 583 | params=[ 584 | { 585 | "params": param_groups[0], 586 | "lr": lrs[task_idx], 587 | "weight_decay": args.weight_decay, 588 | }, 589 | { 590 | "params": param_groups[1], 591 | "lr": lrs[task_idx], 592 | "weight_decay": 0.0, 593 | }, 594 | { 595 | "params": param_groups[2], 596 | "lr": lrs[task_idx] * 10, 597 | "weight_decay": args.weight_decay, 598 | }, 599 | ], 600 | lr=lrs[task_idx], 601 | weight_decay=args.weight_decay, 602 | betas=[0.9, 0.999], 603 | warmup_iter=1500, 604 | max_iter=40000, 605 | warmup_ratio=1e-6, 606 | power=1.0, 607 | ) 608 | total_epoch = sum([args.num_epoch[idx] for idx in range(task_idx + 1)]) 609 | if epoch_start >= total_epoch: 610 | continue 611 | start = time.time() 612 | torch.cuda.empty_cache() 613 | # Create dataloaders 614 | train_loader, val_loader, train_sampler = create_loaders( 615 | args.dataset, 616 | args.train_dir, 617 | args.val_dir, 618 | args.train_list, 619 | args.val_list, 620 | args.batch_size, 621 | args.num_workers, 622 | args.ignore_label, 623 | ) 624 | if args.evaluate: 625 | return validate( 626 | no_ddp_segmenter, 627 | args.input, 628 | val_loader, 629 | 0, 630 | ckpt_dir, 631 | num_classes=args.num_classes, 632 | save_image=args.save_image, 633 | ) 634 | 635 | # Optimisers 636 | print_log("Training Stage {}".format(str(task_idx))) 637 | 638 | for epoch in range(min(args.num_epoch[task_idx], total_epoch - epoch_start)): 639 | train_sampler.set_epoch(epoch) 640 | train( 641 | segmenter, 642 | args.input, 643 | train_loader, 644 | optimizer, 645 | epoch_current, 646 | segm_crit, 647 | args.freeze_bn, 648 | args.print_loss, 649 | ) 650 | if (epoch + 1) % (args.val_every) == 0: 651 | miou = validate( 652 | no_ddp_segmenter, 653 | args.input, 654 | val_loader, 655 | epoch_current, 656 | ckpt_dir, 657 | args.num_classes, 658 | ) 659 | saver.save( 660 | miou, 661 | {"segmenter": segmenter.state_dict(), "epoch_start": epoch_current}, 662 | ) 663 | epoch_current += 1 664 | 665 | print_log( 666 | "Stage {} finished, time spent {:.3f}min\n".format( 667 | task_idx, (time.time() - start) / 60.0 668 | ) 669 | ) 670 | 671 | print_log("All stages are now finished. Best Val is {:.3f}".format(saver.best_val)) 672 | helpers.logger.close() 673 | cleanup_ddp() 674 | 675 | 676 | if __name__ == "__main__": 677 | main() 678 | -------------------------------------------------------------------------------- /models/mix_transformer.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2021, NVIDIA Corporation. All rights reserved. 3 | # 4 | # This work is licensed under the NVIDIA Source Code License 5 | # --------------------------------------------------------------- 6 | import math 7 | import torch 8 | import torch.nn as nn 9 | 10 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 11 | from .modules import ModuleParallel, LayerNormParallel 12 | 13 | 14 | class Mlp(nn.Module): 15 | def __init__( 16 | self, 17 | in_features, 18 | hidden_features=None, 19 | out_features=None, 20 | act_layer=nn.GELU, 21 | drop=0.0, 22 | ): 23 | super().__init__() 24 | out_features = out_features or in_features 25 | hidden_features = hidden_features or in_features 26 | self.fc1 = ModuleParallel(nn.Linear(in_features, hidden_features)) 27 | self.dwconv = DWConv(hidden_features) 28 | self.act = ModuleParallel(act_layer()) 29 | self.fc2 = ModuleParallel(nn.Linear(hidden_features, out_features)) 30 | self.drop = ModuleParallel(nn.Dropout(drop)) 31 | 32 | self.apply(self._init_weights) 33 | 34 | def _init_weights(self, m): 35 | if isinstance(m, nn.Linear): 36 | trunc_normal_(m.weight, std=0.02) 37 | if isinstance(m, nn.Linear) and m.bias is not None: 38 | nn.init.constant_(m.bias, 0) 39 | elif isinstance(m, nn.LayerNorm): 40 | nn.init.constant_(m.bias, 0) 41 | nn.init.constant_(m.weight, 1.0) 42 | elif isinstance(m, nn.Conv2d): 43 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 44 | fan_out //= m.groups 45 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 46 | if m.bias is not None: 47 | m.bias.data.zero_() 48 | 49 | def forward(self, x, H, W): 50 | x = self.fc1(x) 51 | x = [self.dwconv(x[0], H, W), self.dwconv(x[1], H, W)] 52 | x = self.act(x) 53 | x = self.drop(x) 54 | x = self.fc2(x) 55 | x = self.drop(x) 56 | return x 57 | 58 | 59 | class Mlp_2(nn.Module): 60 | """Multilayer perceptron.""" 61 | 62 | def __init__( 63 | self, 64 | in_features, 65 | hidden_features=None, 66 | out_features=None, 67 | act_layer=nn.GELU, 68 | drop=0.0, 69 | ): 70 | super().__init__() 71 | out_features = out_features or in_features 72 | hidden_features = hidden_features or in_features 73 | self.fc1 = nn.Linear(in_features, hidden_features) 74 | self.act = act_layer() 75 | self.fc2 = nn.Linear(hidden_features, out_features) 76 | self.drop = nn.Dropout(drop) 77 | 78 | def forward(self, x): 79 | x = self.fc1(x) 80 | x = self.act(x) 81 | x = self.drop(x) 82 | x = self.fc2(x) 83 | x = self.drop(x) 84 | return x 85 | 86 | 87 | class Attention(nn.Module): 88 | def __init__( 89 | self, 90 | dim, 91 | num_heads=8, 92 | qkv_bias=False, 93 | qk_scale=None, 94 | attn_drop=0.0, 95 | proj_drop=0.0, 96 | sr_ratio=1, 97 | n_heads=8, 98 | ): 99 | super().__init__() 100 | assert ( 101 | dim % num_heads == 0 102 | ), f"dim {dim} should be divided by num_heads {num_heads}." 103 | 104 | self.dim = dim 105 | self.num_heads = num_heads 106 | head_dim = dim // num_heads 107 | self.scale = qk_scale or head_dim**-0.5 108 | 109 | self.q = ModuleParallel(nn.Linear(dim, dim, bias=qkv_bias)) 110 | self.kv = ModuleParallel(nn.Linear(dim, dim * 2, bias=qkv_bias)) 111 | self.attn_drop = ModuleParallel(nn.Dropout(attn_drop)) 112 | self.proj = ModuleParallel(nn.Linear(dim, dim)) 113 | self.proj_drop = ModuleParallel(nn.Dropout(proj_drop)) 114 | 115 | self.sr_ratio = sr_ratio 116 | if sr_ratio > 1: 117 | self.sr = ModuleParallel( 118 | nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) 119 | ) 120 | self.norm = LayerNormParallel(dim) 121 | 122 | self.cross_heads = n_heads 123 | self.cross_attn_0_to_1 = nn.MultiheadAttention( 124 | dim, self.cross_heads, dropout=0.0, batch_first=False 125 | ) 126 | self.cross_attn_1_to_0 = nn.MultiheadAttention( 127 | dim, self.cross_heads, dropout=0.0, batch_first=False 128 | ) 129 | 130 | self.relation_judger = nn.Sequential( 131 | Mlp_2(dim * 2, dim, dim), torch.nn.Softmax(dim=-1) 132 | ) 133 | 134 | self.k_noise = nn.Embedding(2, dim) 135 | self.v_noise = nn.Embedding(2, dim) 136 | 137 | self.apply(self._init_weights) 138 | 139 | def _init_weights(self, m): 140 | if isinstance(m, nn.Linear): 141 | trunc_normal_(m.weight, std=0.02) 142 | if isinstance(m, nn.Linear) and m.bias is not None: 143 | nn.init.constant_(m.bias, 0) 144 | elif isinstance(m, nn.LayerNorm): 145 | nn.init.constant_(m.bias, 0) 146 | nn.init.constant_(m.weight, 1.0) 147 | elif isinstance(m, nn.Conv2d): 148 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 149 | fan_out //= m.groups 150 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 151 | if m.bias is not None: 152 | m.bias.data.zero_() 153 | 154 | def forward( 155 | self, 156 | x, 157 | H, 158 | W, 159 | ): 160 | B, N, C = x[0].shape 161 | q = self.q(x) 162 | q = [ 163 | q_.reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 164 | for q_ in q 165 | ] 166 | 167 | if self.sr_ratio > 1: 168 | x = [x_.permute(0, 2, 1).reshape(B, C, H, W) for x_ in x] 169 | x = self.sr(x) 170 | x = [x_.reshape(B, C, -1).permute(0, 2, 1) for x_ in x] 171 | x = self.norm(x) 172 | kv = self.kv(x) 173 | kv = [ 174 | kv_.reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute( 175 | 2, 0, 3, 1, 4 176 | ) 177 | for kv_ in kv 178 | ] 179 | else: 180 | kv = self.kv(x) 181 | kv = [ 182 | kv_.reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute( 183 | 2, 0, 3, 1, 4 184 | ) 185 | for kv_ in kv 186 | ] 187 | k, v = [kv[0][0], kv[1][0]], [kv[0][1], kv[1][1]] 188 | 189 | attn = [(q_ @ k_.transpose(-2, -1)) * self.scale for (q_, k_) in zip(q, k)] 190 | attn = [attn_.softmax(dim=-1) for attn_ in attn] 191 | attn = self.attn_drop(attn) 192 | 193 | x = [ 194 | (attn_ @ v_).transpose(1, 2).reshape(B, N, C) 195 | for (attn_, v_) in zip(attn, v) 196 | ] 197 | 198 | # cross-attn per batch 199 | new_x0 = [] 200 | new_x1 = [] 201 | for bs in range(B): 202 | ## 1. 0_to_1 cross attn and skip connect 203 | q = x[0][bs].unsqueeze(0) 204 | 205 | judger_input = torch.cat( 206 | [x[0][bs].unsqueeze(0), x[1][bs].unsqueeze(0)], dim=-1 207 | ) 208 | 209 | relation_score = self.relation_judger(judger_input) 210 | 211 | noise_k = self.k_noise.weight[0] + q 212 | noise_v = self.v_noise.weight[0] + q 213 | 214 | k = torch.cat([noise_k, torch.mul(q, relation_score)], dim=0) 215 | v = torch.cat([noise_v, x[1][bs].unsqueeze(0)], dim=0) 216 | 217 | new_x0.append(x[0][bs] + self.cross_attn_0_to_1(q, k, v)[0].squeeze(0)) 218 | 219 | ## 2. 1_to_0 cross attn and skip connect 220 | q = x[1][bs].unsqueeze(0) 221 | 222 | judger_input = torch.cat( 223 | [x[1][bs].unsqueeze(0), x[0][bs].unsqueeze(0)], dim=-1 224 | ) 225 | 226 | relation_score = self.relation_judger(judger_input) 227 | 228 | noise_k = self.k_noise.weight[1] + q 229 | noise_v = self.v_noise.weight[1] + q 230 | 231 | k = torch.cat([noise_k, torch.mul(q, relation_score)], dim=0) 232 | v = torch.cat([noise_v, x[0][bs].unsqueeze(0)], dim=0) 233 | 234 | new_x1.append(x[1][bs] + self.cross_attn_1_to_0(q, k, v)[0].squeeze(0)) 235 | 236 | new_x0 = torch.stack(new_x0) 237 | new_x1 = torch.stack(new_x1) 238 | x[0] = new_x0 239 | x[1] = new_x1 240 | 241 | x = self.proj(x) 242 | x = self.proj_drop(x) 243 | 244 | return x 245 | 246 | 247 | class Block(nn.Module): 248 | def __init__( 249 | self, 250 | dim, 251 | num_heads, 252 | mlp_ratio=4.0, 253 | qkv_bias=False, 254 | qk_scale=None, 255 | drop=0.0, 256 | attn_drop=0.0, 257 | drop_path=0.0, 258 | act_layer=nn.GELU, 259 | norm_layer=LayerNormParallel, 260 | sr_ratio=1, 261 | n_heads=8, 262 | ): 263 | super().__init__() 264 | self.norm1 = norm_layer(dim) 265 | 266 | self.attn = Attention( 267 | dim, 268 | num_heads=num_heads, 269 | qkv_bias=qkv_bias, 270 | qk_scale=qk_scale, 271 | attn_drop=attn_drop, 272 | proj_drop=drop, 273 | sr_ratio=sr_ratio, 274 | n_heads=n_heads, 275 | ) 276 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 277 | self.drop_path = ( 278 | ModuleParallel(DropPath(drop_path)) 279 | if drop_path > 0.0 280 | else ModuleParallel(nn.Identity()) 281 | ) 282 | self.norm2 = norm_layer(dim) 283 | mlp_hidden_dim = int(dim * mlp_ratio) 284 | self.mlp = Mlp( 285 | in_features=dim, 286 | hidden_features=mlp_hidden_dim, 287 | act_layer=act_layer, 288 | drop=drop, 289 | ) 290 | 291 | self.apply(self._init_weights) 292 | 293 | def _init_weights(self, m): 294 | if isinstance(m, nn.Linear): 295 | trunc_normal_(m.weight, std=0.02) 296 | if isinstance(m, nn.Linear) and m.bias is not None: 297 | nn.init.constant_(m.bias, 0) 298 | elif isinstance(m, nn.LayerNorm): 299 | nn.init.constant_(m.bias, 0) 300 | nn.init.constant_(m.weight, 1.0) 301 | elif isinstance(m, nn.Conv2d): 302 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 303 | fan_out //= m.groups 304 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 305 | if m.bias is not None: 306 | m.bias.data.zero_() 307 | 308 | def forward(self, x, H, W): 309 | B = x[0].shape[0] 310 | 311 | f = self.drop_path( 312 | self.attn( 313 | self.norm1(x), 314 | H, 315 | W, 316 | ) 317 | ) 318 | x = [x_ + f_ for (x_, f_) in zip(x, f)] 319 | f = self.drop_path(self.mlp(self.norm2(x), H, W)) 320 | x = [x_ + f_ for (x_, f_) in zip(x, f)] 321 | 322 | return x 323 | 324 | 325 | class OverlapPatchEmbed(nn.Module): 326 | """Image to Patch Embedding""" 327 | 328 | def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768): 329 | super().__init__() 330 | img_size = to_2tuple(img_size) 331 | patch_size = to_2tuple(patch_size) 332 | 333 | self.img_size = img_size 334 | self.patch_size = patch_size 335 | self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1] 336 | self.num_patches = self.H * self.W 337 | self.proj = ModuleParallel( 338 | nn.Conv2d( 339 | in_chans, 340 | embed_dim, 341 | kernel_size=patch_size, 342 | stride=stride, 343 | padding=(patch_size[0] // 2, patch_size[1] // 2), 344 | ) 345 | ) 346 | self.norm = LayerNormParallel(embed_dim) 347 | 348 | self.apply(self._init_weights) 349 | 350 | def _init_weights(self, m): 351 | if isinstance(m, nn.Linear): 352 | trunc_normal_(m.weight, std=0.02) 353 | if isinstance(m, nn.Linear) and m.bias is not None: 354 | nn.init.constant_(m.bias, 0) 355 | elif isinstance(m, nn.LayerNorm): 356 | nn.init.constant_(m.bias, 0) 357 | nn.init.constant_(m.weight, 1.0) 358 | elif isinstance(m, nn.Conv2d): 359 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 360 | fan_out //= m.groups 361 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 362 | if m.bias is not None: 363 | m.bias.data.zero_() 364 | 365 | def forward(self, x): 366 | x = self.proj(x) 367 | _, _, H, W = x[0].shape 368 | x = [x_.flatten(2).transpose(1, 2) for x_ in x] 369 | x = self.norm(x) 370 | return x, H, W 371 | 372 | 373 | class MixVisionTransformer(nn.Module): 374 | def __init__( 375 | self, 376 | img_size=224, 377 | patch_size=16, 378 | in_chans=3, 379 | num_classes=1000, 380 | embed_dims=[64, 128, 256, 512], 381 | num_heads=[1, 2, 4, 8], 382 | mlp_ratios=[4, 4, 4, 4], 383 | qkv_bias=False, 384 | qk_scale=None, 385 | drop_rate=0.0, 386 | attn_drop_rate=0.0, 387 | drop_path_rate=0.0, 388 | norm_layer=LayerNormParallel, 389 | depths=[3, 4, 6, 3], 390 | sr_ratios=[8, 4, 2, 1], 391 | n_heads=8, 392 | ): 393 | super().__init__() 394 | 395 | self.num_classes = num_classes 396 | self.depths = depths 397 | self.embed_dims = embed_dims 398 | 399 | # patch_embed 400 | self.patch_embed1 = OverlapPatchEmbed( 401 | img_size=img_size, 402 | patch_size=7, 403 | stride=4, 404 | in_chans=in_chans, 405 | embed_dim=embed_dims[0], 406 | ) 407 | self.patch_embed2 = OverlapPatchEmbed( 408 | img_size=img_size // 4, 409 | patch_size=3, 410 | stride=2, 411 | in_chans=embed_dims[0], 412 | embed_dim=embed_dims[1], 413 | ) 414 | self.patch_embed3 = OverlapPatchEmbed( 415 | img_size=img_size // 8, 416 | patch_size=3, 417 | stride=2, 418 | in_chans=embed_dims[1], 419 | embed_dim=embed_dims[2], 420 | ) 421 | self.patch_embed4 = OverlapPatchEmbed( 422 | img_size=img_size // 16, 423 | patch_size=3, 424 | stride=2, 425 | in_chans=embed_dims[2], 426 | embed_dim=embed_dims[3], 427 | ) 428 | 429 | # transformer encoder 430 | dpr = [ 431 | x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)) 432 | ] # stochastic depth decay rule 433 | cur = 0 434 | self.block1 = nn.ModuleList( 435 | [ 436 | Block( 437 | dim=embed_dims[0], 438 | num_heads=num_heads[0], 439 | mlp_ratio=mlp_ratios[0], 440 | qkv_bias=qkv_bias, 441 | qk_scale=qk_scale, 442 | drop=drop_rate, 443 | attn_drop=attn_drop_rate, 444 | drop_path=dpr[cur + i], 445 | norm_layer=norm_layer, 446 | sr_ratio=sr_ratios[0], 447 | n_heads=n_heads, 448 | ) 449 | for i in range(depths[0]) 450 | ] 451 | ) 452 | self.norm1 = norm_layer(embed_dims[0]) 453 | 454 | cur += depths[0] 455 | self.block2 = nn.ModuleList( 456 | [ 457 | Block( 458 | dim=embed_dims[1], 459 | num_heads=num_heads[1], 460 | mlp_ratio=mlp_ratios[1], 461 | qkv_bias=qkv_bias, 462 | qk_scale=qk_scale, 463 | drop=drop_rate, 464 | attn_drop=attn_drop_rate, 465 | drop_path=dpr[cur + i], 466 | norm_layer=norm_layer, 467 | sr_ratio=sr_ratios[1], 468 | n_heads=n_heads, 469 | ) 470 | for i in range(depths[1]) 471 | ] 472 | ) 473 | self.norm2 = norm_layer(embed_dims[1]) 474 | 475 | cur += depths[1] 476 | self.block3 = nn.ModuleList( 477 | [ 478 | Block( 479 | dim=embed_dims[2], 480 | num_heads=num_heads[2], 481 | mlp_ratio=mlp_ratios[2], 482 | qkv_bias=qkv_bias, 483 | qk_scale=qk_scale, 484 | drop=drop_rate, 485 | attn_drop=attn_drop_rate, 486 | drop_path=dpr[cur + i], 487 | norm_layer=norm_layer, 488 | sr_ratio=sr_ratios[2], 489 | n_heads=n_heads, 490 | ) 491 | for i in range(depths[2]) 492 | ] 493 | ) 494 | self.norm3 = norm_layer(embed_dims[2]) 495 | 496 | cur += depths[2] 497 | self.block4 = nn.ModuleList( 498 | [ 499 | Block( 500 | dim=embed_dims[3], 501 | num_heads=num_heads[3], 502 | mlp_ratio=mlp_ratios[3], 503 | qkv_bias=qkv_bias, 504 | qk_scale=qk_scale, 505 | drop=drop_rate, 506 | attn_drop=attn_drop_rate, 507 | drop_path=dpr[cur + i], 508 | norm_layer=norm_layer, 509 | sr_ratio=sr_ratios[3], 510 | n_heads=n_heads, 511 | ) 512 | for i in range(depths[3]) 513 | ] 514 | ) 515 | self.norm4 = norm_layer(embed_dims[3]) 516 | 517 | self.apply(self._init_weights) 518 | 519 | def _init_weights(self, m): 520 | if isinstance(m, nn.Linear): 521 | trunc_normal_(m.weight, std=0.02) 522 | if isinstance(m, nn.Linear) and m.bias is not None: 523 | nn.init.constant_(m.bias, 0) 524 | elif isinstance(m, nn.LayerNorm): 525 | nn.init.constant_(m.bias, 0) 526 | nn.init.constant_(m.weight, 1.0) 527 | elif isinstance(m, nn.Conv2d): 528 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 529 | fan_out //= m.groups 530 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 531 | if m.bias is not None: 532 | m.bias.data.zero_() 533 | 534 | def reset_drop_path(self, drop_path_rate): 535 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))] 536 | cur = 0 537 | for i in range(self.depths[0]): 538 | self.block1[i].drop_path.drop_prob = dpr[cur + i] 539 | 540 | cur += self.depths[0] 541 | for i in range(self.depths[1]): 542 | self.block2[i].drop_path.drop_prob = dpr[cur + i] 543 | 544 | cur += self.depths[1] 545 | for i in range(self.depths[2]): 546 | self.block3[i].drop_path.drop_prob = dpr[cur + i] 547 | 548 | cur += self.depths[2] 549 | for i in range(self.depths[3]): 550 | self.block4[i].drop_path.drop_prob = dpr[cur + i] 551 | 552 | def freeze_patch_emb(self): 553 | self.patch_embed1.requires_grad = False 554 | 555 | @torch.jit.ignore 556 | def no_weight_decay(self): 557 | return { 558 | "pos_embed1", 559 | "pos_embed2", 560 | "pos_embed3", 561 | "pos_embed4", 562 | "cls_token", 563 | } # has pos_embed may be better 564 | 565 | def get_classifier(self): 566 | return self.head 567 | 568 | def reset_classifier(self, num_classes, global_pool=""): 569 | self.num_classes = num_classes 570 | self.head = ( 571 | nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 572 | ) 573 | 574 | def forward_features(self, x): 575 | B = x[0].shape[0] 576 | outs0, outs1 = [], [] 577 | 578 | # stage 1 579 | x, H, W = self.patch_embed1(x) 580 | for i, blk in enumerate(self.block1): 581 | 582 | x = blk( 583 | x, 584 | H, 585 | W, 586 | ) 587 | x = self.norm1(x) 588 | x = [x_.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for x_ in x] 589 | outs0.append(x[0]) 590 | outs1.append(x[1]) 591 | 592 | # stage 2 593 | x, H, W = self.patch_embed2(x) 594 | for i, blk in enumerate(self.block2): 595 | 596 | x = blk( 597 | x, 598 | H, 599 | W, 600 | ) 601 | x = self.norm2(x) 602 | x = [x_.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for x_ in x] 603 | outs0.append(x[0]) 604 | outs1.append(x[1]) 605 | 606 | # stage 3 607 | x, H, W = self.patch_embed3(x) 608 | for i, blk in enumerate(self.block3): 609 | 610 | x = blk( 611 | x, 612 | H, 613 | W, 614 | ) 615 | x = self.norm3(x) 616 | x = [x_.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for x_ in x] 617 | outs0.append(x[0]) 618 | outs1.append(x[1]) 619 | 620 | # stage 4 621 | x, H, W = self.patch_embed4(x) 622 | for i, blk in enumerate(self.block4): 623 | 624 | x = blk( 625 | x, 626 | H, 627 | W, 628 | ) 629 | x = self.norm4(x) 630 | x = [x_.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for x_ in x] 631 | outs0.append(x[0]) 632 | outs1.append(x[1]) 633 | 634 | return [outs0, outs1] 635 | 636 | def forward(self, x): 637 | x = self.forward_features(x) 638 | return x 639 | 640 | 641 | class DWConv(nn.Module): 642 | def __init__(self, dim=768): 643 | super(DWConv, self).__init__() 644 | self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) 645 | 646 | def forward(self, x, H, W): 647 | B, N, C = x.shape 648 | x = x.transpose(1, 2).view(B, C, H, W) 649 | x = self.dwconv(x) 650 | x = x.flatten(2).transpose(1, 2) 651 | 652 | return x 653 | 654 | 655 | class mit_b0(MixVisionTransformer): 656 | def __init__(self, n_heads, dpr, drop_rate, **kwargs): 657 | super(mit_b0, self).__init__( 658 | patch_size=4, 659 | embed_dims=[32, 64, 160, 256], 660 | num_heads=[1, 2, 5, 8], 661 | mlp_ratios=[4, 4, 4, 4], 662 | qkv_bias=True, 663 | norm_layer=LayerNormParallel, 664 | depths=[2, 2, 2, 2], 665 | sr_ratios=[8, 4, 2, 1], 666 | drop_rate=drop_rate, 667 | drop_path_rate=dpr, 668 | n_heads=n_heads, 669 | ) 670 | 671 | 672 | class mit_b1(MixVisionTransformer): 673 | def __init__(self, n_heads, dpr, drop_rate, **kwargs): 674 | super(mit_b1, self).__init__( 675 | patch_size=4, 676 | embed_dims=[64, 128, 320, 512], 677 | num_heads=[1, 2, 5, 8], 678 | mlp_ratios=[4, 4, 4, 4], 679 | qkv_bias=True, 680 | norm_layer=LayerNormParallel, 681 | depths=[2, 2, 2, 2], 682 | sr_ratios=[8, 4, 2, 1], 683 | drop_rate=drop_rate, 684 | drop_path_rate=dpr, 685 | n_heads=n_heads, 686 | ) 687 | 688 | 689 | class mit_b2(MixVisionTransformer): 690 | def __init__(self, n_heads, dpr, drop_rate, **kwargs): 691 | super(mit_b2, self).__init__( 692 | patch_size=4, 693 | embed_dims=[64, 128, 320, 512], 694 | num_heads=[1, 2, 5, 8], 695 | mlp_ratios=[4, 4, 4, 4], 696 | qkv_bias=True, 697 | norm_layer=LayerNormParallel, 698 | depths=[3, 4, 6, 3], 699 | sr_ratios=[8, 4, 2, 1], 700 | drop_rate=drop_rate, 701 | drop_path_rate=dpr, 702 | n_heads=n_heads, 703 | ) 704 | 705 | 706 | class mit_b3(MixVisionTransformer): 707 | def __init__(self, n_heads, dpr, drop_rate, **kwargs): 708 | super(mit_b3, self).__init__( 709 | patch_size=4, 710 | embed_dims=[64, 128, 320, 512], 711 | num_heads=[1, 2, 5, 8], 712 | mlp_ratios=[4, 4, 4, 4], 713 | qkv_bias=True, 714 | norm_layer=LayerNormParallel, 715 | depths=[3, 4, 18, 3], 716 | sr_ratios=[8, 4, 2, 1], 717 | drop_rate=drop_rate, 718 | drop_path_rate=dpr, 719 | n_heads=n_heads, 720 | ) 721 | 722 | 723 | class mit_b4(MixVisionTransformer): 724 | def __init__(self, n_heads, dpr, drop_rate, **kwargs): 725 | super(mit_b4, self).__init__( 726 | patch_size=4, 727 | embed_dims=[64, 128, 320, 512], 728 | num_heads=[1, 2, 5, 8], 729 | mlp_ratios=[4, 4, 4, 4], 730 | qkv_bias=True, 731 | norm_layer=LayerNormParallel, 732 | depths=[3, 8, 27, 3], 733 | sr_ratios=[8, 4, 2, 1], 734 | drop_rate=drop_rate, 735 | drop_path_rate=dpr, 736 | n_heads=n_heads, 737 | ) 738 | 739 | 740 | class mit_b5(MixVisionTransformer): 741 | def __init__(self, n_heads, dpr, drop_rate, **kwargs): 742 | super(mit_b5, self).__init__( 743 | patch_size=4, 744 | embed_dims=[64, 128, 320, 512], 745 | num_heads=[1, 2, 5, 8], 746 | mlp_ratios=[4, 4, 4, 4], 747 | qkv_bias=True, 748 | norm_layer=LayerNormParallel, 749 | depths=[3, 6, 40, 3], 750 | sr_ratios=[8, 4, 2, 1], 751 | drop_rate=drop_rate, 752 | drop_path_rate=dpr, 753 | n_heads=n_heads, 754 | ) 755 | -------------------------------------------------------------------------------- /models/swin_transformer.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu, Yutong Lin, Yixuan Wei 6 | # -------------------------------------------------------- 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import torch.utils.checkpoint as checkpoint 12 | import numpy as np 13 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 14 | import math 15 | from mmcv_custom import load_checkpoint 16 | from mmdet.utils import get_root_logger 17 | from .modules import ( 18 | ModuleParallel, 19 | Additional_One_ModuleParallel, 20 | LayerNormParallel, 21 | Additional_Two_ModuleParallel, 22 | ) 23 | 24 | 25 | class Mlp(nn.Module): 26 | """Multilayer perceptron.""" 27 | 28 | def __init__( 29 | self, 30 | in_features, 31 | hidden_features=None, 32 | out_features=None, 33 | act_layer=ModuleParallel(nn.GELU()), 34 | drop=0.0, 35 | ): 36 | super().__init__() 37 | out_features = out_features or in_features 38 | hidden_features = hidden_features or in_features 39 | self.fc1 = ModuleParallel(nn.Linear(in_features, hidden_features)) 40 | self.act = act_layer 41 | self.fc2 = ModuleParallel(nn.Linear(hidden_features, out_features)) 42 | self.drop = ModuleParallel(nn.Dropout(drop)) 43 | 44 | def forward(self, x): 45 | x = self.fc1(x) 46 | x = self.act(x) 47 | x = self.drop(x) 48 | x = self.fc2(x) 49 | x = self.drop(x) 50 | return x 51 | 52 | 53 | def window_partition(x, window_size): 54 | """ 55 | Args: 56 | x: (B, H, W, C) 57 | window_size (int): window size 58 | 59 | Returns: 60 | windows: (num_windows*B, window_size, window_size, C) 61 | """ 62 | B, H, W, C = x.shape 63 | x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) 64 | windows = ( 65 | x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 66 | ) 67 | return windows 68 | 69 | 70 | def window_reverse(windows, window_size, H, W): 71 | """ 72 | Args: 73 | windows: (num_windows*B, window_size, window_size, C) 74 | window_size (int): Window size 75 | H (int): Height of image 76 | W (int): Width of image 77 | 78 | Returns: 79 | x: (B, H, W, C) 80 | """ 81 | B = int(windows.shape[0] / (H * W / window_size / window_size)) 82 | x = windows.view( 83 | B, H // window_size, W // window_size, window_size, window_size, -1 84 | ) 85 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) 86 | return x 87 | 88 | 89 | class WindowAttention(nn.Module): 90 | """Window based multi-head self attention (W-MSA) module with relative position bias. 91 | It supports both of shifted and non-shifted window. 92 | 93 | Args: 94 | dim (int): Number of input channels. 95 | window_size (tuple[int]): The height and width of the window. 96 | num_heads (int): Number of attention heads. 97 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 98 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set 99 | attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 100 | proj_drop (float, optional): Dropout ratio of output. Default: 0.0 101 | """ 102 | 103 | def __init__( 104 | self, 105 | dim, 106 | window_size, 107 | num_heads, 108 | qkv_bias=True, 109 | qk_scale=None, 110 | attn_drop=0.0, 111 | proj_drop=0.0, 112 | ): 113 | 114 | super().__init__() 115 | self.dim = dim 116 | self.window_size = window_size # Wh, Ww 117 | self.num_heads = num_heads 118 | head_dim = dim // num_heads 119 | self.scale = qk_scale or head_dim**-0.5 120 | 121 | # define a parameter table of relative position bias 122 | self.relative_position_bias_table = nn.Parameter( 123 | torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads) 124 | ) # 2*Wh-1 * 2*Ww-1, nH 125 | 126 | # get pair-wise relative position index for each token inside the window 127 | coords_h = torch.arange(self.window_size[0]) 128 | coords_w = torch.arange(self.window_size[1]) 129 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww 130 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww 131 | relative_coords = ( 132 | coords_flatten[:, :, None] - coords_flatten[:, None, :] 133 | ) # 2, Wh*Ww, Wh*Ww 134 | relative_coords = relative_coords.permute( 135 | 1, 2, 0 136 | ).contiguous() # Wh*Ww, Wh*Ww, 2 137 | relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 138 | relative_coords[:, :, 1] += self.window_size[1] - 1 139 | relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 140 | relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww 141 | self.register_buffer("relative_position_index", relative_position_index) 142 | 143 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 144 | self.attn_drop = nn.Dropout(attn_drop) 145 | self.proj = nn.Linear(dim, dim) 146 | self.proj_drop = nn.Dropout(proj_drop) 147 | 148 | trunc_normal_(self.relative_position_bias_table, std=0.02) 149 | self.softmax = nn.Softmax(dim=-1) 150 | 151 | def forward(self, x, mask=None): 152 | """Forward function. 153 | 154 | Args: 155 | x: input features with shape of (num_windows*B, N, C) 156 | mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None 157 | """ 158 | B_, N, C = x.shape 159 | qkv = ( 160 | self.qkv(x) 161 | .reshape(B_, N, 3, self.num_heads, C // self.num_heads) 162 | .permute(2, 0, 3, 1, 4) 163 | ) 164 | q, k, v = ( 165 | qkv[0], 166 | qkv[1], 167 | qkv[2], 168 | ) # make torchscript happy (cannot use tensor as tuple) 169 | 170 | q = q * self.scale 171 | attn = q @ k.transpose(-2, -1) 172 | 173 | relative_position_bias = self.relative_position_bias_table[ 174 | self.relative_position_index.view(-1) 175 | ].view( 176 | self.window_size[0] * self.window_size[1], 177 | self.window_size[0] * self.window_size[1], 178 | -1, 179 | ) # Wh*Ww,Wh*Ww,nH 180 | relative_position_bias = relative_position_bias.permute( 181 | 2, 0, 1 182 | ).contiguous() # nH, Wh*Ww, Wh*Ww 183 | attn = attn + relative_position_bias.unsqueeze(0) 184 | 185 | if mask is not None: 186 | 187 | nW = mask.shape[0] 188 | attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze( 189 | 1 190 | ).unsqueeze(0) 191 | 192 | attn = attn.view(-1, self.num_heads, N, N) 193 | attn = self.softmax(attn) 194 | else: 195 | attn = self.softmax(attn) 196 | 197 | attn = self.attn_drop(attn) 198 | 199 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C) 200 | x = self.proj(x) 201 | x = self.proj_drop(x) 202 | return x 203 | 204 | 205 | class Mlp_2(nn.Module): 206 | """Multilayer perceptron.""" 207 | 208 | def __init__( 209 | self, 210 | in_features, 211 | hidden_features=None, 212 | out_features=None, 213 | act_layer=nn.GELU, 214 | drop=0.0, 215 | ): 216 | super().__init__() 217 | out_features = out_features or in_features 218 | hidden_features = hidden_features or in_features 219 | self.fc1 = nn.Linear(in_features, hidden_features) 220 | self.act = act_layer() 221 | self.fc2 = nn.Linear(hidden_features, out_features) 222 | self.drop = nn.Dropout(drop) 223 | 224 | def forward(self, x): 225 | x = self.fc1(x) 226 | x = self.act(x) 227 | x = self.drop(x) 228 | x = self.fc2(x) 229 | x = self.drop(x) 230 | return x 231 | 232 | 233 | class SwinTransformerBlock(nn.Module): 234 | """Swin Transformer Block. 235 | 236 | Args: 237 | dim (int): Number of input channels. 238 | num_heads (int): Number of attention heads. 239 | window_size (int): Window size. 240 | shift_size (int): Shift size for SW-MSA. 241 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 242 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 243 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. 244 | drop (float, optional): Dropout rate. Default: 0.0 245 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 246 | drop_path (float, optional): Stochastic depth rate. Default: 0.0 247 | act_layer (nn.Module, optional): Activation layer. Default: nn.GELU 248 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 249 | """ 250 | 251 | def __init__( 252 | self, 253 | dim, 254 | num_heads, 255 | window_size=7, 256 | shift_size=0, 257 | mlp_ratio=4.0, 258 | qkv_bias=True, 259 | qk_scale=None, 260 | drop=0.0, 261 | attn_drop=0.0, 262 | drop_path=0.0, 263 | act_layer=ModuleParallel(nn.GELU()), 264 | norm_layer=LayerNormParallel, 265 | ): 266 | super().__init__() 267 | self.dim = dim 268 | self.num_heads = num_heads 269 | self.window_size = window_size 270 | self.shift_size = shift_size 271 | self.mlp_ratio = mlp_ratio 272 | assert ( 273 | 0 <= self.shift_size < self.window_size 274 | ), "shift_size must in 0-window_size" 275 | 276 | self.norm1 = ModuleParallel(nn.LayerNorm(dim)) 277 | self.attn = Additional_One_ModuleParallel( 278 | WindowAttention( 279 | dim, 280 | window_size=to_2tuple(self.window_size), 281 | num_heads=num_heads, 282 | qkv_bias=qkv_bias, 283 | qk_scale=qk_scale, 284 | attn_drop=attn_drop, 285 | proj_drop=drop, 286 | ) 287 | ) 288 | 289 | self.drop_path = ( 290 | ModuleParallel(DropPath(drop_path)) if drop_path > 0.0 else nn.Identity() 291 | ) 292 | self.norm2 = ModuleParallel(nn.LayerNorm(dim)) 293 | mlp_hidden_dim = int(dim * mlp_ratio) 294 | self.mlp = Mlp( 295 | in_features=dim, 296 | hidden_features=mlp_hidden_dim, 297 | act_layer=act_layer, 298 | drop=drop, 299 | ) 300 | 301 | self.H = None 302 | self.W = None 303 | 304 | self.cross_heads = 8 305 | self.cross_attn_0_to_1 = nn.MultiheadAttention( 306 | dim, self.cross_heads, dropout=0.0, batch_first=False 307 | ) 308 | self.cross_attn_1_to_0 = nn.MultiheadAttention( 309 | dim, self.cross_heads, dropout=0.0, batch_first=False 310 | ) 311 | self.relation_judger = nn.Sequential( 312 | Mlp_2(dim * 2, dim, dim), torch.nn.Softmax(dim=-1) 313 | ) 314 | self.k_noise = nn.Embedding(2, dim) 315 | self.v_noise = nn.Embedding(2, dim) 316 | 317 | self.apply(self._init_weights) 318 | 319 | def _init_weights(self, m): 320 | if isinstance(m, nn.Linear): 321 | trunc_normal_(m.weight, std=0.02) 322 | if isinstance(m, nn.Linear) and m.bias is not None: 323 | nn.init.constant_(m.bias, 0) 324 | elif isinstance(m, nn.LayerNorm): 325 | nn.init.constant_(m.bias, 0) 326 | nn.init.constant_(m.weight, 1.0) 327 | elif isinstance(m, nn.Conv2d): 328 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 329 | fan_out //= m.groups 330 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 331 | if m.bias is not None: 332 | m.bias.data.zero_() 333 | 334 | def forward(self, x, mask_matrix): 335 | """Forward function. 336 | 337 | Args: 338 | x: Input feature, tensor size (B, H*W, C). 339 | H, W: Spatial resolution of the input feature. 340 | mask_matrix: Attention mask for cyclic shift. 341 | """ 342 | B, L, C = x[0].shape 343 | H, W = self.H, self.W 344 | assert L == H * W, "input feature has wrong size" 345 | 346 | shortcut = x 347 | x = self.norm1(x) 348 | for i in range(len(x)): 349 | x[i] = x[i].view(B, H, W, C) 350 | 351 | # pad feature maps to multiples of window size 352 | pad_l = pad_t = 0 353 | pad_r = (self.window_size - W % self.window_size) % self.window_size 354 | pad_b = (self.window_size - H % self.window_size) % self.window_size 355 | for i in range(len(x)): 356 | x[i] = F.pad(x[i], (0, 0, pad_l, pad_r, pad_t, pad_b)) 357 | _, Hp, Wp, _ = x[0].shape 358 | 359 | # cyclic shift 360 | if self.shift_size > 0: 361 | shifted_x = [ 362 | torch.roll( 363 | x[i], shifts=(-self.shift_size, -self.shift_size), dims=(1, 2) 364 | ) 365 | for i in range(len(x)) 366 | ] 367 | attn_mask = mask_matrix 368 | else: 369 | shifted_x = x 370 | attn_mask = None 371 | 372 | # partition windows 373 | x_windows = [ 374 | window_partition(shifted_x[i], self.window_size) 375 | for i in range(len(shifted_x)) 376 | ] # nW*B, window_size, window_size, C 377 | 378 | for i in range(len(x_windows)): 379 | x_windows[i] = x_windows[i].view( 380 | -1, self.window_size * self.window_size, C 381 | ) # nW*B, window_size*window_size, C 382 | 383 | # W-MSA/SW-MSA 384 | attn_windows = self.attn( 385 | x_windows, attn_mask 386 | ) # nW*B, window_size*window_size, C 387 | 388 | # merge windows 389 | for i in range(len(attn_windows)): 390 | attn_windows[i] = attn_windows[i].view( 391 | -1, self.window_size, self.window_size, C 392 | ) 393 | shifted_x = [ 394 | window_reverse(attn_windows[i], self.window_size, Hp, Wp) 395 | for i in range(len(attn_windows)) 396 | ] # B H' W' C 397 | 398 | # reverse cyclic shift 399 | if self.shift_size > 0: 400 | x = [ 401 | torch.roll( 402 | shifted_x[i], shifts=(self.shift_size, self.shift_size), dims=(1, 2) 403 | ) 404 | for i in range(len(shifted_x)) 405 | ] 406 | else: 407 | x = shifted_x 408 | 409 | if pad_r > 0 or pad_b > 0: 410 | for i in range(len(x)): 411 | x[i] = x[i][:, :H, :W, :].contiguous() 412 | for i in range(len(x)): 413 | x[i] = x[i].view(B, H * W, C) 414 | 415 | # cross-attn per batch 416 | new_x0 = [] 417 | new_x1 = [] 418 | for bs in range(B): 419 | ## 1. 0_to_1 cross attn and skip connect 420 | q = x[0][bs].unsqueeze(0) 421 | # k = v = x[1][bs].unsqueeze(0) 422 | judger_input = torch.cat( 423 | [x[0][bs].unsqueeze(0), x[1][bs].unsqueeze(0)], dim=-1 424 | ) 425 | 426 | relation_score = self.relation_judger(judger_input) 427 | 428 | noise_k = self.k_noise.weight[0] + q 429 | noise_v = self.v_noise.weight[0] + q 430 | 431 | k = torch.cat([noise_k, torch.mul(q, relation_score)], dim=0) 432 | v = torch.cat([noise_v, x[1][bs].unsqueeze(0)], dim=0) 433 | 434 | new_x0.append(x[0][bs] + self.cross_attn_0_to_1(q, k, v)[0].squeeze(0)) 435 | 436 | ## 2. 1_to_0 cross attn and skip connect 437 | q = x[1][bs].unsqueeze(0) 438 | # k = v = x[0][bs].unsqueeze(0) 439 | judger_input = torch.cat( 440 | [x[1][bs].unsqueeze(0), x[0][bs].unsqueeze(0)], dim=-1 441 | ) 442 | 443 | relation_score = self.relation_judger(judger_input) 444 | 445 | noise_k = self.k_noise.weight[1] + q 446 | noise_v = self.v_noise.weight[1] + q 447 | 448 | k = torch.cat([noise_k, torch.mul(q, relation_score)], dim=0) 449 | v = torch.cat([noise_v, x[0][bs].unsqueeze(0)], dim=0) 450 | 451 | new_x1.append(x[1][bs] + self.cross_attn_1_to_0(q, k, v)[0].squeeze(0)) 452 | 453 | new_x0 = torch.stack(new_x0) 454 | new_x1 = torch.stack(new_x1) 455 | x[0] = new_x0 456 | x[1] = new_x1 457 | 458 | # FFN 459 | x_dp1 = self.drop_path(x) 460 | for i in range(len(x)): 461 | x[i] = shortcut[i] + x_dp1[i] 462 | x_dp2 = self.drop_path(self.mlp(self.norm2(x))) 463 | for i in range(len(x)): 464 | x[i] = x[i] + x_dp2[i] 465 | 466 | return x 467 | 468 | 469 | class PatchMerging(nn.Module): 470 | """Patch Merging Layer 471 | 472 | Args: 473 | dim (int): Number of input channels. 474 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 475 | """ 476 | 477 | def __init__(self, dim, norm_layer=nn.LayerNorm): 478 | super().__init__() 479 | self.dim = dim 480 | self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) 481 | self.norm = norm_layer(4 * dim) 482 | 483 | def forward(self, x, H, W): 484 | """Forward function. 485 | 486 | Args: 487 | x: Input feature, tensor size (B, H*W, C). 488 | H, W: Spatial resolution of the input feature. 489 | """ 490 | 491 | B, L, C = x.shape 492 | assert L == H * W, "input feature has wrong size" 493 | 494 | x = x.view(B, H, W, C) 495 | 496 | # padding 497 | pad_input = (H % 2 == 1) or (W % 2 == 1) 498 | if pad_input: 499 | x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) 500 | 501 | x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C 502 | x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C 503 | x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C 504 | x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C 505 | x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C 506 | x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C 507 | 508 | x = self.norm(x) 509 | x = self.reduction(x) 510 | 511 | return x 512 | 513 | 514 | class BasicLayer(nn.Module): 515 | """A basic Swin Transformer layer for one stage. 516 | 517 | Args: 518 | dim (int): Number of feature channels 519 | depth (int): Depths of this stage. 520 | num_heads (int): Number of attention head. 521 | window_size (int): Local window size. Default: 7. 522 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. 523 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 524 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. 525 | drop (float, optional): Dropout rate. Default: 0.0 526 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 527 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 528 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 529 | downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None 530 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. 531 | """ 532 | 533 | def __init__( 534 | self, 535 | dim, 536 | depth, 537 | num_heads, 538 | window_size=7, 539 | mlp_ratio=4.0, 540 | qkv_bias=True, 541 | qk_scale=None, 542 | drop=0.0, 543 | attn_drop=0.0, 544 | drop_path=0.0, 545 | norm_layer=LayerNormParallel, 546 | downsample=None, 547 | use_checkpoint=False, 548 | ): 549 | super().__init__() 550 | self.window_size = window_size 551 | self.shift_size = window_size // 2 552 | self.depth = depth 553 | self.use_checkpoint = use_checkpoint 554 | 555 | # build blocks 556 | self.blocks = nn.ModuleList( 557 | [ 558 | SwinTransformerBlock( 559 | dim=dim, 560 | num_heads=num_heads, 561 | window_size=window_size, 562 | shift_size=0 if (i % 2 == 0) else window_size // 2, 563 | mlp_ratio=mlp_ratio, 564 | qkv_bias=qkv_bias, 565 | qk_scale=qk_scale, 566 | drop=drop, 567 | attn_drop=attn_drop, 568 | drop_path=( 569 | drop_path[i] if isinstance(drop_path, list) else drop_path 570 | ), 571 | norm_layer=norm_layer, 572 | ) 573 | for i in range(depth) 574 | ] 575 | ) 576 | 577 | # patch merging layer 578 | if downsample is not None: 579 | self.downsample = Additional_Two_ModuleParallel( 580 | downsample(dim=dim, norm_layer=nn.LayerNorm) 581 | ) 582 | else: 583 | self.downsample = None 584 | 585 | def forward(self, x, H, W): 586 | """Forward function. 587 | 588 | Args: 589 | x: Input feature, tensor size (B, H*W, C). 590 | H, W: Spatial resolution of the input feature. 591 | """ 592 | 593 | # calculate attention mask for SW-MSA 594 | Hp = int(np.ceil(H / self.window_size)) * self.window_size 595 | Wp = int(np.ceil(W / self.window_size)) * self.window_size 596 | img_mask = torch.zeros((1, Hp, Wp, 1), device=x[0].device) # 1 Hp Wp 1 597 | h_slices = ( 598 | slice(0, -self.window_size), 599 | slice(-self.window_size, -self.shift_size), 600 | slice(-self.shift_size, None), 601 | ) 602 | w_slices = ( 603 | slice(0, -self.window_size), 604 | slice(-self.window_size, -self.shift_size), 605 | slice(-self.shift_size, None), 606 | ) 607 | cnt = 0 608 | for h in h_slices: 609 | for w in w_slices: 610 | img_mask[:, h, w, :] = cnt 611 | cnt += 1 612 | 613 | mask_windows = window_partition( 614 | img_mask, self.window_size 615 | ) # nW, window_size, window_size, 1 616 | mask_windows = mask_windows.view(-1, self.window_size * self.window_size) 617 | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) 618 | attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill( 619 | attn_mask == 0, float(0.0) 620 | ) 621 | 622 | for blk in self.blocks: 623 | blk.H, blk.W = H, W 624 | if self.use_checkpoint: 625 | x = checkpoint.checkpoint(blk, x, attn_mask) 626 | else: 627 | x = blk(x, attn_mask) 628 | if self.downsample is not None: 629 | x_down = self.downsample(x, H, W) 630 | Wh, Ww = (H + 1) // 2, (W + 1) // 2 631 | return x, H, W, x_down, Wh, Ww 632 | else: 633 | return x, H, W, x, H, W 634 | 635 | 636 | class PatchEmbed(nn.Module): 637 | """Image to Patch Embedding 638 | 639 | Args: 640 | patch_size (int): Patch token size. Default: 4. 641 | in_chans (int): Number of input image channels. Default: 3. 642 | embed_dim (int): Number of linear projection output channels. Default: 96. 643 | norm_layer (nn.Module, optional): Normalization layer. Default: None 644 | """ 645 | 646 | def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): 647 | super().__init__() 648 | patch_size = to_2tuple(patch_size) 649 | self.patch_size = patch_size 650 | 651 | self.in_chans = in_chans 652 | self.embed_dim = embed_dim 653 | 654 | self.proj = ModuleParallel( 655 | nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 656 | ) 657 | if norm_layer is not None: 658 | self.norm = norm_layer(embed_dim) 659 | else: 660 | self.norm = None 661 | 662 | def forward(self, x): 663 | """Forward function.""" 664 | # padding 665 | _, _, H, W = x[0].size() 666 | if W % self.patch_size[1] != 0: 667 | x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1])) 668 | if H % self.patch_size[0] != 0: 669 | x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0])) 670 | 671 | x = self.proj(x) # B C Wh Ww 672 | if self.norm is not None: 673 | Wh, Ww = x[0].size(2), x[0].size(3) 674 | for i in range(len(x)): 675 | x[i] = x[i].flatten(2).transpose(1, 2) 676 | x = self.norm(x) 677 | for i in range(len(x)): 678 | x[i] = x[i].transpose(1, 2).view(-1, self.embed_dim, Wh, Ww) 679 | 680 | return x 681 | 682 | 683 | class SwinTransformer(nn.Module): 684 | """Swin Transformer backbone. 685 | A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - 686 | https://arxiv.org/pdf/2103.14030 687 | 688 | Args: 689 | pretrain_img_size (int): Input image size for training the pretrained model, 690 | used in absolute postion embedding. Default 224. 691 | patch_size (int | tuple(int)): Patch size. Default: 4. 692 | in_chans (int): Number of input image channels. Default: 3. 693 | embed_dim (int): Number of linear projection output channels. Default: 96. 694 | depths (tuple[int]): Depths of each Swin Transformer stage. 695 | num_heads (tuple[int]): Number of attention head of each stage. 696 | window_size (int): Window size. Default: 7. 697 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. 698 | qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True 699 | qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. 700 | drop_rate (float): Dropout rate. 701 | attn_drop_rate (float): Attention dropout rate. Default: 0. 702 | drop_path_rate (float): Stochastic depth rate. Default: 0.2. 703 | norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. 704 | ape (bool): If True, add absolute position embedding to the patch embedding. Default: False. 705 | patch_norm (bool): If True, add normalization after patch embedding. Default: True. 706 | out_indices (Sequence[int]): Output from which stages. 707 | frozen_stages (int): Stages to be frozen (stop grad and set eval mode). 708 | -1 means not freezing any parameters. 709 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. 710 | """ 711 | 712 | def __init__( 713 | self, 714 | pretrain_img_size=224, 715 | patch_size=4, 716 | in_chans=3, 717 | embed_dim=96, 718 | depths=[2, 2, 6, 2], 719 | num_heads=[3, 6, 12, 24], 720 | window_size=7, 721 | mlp_ratio=4.0, 722 | qkv_bias=True, 723 | qk_scale=None, 724 | drop_rate=0.0, 725 | attn_drop_rate=0.0, 726 | drop_path_rate=0.2, 727 | norm_layer=LayerNormParallel, 728 | ape=False, 729 | patch_norm=True, 730 | out_indices=(0, 1, 2, 3), 731 | frozen_stages=-1, 732 | use_checkpoint=False, 733 | ): 734 | super().__init__() 735 | self.drop_path_rate = drop_path_rate 736 | self.pretrain_img_size = pretrain_img_size 737 | self.num_layers = len(depths) 738 | self.embed_dim = embed_dim 739 | self.ape = ape 740 | self.patch_norm = patch_norm 741 | self.out_indices = out_indices 742 | self.frozen_stages = frozen_stages 743 | 744 | # split image into non-overlapping patches 745 | self.patch_embed = PatchEmbed( 746 | patch_size=patch_size, 747 | in_chans=in_chans, 748 | embed_dim=embed_dim, 749 | norm_layer=norm_layer if self.patch_norm else None, 750 | ) 751 | 752 | # absolute position embedding 753 | if self.ape: 754 | pretrain_img_size = to_2tuple(pretrain_img_size) 755 | patch_size = to_2tuple(patch_size) 756 | patches_resolution = [ 757 | pretrain_img_size[0] // patch_size[0], 758 | pretrain_img_size[1] // patch_size[1], 759 | ] 760 | 761 | self.absolute_pos_embed = nn.Parameter( 762 | torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1]) 763 | ) 764 | trunc_normal_(self.absolute_pos_embed, std=0.02) 765 | 766 | self.pos_drop = ModuleParallel(nn.Dropout(p=drop_rate)) 767 | 768 | # stochastic depth 769 | dpr = [ 770 | x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)) 771 | ] # stochastic depth decay rule 772 | 773 | # build layers 774 | self.layers = nn.ModuleList() 775 | for i_layer in range(self.num_layers): 776 | layer = BasicLayer( 777 | dim=int(embed_dim * 2**i_layer), 778 | depth=depths[i_layer], 779 | num_heads=num_heads[i_layer], 780 | window_size=window_size, 781 | mlp_ratio=mlp_ratio, 782 | qkv_bias=qkv_bias, 783 | qk_scale=qk_scale, 784 | drop=drop_rate, 785 | attn_drop=attn_drop_rate, 786 | drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])], 787 | norm_layer=norm_layer, 788 | downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, 789 | use_checkpoint=use_checkpoint, 790 | ) 791 | self.layers.append(layer) 792 | 793 | num_features = [int(embed_dim * 2**i) for i in range(self.num_layers)] 794 | self.num_features = num_features 795 | 796 | # add a norm layer for each output 797 | for i_layer in out_indices: 798 | layer = norm_layer(num_features[i_layer]) 799 | layer_name = f"norm{i_layer}" 800 | self.add_module(layer_name, layer) 801 | 802 | self._freeze_stages() 803 | 804 | def _freeze_stages(self): 805 | if self.frozen_stages >= 0: 806 | self.patch_embed.eval() 807 | for param in self.patch_embed.parameters(): 808 | param.requires_grad = False 809 | 810 | if self.frozen_stages >= 1 and self.ape: 811 | self.absolute_pos_embed.requires_grad = False 812 | 813 | if self.frozen_stages >= 2: 814 | self.pos_drop.eval() 815 | for i in range(0, self.frozen_stages - 1): 816 | m = self.layers[i] 817 | m.eval() 818 | for param in m.parameters(): 819 | param.requires_grad = False 820 | 821 | def init_weights(self, pretrained=None): 822 | """Initialize the weights in backbone. 823 | 824 | Args: 825 | pretrained (str, optional): Path to pre-trained weights. 826 | Defaults to None. 827 | """ 828 | 829 | def _init_weights(m): 830 | pass 831 | 832 | if isinstance(pretrained, str): 833 | self.apply(_init_weights) 834 | logger = get_root_logger() 835 | load_checkpoint(self, pretrained, strict=False, logger=logger) 836 | elif pretrained is None: 837 | self.apply(_init_weights) 838 | else: 839 | raise TypeError("pretrained must be a str or None") 840 | 841 | def forward(self, x): 842 | """Forward function.""" 843 | x = self.patch_embed(x) 844 | 845 | Wh, Ww = x[0].size(2), x[0].size(3) 846 | if self.ape: 847 | # interpolate the position embedding to the corresponding size 848 | absolute_pos_embed = F.interpolate( 849 | self.absolute_pos_embed, size=(Wh, Ww), mode="bicubic" 850 | ) 851 | x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C 852 | else: 853 | for i in range(len(x)): 854 | x[i] = x[i].flatten(2).transpose(1, 2) 855 | x = self.pos_drop(x) 856 | 857 | outs = {} 858 | for i in range(self.num_layers): 859 | layer = self.layers[i] 860 | x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww) 861 | 862 | if i in self.out_indices: 863 | norm_layer = getattr(self, f"norm{i}") 864 | x_out = norm_layer(x_out) 865 | 866 | out = [ 867 | x_out[j] 868 | .view(-1, H, W, self.num_features[i]) 869 | .permute(0, 3, 1, 2) 870 | .contiguous() 871 | for j in range(len(x_out)) 872 | ] 873 | outs[i] = out 874 | 875 | new_x0 = [] 876 | new_x1 = [] 877 | for i in range(4): 878 | new_x0.append(outs[i][0]) 879 | new_x1.append(outs[i][1]) 880 | x = [new_x0, new_x1] 881 | 882 | return x 883 | 884 | def train(self, mode=True): 885 | """Convert the model into training mode while keep layers freezed.""" 886 | super(SwinTransformer, self).train(mode) 887 | self._freeze_stages() 888 | -------------------------------------------------------------------------------- /data/nyudv2/val.txt: -------------------------------------------------------------------------------- 1 | rgb/000001.png masks/000001.png depth/000001.png 2 | rgb/000002.png masks/000002.png depth/000002.png 3 | rgb/000009.png masks/000009.png depth/000009.png 4 | rgb/000014.png masks/000014.png depth/000014.png 5 | rgb/000015.png masks/000015.png depth/000015.png 6 | rgb/000016.png masks/000016.png depth/000016.png 7 | rgb/000017.png masks/000017.png depth/000017.png 8 | rgb/000018.png masks/000018.png depth/000018.png 9 | rgb/000021.png masks/000021.png depth/000021.png 10 | rgb/000028.png masks/000028.png depth/000028.png 11 | rgb/000029.png masks/000029.png depth/000029.png 12 | rgb/000030.png masks/000030.png depth/000030.png 13 | rgb/000031.png masks/000031.png depth/000031.png 14 | rgb/000032.png masks/000032.png depth/000032.png 15 | rgb/000033.png masks/000033.png depth/000033.png 16 | rgb/000034.png masks/000034.png depth/000034.png 17 | rgb/000035.png masks/000035.png depth/000035.png 18 | rgb/000036.png masks/000036.png depth/000036.png 19 | rgb/000037.png masks/000037.png depth/000037.png 20 | rgb/000038.png masks/000038.png depth/000038.png 21 | rgb/000039.png masks/000039.png depth/000039.png 22 | rgb/000040.png masks/000040.png depth/000040.png 23 | rgb/000041.png masks/000041.png depth/000041.png 24 | rgb/000042.png masks/000042.png depth/000042.png 25 | rgb/000043.png masks/000043.png depth/000043.png 26 | rgb/000046.png masks/000046.png depth/000046.png 27 | rgb/000047.png masks/000047.png depth/000047.png 28 | rgb/000056.png masks/000056.png depth/000056.png 29 | rgb/000057.png masks/000057.png depth/000057.png 30 | rgb/000059.png masks/000059.png depth/000059.png 31 | rgb/000060.png masks/000060.png depth/000060.png 32 | rgb/000061.png masks/000061.png depth/000061.png 33 | rgb/000062.png masks/000062.png depth/000062.png 34 | rgb/000063.png masks/000063.png depth/000063.png 35 | rgb/000076.png masks/000076.png depth/000076.png 36 | rgb/000077.png masks/000077.png depth/000077.png 37 | rgb/000078.png masks/000078.png depth/000078.png 38 | rgb/000079.png masks/000079.png depth/000079.png 39 | rgb/000084.png masks/000084.png depth/000084.png 40 | rgb/000085.png masks/000085.png depth/000085.png 41 | rgb/000086.png masks/000086.png depth/000086.png 42 | rgb/000087.png masks/000087.png depth/000087.png 43 | rgb/000088.png masks/000088.png depth/000088.png 44 | rgb/000089.png masks/000089.png depth/000089.png 45 | rgb/000090.png masks/000090.png depth/000090.png 46 | rgb/000091.png masks/000091.png depth/000091.png 47 | rgb/000117.png masks/000117.png depth/000117.png 48 | rgb/000118.png masks/000118.png depth/000118.png 49 | rgb/000119.png masks/000119.png depth/000119.png 50 | rgb/000125.png masks/000125.png depth/000125.png 51 | rgb/000126.png masks/000126.png depth/000126.png 52 | rgb/000127.png masks/000127.png depth/000127.png 53 | rgb/000128.png masks/000128.png depth/000128.png 54 | rgb/000129.png masks/000129.png depth/000129.png 55 | rgb/000131.png masks/000131.png depth/000131.png 56 | rgb/000132.png masks/000132.png depth/000132.png 57 | rgb/000133.png masks/000133.png depth/000133.png 58 | rgb/000134.png masks/000134.png depth/000134.png 59 | rgb/000137.png masks/000137.png depth/000137.png 60 | rgb/000153.png masks/000153.png depth/000153.png 61 | rgb/000154.png masks/000154.png depth/000154.png 62 | rgb/000155.png masks/000155.png depth/000155.png 63 | rgb/000167.png masks/000167.png depth/000167.png 64 | rgb/000168.png masks/000168.png depth/000168.png 65 | rgb/000169.png masks/000169.png depth/000169.png 66 | rgb/000171.png masks/000171.png depth/000171.png 67 | rgb/000172.png masks/000172.png depth/000172.png 68 | rgb/000173.png masks/000173.png depth/000173.png 69 | rgb/000174.png masks/000174.png depth/000174.png 70 | rgb/000175.png masks/000175.png depth/000175.png 71 | rgb/000176.png masks/000176.png depth/000176.png 72 | rgb/000180.png masks/000180.png depth/000180.png 73 | rgb/000181.png masks/000181.png depth/000181.png 74 | rgb/000182.png masks/000182.png depth/000182.png 75 | rgb/000183.png masks/000183.png depth/000183.png 76 | rgb/000184.png masks/000184.png depth/000184.png 77 | rgb/000185.png masks/000185.png depth/000185.png 78 | rgb/000186.png masks/000186.png depth/000186.png 79 | rgb/000187.png masks/000187.png depth/000187.png 80 | rgb/000188.png masks/000188.png depth/000188.png 81 | rgb/000189.png masks/000189.png depth/000189.png 82 | rgb/000190.png masks/000190.png depth/000190.png 83 | rgb/000191.png masks/000191.png depth/000191.png 84 | rgb/000192.png masks/000192.png depth/000192.png 85 | rgb/000193.png masks/000193.png depth/000193.png 86 | rgb/000194.png masks/000194.png depth/000194.png 87 | rgb/000195.png masks/000195.png depth/000195.png 88 | rgb/000196.png masks/000196.png depth/000196.png 89 | rgb/000197.png masks/000197.png depth/000197.png 90 | rgb/000198.png masks/000198.png depth/000198.png 91 | rgb/000199.png masks/000199.png depth/000199.png 92 | rgb/000200.png masks/000200.png depth/000200.png 93 | rgb/000201.png masks/000201.png depth/000201.png 94 | rgb/000202.png masks/000202.png depth/000202.png 95 | rgb/000207.png masks/000207.png depth/000207.png 96 | rgb/000208.png masks/000208.png depth/000208.png 97 | rgb/000209.png masks/000209.png depth/000209.png 98 | rgb/000210.png masks/000210.png depth/000210.png 99 | rgb/000211.png masks/000211.png depth/000211.png 100 | rgb/000212.png masks/000212.png depth/000212.png 101 | rgb/000220.png masks/000220.png depth/000220.png 102 | rgb/000221.png masks/000221.png depth/000221.png 103 | rgb/000222.png masks/000222.png depth/000222.png 104 | rgb/000250.png masks/000250.png depth/000250.png 105 | rgb/000264.png masks/000264.png depth/000264.png 106 | rgb/000271.png masks/000271.png depth/000271.png 107 | rgb/000272.png masks/000272.png depth/000272.png 108 | rgb/000273.png masks/000273.png depth/000273.png 109 | rgb/000279.png masks/000279.png depth/000279.png 110 | rgb/000280.png masks/000280.png depth/000280.png 111 | rgb/000281.png masks/000281.png depth/000281.png 112 | rgb/000282.png masks/000282.png depth/000282.png 113 | rgb/000283.png masks/000283.png depth/000283.png 114 | rgb/000284.png masks/000284.png depth/000284.png 115 | rgb/000285.png masks/000285.png depth/000285.png 116 | rgb/000296.png masks/000296.png depth/000296.png 117 | rgb/000297.png masks/000297.png depth/000297.png 118 | rgb/000298.png masks/000298.png depth/000298.png 119 | rgb/000299.png masks/000299.png depth/000299.png 120 | rgb/000300.png masks/000300.png depth/000300.png 121 | rgb/000301.png masks/000301.png depth/000301.png 122 | rgb/000302.png masks/000302.png depth/000302.png 123 | rgb/000310.png masks/000310.png depth/000310.png 124 | rgb/000311.png masks/000311.png depth/000311.png 125 | rgb/000312.png masks/000312.png depth/000312.png 126 | rgb/000315.png masks/000315.png depth/000315.png 127 | rgb/000316.png masks/000316.png depth/000316.png 128 | rgb/000317.png masks/000317.png depth/000317.png 129 | rgb/000325.png masks/000325.png depth/000325.png 130 | rgb/000326.png masks/000326.png depth/000326.png 131 | rgb/000327.png masks/000327.png depth/000327.png 132 | rgb/000328.png masks/000328.png depth/000328.png 133 | rgb/000329.png masks/000329.png depth/000329.png 134 | rgb/000330.png masks/000330.png depth/000330.png 135 | rgb/000331.png masks/000331.png depth/000331.png 136 | rgb/000332.png masks/000332.png depth/000332.png 137 | rgb/000333.png masks/000333.png depth/000333.png 138 | rgb/000334.png masks/000334.png depth/000334.png 139 | rgb/000335.png masks/000335.png depth/000335.png 140 | rgb/000351.png masks/000351.png depth/000351.png 141 | rgb/000352.png masks/000352.png depth/000352.png 142 | rgb/000355.png masks/000355.png depth/000355.png 143 | rgb/000356.png masks/000356.png depth/000356.png 144 | rgb/000357.png masks/000357.png depth/000357.png 145 | rgb/000358.png masks/000358.png depth/000358.png 146 | rgb/000359.png masks/000359.png depth/000359.png 147 | rgb/000360.png masks/000360.png depth/000360.png 148 | rgb/000361.png masks/000361.png depth/000361.png 149 | rgb/000362.png masks/000362.png depth/000362.png 150 | rgb/000363.png masks/000363.png depth/000363.png 151 | rgb/000364.png masks/000364.png depth/000364.png 152 | rgb/000384.png masks/000384.png depth/000384.png 153 | rgb/000385.png masks/000385.png depth/000385.png 154 | rgb/000386.png masks/000386.png depth/000386.png 155 | rgb/000387.png masks/000387.png depth/000387.png 156 | rgb/000388.png masks/000388.png depth/000388.png 157 | rgb/000389.png masks/000389.png depth/000389.png 158 | rgb/000390.png masks/000390.png depth/000390.png 159 | rgb/000395.png masks/000395.png depth/000395.png 160 | rgb/000396.png masks/000396.png depth/000396.png 161 | rgb/000397.png masks/000397.png depth/000397.png 162 | rgb/000411.png masks/000411.png depth/000411.png 163 | rgb/000412.png masks/000412.png depth/000412.png 164 | rgb/000413.png masks/000413.png depth/000413.png 165 | rgb/000414.png masks/000414.png depth/000414.png 166 | rgb/000430.png masks/000430.png depth/000430.png 167 | rgb/000431.png masks/000431.png depth/000431.png 168 | rgb/000432.png masks/000432.png depth/000432.png 169 | rgb/000433.png masks/000433.png depth/000433.png 170 | rgb/000434.png masks/000434.png depth/000434.png 171 | rgb/000435.png masks/000435.png depth/000435.png 172 | rgb/000441.png masks/000441.png depth/000441.png 173 | rgb/000442.png masks/000442.png depth/000442.png 174 | rgb/000443.png masks/000443.png depth/000443.png 175 | rgb/000444.png masks/000444.png depth/000444.png 176 | rgb/000445.png masks/000445.png depth/000445.png 177 | rgb/000446.png masks/000446.png depth/000446.png 178 | rgb/000447.png masks/000447.png depth/000447.png 179 | rgb/000448.png masks/000448.png depth/000448.png 180 | rgb/000462.png masks/000462.png depth/000462.png 181 | rgb/000463.png masks/000463.png depth/000463.png 182 | rgb/000464.png masks/000464.png depth/000464.png 183 | rgb/000465.png masks/000465.png depth/000465.png 184 | rgb/000466.png masks/000466.png depth/000466.png 185 | rgb/000469.png masks/000469.png depth/000469.png 186 | rgb/000470.png masks/000470.png depth/000470.png 187 | rgb/000471.png masks/000471.png depth/000471.png 188 | rgb/000472.png masks/000472.png depth/000472.png 189 | rgb/000473.png masks/000473.png depth/000473.png 190 | rgb/000474.png masks/000474.png depth/000474.png 191 | rgb/000475.png masks/000475.png depth/000475.png 192 | rgb/000476.png masks/000476.png depth/000476.png 193 | rgb/000477.png masks/000477.png depth/000477.png 194 | rgb/000508.png masks/000508.png depth/000508.png 195 | rgb/000509.png masks/000509.png depth/000509.png 196 | rgb/000510.png masks/000510.png depth/000510.png 197 | rgb/000511.png masks/000511.png depth/000511.png 198 | rgb/000512.png masks/000512.png depth/000512.png 199 | rgb/000513.png masks/000513.png depth/000513.png 200 | rgb/000515.png masks/000515.png depth/000515.png 201 | rgb/000516.png masks/000516.png depth/000516.png 202 | rgb/000517.png masks/000517.png depth/000517.png 203 | rgb/000518.png masks/000518.png depth/000518.png 204 | rgb/000519.png masks/000519.png depth/000519.png 205 | rgb/000520.png masks/000520.png depth/000520.png 206 | rgb/000521.png masks/000521.png depth/000521.png 207 | rgb/000522.png masks/000522.png depth/000522.png 208 | rgb/000523.png masks/000523.png depth/000523.png 209 | rgb/000524.png masks/000524.png depth/000524.png 210 | rgb/000525.png masks/000525.png depth/000525.png 211 | rgb/000526.png masks/000526.png depth/000526.png 212 | rgb/000531.png masks/000531.png depth/000531.png 213 | rgb/000532.png masks/000532.png depth/000532.png 214 | rgb/000533.png masks/000533.png depth/000533.png 215 | rgb/000537.png masks/000537.png depth/000537.png 216 | rgb/000538.png masks/000538.png depth/000538.png 217 | rgb/000539.png masks/000539.png depth/000539.png 218 | rgb/000549.png masks/000549.png depth/000549.png 219 | rgb/000550.png masks/000550.png depth/000550.png 220 | rgb/000551.png masks/000551.png depth/000551.png 221 | rgb/000555.png masks/000555.png depth/000555.png 222 | rgb/000556.png masks/000556.png depth/000556.png 223 | rgb/000557.png masks/000557.png depth/000557.png 224 | rgb/000558.png masks/000558.png depth/000558.png 225 | rgb/000559.png masks/000559.png depth/000559.png 226 | rgb/000560.png masks/000560.png depth/000560.png 227 | rgb/000561.png masks/000561.png depth/000561.png 228 | rgb/000562.png masks/000562.png depth/000562.png 229 | rgb/000563.png masks/000563.png depth/000563.png 230 | rgb/000564.png masks/000564.png depth/000564.png 231 | rgb/000565.png masks/000565.png depth/000565.png 232 | rgb/000566.png masks/000566.png depth/000566.png 233 | rgb/000567.png masks/000567.png depth/000567.png 234 | rgb/000568.png masks/000568.png depth/000568.png 235 | rgb/000569.png masks/000569.png depth/000569.png 236 | rgb/000570.png masks/000570.png depth/000570.png 237 | rgb/000571.png masks/000571.png depth/000571.png 238 | rgb/000579.png masks/000579.png depth/000579.png 239 | rgb/000580.png masks/000580.png depth/000580.png 240 | rgb/000581.png masks/000581.png depth/000581.png 241 | rgb/000582.png masks/000582.png depth/000582.png 242 | rgb/000583.png masks/000583.png depth/000583.png 243 | rgb/000591.png masks/000591.png depth/000591.png 244 | rgb/000592.png masks/000592.png depth/000592.png 245 | rgb/000593.png masks/000593.png depth/000593.png 246 | rgb/000594.png masks/000594.png depth/000594.png 247 | rgb/000603.png masks/000603.png depth/000603.png 248 | rgb/000604.png masks/000604.png depth/000604.png 249 | rgb/000605.png masks/000605.png depth/000605.png 250 | rgb/000606.png masks/000606.png depth/000606.png 251 | rgb/000607.png masks/000607.png depth/000607.png 252 | rgb/000612.png masks/000612.png depth/000612.png 253 | rgb/000613.png masks/000613.png depth/000613.png 254 | rgb/000617.png masks/000617.png depth/000617.png 255 | rgb/000618.png masks/000618.png depth/000618.png 256 | rgb/000619.png masks/000619.png depth/000619.png 257 | rgb/000620.png masks/000620.png depth/000620.png 258 | rgb/000621.png masks/000621.png depth/000621.png 259 | rgb/000633.png masks/000633.png depth/000633.png 260 | rgb/000634.png masks/000634.png depth/000634.png 261 | rgb/000635.png masks/000635.png depth/000635.png 262 | rgb/000636.png masks/000636.png depth/000636.png 263 | rgb/000637.png masks/000637.png depth/000637.png 264 | rgb/000638.png masks/000638.png depth/000638.png 265 | rgb/000644.png masks/000644.png depth/000644.png 266 | rgb/000645.png masks/000645.png depth/000645.png 267 | rgb/000650.png masks/000650.png depth/000650.png 268 | rgb/000651.png masks/000651.png depth/000651.png 269 | rgb/000656.png masks/000656.png depth/000656.png 270 | rgb/000657.png masks/000657.png depth/000657.png 271 | rgb/000658.png masks/000658.png depth/000658.png 272 | rgb/000663.png masks/000663.png depth/000663.png 273 | rgb/000664.png masks/000664.png depth/000664.png 274 | rgb/000668.png masks/000668.png depth/000668.png 275 | rgb/000669.png masks/000669.png depth/000669.png 276 | rgb/000670.png masks/000670.png depth/000670.png 277 | rgb/000671.png masks/000671.png depth/000671.png 278 | rgb/000672.png masks/000672.png depth/000672.png 279 | rgb/000673.png masks/000673.png depth/000673.png 280 | rgb/000676.png masks/000676.png depth/000676.png 281 | rgb/000677.png masks/000677.png depth/000677.png 282 | rgb/000678.png masks/000678.png depth/000678.png 283 | rgb/000679.png masks/000679.png depth/000679.png 284 | rgb/000680.png masks/000680.png depth/000680.png 285 | rgb/000681.png masks/000681.png depth/000681.png 286 | rgb/000686.png masks/000686.png depth/000686.png 287 | rgb/000687.png masks/000687.png depth/000687.png 288 | rgb/000688.png masks/000688.png depth/000688.png 289 | rgb/000689.png masks/000689.png depth/000689.png 290 | rgb/000690.png masks/000690.png depth/000690.png 291 | rgb/000693.png masks/000693.png depth/000693.png 292 | rgb/000694.png masks/000694.png depth/000694.png 293 | rgb/000697.png masks/000697.png depth/000697.png 294 | rgb/000698.png masks/000698.png depth/000698.png 295 | rgb/000699.png masks/000699.png depth/000699.png 296 | rgb/000706.png masks/000706.png depth/000706.png 297 | rgb/000707.png masks/000707.png depth/000707.png 298 | rgb/000708.png masks/000708.png depth/000708.png 299 | rgb/000709.png masks/000709.png depth/000709.png 300 | rgb/000710.png masks/000710.png depth/000710.png 301 | rgb/000711.png masks/000711.png depth/000711.png 302 | rgb/000712.png masks/000712.png depth/000712.png 303 | rgb/000713.png masks/000713.png depth/000713.png 304 | rgb/000717.png masks/000717.png depth/000717.png 305 | rgb/000718.png masks/000718.png depth/000718.png 306 | rgb/000724.png masks/000724.png depth/000724.png 307 | rgb/000725.png masks/000725.png depth/000725.png 308 | rgb/000726.png masks/000726.png depth/000726.png 309 | rgb/000727.png masks/000727.png depth/000727.png 310 | rgb/000728.png masks/000728.png depth/000728.png 311 | rgb/000731.png masks/000731.png depth/000731.png 312 | rgb/000732.png masks/000732.png depth/000732.png 313 | rgb/000733.png masks/000733.png depth/000733.png 314 | rgb/000734.png masks/000734.png depth/000734.png 315 | rgb/000743.png masks/000743.png depth/000743.png 316 | rgb/000744.png masks/000744.png depth/000744.png 317 | rgb/000759.png masks/000759.png depth/000759.png 318 | rgb/000760.png masks/000760.png depth/000760.png 319 | rgb/000761.png masks/000761.png depth/000761.png 320 | rgb/000762.png masks/000762.png depth/000762.png 321 | rgb/000763.png masks/000763.png depth/000763.png 322 | rgb/000764.png masks/000764.png depth/000764.png 323 | rgb/000765.png masks/000765.png depth/000765.png 324 | rgb/000766.png masks/000766.png depth/000766.png 325 | rgb/000767.png masks/000767.png depth/000767.png 326 | rgb/000768.png masks/000768.png depth/000768.png 327 | rgb/000769.png masks/000769.png depth/000769.png 328 | rgb/000770.png masks/000770.png depth/000770.png 329 | rgb/000771.png masks/000771.png depth/000771.png 330 | rgb/000772.png masks/000772.png depth/000772.png 331 | rgb/000773.png masks/000773.png depth/000773.png 332 | rgb/000774.png masks/000774.png depth/000774.png 333 | rgb/000775.png masks/000775.png depth/000775.png 334 | rgb/000776.png masks/000776.png depth/000776.png 335 | rgb/000777.png masks/000777.png depth/000777.png 336 | rgb/000778.png masks/000778.png depth/000778.png 337 | rgb/000779.png masks/000779.png depth/000779.png 338 | rgb/000780.png masks/000780.png depth/000780.png 339 | rgb/000781.png masks/000781.png depth/000781.png 340 | rgb/000782.png masks/000782.png depth/000782.png 341 | rgb/000783.png masks/000783.png depth/000783.png 342 | rgb/000784.png masks/000784.png depth/000784.png 343 | rgb/000785.png masks/000785.png depth/000785.png 344 | rgb/000786.png masks/000786.png depth/000786.png 345 | rgb/000787.png masks/000787.png depth/000787.png 346 | rgb/000800.png masks/000800.png depth/000800.png 347 | rgb/000801.png masks/000801.png depth/000801.png 348 | rgb/000802.png masks/000802.png depth/000802.png 349 | rgb/000803.png masks/000803.png depth/000803.png 350 | rgb/000804.png masks/000804.png depth/000804.png 351 | rgb/000810.png masks/000810.png depth/000810.png 352 | rgb/000811.png masks/000811.png depth/000811.png 353 | rgb/000812.png masks/000812.png depth/000812.png 354 | rgb/000813.png masks/000813.png depth/000813.png 355 | rgb/000814.png masks/000814.png depth/000814.png 356 | rgb/000821.png masks/000821.png depth/000821.png 357 | rgb/000822.png masks/000822.png depth/000822.png 358 | rgb/000823.png masks/000823.png depth/000823.png 359 | rgb/000833.png masks/000833.png depth/000833.png 360 | rgb/000834.png masks/000834.png depth/000834.png 361 | rgb/000835.png masks/000835.png depth/000835.png 362 | rgb/000836.png masks/000836.png depth/000836.png 363 | rgb/000837.png masks/000837.png depth/000837.png 364 | rgb/000838.png masks/000838.png depth/000838.png 365 | rgb/000839.png masks/000839.png depth/000839.png 366 | rgb/000840.png masks/000840.png depth/000840.png 367 | rgb/000841.png masks/000841.png depth/000841.png 368 | rgb/000842.png masks/000842.png depth/000842.png 369 | rgb/000843.png masks/000843.png depth/000843.png 370 | rgb/000844.png masks/000844.png depth/000844.png 371 | rgb/000845.png masks/000845.png depth/000845.png 372 | rgb/000846.png masks/000846.png depth/000846.png 373 | rgb/000850.png masks/000850.png depth/000850.png 374 | rgb/000851.png masks/000851.png depth/000851.png 375 | rgb/000852.png masks/000852.png depth/000852.png 376 | rgb/000857.png masks/000857.png depth/000857.png 377 | rgb/000858.png masks/000858.png depth/000858.png 378 | rgb/000859.png masks/000859.png depth/000859.png 379 | rgb/000860.png masks/000860.png depth/000860.png 380 | rgb/000861.png masks/000861.png depth/000861.png 381 | rgb/000862.png masks/000862.png depth/000862.png 382 | rgb/000869.png masks/000869.png depth/000869.png 383 | rgb/000870.png masks/000870.png depth/000870.png 384 | rgb/000871.png masks/000871.png depth/000871.png 385 | rgb/000906.png masks/000906.png depth/000906.png 386 | rgb/000907.png masks/000907.png depth/000907.png 387 | rgb/000908.png masks/000908.png depth/000908.png 388 | rgb/000917.png masks/000917.png depth/000917.png 389 | rgb/000918.png masks/000918.png depth/000918.png 390 | rgb/000919.png masks/000919.png depth/000919.png 391 | rgb/000926.png masks/000926.png depth/000926.png 392 | rgb/000927.png masks/000927.png depth/000927.png 393 | rgb/000928.png masks/000928.png depth/000928.png 394 | rgb/000932.png masks/000932.png depth/000932.png 395 | rgb/000933.png masks/000933.png depth/000933.png 396 | rgb/000934.png masks/000934.png depth/000934.png 397 | rgb/000935.png masks/000935.png depth/000935.png 398 | rgb/000945.png masks/000945.png depth/000945.png 399 | rgb/000946.png masks/000946.png depth/000946.png 400 | rgb/000947.png masks/000947.png depth/000947.png 401 | rgb/000959.png masks/000959.png depth/000959.png 402 | rgb/000960.png masks/000960.png depth/000960.png 403 | rgb/000961.png masks/000961.png depth/000961.png 404 | rgb/000962.png masks/000962.png depth/000962.png 405 | rgb/000965.png masks/000965.png depth/000965.png 406 | rgb/000966.png masks/000966.png depth/000966.png 407 | rgb/000967.png masks/000967.png depth/000967.png 408 | rgb/000970.png masks/000970.png depth/000970.png 409 | rgb/000971.png masks/000971.png depth/000971.png 410 | rgb/000972.png masks/000972.png depth/000972.png 411 | rgb/000973.png masks/000973.png depth/000973.png 412 | rgb/000974.png masks/000974.png depth/000974.png 413 | rgb/000975.png masks/000975.png depth/000975.png 414 | rgb/000976.png masks/000976.png depth/000976.png 415 | rgb/000977.png masks/000977.png depth/000977.png 416 | rgb/000991.png masks/000991.png depth/000991.png 417 | rgb/000992.png masks/000992.png depth/000992.png 418 | rgb/000993.png masks/000993.png depth/000993.png 419 | rgb/000994.png masks/000994.png depth/000994.png 420 | rgb/000995.png masks/000995.png depth/000995.png 421 | rgb/001001.png masks/001001.png depth/001001.png 422 | rgb/001002.png masks/001002.png depth/001002.png 423 | rgb/001003.png masks/001003.png depth/001003.png 424 | rgb/001004.png masks/001004.png depth/001004.png 425 | rgb/001010.png masks/001010.png depth/001010.png 426 | rgb/001011.png masks/001011.png depth/001011.png 427 | rgb/001012.png masks/001012.png depth/001012.png 428 | rgb/001021.png masks/001021.png depth/001021.png 429 | rgb/001022.png masks/001022.png depth/001022.png 430 | rgb/001023.png masks/001023.png depth/001023.png 431 | rgb/001032.png masks/001032.png depth/001032.png 432 | rgb/001033.png masks/001033.png depth/001033.png 433 | rgb/001034.png masks/001034.png depth/001034.png 434 | rgb/001038.png masks/001038.png depth/001038.png 435 | rgb/001039.png masks/001039.png depth/001039.png 436 | rgb/001048.png masks/001048.png depth/001048.png 437 | rgb/001049.png masks/001049.png depth/001049.png 438 | rgb/001052.png masks/001052.png depth/001052.png 439 | rgb/001053.png masks/001053.png depth/001053.png 440 | rgb/001057.png masks/001057.png depth/001057.png 441 | rgb/001058.png masks/001058.png depth/001058.png 442 | rgb/001075.png masks/001075.png depth/001075.png 443 | rgb/001076.png masks/001076.png depth/001076.png 444 | rgb/001077.png masks/001077.png depth/001077.png 445 | rgb/001078.png masks/001078.png depth/001078.png 446 | rgb/001079.png masks/001079.png depth/001079.png 447 | rgb/001080.png masks/001080.png depth/001080.png 448 | rgb/001081.png masks/001081.png depth/001081.png 449 | rgb/001082.png masks/001082.png depth/001082.png 450 | rgb/001083.png masks/001083.png depth/001083.png 451 | rgb/001084.png masks/001084.png depth/001084.png 452 | rgb/001088.png masks/001088.png depth/001088.png 453 | rgb/001089.png masks/001089.png depth/001089.png 454 | rgb/001090.png masks/001090.png depth/001090.png 455 | rgb/001091.png masks/001091.png depth/001091.png 456 | rgb/001092.png masks/001092.png depth/001092.png 457 | rgb/001093.png masks/001093.png depth/001093.png 458 | rgb/001094.png masks/001094.png depth/001094.png 459 | rgb/001095.png masks/001095.png depth/001095.png 460 | rgb/001096.png masks/001096.png depth/001096.png 461 | rgb/001098.png masks/001098.png depth/001098.png 462 | rgb/001099.png masks/001099.png depth/001099.png 463 | rgb/001100.png masks/001100.png depth/001100.png 464 | rgb/001101.png masks/001101.png depth/001101.png 465 | rgb/001102.png masks/001102.png depth/001102.png 466 | rgb/001103.png masks/001103.png depth/001103.png 467 | rgb/001104.png masks/001104.png depth/001104.png 468 | rgb/001106.png masks/001106.png depth/001106.png 469 | rgb/001107.png masks/001107.png depth/001107.png 470 | rgb/001108.png masks/001108.png depth/001108.png 471 | rgb/001109.png masks/001109.png depth/001109.png 472 | rgb/001117.png masks/001117.png depth/001117.png 473 | rgb/001118.png masks/001118.png depth/001118.png 474 | rgb/001119.png masks/001119.png depth/001119.png 475 | rgb/001123.png masks/001123.png depth/001123.png 476 | rgb/001124.png masks/001124.png depth/001124.png 477 | rgb/001125.png masks/001125.png depth/001125.png 478 | rgb/001126.png masks/001126.png depth/001126.png 479 | rgb/001127.png masks/001127.png depth/001127.png 480 | rgb/001128.png masks/001128.png depth/001128.png 481 | rgb/001129.png masks/001129.png depth/001129.png 482 | rgb/001130.png masks/001130.png depth/001130.png 483 | rgb/001131.png masks/001131.png depth/001131.png 484 | rgb/001135.png masks/001135.png depth/001135.png 485 | rgb/001136.png masks/001136.png depth/001136.png 486 | rgb/001144.png masks/001144.png depth/001144.png 487 | rgb/001145.png masks/001145.png depth/001145.png 488 | rgb/001146.png masks/001146.png depth/001146.png 489 | rgb/001147.png masks/001147.png depth/001147.png 490 | rgb/001148.png masks/001148.png depth/001148.png 491 | rgb/001149.png masks/001149.png depth/001149.png 492 | rgb/001150.png masks/001150.png depth/001150.png 493 | rgb/001151.png masks/001151.png depth/001151.png 494 | rgb/001152.png masks/001152.png depth/001152.png 495 | rgb/001153.png masks/001153.png depth/001153.png 496 | rgb/001154.png masks/001154.png depth/001154.png 497 | rgb/001155.png masks/001155.png depth/001155.png 498 | rgb/001156.png masks/001156.png depth/001156.png 499 | rgb/001157.png masks/001157.png depth/001157.png 500 | rgb/001158.png masks/001158.png depth/001158.png 501 | rgb/001162.png masks/001162.png depth/001162.png 502 | rgb/001163.png masks/001163.png depth/001163.png 503 | rgb/001164.png masks/001164.png depth/001164.png 504 | rgb/001165.png masks/001165.png depth/001165.png 505 | rgb/001166.png masks/001166.png depth/001166.png 506 | rgb/001167.png masks/001167.png depth/001167.png 507 | rgb/001170.png masks/001170.png depth/001170.png 508 | rgb/001171.png masks/001171.png depth/001171.png 509 | rgb/001174.png masks/001174.png depth/001174.png 510 | rgb/001175.png masks/001175.png depth/001175.png 511 | rgb/001176.png masks/001176.png depth/001176.png 512 | rgb/001179.png masks/001179.png depth/001179.png 513 | rgb/001180.png masks/001180.png depth/001180.png 514 | rgb/001181.png masks/001181.png depth/001181.png 515 | rgb/001182.png masks/001182.png depth/001182.png 516 | rgb/001183.png masks/001183.png depth/001183.png 517 | rgb/001184.png masks/001184.png depth/001184.png 518 | rgb/001192.png masks/001192.png depth/001192.png 519 | rgb/001193.png masks/001193.png depth/001193.png 520 | rgb/001194.png masks/001194.png depth/001194.png 521 | rgb/001195.png masks/001195.png depth/001195.png 522 | rgb/001196.png masks/001196.png depth/001196.png 523 | rgb/001201.png masks/001201.png depth/001201.png 524 | rgb/001202.png masks/001202.png depth/001202.png 525 | rgb/001203.png masks/001203.png depth/001203.png 526 | rgb/001204.png masks/001204.png depth/001204.png 527 | rgb/001205.png masks/001205.png depth/001205.png 528 | rgb/001206.png masks/001206.png depth/001206.png 529 | rgb/001207.png masks/001207.png depth/001207.png 530 | rgb/001208.png masks/001208.png depth/001208.png 531 | rgb/001209.png masks/001209.png depth/001209.png 532 | rgb/001210.png masks/001210.png depth/001210.png 533 | rgb/001211.png masks/001211.png depth/001211.png 534 | rgb/001212.png masks/001212.png depth/001212.png 535 | rgb/001216.png masks/001216.png depth/001216.png 536 | rgb/001217.png masks/001217.png depth/001217.png 537 | rgb/001218.png masks/001218.png depth/001218.png 538 | rgb/001219.png masks/001219.png depth/001219.png 539 | rgb/001220.png masks/001220.png depth/001220.png 540 | rgb/001226.png masks/001226.png depth/001226.png 541 | rgb/001227.png masks/001227.png depth/001227.png 542 | rgb/001228.png masks/001228.png depth/001228.png 543 | rgb/001229.png masks/001229.png depth/001229.png 544 | rgb/001230.png masks/001230.png depth/001230.png 545 | rgb/001233.png masks/001233.png depth/001233.png 546 | rgb/001234.png masks/001234.png depth/001234.png 547 | rgb/001235.png masks/001235.png depth/001235.png 548 | rgb/001247.png masks/001247.png depth/001247.png 549 | rgb/001248.png masks/001248.png depth/001248.png 550 | rgb/001249.png masks/001249.png depth/001249.png 551 | rgb/001250.png masks/001250.png depth/001250.png 552 | rgb/001254.png masks/001254.png depth/001254.png 553 | rgb/001255.png masks/001255.png depth/001255.png 554 | rgb/001256.png masks/001256.png depth/001256.png 555 | rgb/001257.png masks/001257.png depth/001257.png 556 | rgb/001258.png masks/001258.png depth/001258.png 557 | rgb/001259.png masks/001259.png depth/001259.png 558 | rgb/001260.png masks/001260.png depth/001260.png 559 | rgb/001261.png masks/001261.png depth/001261.png 560 | rgb/001262.png masks/001262.png depth/001262.png 561 | rgb/001263.png masks/001263.png depth/001263.png 562 | rgb/001264.png masks/001264.png depth/001264.png 563 | rgb/001265.png masks/001265.png depth/001265.png 564 | rgb/001275.png masks/001275.png depth/001275.png 565 | rgb/001276.png masks/001276.png depth/001276.png 566 | rgb/001277.png masks/001277.png depth/001277.png 567 | rgb/001278.png masks/001278.png depth/001278.png 568 | rgb/001279.png masks/001279.png depth/001279.png 569 | rgb/001280.png masks/001280.png depth/001280.png 570 | rgb/001285.png masks/001285.png depth/001285.png 571 | rgb/001286.png masks/001286.png depth/001286.png 572 | rgb/001287.png masks/001287.png depth/001287.png 573 | rgb/001288.png masks/001288.png depth/001288.png 574 | rgb/001289.png masks/001289.png depth/001289.png 575 | rgb/001290.png masks/001290.png depth/001290.png 576 | rgb/001291.png masks/001291.png depth/001291.png 577 | rgb/001292.png masks/001292.png depth/001292.png 578 | rgb/001293.png masks/001293.png depth/001293.png 579 | rgb/001294.png masks/001294.png depth/001294.png 580 | rgb/001295.png masks/001295.png depth/001295.png 581 | rgb/001297.png masks/001297.png depth/001297.png 582 | rgb/001298.png masks/001298.png depth/001298.png 583 | rgb/001299.png masks/001299.png depth/001299.png 584 | rgb/001302.png masks/001302.png depth/001302.png 585 | rgb/001303.png masks/001303.png depth/001303.png 586 | rgb/001304.png masks/001304.png depth/001304.png 587 | rgb/001305.png masks/001305.png depth/001305.png 588 | rgb/001306.png masks/001306.png depth/001306.png 589 | rgb/001307.png masks/001307.png depth/001307.png 590 | rgb/001308.png masks/001308.png depth/001308.png 591 | rgb/001314.png masks/001314.png depth/001314.png 592 | rgb/001315.png masks/001315.png depth/001315.png 593 | rgb/001329.png masks/001329.png depth/001329.png 594 | rgb/001330.png masks/001330.png depth/001330.png 595 | rgb/001331.png masks/001331.png depth/001331.png 596 | rgb/001332.png masks/001332.png depth/001332.png 597 | rgb/001335.png masks/001335.png depth/001335.png 598 | rgb/001336.png masks/001336.png depth/001336.png 599 | rgb/001337.png masks/001337.png depth/001337.png 600 | rgb/001338.png masks/001338.png depth/001338.png 601 | rgb/001339.png masks/001339.png depth/001339.png 602 | rgb/001340.png masks/001340.png depth/001340.png 603 | rgb/001347.png masks/001347.png depth/001347.png 604 | rgb/001348.png masks/001348.png depth/001348.png 605 | rgb/001349.png masks/001349.png depth/001349.png 606 | rgb/001353.png masks/001353.png depth/001353.png 607 | rgb/001354.png masks/001354.png depth/001354.png 608 | rgb/001355.png masks/001355.png depth/001355.png 609 | rgb/001356.png masks/001356.png depth/001356.png 610 | rgb/001364.png masks/001364.png depth/001364.png 611 | rgb/001365.png masks/001365.png depth/001365.png 612 | rgb/001368.png masks/001368.png depth/001368.png 613 | rgb/001369.png masks/001369.png depth/001369.png 614 | rgb/001384.png masks/001384.png depth/001384.png 615 | rgb/001385.png masks/001385.png depth/001385.png 616 | rgb/001386.png masks/001386.png depth/001386.png 617 | rgb/001387.png masks/001387.png depth/001387.png 618 | rgb/001388.png masks/001388.png depth/001388.png 619 | rgb/001389.png masks/001389.png depth/001389.png 620 | rgb/001390.png masks/001390.png depth/001390.png 621 | rgb/001391.png masks/001391.png depth/001391.png 622 | rgb/001394.png masks/001394.png depth/001394.png 623 | rgb/001395.png masks/001395.png depth/001395.png 624 | rgb/001396.png masks/001396.png depth/001396.png 625 | rgb/001397.png masks/001397.png depth/001397.png 626 | rgb/001398.png masks/001398.png depth/001398.png 627 | rgb/001399.png masks/001399.png depth/001399.png 628 | rgb/001400.png masks/001400.png depth/001400.png 629 | rgb/001401.png masks/001401.png depth/001401.png 630 | rgb/001407.png masks/001407.png depth/001407.png 631 | rgb/001408.png masks/001408.png depth/001408.png 632 | rgb/001409.png masks/001409.png depth/001409.png 633 | rgb/001410.png masks/001410.png depth/001410.png 634 | rgb/001411.png masks/001411.png depth/001411.png 635 | rgb/001412.png masks/001412.png depth/001412.png 636 | rgb/001413.png masks/001413.png depth/001413.png 637 | rgb/001414.png masks/001414.png depth/001414.png 638 | rgb/001421.png masks/001421.png depth/001421.png 639 | rgb/001422.png masks/001422.png depth/001422.png 640 | rgb/001423.png masks/001423.png depth/001423.png 641 | rgb/001424.png masks/001424.png depth/001424.png 642 | rgb/001430.png masks/001430.png depth/001430.png 643 | rgb/001431.png masks/001431.png depth/001431.png 644 | rgb/001432.png masks/001432.png depth/001432.png 645 | rgb/001433.png masks/001433.png depth/001433.png 646 | rgb/001441.png masks/001441.png depth/001441.png 647 | rgb/001442.png masks/001442.png depth/001442.png 648 | rgb/001443.png masks/001443.png depth/001443.png 649 | rgb/001444.png masks/001444.png depth/001444.png 650 | rgb/001445.png masks/001445.png depth/001445.png 651 | rgb/001446.png masks/001446.png depth/001446.png 652 | rgb/001447.png masks/001447.png depth/001447.png 653 | rgb/001448.png masks/001448.png depth/001448.png 654 | rgb/001449.png masks/001449.png depth/001449.png --------------------------------------------------------------------------------