├── README.md ├── configs ├── __init__.py ├── mtmamba_cityscapes.yml ├── mtmamba_nyud.yml ├── mtmamba_pascal.yml ├── mtmamba_plus_cityscapes.yml ├── mtmamba_plus_nyud.yml ├── mtmamba_plus_pascal.yml └── mypath.py ├── data ├── __init__.py ├── cityscapes.py ├── db_info │ ├── context_classes.json │ ├── nyu_classes.json │ ├── pascal_map.npy │ └── pascal_part.json ├── mat2png.py ├── nyud.py ├── pascal_context.py └── transforms.py ├── evaluation ├── __init__.py ├── eval_depth.py ├── eval_edge.py ├── eval_human_parts.py ├── eval_normals.py ├── eval_sal.py ├── eval_semseg.py ├── evaluate_utils.py └── jaccard.py ├── losses ├── __init__.py ├── loss_functions.py └── loss_schemes.py ├── main.py ├── models ├── CTM.py ├── MTMamba.py ├── MTMamba_plus.py └── utils.py ├── pretrained_ckpts ├── run.sh └── swin2mmseg.py └── utils ├── __init__.py ├── common_config.py ├── config.py ├── custom_collate.py ├── logger.py ├── test_utils.py ├── train_utils.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # MTMamba 2 | 3 | This repository contains codes and models for the following papers: 4 | 5 | > Baijiong Lin, Weisen Jiang, Pengguang Chen, Yu Zhang, Shu Liu, and Ying-Cong Chen. MTMamba: Enhancing Multi-Task Dense Scene Understanding by Mamba-Based Decoders. In *European Conference on Computer Vision*, 2024. 6 | 7 | > Baijiong Lin, Weisen Jiang, Pengguang Chen, Shu Liu, and Ying-Cong Chen. MTMamba++: Enhancing Multi-Task Dense Scene Understanding via Mamba-Based Decoders. *arXiv preprint arXiv:2408.15101*, 2024. 8 | 9 | ## Requirements 10 | 11 | - PyTorch 2.0.0 12 | 13 | - timm 0.9.16 14 | 15 | - mmsegmentation 1.2.2 16 | 17 | - mamba-ssm 1.1.2 18 | 19 | - CUDA 11.8 20 | 21 | 22 | 23 | ## Usage 24 | 25 | 1. Prepare the pretrained Swin-Large checkpoint by running the following command 26 | 27 | ```shell 28 | cd pretrained_ckpts 29 | bash run.sh 30 | cd ../ 31 | ``` 32 | 33 | 2. Download the data from [PASCALContext.tar.gz](https://hkustconnect-my.sharepoint.com/:u:/g/personal/hyeae_connect_ust_hk/ER57KyZdEdxPtgMCai7ioV0BXCmAhYzwFftCwkTiMmuM7w?e=2Ex4ab), [NYUDv2.tar.gz](https://hkustconnect-my.sharepoint.com/:u:/g/personal/hyeae_connect_ust_hk/EZ-2tWIDYSFKk7SCcHRimskBhgecungms4WFa_L-255GrQ?e=6jAt4c), and then extract them. You need to modify the dataset directory as ```db_root``` variable in ```configs/mypath.py```. 34 | 35 | 3. Train the model. Taking training NYUDv2 as an example, you can run the following command 36 | 37 | ```shell 38 | python -m torch.distributed.launch --nproc_per_node 8 main.py --run_mode train --config_exp ./configs/mtmamba_nyud.yml 39 | ``` 40 | 41 |         You can download the pretrained models from [mtmamba_nyud.pth.tar](https://hkustgz-my.sharepoint.com/:u:/g/personal/blin241_connect_hkust-gz_edu_cn/EdP6lzTOEIRLggFVLlbzPWUBZrsRPoEkdtNpYjm_H2K54A?e=IwsaaG), [mtmamba_pascal.pth.tar](https://hkustgz-my.sharepoint.com/:u:/g/personal/blin241_connect_hkust-gz_edu_cn/ET0zoRo2mq9OoYJlHZZy2eQB5lh6W-yayKzih6ejwD7awQ?e=DUZFGE), [mtmamba_cityscapes.pth.tar](https://hkustgz-my.sharepoint.com/:u:/g/personal/blin241_connect_hkust-gz_edu_cn/EVfY4W2qn85Ihe8rANBiKisBM0xxGn4OnmuOjRJ9FWNGeA?e=TsyE5B), [mtmamba_plus_nyud.pth.tar](https://hkustgz-my.sharepoint.com/:u:/g/personal/blin241_connect_hkust-gz_edu_cn/Ecjm9MJ5SwBGlPfg4YAxGGABagrzm81LM_TI3h6jADkpvA?e=KePvfD), [mtmamba_plus_pascal.pth.tar](https://hkustgz-my.sharepoint.com/:u:/g/personal/blin241_connect_hkust-gz_edu_cn/EaVpHcqrNihIsfyMeyPR614BpzSrk2ubRSIdBUHLcwZTjA?e=DpRajc), [mtmamba_plus_cityscapes.pth.tar](https://hkustgz-my.sharepoint.com/:u:/g/personal/blin241_connect_hkust-gz_edu_cn/EZHHVmXbGChFsvyorMKOvncBU06opYPC0FuVCg8X8Yg8gw?e=8lnvdI). 42 | 43 | 4. Evaluation. You can run the following command, 44 | 45 | ```shell 46 | python -m torch.distributed.launch --nproc_per_node 1 main.py --run_mode infer --config_exp ./configs/mtmamba_nyud.yml --trained_model ./ckpts/mtmamba_nyud.pth.tar 47 | ``` 48 | 49 | Acknowledgement 50 | --------------- 51 | 52 | We would like to thank the authors that release the public repositories: [Multi-Task-Transformer](https://github.com/prismformore/Multi-Task-Transformer), [mamba](https://github.com/state-spaces/mamba), and [VMamba](https://github.com/MzeroMiko/VMamba). 53 | 54 | 55 | 56 | ## Citation 57 | 58 | If you found this code/work to be useful in your own research, please cite the following: 59 | 60 | ```latex 61 | @inproceedings{lin2024mtmamba, 62 | title={{MTMamba}: Enhancing Multi-Task Dense Scene Understanding by Mamba-Based Decoders}, 63 | author={Lin, Baijiong and Jiang, Weisen and Chen, Pengguang and Zhang, Yu and Liu, Shu and Chen, Ying-Cong}, 64 | booktitle={European Conference on Computer Vision}, 65 | year={2024} 66 | } 67 | 68 | @article{lin2024mtmambaplus, 69 | title={{MTMamba++}: Enhancing Multi-Task Dense Scene Understanding via Mamba-Based Decoders}, 70 | author={Lin, Baijiong and Jiang, Weisen and Chen, Pengguang and Liu, Shu and Chen, Ying-Cong}, 71 | journal={arXiv preprint arXiv:2408.15101}, 72 | year={2024} 73 | } 74 | ``` 75 | -------------------------------------------------------------------------------- /configs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EnVision-Research/MTMamba/17d132d2a7a1ef80c35ffe40c8dd3ff72fa07d77/configs/__init__.py -------------------------------------------------------------------------------- /configs/mtmamba_cityscapes.yml: -------------------------------------------------------------------------------- 1 | # Setup 2 | version_name: mtmamba_city 3 | out_dir: "./" 4 | 5 | # Database 6 | train_db_name: Cityscapes 7 | val_db_name: Cityscapes 8 | trBatch: 1 9 | valBatch: 4 10 | nworkers: 2 11 | ignore_index: 255 12 | 13 | # Optimizer and scheduler 14 | intermediate_supervision: False 15 | val_interval: 1000 16 | epochs: 999999 17 | max_iter: 40000 18 | optimizer: adamw 19 | optimizer_kwargs: 20 | lr: 0.0001 21 | weight_decay: 0.000001 22 | scheduler: poly 23 | 24 | # Model 25 | model: MTMamba 26 | backbone: swin_large 27 | 28 | # Tasks 29 | task_dictionary: 30 | include_semseg: True 31 | include_depth: True 32 | 33 | # Loss kwargs 34 | loss_kwargs: 35 | loss_scheme: log 36 | loss_weights: 37 | semseg: 1.0 38 | depth: 1.0 39 | -------------------------------------------------------------------------------- /configs/mtmamba_nyud.yml: -------------------------------------------------------------------------------- 1 | # Setup 2 | version_name: mtmamba_nyud 3 | out_dir: "./" 4 | 5 | # Database 6 | train_db_name: NYUD 7 | val_db_name: NYUD 8 | trBatch: 1 9 | valBatch: 6 10 | nworkers: 4 11 | ignore_index: 255 12 | 13 | # Optimizer and scheduler 14 | intermediate_supervision: False 15 | val_interval: 1000 16 | epochs: 999999 17 | max_iter: 50000 18 | optimizer: adamw 19 | optimizer_kwargs: 20 | lr: 0.0001 21 | weight_decay: 0.00001 22 | scheduler: poly 23 | 24 | # Model 25 | model: MTMamba 26 | backbone: swin_large 27 | 28 | # Tasks 29 | task_dictionary: 30 | include_semseg: True 31 | include_depth: True 32 | include_edge: True 33 | include_normals: True 34 | edge_w: 0.95 35 | 36 | # Loss kwargs 37 | loss_kwargs: 38 | loss_scheme: log 39 | loss_weights: 40 | semseg: 1.0 41 | depth: 1.0 42 | edge: 1.0 43 | normals: 1.0 44 | -------------------------------------------------------------------------------- /configs/mtmamba_pascal.yml: -------------------------------------------------------------------------------- 1 | # Setup 2 | version_name: mtmamba_pascal 3 | out_dir: "./" 4 | 5 | # Database 6 | train_db_name: PASCALContext 7 | val_db_name: PASCALContext 8 | trBatch: 1 9 | valBatch: 6 10 | nworkers: 4 11 | ignore_index: 255 12 | 13 | # Optimizer and scheduler 14 | intermediate_supervision: False 15 | val_interval: 1000 16 | epochs: 999999 17 | max_iter: 50000 18 | optimizer: adamw 19 | optimizer_kwargs: 20 | lr: 0.0001 21 | weight_decay: 0.00001 22 | scheduler: poly 23 | 24 | # Model 25 | model: MTMamba 26 | backbone: swin_large 27 | 28 | # Tasks 29 | task_dictionary: 30 | include_semseg: True 31 | include_human_parts: True 32 | include_sal: True 33 | include_edge: True 34 | include_normals: True 35 | edge_w: 0.95 36 | 37 | # Loss kwargs 38 | loss_kwargs: 39 | loss_scheme: log 40 | loss_weights: 41 | semseg: 1.0 42 | human_parts: 1.0 43 | sal: 1.0 44 | edge: 1.0 45 | normals: 1.0 46 | -------------------------------------------------------------------------------- /configs/mtmamba_plus_cityscapes.yml: -------------------------------------------------------------------------------- 1 | # Setup 2 | version_name: mtmamba_plus_city 3 | out_dir: "./" 4 | 5 | # Database 6 | train_db_name: Cityscapes 7 | val_db_name: Cityscapes 8 | trBatch: 1 9 | valBatch: 4 10 | nworkers: 2 11 | ignore_index: 255 12 | 13 | # Optimizer and scheduler 14 | intermediate_supervision: False 15 | val_interval: 1000 16 | epochs: 999999 17 | max_iter: 40000 18 | optimizer: adamw 19 | optimizer_kwargs: 20 | lr: 0.0001 21 | weight_decay: 0.000001 22 | scheduler: poly 23 | 24 | # Model 25 | model: MTMamba_plus 26 | backbone: swin_large 27 | 28 | # Tasks 29 | task_dictionary: 30 | include_semseg: True 31 | include_depth: True 32 | 33 | # Loss kwargs 34 | loss_kwargs: 35 | loss_scheme: log 36 | loss_weights: 37 | semseg: 1.0 38 | depth: 1.0 39 | -------------------------------------------------------------------------------- /configs/mtmamba_plus_nyud.yml: -------------------------------------------------------------------------------- 1 | # Setup 2 | version_name: mtmamba_plus_nyud 3 | out_dir: "./" 4 | 5 | # Database 6 | train_db_name: NYUD 7 | val_db_name: NYUD 8 | trBatch: 1 9 | valBatch: 6 10 | nworkers: 2 11 | ignore_index: 255 12 | 13 | # Optimizer and scheduler 14 | intermediate_supervision: False 15 | val_interval: 1000 16 | epochs: 99999 17 | max_iter: 40000 18 | optimizer: adamw 19 | optimizer_kwargs: 20 | lr: 0.00002 21 | weight_decay: 0.000001 22 | scheduler: poly 23 | 24 | # Model 25 | model: MTMamba_plus 26 | backbone: swin_large 27 | 28 | # Tasks 29 | task_dictionary: 30 | include_semseg: True 31 | include_depth: True 32 | include_edge: True 33 | include_normals: True 34 | edge_w: 0.95 35 | 36 | # Loss kwargs 37 | loss_kwargs: 38 | loss_scheme: none 39 | loss_weights: 40 | semseg: 1.0 41 | depth: 1.0 42 | normals: 10.0 43 | edge: 50.0 44 | -------------------------------------------------------------------------------- /configs/mtmamba_plus_pascal.yml: -------------------------------------------------------------------------------- 1 | # Setup 2 | version_name: mtmamba_plus_pascal 3 | out_dir: "./" 4 | 5 | # Database 6 | train_db_name: PASCALContext 7 | val_db_name: PASCALContext 8 | trBatch: 1 9 | valBatch: 6 10 | nworkers: 2 11 | ignore_index: 255 12 | 13 | # Optimizer and scheduler 14 | intermediate_supervision: False 15 | val_interval: 1000 16 | epochs: 99999 17 | max_iter: 40000 18 | optimizer: adamw 19 | optimizer_kwargs: 20 | lr: 0.00008 21 | weight_decay: 0.000001 22 | scheduler: poly 23 | 24 | # Model 25 | model: MTMamba_plus 26 | backbone: swin_large 27 | 28 | # Tasks 29 | task_dictionary: 30 | include_semseg: True 31 | include_human_parts: True 32 | include_sal: True 33 | include_edge: True 34 | include_normals: True 35 | edge_w: 0.95 36 | 37 | # Loss kwargs 38 | loss_kwargs: 39 | loss_scheme: log 40 | loss_weights: 41 | semseg: 1.0 42 | human_parts: 1.0 43 | sal: 1.0 44 | edge: 1.0 45 | normals: 1.0 46 | -------------------------------------------------------------------------------- /configs/mypath.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | db_root = '/path/to/' 4 | PROJECT_ROOT_DIR = os.path.dirname(os.path.abspath(__file__)).split('/')[0] 5 | 6 | db_names = {'PASCALContext': 'PASCALContext', 'NYUD_MT': 'NYUDv2', 'Cityscapes': 'cityscapes_all'} 7 | db_paths = {} 8 | for database, db_pa in db_names.items(): 9 | db_paths[database] = os.path.join(db_root, db_pa) -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EnVision-Research/MTMamba/17d132d2a7a1ef80c35ffe40c8dd3ff72fa07d77/data/__init__.py -------------------------------------------------------------------------------- /data/cityscapes.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import numpy as np 4 | import scipy.io as sio 5 | from scipy.spatial.transform import Rotation 6 | import torch.utils.data as data 7 | from PIL import Image, ImageDraw, ImageFont 8 | import imageio, cv2 9 | import torch 10 | import copy 11 | import torchvision.transforms as t_transforms 12 | import pickle 13 | 14 | def recursive_glob(rootdir='.', suffix=''): 15 | """Performs recursive glob with given suffix and rootdir 16 | :param rootdir is the root directory 17 | :param suffix is the suffix to be searched 18 | """ 19 | return [os.path.join(looproot, filename) 20 | for looproot, _, filenames in os.walk(rootdir) 21 | for filename in filenames if filename.endswith(suffix)] 22 | 23 | def imresize(img, size, mode, resample): 24 | size = (size[1], size[0]) # width, height 25 | _img = Image.fromarray(img)#, mode=mode) 26 | _img = _img.resize(size, resample) 27 | _img = np.array(_img) 28 | return _img 29 | 30 | class CITYSCAPES(data.Dataset): 31 | def __init__(self, p, root, split=["train"], is_transform=False, 32 | img_size=[1024, 2048], augmentations=None, 33 | task_list=['semseg', 'depth'], ignore_index=255): 34 | 35 | if isinstance(split, str): 36 | split = [split] 37 | else: 38 | split.sort() 39 | split = split 40 | 41 | self.split = split 42 | self.root = root 43 | self.split_text = '+'.join(split) 44 | self.is_transform = is_transform 45 | self.augmentations = augmentations 46 | self.n_classes = 19 47 | self.img_size = img_size 48 | 49 | self.task_flags = {'semseg': True, 'insseg': False, 'depth': True} 50 | self.task_list = task_list 51 | self.files = {} 52 | 53 | self.files[self.split_text] = [] 54 | for _split in self.split: 55 | self.images_base = os.path.join(self.root, 'leftImg8bit', _split) 56 | self.annotations_base = os.path.join(self.root, 'gtFine', _split) 57 | self.files[self.split_text] += recursive_glob(rootdir=self.images_base, suffix='.png') 58 | self.depth_base = os.path.join(self.root, 'disparity', _split) 59 | ori_img_no = len(self.files[self.split_text]) 60 | 61 | self.void_classes = [0, 1, 2, 3, 4, 5, 6, 9, 10, 14, 15, 16, 18, 29, 30, -1] 62 | self.valid_classes = [7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33] 63 | self.class_names = ['road', 'sidewalk', 'building', 'wall', 'fence',\ 64 | 'pole', 'traffic_light', 'traffic_sign', 'vegetation', 'terrain',\ 65 | 'sky', 'person', 'rider', 'car', 'truck', 'bus', 'train', \ 66 | 'motorcycle', 'bicycle'] 67 | 68 | self.ignore_index = ignore_index 69 | self.class_map = dict(zip(self.valid_classes, range(19))) 70 | 71 | self.ori_img_size = [1024, 2048] 72 | self.label_dw_ratio = img_size[0] / self.ori_img_size[0] # hacking 73 | 74 | if len(self.files[self.split_text]) < 2: 75 | raise Exception("No files for split=[%s] found in %s" % (self.split_text, self.images_base)) 76 | 77 | # image to tensor 78 | mean = [0.485, 0.456, 0.406] 79 | std = [0.229, 0.224, 0.225] 80 | self.img_transform = t_transforms.Compose([t_transforms.ToTensor(), t_transforms.Normalize(mean, std)]) 81 | 82 | 83 | def __len__(self): 84 | return len(self.files[self.split_text]) 85 | 86 | def __getitem__(self, index): 87 | 88 | img_path = self.files[self.split_text][index].rstrip() 89 | lbl_path = os.path.join(self.annotations_base, 90 | img_path.split(os.sep)[-2], 91 | os.path.basename(img_path)[:-15] + 'gtFine_labelIds.png') 92 | depth_path = os.path.join(self.depth_base, 93 | img_path.split(os.sep)[-2], 94 | os.path.basename(img_path)[:-15] + 'disparity.png') 95 | 96 | 97 | img = cv2.imread(img_path).astype(np.float32) 98 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 99 | sample = {'image': img} 100 | sample['meta'] = {'img_name': img_path.split('.')[0].split('/')[-1], 101 | 'img_size': (img.shape[0], img.shape[1]), 102 | 'scale_factor': np.array([self.img_size[1]/img.shape[1], self.img_size[0]/img.shape[0]]), # in xy order 103 | } 104 | 105 | if 'semseg' in self.task_list: 106 | lbl = imageio.imread(lbl_path) 107 | sample['semseg'] = lbl 108 | 109 | if 'depth' in self.task_list: 110 | depth = cv2.imread(depth_path, cv2.IMREAD_UNCHANGED).astype(np.float32) # disparity 111 | 112 | depth[depth>0] = (depth[depth>0] - 1) / 256 # disparity values 113 | 114 | # make the invalid idx to -1 115 | depth[depth==0] = -1 116 | 117 | # assign the disparity of sky to zero 118 | sky_mask = lbl == 10 119 | depth[sky_mask] = 0 120 | 121 | sample['depth'] = depth 122 | 123 | if self.augmentations is not None: 124 | sample = self.augmentations(sample) 125 | 126 | if 'semseg' in self.task_list: 127 | sample['semseg'] = self.encode_segmap(sample['semseg']) 128 | 129 | if self.is_transform: 130 | sample = self.transform(sample) 131 | 132 | return sample 133 | 134 | def transform(self, sample): 135 | img = sample['image'] 136 | if 'semseg' in self.task_list: 137 | lbl = sample['semseg'] 138 | if 'depth' in self.task_list: 139 | depth = sample['depth'] 140 | 141 | img_ori_shape = img.shape[:2] 142 | img = img.astype(np.uint8) 143 | 144 | if self.img_size != self.ori_img_size: 145 | img = imresize(img, (self.img_size[0], self.img_size[1]), 'RGB', Image.BILINEAR) 146 | 147 | if 'semseg' in self.task_list: 148 | classes = np.unique(lbl) 149 | lbl = lbl.astype(int) 150 | 151 | if 'depth' in self.task_list: 152 | depth = np.expand_dims(depth, axis=0) 153 | depth = torch.from_numpy(depth).float() 154 | sample['depth'] = depth 155 | 156 | if 'semseg' in self.task_list: 157 | if not np.all(np.unique(lbl[lbl!=self.ignore_index]) < self.n_classes): 158 | print('after det', classes, np.unique(lbl)) 159 | raise ValueError("Segmentation map contained invalid class values") 160 | lbl = torch.from_numpy(lbl).long() 161 | sample['semseg'] = lbl 162 | 163 | img = self.img_transform(img) 164 | sample['image'] = img 165 | 166 | return sample 167 | 168 | def encode_segmap(self, mask): 169 | for _voidc in self.void_classes: 170 | mask[mask==_voidc] = self.ignore_index 171 | old_mask = mask.copy() 172 | for _validc in self.valid_classes: 173 | mask[old_mask==_validc] = self.class_map[_validc] 174 | return mask 175 | 176 | class ComposeAug(object): 177 | def __init__(self, augmentations): 178 | self.augmentations = augmentations 179 | 180 | def __call__(self, sample): 181 | sample['image'], sample['semseg'], sample['depth'] = np.array(sample['image'], dtype=np.uint8), np.array(sample['semseg'], dtype=np.uint8), np.array(sample['depth'], dtype=np.float32) 182 | sample['image'], sample['semseg'], sample['depth'] = Image.fromarray(sample['image'], mode='RGB'), Image.fromarray(sample['semseg'], mode='L'), Image.fromarray(sample['depth'], mode='F') 183 | if 'insseg' in sample.keys(): 184 | sample['insseg'] = np.array(sample['insseg'], dtype=np.int32) 185 | sample['insseg'] = Image.fromarray(sample['insseg'], mode='I') 186 | 187 | assert sample['image'].size == sample['semseg'].size 188 | assert sample['image'].size == sample['depth'].size 189 | if 'insseg' in sample.keys(): 190 | assert sample['image'].size == sample['insseg'].size 191 | 192 | for a in self.augmentations: 193 | sample = a(sample) 194 | 195 | sample['image'] = np.array(sample['image']) 196 | sample['semseg'] = np.array(sample['semseg'], dtype=np.uint8) 197 | sample['depth'] = np.array(sample['depth'], dtype=np.float32) 198 | if 'insseg' in sample.keys(): 199 | sample['insseg'] = np.array(sample['insseg'], dtype=np.uint64) 200 | 201 | return sample -------------------------------------------------------------------------------- /data/db_info/context_classes.json: -------------------------------------------------------------------------------- 1 | {"accordion": 1, "aeroplane": 2, "air conditioner": 3, "antenna": 4, "artillery": 5, "ashtray": 6, "atrium": 7, "baby carriage": 8, "bag": 9, "ball": 10, "balloon": 11, "bamboo weaving": 12, "barrel": 13, "baseball bat": 14, "basket": 15, "basketball backboard": 16, "bathtub": 17, "bed": 18, "bedclothes": 19, "beer": 20, "bell": 21, "bench": 22, "bicycle": 23, "binoculars": 24, "bird": 25, "bird cage": 26, "bird feeder": 27, "bird nest": 28, "blackboard": 29, "board": 30, "boat": 31, "bone": 32, "book": 33, "bottle": 34, "bottle opener": 35, "bowl": 36, "box": 37, "bracelet": 38, "brick": 39, "bridge": 40, "broom": 41, "brush": 42, "bucket": 43, "building": 44, "bus": 45, "cabinet": 46, "cabinet door": 47, "cage": 48, "cake": 49, "calculator": 50, "calendar": 51, "camel": 52, "camera": 53, "camera lens": 54, "can": 55, "candle": 56, "candle holder": 57, "cap": 58, "car": 59, "card": 60, "cart": 61, "case": 62, "casette recorder": 63, "cash register": 64, "cat": 65, "cd": 66, "cd player": 67, "ceiling": 68, "cell phone": 69, "cello": 70, "chain": 71, "chair": 72, "chessboard": 73, "chicken": 74, "chopstick": 75, "clip": 76, "clippers": 77, "clock": 78, "closet": 79, "cloth": 80, "clothes tree": 81, "coffee": 82, "coffee machine": 83, "comb": 84, "computer": 85, "concrete": 86, "cone": 87, "container": 88, "control booth": 89, "controller": 90, "cooker": 91, "copying machine": 92, "coral": 93, "cork": 94, "corkscrew": 95, "counter": 96, "court": 97, "cow": 98, "crabstick": 99, "crane": 100, "crate": 101, "cross": 102, "crutch": 103, "cup": 104, "curtain": 105, "cushion": 106, "cutting board": 107, "dais": 108, "disc": 109, "disc case": 110, "dishwasher": 111, "dock": 112, "dog": 113, "dolphin": 114, "door": 115, "drainer": 116, "dray": 117, "drink dispenser": 118, "drinking machine": 119, "drop": 120, "drug": 121, "drum": 122, "drum kit": 123, "duck": 124, "dumbbell": 125, "earphone": 126, "earrings": 127, "egg": 128, "electric fan": 129, "electric iron": 130, "electric pot": 131, "electric saw": 132, "electronic keyboard": 133, "engine": 134, "envelope": 135, "equipment": 136, "escalator": 137, "exhibition booth": 138, "extinguisher": 139, "eyeglass": 140, "fan": 141, "faucet": 142, "fax machine": 143, "fence": 144, "ferris wheel": 145, "fire extinguisher": 146, "fire hydrant": 147, "fire place": 148, "fish": 149, "fish tank": 150, "fishbowl": 151, "fishing net": 152, "fishing pole": 153, "flag": 154, "flagstaff": 155, "flame": 156, "flashlight": 157, "floor": 158, "flower": 159, "fly": 160, "foam": 161, "food": 162, "footbridge": 163, "forceps": 164, "fork": 165, "forklift": 166, "fountain": 167, "fox": 168, "frame": 169, "fridge": 170, "frog": 171, "fruit": 172, "funnel": 173, "furnace": 174, "game controller": 175, "game machine": 176, "gas cylinder": 177, "gas hood": 178, "gas stove": 179, "gift box": 180, "glass": 181, "glass marble": 182, "globe": 183, "glove": 184, "goal": 185, "grandstand": 186, "grass": 187, "gravestone": 188, "ground": 189, "guardrail": 190, "guitar": 191, "gun": 192, "hammer": 193, "hand cart": 194, "handle": 195, "handrail": 196, "hanger": 197, "hard disk drive": 198, "hat": 199, "hay": 200, "headphone": 201, "heater": 202, "helicopter": 203, "helmet": 204, "holder": 205, "hook": 206, "horse": 207, "horse-drawn carriage": 208, "hot-air balloon": 209, "hydrovalve": 210, "ice": 211, "inflator pump": 212, "ipod": 213, "iron": 214, "ironing board": 215, "jar": 216, "kart": 217, "kettle": 218, "key": 219, "keyboard": 220, "kitchen range": 221, "kite": 222, "knife": 223, "knife block": 224, "ladder": 225, "ladder truck": 226, "ladle": 227, "laptop": 228, "leaves": 229, "lid": 230, "life buoy": 231, "light": 232, "light bulb": 233, "lighter": 234, "line": 235, "lion": 236, "lobster": 237, "lock": 238, "machine": 239, "mailbox": 240, "mannequin": 241, "map": 242, "mask": 243, "mat": 244, "match book": 245, "mattress": 246, "menu": 247, "metal": 248, "meter box": 249, "microphone": 250, "microwave": 251, "mirror": 252, "missile": 253, "model": 254, "money": 255, "monkey": 256, "mop": 257, "motorbike": 258, "mountain": 259, "mouse": 260, "mouse pad": 261, "musical instrument": 262, "napkin": 263, "net": 264, "newspaper": 265, "oar": 266, "ornament": 267, "outlet": 268, "oven": 269, "oxygen bottle": 270, "pack": 271, "pan": 272, "paper": 273, "paper box": 274, "paper cutter": 275, "parachute": 276, "parasol": 277, "parterre": 278, "patio": 279, "pelage": 280, "pen": 281, "pen container": 282, "pencil": 283, "person": 284, "photo": 285, "piano": 286, "picture": 287, "pig": 288, "pillar": 289, "pillow": 290, "pipe": 291, "pitcher": 292, "plant": 293, "plastic": 294, "plate": 295, "platform": 296, "player": 297, "playground": 298, "pliers": 299, "plume": 300, "poker": 301, "poker chip": 302, "pole": 303, "pool table": 304, "postcard": 305, "poster": 306, "pot": 307, "pottedplant": 308, "printer": 309, "projector": 310, "pumpkin": 311, "rabbit": 312, "racket": 313, "radiator": 314, "radio": 315, "rail": 316, "rake": 317, "ramp": 318, "range hood": 319, "receiver": 320, "recorder": 321, "recreational machines": 322, "remote control": 323, "road": 324, "robot": 325, "rock": 326, "rocket": 327, "rocking horse": 328, "rope": 329, "rug": 330, "ruler": 331, "runway": 332, "saddle": 333, "sand": 334, "saw": 335, "scale": 336, "scanner": 337, "scissors": 338, "scoop": 339, "screen": 340, "screwdriver": 341, "sculpture": 342, "scythe": 343, "sewer": 344, "sewing machine": 345, "shed": 346, "sheep": 347, "shell": 348, "shelves": 349, "shoe": 350, "shopping cart": 351, "shovel": 352, "sidecar": 353, "sidewalk": 354, "sign": 355, "signal light": 356, "sink": 357, "skateboard": 358, "ski": 359, "sky": 360, "sled": 361, "slippers": 362, "smoke": 363, "snail": 364, "snake": 365, "snow": 366, "snowmobiles": 367, "sofa": 368, "spanner": 369, "spatula": 370, "speaker": 371, "speed bump": 372, "spice container": 373, "spoon": 374, "sprayer": 375, "squirrel": 376, "stage": 377, "stair": 378, "stapler": 379, "stick": 380, "sticky note": 381, "stone": 382, "stool": 383, "stove": 384, "straw": 385, "stretcher": 386, "sun": 387, "sunglass": 388, "sunshade": 389, "surveillance camera": 390, "swan": 391, "sweeper": 392, "swim ring": 393, "swimming pool": 394, "swing": 395, "switch": 396, "table": 397, "tableware": 398, "tank": 399, "tap": 400, "tape": 401, "tarp": 402, "telephone": 403, "telephone booth": 404, "tent": 405, "tire": 406, "toaster": 407, "toilet": 408, "tong": 409, "tool": 410, "toothbrush": 411, "towel": 412, "toy": 413, "toy car": 414, "track": 415, "train": 416, "trampoline": 417, "trash bin": 418, "tray": 419, "tree": 420, "tricycle": 421, "tripod": 422, "trophy": 423, "truck": 424, "tube": 425, "turtle": 426, "tvmonitor": 427, "tweezers": 428, "typewriter": 429, "umbrella": 430, "unknown": 431, "vacuum cleaner": 432, "vending machine": 433, "video camera": 434, "video game console": 435, "video player": 436, "video tape": 437, "violin": 438, "wakeboard": 439, "wall": 440, "wallet": 441, "wardrobe": 442, "washing machine": 443, "watch": 444, "water": 445, "water dispenser": 446, "water pipe": 447, "water skate board": 448, "watermelon": 449, "whale": 450, "wharf": 451, "wheel": 452, "wheelchair": 453, "window": 454, "window blinds": 455, "wineglass": 456, "wire": 457, "wood": 458, "wool": 459} -------------------------------------------------------------------------------- /data/db_info/nyu_classes.json: -------------------------------------------------------------------------------- 1 | {"air conditioner": 79, "air duct": 38, "air vent": 25, "alarm": 525, "alarm clock": 156, "album": 822, "aluminium foil": 708, "american flag": 870, "antenna": 796, "apple": 334, "ashtray": 377, "avocado": 680, "baby chair": 494, "baby gate": 591, "back scrubber": 656, "backpack": 206, "bag": 55, "bag of bagels": 690, "bag of chips": 245, "bag of flour": 285, "bag of hot dog buns": 747, "bag of oreo": 692, "bagel": 689, "baking dish": 260, "ball": 60, "balloon": 385, "banana": 147, "banana peel": 691, "banister": 453, "bar": 51, "bar of soap": 564, "barrel": 343, "baseball": 825, "basket": 39, "basketball": 542, "basketball hoop": 162, "bassinet": 414, "bathtub": 136, "bean bag": 797, "bed": 157, "bed sheets": 352, "bedding package": 808, "beeper": 780, "belt": 610, "bench": 204, "bicycle": 189, "bicycle helmet": 337, "bin": 307, "binder": 399, "blackboard": 225, "blanket": 312, "blender": 268, "blinds": 80, "board": 408, "book": 1, "book holder": 827, "bookend": 374, "bookrack": 224, "books": 85, "bookshelf": 88, "boomerang": 773, "bottle": 2, "bottle of comet": 755, "bottle of contact lens solution": 633, "bottle of hand wash liquid": 677, "bottle of ketchup": 750, "bottle of liquid": 685, "bottle of listerine": 676, "bottle of perfume": 840, "bottle of soap": 502, "bowl": 22, "box": 26, "box of paper": 503, "box of ziplock bags": 271, "bracelet": 860, "bread": 246, "brick": 695, "briefcase": 617, "broom": 328, "bucket": 427, "bulb": 688, "bunk bed": 804, "business cards": 535, "butterfly sculpture": 712, "button": 774, "cabinet": 3, "cable box": 168, "cable modem": 73, "cable rack": 104, "cables": 450, "cactus": 641, "cake": 289, "calculator": 200, "calendar": 583, "camera": 40, "can": 329, "can of beer": 857, "can of food": 280, "can opener": 279, "candelabra": 605, "candle": 137, "candlestick": 148, "cane": 555, "canister": 794, "cannister": 355, "cans of cat food": 593, "cap stand": 441, "car": 530, "cardboard sheet": 452, "cardboard tube": 413, "cart": 305, "carton": 397, "case": 851, "casserole dish": 365, "cat": 594, "cat bed": 608, "cat cage": 580, "cat house": 856, "cd": 207, "cd disc": 585, "ceiling": 4, "celery": 720, "cell phone": 290, "cell phone charger": 602, "centerpiece": 878, "ceramic frog": 643, "certificate": 790, "chair": 5, "chalkboard": 428, "chandelier": 342, "chapstick": 726, "charger": 743, "charger and wire": 574, "chart": 411, "chart roll": 495, "chart stand": 393, "chessboard": 198, "chest": 344, "child carrier": 491, "chimney": 702, "circuit breaker box": 112, "classroom board": 392, "cleaner": 548, "cleaning wipes": 381, "clipboard": 536, "clock": 56, "cloth bag": 492, "cloth drying stand": 549, "clothes": 141, "clothing detergent": 501, "clothing dryer": 498, "clothing drying rack": 556, "clothing hamper": 770, "clothing hanger": 214, "clothing iron": 572, "clothing washer": 499, "coaster": 387, "coat hanger": 400, "coffee bag": 226, "coffee grinder": 237, "coffee machine": 234, "coffee packet": 227, "coffee pot": 893, "coffee table": 356, "coins": 308, "coke bottle": 297, "collander": 694, "cologne": 176, "column": 94, "comb": 809, "comforter": 484, "computer": 46, "computer disk": 616, "conch shell": 673, "cone": 6, "console controller": 613, "console system": 518, "contact lens case": 634, "contact lens solution bottle": 173, "container": 140, "container of skin cream": 637, "cooking pan": 252, "cooking pot cover": 761, "copper vessel": 528, "cordless phone": 474, "cordless telephone": 545, "cork board": 34, "corkscrew": 713, "corn": 716, "counter": 7, "cradle": 493, "crate": 183, "crayon": 511, "cream": 635, "cream tube": 653, "crib": 485, "crock pot": 330, "cup": 35, "curtain": 89, "curtain rod": 582, "cutting board": 247, "decanter": 345, "decoration item": 842, "decorative bottle": 767, "decorative bowl": 826, "decorative candle": 865, "decorative dish": 757, "decorative egg": 862, "decorative item": 853, "decorative plate": 383, "decorative platter": 370, "deoderant": 159, "desk": 36, "desk drawer": 475, "desk mat": 473, "desser": 829, "dish brush": 248, "dish cover": 368, "dish rack": 581, "dish scrubber": 261, "dishes": 733, "dishwasher": 8, "display board": 444, "display case": 540, "display platter": 877, "dog": 701, "dog bed": 858, "dog bowl": 697, "dog cage": 703, "dog toy": 736, "doily": 892, "doll": 99, "doll house": 486, "dollar bill": 810, "dolly": 219, "door": 28, "door window reflection": 642, "door curtain": 663, "door facing trimreflection": 657, "door frame": 615, "door knob": 27, "door lock": 646, "door way": 609, "door way arch": 686, "doorreflection": 658, "drain": 567, "drawer": 174, "drawer handle": 371, "drawer knob": 833, "dresser": 169, "drum": 145, "drying rack": 262, "drying stand": 554, "duck": 887, "duster": 115, "dvd": 197, "dvd player": 170, "dvds": 325, "earplugs": 152, "educational display": 419, "eggplant": 888, "eggs": 699, "electric box": 550, "electric mixer": 369, "electric toothbrush": 142, "electric toothbrush base": 629, "electrical kettle": 738, "electrical outlet": 98, "electronic drumset": 816, "envelope": 476, "envelopes": 843, "eraser": 100, "ethernet jack": 118, "excercise ball": 155, "excercise equipment": 457, "excercise machine": 558, "exit sign": 86, "eye glasses": 335, "eyeball plastic ball": 787, "face wash cream": 665, "fan": 74, "faucet": 9, "faucet handle": 568, "fax machine": 68, "fiberglass case": 543, "figurine": 836, "file": 75, "file box": 734, "file container": 469, "file holder": 410, "file pad": 619, "file stand": 479, "filing shelves": 401, "fire alarm": 338, "fire extinguisher": 10, "fireplace": 372, "fish tank": 782, "flag": 405, "flashcard": 201, "flashlight": 666, "flask": 693, "flask set": 760, "flatbed scanner": 537, "flipboard": 106, "floor": 11, "floor mat": 143, "floor trim": 868, "flower": 81, "flower basket": 595, "flower box": 471, "flower pot": 146, "folder": 69, "folders": 213, "food processor": 715, "food wrapped on a tray": 752, "foosball table": 510, "foot rest": 163, "football": 166, "fork": 349, "framed certificate": 544, "fruit": 286, "fruit basket": 728, "fruit platter": 596, "fruit stand": 681, "fruitplate": 682, "frying pan": 318, "furnace": 551, "furniture": 524, "game system": 516, "game table": 429, "garage door": 850, "garbage bag": 269, "garbage bin": 12, "garlic": 763, "gate": 223, "gift wrapping": 351, "gift wrapping roll": 185, "glass": 612, "glass baking dish": 316, "glass box": 622, "glass container": 636, "glass dish": 721, "glass pane": 412, "glass pot": 304, "glass rack": 216, "glass set": 705, "glass ware": 889, "globe": 347, "globe stand": 466, "glove": 729, "gold piece": 880, "grandfather clock": 462, "grapefruit": 597, "green screen": 57, "grill": 700, "guitar": 300, "guitar case": 771, "hair brush": 120, "hair dryer": 577, "hamburger bun": 748, "hammer": 883, "hand blender": 599, "hand fan": 845, "hand sanitizer": 76, "hand sanitizer dispenser": 505, "hand sculpture": 309, "hand weight": 838, "handle": 758, "hanger": 211, "hangers": 209, "hanging hooks": 96, "hat": 193, "head phone": 586, "head phones": 584, "headband": 802, "headboard": 161, "headphones": 160, "heater": 111, "heating tray": 714, "hockey glove": 194, "hockey stick": 195, "hole puncher": 61, "hookah": 187, "hooks": 95, "hoola hoop": 512, "horse toy": 513, "hot dogs": 722, "hot water heater": 228, "humidifier": 340, "id card": 478, "incense candle": 644, "incense holder": 672, "indoor fountain": 863, "inkwell": 824, "ipad": 386, "iphone": 296, "ipod": 310, "ipod dock": 817, "iron box": 557, "iron grill": 463, "ironing board": 313, "jacket": 324, "jar": 70, "jeans": 849, "jersey": 311, "jug": 687, "juicer": 746, "karate belts": 775, "key": 378, "keyboard": 47, "kichen towel": 264, "kinect": 823, "kitchen container plastic": 739, "kitchen island": 456, "kitchen items": 253, "kitchen utensil": 266, "kitchen utensils": 753, "kiwi": 598, "knife": 259, "knife rack": 258, "knob": 652, "knobs": 600, "label": 759, "ladder": 48, "ladel": 254, "lamp": 144, "lamp shade": 859, "laptop": 37, "laundry basket": 164, "laundry detergent jug": 500, "lazy susan": 679, "lectern": 882, "leg of a girl": 409, "lego": 805, "lemon": 765, "letter stand": 620, "lid": 533, "lid of jar": 445, "life jacket": 784, "light": 62, "light bulb": 566, "light switch": 301, "light switchreflection": 659, "lighting track": 354, "lint comb": 798, "lint roller": 178, "litter box": 606, "lock": 180, "luggage": 783, "luggage rack": 803, "lunch bag": 407, "machine": 220, "magazine": 71, "magazine holder": 468, "magic 8ball": 839, "magnet": 23, "mail shelf": 65, "mail tray": 618, "mailshelf": 153, "makeup brush": 121, "manilla envelope": 63, "mantel": 58, "mantle": 874, "map": 107, "mask": 191, "matchbox": 884, "mattress": 576, "measuring cup": 730, "medal": 776, "medicine tube": 660, "mellon": 707, "menorah": 336, "mens suit": 167, "mens tie": 315, "mezuza": 531, "microphone": 818, "microphone stand": 821, "microwave": 13, "mini display platform": 869, "mirror": 122, "model boat": 789, "modem": 91, "money": 482, "monitor": 49, "mortar and pestle": 357, "motion camera": 52, "mouse": 103, "mouse pad": 539, "muffins": 229, "mug hanger": 749, "mug holder": 744, "music keyboard": 819, "music stand": 820, "music stereo": 442, "nailclipper": 569, "napkin": 244, "napkin dispenser": 230, "napkin holder": 235, "napkin ring": 350, "necklace": 341, "necklace holder": 779, "newspapers": 873, "night stand": 158, "notebook": 210, "notecards": 438, "oil container": 683, "onion": 322, "orange": 709, "orange juicer": 745, "orange plastic cap": 704, "ornamental item": 527, "ornamental plant": 459, "ornamental pot": 735, "ottoman": 359, "oven": 238, "oven handle": 366, "oven mitt": 754, "package of bedroom sheets": 807, "package of bottled water": 875, "package of water": 684, "pan": 589, "paper": 15, "paper bundle": 534, "paper cutter": 108, "paper holder": 470, "paper rack": 77, "paper towel": 113, "paper towel dispenser": 14, "paper towel holder": 281, "paper tray": 538, "paper weight": 480, "papers": 483, "peach": 710, "pen": 97, "pen box": 190, "pen cup": 786, "pen holder": 464, "pen stand": 314, "pencil": 396, "pencil holder": 101, "pepper": 885, "pepper grinder": 579, "pepper shaker": 455, "perfume": 655, "perfume box": 654, "person": 331, "personal care liquid": 649, "phone jack": 363, "photo": 508, "photo album": 864, "piano": 298, "piano bench": 460, "picture": 64, "picture of fish": 394, "piece of wood": 552, "pig": 811, "pillow": 119, "pineapple": 740, "ping pong ball": 623, "ping pong racket": 624, "ping pong racquet": 627, "ping pong table": 625, "pipe": 41, "pitcher": 273, "pizza box": 274, "placard": 420, "placemat": 154, "plant": 82, "plant pot": 239, "plaque": 231, "plastic bowl": 320, "plastic box": 395, "plastic chair": 489, "plastic crate": 402, "plastic cup of coffee": 621, "plastic dish": 723, "plastic rack": 403, "plastic toy container": 514, "plastic tray": 404, "plastic tub": 232, "plate": 233, "platter": 129, "playpen": 815, "pool ball": 520, "pool sticks": 517, "pool table": 515, "poster board": 406, "poster case": 116, "pot": 16, "potato": 323, "power surge": 451, "printer": 66, "projector": 90, "projector screen": 53, "puppy toy": 791, "purse": 181, "pyramid": 472, "quill": 793, "quilt": 575, "radiator": 236, "radio": 188, "rags": 852, "railing": 497, "range hood": 380, "razor": 632, "reflection of window shutters": 861, "refridgerator": 17, "remote control": 175, "roll of paper towels": 449, "roll of toilet paper": 674, "rolled carpet": 571, "rolled up rug": 891, "room divider": 87, "rope": 560, "router": 303, "rug": 130, "ruler": 72, "salt and pepper": 737, "salt container": 361, "salt shaker": 332, "saucer": 217, "scale": 639, "scarf": 240, "scenary": 832, "scissor": 29, "sculpture": 294, "sculpture of the chrysler building": 846, "sculpture of the eiffel tower": 847, "sculpture of the empire state building": 848, "security camera": 212, "server": 360, "serving dish": 867, "serving platter": 876, "serving spoon": 249, "sewing machine": 890, "shaver": 171, "shaving cream": 570, "sheet": 559, "sheet music": 461, "sheet of metal": 287, "sheets": 348, "shelf frame": 855, "shelves": 42, "shirts in hanger": 302, "shoe": 149, "shoe hanger": 834, "shoe rack": 614, "shoelace": 785, "shofar": 546, "shopping baskets": 222, "shopping cart": 319, "shorts": 192, "shovel": 607, "show piece": 454, "shower base": 667, "shower cap": 132, "shower curtain": 123, "shower head": 650, "shower hose": 669, "shower knob": 651, "shower pipe": 664, "shower tube": 675, "sifter": 727, "sign": 208, "sink": 24, "sink protector": 270, "six pack of beer": 382, "sleeping bag": 841, "slide": 814, "soap": 133, "soap box": 671, "soap dish": 638, "soap holder": 506, "soap stand": 640, "soap tray": 662, "soccer ball": 837, "sock": 165, "sofa": 83, "soft toy": 422, "soft toy group": 421, "spatula": 255, "speaker": 54, "spice bottle": 272, "spice rack": 241, "spice stand": 256, "sponge": 250, "spoon": 283, "spoon sets": 592, "spoon stand": 282, "spot light": 353, "squash": 717, "squeeze tube": 131, "stack of plates": 358, "stacked bins boxes": 446, "stacked chairs": 43, "stacked plastic racks": 447, "stairs": 215, "stamp": 114, "stand": 50, "staple remover": 202, "stapler": 67, "steamer": 742, "step stool": 276, "stereo": 84, "stick": 529, "sticker": 725, "sticks": 561, "stones": 578, "stool": 150, "storage bin": 812, "storage chest": 813, "storage rack": 448, "storage shelvesbooks": 430, "storage space": 645, "stove": 242, "stove burner": 18, "stroller": 373, "stuffed animal": 177, "styrofoam object": 44, "suger jar": 741, "suit jacket": 379, "suitcase": 199, "surge protect": 611, "surge protector": 326, "switchbox": 364, "table": 19, "table runner": 375, "tablecloth": 292, "tag": 218, "tape": 109, "tape dispenser": 30, "tea box": 879, "tea cannister": 769, "tea coaster": 711, "tea kettle": 243, "tea pot": 678, "telephone": 32, "telephone cord": 31, "telescope": 467, "television": 172, "tennis racket": 626, "tent": 835, "thermostat": 110, "throw": 872, "tin foil": 265, "tissue": 648, "tissue box": 138, "tissue roll": 764, "toaster": 251, "toaster oven": 275, "toilet": 124, "toilet bowl brush": 565, "toilet brush": 630, "toilet paper": 139, "toilet paper holder": 647, "toilet plunger": 563, "toiletries": 631, "toiletries bag": 125, "toothbrush": 127, "toothbrush holder": 126, "toothpaste": 128, "toothpaste holder": 670, "torah": 894, "torch": 696, "towel": 135, "towel rod": 134, "toy": 389, "toy apple": 830, "toy bin": 417, "toy boat": 795, "toy bottle": 182, "toy box": 434, "toy car": 415, "toy cash register": 532, "toy chair": 487, "toy chest": 801, "toy cube": 423, "toy cuboid": 431, "toy cylinder": 424, "toy dog": 831, "toy doll": 465, "toy horse": 828, "toy house": 490, "toy kitchen": 751, "toy phone": 435, "toy plane": 481, "toy pyramid": 788, "toy rectangle": 425, "toy shelf": 416, "toy sink": 436, "toy sofa": 488, "toy stroller": 854, "toy table": 526, "toy tree": 432, "toy triangle": 426, "toy truck": 391, "toy trucks": 439, "toyhouse": 437, "toys basket": 390, "toys box": 496, "toys rack": 443, "toys shelf": 418, "track light": 33, "trampoline": 521, "travel bag": 799, "tray": 179, "treadmill": 458, "tree sculpture": 541, "tricycle": 522, "trinket": 844, "trivet": 257, "trolley": 504, "trolly": 221, "trophy": 547, "tub of tupperware": 604, "tumbler": 327, "tuna cans": 590, "tupperware": 762, "tv stand": 291, "typewriter": 376, "umbrella": 203, "unknown": 20, "urn": 151, "usb drive": 587, "utensil": 267, "utensil container": 362, "utensils": 317, "vacuum cleaner": 306, "vase": 78, "vasoline": 184, "vegetable": 724, "vegetable peeler": 277, "vegetables": 719, "vessel": 263, "vessel set": 706, "vessels": 601, "vhs tapes": 871, "video game": 519, "vuvuzela": 196, "waffle maker": 288, "walkie talkie": 398, "walkietalkie": 866, "wall": 21, "wall decoration": 186, "wall divider": 800, "wall hand sanitizer dispenser": 440, "wall stand": 295, "wallet": 661, "wardrobe": 772, "washing machine": 278, "watch": 384, "water carboy": 102, "water cooler": 509, "water dispenser": 507, "water filter": 731, "water fountain": 339, "water heater": 588, "water purifier": 93, "watermellon": 718, "webcam": 781, "whisk": 367, "whiteboard": 45, "whiteboard eraser": 388, "whiteboard marker": 117, "wii": 523, "window": 59, "window box": 778, "window cover": 573, "window frame": 477, "window seat": 777, "window shelf": 668, "wine": 766, "wine accessory": 732, "wine bottle": 333, "wine glass": 293, "wine rack": 299, "wire": 92, "wire basket": 603, "wire board": 792, "wire rack": 105, "wire tray": 768, "wooden container": 321, "wooden kitchen utensils": 284, "wooden pillar": 553, "wooden plank": 698, "wooden planks": 562, "wooden toy": 433, "wooden utensil": 756, "wooden utensils": 346, "wreathe": 881, "xbox": 628, "yarmulka": 806, "yellow pepper": 886, "yoga mat": 205} -------------------------------------------------------------------------------- /data/db_info/pascal_map.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EnVision-Research/MTMamba/17d132d2a7a1ef80c35ffe40c8dd3ff72fa07d77/data/db_info/pascal_map.npy -------------------------------------------------------------------------------- /data/db_info/pascal_part.json: -------------------------------------------------------------------------------- 1 | { 2 | "1": {"body": 1 , "engine_1": 11, "engine_10": 20, "engine_2": 12, "engine_3": 13, "engine_4": 14, "engine_5": 15, "engine_6": 16, "engine_7": 17, "engine_8": 18, "engine_9": 19, "lwing": 3, "rwing": 4, "stern": 2, "tail": 5, "wheel_1": 21, "wheel_10": 30, "wheel_2": 22, "wheel_3": 23, "wheel_4": 24, "wheel_5": 25, "wheel_6": 26, "wheel_7": 27, "wheel_8": 28, "wheel_9": 29}, 3 | "2": {"bwheel": 2 , "chainwheel": 5, "fwheel": 1, "handlebar": 4, "headlight_1": 11, "headlight_10": 20, "headlight_2": 12, "headlight_3": 13, "headlight_4": 14, "headlight_5": 15, "headlight_6": 16, "headlight_7": 17, "headlight_8": 18, "headlight_9": 19, "saddle": 3}, 4 | "3": {"beak": 4 , "head": 1, "leye": 2, "lfoot": 10, "lleg": 9, "lwing": 7, "neck": 6, "reye": 3, "rfoot": 12, "rleg": 11, "rwing": 8, "tail": 13, "torso": 5}, 5 | "4": {}, 6 | "5": {"body": 2 , "cap": 1}, 7 | "6": {"backside": 4 , "bliplate": 9, "door_1": 11, "door_10": 20, "door_2": 12, "door_3": 13, "door_4": 14, "door_5": 15, "door_6": 16, "door_7": 17, "door_8": 18, "door_9": 19, "fliplate": 8, "frontside": 1, "headlight_1": 31, "headlight_10": 40, "headlight_2": 32, "headlight_3": 33, "headlight_4": 34, "headlight_5": 35, "headlight_6": 36, "headlight_7": 37, "headlight_8": 38, "headlight_9": 39, "leftmirror": 6, "leftside": 2, "rightmirror": 7, "rightside": 3, "roofside": 5, "wheel_1": 21, "wheel_10": 30, "wheel_2": 22, "wheel_3": 23, "wheel_4": 24, "wheel_5": 25, "wheel_6": 26, "wheel_7": 27, "wheel_8": 28, "wheel_9": 29, "window_1": 41, "window_10": 50, "window_11": 51, "window_12": 52, "window_13": 53, "window_14": 54, "window_15": 55, "window_16": 56, "window_17": 57, "window_18": 58, "window_19": 59, "window_2": 42, "window_20": 60, "window_3": 43, "window_4": 44, "window_5": 45, "window_6": 46, "window_7": 47, "window_8": 48, "window_9": 49}, 8 | "7": {"backside": 4 , "bliplate": 9, "door_1": 11, "door_10": 20, "door_2": 12, "door_3": 13, "door_4": 14, "door_5": 15, "door_6": 16, "door_7": 17, "door_8": 18, "door_9": 19, "fliplate": 8, "frontside": 1, "headlight_1": 31, "headlight_10": 40, "headlight_2": 32, "headlight_3": 33, "headlight_4": 34, "headlight_5": 35, "headlight_6": 36, "headlight_7": 37, "headlight_8": 38, "headlight_9": 39, "leftmirror": 6, "leftside": 2, "rightmirror": 7, "rightside": 3, "roofside": 5, "wheel_1": 21, "wheel_10": 30, "wheel_2": 22, "wheel_3": 23, "wheel_4": 24, "wheel_5": 25, "wheel_6": 26, "wheel_7": 27, "wheel_8": 28, "wheel_9": 29, "window_1": 41, "window_10": 50, "window_11": 51, "window_12": 52, "window_13": 53, "window_14": 54, "window_15": 55, "window_16": 56, "window_17": 57, "window_18": 58, "window_19": 59, "window_2": 42, "window_20": 60, "window_3": 43, "window_4": 44, "window_5": 45, "window_6": 46, "window_7": 47, "window_8": 48, "window_9": 49}, 9 | "8": {"head": 1 , "lbleg": 13, "lbpa": 14, "lear": 4, "leye": 2, "lfleg": 9, "lfpa": 10, "neck": 8, "nose": 6, "rbleg": 15, "rbpa": 16, "rear": 5, "reye": 3, "rfleg": 11, "rfpa": 12, "tail": 17, "torso": 7}, 10 | "9": {}, 11 | "10": {"head": 1 , "lblleg": 16, "lbuleg": 15, "lear": 4, "leye": 2, "lflleg": 12, "lfuleg": 11, "lhorn": 7, "muzzle": 6, "neck": 10, "rblleg": 18, "rbuleg": 17, "rear": 5, "reye": 3, "rflleg": 14, "rfuleg": 13, "rhorn": 8, "tail": 19, "torso": 9}, 12 | "11": {}, 13 | "12": {"head": 1 , "lbleg": 13, "lbpa": 14, "lear": 4, "leye": 2, "lfleg": 9, "lfpa": 10, "muzzle": 20, "neck": 8, "nose": 6, "rbleg": 15, "rbpa": 16, "rear": 5, "reye": 3, "rfleg": 11, "rfpa": 12, "tail": 17, "torso": 7}, 14 | "13": {"head": 1 , "lbho": 32, "lblleg": 16, "lbuleg": 15, "lear": 4, "leye": 2, "lfho": 30, "lflleg": 12, "lfuleg": 11, "muzzle": 6, "neck": 10, "rbho": 33, "rblleg": 18, "rbuleg": 17, "rear": 5, "reye": 3, "rfho": 31, "rflleg": 14, "rfuleg": 13, "tail": 19, "torso": 9}, 15 | "14": {"bwheel": 2 , "fwheel": 1, "handlebar": 3, "headlight_1": 11, "headlight_10": 20, "headlight_2": 12, "headlight_3": 13, "headlight_4": 14, "headlight_5": 15, "headlight_6": 16, "headlight_7": 17, "headlight_8": 18, "headlight_9": 19, "saddle": 4}, 16 | "15": {"hair": 10 , "head": 1, "lear": 4, "lebrow": 6, "leye": 2, "lfoot": 21, "lhand": 15, "llarm": 13, "llleg": 19, "luarm": 14, "luleg": 20, "mouth": 9, "neck": 12, "nose": 8, "rear": 5, "rebrow": 7, "reye": 3, "rfoot": 24, "rhand": 18, "rlarm": 16, "rlleg": 22, "ruarm": 17, "ruleg": 23, "torso": 11}, 17 | "16": {"plant": 2 , "pot": 1}, 18 | "17": {"head": 1 , "lblleg": 16, "lbuleg": 15, "lear": 4, "leye": 2, "lflleg": 12, "lfuleg": 11, "lhorn": 7, "muzzle": 6, "neck": 10, "rblleg": 18, "rbuleg": 17, "rear": 5, "reye": 3, "rflleg": 14, "rfuleg": 13, "rhorn": 8, "tail": 19, "torso": 9}, 19 | "18": {}, 20 | "19": {"cbackside_1": 61 , "cbackside_10": 70, "cbackside_2": 62, "cbackside_3": 63, "cbackside_4": 64, "cbackside_5": 65, "cbackside_6": 66, "cbackside_7": 67, "cbackside_8": 68, "cbackside_9": 69, "cfrontside_1": 31, "cfrontside_10": 40, "cfrontside_2": 32, "cfrontside_3": 33, "cfrontside_4": 34, "cfrontside_5": 35, "cfrontside_6": 36, "cfrontside_7": 37, "cfrontside_8": 38, "cfrontside_9": 39, "cleftside_1": 41, "cleftside_10": 50, "cleftside_2": 42, "cleftside_3": 43, "cleftside_4": 44, "cleftside_5": 45, "cleftside_6": 46, "cleftside_7": 47, "cleftside_8": 48, "cleftside_9": 49, "coach_1": 21, "coach_10": 30, "coach_2": 22, "coach_3": 23, "coach_4": 24, "coach_5": 25, "coach_6": 26, "coach_7": 27, "coach_8": 28, "coach_9": 29, "crightside_1": 51, "crightside_10": 60, "crightside_2": 52, "crightside_3": 53, "crightside_4": 54, "crightside_5": 55, "crightside_6": 56, "crightside_7": 57, "crightside_8": 58, "crightside_9": 59, "croofside_1": 71, "croofside_10": 80, "croofside_2": 72, "croofside_3": 73, "croofside_4": 74, "croofside_5": 75, "croofside_6": 76, "croofside_7": 77, "croofside_8": 78, "croofside_9": 79, "hbackside": 5, "head": 1, "headlight_1": 11, "headlight_10": 20, "headlight_2": 12, "headlight_3": 13, "headlight_4": 14, "headlight_5": 15, "headlight_6": 16, "headlight_7": 17, "headlight_8": 18, "headlight_9": 19, "hfrontside": 2, "hleftside": 3, "hrightside": 4, "hroofside": 6}, 21 | "20": {"screen": 1 } 22 | } 23 | -------------------------------------------------------------------------------- /data/mat2png.py: -------------------------------------------------------------------------------- 1 | import os, glob, cv2, imageio 2 | import scipy.io as sio 3 | from skimage.morphology import thin 4 | import numpy as np 5 | from PIL import Image 6 | import tqdm 7 | 8 | gt_dir = '/dataset/baijionglin/dataset/PASCALContext/pascal-context/trainval' 9 | img_dir = '/dataset/baijionglin/dataset/PASCALContext/JPEGImages' 10 | save_dir = '/dataset/baijionglin/dataset/PASCALContext/edge' 11 | ids = sorted([os.path.split(file)[-1] for file in glob.glob(os.path.join(gt_dir, "*.mat"))]) 12 | 13 | for i in tqdm.tqdm(ids): 14 | i = os.path.splitext(i)[0] 15 | # print(i) 16 | gt = os.path.join(gt_dir, "{}.mat".format(i)) 17 | _tmp = sio.loadmat(gt) 18 | _edge = cv2.Laplacian(_tmp['LabelMap'], cv2.CV_64F) 19 | _edge = thin(np.abs(_edge) > 0).astype(np.float32) 20 | 21 | img_path = os.path.join(img_dir, "{}.jpg".format(i)) 22 | _img = np.array(Image.open(img_path).convert('RGB')).astype(np.float32) 23 | if _edge.shape != _img.shape[:2]: 24 | _edge = cv2.resize(_edge, _img.shape[:2][::-1], interpolation=cv2.INTER_NEAREST) 25 | _edge = np.array(_edge) * 255 26 | imageio.imwrite(os.path.join(save_dir, "{}.png".format(i)), _edge.astype(np.uint8)) 27 | 28 | # break -------------------------------------------------------------------------------- /data/nyud.py: -------------------------------------------------------------------------------- 1 | # This code is referenced from 2 | # https://github.com/facebookresearch/astmt/ 3 | # 4 | # Copyright (c) Facebook, Inc. and its affiliates. 5 | # All rights reserved. 6 | # 7 | # License: Attribution-NonCommercial 4.0 International 8 | 9 | import os 10 | import sys 11 | import tarfile 12 | import cv2 13 | 14 | from PIL import Image 15 | import numpy as np 16 | import torch.utils.data as data 17 | import scipy.io as sio 18 | from six.moves import urllib 19 | 20 | class NYUD_MT(data.Dataset): 21 | """ 22 | from MTI-Net, changed for using ATRC data 23 | NYUD dataset for multi-task learning. 24 | Includes semantic segmentation and depth prediction. 25 | 26 | Data can also be found at: 27 | https://drive.google.com/file/d/14EAEMXmd3zs2hIMY63UhHPSFPDAkiTzw/view?usp=sharing 28 | 29 | """ 30 | 31 | 32 | def __init__(self, 33 | root=None, 34 | download=False, 35 | split='val', 36 | transform=None, 37 | retname=True, 38 | overfit=False, 39 | do_edge=False, 40 | do_semseg=False, 41 | do_normals=False, 42 | do_depth=False, 43 | ): 44 | 45 | self.root = root 46 | 47 | if download: 48 | raise NotImplementedError 49 | 50 | self.transform = transform 51 | 52 | if isinstance(split, str): 53 | self.split = [split] 54 | else: 55 | split.sort() 56 | self.split = split 57 | 58 | self.retname = retname 59 | 60 | # Original Images 61 | self.im_ids = [] 62 | self.images = [] 63 | _image_dir = os.path.join(root, 'images') 64 | 65 | # Edge Detection 66 | self.do_edge = do_edge 67 | self.edges = [] 68 | _edge_gt_dir = os.path.join(root, 'edge') 69 | 70 | # Semantic segmentation 71 | self.do_semseg = do_semseg 72 | self.semsegs = [] 73 | _semseg_gt_dir = os.path.join(root, 'segmentation') 74 | 75 | # Surface Normals 76 | self.do_normals = do_normals 77 | self.normals = [] 78 | _normal_gt_dir = os.path.join(root, 'normals') 79 | 80 | # Depth 81 | self.do_depth = do_depth 82 | self.depths = [] 83 | _depth_gt_dir = os.path.join(root, 'depth') 84 | 85 | # train/val/test splits are pre-cut 86 | _splits_dir = os.path.join(root, 'gt_sets') 87 | 88 | print('Initializing dataloader for NYUD {} set'.format(''.join(self.split))) 89 | for splt in self.split: 90 | with open(os.path.join(os.path.join(_splits_dir, splt + '.txt')), 'r') as f: 91 | lines = f.read().splitlines() 92 | 93 | for ii, line in enumerate(lines): 94 | 95 | # Images 96 | _image = os.path.join(_image_dir, line + '.png') 97 | assert os.path.isfile(_image) 98 | self.images.append(_image) 99 | self.im_ids.append(line.rstrip('\n')) 100 | 101 | # Edges 102 | _edge = os.path.join(self.root, _edge_gt_dir, line + '.png') 103 | assert os.path.isfile(_edge) 104 | self.edges.append(_edge) 105 | 106 | # Semantic Segmentation 107 | _semseg = os.path.join(self.root, _semseg_gt_dir, line + '.png') 108 | assert os.path.isfile(_semseg) 109 | self.semsegs.append(_semseg) 110 | 111 | # Surface Normals 112 | _normal = os.path.join(self.root, _normal_gt_dir, line + '.png') 113 | assert os.path.isfile(_normal) 114 | self.normals.append(_normal) 115 | 116 | # Depth Prediction 117 | _depth = os.path.join(self.root, _depth_gt_dir, line + '.npy') 118 | assert os.path.isfile(_depth) 119 | self.depths.append(_depth) 120 | 121 | if self.do_edge: 122 | assert (len(self.images) == len(self.edges)) 123 | if self.do_semseg: 124 | assert (len(self.images) == len(self.semsegs)) 125 | if self.do_depth: 126 | assert (len(self.images) == len(self.depths)) 127 | if self.do_normals: 128 | assert (len(self.images) == len(self.normals)) 129 | 130 | # Uncomment to overfit to one image 131 | if overfit: 132 | n_of = 64 133 | self.images = self.images[:n_of] 134 | self.im_ids = self.im_ids[:n_of] 135 | 136 | # Display stats 137 | print('Number of dataset images: {:d}'.format(len(self.images))) 138 | 139 | def __getitem__(self, index): 140 | sample = {} 141 | 142 | _img = self._load_img(index) 143 | sample['image'] = _img 144 | 145 | if self.do_edge: 146 | _edge = self._load_edge(index) 147 | if _edge.shape[:2] != _img.shape[:2]: 148 | _edge = cv2.resize(_edge, _img.shape[:2][::-1], interpolation=cv2.INTER_NEAREST) 149 | sample['edge'] = _edge 150 | 151 | if self.do_semseg: 152 | _semseg = self._load_semseg(index) 153 | if _semseg.shape[:2] != _img.shape[:2]: 154 | print('RESHAPE SEMSEG') 155 | _semseg = cv2.resize(_semseg, _img.shape[:2][::-1], interpolation=cv2.INTER_NEAREST) 156 | sample['semseg'] = _semseg 157 | 158 | if self.do_normals: 159 | _normals = self._load_normals(index) 160 | if _normals.shape[:2] != _img.shape[:2]: 161 | _normals = cv2.resize(_normals, _img.shape[:2][::-1], interpolation=cv2.INTER_CUBIC) 162 | sample['normals'] = _normals 163 | 164 | if self.do_depth: 165 | _depth = self._load_depth(index) 166 | if _depth.shape[:2] != _img.shape[:2]: 167 | print('RESHAPE DEPTH') 168 | _depth = cv2.resize(_depth, _img.shape[:2][::-1], interpolation=cv2.INTER_NEAREST) 169 | sample['depth'] = _depth 170 | 171 | if self.retname: 172 | sample['meta'] = {'img_name': str(self.im_ids[index]), 173 | 'img_size': (_img.shape[0], _img.shape[1])} 174 | 175 | if self.transform is not None: 176 | sample = self.transform(sample) 177 | sample['index'] = index 178 | return sample 179 | 180 | def __len__(self): 181 | return len(self.images) 182 | 183 | def _load_img(self, index): 184 | _img = Image.open(self.images[index]).convert('RGB') 185 | _img = np.array(_img, dtype=np.float32, copy=False) 186 | return _img 187 | 188 | def _load_edge(self, index): 189 | _edge = Image.open(self.edges[index]) 190 | _edge = np.expand_dims(np.array(_edge, dtype=np.float32, copy=False), axis=2) / 255. 191 | return _edge 192 | 193 | def _load_semseg(self, index): 194 | # Note: We ignore the background class (40-way classification), as in related work: 195 | _semseg = Image.open(self.semsegs[index]) 196 | _semseg = np.expand_dims(np.array(_semseg, dtype=np.float32, copy=False), axis=2) - 1 197 | _semseg[_semseg == -1] = 255 198 | return _semseg 199 | 200 | def _load_depth(self, index): 201 | _depth = np.load(self.depths[index]) 202 | _depth = np.expand_dims(_depth.astype(np.float32), axis=2) 203 | return _depth 204 | 205 | def _load_normals(self, index): 206 | _normals = Image.open(self.normals[index]) 207 | _normals = 2 * np.array(_normals, dtype=np.float32, copy=False) / 255. - 1 208 | return _normals 209 | 210 | def __str__(self): 211 | return 'NYUD Multitask (split=' + str(self.split) + ')' 212 | 213 | -------------------------------------------------------------------------------- /data/pascal_context.py: -------------------------------------------------------------------------------- 1 | # This code is referenced from 2 | # https://github.com/facebookresearch/astmt/ 3 | # 4 | # Copyright (c) Facebook, Inc. and its affiliates. 5 | # All rights reserved. 6 | # 7 | # License: Attribution-NonCommercial 4.0 International 8 | 9 | import os 10 | import sys 11 | import tarfile 12 | import json 13 | import cv2 14 | 15 | import numpy as np 16 | import scipy.io as sio 17 | import torch.utils.data as data 18 | from PIL import Image 19 | from skimage.morphology import thin 20 | from six.moves import urllib 21 | 22 | from configs.mypath import PROJECT_ROOT_DIR 23 | 24 | class PASCALContext(data.Dataset): 25 | """ 26 | from MTI-Net 27 | PASCAL-Context dataset, for multiple tasks 28 | Included tasks: 29 | 1. Edge detection, 30 | 2. Semantic Segmentation, 31 | 3. Human Part Segmentation, 32 | 4. Surface Normal prediction (distilled), 33 | 5. Saliency (distilled) 34 | """ 35 | 36 | URL = 'https://data.vision.ee.ethz.ch/kmaninis/share/MTL/PASCAL_MT.tgz' 37 | FILE = 'PASCAL_MT.tgz' 38 | 39 | HUMAN_PART = {1: {'hair': 1, 'head': 1, 'lear': 1, 'lebrow': 1, 'leye': 1, 'lfoot': 1, 40 | 'lhand': 1, 'llarm': 1, 'llleg': 1, 'luarm': 1, 'luleg': 1, 'mouth': 1, 41 | 'neck': 1, 'nose': 1, 'rear': 1, 'rebrow': 1, 'reye': 1, 'rfoot': 1, 42 | 'rhand': 1, 'rlarm': 1, 'rlleg': 1, 'ruarm': 1, 'ruleg': 1, 'torso': 1}, 43 | 4: {'hair': 1, 'head': 1, 'lear': 1, 'lebrow': 1, 'leye': 1, 'lfoot': 4, 44 | 'lhand': 3, 'llarm': 3, 'llleg': 4, 'luarm': 3, 'luleg': 4, 'mouth': 1, 45 | 'neck': 2, 'nose': 1, 'rear': 1, 'rebrow': 1, 'reye': 1, 'rfoot': 4, 46 | 'rhand': 3, 'rlarm': 3, 'rlleg': 4, 'ruarm': 3, 'ruleg': 4, 'torso': 2}, 47 | 6: {'hair': 1, 'head': 1, 'lear': 1, 'lebrow': 1, 'leye': 1, 'lfoot': 6, 48 | 'lhand': 4, 'llarm': 4, 'llleg': 6, 'luarm': 3, 'luleg': 5, 'mouth': 1, 49 | 'neck': 2, 'nose': 1, 'rear': 1, 'rebrow': 1, 'reye': 1, 'rfoot': 6, 50 | 'rhand': 4, 'rlarm': 4, 'rlleg': 6, 'ruarm': 3, 'ruleg': 5, 'torso': 2}, 51 | 14: {'hair': 1, 'head': 1, 'lear': 1, 'lebrow': 1, 'leye': 1, 'lfoot': 14, 52 | 'lhand': 8, 'llarm': 7, 'llleg': 13, 'luarm': 6, 'luleg': 12, 'mouth': 1, 53 | 'neck': 2, 'nose': 1, 'rear': 1, 'rebrow': 1, 'reye': 1, 'rfoot': 11, 54 | 'rhand': 5, 'rlarm': 4, 'rlleg': 10, 'ruarm': 3, 'ruleg': 9, 'torso': 2} 55 | } 56 | 57 | VOC_CATEGORY_NAMES = ['background', 58 | 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 59 | 'bus', 'car', 'cat', 'chair', 'cow', 60 | 'diningtable', 'dog', 'horse', 'motorbike', 'person', 61 | 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'] 62 | 63 | CONTEXT_CATEGORY_LABELS = [0, 64 | 2, 23, 25, 31, 34, 65 | 45, 59, 65, 72, 98, 66 | 397, 113, 207, 258, 284, 67 | 308, 347, 368, 416, 427] 68 | 69 | def __init__(self, 70 | root=None, 71 | download=True, 72 | split='val', 73 | transform=None, 74 | area_thres=0, 75 | retname=True, 76 | overfit=False, 77 | do_edge=True, 78 | do_human_parts=False, 79 | do_semseg=False, 80 | do_normals=False, 81 | do_sal=False, 82 | num_human_parts=6, 83 | ): 84 | 85 | self.root = root 86 | if download: 87 | self._download() 88 | 89 | image_dir = os.path.join(self.root, 'JPEGImages') 90 | self.transform = transform 91 | 92 | if isinstance(split, str): 93 | self.split = [split] 94 | else: 95 | split.sort() 96 | self.split = split 97 | 98 | self.area_thres = area_thres 99 | self.retname = retname 100 | 101 | # Edge Detection 102 | self.do_edge = do_edge 103 | self.edges = [] 104 | edge_gt_dir = os.path.join(self.root, 'pascal-context', 'trainval') 105 | 106 | # Semantic Segmentation 107 | self.do_semseg = do_semseg 108 | self.semsegs = [] 109 | 110 | # Human Part Segmentation 111 | self.do_human_parts = do_human_parts 112 | part_gt_dir = os.path.join(self.root, 'human_parts') 113 | self.parts = [] 114 | self.human_parts_category = 15 115 | print(PROJECT_ROOT_DIR) 116 | self.cat_part = json.load(open(os.path.join(PROJECT_ROOT_DIR, 'data/db_info/pascal_part.json'), 'r')) 117 | self.cat_part["15"] = self.HUMAN_PART[num_human_parts] 118 | self.parts_file = os.path.join(os.path.join(self.root, 'ImageSets', 'Parts'), 119 | ''.join(self.split) + '.txt') 120 | 121 | # Surface Normal Estimation 122 | self.do_normals = do_normals 123 | _normal_gt_dir = os.path.join(self.root, 'normals_distill') 124 | self.normals = [] 125 | if self.do_normals: 126 | with open(os.path.join(PROJECT_ROOT_DIR, 'data/db_info/nyu_classes.json')) as f: 127 | cls_nyu = json.load(f) 128 | with open(os.path.join(PROJECT_ROOT_DIR, 'data/db_info/context_classes.json')) as f: 129 | cls_context = json.load(f) 130 | 131 | self.normals_valid_classes = [] 132 | for cl_nyu in cls_nyu: 133 | if cl_nyu in cls_context and cl_nyu != 'unknown': 134 | self.normals_valid_classes.append(cls_context[cl_nyu]) 135 | 136 | # Custom additions due to incompatibilities 137 | self.normals_valid_classes.append(cls_context['tvmonitor']) 138 | 139 | # Saliency 140 | self.do_sal = do_sal 141 | _sal_gt_dir = os.path.join(self.root, 'sal_distill') 142 | self.sals = [] 143 | 144 | # train/val/test splits are pre-cut 145 | _splits_dir = os.path.join(self.root, 'ImageSets', 'Context') 146 | 147 | self.im_ids = [] 148 | self.images = [] 149 | 150 | print("Initializing dataloader for PASCAL {} set".format(''.join(self.split))) 151 | for splt in self.split: 152 | with open(os.path.join(os.path.join(_splits_dir, splt + '.txt')), "r") as f: 153 | lines = f.read().splitlines() 154 | 155 | for ii, line in enumerate(lines): 156 | # Images 157 | _image = os.path.join(image_dir, line + ".jpg") 158 | assert os.path.isfile(_image) 159 | self.images.append(_image) 160 | self.im_ids.append(line.rstrip('\n')) 161 | 162 | # Edges 163 | _edge = os.path.join(edge_gt_dir, line + ".mat") 164 | assert os.path.isfile(_edge) 165 | self.edges.append(_edge) 166 | 167 | # Semantic Segmentation 168 | _semseg = self._get_semseg_fname(line) 169 | assert os.path.isfile(_semseg) 170 | self.semsegs.append(_semseg) 171 | 172 | # Human Parts 173 | _human_part = os.path.join(self.root, part_gt_dir, line + ".mat") 174 | assert os.path.isfile(_human_part) 175 | self.parts.append(_human_part) 176 | 177 | _normal = os.path.join(self.root, _normal_gt_dir, line + ".png") 178 | assert os.path.isfile(_normal) 179 | self.normals.append(_normal) 180 | 181 | _sal = os.path.join(self.root, _sal_gt_dir, line + ".png") 182 | assert os.path.isfile(_sal) 183 | self.sals.append(_sal) 184 | 185 | if self.do_edge: 186 | assert (len(self.images) == len(self.edges)) 187 | if self.do_human_parts: 188 | assert (len(self.images) == len(self.parts)) 189 | if self.do_semseg: 190 | assert (len(self.images) == len(self.semsegs)) 191 | if self.do_normals: 192 | assert (len(self.images) == len(self.normals)) 193 | if self.do_sal: 194 | assert (len(self.images) == len(self.sals)) 195 | 196 | if not self._check_preprocess_parts(): 197 | print('Pre-processing PASCAL dataset for human parts, this will take long, but will be done only once.') 198 | self._preprocess_parts() 199 | 200 | if self.do_human_parts: 201 | # Find images which have human parts 202 | self.has_human_parts = [] 203 | for ii in range(len(self.im_ids)): 204 | if self.human_parts_category in self.part_obj_dict[self.im_ids[ii]]: 205 | self.has_human_parts.append(1) 206 | else: 207 | self.has_human_parts.append(0) 208 | 209 | # If the other tasks are disabled, select only the images that contain human parts, to allow batching 210 | if not self.do_edge and not self.do_semseg and not self.do_sal and not self.do_normals: 211 | print('Ignoring images that do not contain human parts') 212 | for i in range(len(self.parts) - 1, -1, -1): 213 | if self.has_human_parts[i] == 0: 214 | del self.im_ids[i] 215 | del self.images[i] 216 | del self.parts[i] 217 | del self.has_human_parts[i] 218 | print('Number of images with human parts: {:d}'.format(np.sum(self.has_human_parts))) 219 | 220 | # Overfit to n_of images 221 | if overfit: 222 | n_of = 64 223 | self.images = self.images[:n_of] 224 | self.im_ids = self.im_ids[:n_of] 225 | if self.do_edge: 226 | self.edges = self.edges[:n_of] 227 | if self.do_semseg: 228 | self.semsegs = self.semsegs[:n_of] 229 | if self.do_human_parts: 230 | self.parts = self.parts[:n_of] 231 | if self.do_normals: 232 | self.normals = self.normals[:n_of] 233 | if self.do_sal: 234 | self.sals = self.sals[:n_of] 235 | 236 | # Display stats 237 | print('Number of dataset images: {:d}'.format(len(self.images))) 238 | 239 | def __getitem__(self, index): 240 | sample = {} 241 | 242 | _img = self._load_img(index) 243 | sample['image'] = _img 244 | 245 | if self.do_edge: 246 | _edge = self._load_edge(index) 247 | if _edge.shape != _img.shape[:2]: 248 | _edge = cv2.resize(_edge, _img.shape[:2][::-1], interpolation=cv2.INTER_NEAREST) 249 | sample['edge'] = np.expand_dims(_edge, -1) 250 | 251 | if self.do_human_parts: 252 | _human_parts, _ = self._load_human_parts(index) 253 | if _human_parts.shape != _img.shape[:2]: 254 | _human_parts = cv2.resize(_human_parts, _img.shape[:2][::-1], interpolation=cv2.INTER_NEAREST) 255 | sample['human_parts'] = np.expand_dims(_human_parts, -1) 256 | 257 | if self.do_semseg: 258 | _semseg = self._load_semseg(index) 259 | if _semseg.shape != _img.shape[:2]: 260 | _semseg = cv2.resize(_semseg, _img.shape[:2][::-1], interpolation=cv2.INTER_NEAREST) 261 | sample['semseg'] = np.expand_dims(_semseg, -1) 262 | 263 | if self.do_normals: 264 | _normals = self._load_normals_distilled(index) 265 | if _normals.shape[:2] != _img.shape[:2]: 266 | _normals = cv2.resize(_normals, _img.shape[:2][::-1], interpolation=cv2.INTER_CUBIC) 267 | sample['normals'] = _normals 268 | 269 | if self.do_sal: 270 | _sal = self._load_sal_distilled(index) 271 | if _sal.shape[:2] != _img.shape[:2]: 272 | _sal = cv2.resize(_sal, _img.shape[:2][::-1], interpolation=cv2.INTER_NEAREST) 273 | sample['sal'] = np.expand_dims(_sal, -1) 274 | 275 | if self.retname: 276 | sample['meta'] = {'img_name': str(self.im_ids[index]), 277 | 'img_size': (_img.shape[0], _img.shape[1])} 278 | 279 | if self.transform is not None: 280 | sample = self.transform(sample) 281 | 282 | return sample 283 | 284 | def __len__(self): 285 | return len(self.images) 286 | 287 | def _load_img(self, index): 288 | _img = np.array(Image.open(self.images[index]).convert('RGB')).astype(np.float32) 289 | return _img 290 | 291 | def _load_edge(self, index): 292 | # Read Target object 293 | _tmp = sio.loadmat(self.edges[index]) 294 | _edge = cv2.Laplacian(_tmp['LabelMap'], cv2.CV_64F) 295 | _edge = thin(np.abs(_edge) > 0).astype(np.float32) 296 | return _edge 297 | 298 | def _load_human_parts(self, index): 299 | if self.has_human_parts[index]: 300 | 301 | # Read Target object 302 | _part_mat = sio.loadmat(self.parts[index])['anno'][0][0][1][0] 303 | 304 | _inst_mask = _target = None 305 | 306 | for _obj_ii in range(len(_part_mat)): 307 | 308 | has_human = _part_mat[_obj_ii][1][0][0] == self.human_parts_category 309 | has_parts = len(_part_mat[_obj_ii][3]) != 0 310 | 311 | if has_human and has_parts: 312 | if _inst_mask is None: 313 | _inst_mask = _part_mat[_obj_ii][2].astype(np.float32) 314 | _target = np.zeros(_inst_mask.shape) 315 | else: 316 | _inst_mask = np.maximum(_inst_mask, _part_mat[_obj_ii][2].astype(np.float32)) 317 | 318 | n_parts = len(_part_mat[_obj_ii][3][0]) 319 | for part_i in range(n_parts): 320 | cat_part = str(_part_mat[_obj_ii][3][0][part_i][0][0]) 321 | mask_id = self.cat_part[str(self.human_parts_category)][cat_part] 322 | mask = _part_mat[_obj_ii][3][0][part_i][1].astype(bool) 323 | _target[mask] = mask_id 324 | 325 | if _target is not None: 326 | _target, _inst_mask = _target.astype(np.float32), _inst_mask.astype(np.float32) 327 | else: 328 | _target, _inst_mask = np.zeros((512, 512), dtype=np.float32), np.zeros((512, 512), dtype=np.float32) 329 | 330 | return _target, _inst_mask 331 | 332 | else: 333 | return np.zeros((512, 512), dtype=np.float32), np.zeros((512, 512), dtype=np.float32) 334 | 335 | def _load_semseg(self, index): 336 | _semseg = np.array(Image.open(self.semsegs[index])).astype(np.float32) 337 | 338 | return _semseg 339 | 340 | def _load_normals_distilled(self, index): 341 | _tmp = np.array(Image.open(self.normals[index])).astype(np.float32) 342 | _tmp = 2.0 * _tmp / 255.0 - 1.0 343 | 344 | labels = sio.loadmat(os.path.join(self.root, 'pascal-context', 'trainval', self.im_ids[index] + '.mat')) 345 | labels = labels['LabelMap'] 346 | 347 | _normals = np.zeros(_tmp.shape) 348 | 349 | for x in np.unique(labels): 350 | if x in self.normals_valid_classes: 351 | _normals[labels == x, :] = _tmp[labels == x, :] 352 | 353 | return _normals 354 | 355 | def _load_sal_distilled(self, index): 356 | _sal = np.array(Image.open(self.sals[index])).astype(np.float32) / 255. 357 | _sal = (_sal > 0.5).astype(np.float32) 358 | 359 | return _sal 360 | 361 | def _get_semseg_fname(self, fname): 362 | fname_voc = os.path.join(self.root, 'semseg', 'VOC12', fname + '.png') 363 | fname_context = os.path.join(self.root, 'semseg', 'pascal-context', fname + '.png') 364 | if os.path.isfile(fname_voc): 365 | seg = fname_voc 366 | elif os.path.isfile(fname_context): 367 | seg = fname_context 368 | else: 369 | seg = None 370 | print('Segmentation for im: {} was not found'.format(fname)) 371 | 372 | return seg 373 | 374 | def _check_preprocess_parts(self): 375 | _obj_list_file = self.parts_file 376 | if not os.path.isfile(_obj_list_file): 377 | return False 378 | else: 379 | self.part_obj_dict = json.load(open(_obj_list_file, 'r')) 380 | 381 | return list(np.sort([str(x) for x in self.part_obj_dict.keys()])) == list(np.sort(self.im_ids)) 382 | 383 | def _preprocess_parts(self): 384 | self.part_obj_dict = {} 385 | obj_counter = 0 386 | for ii in range(len(self.im_ids)): 387 | # Read object masks and get number of objects 388 | if ii % 100 == 0: 389 | print("Processing image: {}".format(ii)) 390 | part_mat = sio.loadmat( 391 | os.path.join(self.root, 'human_parts', '{}.mat'.format(self.im_ids[ii]))) 392 | n_obj = len(part_mat['anno'][0][0][1][0]) 393 | 394 | # Get the categories from these objects 395 | _cat_ids = [] 396 | for jj in range(n_obj): 397 | obj_area = np.sum(part_mat['anno'][0][0][1][0][jj][2]) 398 | obj_cat = int(part_mat['anno'][0][0][1][0][jj][1]) 399 | if obj_area > self.area_thres: 400 | _cat_ids.append(int(part_mat['anno'][0][0][1][0][jj][1])) 401 | else: 402 | _cat_ids.append(-1) 403 | obj_counter += 1 404 | 405 | self.part_obj_dict[self.im_ids[ii]] = _cat_ids 406 | 407 | with open(self.parts_file, 'w') as outfile: 408 | outfile.write('{{\n\t"{:s}": {:s}'.format(self.im_ids[0], json.dumps(self.part_obj_dict[self.im_ids[0]]))) 409 | for ii in range(1, len(self.im_ids)): 410 | outfile.write( 411 | ',\n\t"{:s}": {:s}'.format(self.im_ids[ii], json.dumps(self.part_obj_dict[self.im_ids[ii]]))) 412 | outfile.write('\n}\n') 413 | 414 | print('Preprocessing for parts finished') 415 | 416 | def _download(self): 417 | _fpath = os.path.join(self.root, self.FILE) 418 | 419 | if os.path.isfile(_fpath): 420 | print('Files already downloaded') 421 | return 422 | else: 423 | print('Downloading ' + self.URL + ' to ' + _fpath) 424 | 425 | def _progress(count, block_size, total_size): 426 | sys.stdout.write('\r>> %s %.1f%%' % 427 | (_fpath, float(count * block_size) / 428 | float(total_size) * 100.0)) 429 | sys.stdout.flush() 430 | 431 | urllib.request.urlretrieve(self.URL, _fpath, _progress) 432 | 433 | # extract file 434 | cwd = os.getcwd() 435 | print('\nExtracting tar file') 436 | tar = tarfile.open(_fpath) 437 | os.chdir(self.root) 438 | tar.extractall() 439 | tar.close() 440 | os.chdir(cwd) 441 | print('Done!') 442 | 443 | def __str__(self): 444 | return 'PASCAL_MT(split=' + str(self.split) + ')' 445 | -------------------------------------------------------------------------------- /data/transforms.py: -------------------------------------------------------------------------------- 1 | # same transform as ATRC 2 | 3 | import numpy as np 4 | import random 5 | import cv2 6 | import torch 7 | 8 | 9 | class RandomScaling: 10 | """Random scale the input. 11 | Args: 12 | min_scale_factor: Minimum scale value. 13 | max_scale_factor: Maximum scale value. 14 | step_size: The step size from minimum to maximum value. 15 | Returns: 16 | sample: The input sample scaled 17 | """ 18 | 19 | def __init__(self, scale_factors=(0.5, 2.0), discrete=False): 20 | self.scale_factors = scale_factors 21 | self.discrete = discrete 22 | self.mode = { 23 | 'semseg': cv2.INTER_NEAREST, 24 | 'depth': cv2.INTER_NEAREST, 25 | 'normals': cv2.INTER_NEAREST, 26 | 'edge': cv2.INTER_NEAREST, 27 | 'sal': cv2.INTER_NEAREST, 28 | 'human_parts': cv2.INTER_NEAREST, 29 | 'image': cv2.INTER_LINEAR 30 | } 31 | 32 | def get_scale_factor(self): 33 | if self.discrete: 34 | # choose one option out of the list 35 | random_scale = random.choice(self.scale_factors) 36 | else: 37 | assert len(self.scale_factors) == 2 38 | random_scale = random.uniform(*self.scale_factors) 39 | return random_scale 40 | 41 | def scale(self, key, unscaled, scale=1.0): 42 | """Randomly scales image and label. 43 | Args: 44 | key: Key indicating the uscaled input origin 45 | unscaled: Image or target to be scaled. 46 | scale: The value to scale image and label. 47 | Returns: 48 | scaled: The scaled image or target 49 | """ 50 | # No random scaling if scale == 1. 51 | if scale == 1.0: 52 | return unscaled 53 | image_shape = np.shape(unscaled)[0:2] 54 | new_dim = tuple([int(x * scale) for x in image_shape]) 55 | 56 | unscaled = np.squeeze(unscaled) 57 | scaled = cv2.resize(unscaled, new_dim[::-1], interpolation=self.mode[key]) 58 | if scaled.ndim == 2: 59 | scaled = np.expand_dims(scaled, axis=2) 60 | 61 | if key == 'depth': 62 | # ignore regions for depth are 0 63 | scaled /= scale 64 | 65 | return scaled 66 | 67 | def __call__(self, sample): 68 | random_scale = self.get_scale_factor() 69 | for key, val in sample.items(): 70 | if key == 'meta': 71 | continue 72 | sample[key] = self.scale(key, val, scale=random_scale) 73 | return sample 74 | 75 | def __repr__(self): 76 | return self.__class__.__name__ + '()' 77 | 78 | 79 | class PadImage: 80 | """Pad image and label to have dimensions >= [size_height, size_width] 81 | Args: 82 | size: Desired size 83 | Returns: 84 | sample: The input sample padded 85 | """ 86 | 87 | def __init__(self, size): 88 | if isinstance(size, int): 89 | self.size = tuple([size, size]) 90 | elif isinstance(size, (list, tuple)): 91 | self.size = size 92 | else: 93 | raise ValueError('Crop size must be an int, tuple or list') 94 | self.fill_index = {'edge': 255, 95 | 'human_parts': 255, 96 | 'semseg': 255, 97 | 'depth': 0, 98 | 'normals': [0, 0, 0], 99 | 'sal': 255, 100 | 'image': [0, 0, 0]} 101 | 102 | def pad(self, key, unpadded): 103 | unpadded_shape = np.shape(unpadded) 104 | delta_height = max(self.size[0] - unpadded_shape[0], 0) 105 | delta_width = max(self.size[1] - unpadded_shape[1], 0) 106 | 107 | if delta_height == 0 and delta_width == 0: 108 | return unpadded 109 | 110 | # Location to place image 111 | height_location = [delta_height // 2, 112 | (delta_height // 2) + unpadded_shape[0]] 113 | width_location = [delta_width // 2, 114 | (delta_width // 2) + unpadded_shape[1]] 115 | 116 | pad_value = self.fill_index[key] 117 | max_height = max(self.size[0], unpadded_shape[0]) 118 | max_width = max(self.size[1], unpadded_shape[1]) 119 | 120 | padded = np.full((max_height, max_width, unpadded_shape[2]), 121 | pad_value, dtype=np.float32) 122 | padded[height_location[0]:height_location[1], 123 | width_location[0]:width_location[1], :] = unpadded 124 | # else: 125 | # padded = np.full((max_height, max_width), 126 | # pad_value, dtype=np.float32) 127 | # padded[height_location[0]:height_location[1], 128 | # width_location[0]:width_location[1]] = unpadded 129 | 130 | return padded 131 | 132 | def __call__(self, sample): 133 | for key, val in sample.items(): 134 | if key == 'meta': 135 | continue 136 | sample[key] = self.pad(key, val) 137 | return sample 138 | 139 | def __repr__(self): 140 | return self.__class__.__name__ + '()' 141 | 142 | 143 | class RandomCrop: 144 | """Random crop image if it exceeds desired size 145 | Args: 146 | size: Desired size 147 | Returns: 148 | sample: The input sample randomly cropped 149 | """ 150 | 151 | def __init__(self, size, cat_max_ratio=1): 152 | if isinstance(size, int): 153 | self.size = tuple([size, size]) 154 | elif isinstance(size, (list, tuple)): 155 | self.size = size 156 | else: 157 | raise ValueError('Crop size must be an int, tuple or list') 158 | self.cat_max_ratio = cat_max_ratio # need semantic labels for this 159 | 160 | def get_random_crop_loc(self, uncropped): 161 | """Gets a random crop location. 162 | Args: 163 | key: Key indicating the uncropped input origin 164 | uncropped: Image or target to be cropped. 165 | Returns: 166 | Cropping region. 167 | """ 168 | uncropped_shape = np.shape(uncropped) 169 | img_height = uncropped_shape[0] 170 | img_width = uncropped_shape[1] 171 | 172 | crop_height = self.size[0] 173 | crop_width = self.size[1] 174 | if img_height == crop_height and img_width == crop_width: 175 | return None 176 | # Get random offset uniformly from [0, max_offset] 177 | max_offset_height = max(img_height - crop_height, 0) 178 | max_offset_width = max(img_width - crop_width, 0) 179 | 180 | offset_height = random.randint(0, max_offset_height) 181 | offset_width = random.randint(0, max_offset_width) 182 | crop_loc = [offset_height, offset_height + crop_height, 183 | offset_width, offset_width + crop_width] 184 | 185 | return crop_loc 186 | 187 | def random_crop(self, key, uncropped, crop_loc): 188 | if crop_loc is None: 189 | return uncropped 190 | 191 | cropped = uncropped[crop_loc[0]:crop_loc[1], 192 | crop_loc[2]:crop_loc[3], :] 193 | return cropped 194 | 195 | def __call__(self, sample): 196 | crop_location = self.get_random_crop_loc(sample['image']) 197 | if self.cat_max_ratio < 1.: 198 | # Repeat 10 times 199 | for _ in range(10): 200 | try: 201 | seg_tmp = self.random_crop('semseg', sample['semseg'], crop_location) 202 | labels, cnt = np.unique(seg_tmp, return_counts=True) 203 | cnt = cnt[labels != 255] 204 | if len(cnt) > 1 and np.max(cnt) / np.sum(cnt) < self.cat_max_ratio: 205 | break 206 | except: 207 | pass 208 | crop_location = self.get_random_crop_loc(sample['image']) 209 | 210 | for key, val in sample.items(): 211 | if key == 'meta': 212 | continue 213 | sample[key] = self.random_crop(key, val, crop_location) 214 | return sample 215 | 216 | def __repr__(self): 217 | return self.__class__.__name__ + '()' 218 | 219 | 220 | class RandomHorizontalFlip: 221 | """Horizontally flip the given image and ground truth randomly.""" 222 | 223 | def __init__(self, p=0.5): 224 | self.p = p 225 | 226 | def __call__(self, sample): 227 | if random.random() < self.p: 228 | for key, val in sample.items(): 229 | if key == 'meta': 230 | continue 231 | sample[key] = np.fliplr(val).copy() 232 | if key == 'normals': 233 | sample[key][:, :, 0] *= -1 234 | return sample 235 | 236 | def __repr__(self): 237 | return self.__class__.__name__ + '()' 238 | 239 | 240 | class Normalize: 241 | """ Normalize image values by first mapping from [0, 255] to [0, 1] and then 242 | applying standardization. 243 | """ 244 | 245 | def __init__(self, mean, std): 246 | self.mean = np.array(mean, dtype=np.float32).reshape(1, 1, 3) 247 | self.std = np.array(std, dtype=np.float32).reshape(1, 1, 3) 248 | 249 | def normalize_img(self, img): 250 | assert img.dtype == np.float32 251 | scaled = img.copy() / 255. 252 | scaled -= self.mean 253 | scaled /= self.std 254 | return scaled 255 | 256 | def __call__(self, sample): 257 | """Call function to normalize images. 258 | Args: 259 | results (dict): Result dict from loading pipeline. 260 | Returns: 261 | dict: Normalized results, 'img_norm_cfg' key is added into 262 | result dict. 263 | """ 264 | sample['image'] = self.normalize_img(sample['image']) 265 | return sample 266 | 267 | 268 | class ToTensor: 269 | """Convert ndarrays in sample to Tensors.""" 270 | 271 | def __call__(self, sample): 272 | for key, val in sample.items(): 273 | if key == 'meta': 274 | continue 275 | sample[key] = torch.from_numpy(val.transpose((2, 0, 1))).float() 276 | return sample 277 | 278 | def __repr__(self): 279 | return self.__class__.__name__ + '()' 280 | 281 | 282 | class AddIgnoreRegions: 283 | """Add Ignore Regions""" 284 | 285 | def __call__(self, sample): 286 | for elem in sample.keys(): 287 | tmp = sample[elem] 288 | if elem == 'normals': 289 | # Check areas with norm 0 290 | norm = np.sqrt(tmp[:, :, 0] ** 2 + 291 | tmp[:, :, 1] ** 2 + tmp[:, :, 2] ** 2) 292 | tmp[norm == 0, :] = 255 293 | sample[elem] = tmp 294 | elif elem == 'human_parts': 295 | # Check for images without human part annotations 296 | if ((tmp == 0) | (tmp == 255)).all(): 297 | tmp = np.full(tmp.shape, 255, dtype=tmp.dtype) 298 | sample[elem] = tmp 299 | elif elem == 'depth': 300 | tmp[tmp == 0] = 255 301 | sample[elem] = tmp 302 | return sample 303 | 304 | def __repr__(self): 305 | return self.__class__.__name__ + '()' 306 | 307 | 308 | class PhotoMetricDistortion: 309 | """Apply photometric distortion to image sequentially, every transformation 310 | is applied with a probability of 0.5. The position of random contrast is in 311 | second or second to last. 312 | 1. random brightness 313 | 2. random contrast (mode 0) 314 | 3. convert color from BGR to HSV 315 | 4. random saturation 316 | 5. random hue 317 | 6. convert color from HSV to BGR 318 | 7. random contrast (mode 1) 319 | 8. randomly swap channels 320 | Args: 321 | brightness_delta (int): delta of brightness. 322 | contrast_range (tuple): range of contrast. 323 | saturation_range (tuple): range of saturation. 324 | hue_delta (int): delta of hue. 325 | """ 326 | 327 | def __init__(self, 328 | brightness_delta=32, 329 | contrast_range=(0.5, 1.5), 330 | saturation_range=(0.5, 1.5), 331 | hue_delta=18): 332 | self.brightness_delta = brightness_delta 333 | self.contrast_lower, self.contrast_upper = contrast_range 334 | self.saturation_lower, self.saturation_upper = saturation_range 335 | self.hue_delta = hue_delta 336 | 337 | def convert(self, img, alpha=1, beta=0): 338 | """Multiple with alpha and add beat with clip.""" 339 | img = img.astype(np.float32) * alpha + beta 340 | img = np.clip(img, 0, 255) 341 | return img.astype(np.uint8) 342 | 343 | def brightness(self, img): 344 | """Brightness distortion.""" 345 | if random.random() < 0.5: 346 | return self.convert( 347 | img, 348 | beta=random.uniform(-self.brightness_delta, 349 | self.brightness_delta)) 350 | return img 351 | 352 | def contrast(self, img): 353 | """Contrast distortion.""" 354 | if random.random() < 0.5: 355 | return self.convert( 356 | img, 357 | alpha=random.uniform(self.contrast_lower, self.contrast_upper)) 358 | return img 359 | 360 | def saturation(self, img): 361 | """Saturation distortion.""" 362 | if random.random() < 0.5: 363 | img = cv2.cvtColor(img, cv2.COLOR_RGB2HSV) 364 | img[:, :, 1] = self.convert( 365 | img[:, :, 1], 366 | alpha=random.uniform(self.saturation_lower, 367 | self.saturation_upper)) 368 | img = cv2.cvtColor(img, cv2.COLOR_HSV2RGB) 369 | return img 370 | 371 | def hue(self, img): 372 | """Hue distortion.""" 373 | if random.random() < 0.5: 374 | img = cv2.cvtColor(img, cv2.COLOR_RGB2HSV) 375 | img[:, :, 0] = (img[:, :, 0].astype(int) + random.randint(-self.hue_delta, self.hue_delta - 1)) % 180 376 | img = cv2.cvtColor(img, cv2.COLOR_HSV2RGB) 377 | return img 378 | 379 | def __call__(self, sample): 380 | """Call function to perform photometric distortion on images. 381 | Args: 382 | results (dict): Result dict from loading pipeline. 383 | Returns: 384 | dict: Result dict with images distorted. 385 | """ 386 | 387 | img = sample['image'] 388 | img = img.astype(np.uint8) # functions need a uint8 image 389 | 390 | # random brightness 391 | img = self.brightness(img) 392 | 393 | # f_mode == True --> do random contrast first 394 | # else --> do random contrast last 395 | f_mode = random.random() < 0.5 396 | if f_mode: 397 | img = self.contrast(img) 398 | 399 | # random saturation 400 | img = self.saturation(img) 401 | 402 | # random hue 403 | img = self.hue(img) 404 | 405 | # random contrast 406 | if not f_mode: 407 | img = self.contrast(img) 408 | 409 | sample['image'] = img.astype(np.float32) 410 | return sample 411 | 412 | def __repr__(self): 413 | repr_str = self.__class__.__name__ 414 | repr_str += (f'(brightness_delta={self.brightness_delta}, ' 415 | f'contrast_range=({self.contrast_lower}, ' 416 | f'{self.contrast_upper}), ' 417 | f'saturation_range=({self.saturation_lower}, ' 418 | f'{self.saturation_upper}), ' 419 | f'hue_delta={self.hue_delta})') 420 | return repr_str 421 | 422 | class DirectResize: 423 | """Resize samples so that the max dimension is the same as the giving one. The aspect ratio is kept. 424 | """ 425 | 426 | def __init__(self, size): 427 | self.size = size 428 | 429 | self.mode = { 430 | 'image': cv2.INTER_LINEAR 431 | } 432 | 433 | 434 | def resize(self, key, ori): 435 | new = cv2.resize(ori, self.size[::-1], interpolation=self.mode[key]) 436 | 437 | return new 438 | 439 | def __call__(self, sample): 440 | 441 | for key, val in sample.items(): 442 | if key == 'image': 443 | sample[key] = self.resize(key, val) 444 | 445 | return sample 446 | 447 | def __repr__(self): 448 | return self.__class__.__name__ + '()' -------------------------------------------------------------------------------- /evaluation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EnVision-Research/MTMamba/17d132d2a7a1ef80c35ffe40c8dd3ff72fa07d77/evaluation/__init__.py -------------------------------------------------------------------------------- /evaluation/eval_depth.py: -------------------------------------------------------------------------------- 1 | # This code is referenced from 2 | # https://github.com/facebookresearch/astmt/ 3 | # 4 | # Copyright (c) Facebook, Inc. and its affiliates. 5 | # All rights reserved. 6 | # 7 | # License: Attribution-NonCommercial 4.0 International 8 | 9 | from shutil import ignore_patterns 10 | import warnings 11 | import cv2 12 | import os.path 13 | import numpy as np 14 | import glob 15 | import torch 16 | import json 17 | import scipy.io as sio 18 | 19 | class DepthMeter(object): 20 | def __init__(self, ignore_index=255, max_depth=None, min_depth=None): 21 | self.total_rmses = 0.0 22 | self.total_log_rmses = 0.0 23 | self.n_valid = 0.0 24 | self.ignore_index = ignore_index 25 | self.max_depth = max_depth 26 | self.min_depth = min_depth 27 | 28 | self.abs_rel = 0.0 29 | self.sq_rel = 0.0 30 | 31 | @torch.no_grad() 32 | def update(self, pred, gt): 33 | pred, gt = pred.squeeze(), gt.squeeze() 34 | 35 | # Determine valid mask 36 | if self.max_depth is not None and self.min_depth is not None: 37 | mask = torch.logical_and(gt < self.max_depth, gt > self.min_depth) 38 | else: 39 | mask = (gt != self.ignore_index).bool() 40 | self.n_valid += mask.float().sum().item() # Valid pixels per image 41 | 42 | # Only positive depth values are possible 43 | # pred = torch.clamp(pred, min=1e-9) 44 | gt[gt <=0 ] = 1e-9 45 | pred[pred <= 0] = 1e-9 46 | 47 | # Per pixel rmse and log-rmse. 48 | log_rmse_tmp = torch.pow(torch.log(gt[mask]) - torch.log(pred[mask]), 2) 49 | self.total_log_rmses += log_rmse_tmp.sum().item() 50 | 51 | rmse_tmp = torch.pow(gt[mask] - pred[mask], 2) 52 | self.total_rmses += rmse_tmp.sum().item() 53 | 54 | # abs rel 55 | self.abs_rel += (torch.abs(gt[mask] - pred[mask]) / gt[mask]).sum().item() 56 | # sq_rel 57 | self.sq_rel += (((gt[mask] - pred[mask]) ** 2) / gt[mask]).sum().item() 58 | 59 | def reset(self): 60 | self.rmses = [] 61 | self.log_rmses = [] 62 | 63 | def get_score(self, verbose=True): 64 | eval_result = dict() 65 | eval_result['rmse'] = np.sqrt(self.total_rmses / self.n_valid) 66 | eval_result['log_rmse'] = np.sqrt(self.total_log_rmses / self.n_valid) 67 | eval_result['abs_rel'] = self.abs_rel / self.n_valid 68 | eval_result['sq_rel'] = self.sq_rel / self.n_valid 69 | 70 | if verbose: 71 | print('Results for depth prediction') 72 | for x in eval_result: 73 | spaces = '' 74 | for j in range(0, 15 - len(x)): 75 | spaces += ' ' 76 | print('{0:s}{1:s}{2:.4f}'.format(x, spaces, eval_result[x])) 77 | 78 | return eval_result 79 | -------------------------------------------------------------------------------- /evaluation/eval_edge.py: -------------------------------------------------------------------------------- 1 | # This code is referenced from MTI-Net 2 | # Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/) 3 | 4 | import os 5 | import glob 6 | import json 7 | import torch 8 | import numpy as np 9 | from utils.utils import mkdir_if_missing 10 | from losses.loss_functions import BalancedBinaryCrossEntropyLoss 11 | from configs.mypath import PROJECT_ROOT_DIR 12 | 13 | class EdgeMeter(object): 14 | def __init__(self, pos_weight, ignore_index): 15 | self.loss = 0 16 | self.n = 0 17 | self.loss_function = BalancedBinaryCrossEntropyLoss(pos_weight=pos_weight, ignore_index=ignore_index) 18 | self.ignore_index = ignore_index 19 | 20 | @torch.no_grad() 21 | def update(self, pred, gt): 22 | gt = gt.squeeze() 23 | valid_mask = (gt != self.ignore_index) 24 | pred = pred[valid_mask] 25 | gt = gt[valid_mask] 26 | 27 | pred = pred.float().squeeze() / 255. 28 | loss = self.loss_function(pred, gt).item() 29 | numel = gt.numel() 30 | self.n += numel 31 | self.loss += numel * loss 32 | 33 | def reset(self): 34 | self.loss = 0 35 | self.n = 0 36 | 37 | def get_score(self, verbose=True): 38 | eval_dict = {'loss': self.loss / self.n} 39 | 40 | if verbose: 41 | print('\n Edge Detection Evaluation') 42 | print('Edge Detection Loss %.3f' %(eval_dict['loss'])) 43 | 44 | return eval_dict 45 | 46 | -------------------------------------------------------------------------------- /evaluation/eval_human_parts.py: -------------------------------------------------------------------------------- 1 | # This code is referenced from 2 | # https://github.com/facebookresearch/astmt/ 3 | # 4 | # Copyright (c) Facebook, Inc. and its affiliates. 5 | # All rights reserved. 6 | # 7 | # License: Attribution-NonCommercial 4.0 International 8 | 9 | import warnings 10 | import cv2 11 | import glob 12 | import json 13 | import os.path 14 | import numpy as np 15 | import torch 16 | from PIL import Image 17 | 18 | PART_CATEGORY_NAMES = ['background', 'head', 'torso', 'uarm', 'larm', 'uleg', 'lleg'] 19 | 20 | class HumanPartsMeter(object): 21 | def __init__(self, database, ignore_idx=255): 22 | assert(database == 'PASCALContext') 23 | self.database = database 24 | self.cat_names = PART_CATEGORY_NAMES 25 | self.n_parts = 6 26 | self.tp = [0] * (self.n_parts + 1) 27 | self.fp = [0] * (self.n_parts + 1) 28 | self.fn = [0] * (self.n_parts + 1) 29 | 30 | self.ignore_idx = ignore_idx 31 | 32 | @torch.no_grad() 33 | def update(self, pred, gt): 34 | pred, gt = pred.squeeze(), gt.squeeze() 35 | valid = (gt != self.ignore_idx) 36 | 37 | for i_part in range(self.n_parts + 1): 38 | tmp_gt = (gt == i_part) 39 | tmp_pred = (pred == i_part) 40 | self.tp[i_part] += torch.sum(tmp_gt & tmp_pred & (valid)).item() 41 | self.fp[i_part] += torch.sum(~tmp_gt & tmp_pred & (valid)).item() 42 | self.fn[i_part] += torch.sum(tmp_gt & ~tmp_pred & (valid)).item() 43 | 44 | def reset(self): 45 | self.tp = [0] * (self.n_parts + 1) 46 | self.fp = [0] * (self.n_parts + 1) 47 | self.fn = [0] * (self.n_parts + 1) 48 | 49 | def get_score(self, verbose=True): 50 | jac = [0] * (self.n_parts + 1) 51 | for i_part in range(0, self.n_parts + 1): 52 | jac[i_part] = float(self.tp[i_part]) / max(float(self.tp[i_part] + self.fp[i_part] + self.fn[i_part]), 1e-8) 53 | 54 | eval_result = dict() 55 | # eval_result['jaccards_all_categs'] = jac 56 | eval_result['mIoU'] = np.mean(jac) 57 | 58 | print('\nHuman Parts mIoU: {0:.4f}\n'.format(100 * eval_result['mIoU'])) 59 | class_IoU = jac 60 | for i in range(len(class_IoU)): 61 | spaces = '' 62 | for j in range(0, 15 - len(self.cat_names[i])): 63 | spaces += ' ' 64 | print('{0:s}{1:s}{2:.4f}'.format(self.cat_names[i], spaces, 100 * class_IoU[i])) 65 | 66 | return eval_result 67 | -------------------------------------------------------------------------------- /evaluation/eval_normals.py: -------------------------------------------------------------------------------- 1 | # This code is referenced from 2 | # https://github.com/facebookresearch/astmt/ 3 | # 4 | # Copyright (c) Facebook, Inc. and its affiliates. 5 | # All rights reserved. 6 | # 7 | # License: Attribution-NonCommercial 4.0 International 8 | 9 | import warnings 10 | import cv2 11 | import os.path 12 | import numpy as np 13 | import glob 14 | import math 15 | import torch 16 | import json 17 | 18 | 19 | def normalize_tensor(input_tensor, dim): 20 | norm = torch.norm(input_tensor, p='fro', dim=dim, keepdim=True) 21 | zero_mask = (norm == 0) 22 | norm[zero_mask] = 1 23 | out = input_tensor.div(norm) 24 | out[zero_mask.expand_as(out)] = 0 25 | return out 26 | 27 | class NormalsMeter(object): 28 | def __init__(self, ignore_index=255): 29 | self.sum_deg_diff = 0 30 | self.total = 0 31 | self.ignore_index = ignore_index 32 | 33 | @torch.no_grad() 34 | def update(self, pred, gt): 35 | pred = pred.permute(0, 3, 1, 2) # [B, C, H, W] 36 | pred = 2 * pred / 255 - 1 # reverse post-processing 37 | valid_mask = (gt != self.ignore_index).all(dim=1) 38 | 39 | pred = normalize_tensor(pred, dim=1) 40 | gt = normalize_tensor(gt, dim=1) 41 | deg_diff = torch.rad2deg(2 * torch.atan2(torch.norm(pred - gt, dim=1), torch.norm(pred + gt, dim=1))) 42 | deg_diff = torch.masked_select(deg_diff, valid_mask) 43 | 44 | self.sum_deg_diff += torch.sum(deg_diff).cpu().item() 45 | self.total += deg_diff.numel() 46 | 47 | def get_score(self, verbose=False): 48 | eval_result = dict() 49 | eval_result['mean'] = self.sum_deg_diff / self.total 50 | 51 | return eval_result -------------------------------------------------------------------------------- /evaluation/eval_sal.py: -------------------------------------------------------------------------------- 1 | # This code is referenced from 2 | # https://github.com/facebookresearch/astmt/ 3 | # 4 | # Copyright (c) Facebook, Inc. and its affiliates. 5 | # All rights reserved. 6 | # 7 | # License: Attribution-NonCommercial 4.0 International 8 | 9 | import torch 10 | from torch import nn 11 | 12 | class SaliencyMeter(object): 13 | def __init__(self, ignore_index=255, threshold_step=None, beta_squared=1): 14 | self.ignore_index = ignore_index 15 | self.beta_squared = beta_squared 16 | self.thresholds = torch.arange(threshold_step, 1, threshold_step) 17 | self.true_positives = torch.zeros(len(self.thresholds)) 18 | self.predicted_positives = torch.zeros(len(self.thresholds)) 19 | self.actual_positives = torch.zeros(len(self.thresholds)) 20 | 21 | @torch.no_grad() 22 | def update(self, preds, target): 23 | """ 24 | Update state with predictions and targets. 25 | 26 | Args: 27 | preds: Predictions from model [B, H, W] 28 | target: Ground truth values 29 | """ 30 | preds = preds.float() / 255. 31 | 32 | if target.shape[1] == 1: 33 | target = target.squeeze(1) 34 | 35 | assert preds.shape == target.shape 36 | 37 | if len(preds.shape) == len(target.shape) + 1: 38 | assert preds.shape[1] == 2 39 | # two class probabilites 40 | preds = nn.functional.softmax(preds, dim=1)[:, 1, :, :] 41 | else: 42 | # squash logits into probabilities 43 | preds = torch.sigmoid(preds) 44 | 45 | if not len(preds.shape) == len(target.shape): 46 | raise ValueError("preds and target must have same number of dimensions, or preds one more") 47 | 48 | valid_mask = (target != self.ignore_index) 49 | 50 | for idx, thresh in enumerate(self.thresholds): 51 | # threshold probablities 52 | f_preds = (preds >= thresh).long() 53 | f_target = target.long() 54 | 55 | f_preds = torch.masked_select(f_preds, valid_mask) 56 | f_target = torch.masked_select(f_target, valid_mask) 57 | 58 | self.true_positives[idx] += torch.sum(f_preds * f_target).cpu() 59 | self.predicted_positives[idx] += torch.sum(f_preds).cpu() 60 | self.actual_positives[idx] += torch.sum(f_target).cpu() 61 | 62 | 63 | def get_score(self, verbose=False): 64 | """ 65 | Computes F-scores over state and returns the max. 66 | """ 67 | precision = self.true_positives.float() / self.predicted_positives 68 | recall = self.true_positives.float() / self.actual_positives 69 | 70 | num = (1 + self.beta_squared) * precision * recall 71 | denom = self.beta_squared * precision + recall 72 | 73 | # For the rest we need to take care of instances where the denom can be 0 74 | # for some classes which will produce nans for that class 75 | fscore = num / denom 76 | fscore[fscore != fscore] = 0 77 | 78 | eval_result = {'maxF': fscore.max().item()} 79 | return eval_result 80 | -------------------------------------------------------------------------------- /evaluation/eval_semseg.py: -------------------------------------------------------------------------------- 1 | # This code is referenced from 2 | # https://github.com/facebookresearch/astmt/ 3 | # 4 | # Copyright (c) Facebook, Inc. and its affiliates. 5 | # All rights reserved. 6 | # 7 | # License: Attribution-NonCommercial 4.0 International 8 | 9 | import warnings 10 | import cv2 11 | import os.path 12 | import glob 13 | import json 14 | import numpy as np 15 | import torch 16 | from PIL import Image 17 | import pdb 18 | 19 | VOC_CATEGORY_NAMES = ['background', 20 | 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 21 | 'bus', 'car', 'cat', 'chair', 'cow', 22 | 'diningtable', 'dog', 'horse', 'motorbike', 'person', 23 | 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'] 24 | 25 | 26 | NYU_CATEGORY_NAMES = ['wall', 'floor', 'cabinet', 'bed', 'chair', 27 | 'sofa', 'table', 'door', 'window', 'bookshelf', 28 | 'picture', 'counter', 'blinds', 'desk', 'shelves', 29 | 'curtain', 'dresser', 'pillow', 'mirror', 'floor mat', 30 | 'clothes', 'ceiling', 'books', 'refridgerator', 'television', 31 | 'paper', 'towel', 'shower curtain', 'box', 'whiteboard', 32 | 'person', 'night stand', 'toilet', 'sink', 'lamp', 33 | 'bathtub', 'bag', 'otherstructure', 'otherfurniture', 'otherprop'] 34 | 35 | CITYSCAPES_CATEGORY_NAMES = ['road', 'sidewalk', 'building', 'wall', 'fence',\ 36 | 'pole', 'traffic_light', 'traffic_sign', 'vegetation', 'terrain',\ 37 | 'sky', 'person', 'rider', 'car', 'truck', 'bus', 'train', \ 38 | 'motorcycle', 'bicycle'] 39 | 40 | class SemsegMeter(object): 41 | def __init__(self, database, ignore_idx=255): 42 | ''' "marco" way in ATRC evaluation code. 43 | ''' 44 | if database == 'PASCALContext': 45 | n_classes = 20 46 | cat_names = VOC_CATEGORY_NAMES 47 | has_bg = True 48 | 49 | elif database == 'NYUD': 50 | n_classes = 40 51 | cat_names = NYU_CATEGORY_NAMES 52 | has_bg = False 53 | 54 | elif database == 'Cityscapes': 55 | n_classes = 19 56 | cat_names = CITYSCAPES_CATEGORY_NAMES 57 | has_bg = False 58 | 59 | else: 60 | raise NotImplementedError 61 | 62 | self.n_classes = n_classes + int(has_bg) 63 | self.cat_names = cat_names 64 | self.tp = [0] * self.n_classes 65 | self.fp = [0] * self.n_classes 66 | self.fn = [0] * self.n_classes 67 | 68 | self.ignore_idx = ignore_idx 69 | 70 | @torch.no_grad() 71 | def update(self, pred, gt): 72 | pred = pred.squeeze() 73 | gt = gt.squeeze() 74 | valid = (gt != self.ignore_idx) 75 | 76 | for i_part in range(0, self.n_classes): 77 | tmp_gt = (gt == i_part) 78 | tmp_pred = (pred == i_part) 79 | self.tp[i_part] += torch.sum(tmp_gt & tmp_pred & valid).item() 80 | self.fp[i_part] += torch.sum(~tmp_gt & tmp_pred & valid).item() 81 | self.fn[i_part] += torch.sum(tmp_gt & ~tmp_pred & valid).item() 82 | 83 | def reset(self): 84 | self.tp = [0] * self.n_classes 85 | self.fp = [0] * self.n_classes 86 | self.fn = [0] * self.n_classes 87 | 88 | def get_score(self, verbose=True): 89 | jac = [0] * self.n_classes 90 | for i_part in range(self.n_classes): 91 | jac[i_part] = float(self.tp[i_part]) / max(float(self.tp[i_part] + self.fp[i_part] + self.fn[i_part]), 1e-8) 92 | 93 | eval_result = dict() 94 | # eval_result['jaccards_all_categs'] = jac 95 | eval_result['mIoU'] = np.mean(jac) 96 | 97 | 98 | if verbose: 99 | print('\nSemantic Segmentation mIoU: {0:.4f}\n'.format(100 * eval_result['mIoU'])) 100 | class_IoU = jac #eval_result['jaccards_all_categs'] 101 | for i in range(len(class_IoU)): 102 | spaces = '' 103 | for j in range(0, 20 - len(self.cat_names[i])): 104 | spaces += ' ' 105 | print('{0:s}{1:s}{2:.4f}'.format(self.cat_names[i], spaces, 100 * class_IoU[i])) 106 | 107 | return eval_result 108 | -------------------------------------------------------------------------------- /evaluation/evaluate_utils.py: -------------------------------------------------------------------------------- 1 | # This code is referenced from MTI-Net 2 | # Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/) 3 | 4 | import os 5 | import cv2 6 | import imageio 7 | import numpy as np 8 | import json 9 | import torch 10 | import scipy.io as sio 11 | from utils.utils import get_output 12 | import pdb 13 | 14 | def count_improvement(dataset_name, eval_results, task_list): 15 | if dataset_name == 'NYUD': 16 | base_result_dict = {'semseg': 0.55, 'depth': 0.49, 'normals': 18.40, 'edge': 0.0465} 17 | task_metric = {'semseg': 'mIoU', 'depth': 'rmse', 'normals': 'mean', 'edge': 'loss'} 18 | weight_dict = {'semseg': 1, 'depth': 0, 'normals': 0, 'edge': 0} 19 | elif dataset_name == 'PASCALContext': 20 | base_result_dict = {'semseg': 80.89, 'human_parts': 71.71, 'sal': 85.28, 'normals': 13.47, 'edge': 0.04} 21 | task_metric = {'semseg': 'mIoU', 'human_parts': 'mIoU', 'sal': 'maxF', 'normals': 'mean', 'edge': 'loss'} 22 | weight_dict = {'semseg': 1, 'human_parts': 1, 'sal': 1, 'normals': 0, 'edge': 0} 23 | elif dataset_name == 'Cityscapes': 24 | base_result_dict = {'semseg': 80, 'depth': 1} 25 | task_metric = {'semseg': 'mIoU', 'depth': 'rmse'} 26 | weight_dict = {'semseg': 1, 'depth': 0} 27 | else: 28 | raise ValueError 29 | 30 | base_result, new_result, weight = [], [], [] 31 | for tname in task_list: 32 | base_result.append(base_result_dict[tname]) 33 | new_result.append(eval_results[tname][task_metric[tname]]) 34 | weight.append(weight_dict[tname]) 35 | 36 | assert len(weight) == len(base_result) == len(new_result) 37 | improvement = (((-1)**np.array(weight))*\ 38 | (np.array(base_result)-np.array(new_result))/\ 39 | np.array(base_result)).mean() 40 | return improvement 41 | 42 | 43 | class PerformanceMeter(object): 44 | """ A general performance meter which shows performance across one or more tasks """ 45 | def __init__(self, p, tasks): 46 | self.database = p['train_db_name'] 47 | self.tasks = tasks 48 | self.meters = {t: get_single_task_meter(p, self.database, t) for t in self.tasks} 49 | 50 | def reset(self): 51 | for t in self.tasks: 52 | self.meters[t].reset() 53 | 54 | def update(self, pred, gt): 55 | for t in self.tasks: 56 | self.meters[t].update(pred[t], gt[t]) 57 | 58 | def get_score(self, verbose=True): 59 | eval_dict = {} 60 | for t in self.tasks: 61 | eval_dict[t] = self.meters[t].get_score(verbose) 62 | 63 | return eval_dict 64 | 65 | def get_single_task_meter(p, database, task): 66 | """ Retrieve a meter to measure the single-task performance """ 67 | 68 | # ignore index based on transforms.AddIgnoreRegions 69 | if task == 'semseg': 70 | from evaluation.eval_semseg import SemsegMeter 71 | return SemsegMeter(database, ignore_idx=p.ignore_index) 72 | 73 | elif task == 'human_parts': 74 | from evaluation.eval_human_parts import HumanPartsMeter 75 | return HumanPartsMeter(database, ignore_idx=p.ignore_index) 76 | 77 | elif task == 'normals': 78 | from evaluation.eval_normals import NormalsMeter 79 | return NormalsMeter(ignore_index=p.ignore_index) 80 | 81 | elif task == 'sal': 82 | from evaluation.eval_sal import SaliencyMeter 83 | return SaliencyMeter(ignore_index=p.ignore_index, threshold_step=0.05, beta_squared=0.3) 84 | 85 | elif task == 'depth': 86 | from evaluation.eval_depth import DepthMeter 87 | return DepthMeter(ignore_index=p.ignore_index, max_depth=p.TASKS.depth_max, min_depth=p.TASKS.depth_min) 88 | 89 | elif task == 'edge': # just for reference 90 | from evaluation.eval_edge import EdgeMeter 91 | return EdgeMeter(pos_weight=p['edge_w'], ignore_index=p.ignore_index) 92 | 93 | else: 94 | raise NotImplementedError 95 | 96 | @torch.no_grad() 97 | def save_model_pred_for_one_task(p, sample, output, save_dirs, task=None, epoch=None): 98 | """ Save model predictions for one task""" 99 | 100 | inputs, meta = sample['image'].cuda(non_blocking=True), sample['meta'] 101 | output_task = get_output(output[task], task) 102 | 103 | for jj in range(int(inputs.size()[0])): 104 | if len(sample[task][jj].unique()) == 1 and sample[task][jj].unique() == p.ignore_index: 105 | continue 106 | fname = meta['img_name'][jj] 107 | 108 | im_height = meta['img_size'][jj][0] 109 | im_width = meta['img_size'][jj][1] 110 | pred = output_task[jj] # (H, W) or (H, W, C) 111 | # if we used padding on the input, we crop the prediction accordingly 112 | if (im_height, im_width) != pred.shape[:2]: 113 | delta_height = max(pred.shape[0] - im_height, 0) 114 | delta_width = max(pred.shape[1] - im_width, 0) 115 | if delta_height > 0 or delta_width > 0: 116 | # deprecated by python 117 | # height_location = [delta_height // 2, 118 | # (delta_height // 2) + im_height] 119 | # width_location = [delta_width // 2, 120 | # (delta_width // 2) + im_width] 121 | height_begin = torch.div(delta_height, 2, rounding_mode="trunc") 122 | height_location = [height_begin, height_begin + im_height] 123 | width_begin =torch.div(delta_width, 2, rounding_mode="trunc") 124 | width_location = [width_begin, width_begin + im_width] 125 | pred = pred[height_location[0]:height_location[1], 126 | width_location[0]:width_location[1]] 127 | assert pred.shape[:2] == (im_height, im_width) 128 | if pred.ndim == 3: 129 | raise 130 | result = pred.cpu().numpy() 131 | if task == 'depth': 132 | sio.savemat(os.path.join(save_dirs[task], fname + '.mat'), {'depth': result}) 133 | else: 134 | imageio.imwrite(os.path.join(save_dirs[task], fname + '.png'), result.astype(np.uint8)) 135 | -------------------------------------------------------------------------------- /evaluation/jaccard.py: -------------------------------------------------------------------------------- 1 | # This code is referenced from 2 | # https://github.com/facebookresearch/astmt/ 3 | # 4 | # Copyright (c) Facebook, Inc. and its affiliates. 5 | # All rights reserved. 6 | # 7 | # License: Attribution-NonCommercial 4.0 International 8 | 9 | import numpy as np 10 | 11 | 12 | def jaccard(gt, pred, void_pixels=None): 13 | 14 | assert(gt.shape == pred.shape) 15 | 16 | if void_pixels is None: 17 | void_pixels = np.zeros_like(gt) 18 | assert(void_pixels.shape == gt.shape) 19 | 20 | gt = gt.astype(np.bool) 21 | pred = pred.astype(np.bool) 22 | void_pixels = void_pixels.astype(np.bool) 23 | if np.isclose(np.sum(gt & np.logical_not(void_pixels)), 0) and np.isclose(np.sum(pred & np.logical_not(void_pixels)), 0): 24 | return 1 25 | 26 | else: 27 | return np.sum(((gt & pred) & np.logical_not(void_pixels))) / \ 28 | np.sum(((gt | pred) & np.logical_not(void_pixels)), dtype=np.float32) 29 | 30 | 31 | def precision_recall(gt, pred, void_pixels=None): 32 | 33 | if void_pixels is None: 34 | void_pixels = np.zeros_like(gt) 35 | 36 | gt = gt.astype(np.bool) 37 | pred = pred.astype(np.bool) 38 | void_pixels = void_pixels.astype(np.bool) 39 | 40 | tp = ((pred & gt) & ~void_pixels).sum() 41 | fn = ((~pred & gt) & ~void_pixels).sum() 42 | 43 | fp = ((pred & ~gt) & ~void_pixels).sum() 44 | 45 | prec = tp / (tp + fp + 1e-12) 46 | rec = tp / (tp + fn + 1e-12) 47 | 48 | return prec, rec 49 | -------------------------------------------------------------------------------- /losses/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EnVision-Research/MTMamba/17d132d2a7a1ef80c35ffe40c8dd3ff72fa07d77/losses/__init__.py -------------------------------------------------------------------------------- /losses/loss_functions.py: -------------------------------------------------------------------------------- 1 | # This code is referenced from 2 | # https://github.com/facebookresearch/astmt/ 3 | # 4 | # Copyright (c) Facebook, Inc. and its affiliates. 5 | # All rights reserved. 6 | # 7 | # License: Attribution-NonCommercial 4.0 International 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from torch.nn.modules.module import Module 13 | import numpy as np 14 | 15 | class CrossEntropyLoss(nn.Module): 16 | """ 17 | Cross entropy loss with ignore regions. 18 | """ 19 | def __init__(self, ignore_index=255, class_weight=None, balanced=False): 20 | super().__init__() 21 | self.ignore_index = ignore_index 22 | if balanced: 23 | assert class_weight is None 24 | self.balanced = balanced 25 | if class_weight is not None: 26 | self.register_buffer('class_weight', class_weight) 27 | else: 28 | self.class_weight = None 29 | 30 | def forward(self, out, label, reduction='mean'): 31 | label = torch.squeeze(label, dim=1).long() 32 | if self.balanced: 33 | mask = (label != self.ignore_index) 34 | masked_label = torch.masked_select(label, mask) 35 | assert torch.max(masked_label) < 2 # binary 36 | num_labels_neg = torch.sum(1.0 - masked_label) 37 | num_total = torch.numel(masked_label) 38 | w_pos = num_labels_neg / num_total 39 | class_weight = torch.stack((1. - w_pos, w_pos), dim=0) 40 | loss = nn.functional.cross_entropy( 41 | out, label, weight=class_weight, ignore_index=self.ignore_index, reduction='none') 42 | else: 43 | loss = nn.functional.cross_entropy(out, 44 | label, 45 | weight=self.class_weight, 46 | ignore_index=self.ignore_index, 47 | reduction='none') 48 | if reduction == 'mean': 49 | n_valid = (label != self.ignore_index).sum() 50 | return (loss.sum() / max(n_valid, 1)).float() 51 | elif reduction == 'sum': 52 | return loss.sum() 53 | elif reduction == 'none': 54 | return loss 55 | 56 | class BalancedBinaryCrossEntropyLoss(nn.Module): 57 | """ 58 | Balanced binary cross entropy loss with ignore regions. 59 | """ 60 | def __init__(self, pos_weight=None, ignore_index=255): 61 | super().__init__() 62 | self.pos_weight = pos_weight 63 | self.ignore_index = ignore_index 64 | 65 | def forward(self, output, label, reduction='mean'): 66 | 67 | mask = (label != self.ignore_index) 68 | masked_label = torch.masked_select(label, mask) 69 | masked_output = torch.masked_select(output, mask) 70 | 71 | # weighting of the loss, default is HED-style 72 | if self.pos_weight is None: 73 | num_labels_neg = torch.sum(1.0 - masked_label) 74 | num_total = torch.numel(masked_label) 75 | w = num_labels_neg / num_total 76 | if w == 1.0: 77 | return 0 78 | else: 79 | w = torch.as_tensor(self.pos_weight, device=output.device) 80 | factor = 1. / (1 - w) 81 | 82 | loss = nn.functional.binary_cross_entropy_with_logits( 83 | masked_output, 84 | masked_label, 85 | pos_weight=w*factor, 86 | reduction=reduction) 87 | loss /= factor 88 | return loss 89 | 90 | 91 | class Normalize(nn.Module): 92 | def __init__(self): 93 | super(Normalize, self).__init__() 94 | 95 | def forward(self, bottom): 96 | qn = torch.norm(bottom, p=2, dim=1).unsqueeze(dim=1) + 1e-12 97 | top = bottom.div(qn) 98 | 99 | return top 100 | 101 | 102 | class NormalsLoss(Module): 103 | """ 104 | L1 loss with ignore labels 105 | normalize: normalization for surface normals 106 | """ 107 | def __init__(self, size_average=True, normalize=False, norm=1): 108 | super(NormalsLoss, self).__init__() 109 | 110 | self.size_average = size_average 111 | 112 | if normalize: 113 | self.normalize = Normalize() 114 | else: 115 | self.normalize = None 116 | 117 | if norm == 1: 118 | print('Using L1 loss for surface normals') 119 | self.loss_func = F.l1_loss 120 | elif norm == 2: 121 | print('Using L2 loss for surface normals') 122 | self.loss_func = F.mse_loss 123 | else: 124 | raise NotImplementedError 125 | 126 | def forward(self, out, label, ignore_label=255): 127 | assert not label.requires_grad 128 | mask = (label != ignore_label) 129 | n_valid = torch.sum(mask).item() 130 | 131 | if self.normalize is not None: 132 | out_norm = self.normalize(out) 133 | loss = self.loss_func(torch.masked_select(out_norm, mask), torch.masked_select(label, mask), reduction='sum') 134 | else: 135 | loss = self.loss_func(torch.masked_select(out, mask), torch.masked_select(label, mask), reduction='sum') 136 | 137 | if self.size_average: 138 | if ignore_label: 139 | ret_loss = torch.div(loss, max(n_valid, 1e-6)) 140 | return ret_loss 141 | else: 142 | ret_loss = torch.div(loss, float(np.prod(label.size()))) 143 | return ret_loss 144 | 145 | return loss 146 | 147 | class L1Loss(nn.Module): 148 | """ 149 | from ATRC 150 | L1 loss with ignore regions. 151 | normalize: normalization for surface normals 152 | """ 153 | def __init__(self, normalize=False, ignore_index=255): 154 | super().__init__() 155 | self.normalize = normalize 156 | self.ignore_index = ignore_index 157 | 158 | def forward(self, out, label, reduction='mean'): 159 | 160 | if self.normalize: 161 | out = nn.functional.normalize(out, p=2, dim=1) 162 | 163 | mask = (label != self.ignore_index).all(dim=1, keepdim=True) 164 | n_valid = torch.sum(mask).item() 165 | masked_out = torch.masked_select(out, mask) 166 | masked_label = torch.masked_select(label, mask) 167 | if reduction == 'mean': 168 | return nn.functional.l1_loss(masked_out, masked_label, reduction='sum') / max(n_valid, 1) 169 | elif reduction == 'sum': 170 | return nn.functional.l1_loss(masked_out, masked_label, reduction='sum') 171 | elif reduction == 'none': 172 | return nn.functional.l1_loss(masked_out, masked_label, reduction='none') -------------------------------------------------------------------------------- /losses/loss_schemes.py: -------------------------------------------------------------------------------- 1 | # 2 | # Authors: Simon Vandenhende 3 | # Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/) 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | class MultiTaskLoss(nn.Module): 10 | def __init__(self, p, tasks: list, loss_ft: nn.ModuleDict, loss_weights: dict): 11 | super(MultiTaskLoss, self).__init__() 12 | assert(set(tasks) == set(loss_ft.keys())) 13 | assert(set(tasks) == set(loss_weights.keys())) 14 | self.p = p 15 | self.tasks = tasks 16 | self.loss_ft = loss_ft 17 | self.loss_weights = loss_weights 18 | 19 | 20 | def forward(self, pred, gt, tasks): 21 | out = {task: self.loss_ft[task](pred[task], gt[task]) for task in tasks} 22 | 23 | out['total'] = torch.sum(torch.stack([self.loss_weights[t] * out[t] for t in tasks])) 24 | 25 | if self.p.intermediate_supervision: 26 | inter_preds = pred['inter_preds'] 27 | losses_inter = {t: self.loss_ft[t](inter_preds[t], gt[t]) for t in self.tasks} 28 | for k, v in losses_inter.items(): 29 | out['inter_%s' %(k)] = v 30 | out['total'] += self.loss_weights[k] * v #* 0.5 31 | 32 | return out 33 | 34 | class MultiTaskLoss_log(nn.Module): 35 | def __init__(self, p, tasks: list, loss_ft: nn.ModuleDict, loss_weights: dict): 36 | super(MultiTaskLoss_log, self).__init__() 37 | assert(set(tasks) == set(loss_ft.keys())) 38 | assert(set(tasks) == set(loss_weights.keys())) 39 | self.p = p 40 | self.tasks = tasks 41 | self.loss_ft = loss_ft 42 | self.loss_weights = loss_weights 43 | 44 | 45 | def forward(self, pred, gt, tasks): 46 | out = {task: self.loss_ft[task](pred[task], gt[task]) for task in tasks} 47 | 48 | out['total'] = torch.sum(torch.stack([self.loss_weights[t] * torch.log(out[t]+1e-8) for t in tasks])) 49 | 50 | return out -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import cv2 3 | import os 4 | import numpy as np 5 | import sys 6 | import torch 7 | import pdb 8 | import pprint 9 | from utils.utils import mkdir_if_missing 10 | from utils.config import create_config 11 | from utils.common_config import get_train_dataset, get_transformations,\ 12 | get_test_dataset, get_train_dataloader, get_test_dataloader,\ 13 | get_optimizer, get_model, get_criterion 14 | from utils.logger import Logger 15 | from utils.train_utils import train_phase 16 | from utils.test_utils import test_phase 17 | from evaluation.evaluate_utils import PerformanceMeter 18 | 19 | from torch.utils.tensorboard import SummaryWriter 20 | import time 21 | start_time = time.time() 22 | 23 | # DDP 24 | import torch.distributed as dist 25 | import datetime 26 | dist.init_process_group(backend='nccl', init_method='env://', timeout=datetime.timedelta(0, 3600*2)) 27 | 28 | # Parser 29 | parser = argparse.ArgumentParser(description='Vanilla Training') 30 | parser.add_argument('--config_exp', 31 | help='Config file for the experiment') 32 | parser.add_argument('--local-rank', default=0, type=int, 33 | help='node rank for distributed training') 34 | parser.add_argument('--run_mode', 35 | help='Config file for the experiment') 36 | parser.add_argument('--trained_model', default=None, 37 | help='Config file for the experiment') 38 | parser.add_argument('--seed', default=0, type=int, help='') 39 | args = parser.parse_args() 40 | 41 | print('local rank: %s' %args.local_rank) 42 | torch.cuda.set_device(args.local_rank) 43 | 44 | # CUDNN 45 | torch.backends.cudnn.benchmark = True 46 | # opencv 47 | cv2.setNumThreads(0) 48 | 49 | def set_seed(seed): 50 | import random 51 | random.seed(seed) 52 | np.random.seed(seed) 53 | torch.manual_seed(seed) 54 | torch.cuda.manual_seed_all(seed) 55 | 56 | def main(): 57 | set_seed(args.seed) 58 | # Retrieve config file 59 | params = {'run_mode': args.run_mode} 60 | p = create_config(args.config_exp, params) 61 | if args.local_rank == 0: 62 | sys.stdout = Logger(os.path.join(p['output_dir'], 'log_file.txt')) 63 | pprint.pprint(p) 64 | 65 | # tensorboard 66 | tb_log_dir = p.root_dir + '/tb_dir' #os.path.join(p['output_dir'], 'tensorboard_logdir') 67 | p.tb_log_dir = tb_log_dir 68 | if args.local_rank == 0: 69 | train_tb_log_dir = tb_log_dir + '/train' 70 | test_tb_log_dir = tb_log_dir + '/test' 71 | tb_writer_train = SummaryWriter(train_tb_log_dir) 72 | tb_writer_test = SummaryWriter(test_tb_log_dir) 73 | if args.run_mode == 'train': 74 | mkdir_if_missing(tb_log_dir) 75 | mkdir_if_missing(train_tb_log_dir) 76 | mkdir_if_missing(test_tb_log_dir) 77 | print(f"Tensorboard dir: {tb_log_dir}") 78 | else: 79 | tb_writer_train = None 80 | tb_writer_test = None 81 | 82 | 83 | # Get model 84 | model = get_model(p) 85 | num = 0 86 | for mp in model.parameters(): 87 | num += mp.numel() 88 | print(f'Model Param Size: {num}') 89 | # model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).cuda() 90 | model = model.cuda() 91 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True) 92 | 93 | # Get criterion 94 | criterion = get_criterion(p).cuda() 95 | 96 | # Optimizer 97 | scheduler, optimizer = get_optimizer(p, model) 98 | 99 | # Performance meter init 100 | performance_meter = PerformanceMeter(p, p.TASKS.NAMES) 101 | 102 | # Transforms 103 | train_transforms, val_transforms = get_transformations(p) 104 | if args.run_mode == 'train': 105 | train_dataset = get_train_dataset(p, train_transforms) 106 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, drop_last=True) 107 | train_dataloader = get_train_dataloader(p, train_dataset, train_sampler) 108 | test_dataset = get_test_dataset(p, val_transforms) 109 | test_dataloader = get_test_dataloader(p, test_dataset) 110 | elif args.run_mode == 'infer': 111 | test_dataset = get_test_dataset(p, val_transforms) 112 | test_dataloader = get_test_dataloader(p, test_dataset) 113 | 114 | 115 | # Resume from checkpoint 116 | if os.path.exists(p['checkpoint']) or args.run_mode in ['infer']: 117 | if args.trained_model != None: 118 | checkpoint_path = args.trained_model 119 | else: 120 | checkpoint_path = p['checkpoint'] 121 | if args.local_rank == 0: 122 | print('Use checkpoint {}'.format(checkpoint_path)) 123 | checkpoint = torch.load(checkpoint_path, map_location='cpu') 124 | model.load_state_dict(checkpoint['model']) 125 | if 'optimizer' in checkpoint.keys(): 126 | optimizer.load_state_dict(checkpoint['optimizer']) 127 | if 'scheduler' in checkpoint.keys(): 128 | scheduler.load_state_dict(checkpoint['scheduler']) 129 | if 'epoch' in checkpoint.keys(): 130 | start_epoch = checkpoint['epoch'] + 1 # epoch count is not used 131 | else: 132 | start_epoch = 0 133 | if 'iter_count' in checkpoint.keys(): 134 | iter_count = checkpoint['iter_count'] # already + 1 when saving 135 | else: 136 | iter_count = 0 137 | else: 138 | if args.local_rank == 0: 139 | print('Fresh start...') 140 | start_epoch = 0 141 | iter_count = 0 142 | best_imp = -100 143 | if DEBUG_FLAG and args.local_rank == 0: 144 | print("\nFirst Testing...") 145 | if True: 146 | eval_test = test_phase(p, test_dataloader, model, criterion, iter_count) 147 | else: 148 | eval_test = {} 149 | print(eval_test) 150 | 151 | # Train loop 152 | if args.run_mode == 'train': 153 | for epoch in range(start_epoch, p['epochs']): 154 | train_sampler.set_epoch(epoch) 155 | if args.local_rank == 0: 156 | print('Epoch %d/%d' %(epoch+1, p['epochs'])) 157 | print('-'*10) 158 | 159 | end_signal, iter_count, best_imp = train_phase(p, args, train_dataloader, test_dataloader, model, criterion, 160 | optimizer, scheduler, epoch, tb_writer_train, tb_writer_test, iter_count, best_imp) 161 | 162 | if end_signal: 163 | break 164 | 165 | # running eval 166 | if args.local_rank == 0: 167 | eval_epoch = iter_count # start_epoch 168 | eval_test = test_phase(p, test_dataloader, model, criterion, eval_epoch, save_edge=True) 169 | print('Infer test restuls:') 170 | print(eval_test) 171 | 172 | end_time = time.time() 173 | run_time = (end_time-start_time) / 3600 174 | print('Total running time: {} h.'.format(run_time)) 175 | 176 | if __name__ == "__main__": 177 | # IMPORTANT VARIABLES 178 | DEBUG_FLAG = False # When True, test the evaluation code when started 179 | 180 | assert args.run_mode in ['train', 'infer'] 181 | main() 182 | -------------------------------------------------------------------------------- /models/CTM.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.utils.checkpoint as checkpoint 6 | from einops import rearrange, repeat 7 | from timm.models.layers import DropPath, trunc_normal_ 8 | from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, selective_scan_ref 9 | from functools import partial 10 | from typing import Optional, Callable, Any 11 | from collections import OrderedDict 12 | 13 | class SS2D(nn.Module): 14 | def __init__( 15 | self, 16 | d_model, 17 | d_model_rate=1, 18 | d_state=16, 19 | d_conv=3, 20 | ssm_ratio=2, 21 | dt_rank="auto", 22 | # ====================== 23 | dropout=0., 24 | conv_bias=True, 25 | bias=False, 26 | dtype=None, 27 | # ====================== 28 | dt_min=0.001, 29 | dt_max=0.1, 30 | dt_init="random", 31 | dt_scale=1.0, 32 | dt_init_floor=1e-4, 33 | # ====================== 34 | shared_ssm=False, 35 | softmax_version=False, 36 | # ====================== 37 | **kwargs, 38 | ): 39 | factory_kwargs = {"device": None, "dtype": dtype} 40 | super().__init__() 41 | self.softmax_version = softmax_version 42 | self.d_model = d_model 43 | self.d_state = math.ceil(self.d_model / 6) if d_state == "auto" else d_state # 20240109 44 | self.d_conv = d_conv 45 | self.expand = ssm_ratio 46 | self.d_inner = int(self.expand * self.d_model) 47 | self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank 48 | self.K = 4 if not shared_ssm else 1 49 | 50 | self.in_proj = nn.Linear(self.d_model*d_model_rate, self.d_inner, bias=bias, **factory_kwargs) 51 | self.conv2d = nn.Conv2d( 52 | in_channels=self.d_inner, 53 | out_channels=self.d_inner, 54 | groups=self.d_inner, 55 | bias=conv_bias, 56 | kernel_size=d_conv, 57 | padding=(d_conv - 1) // 2, 58 | **factory_kwargs, 59 | ) 60 | self.act = nn.SiLU() 61 | 62 | self.x_proj = [ 63 | nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs) 64 | for _ in range(self.K) 65 | ] 66 | self.x_proj_weight = nn.Parameter(torch.stack([t.weight for t in self.x_proj], dim=0)) # (K, N, inner) 67 | del self.x_proj 68 | 69 | self.dt_projs = [ 70 | self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, **factory_kwargs) 71 | for _ in range(self.K) 72 | ] 73 | self.dt_projs_weight = nn.Parameter(torch.stack([t.weight for t in self.dt_projs], dim=0)) # (K, inner, rank) 74 | self.dt_projs_bias = nn.Parameter(torch.stack([t.bias for t in self.dt_projs], dim=0)) # (K * inner) 75 | del self.dt_projs 76 | 77 | self.A_logs = self.A_log_init(self.d_state, self.d_inner, copies=self.K, merge=True) # (K * D, N) 78 | self.Ds = self.D_init(self.d_inner, copies=self.K, merge=True) # (K * D) 79 | 80 | if not self.softmax_version: 81 | self.out_norm = nn.LayerNorm(self.d_inner) 82 | # self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) 83 | # self.dropout = nn.Dropout(dropout) if dropout > 0. else None 84 | 85 | @staticmethod 86 | def dt_init(dt_rank, d_inner, dt_scale=1.0, dt_init="random", dt_min=0.001, dt_max=0.1, dt_init_floor=1e-4, **factory_kwargs): 87 | dt_proj = nn.Linear(dt_rank, d_inner, bias=True, **factory_kwargs) 88 | 89 | # Initialize special dt projection to preserve variance at initialization 90 | dt_init_std = dt_rank**-0.5 * dt_scale 91 | if dt_init == "constant": 92 | nn.init.constant_(dt_proj.weight, dt_init_std) 93 | elif dt_init == "random": 94 | nn.init.uniform_(dt_proj.weight, -dt_init_std, dt_init_std) 95 | else: 96 | raise NotImplementedError 97 | 98 | # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max 99 | dt = torch.exp( 100 | torch.rand(d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min)) 101 | + math.log(dt_min) 102 | ).clamp(min=dt_init_floor) 103 | # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 104 | inv_dt = dt + torch.log(-torch.expm1(-dt)) 105 | with torch.no_grad(): 106 | dt_proj.bias.copy_(inv_dt) 107 | # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit 108 | dt_proj.bias._no_reinit = True 109 | 110 | return dt_proj 111 | 112 | @staticmethod 113 | def A_log_init(d_state, d_inner, copies=-1, device=None, merge=True): 114 | # S4D real initialization 115 | A = repeat( 116 | torch.arange(1, d_state + 1, dtype=torch.float32, device=device), 117 | "n -> d n", 118 | d=d_inner, 119 | ).contiguous() 120 | A_log = torch.log(A) # Keep A_log in fp32 121 | if copies > 0: 122 | A_log = repeat(A_log, "d n -> r d n", r=copies) 123 | if merge: 124 | A_log = A_log.flatten(0, 1) 125 | A_log = nn.Parameter(A_log) 126 | A_log._no_weight_decay = True 127 | return A_log 128 | 129 | @staticmethod 130 | def D_init(d_inner, copies=-1, device=None, merge=True): 131 | # D "skip" parameter 132 | D = torch.ones(d_inner, device=device) 133 | if copies > 0: 134 | D = repeat(D, "n1 -> r n1", r=copies) 135 | if merge: 136 | D = D.flatten(0, 1) 137 | D = nn.Parameter(D) # Keep in fp32 138 | D._no_weight_decay = True 139 | return D 140 | 141 | def forward_corev0(self, x: torch.Tensor): 142 | self.selective_scan = selective_scan_fn 143 | 144 | B, C, H, W = x.shape 145 | L = H * W 146 | K = 4 147 | 148 | x_hwwh = torch.stack([x.view(B, -1, L), torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, -1, L)], dim=1).view(B, 2, -1, L) 149 | xs = torch.cat([x_hwwh, torch.flip(x_hwwh, dims=[-1])], dim=1) # (b, k, d, l) 150 | 151 | x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs, self.x_proj_weight) 152 | # x_dbl = x_dbl + self.x_proj_bias.view(1, K, -1, 1) 153 | dts, Bs, Cs = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=2) 154 | dts = torch.einsum("b k r l, k d r -> b k d l", dts, self.dt_projs_weight) 155 | 156 | xs = xs.float().view(B, -1, L) # (b, k * d, l) 157 | dts = dts.contiguous().float().view(B, -1, L) # (b, k * d, l) 158 | Bs = Bs.float() # (b, k, d_state, l) 159 | Cs = Cs.float() # (b, k, d_state, l) 160 | 161 | As = -torch.exp(self.A_logs.float()) # (k * d, d_state) 162 | Ds = self.Ds.float() # (k * d) 163 | dt_projs_bias = self.dt_projs_bias.float().view(-1) # (k * d) 164 | 165 | # assert len(xs.shape) == 3 and len(dts.shape) == 3 and len(Bs.shape) == 4 and len(Cs.shape) == 4 166 | # assert len(As.shape) == 2 and len(Ds.shape) == 1 and len(dt_projs_bias.shape) == 1 167 | 168 | out_y = self.selective_scan( 169 | xs, dts, 170 | As, Bs, Cs, Ds, z=None, 171 | delta_bias=dt_projs_bias, 172 | delta_softplus=True, 173 | return_last_state=False, 174 | ).view(B, K, -1, L) 175 | assert out_y.dtype == torch.float 176 | 177 | inv_y = torch.flip(out_y[:, 2:4], dims=[-1]).view(B, 2, -1, L) 178 | wh_y = torch.transpose(out_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L) 179 | invwh_y = torch.transpose(inv_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L) 180 | y = out_y[:, 0] + inv_y[:, 0] + wh_y + invwh_y 181 | y = torch.transpose(y, dim0=1, dim1=2).contiguous().view(B, H, W, -1) 182 | y = self.out_norm(y) 183 | 184 | return y 185 | 186 | forward_core = forward_corev0 187 | 188 | def forward(self, x: torch.Tensor, **kwargs): 189 | x = self.in_proj(x) 190 | # x, z = xz.chunk(2, dim=-1) # (b, h, w, d) 191 | 192 | x = x.permute(0, 3, 1, 2).contiguous() 193 | x = self.act(self.conv2d(x)) # (b, d, h, w) 194 | y = self.forward_core(x) 195 | # y = y * F.silu(z) 196 | # out = self.out_proj(y) 197 | # if self.dropout is not None: 198 | # out = self.dropout(out) 199 | return y 200 | 201 | 202 | class CTMBlock(nn.Module): 203 | def __init__( 204 | self, 205 | tasks, 206 | hidden_dim: int = 0, 207 | drop_path: float = 0, 208 | norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), 209 | attn_drop_rate: float = 0, 210 | d_state: int = 16, 211 | dt_rank: Any = "auto", 212 | ssm_ratio=2.0, 213 | shared_ssm=False, 214 | softmax_version=False, 215 | use_checkpoint: bool = False, 216 | mlp_ratio=4.0, 217 | act_layer=nn.GELU, 218 | drop: float = 0.0, 219 | **kwargs, 220 | ): 221 | super().__init__() 222 | self.use_checkpoint = use_checkpoint 223 | self.tasks = tasks 224 | 225 | self.norm = nn.ModuleDict() 226 | self.op = nn.ModuleDict() 227 | self.in_proj = nn.ModuleDict() 228 | self.out_proj = nn.ModuleDict() 229 | for t in self.tasks: 230 | self.norm[t] = norm_layer(hidden_dim) 231 | self.in_proj[t] = nn.Linear(hidden_dim, hidden_dim*ssm_ratio, bias=False) 232 | self.out_proj[t] = nn.Linear(hidden_dim*ssm_ratio, hidden_dim, bias=False) 233 | 234 | self.op[t] = SS2D( 235 | d_model=hidden_dim, 236 | d_model_rate=1, 237 | dropout=attn_drop_rate, 238 | d_state=d_state, 239 | ssm_ratio=ssm_ratio, 240 | dt_rank=dt_rank, 241 | shared_ssm=shared_ssm, 242 | softmax_version=softmax_version, 243 | **kwargs 244 | ) 245 | 246 | d_model_rate = len(self.tasks) 247 | 248 | self.norm_share = norm_layer(hidden_dim*d_model_rate) 249 | self.op_share = SS2D( 250 | d_model=hidden_dim, 251 | d_model_rate=d_model_rate, 252 | dropout=attn_drop_rate, 253 | d_state=d_state, 254 | ssm_ratio=ssm_ratio, 255 | dt_rank=dt_rank, 256 | shared_ssm=shared_ssm, 257 | softmax_version=softmax_version, 258 | **kwargs) 259 | 260 | self.drop_path = DropPath(drop_path) 261 | 262 | def _forward_pre(self, input: dict): 263 | x = torch.cat([input[t] for t in self.tasks], dim=-1) 264 | x = self.op_share(self.norm_share(x)) 265 | 266 | out = {} 267 | for t in self.tasks: 268 | z = self.norm[t](input[t]) 269 | x_t = self.op[t](z) 270 | g = F.sigmoid(self.in_proj[t](z)) 271 | z = g * x + (1 - g) * x_t 272 | out[t] = self.out_proj[t](z) 273 | return out 274 | 275 | def _forward(self, input: dict): 276 | z = self._forward_pre(input) 277 | 278 | out = {} 279 | for t in self.tasks: 280 | out[t] = input[t] + self.drop_path(z[t]) 281 | return out 282 | 283 | def forward(self, input: torch.Tensor): 284 | if self.use_checkpoint: 285 | return checkpoint.checkpoint(self._forward, input) 286 | else: 287 | return self._forward(input) -------------------------------------------------------------------------------- /models/MTMamba.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .utils import PatchExpand, FinalPatchExpand_X4, STMBlock 5 | from .CTM import CTMBlock 6 | 7 | INTERPOLATE_MODE = 'bilinear' 8 | 9 | class MTMamba(nn.Module): 10 | def __init__(self, p, backbone, d_state=16, dt_rank="auto", ssm_ratio=2, mlp_ratio=0): 11 | super().__init__() 12 | self.tasks = p.TASKS.NAMES 13 | self.backbone = backbone 14 | self.feature_channel = backbone.num_features 15 | self.img_size = p.IMAGE_ORI_SIZE 16 | 17 | total_depth = 3 18 | dpr = [x.item() for x in torch.linspace(0.2, 0, (len(self.feature_channel)-1)*total_depth)] 19 | 20 | self.expand_layers = nn.ModuleDict() 21 | self.concat_layers = nn.ModuleDict() 22 | self.block_1 = nn.ModuleDict() 23 | self.block_2 = nn.ModuleDict() 24 | self.final_project = nn.ModuleDict() 25 | self.final_expand = nn.ModuleDict() 26 | for t in self.tasks: 27 | for stage in range(len(self.feature_channel) - 1): 28 | current_channel = self.feature_channel[::-1][stage] 29 | skip_channel = self.feature_channel[::-1][stage+1] 30 | 31 | self.expand_layers[f'{t}_{stage}'] = PatchExpand(input_resolution=None, 32 | dim=current_channel, 33 | dim_scale=2, 34 | norm_layer=nn.LayerNorm) 35 | self.concat_layers[f'{t}_{stage}'] = nn.Linear(2*skip_channel, skip_channel) 36 | 37 | self.block_1[f'{t}_{stage}'] = STMBlock(hidden_dim=skip_channel, 38 | drop_path=dpr[total_depth*(stage)+0], 39 | norm_layer=nn.LayerNorm, 40 | ssm_ratio=ssm_ratio, 41 | d_state=d_state, 42 | mlp_ratio=mlp_ratio, 43 | dt_rank=dt_rank) 44 | self.block_2[f'{t}_{stage}'] = STMBlock(hidden_dim=skip_channel, 45 | drop_path=dpr[total_depth*(stage)+1], 46 | norm_layer=nn.LayerNorm, 47 | ssm_ratio=ssm_ratio, 48 | d_state=d_state, 49 | mlp_ratio=mlp_ratio, 50 | dt_rank=dt_rank) 51 | 52 | self.final_expand[t] = FinalPatchExpand_X4( 53 | input_resolution=None, 54 | dim=self.feature_channel[0], 55 | dim_scale=4, 56 | norm_layer=nn.LayerNorm, 57 | ) 58 | self.final_project[t] = nn.Conv2d(self.feature_channel[0], p.TASKS.NUM_OUTPUT[t], 1, 1, 0, bias=True) 59 | 60 | self.block_3 = nn.ModuleDict() 61 | for stage in range(len(self.feature_channel) - 1): 62 | skip_channel = self.feature_channel[::-1][stage+1] 63 | self.block_3[f'{stage}'] = CTMBlock(tasks=self.tasks, 64 | hidden_dim=skip_channel, 65 | drop_path=dpr[total_depth*(stage)+2], 66 | norm_layer=nn.LayerNorm, 67 | ssm_ratio=ssm_ratio, 68 | d_state=d_state, 69 | mlp_ratio=mlp_ratio, 70 | dt_rank=dt_rank) 71 | 72 | def _forward_expand(self, x_dict, selected_fea: list, stage: int) -> dict: 73 | if stage == 0: 74 | x_dict = {t: selected_fea[-1] for t in self.tasks} 75 | 76 | skip = selected_fea[::-1][stage+1] 77 | out = {} 78 | for t in self.tasks: 79 | x = self.expand_layers[f'{t}_{stage}'](x_dict[t]) 80 | x = torch.cat((x, skip.permute(0,2,3,1)), -1) 81 | x = self.concat_layers[f'{t}_{stage}'](x) 82 | out[t] = x 83 | return out # B,H,W,C 84 | 85 | def _forward_block1(self, x_dict: dict, stage: int) -> dict: 86 | out = {} 87 | for t in self.tasks: 88 | out[t] = self.block_1[f'{t}_{stage}'](x_dict[t]) 89 | return out 90 | 91 | def _forward_block2(self, x_dict: dict, stage: int) -> dict: 92 | out = {} 93 | for t in self.tasks: 94 | out[t] = self.block_2[f'{t}_{stage}'](x_dict[t]) 95 | return out 96 | 97 | def _forward_block3(self, x_dict: dict, stage: int) -> dict: 98 | return self.block_3[f'{stage}'](x_dict) 99 | 100 | def forward(self, x): 101 | # img_size = x.size()[-2:] 102 | 103 | # Backbone 104 | selected_fea = self.backbone(x) 105 | 106 | x_dict = None 107 | for stage in range(len(self.feature_channel) - 1): 108 | x_dict = self._forward_expand(x_dict, selected_fea, stage) 109 | x_dict = self._forward_block1(x_dict, stage) 110 | x_dict = self._forward_block2(x_dict, stage) 111 | x_dict = self._forward_block3(x_dict, stage) 112 | x_dict = {t: xx.permute(0,3,1,2) for t, xx in x_dict.items()} 113 | 114 | out = {} 115 | for t in self.tasks: 116 | z = self.final_expand[t](x_dict[t]) 117 | z = self.final_project[t](z.permute(0,3,1,2)) 118 | out[t] = F.interpolate(z, self.img_size, mode=INTERPOLATE_MODE) 119 | 120 | return out -------------------------------------------------------------------------------- /models/MTMamba_plus.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | from functools import partial 6 | from einops import rearrange, repeat 7 | from typing import Optional, Callable, Any 8 | from collections import OrderedDict 9 | from mamba_ssm.ops.selective_scan_interface import selective_scan_fn 10 | from timm.models.layers import DropPath, trunc_normal_ 11 | 12 | from .utils import PatchExpand, STMBlock 13 | 14 | INTERPOLATE_MODE = 'bilinear' 15 | 16 | class CSS2D(nn.Module): 17 | def __init__( 18 | self, 19 | d_model, 20 | d_state=16, 21 | d_conv=3, 22 | ssm_ratio=2, 23 | dt_rank="auto", 24 | # ====================== 25 | dropout=0., 26 | conv_bias=True, 27 | bias=False, 28 | dtype=None, 29 | # ====================== 30 | dt_min=0.001, 31 | dt_max=0.1, 32 | dt_init="random", 33 | dt_scale=1.0, 34 | dt_init_floor=1e-4, 35 | # ====================== 36 | shared_ssm=False, 37 | softmax_version=False, 38 | # ====================== 39 | **kwargs, 40 | ): 41 | factory_kwargs = {"device": None, "dtype": dtype} 42 | super().__init__() 43 | self.softmax_version = softmax_version 44 | self.d_model = d_model 45 | self.d_state = math.ceil(self.d_model / 6) if d_state == "auto" else d_state # 20240109 46 | self.d_conv = d_conv 47 | self.expand = ssm_ratio 48 | self.d_inner = int(self.expand * self.d_model) 49 | self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank 50 | self.K = 4 if not shared_ssm else 1 51 | 52 | self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs) 53 | self.in_proj_cross = nn.Linear(self.d_model, self.d_inner, bias=bias, **factory_kwargs) 54 | self.conv2d = nn.Conv2d( 55 | in_channels=self.d_inner, 56 | out_channels=self.d_inner, 57 | groups=self.d_inner, 58 | bias=conv_bias, 59 | kernel_size=d_conv, 60 | padding=(d_conv - 1) // 2, 61 | **factory_kwargs, 62 | ) 63 | self.act = nn.SiLU() 64 | 65 | self.x_proj = [ 66 | nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs) 67 | for _ in range(self.K) 68 | ] 69 | self.x_proj_weight = nn.Parameter(torch.stack([t.weight for t in self.x_proj], dim=0)) # (K, N, inner) 70 | del self.x_proj 71 | 72 | self.dt_projs = [ 73 | self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, **factory_kwargs) 74 | for _ in range(self.K) 75 | ] 76 | self.dt_projs_weight = nn.Parameter(torch.stack([t.weight for t in self.dt_projs], dim=0)) # (K, inner, rank) 77 | self.dt_projs_bias = nn.Parameter(torch.stack([t.bias for t in self.dt_projs], dim=0)) # (K * inner) 78 | del self.dt_projs 79 | 80 | self.A_logs = self.A_log_init(self.d_state, self.d_inner, copies=self.K, merge=True) # (K * D, N) 81 | self.Ds = self.D_init(self.d_inner, copies=self.K, merge=True) # (K * D) 82 | 83 | if not self.softmax_version: 84 | self.out_norm = nn.LayerNorm(self.d_inner) 85 | self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) 86 | self.dropout = nn.Dropout(dropout) if dropout > 0. else None 87 | 88 | @staticmethod 89 | def dt_init(dt_rank, d_inner, dt_scale=1.0, dt_init="random", dt_min=0.001, dt_max=0.1, dt_init_floor=1e-4, **factory_kwargs): 90 | dt_proj = nn.Linear(dt_rank, d_inner, bias=True, **factory_kwargs) 91 | 92 | # Initialize special dt projection to preserve variance at initialization 93 | dt_init_std = dt_rank**-0.5 * dt_scale 94 | if dt_init == "constant": 95 | nn.init.constant_(dt_proj.weight, dt_init_std) 96 | elif dt_init == "random": 97 | nn.init.uniform_(dt_proj.weight, -dt_init_std, dt_init_std) 98 | else: 99 | raise NotImplementedError 100 | 101 | # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max 102 | dt = torch.exp( 103 | torch.rand(d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min)) 104 | + math.log(dt_min) 105 | ).clamp(min=dt_init_floor) 106 | # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 107 | inv_dt = dt + torch.log(-torch.expm1(-dt)) 108 | with torch.no_grad(): 109 | dt_proj.bias.copy_(inv_dt) 110 | # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit 111 | dt_proj.bias._no_reinit = True 112 | 113 | return dt_proj 114 | 115 | @staticmethod 116 | def A_log_init(d_state, d_inner, copies=-1, device=None, merge=True): 117 | # S4D real initialization 118 | A = repeat( 119 | torch.arange(1, d_state + 1, dtype=torch.float32, device=device), 120 | "n -> d n", 121 | d=d_inner, 122 | ).contiguous() 123 | A_log = torch.log(A) # Keep A_log in fp32 124 | if copies > 0: 125 | A_log = repeat(A_log, "d n -> r d n", r=copies) 126 | if merge: 127 | A_log = A_log.flatten(0, 1) 128 | A_log = nn.Parameter(A_log) 129 | A_log._no_weight_decay = True 130 | return A_log 131 | 132 | @staticmethod 133 | def D_init(d_inner, copies=-1, device=None, merge=True): 134 | # D "skip" parameter 135 | D = torch.ones(d_inner, device=device) 136 | if copies > 0: 137 | D = repeat(D, "n1 -> r n1", r=copies) 138 | if merge: 139 | D = D.flatten(0, 1) 140 | D = nn.Parameter(D) # Keep in fp32 141 | D._no_weight_decay = True 142 | return D 143 | 144 | def forward_corev0(self, x: torch.Tensor, x_cross: torch.Tensor): 145 | self.selective_scan = selective_scan_fn 146 | assert x.shape == x_cross.shape 147 | B, C, H, W = x.shape 148 | L = H * W 149 | K = 4 150 | 151 | x_hwwh = torch.stack([x.view(B, -1, L), torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, -1, L)], dim=1).view(B, 2, -1, L) 152 | xs = torch.cat([x_hwwh, torch.flip(x_hwwh, dims=[-1])], dim=1) # (b, k, d, l) 153 | 154 | x_hwwh_cross = torch.stack([x_cross.view(B, -1, L), torch.transpose(x_cross, dim0=2, dim1=3).contiguous().view(B, -1, L)], dim=1).view(B, 2, -1, L) 155 | xs_cross = torch.cat([x_hwwh_cross, torch.flip(x_hwwh_cross, dims=[-1])], dim=1) # (b, k, d, l) 156 | 157 | x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs_cross, self.x_proj_weight) 158 | del x_cross, xs_cross, x_hwwh_cross 159 | 160 | dts, Bs, Cs = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=2) 161 | dts = torch.einsum("b k r l, k d r -> b k d l", dts, self.dt_projs_weight) 162 | 163 | xs = xs.float().view(B, -1, L) # (b, k * d, l) 164 | dts = dts.contiguous().float().view(B, -1, L) # (b, k * d, l) 165 | Bs = Bs.float() # (b, k, d_state, l) 166 | Cs = Cs.float() # (b, k, d_state, l) 167 | 168 | As = -torch.exp(self.A_logs.float()) # (k * d, d_state) 169 | Ds = self.Ds.float() # (k * d) 170 | dt_projs_bias = self.dt_projs_bias.float().view(-1) # (k * d) 171 | 172 | # assert len(xs.shape) == 3 and len(dts.shape) == 3 and len(Bs.shape) == 4 and len(Cs.shape) == 4 173 | # assert len(As.shape) == 2 and len(Ds.shape) == 1 and len(dt_projs_bias.shape) == 1 174 | 175 | out_y = self.selective_scan( 176 | xs, dts, 177 | As, Bs, Cs, Ds, z=None, 178 | delta_bias=dt_projs_bias, 179 | delta_softplus=True, 180 | return_last_state=False, 181 | ).view(B, K, -1, L) 182 | assert out_y.dtype == torch.float 183 | 184 | inv_y = torch.flip(out_y[:, 2:4], dims=[-1]).view(B, 2, -1, L) 185 | wh_y = torch.transpose(out_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L) 186 | invwh_y = torch.transpose(inv_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L) 187 | y = out_y[:, 0] + inv_y[:, 0] + wh_y + invwh_y 188 | y = torch.transpose(y, dim0=1, dim1=2).contiguous().view(B, H, W, -1) 189 | y = self.out_norm(y) 190 | 191 | return y 192 | 193 | forward_core = forward_corev0 194 | 195 | def forward(self, x: torch.Tensor, x_cross: torch.Tensor, **kwargs): 196 | xz = self.in_proj(x) 197 | x_cross = self.in_proj_cross(x_cross) 198 | x, z = xz.chunk(2, dim=-1) # (b, h, w, d) 199 | 200 | x = x.permute(0, 3, 1, 2).contiguous() 201 | x = self.act(self.conv2d(x)) # (b, d, h, w) 202 | y = self.forward_core(x, x_cross.permute(0,3,1,2)) 203 | y = y * F.silu(z) 204 | out = self.out_proj(y) 205 | if self.dropout is not None: 206 | out = self.dropout(out) 207 | return out 208 | 209 | class SCTM(nn.Module): 210 | def __init__( 211 | self, 212 | tasks, 213 | hidden_dim: int = 0, 214 | drop_path: float = 0, 215 | norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), 216 | attn_drop_rate: float = 0, 217 | d_state: int = 16, 218 | dt_rank: Any = "auto", 219 | ssm_ratio=2.0, 220 | shared_ssm=False, 221 | softmax_version=False, 222 | mlp_ratio=4.0, 223 | act_layer=nn.GELU, 224 | drop: float = 0.0, 225 | **kwargs, 226 | ): 227 | super().__init__() 228 | self.tasks = tasks 229 | self.norm_share = norm_layer(hidden_dim*len(tasks)) 230 | self.conv_share = nn.Sequential(nn.Conv2d(hidden_dim*len(tasks), hidden_dim, 1), 231 | nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1)) 232 | 233 | self.norm = nn.ModuleDict() 234 | self.op = nn.ModuleDict() 235 | for tname in self.tasks: 236 | self.norm[tname] = norm_layer(hidden_dim) 237 | self.op[tname] = CSS2D( 238 | d_model=hidden_dim, 239 | dropout=attn_drop_rate, 240 | d_state=d_state, 241 | ssm_ratio=ssm_ratio, 242 | dt_rank=dt_rank, 243 | shared_ssm=shared_ssm, 244 | softmax_version=softmax_version, 245 | **kwargs 246 | ) 247 | self.drop_path = DropPath(drop_path) 248 | 249 | def forward(self, input: dict): 250 | x_share = torch.cat([input[t] for t in self.tasks], dim=-1) 251 | x_share = self.conv_share(self.norm_share(x_share).permute(0,3,1,2)).permute(0,2,3,1) 252 | out = {} 253 | for t in self.tasks: 254 | x_t = input[t] 255 | x = x_t + self.drop_path(self.op[t](self.norm[t](x_t), x_share)) 256 | out[t] = x 257 | return out 258 | 259 | class MTMamba_plus(nn.Module): 260 | def __init__(self, p, backbone, d_state=16, dt_rank="auto", ssm_ratio=2, mlp_ratio=0): 261 | super().__init__() 262 | self.tasks = p.TASKS.NAMES 263 | self.backbone = backbone 264 | self.feature_channel = backbone.num_features 265 | self.img_size = p.IMAGE_ORI_SIZE 266 | 267 | each_stage_depth = 3 268 | stage_num = len(self.feature_channel) - 1 269 | 270 | dpr = [x.item() for x in torch.linspace(0.2, 0, stage_num*each_stage_depth)] 271 | 272 | self.expand_layers = nn.ModuleDict() 273 | self.concat_layers = nn.ModuleDict() 274 | self.block_stm = nn.ModuleDict() 275 | self.final_project = nn.ModuleDict() 276 | self.final_expand = nn.ModuleDict() 277 | for t in self.tasks: 278 | for stage in range(len(self.feature_channel) - 1): 279 | current_channel = self.feature_channel[::-1][stage] 280 | skip_channel = self.feature_channel[::-1][stage+1] 281 | 282 | self.expand_layers[f'{t}_{stage}'] = PatchExpand(input_resolution=None, 283 | dim=current_channel, 284 | dim_scale=2, 285 | norm_layer=nn.LayerNorm) 286 | self.concat_layers[f'{t}_{stage}'] = nn.Conv2d(2*skip_channel, skip_channel, 1) 287 | 288 | stm_layer = [STMBlock(hidden_dim=skip_channel, 289 | drop_path=dpr[each_stage_depth*(stage)+stm_idx], 290 | norm_layer=nn.LayerNorm, 291 | ssm_ratio=ssm_ratio, 292 | d_state=d_state, 293 | mlp_ratio=mlp_ratio, 294 | dt_rank=dt_rank) for stm_idx in range(2)] 295 | 296 | self.block_stm[f'{t}_{stage}'] = nn.Sequential(*stm_layer) 297 | 298 | self.final_expand[t] = nn.Sequential( 299 | nn.Conv2d(self.feature_channel[0], 96, 3, padding=1), 300 | nn.SyncBatchNorm(96), 301 | nn.ReLU(True) 302 | ) 303 | trunc_normal_(self.final_expand[t][0].weight, std=0.02) 304 | 305 | self.final_project[t] = nn.Conv2d(96, p.TASKS.NUM_OUTPUT[t], 1) 306 | 307 | self.block_ctm = nn.ModuleDict() 308 | for stage in range(len(self.feature_channel) - 1): 309 | skip_channel = self.feature_channel[::-1][stage+1] 310 | 311 | ctm_layer = [SCTM(tasks=self.tasks, 312 | hidden_dim=skip_channel, 313 | drop_path=dpr[each_stage_depth*(stage)+2], 314 | norm_layer=nn.LayerNorm, 315 | ssm_ratio=ssm_ratio, 316 | d_state=d_state, 317 | mlp_ratio=mlp_ratio, 318 | dt_rank=dt_rank)] 319 | self.block_ctm[f'{stage}'] = nn.Sequential(*ctm_layer) 320 | 321 | def _forward_expand(self, x_dict, selected_fea: list, stage: int) -> dict: 322 | if stage == 0: 323 | x_dict = {t: selected_fea[-1] for t in self.tasks} 324 | 325 | skip = selected_fea[::-1][stage+1] 326 | out = {} 327 | for t in self.tasks: 328 | x = self.expand_layers[f'{t}_{stage}'](x_dict[t]) 329 | x = torch.cat((x.permute(0,3,1,2), skip), 1) 330 | x = self.concat_layers[f'{t}_{stage}'](x) 331 | x = x.permute(0,2,3,1) 332 | out[t] = x 333 | return out # B,H,W,C 334 | 335 | def _forward_block_stm(self, x_dict: dict, stage: int) -> dict: 336 | out = {} 337 | for t in self.tasks: 338 | out[t] = self.block_stm[f'{t}_{stage}'](x_dict[t]) 339 | return out 340 | 341 | def _forward_block_ctm(self, x_dict: dict, stage: int) -> dict: 342 | return self.block_ctm[f'{stage}'](x_dict) 343 | 344 | 345 | def forward(self, x): 346 | # img_size = x.size()[-2:] 347 | 348 | # Backbone 349 | selected_fea = self.backbone(x) 350 | 351 | x_dict = None 352 | for stage in range(len(self.feature_channel) - 1): 353 | x_dict = self._forward_expand(x_dict, selected_fea, stage) 354 | x_dict = self._forward_block_stm(x_dict, stage) 355 | x_dict = self._forward_block_ctm(x_dict, stage) 356 | x_dict = {t: xx.permute(0,3,1,2) for t, xx in x_dict.items()} 357 | 358 | out = {} 359 | for t in self.tasks: 360 | z = self.final_expand[t](x_dict[t]) 361 | z = self.final_project[t](z) 362 | out[t] = F.interpolate(z, self.img_size, mode=INTERPOLATE_MODE) 363 | 364 | return out 365 | -------------------------------------------------------------------------------- /models/utils.py: -------------------------------------------------------------------------------- 1 | import torch, math 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from einops import rearrange, repeat 5 | from functools import partial 6 | from typing import Optional, Callable, Any 7 | from collections import OrderedDict 8 | from mamba_ssm.ops.selective_scan_interface import selective_scan_fn 9 | from timm.models.layers import DropPath, trunc_normal_, to_2tuple 10 | DropPath.__repr__ = lambda self: f"timm.DropPath({self.drop_prob})" 11 | 12 | class SS2D(nn.Module): 13 | def __init__( 14 | self, 15 | d_model, 16 | d_state=16, 17 | d_conv=3, 18 | ssm_ratio=2, 19 | dt_rank="auto", 20 | # ====================== 21 | dropout=0., 22 | conv_bias=True, 23 | bias=False, 24 | dtype=None, 25 | # ====================== 26 | dt_min=0.001, 27 | dt_max=0.1, 28 | dt_init="random", 29 | dt_scale=1.0, 30 | dt_init_floor=1e-4, 31 | # ====================== 32 | shared_ssm=False, 33 | softmax_version=False, 34 | # ====================== 35 | **kwargs, 36 | ): 37 | factory_kwargs = {"device": None, "dtype": dtype} 38 | super().__init__() 39 | self.softmax_version = softmax_version 40 | self.d_model = d_model 41 | self.d_state = math.ceil(self.d_model / 6) if d_state == "auto" else d_state # 20240109 42 | self.d_conv = d_conv 43 | self.expand = ssm_ratio 44 | self.d_inner = int(self.expand * self.d_model) 45 | self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank 46 | self.K = 4 if not shared_ssm else 1 47 | 48 | self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs) 49 | self.conv2d = nn.Conv2d( 50 | in_channels=self.d_inner, 51 | out_channels=self.d_inner, 52 | groups=self.d_inner, 53 | bias=conv_bias, 54 | kernel_size=d_conv, 55 | padding=(d_conv - 1) // 2, 56 | **factory_kwargs, 57 | ) 58 | self.act = nn.SiLU() 59 | 60 | self.x_proj = [ 61 | nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs) 62 | for _ in range(self.K) 63 | ] 64 | self.x_proj_weight = nn.Parameter(torch.stack([t.weight for t in self.x_proj], dim=0)) # (K, N, inner) 65 | del self.x_proj 66 | 67 | self.dt_projs = [ 68 | self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, **factory_kwargs) 69 | for _ in range(self.K) 70 | ] 71 | self.dt_projs_weight = nn.Parameter(torch.stack([t.weight for t in self.dt_projs], dim=0)) # (K, inner, rank) 72 | self.dt_projs_bias = nn.Parameter(torch.stack([t.bias for t in self.dt_projs], dim=0)) # (K * inner) 73 | del self.dt_projs 74 | 75 | self.A_logs = self.A_log_init(self.d_state, self.d_inner, copies=self.K, merge=True) # (K * D, N) 76 | self.Ds = self.D_init(self.d_inner, copies=self.K, merge=True) # (K * D) 77 | 78 | if not self.softmax_version: 79 | self.out_norm = nn.LayerNorm(self.d_inner) 80 | self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) 81 | self.dropout = nn.Dropout(dropout) if dropout > 0. else None 82 | 83 | @staticmethod 84 | def dt_init(dt_rank, d_inner, dt_scale=1.0, dt_init="random", dt_min=0.001, dt_max=0.1, dt_init_floor=1e-4, **factory_kwargs): 85 | dt_proj = nn.Linear(dt_rank, d_inner, bias=True, **factory_kwargs) 86 | 87 | # Initialize special dt projection to preserve variance at initialization 88 | dt_init_std = dt_rank**-0.5 * dt_scale 89 | if dt_init == "constant": 90 | nn.init.constant_(dt_proj.weight, dt_init_std) 91 | elif dt_init == "random": 92 | nn.init.uniform_(dt_proj.weight, -dt_init_std, dt_init_std) 93 | else: 94 | raise NotImplementedError 95 | 96 | # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max 97 | dt = torch.exp( 98 | torch.rand(d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min)) 99 | + math.log(dt_min) 100 | ).clamp(min=dt_init_floor) 101 | # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 102 | inv_dt = dt + torch.log(-torch.expm1(-dt)) 103 | with torch.no_grad(): 104 | dt_proj.bias.copy_(inv_dt) 105 | # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit 106 | dt_proj.bias._no_reinit = True 107 | 108 | return dt_proj 109 | 110 | @staticmethod 111 | def A_log_init(d_state, d_inner, copies=-1, device=None, merge=True): 112 | # S4D real initialization 113 | A = repeat( 114 | torch.arange(1, d_state + 1, dtype=torch.float32, device=device), 115 | "n -> d n", 116 | d=d_inner, 117 | ).contiguous() 118 | A_log = torch.log(A) # Keep A_log in fp32 119 | if copies > 0: 120 | A_log = repeat(A_log, "d n -> r d n", r=copies) 121 | if merge: 122 | A_log = A_log.flatten(0, 1) 123 | A_log = nn.Parameter(A_log) 124 | A_log._no_weight_decay = True 125 | return A_log 126 | 127 | @staticmethod 128 | def D_init(d_inner, copies=-1, device=None, merge=True): 129 | # D "skip" parameter 130 | D = torch.ones(d_inner, device=device) 131 | if copies > 0: 132 | D = repeat(D, "n1 -> r n1", r=copies) 133 | if merge: 134 | D = D.flatten(0, 1) 135 | D = nn.Parameter(D) # Keep in fp32 136 | D._no_weight_decay = True 137 | return D 138 | 139 | def forward_corev0(self, x: torch.Tensor): 140 | self.selective_scan = selective_scan_fn 141 | 142 | B, C, H, W = x.shape 143 | L = H * W 144 | K = 4 145 | 146 | x_hwwh = torch.stack([x.view(B, -1, L), torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, -1, L)], dim=1).view(B, 2, -1, L) 147 | xs = torch.cat([x_hwwh, torch.flip(x_hwwh, dims=[-1])], dim=1) # (b, k, d, l) 148 | 149 | x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs, self.x_proj_weight) 150 | # x_dbl = x_dbl + self.x_proj_bias.view(1, K, -1, 1) 151 | dts, Bs, Cs = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=2) 152 | dts = torch.einsum("b k r l, k d r -> b k d l", dts, self.dt_projs_weight) 153 | 154 | xs = xs.float().view(B, -1, L) # (b, k * d, l) 155 | dts = dts.contiguous().float().view(B, -1, L) # (b, k * d, l) 156 | Bs = Bs.float() # (b, k, d_state, l) 157 | Cs = Cs.float() # (b, k, d_state, l) 158 | 159 | As = -torch.exp(self.A_logs.float()) # (k * d, d_state) 160 | Ds = self.Ds.float() # (k * d) 161 | dt_projs_bias = self.dt_projs_bias.float().view(-1) # (k * d) 162 | 163 | # assert len(xs.shape) == 3 and len(dts.shape) == 3 and len(Bs.shape) == 4 and len(Cs.shape) == 4 164 | # assert len(As.shape) == 2 and len(Ds.shape) == 1 and len(dt_projs_bias.shape) == 1 165 | 166 | out_y = self.selective_scan( 167 | xs, dts, 168 | As, Bs, Cs, Ds, z=None, 169 | delta_bias=dt_projs_bias, 170 | delta_softplus=True, 171 | return_last_state=False, 172 | ).view(B, K, -1, L) 173 | assert out_y.dtype == torch.float 174 | 175 | inv_y = torch.flip(out_y[:, 2:4], dims=[-1]).view(B, 2, -1, L) 176 | wh_y = torch.transpose(out_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L) 177 | invwh_y = torch.transpose(inv_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L) 178 | y = out_y[:, 0] + inv_y[:, 0] + wh_y + invwh_y 179 | y = torch.transpose(y, dim0=1, dim1=2).contiguous().view(B, H, W, -1) 180 | y = self.out_norm(y) 181 | 182 | return y 183 | 184 | forward_core = forward_corev0 185 | 186 | def forward(self, x: torch.Tensor, **kwargs): 187 | xz = self.in_proj(x) 188 | x, z = xz.chunk(2, dim=-1) # (b, h, w, d) 189 | 190 | x = x.permute(0, 3, 1, 2).contiguous() 191 | x = self.act(self.conv2d(x)) # (b, d, h, w) 192 | y = self.forward_core(x) 193 | y = y * F.silu(z) 194 | out = self.out_proj(y) 195 | if self.dropout is not None: 196 | out = self.dropout(out) 197 | return out 198 | 199 | class Mlp(nn.Module): 200 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.,channels_first=False): 201 | super().__init__() 202 | out_features = out_features or in_features 203 | hidden_features = hidden_features or in_features 204 | 205 | Linear = partial(nn.Conv2d, kernel_size=1, padding=0) if channels_first else nn.Linear 206 | self.fc1 = Linear(in_features, hidden_features) 207 | self.act = act_layer() 208 | self.fc2 = Linear(hidden_features, out_features) 209 | self.drop = nn.Dropout(drop) 210 | 211 | def forward(self, x): 212 | x = self.fc1(x) 213 | x = self.act(x) 214 | x = self.drop(x) 215 | x = self.fc2(x) 216 | x = self.drop(x) 217 | return x 218 | 219 | 220 | class STMBlock(nn.Module): 221 | def __init__( 222 | self, 223 | hidden_dim: int = 0, 224 | drop_path: float = 0, 225 | norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), 226 | attn_drop_rate: float = 0, 227 | d_state: int = 16, 228 | dt_rank: Any = "auto", 229 | ssm_ratio=2.0, 230 | shared_ssm=False, 231 | softmax_version=False, 232 | mlp_ratio=4.0, 233 | act_layer=nn.GELU, 234 | drop: float = 0.0, 235 | **kwargs, 236 | ): 237 | super().__init__() 238 | self.norm = norm_layer(hidden_dim) 239 | self.op = SS2D( 240 | d_model=hidden_dim, 241 | dropout=attn_drop_rate, 242 | d_state=d_state, 243 | ssm_ratio=ssm_ratio, 244 | dt_rank=dt_rank, 245 | shared_ssm=shared_ssm, 246 | softmax_version=softmax_version, 247 | **kwargs 248 | ) 249 | self.drop_path = DropPath(drop_path) 250 | 251 | self.mlp_branch = mlp_ratio > 0 252 | if self.mlp_branch: 253 | self.norm2 = norm_layer(hidden_dim) 254 | mlp_hidden_dim = int(hidden_dim * mlp_ratio) 255 | self.mlp = Mlp(in_features=hidden_dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, channels_first=False) 256 | 257 | def forward(self, input: torch.Tensor): 258 | x = input + self.drop_path(self.op(self.norm(input))) 259 | if self.mlp_branch: 260 | x = x + self.drop_path(self.mlp(self.norm2(x))) # FFN 261 | return x 262 | 263 | class PatchExpand(nn.Module): 264 | """ 265 | Reference: https://arxiv.org/pdf/2105.05537.pdf 266 | """ 267 | def __init__(self, input_resolution, dim, dim_scale=2, norm_layer=nn.LayerNorm): 268 | super().__init__() 269 | self.dim = dim 270 | self.expand = nn.Linear(dim, 2*dim, bias=False) if dim_scale==2 else nn.Identity() 271 | self.norm = norm_layer(dim // dim_scale) 272 | 273 | def forward(self, x): 274 | x = x.permute(0, 2, 3, 1) # B, C, H, W ==> B, H, W, C 275 | x = self.expand(x) 276 | B, H, W, C = x.shape 277 | 278 | x = x.view(B, H, W, C) 279 | x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=2, p2=2, c=C//4) 280 | x = x.view(B,-1,C//4) 281 | x = self.norm(x) 282 | x = x.reshape(B, H*2, W*2, C//4) 283 | 284 | return x 285 | 286 | class FinalPatchExpand_X4(nn.Module): 287 | """ 288 | Reference: 289 | - GitHub: https://github.com/HuCaoFighting/Swin-Unet/blob/main/networks/swin_transformer_unet_skip_expand_decoder_sys.py 290 | - Paper: https://arxiv.org/pdf/2105.05537.pdf 291 | """ 292 | def __init__(self, input_resolution, dim, dim_scale=4, norm_layer=nn.LayerNorm): 293 | super().__init__() 294 | # self.input_resolution = input_resolution 295 | self.dim = dim 296 | self.dim_scale = dim_scale 297 | self.expand = nn.Linear(dim, 16*dim, bias=False) 298 | self.output_dim = dim 299 | self.norm = norm_layer(self.output_dim) 300 | 301 | def forward(self, x): 302 | """ 303 | x: B, H*W, C 304 | """ 305 | # H, W = self.input_resolution 306 | x = x.permute(0, 2, 3, 1) # B, C, H, W ==> B, H, W, C 307 | x = self.expand(x) 308 | B, H, W, C = x.shape 309 | 310 | x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=self.dim_scale, p2=self.dim_scale, c=C//(self.dim_scale**2)) 311 | x = x.view(B,-1,self.output_dim) 312 | x = self.norm(x) 313 | x = x.reshape(B, H*self.dim_scale, W*self.dim_scale, self.output_dim) 314 | 315 | return x#.permute(0, 3, 1, 2) -------------------------------------------------------------------------------- /pretrained_ckpts/run.sh: -------------------------------------------------------------------------------- 1 | wget https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22kto1k.pth 2 | python swin2mmseg.py swin_large_patch4_window12_384_22kto1k.pth 3 | -------------------------------------------------------------------------------- /pretrained_ckpts/swin2mmseg.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | import os.path as osp 4 | from collections import OrderedDict 5 | 6 | import mmengine 7 | import torch 8 | from mmengine.runner import CheckpointLoader 9 | 10 | 11 | def convert_swin(ckpt): 12 | new_ckpt = OrderedDict() 13 | 14 | def correct_unfold_reduction_order(x): 15 | out_channel, in_channel = x.shape 16 | x = x.reshape(out_channel, 4, in_channel // 4) 17 | x = x[:, [0, 2, 1, 3], :].transpose(1, 18 | 2).reshape(out_channel, in_channel) 19 | return x 20 | 21 | def correct_unfold_norm_order(x): 22 | in_channel = x.shape[0] 23 | x = x.reshape(4, in_channel // 4) 24 | x = x[[0, 2, 1, 3], :].transpose(0, 1).reshape(in_channel) 25 | return x 26 | 27 | for k, v in ckpt.items(): 28 | if k.startswith('head'): 29 | continue 30 | elif k.startswith('layers'): 31 | new_v = v 32 | if 'attn.' in k: 33 | new_k = k.replace('attn.', 'attn.w_msa.') 34 | elif 'mlp.' in k: 35 | if 'mlp.fc1.' in k: 36 | new_k = k.replace('mlp.fc1.', 'ffn.layers.0.0.') 37 | elif 'mlp.fc2.' in k: 38 | new_k = k.replace('mlp.fc2.', 'ffn.layers.1.') 39 | else: 40 | new_k = k.replace('mlp.', 'ffn.') 41 | elif 'downsample' in k: 42 | new_k = k 43 | if 'reduction.' in k: 44 | new_v = correct_unfold_reduction_order(v) 45 | elif 'norm.' in k: 46 | new_v = correct_unfold_norm_order(v) 47 | else: 48 | new_k = k 49 | new_k = new_k.replace('layers', 'stages', 1) 50 | elif k.startswith('patch_embed'): 51 | new_v = v 52 | if 'proj' in k: 53 | new_k = k.replace('proj', 'projection') 54 | else: 55 | new_k = k 56 | else: 57 | new_v = v 58 | new_k = k 59 | 60 | new_ckpt[new_k] = new_v 61 | 62 | return new_ckpt 63 | 64 | 65 | def main(): 66 | parser = argparse.ArgumentParser( 67 | description='Convert keys in official pretrained swin models to' 68 | 'MMSegmentation style.') 69 | parser.add_argument('src', help='src model path or url') 70 | # The dst path must be a full path of the new checkpoint. 71 | # parser.add_argument('dst', help='save path') 72 | args = parser.parse_args() 73 | 74 | checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu') 75 | if 'state_dict' in checkpoint: 76 | state_dict = checkpoint['state_dict'] 77 | elif 'model' in checkpoint: 78 | state_dict = checkpoint['model'] 79 | else: 80 | state_dict = checkpoint 81 | weight = convert_swin(state_dict) 82 | # mmengine.mkdir_or_exist(osp.dirname(args.dst)) 83 | # torch.save(weight, args.dst) 84 | torch.save(weight, f'mmseg_{args.src}') 85 | 86 | if __name__ == '__main__': 87 | main() -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EnVision-Research/MTMamba/17d132d2a7a1ef80c35ffe40c8dd3ff72fa07d77/utils/__init__.py -------------------------------------------------------------------------------- /utils/common_config.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.utils.data import DataLoader 4 | from utils.custom_collate import collate_mil 5 | import pdb 6 | 7 | 8 | def get_backbone(p): 9 | """ Return the backbone """ 10 | 11 | if p['backbone'] == 'swin_large': 12 | from mmseg.models.backbones.swin import SwinTransformer 13 | backbone_channels = None 14 | ppath = './pretrained_ckpts/' 15 | backbone = SwinTransformer(patch_size=4, window_size=12, embed_dims=192, 16 | depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), pretrain_img_size=p.TRAIN.SCALE, 17 | pretrained=f'{ppath}/mmseg_swin_large_patch4_window12_384_22kto1k.pth') 18 | backbone.init_weights() 19 | else: 20 | raise NotImplementedError 21 | 22 | return backbone, backbone_channels 23 | 24 | 25 | def get_model(p): 26 | """ Return the model """ 27 | 28 | if p['model'] == 'MTMamba': 29 | backbone, backbone_channels = get_backbone(p) 30 | from models.MTMamba import MTMamba 31 | model = MTMamba(p, backbone) 32 | elif p['model'] == 'MTMamba_plus': 33 | backbone, backbone_channels = get_backbone(p) 34 | from models.MTMamba_plus import MTMamba_plus 35 | model = MTMamba_plus(p, backbone) 36 | else: 37 | raise NotImplementedError('Unknown model {}'.format(p['model'])) 38 | return model 39 | 40 | 41 | """ 42 | Transformations, datasets and dataloaders 43 | """ 44 | def get_transformations(p): 45 | """ Return transformations for training and evaluationg """ 46 | from data import transforms 47 | import torchvision 48 | 49 | # Training transformations 50 | if p['train_db_name'] == 'NYUD' or p['train_db_name'] == 'PASCALContext': 51 | train_transforms = torchvision.transforms.Compose([ # from ATRC 52 | transforms.RandomScaling(scale_factors=[0.5, 2.0], discrete=False), 53 | transforms.RandomCrop(size=p.TRAIN.SCALE, cat_max_ratio=0.75), 54 | transforms.RandomHorizontalFlip(p=0.5), 55 | transforms.PhotoMetricDistortion(), 56 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 57 | transforms.PadImage(size=p.TRAIN.SCALE), 58 | transforms.AddIgnoreRegions(), 59 | transforms.ToTensor(), 60 | ]) 61 | 62 | # Testing 63 | valid_transforms = torchvision.transforms.Compose([ 64 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 65 | transforms.PadImage(size=p.TEST.SCALE), 66 | transforms.AddIgnoreRegions(), 67 | transforms.ToTensor(), 68 | ]) 69 | return train_transforms, valid_transforms 70 | 71 | else: 72 | return None, None 73 | 74 | 75 | def get_train_dataset(p, transforms=None): 76 | """ Return the train dataset """ 77 | 78 | db_name = p['train_db_name'] 79 | print('Preparing train dataset for db: {}'.format(db_name)) 80 | 81 | if db_name == 'PASCALContext': 82 | from data.pascal_context import PASCALContext 83 | database = PASCALContext(p.db_paths['PASCALContext'], download=False, split=['train'], transform=transforms, retname=True, 84 | do_semseg='semseg' in p.TASKS.NAMES, 85 | do_edge='edge' in p.TASKS.NAMES, 86 | do_normals='normals' in p.TASKS.NAMES, 87 | do_sal='sal' in p.TASKS.NAMES, 88 | do_human_parts='human_parts' in p.TASKS.NAMES, 89 | overfit=False) 90 | 91 | if db_name == 'NYUD': 92 | from data.nyud import NYUD_MT 93 | database = NYUD_MT(p.db_paths['NYUD_MT'], download=False, split='train', transform=transforms, do_edge='edge' in p.TASKS.NAMES, 94 | do_semseg='semseg' in p.TASKS.NAMES, 95 | do_normals='normals' in p.TASKS.NAMES, 96 | do_depth='depth' in p.TASKS.NAMES, overfit=False) 97 | 98 | if db_name == 'Cityscapes': 99 | from data.cityscapes import CITYSCAPES 100 | database = CITYSCAPES(p, p.db_paths['Cityscapes'], split=["train"], is_transform=True, 101 | img_size=p.TRAIN.SCALE, augmentations=None, 102 | task_list=p.TASKS.NAMES) 103 | 104 | return database 105 | 106 | 107 | def get_train_dataloader(p, dataset, sampler): 108 | """ Return the train dataloader """ 109 | collate = collate_mil 110 | trainloader = DataLoader(dataset, batch_size=p['trBatch'], drop_last=True, 111 | num_workers=p['nworkers'], collate_fn=collate, pin_memory=True, sampler=sampler) 112 | return trainloader 113 | 114 | 115 | def get_test_dataset(p, transforms=None): 116 | """ Return the test dataset """ 117 | 118 | db_name = p['val_db_name'] 119 | print('Preparing test dataset for db: {}'.format(db_name)) 120 | 121 | if db_name == 'PASCALContext': 122 | from data.pascal_context import PASCALContext 123 | database = PASCALContext(p.db_paths['PASCALContext'], download=False, split=['val'], transform=transforms, retname=True, 124 | do_semseg='semseg' in p.TASKS.NAMES, 125 | do_edge='edge' in p.TASKS.NAMES, 126 | do_normals='normals' in p.TASKS.NAMES, 127 | do_sal='sal' in p.TASKS.NAMES, 128 | do_human_parts='human_parts' in p.TASKS.NAMES, 129 | overfit=False) 130 | 131 | elif db_name == 'NYUD': 132 | from data.nyud import NYUD_MT 133 | database = NYUD_MT(p.db_paths['NYUD_MT'], download=False, split='val', transform=transforms, do_edge='edge' in p.TASKS.NAMES, 134 | do_semseg='semseg' in p.TASKS.NAMES, 135 | do_normals='normals' in p.TASKS.NAMES, 136 | do_depth='depth' in p.TASKS.NAMES) 137 | 138 | elif db_name == 'Cityscapes': 139 | from data.cityscapes import CITYSCAPES 140 | database = CITYSCAPES(p, p.db_paths['Cityscapes'], split=["val"], is_transform=True, 141 | img_size=p.TEST.SCALE, augmentations=None, 142 | task_list=p.TASKS.NAMES) 143 | 144 | else: 145 | raise NotImplemented("test_db_name: Choose among PASCALContext and NYUD") 146 | 147 | return database 148 | 149 | 150 | def get_test_dataloader(p, dataset): 151 | """ Return the validation dataloader """ 152 | collate = collate_mil 153 | testloader = DataLoader(dataset, batch_size=p['valBatch'], shuffle=False, drop_last=False, 154 | num_workers=p['nworkers'], pin_memory=True, collate_fn=collate) 155 | return testloader 156 | 157 | 158 | """ 159 | Loss functions 160 | """ 161 | def get_loss(p, task=None): 162 | """ Return loss function for a specific task """ 163 | 164 | if task == 'edge': 165 | from losses.loss_functions import BalancedBinaryCrossEntropyLoss 166 | criterion = BalancedBinaryCrossEntropyLoss(pos_weight=p['edge_w'], ignore_index=p.ignore_index) 167 | 168 | elif task == 'semseg' or task == 'human_parts': 169 | from losses.loss_functions import CrossEntropyLoss 170 | criterion = CrossEntropyLoss(ignore_index=p.ignore_index) 171 | 172 | elif task == 'normals': 173 | from losses.loss_functions import L1Loss 174 | criterion = L1Loss(normalize=True, ignore_index=p.ignore_index) 175 | 176 | elif task == 'sal': 177 | from losses.loss_functions import CrossEntropyLoss 178 | criterion = CrossEntropyLoss(balanced=True, ignore_index=p.ignore_index) 179 | 180 | elif task == 'depth': 181 | from losses.loss_functions import L1Loss 182 | criterion = L1Loss() 183 | 184 | else: 185 | criterion = None 186 | 187 | return criterion 188 | 189 | 190 | def get_criterion(p): 191 | if p['loss_kwargs']['loss_scheme'] == 'log': 192 | from losses.loss_schemes import MultiTaskLoss_log 193 | loss_ft = torch.nn.ModuleDict({task: get_loss(p, task) for task in p.TASKS.NAMES}) 194 | loss_weights = p['loss_kwargs']['loss_weights'] 195 | return MultiTaskLoss_log(p, p.TASKS.NAMES, loss_ft, loss_weights) 196 | else: 197 | from losses.loss_schemes import MultiTaskLoss 198 | loss_ft = torch.nn.ModuleDict({task: get_loss(p, task) for task in p.TASKS.NAMES}) 199 | loss_weights = p['loss_kwargs']['loss_weights'] 200 | return MultiTaskLoss(p, p.TASKS.NAMES, loss_ft, loss_weights) 201 | 202 | 203 | """ 204 | Optimizers and schedulers 205 | """ 206 | def get_optimizer(p, model): 207 | """ Return optimizer for a given model and setup """ 208 | 209 | print('Optimizer uses a single parameter group - (Default)') 210 | params = model.parameters() 211 | 212 | backbone_params, others_params = [], [] 213 | for pn, pp in model.named_parameters(): 214 | if 'backbone' in pn: 215 | backbone_params.append(pp) 216 | else: 217 | others_params.append(pp) 218 | 219 | if p['optimizer'] == 'sgd': 220 | optimizer = torch.optim.SGD(params, **p['optimizer_kwargs']) 221 | 222 | elif p['optimizer'] == 'adam': 223 | optimizer = torch.optim.Adam(params, **p['optimizer_kwargs']) 224 | 225 | elif p['optimizer'] == 'adamw': 226 | optimizer = torch.optim.AdamW( 227 | params, 228 | lr=p['optimizer_kwargs']['lr'], 229 | betas=(0.9, 0.999), eps=1e-08, 230 | weight_decay=p['optimizer_kwargs']['weight_decay'], amsgrad=False) 231 | 232 | else: 233 | raise ValueError('Invalid optimizer {}'.format(p['optimizer'])) 234 | 235 | # get scheduler 236 | if p.scheduler == 'poly': 237 | from utils.train_utils import PolynomialLR 238 | scheduler = PolynomialLR(optimizer, p.max_iter, gamma=0.9, min_lr=0) 239 | else: 240 | scheduler = None 241 | 242 | return scheduler, optimizer 243 | 244 | -------------------------------------------------------------------------------- /utils/config.py: -------------------------------------------------------------------------------- 1 | # This code is referenced from 2 | # https://github.com/facebookresearch/astmt/ 3 | # 4 | # Copyright (c) Facebook, Inc. and its affiliates. 5 | # All rights reserved. 6 | # 7 | # License: Attribution-NonCommercial 4.0 International 8 | 9 | import os 10 | import cv2 11 | import yaml 12 | from easydict import EasyDict as edict 13 | from utils.utils import mkdir_if_missing 14 | import pdb 15 | 16 | 17 | def parse_task_dictionary(db_name, task_dictionary): 18 | """ 19 | Return a dictionary with task information. 20 | Additionally we return a dict with key, values to be added to the main dictionary 21 | """ 22 | 23 | task_cfg = edict() 24 | other_args = dict() 25 | task_cfg.NAMES = [] 26 | task_cfg.NUM_OUTPUT = {} 27 | task_cfg.FLAGVALS = {'image': cv2.INTER_CUBIC} 28 | task_cfg.INFER_FLAGVALS = {} 29 | 30 | if 'include_semseg' in task_dictionary.keys() and task_dictionary['include_semseg']: 31 | tmp = 'semseg' 32 | task_cfg.NAMES.append('semseg') 33 | if db_name == 'PASCALContext': 34 | task_cfg.NUM_OUTPUT[tmp] = 21 35 | elif db_name == 'NYUD': 36 | task_cfg.NUM_OUTPUT[tmp] = 40 37 | elif db_name == 'Cityscapes': 38 | task_cfg.NUM_OUTPUT[tmp] = 19 39 | else: 40 | raise NotImplementedError 41 | task_cfg.FLAGVALS[tmp] = cv2.INTER_NEAREST 42 | task_cfg.INFER_FLAGVALS[tmp] = cv2.INTER_NEAREST 43 | 44 | if 'include_depth' in task_dictionary.keys() and task_dictionary['include_depth']: 45 | tmp = 'depth' 46 | task_cfg.NAMES.append(tmp) 47 | task_cfg.NUM_OUTPUT[tmp] = 1 48 | task_cfg.FLAGVALS[tmp] = cv2.INTER_NEAREST 49 | task_cfg.INFER_FLAGVALS[tmp] = cv2.INTER_LINEAR 50 | 51 | if db_name == 'Cityscapes': 52 | task_cfg.depth_max = 80.0 53 | task_cfg.depth_min = 0. 54 | else: 55 | task_cfg.depth_max = None 56 | task_cfg.depth_min = None 57 | 58 | if 'include_human_parts' in task_dictionary.keys() and task_dictionary['include_human_parts']: 59 | # Human Parts Segmentation 60 | assert(db_name == 'PASCALContext') 61 | tmp = 'human_parts' 62 | task_cfg.NAMES.append(tmp) 63 | task_cfg.NUM_OUTPUT[tmp] = 7 64 | task_cfg.FLAGVALS[tmp] = cv2.INTER_NEAREST 65 | task_cfg.INFER_FLAGVALS[tmp] = cv2.INTER_NEAREST 66 | 67 | if 'include_sal' in task_dictionary.keys() and task_dictionary['include_sal']: 68 | # Saliency Estimation 69 | assert(db_name == 'PASCALContext') 70 | tmp = 'sal' 71 | task_cfg.NAMES.append(tmp) 72 | task_cfg.NUM_OUTPUT[tmp] = 2 73 | task_cfg.FLAGVALS[tmp] = cv2.INTER_NEAREST 74 | task_cfg.INFER_FLAGVALS[tmp] = cv2.INTER_LINEAR 75 | 76 | if 'include_normals' in task_dictionary.keys() and task_dictionary['include_normals']: 77 | # Surface Normals 78 | tmp = 'normals' 79 | assert(db_name in ['PASCALContext', 'NYUD']) 80 | task_cfg.NAMES.append(tmp) 81 | task_cfg.NUM_OUTPUT[tmp] = 3 82 | task_cfg.FLAGVALS[tmp] = cv2.INTER_CUBIC 83 | task_cfg.INFER_FLAGVALS[tmp] = cv2.INTER_LINEAR 84 | task_cfg.INFER_FLAGVALS['normals'] = cv2.INTER_LINEAR 85 | 86 | if 'include_edge' in task_dictionary.keys() and task_dictionary['include_edge']: 87 | # Edge Detection 88 | assert(db_name in ['PASCALContext', 'NYUD']) 89 | tmp = 'edge' 90 | task_cfg.NAMES.append(tmp) 91 | task_cfg.NUM_OUTPUT[tmp] = 1 92 | task_cfg.FLAGVALS[tmp] = cv2.INTER_NEAREST 93 | task_cfg.INFER_FLAGVALS[tmp] = cv2.INTER_LINEAR 94 | other_args['edge_w'] = task_dictionary['edge_w'] 95 | other_args['eval_edge'] = False 96 | task_cfg.INFER_FLAGVALS['edge'] = cv2.INTER_LINEAR 97 | 98 | return task_cfg, other_args 99 | 100 | 101 | def create_config(exp_file, params): 102 | 103 | with open(exp_file, 'r') as stream: 104 | config = yaml.safe_load(stream) 105 | 106 | # Copy all the arguments 107 | cfg = edict() 108 | for k, v in config.items(): 109 | cfg[k] = v 110 | 111 | # set root dir 112 | root_dir = cfg["out_dir"] + cfg['version_name'] 113 | 114 | # Parse the task dictionary separately 115 | cfg.TASKS, extra_args = parse_task_dictionary(cfg['train_db_name'], cfg['task_dictionary']) 116 | 117 | for k, v in extra_args.items(): 118 | cfg[k] = v 119 | 120 | # Other arguments 121 | if cfg['train_db_name'] == 'PASCALContext': 122 | cfg.TRAIN = edict() 123 | cfg.TRAIN.SCALE = (512, 512) 124 | cfg.TEST = edict() 125 | cfg.TEST.SCALE = (512, 512) 126 | cfg.IMAGE_ORI_SIZE = (512, 512) 127 | 128 | elif cfg['train_db_name'] == 'NYUD': 129 | cfg.TRAIN = edict() 130 | cfg.TEST = edict() 131 | cfg.IMAGE_ORI_SIZE = (448, 576) 132 | cfg.TRAIN.SCALE = (448, 576) 133 | cfg.TEST.SCALE = (448, 576) 134 | elif cfg['train_db_name'] == 'Cityscapes': 135 | cfg.IMAGE_ORI_SIZE = (1024, 2048) 136 | cfg.TRAIN = edict() 137 | cfg.TRAIN.SCALE = (512, 1024) 138 | cfg.TEST = edict() 139 | cfg.TEST.SCALE = (512, 1024) # original size 140 | 141 | else: 142 | raise NotImplementedError 143 | 144 | # set log dir 145 | output_dir = root_dir 146 | cfg['root_dir'] = root_dir 147 | cfg['output_dir'] = output_dir 148 | cfg['save_dir'] = os.path.join(output_dir, 'results') 149 | cfg['checkpoint'] = os.path.join(output_dir, 'checkpoint.pth.tar') 150 | if params['run_mode'] != 'infer': 151 | mkdir_if_missing(cfg['output_dir']) 152 | mkdir_if_missing(cfg['save_dir']) 153 | 154 | from configs.mypath import db_paths, PROJECT_ROOT_DIR 155 | params['db_paths'] = db_paths 156 | params['PROJECT_ROOT_DIR'] = PROJECT_ROOT_DIR 157 | 158 | cfg.update(params) 159 | 160 | return cfg 161 | -------------------------------------------------------------------------------- /utils/custom_collate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import collections 3 | import re 4 | int_classes = int 5 | string_classes = str 6 | 7 | _use_shared_memory = False 8 | r"""Whether to use shared memory in default_collate""" 9 | 10 | 11 | numpy_type_map = { 12 | 'float64': torch.DoubleTensor, 13 | 'float32': torch.FloatTensor, 14 | 'float16': torch.HalfTensor, 15 | 'int64': torch.LongTensor, 16 | 'int32': torch.IntTensor, 17 | 'int16': torch.ShortTensor, 18 | 'int8': torch.CharTensor, 19 | 'uint8': torch.ByteTensor, 20 | } 21 | 22 | 23 | def collate_mil(batch): 24 | """ 25 | Puts each data field into a tensor with outer dimension batch size. 26 | Custom-made for supporting MIL 27 | """ 28 | error_msg = "batch must contain tensors, numbers, dicts or lists; found {}" 29 | if len(batch) == 0: # probably when there is no valid bbox in this sample 30 | return batch 31 | elem_type = type(batch[0]) 32 | if isinstance(batch[0], torch.Tensor): 33 | out = None 34 | if _use_shared_memory: 35 | # If we're in a background process, concatenate directly into a 36 | # shared memory tensor to avoid an extra copy 37 | numel = sum([x.numel() for x in batch]) 38 | storage = batch[0].storage()._new_shared(numel) 39 | out = batch[0].new(storage) 40 | return torch.stack(batch, 0, out=out) 41 | 42 | elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ 43 | and elem_type.__name__ != 'string_': 44 | elem = batch[0] 45 | if elem_type.__name__ == 'ndarray': 46 | # array of string classes and object 47 | if re.search('[SaUO]', elem.dtype.str) is not None: 48 | raise TypeError(error_msg.format(elem.dtype)) 49 | 50 | return torch.stack([torch.from_numpy(b) for b in batch], 0) 51 | if elem.shape == (): # scalars 52 | py_type = float if elem.dtype.name.startswith('float') else int 53 | return numpy_type_map[elem.dtype.name](list(map(py_type, batch))) 54 | 55 | elif isinstance(batch[0], int_classes): 56 | return torch.LongTensor(batch) 57 | 58 | elif isinstance(batch[0], float): 59 | return torch.DoubleTensor(batch) 60 | 61 | elif isinstance(batch[0], string_classes): 62 | return batch 63 | 64 | elif isinstance(batch[0], collections.Mapping): 65 | batch_modified = {key: collate_mil([d[key] for d in batch]) for key in batch[0] if key.find('idx') < 0} 66 | if 'edgeidx' in batch[0]: 67 | batch_modified['edgeidx'] = [batch[x]['edgeidx'] for x in range(len(batch))] 68 | return batch_modified 69 | 70 | elif isinstance(batch[0], collections.Sequence): 71 | # transposed = zip(*batch) 72 | # return [collate_mil(samples) for samples in transposed] 73 | 74 | # yhr: change this for tolerating lists with different lengths from different samples. 75 | out = [] 76 | for samples in batch: 77 | out.append(collate_mil(samples)) 78 | return out 79 | 80 | raise TypeError((error_msg.format(type(batch[0])))) 81 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | # 2 | # Authors: Simon Vandenhende 3 | # Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/) 4 | 5 | import os 6 | import sys 7 | 8 | 9 | class Logger(object): 10 | def __init__(self, fpath=None): 11 | self.console = sys.stdout 12 | self.file = None 13 | self.fpath = fpath 14 | if fpath is not None: 15 | if not os.path.exists(os.path.dirname(fpath)): 16 | os.makedirs(os.path.dirname(fpath)) 17 | self.file = open(fpath, 'w') 18 | else: 19 | self.file = open(fpath, 'a') 20 | 21 | def __del__(self): 22 | self.close() 23 | 24 | def __enter__(self): 25 | pass 26 | 27 | def __exit__(self, *args): 28 | self.close() 29 | 30 | def write(self, msg): 31 | self.console.write(msg) 32 | if self.file is not None: 33 | self.file.write(msg) 34 | 35 | def flush(self): 36 | self.console.flush() 37 | if self.file is not None: 38 | self.file.flush() 39 | os.fsync(self.file.fileno()) 40 | 41 | def close(self): 42 | self.console.close() 43 | if self.file is not None: 44 | self.file.close() 45 | -------------------------------------------------------------------------------- /utils/test_utils.py: -------------------------------------------------------------------------------- 1 | # By Hanrong Ye 2 | # Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/) 3 | 4 | from evaluation.evaluate_utils import PerformanceMeter 5 | from tqdm import tqdm 6 | from utils.utils import get_output, mkdir_if_missing 7 | from evaluation.evaluate_utils import save_model_pred_for_one_task 8 | import torch 9 | import os 10 | 11 | @torch.no_grad() 12 | def test_phase(p, test_loader, model, criterion, epoch, save_edge=False): 13 | tasks = p.TASKS.NAMES 14 | 15 | performance_meter = PerformanceMeter(p, tasks) 16 | 17 | model.eval() 18 | 19 | if save_edge: 20 | tasks_to_save = ['edge'] 21 | save_dirs = {task: os.path.join(p['save_dir'], task) for task in tasks_to_save} 22 | for save_dir in save_dirs.values(): 23 | mkdir_if_missing(save_dir) 24 | 25 | for i, batch in enumerate(tqdm(test_loader)): 26 | # Forward pass 27 | with torch.no_grad(): 28 | images = batch['image'].cuda(non_blocking=True) 29 | targets = {task: batch[task].cuda(non_blocking=True) for task in tasks} 30 | 31 | output = model.module(images) 32 | 33 | # Measure loss and performance 34 | performance_meter.update({t: get_output(output[t], t) for t in tasks}, 35 | {t: targets[t] for t in tasks}) 36 | 37 | if save_edge: 38 | for task in tasks_to_save: 39 | try: 40 | save_model_pred_for_one_task(p, batch, output, save_dirs, task, epoch=epoch) 41 | except: 42 | pass 43 | 44 | 45 | eval_results = performance_meter.get_score(verbose = True) 46 | 47 | return eval_results -------------------------------------------------------------------------------- /utils/train_utils.py: -------------------------------------------------------------------------------- 1 | # Rewritten based on MTI-Net by Hanrong Ye 2 | # Original authors: Simon Vandenhende 3 | # Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/) 4 | 5 | 6 | import os, json 7 | from evaluation.evaluate_utils import PerformanceMeter, count_improvement 8 | from utils.utils import to_cuda 9 | import torch 10 | from tqdm import tqdm 11 | from utils.test_utils import test_phase 12 | import pdb 13 | 14 | # import torch.profiler 15 | 16 | def update_tb(tb_writer, tag, loss_dict, iter_no): 17 | for k, v in loss_dict.items(): 18 | tb_writer.add_scalar(f'{tag}/{k}', v.item(), iter_no) 19 | 20 | 21 | def train_phase(p, args, train_loader, test_dataloader, model, criterion, optimizer, 22 | scheduler, epoch, tb_writer, tb_writer_test, iter_count, best_imp): 23 | """ Vanilla training with fixed loss weights """ 24 | model.train() 25 | 26 | for i, cpu_batch in enumerate(tqdm(train_loader)): 27 | # Forward pass 28 | batch = to_cuda(cpu_batch) 29 | images = batch['image'] 30 | 31 | output = model(images) 32 | iter_count += 1 33 | # Measure loss 34 | loss_dict = criterion(output, batch, tasks=p.TASKS.NAMES) 35 | # get learning rate 36 | if scheduler is not None: 37 | lr = scheduler.get_lr()[0] 38 | else: 39 | lr = optimizer.param_groups[0]['lr'] 40 | loss_dict['lr'] = torch.tensor(lr) 41 | 42 | if tb_writer is not None: 43 | update_tb(tb_writer, 'Train_Loss', loss_dict, iter_count) 44 | 45 | if args.local_rank == 0: 46 | print(f'Iter {iter_count}, ', end="") 47 | for k, v in loss_dict.items(): 48 | print('{}: {:.7f} | '.format(k, v), end="") 49 | print() 50 | 51 | # Backward 52 | optimizer.zero_grad() 53 | loss_dict['total'].backward() 54 | try: 55 | torch.nn.utils.clip_grad_norm_(model.parameters(), **p.grad_clip_param) 56 | except: 57 | pass 58 | optimizer.step() 59 | if scheduler is not None: 60 | scheduler.step() 61 | 62 | # end condition 63 | if iter_count >= p.max_iter: 64 | print('Max itereaction achieved.') 65 | if args.local_rank == 0: 66 | curr_result = test_phase(p, test_dataloader, model, criterion, epoch) 67 | torch.save({'model': model.state_dict()}, p['checkpoint'].replace('checkpoint.pth.tar', 'last_model.pth.tar')) 68 | with open(p['checkpoint'].replace('checkpoint.pth.tar', 'last.txt'), 'w') as f: 69 | json.dump(curr_result, f, indent=4) 70 | end_signal = True 71 | return True, iter_count, best_imp 72 | else: 73 | end_signal = False 74 | 75 | # Perform evaluation 76 | begin_eva = 1 77 | if args.local_rank == 0 and epoch >= begin_eva: 78 | print('Evaluate at epoch {}'.format(epoch)) 79 | curr_result = test_phase(p, test_dataloader, model, criterion, epoch) 80 | # tb_update_perf(p, tb_writer_test, curr_result, iter_count) 81 | print('Evaluate results at epoch {}: \n'.format(epoch)) 82 | print(curr_result) 83 | with open(os.path.join(p['save_dir'], p.version_name + '_' + str(epoch) + '.txt'), 'w') as f: 84 | json.dump(curr_result, f, indent=4) 85 | 86 | current_imp = count_improvement(p['train_db_name'], curr_result, p['TASKS']['NAMES']) 87 | if current_imp > best_imp: 88 | best_imp = current_imp 89 | # Checkpoint after evaluation 90 | print('Checkpoint starts at epoch {}....'.format(epoch)) 91 | torch.save({'model': model.state_dict()}, p['checkpoint']) 92 | print('Checkpoint finishs.') 93 | 94 | curr_result.update({'epoch': epoch, 'best_imp': best_imp}) 95 | with open(p['checkpoint'].replace('checkpoint.pth.tar', 'best.txt'), 'w') as f: 96 | json.dump(curr_result, f, indent=4) 97 | model.train() # set model back to train status 98 | 99 | # if end_signal: 100 | # return True, iter_count 101 | 102 | return False, iter_count, best_imp 103 | 104 | 105 | class PolynomialLR(torch.optim.lr_scheduler._LRScheduler): 106 | def __init__(self, optimizer, max_iterations, gamma=0.9, min_lr=0., last_epoch=-1): 107 | self.max_iterations = max_iterations 108 | self.gamma = gamma 109 | self.min_lr = min_lr 110 | super().__init__(optimizer, last_epoch) 111 | 112 | def get_lr(self): 113 | # slight abuse: last_epoch refers to last iteration 114 | factor = (1 - self.last_epoch / 115 | float(self.max_iterations)) ** self.gamma 116 | return [(base_lr - self.min_lr) * factor + self.min_lr for base_lr in self.base_lrs] 117 | 118 | def tb_update_perf(p, tb_writer_test, curr_result, cur_iter): 119 | if 'semseg' in p.TASKS.NAMES: 120 | tb_writer_test.add_scalar('perf/semseg_miou', curr_result['semseg']['mIoU'], cur_iter) 121 | if 'human_parts' in p.TASKS.NAMES: 122 | tb_writer_test.add_scalar('perf/human_parts_mIoU', curr_result['human_parts']['mIoU'], cur_iter) 123 | if 'sal' in p.TASKS.NAMES: 124 | tb_writer_test.add_scalar('perf/sal_maxF', curr_result['sal']['maxF'], cur_iter) 125 | if 'edge' in p.TASKS.NAMES: 126 | tb_writer_test.add_scalar('perf/edge_val_loss', curr_result['edge']['loss'], cur_iter) 127 | if 'normals' in p.TASKS.NAMES: 128 | tb_writer_test.add_scalar('perf/normals_mean', curr_result['normals']['mean'], cur_iter) 129 | if 'depth' in p.TASKS.NAMES: 130 | tb_writer_test.add_scalar('perf/depth_rmse', curr_result['depth']['rmse'], cur_iter) 131 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn.functional as F 4 | import pdb 5 | 6 | def mkdir_if_missing(directory): 7 | if not os.path.exists(directory): 8 | try: 9 | os.makedirs(directory) 10 | except OSError as e: 11 | pass 12 | 13 | 14 | def get_output(output, task): 15 | """Borrow from MTI-Net""" 16 | 17 | if task == 'normals': 18 | output = output.permute(0, 2, 3, 1) 19 | output = (F.normalize(output, p = 2, dim = 3) + 1.0) * 255 / 2.0 20 | 21 | elif task in {'semseg'}: 22 | output = output.permute(0, 2, 3, 1) 23 | _, output = torch.max(output, dim=3) 24 | 25 | elif task in {'human_parts'}: 26 | output = output.permute(0, 2, 3, 1) 27 | _, output = torch.max(output, dim=3) 28 | 29 | elif task in {'edge'}: 30 | output = output.permute(0, 2, 3, 1) 31 | output = torch.squeeze(255 * 1 / (1 + torch.exp(-output)), dim=3) 32 | 33 | elif task in {'sal'}: 34 | output = output.permute(0, 2, 3, 1) 35 | output = F.softmax(output, dim=3)[:, :, :, 1] *255 # torch.squeeze(255 * 1 / (1 + torch.exp(-output))) 36 | 37 | elif task in {'depth'}: 38 | output.clamp_(min=0.) 39 | output = output.permute(0, 2, 3, 1) 40 | 41 | else: 42 | raise ValueError('Select one of the valid tasks') 43 | 44 | return output 45 | 46 | def to_cuda(batch): 47 | if type(batch) == dict: 48 | out = {} 49 | for k, v in batch.items(): 50 | if k == 'meta': 51 | out[k] = v 52 | else: 53 | out[k] = to_cuda(v) 54 | return out 55 | elif type(batch) == torch.Tensor: 56 | return batch.cuda(non_blocking=True) 57 | elif type(batch) == list: 58 | return [to_cuda(v) for v in batch] 59 | else: 60 | return batch 61 | 62 | # From PyTorch internals 63 | import collections.abc as container_abcs 64 | from itertools import repeat 65 | def _ntuple(n): 66 | def parse(x): 67 | if isinstance(x, container_abcs.Iterable): 68 | return x 69 | return tuple(repeat(x, n)) 70 | 71 | return parse 72 | 73 | to_2tuple = _ntuple(2) --------------------------------------------------------------------------------