├── .gitignore ├── LICENSE ├── README.MD ├── additional_utils ├── encoding_models.py └── models.py ├── data └── __init__.py ├── fewshot_data ├── README.md ├── common │ ├── evaluation.py │ ├── logger.py │ ├── utils.py │ └── vis.py ├── data │ ├── assets │ │ ├── architecture.png │ │ └── qualitative_results.png │ ├── coco.py │ ├── dataset.py │ ├── fss.py │ ├── pascal.py │ └── splits │ │ ├── coco │ │ ├── trn │ │ │ ├── fold0.pkl │ │ │ ├── fold1.pkl │ │ │ ├── fold2.pkl │ │ │ └── fold3.pkl │ │ └── val │ │ │ ├── fold0.pkl │ │ │ ├── fold1.pkl │ │ │ ├── fold2.pkl │ │ │ └── fold3.pkl │ │ ├── fss │ │ ├── test.txt │ │ ├── trn.txt │ │ └── val.txt │ │ └── pascal │ │ ├── trn │ │ ├── fold0.txt │ │ ├── fold1.txt │ │ ├── fold2.txt │ │ └── fold3.txt │ │ └── val │ │ ├── fold0.txt │ │ ├── fold1.txt │ │ ├── fold2.txt │ │ └── fold3.txt ├── model │ ├── base │ │ ├── conv4d.py │ │ ├── correlation.py │ │ └── feature.py │ ├── hsnet.py │ └── learner.py ├── sbatch_run.sh ├── test.py └── train.py ├── inputs └── cat1.jpeg ├── label_files ├── ade20k_objectInfo150.txt ├── fewshot_coco.txt ├── fewshot_fss.txt └── fewshot_pascal.txt ├── lseg_app.py ├── lseg_demo.ipynb ├── modules ├── lseg_module.py ├── lseg_module_zs.py ├── lsegmentation_module.py ├── lsegmentation_module_zs.py └── models │ ├── lseg_blocks.py │ ├── lseg_blocks_zs.py │ ├── lseg_net.py │ ├── lseg_net_zs.py │ ├── lseg_vit.py │ └── lseg_vit_zs.py ├── prepare_ade20k.py ├── requirements.txt ├── test.sh ├── test_lseg.py ├── test_lseg_zs.py ├── train.sh ├── train_lseg.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | checkpoints/ 131 | logs/ 132 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Intelligent Systems Lab Org 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.MD: -------------------------------------------------------------------------------- 1 | # PROJECT NOT UNDER ACTIVE MANAGEMENT 2 | This project will no longer be maintained by Intel. 3 | Intel has ceased development and contributions including, but not limited to, maintenance, bug fixes, new releases, or updates, to this project. 4 | Intel no longer accepts patches to this project. 5 | If you have an ongoing need to use this project, are interested in independently developing it, or would like to maintain patches for the open source software community, please create your own fork of this project. 6 | 7 | # Language-driven Semantic Segmentation (LSeg) 8 | The repo contains official PyTorch Implementation of paper [Language-driven Semantic Segmentation](https://arxiv.org/abs/2201.03546). 9 | 10 | ICLR 2022 11 | 12 | #### Authors: 13 | * [Boyi Li](https://sites.google.com/site/boyilics/home) 14 | * [Kilian Q. Weinberger](http://kilian.cs.cornell.edu/index.html) 15 | * [Serge Belongie](https://scholar.google.com/citations?user=ORr4XJYAAAAJ&hl=zh-CN) 16 | * [Vladlen Koltun](http://vladlen.info/) 17 | * [Rene Ranftl](https://scholar.google.at/citations?user=cwKg158AAAAJ&hl=de) 18 | 19 | 20 | ### Overview 21 | 22 | 23 | We present LSeg, a novel model for language-driven semantic image segmentation. LSeg uses a text encoder to compute embeddings of descriptive input labels (e.g., ''grass'' or 'building'') together with a transformer-based image encoder that computes dense per-pixel embeddings of the input image. The image encoder is trained with a contrastive objective to align pixel embeddings to the text embedding of the corresponding semantic class. The text embeddings provide a flexible label representation in which semantically similar labels map to similar regions in the embedding space (e.g., ''cat'' and ''furry''). This allows LSeg to generalize to previously unseen categories at test time, without retraining or even requiring a single additional training sample. We demonstrate that our approach achieves highly competitive zero-shot performance compared to existing zero- and few-shot semantic segmentation methods, and even matches the accuracy of traditional segmentation algorithms when a fixed label set is provided. 24 | 25 | Please check our [Video Demo (4k)](https://www.youtube.com/watch?v=bmU75rsmv6s) to further showcase the capabilities of LSeg. 26 | 27 | ## Usage 28 | ### Installation 29 | Option 1: 30 | 31 | ``` pip install -r requirements.txt ``` 32 | 33 | Option 2: 34 | ``` 35 | conda install ipython 36 | pip install torch==1.7.1 torchvision==0.8.2 torchaudio==0.7.2 37 | pip install git+https://github.com/zhanghang1989/PyTorch-Encoding/ 38 | pip install pytorch-lightning==1.3.5 39 | pip install opencv-python 40 | pip install imageio 41 | pip install ftfy regex tqdm 42 | pip install git+https://github.com/openai/CLIP.git 43 | pip install altair 44 | pip install streamlit 45 | pip install --upgrade protobuf 46 | pip install timm 47 | pip install tensorboardX 48 | pip install matplotlib 49 | pip install test-tube 50 | pip install wandb 51 | ``` 52 | 53 | ### Data Preparation 54 | By default, for training, testing and demo, we use [ADE20k](https://groups.csail.mit.edu/vision/datasets/ADE20K/). 55 | 56 | ``` 57 | python prepare_ade20k.py 58 | unzip ../datasets/ADEChallengeData2016.zip 59 | ``` 60 | 61 | Note: for demo, if you want to use random inputs, you can ignore data loading and comment the code at [link](https://github.com/isl-org/lang-seg/blob/main/modules/lseg_module.py#L55). 62 | 63 | 64 | ### 🌻 Try demo now 65 | 66 | #### Download Demo Model 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 |
namebackbonetext encoderurl
Model for demoViT-L/16CLIP ViT-B/32download
85 | 86 | #### 👉 Option 1: Running interactive app 87 | Download the model for demo and put it under folder `checkpoints` as `checkpoints/demo_e200.ckpt`. 88 | 89 | Then ``` streamlit run lseg_app.py ``` 90 | 91 | #### 👉 Option 2: Jupyter Notebook 92 | Download the model for demo and put it under folder `checkpoints` as `checkpoints/demo_e200.ckpt`. 93 | 94 | Then follow [lseg_demo.ipynb](https://github.com/isl-org/lang-seg/blob/main/lseg_demo.ipynb) to play around with LSeg. Enjoy! 95 | 96 | 97 | 98 | ### Training and Testing Example 99 | Training: Backbone = ViT-L/16, Text Encoder from CLIP ViT-B/32 100 | 101 | ``` bash train.sh ``` 102 | 103 | Testing: Backbone = ViT-L/16, Text Encoder from CLIP ViT-B/32 104 | 105 | ``` bash test.sh ``` 106 | 107 | ### Zero-shot Experiments 108 | #### Data Preparation 109 | Please follow [HSNet](https://github.com/juhongm999/hsnet) and put all dataset in `data/Dataset_HSN` 110 | 111 | #### Pascal-5i 112 | ``` 113 | for fold in 0 1 2 3; do 114 | python -u test_lseg_zs.py --backbone clip_resnet101 --module clipseg_DPT_test_v2 --dataset pascal \ 115 | --widehead --no-scaleinv --arch_option 0 --ignore_index 255 --fold ${fold} --nshot 0 \ 116 | --weights checkpoints/pascal_fold${fold}.ckpt 117 | done 118 | ``` 119 | #### COCO-20i 120 | ``` 121 | for fold in 0 1 2 3; do 122 | python -u test_lseg_zs.py --backbone clip_resnet101 --module clipseg_DPT_test_v2 --dataset coco \ 123 | --widehead --no-scaleinv --arch_option 0 --ignore_index 255 --fold ${fold} --nshot 0 \ 124 | --weights checkpoints/pascal_fold${fold}.ckpt 125 | done 126 | ``` 127 | #### FSS 128 | ``` 129 | python -u test_lseg_zs.py --backbone clip_vitl16_384 --module clipseg_DPT_test_v2 --dataset fss \ 130 | --widehead --no-scaleinv --arch_option 0 --ignore_index 255 --fold 0 --nshot 0 \ 131 | --weights checkpoints/fss_l16.ckpt 132 | ``` 133 | 134 | ``` 135 | python -u test_lseg_zs.py --backbone clip_resnet101 --module clipseg_DPT_test_v2 --dataset fss \ 136 | --widehead --no-scaleinv --arch_option 0 --ignore_index 255 --fold 0 --nshot 0 \ 137 | --weights checkpoints/fss_rn101.ckpt 138 | ``` 139 | 140 | #### Model Zoo 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 | 208 | 209 | 210 | 211 | 212 | 213 | 214 | 215 | 216 | 217 | 218 | 219 | 220 | 221 | 222 | 223 | 224 | 225 | 226 | 227 | 228 | 229 | 230 | 231 | 232 | 233 | 234 |
datasetfoldbackbonetext encoderperformanceurl
pascal0ResNet101CLIP ViT-B/3252.8download
pascal1ResNet101CLIP ViT-B/3253.8download
pascal2ResNet101CLIP ViT-B/3244.4download
pascal3ResNet101CLIP ViT-B/3238.5download
coco0ResNet101CLIP ViT-B/3222.1download
coco1ResNet101CLIP ViT-B/3225.1download
coco2ResNet101CLIP ViT-B/3224.9download
coco3ResNet101CLIP ViT-B/3221.5download
fss-ResNet101CLIP ViT-B/3284.7download
fss-ViT-L/16CLIP ViT-B/3287.8download
235 | 236 | If you find this repo useful, please cite: 237 | ``` 238 | @inproceedings{ 239 | li2022languagedriven, 240 | title={Language-driven Semantic Segmentation}, 241 | author={Boyi Li and Kilian Q Weinberger and Serge Belongie and Vladlen Koltun and Rene Ranftl}, 242 | booktitle={International Conference on Learning Representations}, 243 | year={2022}, 244 | url={https://openreview.net/forum?id=RriDjddCLN} 245 | } 246 | ``` 247 | 248 | ## Acknowledgement 249 | Thanks to the code base from [DPT](https://github.com/isl-org/DPT), [Pytorch_lightning](https://github.com/PyTorchLightning/pytorch-lightning), [CLIP](https://github.com/openai/CLIP), [Pytorch Encoding](https://github.com/zhanghang1989/PyTorch-Encoding), [Streamlit](https://streamlit.io/), [Wandb](https://wandb.ai/site) 250 | -------------------------------------------------------------------------------- /additional_utils/encoding_models.py: -------------------------------------------------------------------------------- 1 | ########################################################################### 2 | # Referred to: https://github.com/zhanghang1989/PyTorch-Encoding 3 | ########################################################################### 4 | import math 5 | import numpy as np 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from torch.nn.parallel.data_parallel import DataParallel 11 | from torch.nn.parallel.scatter_gather import scatter 12 | import threading 13 | import torch 14 | from torch.cuda._utils import _get_device_index 15 | from torch.cuda.amp import autocast 16 | from torch._utils import ExceptionWrapper 17 | 18 | up_kwargs = {'mode': 'bilinear', 'align_corners': True} 19 | 20 | __all__ = ['MultiEvalModule'] 21 | 22 | class MultiEvalModule(DataParallel): 23 | """Multi-size Segmentation Eavluator""" 24 | def __init__(self, module, nclass, device_ids=None, flip=True, 25 | scales=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75]): 26 | super(MultiEvalModule, self).__init__(module, device_ids) 27 | self.nclass = nclass 28 | self.base_size = module.base_size 29 | self.crop_size = module.crop_size 30 | self.scales = scales 31 | self.flip = flip 32 | print('MultiEvalModule: base_size {}, crop_size {}'. \ 33 | format(self.base_size, self.crop_size)) 34 | 35 | def parallel_forward(self, inputs, **kwargs): 36 | """Multi-GPU Mult-size Evaluation 37 | 38 | Args: 39 | inputs: list of Tensors 40 | """ 41 | inputs = [(input.unsqueeze(0).cuda(device),) 42 | for input, device in zip(inputs, self.device_ids)] 43 | replicas = self.replicate(self, self.device_ids[:len(inputs)]) 44 | kwargs = scatter(kwargs, target_gpus, dim) if kwargs else [] 45 | if len(inputs) < len(kwargs): 46 | inputs.extend([() for _ in range(len(kwargs) - len(inputs))]) 47 | elif len(kwargs) < len(inputs): 48 | kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))]) 49 | outputs = self.parallel_apply(replicas, inputs, kwargs) 50 | #for out in outputs: 51 | # print('out.size()', out.size()) 52 | return outputs 53 | 54 | def forward(self, image): 55 | """Mult-size Evaluation""" 56 | # only single image is supported for evaluation 57 | batch, _, h, w = image.size() 58 | assert(batch == 1) 59 | stride_rate = 2.0/3.0 60 | crop_size = self.crop_size 61 | stride = int(crop_size * stride_rate) 62 | with torch.cuda.device_of(image): 63 | scores = image.new().resize_(batch,self.nclass,h,w).zero_().cuda() 64 | 65 | for scale in self.scales: 66 | long_size = int(math.ceil(self.base_size * scale)) 67 | if h > w: 68 | height = long_size 69 | width = int(1.0 * w * long_size / h + 0.5) 70 | short_size = width 71 | else: 72 | width = long_size 73 | height = int(1.0 * h * long_size / w + 0.5) 74 | short_size = height 75 | """ 76 | short_size = int(math.ceil(self.base_size * scale)) 77 | if h > w: 78 | width = short_size 79 | height = int(1.0 * h * short_size / w) 80 | long_size = height 81 | else: 82 | height = short_size 83 | width = int(1.0 * w * short_size / h) 84 | long_size = width 85 | """ 86 | # resize image to current size 87 | cur_img = resize_image(image, height, width, **self.module._up_kwargs) 88 | if long_size <= crop_size: 89 | pad_img = pad_image(cur_img, self.module.mean, 90 | self.module.std, crop_size) 91 | outputs = module_inference(self.module, pad_img, self.flip) 92 | outputs = crop_image(outputs, 0, height, 0, width) 93 | else: 94 | if short_size < crop_size: 95 | # pad if needed 96 | pad_img = pad_image(cur_img, self.module.mean, 97 | self.module.std, crop_size) 98 | else: 99 | pad_img = cur_img 100 | _,_,ph,pw = pad_img.size() 101 | assert(ph >= height and pw >= width) 102 | # grid forward and normalize 103 | h_grids = int(math.ceil(1.0 * (ph-crop_size)/stride)) + 1 104 | w_grids = int(math.ceil(1.0 * (pw-crop_size)/stride)) + 1 105 | with torch.cuda.device_of(image): 106 | outputs = image.new().resize_(batch,self.nclass,ph,pw).zero_().cuda() 107 | count_norm = image.new().resize_(batch,1,ph,pw).zero_().cuda() 108 | # grid evaluation 109 | for idh in range(h_grids): 110 | for idw in range(w_grids): 111 | h0 = idh * stride 112 | w0 = idw * stride 113 | h1 = min(h0 + crop_size, ph) 114 | w1 = min(w0 + crop_size, pw) 115 | crop_img = crop_image(pad_img, h0, h1, w0, w1) 116 | # pad if needed 117 | pad_crop_img = pad_image(crop_img, self.module.mean, 118 | self.module.std, crop_size) 119 | output = module_inference(self.module, pad_crop_img, self.flip) 120 | outputs[:,:,h0:h1,w0:w1] += crop_image(output, 121 | 0, h1-h0, 0, w1-w0) 122 | count_norm[:,:,h0:h1,w0:w1] += 1 123 | assert((count_norm==0).sum()==0) 124 | outputs = outputs / count_norm 125 | outputs = outputs[:,:,:height,:width] 126 | 127 | score = resize_image(outputs, h, w, **self.module._up_kwargs) 128 | scores += score 129 | 130 | return scores 131 | 132 | 133 | def module_inference(module, image, flip=True): 134 | output = module.evaluate(image) 135 | if flip: 136 | fimg = flip_image(image) 137 | foutput = module.evaluate(fimg) 138 | output += flip_image(foutput) 139 | return output 140 | 141 | def resize_image(img, h, w, **up_kwargs): 142 | return F.interpolate(img, (h, w), **up_kwargs) 143 | 144 | def pad_image(img, mean, std, crop_size): 145 | b,c,h,w = img.size() 146 | assert(c==3) 147 | padh = crop_size - h if h < crop_size else 0 148 | padw = crop_size - w if w < crop_size else 0 149 | pad_values = -np.array(mean) / np.array(std) 150 | img_pad = img.new().resize_(b,c,h+padh,w+padw) 151 | for i in range(c): 152 | # note that pytorch pad params is in reversed orders 153 | img_pad[:,i,:,:] = F.pad(img[:,i,:,:], (0, padw, 0, padh), value=pad_values[i]) 154 | assert(img_pad.size(2)>=crop_size and img_pad.size(3)>=crop_size) 155 | return img_pad 156 | 157 | def crop_image(img, h0, h1, w0, w1): 158 | return img[:,:,h0:h1,w0:w1] 159 | 160 | def flip_image(img): 161 | assert(img.dim()==4) 162 | with torch.cuda.device_of(img): 163 | idx = torch.arange(img.size(3)-1, -1, -1).type_as(img).long() 164 | return img.index_select(3, idx) -------------------------------------------------------------------------------- /additional_utils/models.py: -------------------------------------------------------------------------------- 1 | ########################################################################### 2 | # Referred to: https://github.com/zhanghang1989/PyTorch-Encoding 3 | ########################################################################### 4 | import math 5 | import numpy as np 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from torch.nn.parallel.data_parallel import DataParallel 11 | from torch.nn.parallel.scatter_gather import scatter 12 | import threading 13 | import torch 14 | from torch.cuda._utils import _get_device_index 15 | from torch.cuda.amp import autocast 16 | from torch._utils import ExceptionWrapper 17 | 18 | up_kwargs = {'mode': 'bilinear', 'align_corners': True} 19 | 20 | __all__ = ['LSeg_MultiEvalModule'] 21 | 22 | 23 | class LSeg_MultiEvalModule(DataParallel): 24 | """Multi-size Segmentation Eavluator""" 25 | def __init__(self, module, device_ids=None, flip=True, 26 | scales=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75]): 27 | super(LSeg_MultiEvalModule, self).__init__(module, device_ids) 28 | self.base_size = module.base_size 29 | self.crop_size = module.crop_size 30 | self.scales = scales 31 | self.flip = flip 32 | print('MultiEvalModule: base_size {}, crop_size {}'. \ 33 | format(self.base_size, self.crop_size)) 34 | 35 | def parallel_forward(self, inputs, label_set='', **kwargs): 36 | """Multi-GPU Mult-size Evaluation 37 | 38 | Args: 39 | inputs: list of Tensors 40 | """ 41 | if len(label_set) < 10: 42 | print('** MultiEvalModule parallel_forward phase: {} **'.format(label_set)) 43 | self.nclass = len(label_set) 44 | inputs = [(input.unsqueeze(0).cuda(device),) 45 | for input, device in zip(inputs, self.device_ids)] 46 | replicas = self.replicate(self, self.device_ids[:len(inputs)]) 47 | kwargs = scatter(kwargs, target_gpus, dim) if kwargs else [] 48 | if len(inputs) < len(kwargs): 49 | inputs.extend([() for _ in range(len(kwargs) - len(inputs))]) 50 | elif len(kwargs) < len(inputs): 51 | kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))]) 52 | outputs = parallel_apply(replicas, inputs, label_set, kwargs) 53 | return outputs 54 | 55 | def forward(self, image, label_set=''): 56 | """Mult-size Evaluation""" 57 | # only single image is supported for evaluation 58 | if len(label_set) < 10: 59 | print('** MultiEvalModule forward phase: {} **'.format(label_set)) 60 | batch, _, h, w = image.size() 61 | assert(batch == 1) 62 | self.nclass = len(label_set) 63 | stride_rate = 2.0/3.0 64 | crop_size = self.crop_size 65 | stride = int(crop_size * stride_rate) 66 | with torch.cuda.device_of(image): 67 | scores = image.new().resize_(batch,self.nclass,h,w).zero_().cuda() 68 | 69 | for scale in self.scales: 70 | long_size = int(math.ceil(self.base_size * scale)) 71 | if h > w: 72 | height = long_size 73 | width = int(1.0 * w * long_size / h + 0.5) 74 | short_size = width 75 | else: 76 | width = long_size 77 | height = int(1.0 * h * long_size / w + 0.5) 78 | short_size = height 79 | """ 80 | short_size = int(math.ceil(self.base_size * scale)) 81 | if h > w: 82 | width = short_size 83 | height = int(1.0 * h * short_size / w) 84 | long_size = height 85 | else: 86 | height = short_size 87 | width = int(1.0 * w * short_size / h) 88 | long_size = width 89 | """ 90 | # resize image to current size 91 | cur_img = resize_image(image, height, width, **self.module._up_kwargs) 92 | if long_size <= crop_size: 93 | pad_img = pad_image(cur_img, self.module.mean, 94 | self.module.std, crop_size) 95 | outputs = module_inference(self.module, pad_img, label_set, self.flip) 96 | outputs = crop_image(outputs, 0, height, 0, width) 97 | else: 98 | if short_size < crop_size: 99 | # pad if needed 100 | pad_img = pad_image(cur_img, self.module.mean, 101 | self.module.std, crop_size) 102 | else: 103 | pad_img = cur_img 104 | _,_,ph,pw = pad_img.shape #.size() 105 | assert(ph >= height and pw >= width) 106 | # grid forward and normalize 107 | h_grids = int(math.ceil(1.0 * (ph-crop_size)/stride)) + 1 108 | w_grids = int(math.ceil(1.0 * (pw-crop_size)/stride)) + 1 109 | with torch.cuda.device_of(image): 110 | outputs = image.new().resize_(batch,self.nclass,ph,pw).zero_().cuda() 111 | count_norm = image.new().resize_(batch,1,ph,pw).zero_().cuda() 112 | # grid evaluation 113 | for idh in range(h_grids): 114 | for idw in range(w_grids): 115 | h0 = idh * stride 116 | w0 = idw * stride 117 | h1 = min(h0 + crop_size, ph) 118 | w1 = min(w0 + crop_size, pw) 119 | crop_img = crop_image(pad_img, h0, h1, w0, w1) 120 | # pad if needed 121 | pad_crop_img = pad_image(crop_img, self.module.mean, 122 | self.module.std, crop_size) 123 | output = module_inference(self.module, pad_crop_img, label_set, self.flip) 124 | outputs[:,:,h0:h1,w0:w1] += crop_image(output, 125 | 0, h1-h0, 0, w1-w0) 126 | count_norm[:,:,h0:h1,w0:w1] += 1 127 | assert((count_norm==0).sum()==0) 128 | outputs = outputs / count_norm 129 | outputs = outputs[:,:,:height,:width] 130 | score = resize_image(outputs, h, w, **self.module._up_kwargs) 131 | scores += score 132 | return scores 133 | 134 | def module_inference(module, image, label_set, flip=True): 135 | output = module.evaluate_random(image, label_set) 136 | if flip: 137 | fimg = flip_image(image) 138 | foutput = module.evaluate_random(fimg, label_set) 139 | output += flip_image(foutput) 140 | return output 141 | 142 | def resize_image(img, h, w, **up_kwargs): 143 | return F.interpolate(img, (h, w), **up_kwargs) 144 | 145 | def pad_image(img, mean, std, crop_size): 146 | b,c,h,w = img.shape #.size() 147 | assert(c==3) 148 | padh = crop_size - h if h < crop_size else 0 149 | padw = crop_size - w if w < crop_size else 0 150 | pad_values = -np.array(mean) / np.array(std) 151 | img_pad = img.new().resize_(b,c,h+padh,w+padw) 152 | for i in range(c): 153 | # note that pytorch pad params is in reversed orders 154 | img_pad[:,i,:,:] = F.pad(img[:,i,:,:], (0, padw, 0, padh), value=pad_values[i]) 155 | assert(img_pad.size(2)>=crop_size and img_pad.size(3)>=crop_size) 156 | return img_pad 157 | 158 | def crop_image(img, h0, h1, w0, w1): 159 | return img[:,:,h0:h1,w0:w1] 160 | 161 | def flip_image(img): 162 | assert(img.dim()==4) 163 | with torch.cuda.device_of(img): 164 | idx = torch.arange(img.size(3)-1, -1, -1).type_as(img).long() 165 | return img.index_select(3, idx) 166 | 167 | 168 | def get_a_var(obj): 169 | if isinstance(obj, torch.Tensor): 170 | return obj 171 | 172 | if isinstance(obj, list) or isinstance(obj, tuple): 173 | for result in map(get_a_var, obj): 174 | if isinstance(result, torch.Tensor): 175 | return result 176 | if isinstance(obj, dict): 177 | for result in map(get_a_var, obj.items()): 178 | if isinstance(result, torch.Tensor): 179 | return result 180 | return None 181 | 182 | 183 | def parallel_apply(modules, inputs, label_set, kwargs_tup=None, devices=None): 184 | r"""Applies each `module` in :attr:`modules` in parallel on arguments 185 | contained in :attr:`inputs` (positional) and :attr:`kwargs_tup` (keyword) 186 | on each of :attr:`devices`. 187 | 188 | Args: 189 | modules (Module): modules to be parallelized 190 | inputs (tensor): inputs to the modules 191 | devices (list of int or torch.device): CUDA devices 192 | 193 | :attr:`modules`, :attr:`inputs`, :attr:`kwargs_tup` (if given), and 194 | :attr:`devices` (if given) should all have same length. Moreover, each 195 | element of :attr:`inputs` can either be a single object as the only argument 196 | to a module, or a collection of positional arguments. 197 | """ 198 | assert len(modules) == len(inputs) 199 | if kwargs_tup is not None: 200 | assert len(modules) == len(kwargs_tup) 201 | else: 202 | kwargs_tup = ({},) * len(modules) 203 | if devices is not None: 204 | assert len(modules) == len(devices) 205 | else: 206 | devices = [None] * len(modules) 207 | devices = [_get_device_index(x, True) for x in devices] 208 | lock = threading.Lock() 209 | results = {} 210 | grad_enabled, autocast_enabled = torch.is_grad_enabled(), torch.is_autocast_enabled() 211 | 212 | def _worker(i, module, input, label_set, kwargs, device=None): 213 | torch.set_grad_enabled(grad_enabled) 214 | if device is None: 215 | device = get_a_var(input).get_device() 216 | try: 217 | with torch.cuda.device(device), autocast(enabled=autocast_enabled): 218 | # this also avoids accidental slicing of `input` if it is a Tensor 219 | if not isinstance(input, (list, tuple)): 220 | input = (input,) 221 | output = module(*input, label_set, **kwargs) 222 | with lock: 223 | results[i] = output 224 | except Exception: 225 | with lock: 226 | results[i] = ExceptionWrapper( 227 | where="in replica {} on device {}".format(i, device)) 228 | 229 | if len(modules) > 1: 230 | threads = [threading.Thread(target=_worker, 231 | args=(i, module, input, label_set, kwargs, device)) 232 | for i, (module, input, kwargs, device) in 233 | enumerate(zip(modules, inputs, kwargs_tup, devices))] 234 | 235 | for thread in threads: 236 | thread.start() 237 | for thread in threads: 238 | thread.join() 239 | else: 240 | _worker(0, modules[0], inputs[0], label_set, kwargs_tup[0], devices[0]) 241 | 242 | outputs = [] 243 | for i in range(len(inputs)): 244 | output = results[i] 245 | if isinstance(output, ExceptionWrapper): 246 | output.reraise() 247 | outputs.append(output) 248 | return outputs 249 | 250 | 251 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import itertools 4 | import functools 5 | import numpy as np 6 | import torch 7 | import torch.utils.data 8 | import torchvision.transforms as torch_transforms 9 | import encoding.datasets as enc_ds 10 | 11 | encoding_datasets = { 12 | x: functools.partial(enc_ds.get_dataset, x) 13 | for x in ["coco", "ade20k", "pascal_voc", "pascal_aug", "pcontext", "citys"] 14 | } 15 | 16 | 17 | def get_dataset(name, **kwargs): 18 | if name in encoding_datasets: 19 | return encoding_datasets[name.lower()](**kwargs) 20 | assert False, f"dataset {name} not found" 21 | 22 | 23 | def get_available_datasets(): 24 | return list(encoding_datasets.keys()) 25 | -------------------------------------------------------------------------------- /fewshot_data/README.md: -------------------------------------------------------------------------------- 1 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/hypercorrelation-squeeze-for-few-shot/few-shot-semantic-segmentation-on-pascal-5i-1)](https://paperswithcode.com/sota/few-shot-semantic-segmentation-on-pascal-5i-1?p=hypercorrelation-squeeze-for-few-shot) 2 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/hypercorrelation-squeeze-for-few-shot/few-shot-semantic-segmentation-on-pascal-5i-5)](https://paperswithcode.com/sota/few-shot-semantic-segmentation-on-pascal-5i-5?p=hypercorrelation-squeeze-for-few-shot) 3 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/hypercorrelation-squeeze-for-few-shot/few-shot-semantic-segmentation-on-pascal-5i)](https://paperswithcode.com/sota/few-shot-semantic-segmentation-on-pascal-5i?p=hypercorrelation-squeeze-for-few-shot) 4 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/hypercorrelation-squeeze-for-few-shot/few-shot-semantic-segmentation-on-coco-20i-1)](https://paperswithcode.com/sota/few-shot-semantic-segmentation-on-coco-20i-1?p=hypercorrelation-squeeze-for-few-shot) 5 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/hypercorrelation-squeeze-for-few-shot/few-shot-semantic-segmentation-on-coco-20i-5)](https://paperswithcode.com/sota/few-shot-semantic-segmentation-on-coco-20i-5?p=hypercorrelation-squeeze-for-few-shot) 6 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/hypercorrelation-squeeze-for-few-shot/few-shot-semantic-segmentation-on-coco-20i-10)](https://paperswithcode.com/sota/few-shot-semantic-segmentation-on-coco-20i-10?p=hypercorrelation-squeeze-for-few-shot) 7 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/hypercorrelation-squeeze-for-few-shot/few-shot-semantic-segmentation-on-fss-1000-1)](https://paperswithcode.com/sota/few-shot-semantic-segmentation-on-fss-1000-1?p=hypercorrelation-squeeze-for-few-shot) 8 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/hypercorrelation-squeeze-for-few-shot/few-shot-semantic-segmentation-on-fss-1000-5)](https://paperswithcode.com/sota/few-shot-semantic-segmentation-on-fss-1000-5?p=hypercorrelation-squeeze-for-few-shot) 9 | 10 | 11 | ## Hypercorrelation Squeeze for Few-Shot Segmentation 12 | This is the implementation of the paper "Hypercorrelation Squeeze for Few-Shot Segmentation" by Juhong Min, Dahyun Kang, and Minsu Cho. Implemented on Python 3.7 and Pytorch 1.5.1. 13 | 14 |

15 | 16 |

17 | 18 | For more information, check out project [[website](http://cvlab.postech.ac.kr/research/HSNet/)] and the paper on [[arXiv](https://arxiv.org/abs/2104.01538)]. 19 | 20 | ## Requirements 21 | 22 | - Python 3.7 23 | - PyTorch 1.5.1 24 | - cuda 10.1 25 | - tensorboard 1.14 26 | 27 | Conda environment settings: 28 | ```bash 29 | conda create -n hsnet python=3.7 30 | conda activate hsnet 31 | 32 | conda install pytorch=1.5.1 torchvision cudatoolkit=10.1 -c pytorch 33 | conda install -c conda-forge tensorflow 34 | pip install tensorboardX 35 | ``` 36 | ## Preparing Few-Shot Segmentation Datasets 37 | Download following datasets: 38 | 39 | > #### 1. PASCAL-5i 40 | > Download PASCAL VOC2012 devkit (train/val data): 41 | > ```bash 42 | > wget http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar 43 | > ``` 44 | > Download PASCAL VOC2012 SDS extended mask annotations from our [[Google Drive](https://drive.google.com/file/d/10zxG2VExoEZUeyQl_uXga2OWHjGeZaf2/view?usp=sharing)]. 45 | 46 | > #### 2. COCO-20i 47 | > Download COCO2014 train/val images and annotations: 48 | > ```bash 49 | > wget http://images.cocodataset.org/zips/train2014.zip 50 | > wget http://images.cocodataset.org/zips/val2014.zip 51 | > wget http://images.cocodataset.org/annotations/annotations_trainval2014.zip 52 | > ``` 53 | > Download COCO2014 train/val annotations from our Google Drive: [[train2014.zip](https://drive.google.com/file/d/1cwup51kcr4m7v9jO14ArpxKMA4O3-Uge/view?usp=sharing)], [[val2014.zip](https://drive.google.com/file/d/1PNw4U3T2MhzAEBWGGgceXvYU3cZ7mJL1/view?usp=sharing)]. 54 | > (and locate both train2014/ and val2014/ under annotations/ directory). 55 | 56 | > #### 3. FSS-1000 57 | > Download FSS-1000 images and annotations from our [[Google Drive](https://drive.google.com/file/d/1Fn-cUESMMF1pQy8Xff-vPQvXJdZoUlP3/view?usp=sharing)]. 58 | 59 | Create a directory '../Datasets_HSN' for the above three few-shot segmentation datasets and appropriately place each dataset to have following directory structure: 60 | 61 | ../ # parent directory 62 | ├── ./ # current (project) directory 63 | │ ├── common/ # (dir.) helper functions 64 | │ ├── data/ # (dir.) dataloaders and splits for each FSSS dataset 65 | │ ├── model/ # (dir.) implementation of Hypercorrelation Squeeze Network model 66 | │ ├── README.md # intstruction for reproduction 67 | │ ├── train.py # code for training HSNet 68 | │ └── test.py # code for testing HSNet 69 | └── Datasets_HSN/ 70 | ├── VOC2012/ # PASCAL VOC2012 devkit 71 | │ ├── Annotations/ 72 | │ ├── ImageSets/ 73 | │ ├── ... 74 | │ └── SegmentationClassAug/ 75 | ├── COCO2014/ 76 | │ ├── annotations/ 77 | │ │ ├── train2014/ # (dir.) training masks (from Google Drive) 78 | │ │ ├── val2014/ # (dir.) validation masks (from Google Drive) 79 | │ │ └── ..some json files.. 80 | │ ├── train2014/ 81 | │ └── val2014/ 82 | └── FSS-1000/ # (dir.) contains 1000 object classes 83 | ├── abacus/ 84 | ├── ... 85 | └── zucchini/ 86 | 87 | ## Training 88 | > ### 1. PASCAL-5i 89 | > ```bash 90 | > python train.py --backbone {vgg16, resnet50, resnet101} 91 | > --fold {0, 1, 2, 3} 92 | > --benchmark pascal 93 | > --lr 1e-3 94 | > --bsz 20 95 | > --logpath "your_experiment_name" 96 | > ``` 97 | > * Training takes approx. 2 days until convergence (trained with four 2080 Ti GPUs). 98 | 99 | 100 | > ### 2. COCO-20i 101 | > ```bash 102 | > python train.py --backbone {resnet50, resnet101} 103 | > --fold {0, 1, 2, 3} 104 | > --benchmark coco 105 | > --lr 1e-3 106 | > --bsz 40 107 | > --logpath "your_experiment_name" 108 | > ``` 109 | > * Training takes approx. 1 week until convergence (trained four Titan RTX GPUs). 110 | 111 | > ### 3. FSS-1000 112 | > ```bash 113 | > python train.py --backbone {vgg16, resnet50, resnet101} 114 | > --benchmark fss 115 | > --lr 1e-3 116 | > --bsz 20 117 | > --logpath "your_experiment_name" 118 | > ``` 119 | > * Training takes approx. 3 days until convergence (trained with four 2080 Ti GPUs). 120 | 121 | > ### Babysitting training: 122 | > Use tensorboard to babysit training progress: 123 | > - For each experiment, a directory that logs training progress will be automatically generated under logs/ directory. 124 | > - From terminal, run 'tensorboard --logdir logs/' to monitor the training progress. 125 | > - Choose the best model when the validation (mIoU) curve starts to saturate. 126 | 127 | 128 | 129 | ## Testing 130 | 131 | > ### 1. PASCAL-5i 132 | > Pretrained models with tensorboard logs are available on our [[Google Drive](https://drive.google.com/drive/folders/1z4KgjgOu--k6YuIj3qWrGg264GRcMis2?usp=sharing)]. 133 | > ```bash 134 | > python test.py --backbone {vgg16, resnet50, resnet101} 135 | > --fold {0, 1, 2, 3} 136 | > --benchmark pascal 137 | > --nshot {1, 5} 138 | > --load "path_to_trained_model/best_model.pt" 139 | > ``` 140 | 141 | 142 | > ### 2. COCO-20i 143 | > Pretrained models with tensorboard logs are available on our [[Google Drive](https://drive.google.com/drive/folders/1WpwmCQzxTWhJD5aLQhsgJASaoxxqmFUk?usp=sharing)]. 144 | > ```bash 145 | > python test.py --backbone {resnet50, resnet101} 146 | > --fold {0, 1, 2, 3} 147 | > --benchmark coco 148 | > --nshot {1, 5} 149 | > --load "path_to_trained_model/best_model.pt" 150 | > ``` 151 | 152 | > ### 3. FSS-1000 153 | > Pretrained models with tensorboard logs are available on our [[Google Drive](https://drive.google.com/drive/folders/1JOaaJknGwsrSEPoLF3x6_lDiy4XfAe99?usp=sharing)]. 154 | > ```bash 155 | > python test.py --backbone {vgg16, resnet50, resnet101} 156 | > --benchmark fss 157 | > --nshot {1, 5} 158 | > --load "path_to_trained_model/best_model.pt" 159 | > ``` 160 | 161 | > ### 4. Evaluation *without support feature masking* on PASCAL-5i 162 | > * To reproduce the results in Tab.1 of our main paper, **COMMENT OUT line 51 in hsnet.py**: support_feats = self.mask_feature(support_feats, support_mask.clone()) 163 | > 164 | > Pretrained models with tensorboard logs are available on our [[Google Drive](https://drive.google.com/drive/folders/18YWMCePIrza194pZvVMqQBuYqhwBmJwd?usp=sharing)]. 165 | > ```bash 166 | > python test.py --backbone resnet101 167 | > --fold {0, 1, 2, 3} 168 | > --benchmark pascal 169 | > --nshot {1, 5} 170 | > --load "path_to_trained_model/best_model.pt" 171 | > ``` 172 | 173 | 174 | ## Visualization 175 | 176 | * To visualize mask predictions, add command line argument **--visualize**: 177 | (prediction results will be saved under vis/ directory) 178 | ```bash 179 | python test.py '...other arguments...' --visualize 180 | ``` 181 | 182 | #### Example qualitative results (1-shot): 183 | 184 |

185 | 186 |

187 | 188 | ## BibTeX 189 | If you use this code for your research, please consider citing: 190 | ````BibTeX 191 | @InProceedings{min2021hypercorrelation, 192 | title={Hypercorrelation Squeeze for Few-Shot Segmentation}, 193 | author={Juhong Min and Dahyun Kang and Minsu Cho}, 194 | booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)}, 195 | year={2021} 196 | } 197 | ```` 198 | -------------------------------------------------------------------------------- /fewshot_data/common/evaluation.py: -------------------------------------------------------------------------------- 1 | r""" Evaluate mask prediction """ 2 | import torch 3 | 4 | 5 | class Evaluator: 6 | r""" Computes intersection and union between prediction and ground-truth """ 7 | @classmethod 8 | def initialize(cls): 9 | cls.ignore_index = 255 10 | 11 | @classmethod 12 | def classify_prediction(cls, pred_mask, gt_mask, query_ignore_idx=None): 13 | # gt_mask = batch.get('query_mask') 14 | 15 | # # Apply ignore_index in PASCAL-5i masks (following evaluation scheme in PFE-Net (TPAMI 2020)) 16 | # query_ignore_idx = batch.get('query_ignore_idx') 17 | if query_ignore_idx is not None: 18 | assert torch.logical_and(query_ignore_idx, gt_mask).sum() == 0 19 | query_ignore_idx *= cls.ignore_index 20 | gt_mask = gt_mask + query_ignore_idx 21 | pred_mask[gt_mask == cls.ignore_index] = cls.ignore_index 22 | 23 | # compute intersection and union of each episode in a batch 24 | area_inter, area_pred, area_gt = [], [], [] 25 | for _pred_mask, _gt_mask in zip(pred_mask, gt_mask): 26 | _inter = _pred_mask[_pred_mask == _gt_mask] 27 | if _inter.size(0) == 0: # as torch.histc returns error if it gets empty tensor (pytorch 1.5.1) 28 | _area_inter = torch.tensor([0, 0], device=_pred_mask.device) 29 | else: 30 | _area_inter = torch.histc(_inter, bins=2, min=0, max=1) 31 | area_inter.append(_area_inter) 32 | area_pred.append(torch.histc(_pred_mask, bins=2, min=0, max=1)) 33 | area_gt.append(torch.histc(_gt_mask, bins=2, min=0, max=1)) 34 | area_inter = torch.stack(area_inter).t() 35 | area_pred = torch.stack(area_pred).t() 36 | area_gt = torch.stack(area_gt).t() 37 | area_union = area_pred + area_gt - area_inter 38 | 39 | return area_inter, area_union 40 | -------------------------------------------------------------------------------- /fewshot_data/common/logger.py: -------------------------------------------------------------------------------- 1 | r""" Logging during training/testing """ 2 | import datetime 3 | import logging 4 | import os 5 | 6 | from tensorboardX import SummaryWriter 7 | import torch 8 | 9 | 10 | class AverageMeter: 11 | r""" Stores loss, evaluation results """ 12 | def __init__(self, dataset): 13 | self.benchmark = dataset.benchmark 14 | self.class_ids_interest = dataset.class_ids 15 | self.class_ids_interest = torch.tensor(self.class_ids_interest).cuda() 16 | 17 | if self.benchmark == 'pascal': 18 | self.nclass = 20 19 | elif self.benchmark == 'coco': 20 | self.nclass = 80 21 | elif self.benchmark == 'fss': 22 | self.nclass = 1000 23 | 24 | self.intersection_buf = torch.zeros([2, self.nclass]).float().cuda() 25 | self.union_buf = torch.zeros([2, self.nclass]).float().cuda() 26 | self.ones = torch.ones_like(self.union_buf) 27 | self.loss_buf = [] 28 | 29 | def update(self, inter_b, union_b, class_id, loss): 30 | self.intersection_buf.index_add_(1, class_id, inter_b.float()) 31 | self.union_buf.index_add_(1, class_id, union_b.float()) 32 | if loss is None: 33 | loss = torch.tensor(0.0) 34 | self.loss_buf.append(loss) 35 | 36 | def compute_iou(self): 37 | iou = self.intersection_buf.float() / \ 38 | torch.max(torch.stack([self.union_buf, self.ones]), dim=0)[0] 39 | iou = iou.index_select(1, self.class_ids_interest) 40 | miou = iou[1].mean() * 100 41 | 42 | fb_iou = (self.intersection_buf.index_select(1, self.class_ids_interest).sum(dim=1) / 43 | self.union_buf.index_select(1, self.class_ids_interest).sum(dim=1)).mean() * 100 44 | 45 | return miou, fb_iou 46 | 47 | def write_result(self, split, epoch): 48 | iou, fb_iou = self.compute_iou() 49 | 50 | loss_buf = torch.stack(self.loss_buf) 51 | msg = '\n*** %s ' % split 52 | msg += '[@Epoch %02d] ' % epoch 53 | msg += 'Avg L: %6.5f ' % loss_buf.mean() 54 | msg += 'mIoU: %5.2f ' % iou 55 | msg += 'FB-IoU: %5.2f ' % fb_iou 56 | 57 | msg += '***\n' 58 | Logger.info(msg) 59 | 60 | def write_process(self, batch_idx, datalen, epoch, write_batch_idx=20): 61 | if batch_idx % write_batch_idx == 0: 62 | msg = '[Epoch: %02d] ' % epoch if epoch != -1 else '' 63 | msg += '[Batch: %04d/%04d] ' % (batch_idx+1, datalen) 64 | iou, fb_iou = self.compute_iou() 65 | if epoch != -1: 66 | loss_buf = torch.stack(self.loss_buf) 67 | msg += 'L: %6.5f ' % loss_buf[-1] 68 | msg += 'Avg L: %6.5f ' % loss_buf.mean() 69 | msg += 'mIoU: %5.2f | ' % iou 70 | msg += 'FB-IoU: %5.2f' % fb_iou 71 | Logger.info(msg) 72 | return iou, fb_iou 73 | 74 | 75 | class Logger: 76 | r""" Writes evaluation results of training/testing """ 77 | @classmethod 78 | def initialize(cls, args, training): 79 | logtime = datetime.datetime.now().__format__('_%m%d_%H%M%S') 80 | logpath = args.logpath if training else '_TEST_' + args.load.split('/')[-2].split('.')[0] + logtime 81 | if logpath == '': logpath = logtime 82 | 83 | cls.logpath = os.path.join('logs', logpath + '.log') 84 | cls.benchmark = args.benchmark 85 | if not os.path.exists(cls.logpath): 86 | os.makedirs(cls.logpath) 87 | 88 | logging.basicConfig(filemode='w', 89 | filename=os.path.join(cls.logpath, 'log.txt'), 90 | level=logging.INFO, 91 | format='%(message)s', 92 | datefmt='%m-%d %H:%M:%S') 93 | 94 | # Console log config 95 | console = logging.StreamHandler() 96 | console.setLevel(logging.INFO) 97 | formatter = logging.Formatter('%(message)s') 98 | console.setFormatter(formatter) 99 | logging.getLogger('').addHandler(console) 100 | 101 | # Tensorboard writer 102 | cls.tbd_writer = SummaryWriter(os.path.join(cls.logpath, 'tbd/runs')) 103 | 104 | # Log arguments 105 | logging.info('\n:=========== Few-shot Seg. with HSNet ===========') 106 | for arg_key in args.__dict__: 107 | logging.info('| %20s: %-24s' % (arg_key, str(args.__dict__[arg_key]))) 108 | logging.info(':================================================\n') 109 | 110 | @classmethod 111 | def info(cls, msg): 112 | r""" Writes log message to log.txt """ 113 | logging.info(msg) 114 | 115 | @classmethod 116 | def save_model_miou(cls, model, epoch, val_miou): 117 | torch.save(model.state_dict(), os.path.join(cls.logpath, 'best_model.pt')) 118 | cls.info('Model saved @%d w/ val. mIoU: %5.2f.\n' % (epoch, val_miou)) 119 | 120 | @classmethod 121 | def log_params(cls, model): 122 | backbone_param = 0 123 | learner_param = 0 124 | for k in model.state_dict().keys(): 125 | n_param = model.state_dict()[k].view(-1).size(0) 126 | if k.split('.')[0] in 'backbone': 127 | if k.split('.')[1] in ['classifier', 'fc']: # as fc layers are not used in HSNet 128 | continue 129 | backbone_param += n_param 130 | else: 131 | learner_param += n_param 132 | Logger.info('Backbone # param.: %d' % backbone_param) 133 | Logger.info('Learnable # param.: %d' % learner_param) 134 | Logger.info('Total # param.: %d' % (backbone_param + learner_param)) 135 | 136 | -------------------------------------------------------------------------------- /fewshot_data/common/utils.py: -------------------------------------------------------------------------------- 1 | r""" Helper functions """ 2 | import random 3 | 4 | import torch 5 | import numpy as np 6 | 7 | 8 | def fix_randseed(seed): 9 | r""" Set random seeds for reproducibility """ 10 | if seed is None: 11 | seed = int(random.random() * 1e5) 12 | np.random.seed(seed) 13 | torch.manual_seed(seed) 14 | torch.cuda.manual_seed(seed) 15 | torch.cuda.manual_seed_all(seed) 16 | torch.backends.cudnn.benchmark = False 17 | torch.backends.cudnn.deterministic = True 18 | 19 | 20 | def mean(x): 21 | return sum(x) / len(x) if len(x) > 0 else 0.0 22 | 23 | 24 | def to_cuda(batch): 25 | for key, value in batch.items(): 26 | if isinstance(value, torch.Tensor): 27 | batch[key] = value.cuda() 28 | return batch 29 | 30 | 31 | def to_cpu(tensor): 32 | return tensor.detach().clone().cpu() 33 | -------------------------------------------------------------------------------- /fewshot_data/common/vis.py: -------------------------------------------------------------------------------- 1 | r""" Visualize model predictions """ 2 | import os 3 | 4 | from PIL import Image 5 | import numpy as np 6 | import torchvision.transforms as transforms 7 | 8 | from fewshot_data.common import utils 9 | 10 | 11 | class Visualizer: 12 | 13 | @classmethod 14 | def initialize(cls, visualize): 15 | cls.visualize = visualize 16 | if not visualize: 17 | return 18 | 19 | cls.colors = {'red': (255, 50, 50), 'blue': (102, 140, 255)} 20 | for key, value in cls.colors.items(): 21 | cls.colors[key] = tuple([c / 255 for c in cls.colors[key]]) 22 | 23 | # cls.mean_img = [0.485, 0.456, 0.406] 24 | # cls.std_img = [0.229, 0.224, 0.225] 25 | cls.mean_img = [0.5] * 3 26 | cls.std_img = [0.5] * 3 27 | cls.to_pil = transforms.ToPILImage() 28 | cls.vis_path = './vis/' 29 | if not os.path.exists(cls.vis_path): os.makedirs(cls.vis_path) 30 | 31 | @classmethod 32 | def visualize_prediction_batch(cls, spt_img_b, spt_mask_b, qry_img_b, qry_mask_b, pred_mask_b, cls_id_b, batch_idx, iou_b=None): 33 | spt_img_b = utils.to_cpu(spt_img_b) 34 | spt_mask_b = utils.to_cpu(spt_mask_b) 35 | qry_img_b = utils.to_cpu(qry_img_b) 36 | qry_mask_b = utils.to_cpu(qry_mask_b) 37 | pred_mask_b = utils.to_cpu(pred_mask_b) 38 | cls_id_b = utils.to_cpu(cls_id_b) 39 | 40 | for sample_idx, (spt_img, spt_mask, qry_img, qry_mask, pred_mask, cls_id) in \ 41 | enumerate(zip(spt_img_b, spt_mask_b, qry_img_b, qry_mask_b, pred_mask_b, cls_id_b)): 42 | iou = iou_b[sample_idx] if iou_b is not None else None 43 | cls.visualize_prediction(spt_img, spt_mask, qry_img, qry_mask, pred_mask, cls_id, batch_idx, sample_idx, True, iou) 44 | 45 | @classmethod 46 | def to_numpy(cls, tensor, type): 47 | if type == 'img': 48 | return np.array(cls.to_pil(cls.unnormalize(tensor))).astype(np.uint8) 49 | elif type == 'mask': 50 | return np.array(tensor).astype(np.uint8) 51 | else: 52 | raise Exception('Undefined tensor type: %s' % type) 53 | 54 | @classmethod 55 | def visualize_prediction(cls, spt_imgs, spt_masks, qry_img, qry_mask, pred_mask, cls_id, batch_idx, sample_idx, label, iou=None): 56 | 57 | spt_color = cls.colors['blue'] 58 | qry_color = cls.colors['red'] 59 | pred_color = cls.colors['red'] 60 | 61 | spt_imgs = [cls.to_numpy(spt_img, 'img') for spt_img in spt_imgs] 62 | spt_pils = [cls.to_pil(spt_img) for spt_img in spt_imgs] 63 | spt_masks = [cls.to_numpy(spt_mask, 'mask') for spt_mask in spt_masks] 64 | spt_masked_pils = [Image.fromarray(cls.apply_mask(spt_img, spt_mask, spt_color)) for spt_img, spt_mask in zip(spt_imgs, spt_masks)] 65 | 66 | qry_img = cls.to_numpy(qry_img, 'img') 67 | qry_pil = cls.to_pil(qry_img) 68 | qry_mask = cls.to_numpy(qry_mask, 'mask') 69 | pred_mask = cls.to_numpy(pred_mask, 'mask') 70 | pred_masked_pil = Image.fromarray(cls.apply_mask(qry_img.astype(np.uint8), pred_mask.astype(np.uint8), pred_color)) 71 | qry_masked_pil = Image.fromarray(cls.apply_mask(qry_img.astype(np.uint8), qry_mask.astype(np.uint8), qry_color)) 72 | 73 | merged_pil = cls.merge_image_pair(spt_masked_pils + [pred_masked_pil, qry_masked_pil]) 74 | 75 | iou = iou.item() if iou else 0.0 76 | merged_pil.save(cls.vis_path + '%d_%d_class-%d_iou-%.2f' % (batch_idx, sample_idx, cls_id, iou) + '.jpg') 77 | 78 | @classmethod 79 | def merge_image_pair(cls, pil_imgs): 80 | r""" Horizontally aligns a pair of pytorch tensor images (3, H, W) and returns PIL object """ 81 | 82 | canvas_width = sum([pil.size[0] for pil in pil_imgs]) 83 | canvas_height = max([pil.size[1] for pil in pil_imgs]) 84 | canvas = Image.new('RGB', (canvas_width, canvas_height)) 85 | 86 | xpos = 0 87 | for pil in pil_imgs: 88 | canvas.paste(pil, (xpos, 0)) 89 | xpos += pil.size[0] 90 | 91 | return canvas 92 | 93 | @classmethod 94 | def apply_mask(cls, image, mask, color, alpha=0.5): 95 | r""" Apply mask to the given image. """ 96 | for c in range(3): 97 | image[:, :, c] = np.where(mask == 1, 98 | image[:, :, c] * 99 | (1 - alpha) + alpha * color[c] * 255, 100 | image[:, :, c]) 101 | return image 102 | 103 | @classmethod 104 | def unnormalize(cls, img): 105 | img = img.clone() 106 | for im_channel, mean, std in zip(img, cls.mean_img, cls.std_img): 107 | im_channel.mul_(std).add_(mean) 108 | return img 109 | -------------------------------------------------------------------------------- /fewshot_data/data/assets/architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isl-org/lang-seg/bb00b219f82dc1a18a01f9b0b647bae1d41f3aeb/fewshot_data/data/assets/architecture.png -------------------------------------------------------------------------------- /fewshot_data/data/assets/qualitative_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isl-org/lang-seg/bb00b219f82dc1a18a01f9b0b647bae1d41f3aeb/fewshot_data/data/assets/qualitative_results.png -------------------------------------------------------------------------------- /fewshot_data/data/coco.py: -------------------------------------------------------------------------------- 1 | r""" COCO-20i few-shot semantic segmentation dataset """ 2 | import os 3 | import pickle 4 | 5 | from torch.utils.data import Dataset 6 | import torch.nn.functional as F 7 | import torch 8 | import PIL.Image as Image 9 | import numpy as np 10 | 11 | 12 | class DatasetCOCO(Dataset): 13 | def __init__(self, datapath, fold, transform, split, shot, use_original_imgsize): 14 | self.split = 'val' if split in ['val', 'test'] else 'trn' 15 | self.fold = fold 16 | self.nfolds = 4 17 | self.nclass = 80 18 | self.benchmark = 'coco' 19 | self.shot = shot 20 | self.split_coco = split if split == 'val2014' else 'train2014' 21 | self.base_path = os.path.join(datapath, 'COCO2014') 22 | self.transform = transform 23 | self.use_original_imgsize = use_original_imgsize 24 | 25 | self.class_ids = self.build_class_ids() 26 | self.img_metadata_classwise = self.build_img_metadata_classwise() 27 | self.img_metadata = self.build_img_metadata() 28 | 29 | def __len__(self): 30 | return len(self.img_metadata) if self.split == 'trn' else 1000 31 | 32 | def __getitem__(self, idx): 33 | # ignores idx during training & testing and perform uniform sampling over object classes to form an episode 34 | # (due to the large size of the COCO dataset) 35 | query_img, query_mask, support_imgs, support_masks, query_name, support_names, class_sample, org_qry_imsize = self.load_frame() 36 | 37 | query_img = self.transform(query_img) 38 | query_mask = query_mask.float() 39 | if not self.use_original_imgsize: 40 | query_mask = F.interpolate(query_mask.unsqueeze(0).unsqueeze(0).float(), query_img.size()[-2:], mode='nearest').squeeze() 41 | 42 | if self.shot: 43 | support_imgs = torch.stack([self.transform(support_img) for support_img in support_imgs]) 44 | for midx, smask in enumerate(support_masks): 45 | support_masks[midx] = F.interpolate(smask.unsqueeze(0).unsqueeze(0).float(), support_imgs.size()[-2:], mode='nearest').squeeze() 46 | support_masks = torch.stack(support_masks) 47 | 48 | 49 | batch = {'query_img': query_img, 50 | 'query_mask': query_mask, 51 | 'query_name': query_name, 52 | 53 | 'org_query_imsize': org_qry_imsize, 54 | 55 | 'support_imgs': support_imgs, 56 | 'support_masks': support_masks, 57 | 'support_names': support_names, 58 | 'class_id': torch.tensor(class_sample)} 59 | 60 | return batch 61 | 62 | def build_class_ids(self): 63 | nclass_trn = self.nclass // self.nfolds 64 | class_ids_val = [self.fold + self.nfolds * v for v in range(nclass_trn)] 65 | class_ids_trn = [x for x in range(self.nclass) if x not in class_ids_val] 66 | class_ids = class_ids_trn if self.split == 'trn' else class_ids_val 67 | 68 | return class_ids 69 | 70 | def build_img_metadata_classwise(self): 71 | with open('fewshot_data/data/splits/coco/%s/fold%d.pkl' % (self.split, self.fold), 'rb') as f: 72 | img_metadata_classwise = pickle.load(f) 73 | return img_metadata_classwise 74 | 75 | def build_img_metadata(self): 76 | img_metadata = [] 77 | for k in self.img_metadata_classwise.keys(): 78 | img_metadata += self.img_metadata_classwise[k] 79 | return sorted(list(set(img_metadata))) 80 | 81 | def read_mask(self, name): 82 | mask_path = os.path.join(self.base_path, 'annotations', name) 83 | mask = torch.tensor(np.array(Image.open(mask_path[:mask_path.index('.jpg')] + '.png'))) 84 | return mask 85 | 86 | def load_frame(self): 87 | class_sample = np.random.choice(self.class_ids, 1, replace=False)[0] 88 | query_name = np.random.choice(self.img_metadata_classwise[class_sample], 1, replace=False)[0] 89 | query_img = Image.open(os.path.join(self.base_path, query_name)).convert('RGB') 90 | query_mask = self.read_mask(query_name) 91 | 92 | org_qry_imsize = query_img.size 93 | 94 | query_mask[query_mask != class_sample + 1] = 0 95 | query_mask[query_mask == class_sample + 1] = 1 96 | 97 | support_names = [] 98 | if self.shot: 99 | while True: # keep sampling support set if query == support 100 | support_name = np.random.choice(self.img_metadata_classwise[class_sample], 1, replace=False)[0] 101 | if query_name != support_name: support_names.append(support_name) 102 | if len(support_names) == self.shot: break 103 | 104 | support_imgs = [] 105 | support_masks = [] 106 | if self.shot: 107 | for support_name in support_names: 108 | support_imgs.append(Image.open(os.path.join(self.base_path, support_name)).convert('RGB')) 109 | support_mask = self.read_mask(support_name) 110 | support_mask[support_mask != class_sample + 1] = 0 111 | support_mask[support_mask == class_sample + 1] = 1 112 | support_masks.append(support_mask) 113 | 114 | return query_img, query_mask, support_imgs, support_masks, query_name, support_names, class_sample, org_qry_imsize 115 | 116 | -------------------------------------------------------------------------------- /fewshot_data/data/dataset.py: -------------------------------------------------------------------------------- 1 | r""" Dataloader builder for few-shot semantic segmentation dataset """ 2 | from torchvision import transforms 3 | from torch.utils.data import DataLoader 4 | 5 | from fewshot_data.data.pascal import DatasetPASCAL 6 | from fewshot_data.data.coco import DatasetCOCO 7 | from fewshot_data.data.fss import DatasetFSS 8 | 9 | 10 | class FSSDataset: 11 | @classmethod 12 | def initialize(cls, img_size, datapath, use_original_imgsize, imagenet_norm=False): 13 | cls.datasets = { 14 | 'pascal': DatasetPASCAL, 15 | 'coco': DatasetCOCO, 16 | 'fss': DatasetFSS, 17 | } 18 | 19 | if imagenet_norm: 20 | cls.img_mean = [0.485, 0.456, 0.406] 21 | cls.img_std = [0.229, 0.224, 0.225] 22 | print('use norm: {}, {}'.format(cls.img_mean, cls.img_std)) 23 | else: 24 | cls.img_mean = [0.5] * 3 25 | cls.img_std = [0.5] * 3 26 | print('use norm: {}, {}'.format(cls.img_mean, cls.img_std)) 27 | 28 | cls.datapath = datapath 29 | cls.use_original_imgsize = use_original_imgsize 30 | 31 | cls.transform = transforms.Compose([transforms.Resize(size=(img_size, img_size)), 32 | transforms.ToTensor(), 33 | transforms.Normalize(cls.img_mean, cls.img_std)]) 34 | 35 | @classmethod 36 | def build_dataloader(cls, benchmark, bsz, nworker, fold, split, shot=1): 37 | shuffle = split == 'trn' 38 | nworker = nworker if split == 'trn' else 0 39 | dataset = cls.datasets[benchmark](cls.datapath, fold=fold, transform=cls.transform, split=split, shot=shot, use_original_imgsize=cls.use_original_imgsize) 40 | dataloader = DataLoader(dataset, batch_size=bsz, shuffle=shuffle, num_workers=nworker) 41 | 42 | return dataloader 43 | -------------------------------------------------------------------------------- /fewshot_data/data/fss.py: -------------------------------------------------------------------------------- 1 | r""" FSS-1000 few-shot semantic segmentation dataset """ 2 | import os 3 | import glob 4 | 5 | from torch.utils.data import Dataset 6 | import torch.nn.functional as F 7 | import torch 8 | import PIL.Image as Image 9 | import numpy as np 10 | 11 | 12 | class DatasetFSS(Dataset): 13 | def __init__(self, datapath, fold, transform, split, shot, use_original_imgsize=None): 14 | self.split = split 15 | self.benchmark = 'fss' 16 | self.shot = shot 17 | 18 | self.base_path = os.path.join(datapath, 'FSS-1000') 19 | 20 | # Given predefined test split, load randomly generated training/val splits: 21 | # (reference regarding trn/val/test splits: https://github.com/HKUSTCV/FSS-1000/issues/7)) 22 | with open('fewshot_data/data/splits/fss/%s.txt' % split, 'r') as f: 23 | self.categories = f.read().split('\n')[:-1] 24 | self.categories = sorted(self.categories) 25 | 26 | self.class_ids = self.build_class_ids() 27 | self.img_metadata = self.build_img_metadata() 28 | 29 | self.transform = transform 30 | 31 | def __len__(self): 32 | return len(self.img_metadata) 33 | 34 | def __getitem__(self, idx): 35 | query_name, support_names, class_sample = self.sample_episode(idx) 36 | query_img, query_mask, support_imgs, support_masks = self.load_frame(query_name, support_names) 37 | 38 | query_img = self.transform(query_img) 39 | query_mask = F.interpolate(query_mask.unsqueeze(0).unsqueeze(0).float(), query_img.size()[-2:], mode='nearest').squeeze() 40 | if self.shot: 41 | support_imgs = torch.stack([self.transform(support_img) for support_img in support_imgs]) 42 | 43 | support_masks_tmp = [] 44 | for smask in support_masks: 45 | smask = F.interpolate(smask.unsqueeze(0).unsqueeze(0).float(), support_imgs.size()[-2:], mode='nearest').squeeze() 46 | support_masks_tmp.append(smask) 47 | support_masks = torch.stack(support_masks_tmp) 48 | 49 | batch = {'query_img': query_img, 50 | 'query_mask': query_mask, 51 | 'query_name': query_name, 52 | 53 | 'support_imgs': support_imgs, 54 | 'support_masks': support_masks, 55 | 'support_names': support_names, 56 | 57 | 'class_id': torch.tensor(class_sample)} 58 | 59 | return batch 60 | 61 | def load_frame(self, query_name, support_names): 62 | query_img = Image.open(query_name).convert('RGB') 63 | if self.shot: 64 | support_imgs = [Image.open(name).convert('RGB') for name in support_names] 65 | else: 66 | support_imgs = [] 67 | 68 | query_id = query_name.split('/')[-1].split('.')[0] 69 | query_name = os.path.join(os.path.dirname(query_name), query_id) + '.png' 70 | 71 | if self.shot: 72 | support_ids = [name.split('/')[-1].split('.')[0] for name in support_names] 73 | support_names = [os.path.join(os.path.dirname(name), sid) + '.png' for name, sid in zip(support_names, support_ids)] 74 | 75 | query_mask = self.read_mask(query_name) 76 | if self.shot: 77 | support_masks = [self.read_mask(name) for name in support_names] 78 | else: 79 | support_masks = [] 80 | 81 | return query_img, query_mask, support_imgs, support_masks 82 | 83 | def read_mask(self, img_name): 84 | mask = torch.tensor(np.array(Image.open(img_name).convert('L'))) 85 | mask[mask < 128] = 0 86 | mask[mask >= 128] = 1 87 | return mask 88 | 89 | def sample_episode(self, idx): 90 | query_name = self.img_metadata[idx] 91 | class_sample = self.categories.index(query_name.split('/')[-2]) 92 | if self.split == 'val': 93 | class_sample += 520 94 | elif self.split == 'test': 95 | class_sample += 760 96 | 97 | support_names = [] 98 | # here we only test with shot=1 99 | if self.split == 'test' and self.shot == 1: 100 | while True: 101 | support_name = 1 102 | support_name = os.path.join(os.path.dirname(query_name), str(support_name)) + '.jpg' 103 | if query_name != support_name: 104 | support_names.append(support_name) 105 | else: 106 | print('Error in sample_episode!') 107 | exit() 108 | if len(support_names) == self.shot: break 109 | elif self.shot: 110 | while True: # keep sampling support set if query == support 111 | support_name = np.random.choice(range(1, 11), 1, replace=False)[0] 112 | support_name = os.path.join(os.path.dirname(query_name), str(support_name)) + '.jpg' 113 | if query_name != support_name: support_names.append(support_name) 114 | if len(support_names) == self.shot: break 115 | 116 | return query_name, support_names, class_sample 117 | 118 | def build_class_ids(self): 119 | if self.split == 'trn': 120 | class_ids = range(0, 520) 121 | elif self.split == 'val': 122 | class_ids = range(520, 760) 123 | elif self.split == 'test': 124 | class_ids = range(760, 1000) 125 | return class_ids 126 | 127 | def build_img_metadata(self): 128 | img_metadata = [] 129 | for cat in self.categories: 130 | img_paths = sorted([path for path in glob.glob('%s/*' % os.path.join(self.base_path, cat))]) 131 | if self.split == 'test' and self.shot == 1: 132 | for i in range(1, len(img_paths)): 133 | img_path = img_paths[i] 134 | if os.path.basename(img_path).split('.')[1] == 'jpg': 135 | img_metadata.append(img_path) 136 | else: 137 | for img_path in img_paths: 138 | if os.path.basename(img_path).split('.')[1] == 'jpg': 139 | img_metadata.append(img_path) 140 | return img_metadata 141 | -------------------------------------------------------------------------------- /fewshot_data/data/pascal.py: -------------------------------------------------------------------------------- 1 | r""" PASCAL-5i few-shot semantic segmentation dataset """ 2 | import os 3 | 4 | from torch.utils.data import Dataset 5 | import torch.nn.functional as F 6 | import torch 7 | import PIL.Image as Image 8 | import numpy as np 9 | 10 | 11 | class DatasetPASCAL(Dataset): 12 | def __init__(self, datapath, fold, transform, split, shot, use_original_imgsize): 13 | self.split = 'val' if split in ['val', 'test'] else 'trn' 14 | self.fold = fold 15 | self.nfolds = 4 16 | self.nclass = 20 17 | self.benchmark = 'pascal' 18 | self.shot = shot 19 | self.use_original_imgsize = use_original_imgsize 20 | 21 | self.img_path = os.path.join(datapath, 'VOC2012/JPEGImages/') 22 | self.ann_path = os.path.join(datapath, 'VOC2012/SegmentationClassAug/') 23 | self.transform = transform 24 | 25 | self.class_ids = self.build_class_ids() 26 | self.img_metadata = self.build_img_metadata() 27 | self.img_metadata_classwise = self.build_img_metadata_classwise() 28 | 29 | def __len__(self): 30 | return len(self.img_metadata) if self.split == 'trn' else 1000 31 | 32 | def __getitem__(self, idx): 33 | idx %= len(self.img_metadata) # for testing, as n_images < 1000 34 | query_name, support_names, class_sample = self.sample_episode(idx) 35 | query_img, query_cmask, support_imgs, support_cmasks, org_qry_imsize = self.load_frame(query_name, support_names) 36 | 37 | query_img = self.transform(query_img) 38 | if not self.use_original_imgsize: 39 | query_cmask = F.interpolate(query_cmask.unsqueeze(0).unsqueeze(0).float(), query_img.size()[-2:], mode='nearest').squeeze() 40 | query_mask, query_ignore_idx = self.extract_ignore_idx(query_cmask.float(), class_sample) 41 | 42 | if self.shot: 43 | support_imgs = torch.stack([self.transform(support_img) for support_img in support_imgs]) 44 | 45 | support_masks = [] 46 | support_ignore_idxs = [] 47 | for scmask in support_cmasks: 48 | scmask = F.interpolate(scmask.unsqueeze(0).unsqueeze(0).float(), support_imgs.size()[-2:], mode='nearest').squeeze() 49 | support_mask, support_ignore_idx = self.extract_ignore_idx(scmask, class_sample) 50 | support_masks.append(support_mask) 51 | support_ignore_idxs.append(support_ignore_idx) 52 | support_masks = torch.stack(support_masks) 53 | support_ignore_idxs = torch.stack(support_ignore_idxs) 54 | else: 55 | support_masks = [] 56 | support_ignore_idxs = [] 57 | batch = {'query_img': query_img, 58 | 'query_mask': query_mask, 59 | 'query_name': query_name, 60 | 'query_ignore_idx': query_ignore_idx, 61 | 62 | 'org_query_imsize': org_qry_imsize, 63 | 64 | 'support_imgs': support_imgs, 65 | 'support_masks': support_masks, 66 | 'support_names': support_names, 67 | 'support_ignore_idxs': support_ignore_idxs, 68 | 69 | 'class_id': torch.tensor(class_sample)} 70 | 71 | return batch 72 | 73 | def extract_ignore_idx(self, mask, class_id): 74 | boundary = (mask / 255).floor() 75 | mask[mask != class_id + 1] = 0 76 | mask[mask == class_id + 1] = 1 77 | 78 | return mask, boundary 79 | 80 | def load_frame(self, query_name, support_names): 81 | query_img = self.read_img(query_name) 82 | query_mask = self.read_mask(query_name) 83 | support_imgs = [self.read_img(name) for name in support_names] 84 | support_masks = [self.read_mask(name) for name in support_names] 85 | 86 | org_qry_imsize = query_img.size 87 | 88 | return query_img, query_mask, support_imgs, support_masks, org_qry_imsize 89 | 90 | def read_mask(self, img_name): 91 | r"""Return segmentation mask in PIL Image""" 92 | mask = torch.tensor(np.array(Image.open(os.path.join(self.ann_path, img_name) + '.png'))) 93 | return mask 94 | 95 | def read_img(self, img_name): 96 | r"""Return RGB image in PIL Image""" 97 | return Image.open(os.path.join(self.img_path, img_name) + '.jpg') 98 | 99 | def sample_episode(self, idx): 100 | query_name, class_sample = self.img_metadata[idx] 101 | 102 | support_names = [] 103 | if self.shot: 104 | while True: # keep sampling support set if query == support 105 | support_name = np.random.choice(self.img_metadata_classwise[class_sample], 1, replace=False)[0] 106 | if query_name != support_name: support_names.append(support_name) 107 | if len(support_names) == self.shot: break 108 | 109 | return query_name, support_names, class_sample 110 | 111 | def build_class_ids(self): 112 | nclass_trn = self.nclass // self.nfolds 113 | class_ids_val = [self.fold * nclass_trn + i for i in range(nclass_trn)] 114 | class_ids_trn = [x for x in range(self.nclass) if x not in class_ids_val] 115 | 116 | if self.split == 'trn': 117 | return class_ids_trn 118 | else: 119 | return class_ids_val 120 | 121 | def build_img_metadata(self): 122 | 123 | def read_metadata(split, fold_id): 124 | fold_n_metadata = os.path.join('fewshot_data/data/splits/pascal/%s/fold%d.txt' % (split, fold_id)) 125 | with open(fold_n_metadata, 'r') as f: 126 | fold_n_metadata = f.read().split('\n')[:-1] 127 | fold_n_metadata = [[data.split('__')[0], int(data.split('__')[1]) - 1] for data in fold_n_metadata] 128 | return fold_n_metadata 129 | 130 | img_metadata = [] 131 | if self.split == 'trn': # For training, read image-metadata of "the other" folds 132 | for fold_id in range(self.nfolds): 133 | if fold_id == self.fold: # Skip validation fold 134 | continue 135 | img_metadata += read_metadata(self.split, fold_id) 136 | elif self.split == 'val': # For validation, read image-metadata of "current" fold 137 | img_metadata = read_metadata(self.split, self.fold) 138 | else: 139 | raise Exception('Undefined split %s: ' % self.split) 140 | 141 | print('Total (%s) images are : %d' % (self.split, len(img_metadata))) 142 | 143 | return img_metadata 144 | 145 | def build_img_metadata_classwise(self): 146 | img_metadata_classwise = {} 147 | for class_id in range(self.nclass): 148 | img_metadata_classwise[class_id] = [] 149 | 150 | for img_name, img_class in self.img_metadata: 151 | img_metadata_classwise[img_class] += [img_name] 152 | return img_metadata_classwise 153 | -------------------------------------------------------------------------------- /fewshot_data/data/splits/coco/trn/fold0.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isl-org/lang-seg/bb00b219f82dc1a18a01f9b0b647bae1d41f3aeb/fewshot_data/data/splits/coco/trn/fold0.pkl -------------------------------------------------------------------------------- /fewshot_data/data/splits/coco/trn/fold1.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isl-org/lang-seg/bb00b219f82dc1a18a01f9b0b647bae1d41f3aeb/fewshot_data/data/splits/coco/trn/fold1.pkl -------------------------------------------------------------------------------- /fewshot_data/data/splits/coco/trn/fold2.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isl-org/lang-seg/bb00b219f82dc1a18a01f9b0b647bae1d41f3aeb/fewshot_data/data/splits/coco/trn/fold2.pkl -------------------------------------------------------------------------------- /fewshot_data/data/splits/coco/trn/fold3.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isl-org/lang-seg/bb00b219f82dc1a18a01f9b0b647bae1d41f3aeb/fewshot_data/data/splits/coco/trn/fold3.pkl -------------------------------------------------------------------------------- /fewshot_data/data/splits/coco/val/fold0.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isl-org/lang-seg/bb00b219f82dc1a18a01f9b0b647bae1d41f3aeb/fewshot_data/data/splits/coco/val/fold0.pkl -------------------------------------------------------------------------------- /fewshot_data/data/splits/coco/val/fold1.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isl-org/lang-seg/bb00b219f82dc1a18a01f9b0b647bae1d41f3aeb/fewshot_data/data/splits/coco/val/fold1.pkl -------------------------------------------------------------------------------- /fewshot_data/data/splits/coco/val/fold2.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isl-org/lang-seg/bb00b219f82dc1a18a01f9b0b647bae1d41f3aeb/fewshot_data/data/splits/coco/val/fold2.pkl -------------------------------------------------------------------------------- /fewshot_data/data/splits/coco/val/fold3.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isl-org/lang-seg/bb00b219f82dc1a18a01f9b0b647bae1d41f3aeb/fewshot_data/data/splits/coco/val/fold3.pkl -------------------------------------------------------------------------------- /fewshot_data/data/splits/fss/test.txt: -------------------------------------------------------------------------------- 1 | bus 2 | hotel_slipper 3 | burj_al 4 | reflex_camera 5 | abe's_flyingfish 6 | oiltank_car 7 | doormat 8 | fish_eagle 9 | barber_shaver 10 | motorbike 11 | feather_clothes 12 | wandering_albatross 13 | rice_cooker 14 | delta_wing 15 | fish 16 | nintendo_switch 17 | bustard 18 | diver 19 | minicooper 20 | cathedrale_paris 21 | big_ben 22 | combination_lock 23 | villa_savoye 24 | american_alligator 25 | gym_ball 26 | andean_condor 27 | leggings 28 | pyramid_cube 29 | jet_aircraft 30 | meatloaf 31 | reel 32 | swan 33 | osprey 34 | crt_screen 35 | microscope 36 | rubber_eraser 37 | arrow 38 | monkey 39 | mitten 40 | spiderman 41 | parthenon 42 | bat 43 | chess_king 44 | sulphur_butterfly 45 | quail_egg 46 | oriole 47 | iron_man 48 | wooden_boat 49 | anise 50 | steering_wheel 51 | groenendael 52 | dwarf_beans 53 | pteropus 54 | chalk_brush 55 | bloodhound 56 | moon 57 | english_foxhound 58 | boxing_gloves 59 | peregine_falcon 60 | pyraminx 61 | cicada 62 | screw 63 | shower_curtain 64 | tredmill 65 | bulb 66 | bell_pepper 67 | lemur_catta 68 | doughnut 69 | twin_tower 70 | astronaut 71 | nintendo_3ds 72 | fennel_bulb 73 | indri 74 | captain_america_shield 75 | kunai 76 | broom 77 | iphone 78 | earphone1 79 | flying_squirrel 80 | onion 81 | vinyl 82 | sydney_opera_house 83 | oyster 84 | harmonica 85 | egg 86 | breast_pump 87 | guitar 88 | potato_chips 89 | tunnel 90 | cuckoo 91 | rubick_cube 92 | plastic_bag 93 | phonograph 94 | net_surface_shoes 95 | goldfinch 96 | ipad 97 | mite_predator 98 | coffee_mug 99 | golden_plover 100 | f1_racing 101 | lapwing 102 | nintendo_gba 103 | pizza 104 | rally_car 105 | drilling_platform 106 | cd 107 | fly 108 | magpie_bird 109 | leaf_fan 110 | little_blue_heron 111 | carriage 112 | moist_proof_pad 113 | flying_snakes 114 | dart_target 115 | warehouse_tray 116 | nintendo_wiiu 117 | chiffon_cake 118 | bath_ball 119 | manatee 120 | cloud 121 | marimba 122 | eagle 123 | ruler 124 | soymilk_machine 125 | sled 126 | seagull 127 | glider_flyingfish 128 | doublebus 129 | transport_helicopter 130 | window_screen 131 | truss_bridge 132 | wasp 133 | snowman 134 | poached_egg 135 | strawberry 136 | spinach 137 | earphone2 138 | downy_pitch 139 | taj_mahal 140 | rocking_chair 141 | cablestayed_bridge 142 | sealion 143 | banana_boat 144 | pheasant 145 | stone_lion 146 | electronic_stove 147 | fox 148 | iguana 149 | rugby_ball 150 | hang_glider 151 | water_buffalo 152 | lotus 153 | paper_plane 154 | missile 155 | flamingo 156 | american_chamelon 157 | kart 158 | chinese_knot 159 | cabbage_butterfly 160 | key 161 | church 162 | tiltrotor 163 | helicopter 164 | french_fries 165 | water_heater 166 | snow_leopard 167 | goblet 168 | fan 169 | snowplow 170 | leafhopper 171 | pspgo 172 | black_bear 173 | quail 174 | condor 175 | chandelier 176 | hair_razor 177 | white_wolf 178 | toaster 179 | pidan 180 | pyramid 181 | chicken_leg 182 | letter_opener 183 | apple_icon 184 | porcupine 185 | chicken 186 | stingray 187 | warplane 188 | windmill 189 | bamboo_slip 190 | wig 191 | flying_geckos 192 | stonechat 193 | haddock 194 | australian_terrier 195 | hover_board 196 | siamang 197 | canton_tower 198 | santa_sledge 199 | arch_bridge 200 | curlew 201 | sushi 202 | beet_root 203 | accordion 204 | leaf_egg 205 | stealth_aircraft 206 | stork 207 | bucket 208 | hawk 209 | chess_queen 210 | ocarina 211 | knife 212 | whippet 213 | cantilever_bridge 214 | may_bug 215 | wagtail 216 | leather_shoes 217 | wheelchair 218 | shumai 219 | speedboat 220 | vacuum_cup 221 | chess_knight 222 | pumpkin_pie 223 | wooden_spoon 224 | bamboo_dragonfly 225 | ganeva_chair 226 | soap 227 | clearwing_flyingfish 228 | pencil_sharpener1 229 | cricket 230 | photocopier 231 | nintendo_sp 232 | samarra_mosque 233 | clam 234 | charge_battery 235 | flying_frog 236 | ferrari911 237 | polo_shirt 238 | echidna 239 | coin 240 | tower_pisa 241 | -------------------------------------------------------------------------------- /fewshot_data/data/splits/fss/trn.txt: -------------------------------------------------------------------------------- 1 | fountain 2 | taxi 3 | assult_rifle 4 | radio 5 | comb 6 | box_turtle 7 | igloo 8 | head_cabbage 9 | cottontail 10 | coho 11 | ashtray 12 | joystick 13 | sleeping_bag 14 | jackfruit 15 | trailer_truck 16 | shower_cap 17 | ibex 18 | kinguin 19 | squirrel 20 | ac_wall 21 | sidewinder 22 | remote_control 23 | marshmallow 24 | bolotie 25 | polar_bear 26 | rock_beauty 27 | tokyo_tower 28 | wafer 29 | red_bayberry 30 | electronic_toothbrush 31 | hartebeest 32 | cassette 33 | oil_filter 34 | bomb 35 | walnut 36 | toilet_tissue 37 | memory_stick 38 | wild_boar 39 | cableways 40 | chihuahua 41 | envelope 42 | bison 43 | poker 44 | pubg_lvl3helmet 45 | indian_cobra 46 | staffordshire 47 | park_bench 48 | wombat 49 | black_grouse 50 | submarine 51 | washer 52 | agama 53 | coyote 54 | feeder 55 | sarong 56 | buckingham_palace 57 | frog 58 | steam_locomotive 59 | acorn 60 | german_pointer 61 | obelisk 62 | polecat 63 | black_swan 64 | butterfly 65 | mountain_tent 66 | gorilla 67 | sloth_bear 68 | aubergine 69 | stinkhorn 70 | stole 71 | owl 72 | mooli 73 | pool_table 74 | collar 75 | lhasa_apso 76 | ambulance 77 | spade 78 | pufferfish 79 | paint_brush 80 | lark 81 | golf_ball 82 | hock 83 | fork 84 | drake 85 | bee_house 86 | mooncake 87 | wok 88 | cocacola 89 | water_bike 90 | ladder 91 | psp 92 | bassoon 93 | bear 94 | border_terrier 95 | petri_dish 96 | pill_bottle 97 | aircraft_carrier 98 | panther 99 | canoe 100 | baseball_player 101 | turtle 102 | espresso 103 | throne 104 | cornet 105 | coucal 106 | eletrical_switch 107 | bra 108 | snail 109 | backpack 110 | jacamar 111 | scroll_brush 112 | gliding_lizard 113 | raft 114 | pinwheel 115 | grasshopper 116 | green_mamba 117 | eft_newt 118 | computer_mouse 119 | vine_snake 120 | recreational_vehicle 121 | llama 122 | meerkat 123 | chainsaw 124 | ferret 125 | garbage_can 126 | kangaroo 127 | litchi 128 | carbonara 129 | housefinch 130 | modem 131 | tebby_cat 132 | thatch 133 | face_powder 134 | tomb 135 | apple 136 | ladybug 137 | killer_whale 138 | rocket 139 | airship 140 | surfboard 141 | lesser_panda 142 | jordan_logo 143 | banana 144 | nail_scissor 145 | swab 146 | perfume 147 | punching_bag 148 | victor_icon 149 | waffle_iron 150 | trimaran 151 | garlic 152 | flute 153 | langur 154 | starfish 155 | parallel_bars 156 | dandie_dinmont 157 | cosmetic_brush 158 | screwdriver 159 | brick_card 160 | balance_weight 161 | hornet 162 | carton 163 | toothpaste 164 | bracelet 165 | egg_tart 166 | pencil_sharpener2 167 | swimming_glasses 168 | howler_monkey 169 | camel 170 | dragonfly 171 | lionfish 172 | convertible 173 | mule 174 | usb 175 | conch 176 | papaya 177 | garbage_truck 178 | dingo 179 | radiator 180 | solar_dish 181 | streetcar 182 | trilobite 183 | bouzouki 184 | ringlet_butterfly 185 | space_shuttle 186 | waffle 187 | american_staffordshire 188 | violin 189 | flowerpot 190 | forklift 191 | manx 192 | sundial 193 | snowmobile 194 | chickadee_bird 195 | ruffed_grouse 196 | brick_tea 197 | paddle 198 | stove 199 | carousel 200 | spatula 201 | beaker 202 | gas_pump 203 | lawn_mower 204 | speaker 205 | tank 206 | tresher 207 | kappa_logo 208 | hare 209 | tennis_racket 210 | shopping_cart 211 | thimble 212 | tractor 213 | anemone_fish 214 | trolleybus 215 | steak 216 | capuchin 217 | red_breasted_merganser 218 | golden_retriever 219 | light_tube 220 | flatworm 221 | melon_seed 222 | digital_watch 223 | jacko_lantern 224 | brown_bear 225 | cairn 226 | mushroom 227 | chalk 228 | skull 229 | stapler 230 | potato 231 | telescope 232 | proboscis 233 | microphone 234 | torii 235 | baseball_bat 236 | dhole 237 | excavator 238 | fig 239 | snake 240 | bradypod 241 | pepitas 242 | prairie_chicken 243 | scorpion 244 | shotgun 245 | bottle_cap 246 | file_cabinet 247 | grey_whale 248 | one-armed_bandit 249 | banded_gecko 250 | flying_disc 251 | croissant 252 | toothbrush 253 | miniskirt 254 | pokermon_ball 255 | gazelle 256 | grey_fox 257 | esport_chair 258 | necklace 259 | ptarmigan 260 | watermelon 261 | besom 262 | pomelo 263 | radio_telescope 264 | studio_couch 265 | black_stork 266 | vestment 267 | koala 268 | brambling 269 | muscle_car 270 | window_shade 271 | space_heater 272 | sunglasses 273 | motor_scooter 274 | ladyfinger 275 | pencil_box 276 | titi_monkey 277 | chicken_wings 278 | mount_fuji 279 | giant_panda 280 | dart 281 | fire_engine 282 | running_shoe 283 | dumbbell 284 | donkey 285 | loafer 286 | hard_disk 287 | globe 288 | lifeboat 289 | medical_kit 290 | brain_coral 291 | paper_towel 292 | dugong 293 | seatbelt 294 | skunk 295 | military_vest 296 | cocktail_shaker 297 | zucchini 298 | quad_drone 299 | ocicat 300 | shih-tzu 301 | teapot 302 | tile_roof 303 | cheese_burger 304 | handshower 305 | red_wolf 306 | stop_sign 307 | mouse 308 | battery 309 | adidas_logo2 310 | earplug 311 | hummingbird 312 | brush_pen 313 | pistachio 314 | hamster 315 | air_strip 316 | indian_elephant 317 | otter 318 | cucumber 319 | scabbard 320 | hawthorn 321 | bullet_train 322 | leopard 323 | whale 324 | cream 325 | chinese_date 326 | jellyfish 327 | lobster 328 | skua 329 | single_log 330 | chicory 331 | bagel 332 | beacon 333 | pingpong_racket 334 | spoon 335 | yurt 336 | wallaby 337 | egret 338 | christmas_stocking 339 | mcdonald_uncle 340 | wrench 341 | spark_plug 342 | triceratops 343 | wall_clock 344 | jinrikisha 345 | pickup 346 | rhinoceros 347 | swimming_trunk 348 | band-aid 349 | spotted_salamander 350 | leeks 351 | marmot 352 | warthog 353 | cello 354 | stool 355 | chest 356 | toilet_plunger 357 | wardrobe 358 | cannon 359 | adidas_logo1 360 | drumstick 361 | lady_slipper 362 | puma_logo 363 | great_wall 364 | white_shark 365 | witch_hat 366 | vending_machine 367 | wreck 368 | chopsticks 369 | garfish 370 | african_elephant 371 | children_slide 372 | hornbill 373 | zebra 374 | boa_constrictor 375 | armour 376 | pineapple 377 | angora 378 | brick 379 | car_wheel 380 | wallet 381 | boston_bull 382 | hyena 383 | lynx 384 | crash_helmet 385 | terrapin_turtle 386 | persian_cat 387 | shift_gear 388 | cactus_ball 389 | fur_coat 390 | plate 391 | pen 392 | okra 393 | mario 394 | airedale 395 | cowboy_hat 396 | celery 397 | macaque 398 | candle 399 | goose 400 | raccoon 401 | brasscica 402 | almond 403 | maotai_bottle 404 | soccer_ball 405 | sports_car 406 | tobacco_pipe 407 | water_polo 408 | eggnog 409 | hook 410 | ostrich 411 | patas 412 | table_lamp 413 | teddy 414 | mongoose 415 | spoonbill 416 | redheart 417 | crane 418 | dinosaur 419 | kitchen_knife 420 | seal 421 | baboon 422 | golfcart 423 | roller_coaster 424 | avocado 425 | birdhouse 426 | yorkshire_terrier 427 | saluki 428 | basketball 429 | buckler 430 | harvester 431 | afghan_hound 432 | beam_bridge 433 | guinea_pig 434 | lorikeet 435 | shakuhachi 436 | motarboard 437 | statue_liberty 438 | police_car 439 | sulphur_crested 440 | gourd 441 | sombrero 442 | mailbox 443 | adhensive_tape 444 | night_snake 445 | bushtit 446 | mouthpiece 447 | beaver 448 | bathtub 449 | printer 450 | cumquat 451 | orange 452 | cleaver 453 | quill_pen 454 | panpipe 455 | diamond 456 | gypsy_moth 457 | cauliflower 458 | lampshade 459 | cougar 460 | traffic_light 461 | briefcase 462 | ballpoint 463 | african_grey 464 | kremlin 465 | barometer 466 | peacock 467 | paper_crane 468 | sunscreen 469 | tofu 470 | bedlington_terrier 471 | snowball 472 | carrot 473 | tiger 474 | mink 475 | cristo_redentor 476 | ladle 477 | keyboard 478 | maraca 479 | monitor 480 | water_snake 481 | can_opener 482 | mud_turtle 483 | bald_eagle 484 | carp 485 | cn_tower 486 | egyptian_cat 487 | hen_of_the_woods 488 | measuring_cup 489 | roller_skate 490 | kite 491 | sandwich_cookies 492 | sandwich 493 | persimmon 494 | chess_bishop 495 | coffin 496 | ruddy_turnstone 497 | prayer_rug 498 | rain_barrel 499 | neck_brace 500 | nematode 501 | rosehip 502 | dutch_oven 503 | goldfish 504 | blossom_card 505 | dough 506 | trench_coat 507 | sponge 508 | stupa 509 | wash_basin 510 | electric_fan 511 | spring_scroll 512 | potted_plant 513 | sparrow 514 | car_mirror 515 | gecko 516 | diaper 517 | leatherback_turtle 518 | strainer 519 | guacamole 520 | microwave 521 | -------------------------------------------------------------------------------- /fewshot_data/data/splits/fss/val.txt: -------------------------------------------------------------------------------- 1 | handcuff 2 | mortar 3 | matchstick 4 | wine_bottle 5 | dowitcher 6 | triumphal_arch 7 | gyromitra 8 | hatchet 9 | airliner 10 | broccoli 11 | olive 12 | pubg_lvl3backpack 13 | calculator 14 | toucan 15 | shovel 16 | sewing_machine 17 | icecream 18 | woodpecker 19 | pig 20 | relay_stick 21 | mcdonald_sign 22 | cpu 23 | peanut 24 | pumpkin 25 | sturgeon 26 | hammer 27 | hami_melon 28 | squirrel_monkey 29 | shuriken 30 | power_drill 31 | pingpong_ball 32 | crocodile 33 | carambola 34 | monarch_butterfly 35 | drum 36 | water_tower 37 | panda 38 | toilet_brush 39 | pay_phone 40 | yonex_icon 41 | cricketball 42 | revolver 43 | chimpanzee 44 | crab 45 | corn 46 | baseball 47 | rabbit 48 | croquet_ball 49 | artichoke 50 | abacus 51 | harp 52 | bell 53 | gas_tank 54 | scissors 55 | vase 56 | upright_piano 57 | typewriter 58 | bittern 59 | impala 60 | tray 61 | fire_hydrant 62 | beer_bottle 63 | sock 64 | soup_bowl 65 | spider 66 | cherry 67 | macaw 68 | toilet_seat 69 | fire_balloon 70 | french_ball 71 | fox_squirrel 72 | volleyball 73 | cornmeal 74 | folding_chair 75 | pubg_airdrop 76 | beagle 77 | skateboard 78 | narcissus 79 | whiptail 80 | cup 81 | arabian_camel 82 | badger 83 | stopwatch 84 | ab_wheel 85 | ox 86 | lettuce 87 | monocycle 88 | redshank 89 | vulture 90 | whistle 91 | smoothing_iron 92 | mashed_potato 93 | conveyor 94 | yoga_pad 95 | tow_truck 96 | siamese_cat 97 | cigar 98 | white_stork 99 | sniper_rifle 100 | stretcher 101 | tulip 102 | handkerchief 103 | basset 104 | iceberg 105 | gibbon 106 | lacewing 107 | thrush 108 | cheetah 109 | bighorn_sheep 110 | espresso_maker 111 | pretzel 112 | english_setter 113 | sandbar 114 | cheese 115 | daisy 116 | arctic_fox 117 | briard 118 | colubus 119 | balance_beam 120 | coffeepot 121 | soap_dispenser 122 | yawl 123 | consomme 124 | parking_meter 125 | cactus 126 | turnstile 127 | taro 128 | fire_screen 129 | digital_clock 130 | rose 131 | pomegranate 132 | bee_eater 133 | schooner 134 | ski_mask 135 | jay_bird 136 | plaice 137 | red_fox 138 | syringe 139 | camomile 140 | pickelhaube 141 | blenheim_spaniel 142 | pear 143 | parachute 144 | common_newt 145 | bowtie 146 | cigarette 147 | oscilloscope 148 | laptop 149 | african_crocodile 150 | apron 151 | coconut 152 | sandal 153 | kwanyin 154 | lion 155 | eel 156 | balloon 157 | crepe 158 | armadillo 159 | kazoo 160 | lemon 161 | spider_monkey 162 | tape_player 163 | ipod 164 | bee 165 | sea_cucumber 166 | suitcase 167 | television 168 | pillow 169 | banjo 170 | rock_snake 171 | partridge 172 | platypus 173 | lycaenid_butterfly 174 | pinecone 175 | conversion_plug 176 | wolf 177 | frying_pan 178 | timber_wolf 179 | bluetick 180 | crayon 181 | giant_schnauzer 182 | orang 183 | scarerow 184 | kobe_logo 185 | loguat 186 | saxophone 187 | ceiling_fan 188 | cardoon 189 | equestrian_helmet 190 | louvre_pyramid 191 | hotdog 192 | ironing_board 193 | razor 194 | nagoya_castle 195 | loggerhead_turtle 196 | lipstick 197 | cradle 198 | strongbox 199 | raven 200 | kit_fox 201 | albatross 202 | flat-coated_retriever 203 | beer_glass 204 | ice_lolly 205 | sungnyemun 206 | totem_pole 207 | vacuum 208 | bolete 209 | mango 210 | ginger 211 | weasel 212 | cabbage 213 | refrigerator 214 | school_bus 215 | hippo 216 | tiger_cat 217 | saltshaker 218 | piano_keyboard 219 | windsor_tie 220 | sea_urchin 221 | microsd 222 | barbell 223 | swim_ring 224 | bulbul_bird 225 | water_ouzel 226 | ac_ground 227 | sweatshirt 228 | umbrella 229 | hair_drier 230 | hammerhead_shark 231 | tomato 232 | projector 233 | cushion 234 | dishwasher 235 | three-toed_sloth 236 | tiger_shark 237 | har_gow 238 | baby 239 | thor's_hammer 240 | nike_logo 241 | -------------------------------------------------------------------------------- /fewshot_data/data/splits/pascal/val/fold0.txt: -------------------------------------------------------------------------------- 1 | 2007_000033__01 2 | 2007_000061__04 3 | 2007_000129__02 4 | 2007_000346__05 5 | 2007_000529__04 6 | 2007_000559__05 7 | 2007_000572__02 8 | 2007_000762__05 9 | 2007_001288__01 10 | 2007_001289__03 11 | 2007_001311__02 12 | 2007_001408__05 13 | 2007_001568__01 14 | 2007_001630__02 15 | 2007_001761__01 16 | 2007_001884__01 17 | 2007_002094__03 18 | 2007_002266__01 19 | 2007_002376__01 20 | 2007_002400__03 21 | 2007_002619__01 22 | 2007_002719__04 23 | 2007_003088__05 24 | 2007_003131__04 25 | 2007_003188__02 26 | 2007_003349__03 27 | 2007_003571__04 28 | 2007_003621__02 29 | 2007_003682__03 30 | 2007_003861__04 31 | 2007_004052__01 32 | 2007_004143__03 33 | 2007_004241__04 34 | 2007_004468__05 35 | 2007_005074__04 36 | 2007_005107__02 37 | 2007_005294__05 38 | 2007_005304__05 39 | 2007_005428__05 40 | 2007_005509__01 41 | 2007_005600__01 42 | 2007_005705__04 43 | 2007_005828__01 44 | 2007_006076__03 45 | 2007_006086__05 46 | 2007_006449__02 47 | 2007_006946__01 48 | 2007_007084__03 49 | 2007_007235__02 50 | 2007_007341__01 51 | 2007_007470__01 52 | 2007_007477__04 53 | 2007_007836__02 54 | 2007_008051__03 55 | 2007_008084__03 56 | 2007_008204__05 57 | 2007_008670__03 58 | 2007_009088__03 59 | 2007_009258__02 60 | 2007_009323__03 61 | 2007_009458__05 62 | 2007_009687__05 63 | 2007_009817__03 64 | 2007_009911__01 65 | 2008_000120__04 66 | 2008_000123__03 67 | 2008_000533__03 68 | 2008_000725__02 69 | 2008_000911__05 70 | 2008_001013__04 71 | 2008_001040__04 72 | 2008_001135__04 73 | 2008_001260__04 74 | 2008_001404__02 75 | 2008_001514__03 76 | 2008_001531__02 77 | 2008_001546__01 78 | 2008_001580__04 79 | 2008_001966__03 80 | 2008_001971__01 81 | 2008_002043__03 82 | 2008_002269__02 83 | 2008_002358__01 84 | 2008_002429__03 85 | 2008_002467__05 86 | 2008_002504__04 87 | 2008_002775__05 88 | 2008_002864__05 89 | 2008_003034__04 90 | 2008_003076__05 91 | 2008_003108__02 92 | 2008_003110__03 93 | 2008_003155__01 94 | 2008_003270__02 95 | 2008_003369__01 96 | 2008_003858__04 97 | 2008_003876__01 98 | 2008_003886__04 99 | 2008_003926__01 100 | 2008_003976__01 101 | 2008_004363__02 102 | 2008_004654__02 103 | 2008_004659__05 104 | 2008_004704__01 105 | 2008_004758__02 106 | 2008_004995__02 107 | 2008_005262__05 108 | 2008_005338__01 109 | 2008_005628__04 110 | 2008_005727__02 111 | 2008_005812__05 112 | 2008_005904__05 113 | 2008_006216__01 114 | 2008_006229__04 115 | 2008_006254__02 116 | 2008_006703__01 117 | 2008_007120__03 118 | 2008_007143__04 119 | 2008_007219__05 120 | 2008_007350__01 121 | 2008_007498__03 122 | 2008_007811__05 123 | 2008_007994__03 124 | 2008_008268__03 125 | 2008_008629__02 126 | 2008_008711__02 127 | 2008_008746__03 128 | 2009_000032__01 129 | 2009_000037__03 130 | 2009_000121__05 131 | 2009_000149__02 132 | 2009_000201__05 133 | 2009_000205__01 134 | 2009_000318__03 135 | 2009_000354__02 136 | 2009_000387__01 137 | 2009_000421__04 138 | 2009_000440__01 139 | 2009_000446__04 140 | 2009_000457__02 141 | 2009_000469__04 142 | 2009_000573__02 143 | 2009_000619__03 144 | 2009_000664__03 145 | 2009_000723__04 146 | 2009_000828__04 147 | 2009_000840__05 148 | 2009_000879__03 149 | 2009_000991__03 150 | 2009_000998__03 151 | 2009_001108__03 152 | 2009_001160__03 153 | 2009_001255__02 154 | 2009_001278__05 155 | 2009_001314__03 156 | 2009_001332__01 157 | 2009_001565__03 158 | 2009_001607__03 159 | 2009_001683__03 160 | 2009_001718__02 161 | 2009_001765__03 162 | 2009_001818__05 163 | 2009_001850__01 164 | 2009_001851__01 165 | 2009_001941__04 166 | 2009_002185__05 167 | 2009_002295__02 168 | 2009_002320__01 169 | 2009_002372__05 170 | 2009_002521__05 171 | 2009_002594__05 172 | 2009_002604__03 173 | 2009_002649__05 174 | 2009_002727__04 175 | 2009_002732__05 176 | 2009_002749__05 177 | 2009_002808__01 178 | 2009_002856__05 179 | 2009_002888__01 180 | 2009_002928__02 181 | 2009_003003__05 182 | 2009_003005__01 183 | 2009_003043__04 184 | 2009_003080__04 185 | 2009_003193__02 186 | 2009_003224__02 187 | 2009_003269__05 188 | 2009_003273__03 189 | 2009_003343__02 190 | 2009_003378__03 191 | 2009_003450__03 192 | 2009_003498__03 193 | 2009_003504__04 194 | 2009_003517__05 195 | 2009_003640__03 196 | 2009_003696__01 197 | 2009_003707__04 198 | 2009_003806__01 199 | 2009_003858__03 200 | 2009_003971__02 201 | 2009_004021__03 202 | 2009_004084__03 203 | 2009_004125__04 204 | 2009_004247__05 205 | 2009_004324__05 206 | 2009_004509__03 207 | 2009_004540__03 208 | 2009_004568__03 209 | 2009_004579__05 210 | 2009_004635__04 211 | 2009_004653__01 212 | 2009_004848__02 213 | 2009_004882__02 214 | 2009_004886__03 215 | 2009_004895__03 216 | 2009_004969__01 217 | 2009_005038__05 218 | 2009_005137__03 219 | 2009_005156__02 220 | 2009_005189__01 221 | 2009_005190__05 222 | 2009_005260__03 223 | 2009_005262__03 224 | 2009_005302__05 225 | 2010_000065__02 226 | 2010_000083__02 227 | 2010_000084__04 228 | 2010_000238__01 229 | 2010_000241__03 230 | 2010_000272__04 231 | 2010_000342__02 232 | 2010_000426__05 233 | 2010_000572__01 234 | 2010_000622__01 235 | 2010_000814__03 236 | 2010_000906__04 237 | 2010_000961__03 238 | 2010_001016__03 239 | 2010_001017__01 240 | 2010_001024__01 241 | 2010_001036__04 242 | 2010_001061__03 243 | 2010_001069__03 244 | 2010_001174__01 245 | 2010_001367__02 246 | 2010_001367__05 247 | 2010_001448__01 248 | 2010_001830__05 249 | 2010_001995__03 250 | 2010_002017__05 251 | 2010_002030__02 252 | 2010_002142__03 253 | 2010_002147__01 254 | 2010_002150__04 255 | 2010_002200__01 256 | 2010_002310__01 257 | 2010_002536__02 258 | 2010_002546__04 259 | 2010_002693__02 260 | 2010_002939__01 261 | 2010_003127__01 262 | 2010_003132__01 263 | 2010_003168__03 264 | 2010_003362__03 265 | 2010_003365__01 266 | 2010_003418__03 267 | 2010_003468__05 268 | 2010_003473__03 269 | 2010_003495__01 270 | 2010_003547__04 271 | 2010_003716__01 272 | 2010_003771__03 273 | 2010_003781__05 274 | 2010_003820__03 275 | 2010_003912__02 276 | 2010_003915__01 277 | 2010_004041__04 278 | 2010_004056__05 279 | 2010_004208__04 280 | 2010_004314__01 281 | 2010_004419__01 282 | 2010_004520__05 283 | 2010_004529__05 284 | 2010_004551__05 285 | 2010_004556__03 286 | 2010_004559__03 287 | 2010_004662__04 288 | 2010_004772__04 289 | 2010_004828__05 290 | 2010_004994__03 291 | 2010_005252__04 292 | 2010_005401__04 293 | 2010_005428__03 294 | 2010_005496__05 295 | 2010_005531__03 296 | 2010_005534__01 297 | 2010_005582__05 298 | 2010_005664__02 299 | 2010_005705__04 300 | 2010_005718__01 301 | 2010_005762__05 302 | 2010_005877__01 303 | 2010_005888__01 304 | 2010_006034__01 305 | 2010_006070__02 306 | 2011_000066__05 307 | 2011_000112__03 308 | 2011_000185__03 309 | 2011_000234__04 310 | 2011_000238__04 311 | 2011_000412__02 312 | 2011_000435__04 313 | 2011_000456__03 314 | 2011_000482__03 315 | 2011_000585__02 316 | 2011_000669__03 317 | 2011_000747__05 318 | 2011_000874__01 319 | 2011_001114__01 320 | 2011_001161__04 321 | 2011_001263__01 322 | 2011_001287__03 323 | 2011_001407__01 324 | 2011_001421__03 325 | 2011_001434__01 326 | 2011_001589__04 327 | 2011_001624__01 328 | 2011_001793__04 329 | 2011_001880__01 330 | 2011_001988__02 331 | 2011_002064__02 332 | 2011_002098__05 333 | 2011_002223__02 334 | 2011_002295__03 335 | 2011_002327__01 336 | 2011_002515__01 337 | 2011_002675__01 338 | 2011_002713__02 339 | 2011_002754__04 340 | 2011_002863__05 341 | 2011_002929__01 342 | 2011_002975__04 343 | 2011_003003__02 344 | 2011_003030__03 345 | 2011_003145__03 346 | 2011_003271__05 347 | -------------------------------------------------------------------------------- /fewshot_data/data/splits/pascal/val/fold1.txt: -------------------------------------------------------------------------------- 1 | 2007_000452__09 2 | 2007_000464__10 3 | 2007_000491__10 4 | 2007_000663__06 5 | 2007_000663__07 6 | 2007_000727__06 7 | 2007_000727__07 8 | 2007_000804__09 9 | 2007_000830__09 10 | 2007_001299__10 11 | 2007_001321__07 12 | 2007_001457__09 13 | 2007_001677__09 14 | 2007_001717__09 15 | 2007_001763__08 16 | 2007_001774__08 17 | 2007_001884__06 18 | 2007_002268__08 19 | 2007_002387__10 20 | 2007_002445__08 21 | 2007_002470__08 22 | 2007_002539__06 23 | 2007_002597__08 24 | 2007_002643__07 25 | 2007_002903__10 26 | 2007_003011__09 27 | 2007_003051__07 28 | 2007_003101__06 29 | 2007_003106__08 30 | 2007_003137__06 31 | 2007_003143__07 32 | 2007_003169__08 33 | 2007_003195__06 34 | 2007_003201__10 35 | 2007_003503__06 36 | 2007_003503__07 37 | 2007_003621__06 38 | 2007_003711__06 39 | 2007_003786__06 40 | 2007_003841__10 41 | 2007_003917__07 42 | 2007_003991__08 43 | 2007_004193__09 44 | 2007_004392__09 45 | 2007_004405__09 46 | 2007_004510__09 47 | 2007_004712__09 48 | 2007_004856__08 49 | 2007_004866__08 50 | 2007_005074__07 51 | 2007_005114__10 52 | 2007_005296__07 53 | 2007_005331__07 54 | 2007_005460__08 55 | 2007_005547__07 56 | 2007_005547__10 57 | 2007_005844__09 58 | 2007_005845__08 59 | 2007_005911__06 60 | 2007_005978__06 61 | 2007_006035__07 62 | 2007_006086__09 63 | 2007_006241__09 64 | 2007_006260__08 65 | 2007_006277__07 66 | 2007_006348__09 67 | 2007_006553__09 68 | 2007_006761__10 69 | 2007_006841__10 70 | 2007_007414__07 71 | 2007_007417__08 72 | 2007_007524__08 73 | 2007_007815__07 74 | 2007_007818__07 75 | 2007_007996__09 76 | 2007_008106__09 77 | 2007_008110__09 78 | 2007_008543__09 79 | 2007_008722__10 80 | 2007_008747__06 81 | 2007_008815__08 82 | 2007_008897__09 83 | 2007_008973__10 84 | 2007_009015__06 85 | 2007_009015__07 86 | 2007_009068__09 87 | 2007_009084__09 88 | 2007_009096__07 89 | 2007_009221__08 90 | 2007_009245__10 91 | 2007_009346__08 92 | 2007_009392__06 93 | 2007_009392__07 94 | 2007_009413__09 95 | 2007_009521__09 96 | 2007_009764__06 97 | 2007_009794__08 98 | 2007_009897__10 99 | 2007_009923__08 100 | 2007_009938__07 101 | 2008_000009__10 102 | 2008_000073__10 103 | 2008_000075__06 104 | 2008_000107__09 105 | 2008_000149__09 106 | 2008_000182__08 107 | 2008_000345__08 108 | 2008_000401__08 109 | 2008_000464__08 110 | 2008_000501__07 111 | 2008_000673__09 112 | 2008_000853__08 113 | 2008_000919__10 114 | 2008_001078__08 115 | 2008_001433__08 116 | 2008_001439__09 117 | 2008_001513__08 118 | 2008_001640__08 119 | 2008_001715__09 120 | 2008_001885__08 121 | 2008_002152__08 122 | 2008_002205__06 123 | 2008_002212__07 124 | 2008_002379__09 125 | 2008_002521__09 126 | 2008_002623__08 127 | 2008_002681__08 128 | 2008_002778__10 129 | 2008_002958__07 130 | 2008_003141__06 131 | 2008_003141__07 132 | 2008_003333__07 133 | 2008_003477__09 134 | 2008_003499__08 135 | 2008_003577__07 136 | 2008_003777__06 137 | 2008_003821__09 138 | 2008_003846__07 139 | 2008_004069__07 140 | 2008_004339__07 141 | 2008_004552__07 142 | 2008_004612__09 143 | 2008_004701__10 144 | 2008_005097__10 145 | 2008_005105__10 146 | 2008_005245__07 147 | 2008_005676__06 148 | 2008_006008__09 149 | 2008_006063__10 150 | 2008_006254__07 151 | 2008_006325__08 152 | 2008_006341__08 153 | 2008_006480__08 154 | 2008_006528__10 155 | 2008_006554__06 156 | 2008_006986__07 157 | 2008_007025__10 158 | 2008_007031__10 159 | 2008_007048__09 160 | 2008_007123__10 161 | 2008_007194__09 162 | 2008_007273__10 163 | 2008_007378__09 164 | 2008_007402__09 165 | 2008_007527__09 166 | 2008_007548__08 167 | 2008_007596__10 168 | 2008_007737__09 169 | 2008_007797__06 170 | 2008_007804__07 171 | 2008_007828__09 172 | 2008_008252__06 173 | 2008_008301__06 174 | 2008_008469__06 175 | 2008_008682__06 176 | 2009_000013__08 177 | 2009_000080__08 178 | 2009_000219__10 179 | 2009_000309__10 180 | 2009_000335__06 181 | 2009_000335__07 182 | 2009_000426__06 183 | 2009_000455__06 184 | 2009_000457__07 185 | 2009_000523__07 186 | 2009_000641__10 187 | 2009_000716__08 188 | 2009_000731__10 189 | 2009_000771__10 190 | 2009_000825__07 191 | 2009_000964__08 192 | 2009_001008__08 193 | 2009_001082__06 194 | 2009_001240__07 195 | 2009_001255__07 196 | 2009_001299__09 197 | 2009_001391__08 198 | 2009_001411__08 199 | 2009_001536__07 200 | 2009_001775__09 201 | 2009_001804__06 202 | 2009_001816__06 203 | 2009_001854__06 204 | 2009_002035__10 205 | 2009_002122__10 206 | 2009_002150__10 207 | 2009_002164__07 208 | 2009_002171__10 209 | 2009_002221__10 210 | 2009_002238__06 211 | 2009_002238__07 212 | 2009_002239__07 213 | 2009_002268__08 214 | 2009_002346__09 215 | 2009_002415__09 216 | 2009_002487__09 217 | 2009_002527__08 218 | 2009_002535__06 219 | 2009_002549__10 220 | 2009_002571__09 221 | 2009_002618__07 222 | 2009_002635__10 223 | 2009_002753__08 224 | 2009_002936__08 225 | 2009_002990__07 226 | 2009_003003__07 227 | 2009_003059__10 228 | 2009_003071__09 229 | 2009_003269__07 230 | 2009_003304__06 231 | 2009_003387__07 232 | 2009_003406__07 233 | 2009_003494__09 234 | 2009_003507__09 235 | 2009_003542__10 236 | 2009_003549__07 237 | 2009_003569__10 238 | 2009_003589__07 239 | 2009_003703__06 240 | 2009_003771__08 241 | 2009_003773__10 242 | 2009_003849__09 243 | 2009_003895__09 244 | 2009_003904__08 245 | 2009_004072__06 246 | 2009_004140__09 247 | 2009_004217__09 248 | 2009_004248__08 249 | 2009_004455__07 250 | 2009_004504__08 251 | 2009_004590__06 252 | 2009_004594__07 253 | 2009_004687__09 254 | 2009_004721__08 255 | 2009_004732__06 256 | 2009_004748__07 257 | 2009_004789__06 258 | 2009_004859__09 259 | 2009_004867__06 260 | 2009_005158__08 261 | 2009_005219__08 262 | 2009_005231__06 263 | 2010_000003__09 264 | 2010_000160__07 265 | 2010_000163__08 266 | 2010_000372__07 267 | 2010_000427__10 268 | 2010_000530__07 269 | 2010_000552__08 270 | 2010_000573__06 271 | 2010_000628__07 272 | 2010_000639__09 273 | 2010_000682__06 274 | 2010_000683__08 275 | 2010_000724__08 276 | 2010_000907__10 277 | 2010_000941__08 278 | 2010_000952__07 279 | 2010_001000__10 280 | 2010_001010__10 281 | 2010_001070__08 282 | 2010_001206__06 283 | 2010_001292__08 284 | 2010_001331__08 285 | 2010_001351__08 286 | 2010_001403__06 287 | 2010_001403__07 288 | 2010_001534__08 289 | 2010_001553__07 290 | 2010_001579__09 291 | 2010_001646__06 292 | 2010_001656__08 293 | 2010_001692__10 294 | 2010_001699__09 295 | 2010_001767__07 296 | 2010_001851__09 297 | 2010_001913__08 298 | 2010_002017__07 299 | 2010_002017__09 300 | 2010_002025__08 301 | 2010_002137__08 302 | 2010_002146__08 303 | 2010_002305__08 304 | 2010_002336__09 305 | 2010_002348__08 306 | 2010_002361__07 307 | 2010_002390__10 308 | 2010_002422__08 309 | 2010_002512__08 310 | 2010_002531__08 311 | 2010_002546__06 312 | 2010_002623__09 313 | 2010_002693__08 314 | 2010_002693__09 315 | 2010_002763__08 316 | 2010_002763__10 317 | 2010_002868__06 318 | 2010_002900__08 319 | 2010_002902__07 320 | 2010_002921__09 321 | 2010_002929__07 322 | 2010_002988__07 323 | 2010_003123__07 324 | 2010_003183__10 325 | 2010_003231__07 326 | 2010_003239__10 327 | 2010_003275__08 328 | 2010_003276__07 329 | 2010_003293__06 330 | 2010_003302__09 331 | 2010_003325__09 332 | 2010_003381__07 333 | 2010_003402__08 334 | 2010_003409__09 335 | 2010_003446__07 336 | 2010_003453__07 337 | 2010_003468__08 338 | 2010_003531__09 339 | 2010_003675__08 340 | 2010_003746__07 341 | 2010_003758__08 342 | 2010_003764__08 343 | 2010_003768__07 344 | 2010_003772__06 345 | 2010_003781__08 346 | 2010_003813__07 347 | 2010_003854__07 348 | 2010_003971__08 349 | 2010_003971__09 350 | 2010_004104__08 351 | 2010_004120__08 352 | 2010_004320__08 353 | 2010_004322__10 354 | 2010_004348__06 355 | 2010_004369__08 356 | 2010_004472__07 357 | 2010_004479__08 358 | 2010_004635__10 359 | 2010_004763__09 360 | 2010_004783__09 361 | 2010_004789__10 362 | 2010_004815__08 363 | 2010_004825__09 364 | 2010_004861__08 365 | 2010_004946__07 366 | 2010_005013__07 367 | 2010_005021__08 368 | 2010_005021__09 369 | 2010_005063__06 370 | 2010_005108__08 371 | 2010_005118__06 372 | 2010_005160__06 373 | 2010_005166__10 374 | 2010_005284__06 375 | 2010_005344__08 376 | 2010_005421__08 377 | 2010_005432__07 378 | 2010_005501__07 379 | 2010_005508__08 380 | 2010_005606__08 381 | 2010_005709__08 382 | 2010_005718__07 383 | 2010_005860__07 384 | 2010_005899__08 385 | 2010_006070__07 386 | 2011_000178__06 387 | 2011_000226__09 388 | 2011_000239__06 389 | 2011_000248__06 390 | 2011_000312__06 391 | 2011_000338__09 392 | 2011_000419__08 393 | 2011_000503__07 394 | 2011_000548__10 395 | 2011_000566__10 396 | 2011_000607__09 397 | 2011_000661__08 398 | 2011_000661__09 399 | 2011_000780__08 400 | 2011_000789__08 401 | 2011_000809__09 402 | 2011_000813__08 403 | 2011_000813__09 404 | 2011_000830__06 405 | 2011_000843__09 406 | 2011_000888__06 407 | 2011_000900__07 408 | 2011_000969__06 409 | 2011_001047__10 410 | 2011_001064__06 411 | 2011_001071__09 412 | 2011_001110__07 413 | 2011_001159__10 414 | 2011_001232__10 415 | 2011_001292__08 416 | 2011_001341__06 417 | 2011_001346__09 418 | 2011_001447__09 419 | 2011_001530__10 420 | 2011_001534__08 421 | 2011_001546__10 422 | 2011_001567__09 423 | 2011_001597__08 424 | 2011_001601__08 425 | 2011_001607__08 426 | 2011_001665__09 427 | 2011_001708__10 428 | 2011_001775__08 429 | 2011_001782__10 430 | 2011_001812__09 431 | 2011_002041__09 432 | 2011_002064__07 433 | 2011_002124__09 434 | 2011_002200__09 435 | 2011_002298__09 436 | 2011_002322__07 437 | 2011_002343__09 438 | 2011_002358__09 439 | 2011_002391__09 440 | 2011_002509__09 441 | 2011_002592__07 442 | 2011_002644__09 443 | 2011_002685__08 444 | 2011_002812__07 445 | 2011_002885__10 446 | 2011_003011__09 447 | 2011_003019__07 448 | 2011_003019__10 449 | 2011_003055__07 450 | 2011_003103__09 451 | 2011_003114__06 452 | -------------------------------------------------------------------------------- /fewshot_data/data/splits/pascal/val/fold3.txt: -------------------------------------------------------------------------------- 1 | 2007_000042__19 2 | 2007_000123__19 3 | 2007_000175__17 4 | 2007_000187__20 5 | 2007_000452__18 6 | 2007_000559__20 7 | 2007_000629__19 8 | 2007_000636__19 9 | 2007_000661__18 10 | 2007_000676__17 11 | 2007_000804__18 12 | 2007_000925__17 13 | 2007_001154__18 14 | 2007_001175__20 15 | 2007_001408__16 16 | 2007_001430__16 17 | 2007_001430__20 18 | 2007_001457__18 19 | 2007_001458__18 20 | 2007_001585__18 21 | 2007_001594__17 22 | 2007_001678__20 23 | 2007_001717__20 24 | 2007_001733__17 25 | 2007_001763__18 26 | 2007_001763__20 27 | 2007_002119__20 28 | 2007_002132__20 29 | 2007_002268__18 30 | 2007_002284__16 31 | 2007_002378__16 32 | 2007_002426__18 33 | 2007_002427__18 34 | 2007_002565__19 35 | 2007_002618__17 36 | 2007_002648__17 37 | 2007_002728__19 38 | 2007_003011__18 39 | 2007_003011__20 40 | 2007_003169__18 41 | 2007_003367__16 42 | 2007_003499__19 43 | 2007_003506__16 44 | 2007_003530__18 45 | 2007_003587__19 46 | 2007_003714__17 47 | 2007_003848__19 48 | 2007_003957__19 49 | 2007_004190__20 50 | 2007_004193__20 51 | 2007_004275__16 52 | 2007_004281__19 53 | 2007_004483__19 54 | 2007_004510__20 55 | 2007_004558__16 56 | 2007_004649__19 57 | 2007_004712__16 58 | 2007_004969__17 59 | 2007_005469__17 60 | 2007_005626__19 61 | 2007_005689__19 62 | 2007_005813__16 63 | 2007_005857__16 64 | 2007_005915__17 65 | 2007_006171__18 66 | 2007_006348__20 67 | 2007_006373__18 68 | 2007_006678__17 69 | 2007_006680__19 70 | 2007_006802__19 71 | 2007_007130__20 72 | 2007_007165__17 73 | 2007_007168__19 74 | 2007_007195__19 75 | 2007_007196__20 76 | 2007_007203__20 77 | 2007_007417__18 78 | 2007_007534__17 79 | 2007_007624__16 80 | 2007_007795__16 81 | 2007_007881__19 82 | 2007_007996__18 83 | 2007_008204__20 84 | 2007_008260__18 85 | 2007_008339__19 86 | 2007_008374__20 87 | 2007_008543__18 88 | 2007_008547__16 89 | 2007_009068__18 90 | 2007_009252__18 91 | 2007_009320__17 92 | 2007_009419__16 93 | 2007_009446__20 94 | 2007_009521__18 95 | 2007_009521__20 96 | 2007_009592__18 97 | 2007_009655__18 98 | 2007_009684__18 99 | 2007_009750__16 100 | 2008_000016__20 101 | 2008_000149__18 102 | 2008_000270__18 103 | 2008_000391__16 104 | 2008_000589__18 105 | 2008_000657__19 106 | 2008_001078__16 107 | 2008_001283__16 108 | 2008_001688__16 109 | 2008_001688__20 110 | 2008_001966__16 111 | 2008_002273__16 112 | 2008_002379__16 113 | 2008_002464__20 114 | 2008_002536__17 115 | 2008_002680__20 116 | 2008_002900__19 117 | 2008_002929__18 118 | 2008_003003__20 119 | 2008_003026__20 120 | 2008_003105__19 121 | 2008_003135__16 122 | 2008_003676__16 123 | 2008_003709__18 124 | 2008_003733__18 125 | 2008_003885__20 126 | 2008_004172__18 127 | 2008_004212__19 128 | 2008_004279__20 129 | 2008_004367__19 130 | 2008_004453__17 131 | 2008_004477__16 132 | 2008_004562__18 133 | 2008_004610__19 134 | 2008_004621__17 135 | 2008_004754__20 136 | 2008_004854__17 137 | 2008_004910__20 138 | 2008_005089__20 139 | 2008_005217__16 140 | 2008_005242__16 141 | 2008_005254__20 142 | 2008_005439__20 143 | 2008_005445__20 144 | 2008_005544__19 145 | 2008_005633__17 146 | 2008_005680__16 147 | 2008_006055__19 148 | 2008_006159__20 149 | 2008_006327__17 150 | 2008_006523__19 151 | 2008_006553__19 152 | 2008_006752__19 153 | 2008_006784__18 154 | 2008_006835__17 155 | 2008_007497__17 156 | 2008_007527__20 157 | 2008_007677__17 158 | 2008_007814__17 159 | 2008_007828__20 160 | 2008_008103__18 161 | 2008_008221__19 162 | 2008_008434__16 163 | 2009_000022__19 164 | 2009_000039__17 165 | 2009_000087__18 166 | 2009_000096__18 167 | 2009_000136__20 168 | 2009_000242__18 169 | 2009_000391__20 170 | 2009_000418__16 171 | 2009_000418__18 172 | 2009_000487__18 173 | 2009_000488__16 174 | 2009_000488__20 175 | 2009_000628__19 176 | 2009_000675__17 177 | 2009_000704__20 178 | 2009_000712__19 179 | 2009_000732__18 180 | 2009_000845__19 181 | 2009_000924__17 182 | 2009_001300__19 183 | 2009_001333__19 184 | 2009_001363__20 185 | 2009_001505__17 186 | 2009_001644__16 187 | 2009_001644__18 188 | 2009_001644__20 189 | 2009_001684__16 190 | 2009_001731__18 191 | 2009_001768__17 192 | 2009_001775__16 193 | 2009_001775__18 194 | 2009_001991__17 195 | 2009_002082__17 196 | 2009_002094__20 197 | 2009_002202__19 198 | 2009_002265__19 199 | 2009_002291__19 200 | 2009_002346__18 201 | 2009_002366__20 202 | 2009_002390__18 203 | 2009_002487__16 204 | 2009_002562__20 205 | 2009_002568__19 206 | 2009_002571__16 207 | 2009_002571__18 208 | 2009_002573__20 209 | 2009_002584__16 210 | 2009_002638__19 211 | 2009_002732__18 212 | 2009_002887__19 213 | 2009_002982__19 214 | 2009_003105__19 215 | 2009_003123__18 216 | 2009_003299__19 217 | 2009_003311__19 218 | 2009_003433__19 219 | 2009_003523__20 220 | 2009_003551__20 221 | 2009_003564__16 222 | 2009_003564__18 223 | 2009_003607__18 224 | 2009_003666__17 225 | 2009_003857__20 226 | 2009_003895__18 227 | 2009_003895__20 228 | 2009_003938__19 229 | 2009_004099__18 230 | 2009_004140__18 231 | 2009_004255__19 232 | 2009_004298__18 233 | 2009_004687__18 234 | 2009_004730__19 235 | 2009_004799__19 236 | 2009_004993__18 237 | 2009_004993__20 238 | 2009_005148__19 239 | 2009_005220__19 240 | 2010_000256__18 241 | 2010_000284__18 242 | 2010_000309__17 243 | 2010_000318__20 244 | 2010_000330__16 245 | 2010_000639__16 246 | 2010_000738__20 247 | 2010_000764__19 248 | 2010_001011__17 249 | 2010_001079__17 250 | 2010_001104__19 251 | 2010_001149__18 252 | 2010_001151__19 253 | 2010_001246__16 254 | 2010_001256__17 255 | 2010_001327__18 256 | 2010_001367__20 257 | 2010_001522__17 258 | 2010_001557__17 259 | 2010_001577__17 260 | 2010_001699__16 261 | 2010_001734__19 262 | 2010_001752__20 263 | 2010_001767__18 264 | 2010_001773__16 265 | 2010_001851__16 266 | 2010_001951__19 267 | 2010_001962__18 268 | 2010_002106__17 269 | 2010_002137__16 270 | 2010_002137__18 271 | 2010_002232__17 272 | 2010_002531__18 273 | 2010_002682__19 274 | 2010_002921__20 275 | 2010_003014__18 276 | 2010_003123__16 277 | 2010_003302__16 278 | 2010_003514__19 279 | 2010_003541__17 280 | 2010_003597__18 281 | 2010_003781__16 282 | 2010_003956__19 283 | 2010_004149__19 284 | 2010_004226__17 285 | 2010_004382__16 286 | 2010_004479__20 287 | 2010_004757__16 288 | 2010_004757__18 289 | 2010_004783__18 290 | 2010_004825__16 291 | 2010_004857__20 292 | 2010_004951__19 293 | 2010_004980__19 294 | 2010_005180__18 295 | 2010_005187__16 296 | 2010_005305__20 297 | 2010_005606__18 298 | 2010_005706__19 299 | 2010_005719__17 300 | 2010_005727__19 301 | 2010_005788__17 302 | 2010_005860__16 303 | 2010_005871__19 304 | 2010_005991__18 305 | 2010_006054__19 306 | 2011_000070__18 307 | 2011_000173__18 308 | 2011_000283__19 309 | 2011_000291__19 310 | 2011_000310__18 311 | 2011_000436__17 312 | 2011_000521__19 313 | 2011_000747__16 314 | 2011_001005__18 315 | 2011_001060__19 316 | 2011_001281__19 317 | 2011_001350__17 318 | 2011_001567__18 319 | 2011_001601__18 320 | 2011_001614__19 321 | 2011_001674__18 322 | 2011_001713__16 323 | 2011_001713__18 324 | 2011_001726__20 325 | 2011_001794__18 326 | 2011_001862__18 327 | 2011_001863__16 328 | 2011_001910__20 329 | 2011_002124__18 330 | 2011_002156__20 331 | 2011_002178__17 332 | 2011_002247__19 333 | 2011_002379__19 334 | 2011_002391__18 335 | 2011_002532__20 336 | 2011_002535__19 337 | 2011_002644__18 338 | 2011_002644__20 339 | 2011_002879__18 340 | 2011_002879__20 341 | 2011_003103__16 342 | 2011_003103__18 343 | 2011_003146__19 344 | 2011_003182__18 345 | 2011_003197__19 346 | 2011_003256__18 347 | -------------------------------------------------------------------------------- /fewshot_data/model/base/conv4d.py: -------------------------------------------------------------------------------- 1 | r""" Implementation of center-pivot 4D convolution """ 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class CenterPivotConv4d(nn.Module): 8 | r""" CenterPivot 4D conv""" 9 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding, bias=True): 10 | super(CenterPivotConv4d, self).__init__() 11 | 12 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size[:2], stride=stride[:2], 13 | bias=bias, padding=padding[:2]) 14 | self.conv2 = nn.Conv2d(in_channels, out_channels, kernel_size[2:], stride=stride[2:], 15 | bias=bias, padding=padding[2:]) 16 | 17 | self.stride34 = stride[2:] 18 | self.kernel_size = kernel_size 19 | self.stride = stride 20 | self.padding = padding 21 | self.idx_initialized = False 22 | 23 | def prune(self, ct): 24 | bsz, ch, ha, wa, hb, wb = ct.size() 25 | if not self.idx_initialized: 26 | idxh = torch.arange(start=0, end=hb, step=self.stride[2:][0], device=ct.device) 27 | idxw = torch.arange(start=0, end=wb, step=self.stride[2:][1], device=ct.device) 28 | self.len_h = len(idxh) 29 | self.len_w = len(idxw) 30 | self.idx = (idxw.repeat(self.len_h, 1) + idxh.repeat(self.len_w, 1).t() * wb).view(-1) 31 | self.idx_initialized = True 32 | ct_pruned = ct.view(bsz, ch, ha, wa, -1).index_select(4, self.idx).view(bsz, ch, ha, wa, self.len_h, self.len_w) 33 | 34 | return ct_pruned 35 | 36 | def forward(self, x): 37 | if self.stride[2:][-1] > 1: 38 | out1 = self.prune(x) 39 | else: 40 | out1 = x 41 | bsz, inch, ha, wa, hb, wb = out1.size() 42 | out1 = out1.permute(0, 4, 5, 1, 2, 3).contiguous().view(-1, inch, ha, wa) 43 | out1 = self.conv1(out1) 44 | outch, o_ha, o_wa = out1.size(-3), out1.size(-2), out1.size(-1) 45 | out1 = out1.view(bsz, hb, wb, outch, o_ha, o_wa).permute(0, 3, 4, 5, 1, 2).contiguous() 46 | 47 | bsz, inch, ha, wa, hb, wb = x.size() 48 | out2 = x.permute(0, 2, 3, 1, 4, 5).contiguous().view(-1, inch, hb, wb) 49 | out2 = self.conv2(out2) 50 | outch, o_hb, o_wb = out2.size(-3), out2.size(-2), out2.size(-1) 51 | out2 = out2.view(bsz, ha, wa, outch, o_hb, o_wb).permute(0, 3, 1, 2, 4, 5).contiguous() 52 | 53 | if out1.size()[-2:] != out2.size()[-2:] and self.padding[-2:] == (0, 0): 54 | out1 = out1.view(bsz, outch, o_ha, o_wa, -1).sum(dim=-1) 55 | out2 = out2.squeeze() 56 | 57 | y = out1 + out2 58 | return y 59 | -------------------------------------------------------------------------------- /fewshot_data/model/base/correlation.py: -------------------------------------------------------------------------------- 1 | r""" Provides functions that builds/manipulates correlation tensors """ 2 | import torch 3 | 4 | 5 | class Correlation: 6 | 7 | @classmethod 8 | def multilayer_correlation(cls, query_feats, support_feats, stack_ids): 9 | eps = 1e-5 10 | 11 | corrs = [] 12 | for idx, (query_feat, support_feat) in enumerate(zip(query_feats, support_feats)): 13 | bsz, ch, hb, wb = support_feat.size() 14 | support_feat = support_feat.view(bsz, ch, -1) 15 | support_feat = support_feat / (support_feat.norm(dim=1, p=2, keepdim=True) + eps) 16 | 17 | bsz, ch, ha, wa = query_feat.size() 18 | query_feat = query_feat.view(bsz, ch, -1) 19 | query_feat = query_feat / (query_feat.norm(dim=1, p=2, keepdim=True) + eps) 20 | 21 | corr = torch.bmm(query_feat.transpose(1, 2), support_feat).view(bsz, ha, wa, hb, wb) 22 | corr = corr.clamp(min=0) 23 | corrs.append(corr) 24 | 25 | corr_l4 = torch.stack(corrs[-stack_ids[0]:]).transpose(0, 1).contiguous() 26 | corr_l3 = torch.stack(corrs[-stack_ids[1]:-stack_ids[0]]).transpose(0, 1).contiguous() 27 | corr_l2 = torch.stack(corrs[-stack_ids[2]:-stack_ids[1]]).transpose(0, 1).contiguous() 28 | 29 | return [corr_l4, corr_l3, corr_l2] 30 | -------------------------------------------------------------------------------- /fewshot_data/model/base/feature.py: -------------------------------------------------------------------------------- 1 | r""" Extracts intermediate features from given backbone network & layer ids """ 2 | 3 | 4 | def extract_feat_vgg(img, backbone, feat_ids, bottleneck_ids=None, lids=None): 5 | r""" Extract intermediate features from VGG """ 6 | feats = [] 7 | feat = img 8 | for lid, module in enumerate(backbone.features): 9 | feat = module(feat) 10 | if lid in feat_ids: 11 | feats.append(feat.clone()) 12 | return feats 13 | 14 | 15 | def extract_feat_res(img, backbone, feat_ids, bottleneck_ids, lids): 16 | r""" Extract intermediate features from ResNet""" 17 | feats = [] 18 | 19 | # Layer 0 20 | feat = backbone.conv1.forward(img) 21 | feat = backbone.bn1.forward(feat) 22 | feat = backbone.relu.forward(feat) 23 | feat = backbone.maxpool.forward(feat) 24 | 25 | # Layer 1-4 26 | for hid, (bid, lid) in enumerate(zip(bottleneck_ids, lids)): 27 | res = feat 28 | feat = backbone.__getattr__('layer%d' % lid)[bid].conv1.forward(feat) 29 | feat = backbone.__getattr__('layer%d' % lid)[bid].bn1.forward(feat) 30 | feat = backbone.__getattr__('layer%d' % lid)[bid].relu.forward(feat) 31 | feat = backbone.__getattr__('layer%d' % lid)[bid].conv2.forward(feat) 32 | feat = backbone.__getattr__('layer%d' % lid)[bid].bn2.forward(feat) 33 | feat = backbone.__getattr__('layer%d' % lid)[bid].relu.forward(feat) 34 | feat = backbone.__getattr__('layer%d' % lid)[bid].conv3.forward(feat) 35 | feat = backbone.__getattr__('layer%d' % lid)[bid].bn3.forward(feat) 36 | 37 | if bid == 0: 38 | res = backbone.__getattr__('layer%d' % lid)[bid].downsample.forward(res) 39 | 40 | feat += res 41 | 42 | if hid + 1 in feat_ids: 43 | feats.append(feat.clone()) 44 | 45 | feat = backbone.__getattr__('layer%d' % lid)[bid].relu.forward(feat) 46 | 47 | return feats -------------------------------------------------------------------------------- /fewshot_data/model/hsnet.py: -------------------------------------------------------------------------------- 1 | r""" Hypercorrelation Squeeze Network """ 2 | from functools import reduce 3 | from operator import add 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torchvision.models import resnet 9 | from torchvision.models import vgg 10 | 11 | from fewshot_data.model.base.feature import extract_feat_vgg, extract_feat_res 12 | from fewshot_data.model.base.correlation import Correlation 13 | from fewshot_data.model.learner import HPNLearner 14 | 15 | 16 | class HypercorrSqueezeNetwork(nn.Module): 17 | def __init__(self, backbone, use_original_imgsize): 18 | super(HypercorrSqueezeNetwork, self).__init__() 19 | 20 | # 1. Backbone network initialization 21 | self.backbone_type = backbone 22 | self.use_original_imgsize = use_original_imgsize 23 | if backbone == 'vgg16': 24 | self.backbone = vgg.vgg16(pretrained=True) 25 | self.feat_ids = [17, 19, 21, 24, 26, 28, 30] 26 | self.extract_feats = extract_feat_vgg 27 | nbottlenecks = [2, 2, 3, 3, 3, 1] 28 | elif backbone == 'resnet50': 29 | self.backbone = resnet.resnet50(pretrained=True) 30 | self.feat_ids = list(range(4, 17)) 31 | self.extract_feats = extract_feat_res 32 | nbottlenecks = [3, 4, 6, 3] 33 | elif backbone == 'resnet101': 34 | self.backbone = resnet.resnet101(pretrained=True) 35 | self.feat_ids = list(range(4, 34)) 36 | self.extract_feats = extract_feat_res 37 | nbottlenecks = [3, 4, 23, 3] 38 | else: 39 | raise Exception('Unavailable backbone: %s' % backbone) 40 | 41 | self.bottleneck_ids = reduce(add, list(map(lambda x: list(range(x)), nbottlenecks))) 42 | self.lids = reduce(add, [[i + 1] * x for i, x in enumerate(nbottlenecks)]) 43 | self.stack_ids = torch.tensor(self.lids).bincount().__reversed__().cumsum(dim=0)[:3] 44 | self.backbone.eval() 45 | self.hpn_learner = HPNLearner(list(reversed(nbottlenecks[-3:]))) 46 | self.cross_entropy_loss = nn.CrossEntropyLoss() 47 | 48 | def forward(self, query_img, support_img, support_mask): 49 | with torch.no_grad(): 50 | query_feats = self.extract_feats(query_img, self.backbone, self.feat_ids, self.bottleneck_ids, self.lids) 51 | support_feats = self.extract_feats(support_img, self.backbone, self.feat_ids, self.bottleneck_ids, self.lids) 52 | support_feats = self.mask_feature(support_feats, support_mask.clone()) 53 | corr = Correlation.multilayer_correlation(query_feats, support_feats, self.stack_ids) 54 | 55 | logit_mask = self.hpn_learner(corr) 56 | if not self.use_original_imgsize: 57 | logit_mask = F.interpolate(logit_mask, support_img.size()[2:], mode='bilinear', align_corners=True) 58 | 59 | return logit_mask 60 | 61 | def mask_feature(self, features, support_mask): 62 | for idx, feature in enumerate(features): 63 | mask = F.interpolate(support_mask.unsqueeze(1).float(), feature.size()[2:], mode='bilinear', align_corners=True) 64 | features[idx] = features[idx] * mask 65 | return features 66 | 67 | def predict_mask_nshot(self, batch, nshot): 68 | 69 | # Perform multiple prediction given (nshot) number of different support sets 70 | logit_mask_agg = 0 71 | for s_idx in range(nshot): 72 | logit_mask = self(batch['query_img'], batch['support_imgs'][:, s_idx], batch['support_masks'][:, s_idx]) 73 | 74 | if self.use_original_imgsize: 75 | org_qry_imsize = tuple([batch['org_query_imsize'][1].item(), batch['org_query_imsize'][0].item()]) 76 | logit_mask = F.interpolate(logit_mask, org_qry_imsize, mode='bilinear', align_corners=True) 77 | 78 | logit_mask_agg += logit_mask.argmax(dim=1).clone() 79 | if nshot == 1: return logit_mask_agg 80 | 81 | # Average & quantize predictions given threshold (=0.5) 82 | bsz = logit_mask_agg.size(0) 83 | max_vote = logit_mask_agg.view(bsz, -1).max(dim=1)[0] 84 | max_vote = torch.stack([max_vote, torch.ones_like(max_vote).long()]) 85 | max_vote = max_vote.max(dim=0)[0].view(bsz, 1, 1) 86 | pred_mask = logit_mask_agg.float() / max_vote 87 | pred_mask[pred_mask < 0.5] = 0 88 | pred_mask[pred_mask >= 0.5] = 1 89 | 90 | return pred_mask 91 | 92 | def compute_objective(self, logit_mask, gt_mask): 93 | bsz = logit_mask.size(0) 94 | logit_mask = logit_mask.view(bsz, 2, -1) 95 | gt_mask = gt_mask.view(bsz, -1).long() 96 | 97 | return self.cross_entropy_loss(logit_mask, gt_mask) 98 | 99 | def train_mode(self): 100 | self.train() 101 | self.backbone.eval() # to prevent BN from learning data statistics with exponential averaging 102 | -------------------------------------------------------------------------------- /fewshot_data/model/learner.py: -------------------------------------------------------------------------------- 1 | 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from fewshot_data.model.base.conv4d import CenterPivotConv4d as Conv4d 6 | 7 | 8 | class HPNLearner(nn.Module): 9 | def __init__(self, inch): 10 | super(HPNLearner, self).__init__() 11 | 12 | def make_building_block(in_channel, out_channels, kernel_sizes, spt_strides, group=4): 13 | assert len(out_channels) == len(kernel_sizes) == len(spt_strides) 14 | 15 | building_block_layers = [] 16 | for idx, (outch, ksz, stride) in enumerate(zip(out_channels, kernel_sizes, spt_strides)): 17 | inch = in_channel if idx == 0 else out_channels[idx - 1] 18 | ksz4d = (ksz,) * 4 19 | str4d = (1, 1) + (stride,) * 2 20 | pad4d = (ksz // 2,) * 4 21 | 22 | building_block_layers.append(Conv4d(inch, outch, ksz4d, str4d, pad4d)) 23 | building_block_layers.append(nn.GroupNorm(group, outch)) 24 | building_block_layers.append(nn.ReLU(inplace=True)) 25 | 26 | return nn.Sequential(*building_block_layers) 27 | 28 | outch1, outch2, outch3 = 16, 64, 128 29 | 30 | # Squeezing building blocks 31 | self.encoder_layer4 = make_building_block(inch[0], [outch1, outch2, outch3], [3, 3, 3], [2, 2, 2]) 32 | self.encoder_layer3 = make_building_block(inch[1], [outch1, outch2, outch3], [5, 3, 3], [4, 2, 2]) 33 | self.encoder_layer2 = make_building_block(inch[2], [outch1, outch2, outch3], [5, 5, 3], [4, 4, 2]) 34 | 35 | # Mixing building blocks 36 | self.encoder_layer4to3 = make_building_block(outch3, [outch3, outch3, outch3], [3, 3, 3], [1, 1, 1]) 37 | self.encoder_layer3to2 = make_building_block(outch3, [outch3, outch3, outch3], [3, 3, 3], [1, 1, 1]) 38 | 39 | # Decoder layers 40 | self.decoder1 = nn.Sequential(nn.Conv2d(outch3, outch3, (3, 3), padding=(1, 1), bias=True), 41 | nn.ReLU(), 42 | nn.Conv2d(outch3, outch2, (3, 3), padding=(1, 1), bias=True), 43 | nn.ReLU()) 44 | 45 | self.decoder2 = nn.Sequential(nn.Conv2d(outch2, outch2, (3, 3), padding=(1, 1), bias=True), 46 | nn.ReLU(), 47 | nn.Conv2d(outch2, 2, (3, 3), padding=(1, 1), bias=True)) 48 | 49 | def interpolate_support_dims(self, hypercorr, spatial_size=None): 50 | bsz, ch, ha, wa, hb, wb = hypercorr.size() 51 | hypercorr = hypercorr.permute(0, 4, 5, 1, 2, 3).contiguous().view(bsz * hb * wb, ch, ha, wa) 52 | hypercorr = F.interpolate(hypercorr, spatial_size, mode='bilinear', align_corners=True) 53 | o_hb, o_wb = spatial_size 54 | hypercorr = hypercorr.view(bsz, hb, wb, ch, o_hb, o_wb).permute(0, 3, 4, 5, 1, 2).contiguous() 55 | return hypercorr 56 | 57 | def forward(self, hypercorr_pyramid): 58 | 59 | # Encode hypercorrelations from each layer (Squeezing building blocks) 60 | hypercorr_sqz4 = self.encoder_layer4(hypercorr_pyramid[0]) 61 | hypercorr_sqz3 = self.encoder_layer3(hypercorr_pyramid[1]) 62 | hypercorr_sqz2 = self.encoder_layer2(hypercorr_pyramid[2]) 63 | 64 | # Propagate encoded 4D-tensor (Mixing building blocks) 65 | hypercorr_sqz4 = self.interpolate_support_dims(hypercorr_sqz4, hypercorr_sqz3.size()[-4:-2]) 66 | hypercorr_mix43 = hypercorr_sqz4 + hypercorr_sqz3 67 | hypercorr_mix43 = self.encoder_layer4to3(hypercorr_mix43) 68 | 69 | hypercorr_mix43 = self.interpolate_support_dims(hypercorr_mix43, hypercorr_sqz2.size()[-4:-2]) 70 | hypercorr_mix432 = hypercorr_mix43 + hypercorr_sqz2 71 | hypercorr_mix432 = self.encoder_layer3to2(hypercorr_mix432) 72 | 73 | bsz, ch, ha, wa, hb, wb = hypercorr_mix432.size() 74 | hypercorr_encoded = hypercorr_mix432.view(bsz, ch, ha, wa, -1).mean(dim=-1) 75 | 76 | # Decode the encoded 4D-tensor 77 | hypercorr_decoded = self.decoder1(hypercorr_encoded) 78 | upsample_size = (hypercorr_decoded.size(-1) * 2,) * 2 79 | hypercorr_decoded = F.interpolate(hypercorr_decoded, upsample_size, mode='bilinear', align_corners=True) 80 | logit_mask = self.decoder2(hypercorr_decoded) 81 | 82 | return logit_mask 83 | -------------------------------------------------------------------------------- /fewshot_data/sbatch_run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | DATE=`date -d now` 3 | EXP=hsnet 4 | NGPU=4 5 | partition=g24 6 | JOB=fewshot_${EXP} 7 | SAVE_ROOT="save/${EXP}" 8 | SCRIPT_ROOT="sweep_scripts/${EXP}" 9 | mkdir -p $SCRIPT_ROOT 10 | NCPU=$((NGPU * 10)) 11 | qos=normal # high normal low 12 | 13 | function print_append { 14 | echo $@ >> $SCRIPT 15 | } 16 | 17 | function slurm_append { 18 | echo $@ >> $SLURM 19 | } 20 | 21 | function print_setup { 22 | SAVE="${SAVE_ROOT}/${JOB}" 23 | SCRIPT="${SCRIPT_ROOT}/${JOB}.sh" 24 | SLURM="${SCRIPT_ROOT}/${JOB}.slrm" 25 | mkdir -p $SAVE 26 | echo `date -d now` $SAVE >> 'submitted.txt' 27 | echo "#!/bin/bash" > $SLURM 28 | slurm_append "#SBATCH --job-name=job1111_${JOB}" 29 | slurm_append "#SBATCH --output=${SAVE}/stdout.txt" 30 | slurm_append "#SBATCH --error=${SAVE}/stderr.txt" 31 | slurm_append "#SBATCH --open-mode=append" 32 | slurm_append "#SBATCH --signal=B:USR1@120" 33 | 34 | slurm_append "#SBATCH -p ${partition}" 35 | slurm_append "#SBATCH --gres=gpu:${NGPU}" 36 | slurm_append "#SBATCH -c ${NCPU}" 37 | slurm_append "#SBATCH -t 02-00" 38 | # slurm_append "#SBATCH -t 01-00" 39 | # slurm_append "#SBATCH -t 00-06" 40 | slurm_append "#SBATCH --qos=${qos}" 41 | slurm_append "srun sh ${SCRIPT}" 42 | 43 | echo "#!/bin/bash" > $SCRIPT 44 | print_append "trap_handler () {" 45 | print_append "echo \"Caught signal: \" \$1" 46 | print_append "# SIGTERM must be bypassed" 47 | print_append "if [ "$1" = "TERM" ]; then" 48 | print_append "echo \"bypass sigterm\"" 49 | print_append "else" 50 | print_append "# Submit a new job to the queue" 51 | print_append "echo \"Requeuing \" \$SLURM_JOB_ID" 52 | print_append "scontrol requeue \$SLURM_JOB_ID" 53 | print_append "fi" 54 | print_append "}" 55 | print_append "trap 'trap_handler USR1' USR1" 56 | print_append "trap 'trap_handler TERM' TERM" 57 | 58 | print_append "{" 59 | print_append "source activate pytorch" 60 | print_append "conda activate pytorch" 61 | print_append "export PATH=/home/boyili/programfiles/anaconda3/envs/pytorch/bin:/home/boyili/programfiles/anaconda3/condabin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/games:/usr/local/games:/snap/bin" 62 | print_append "which python" 63 | print_append "echo \$PATH" 64 | print_append "export NCCL_DEBUG=INFO" 65 | print_append "export PYTHONFAULTHANDLER=1" 66 | 67 | echo $JOB 68 | } 69 | 70 | function print_after { 71 | print_append "} & " 72 | print_append "wait \$!" 73 | print_append "sleep 610 &" 74 | print_append "wait \$!" 75 | } 76 | 77 | print_setup 78 | print_append stdbuf -o0 -e0 \ 79 | python train.py --log 'log_pascal' 80 | print_after 81 | sbatch $SLURM 82 | -------------------------------------------------------------------------------- /fewshot_data/test.py: -------------------------------------------------------------------------------- 1 | r""" Hypercorrelation Squeeze testing code """ 2 | import argparse 3 | 4 | import torch.nn.functional as F 5 | import torch.nn as nn 6 | import torch 7 | 8 | from fewshot_data.model.hsnet import HypercorrSqueezeNetwork 9 | from fewshot_data.common.logger import Logger, AverageMeter 10 | from fewshot_data.common.vis import Visualizer 11 | from fewshot_data.common.evaluation import Evaluator 12 | from fewshot_data.common import utils 13 | from fewshot_data.data.dataset import FSSDataset 14 | 15 | 16 | def test(model, dataloader, nshot): 17 | r""" Test HSNet """ 18 | 19 | # Freeze randomness during testing for reproducibility 20 | utils.fix_randseed(0) 21 | average_meter = AverageMeter(dataloader.dataset) 22 | 23 | for idx, batch in enumerate(dataloader): 24 | 25 | # 1. Hypercorrelation Squeeze Networks forward pass 26 | batch = utils.to_cuda(batch) 27 | pred_mask = model.module.predict_mask_nshot(batch, nshot=nshot) 28 | 29 | assert pred_mask.size() == batch['query_mask'].size() 30 | 31 | # 2. Evaluate prediction 32 | area_inter, area_union = Evaluator.classify_prediction(pred_mask.clone(), batch) 33 | average_meter.update(area_inter, area_union, batch['class_id'], loss=None) 34 | average_meter.write_process(idx, len(dataloader), epoch=-1, write_batch_idx=1) 35 | 36 | # Visualize predictions 37 | if Visualizer.visualize: 38 | Visualizer.visualize_prediction_batch(batch['support_imgs'], batch['support_masks'], 39 | batch['query_img'], batch['query_mask'], 40 | pred_mask, batch['class_id'], idx, 41 | area_inter[1].float() / area_union[1].float()) 42 | # Write evaluation results 43 | average_meter.write_result('Test', 0) 44 | miou, fb_iou = average_meter.compute_iou() 45 | 46 | return miou, fb_iou 47 | 48 | 49 | if __name__ == '__main__': 50 | 51 | # Arguments parsing 52 | parser = argparse.ArgumentParser(description='Hypercorrelation Squeeze Pytorch Implementation') 53 | parser.add_argument('--datapath', type=str, default='fewshot_data/Datasets_HSN') 54 | parser.add_argument('--benchmark', type=str, default='pascal', choices=['pascal', 'coco', 'fss']) 55 | parser.add_argument('--logpath', type=str, default='') 56 | parser.add_argument('--bsz', type=int, default=1) 57 | parser.add_argument('--nworker', type=int, default=0) 58 | parser.add_argument('--load', type=str, default='') 59 | parser.add_argument('--fold', type=int, default=0, choices=[0, 1, 2, 3]) 60 | parser.add_argument('--nshot', type=int, default=1) 61 | parser.add_argument('--backbone', type=str, default='resnet101', choices=['vgg16', 'resnet50', 'resnet101']) 62 | parser.add_argument('--visualize', action='store_true') 63 | parser.add_argument('--use_original_imgsize', action='store_true') 64 | args = parser.parse_args() 65 | Logger.initialize(args, training=False) 66 | 67 | # Model initialization 68 | model = HypercorrSqueezeNetwork(args.backbone, args.use_original_imgsize) 69 | model.eval() 70 | Logger.log_params(model) 71 | 72 | # Device setup 73 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 74 | Logger.info('# available GPUs: %d' % torch.cuda.device_count()) 75 | model = nn.DataParallel(model) 76 | model.to(device) 77 | 78 | # Load trained model 79 | if args.load == '': raise Exception('Pretrained model not specified.') 80 | model.load_state_dict(torch.load(args.load)) 81 | 82 | # Helper classes (for testing) initialization 83 | Evaluator.initialize() 84 | Visualizer.initialize(args.visualize) 85 | 86 | # Dataset initialization 87 | FSSDataset.initialize(img_size=400, datapath=args.datapath, use_original_imgsize=args.use_original_imgsize) 88 | dataloader_test = FSSDataset.build_dataloader(args.benchmark, args.bsz, args.nworker, args.fold, 'test', args.nshot) 89 | 90 | # Test HSNet 91 | with torch.no_grad(): 92 | test_miou, test_fb_iou = test(model, dataloader_test, args.nshot) 93 | Logger.info('Fold %d mIoU: %5.2f \t FB-IoU: %5.2f' % (args.fold, test_miou.item(), test_fb_iou.item())) 94 | Logger.info('==================== Finished Testing ====================') 95 | -------------------------------------------------------------------------------- /fewshot_data/train.py: -------------------------------------------------------------------------------- 1 | r""" Hypercorrelation Squeeze training (validation) code """ 2 | import argparse 3 | 4 | import torch.optim as optim 5 | import torch.nn as nn 6 | import torch 7 | 8 | from fewshot_data.model.hsnet import HypercorrSqueezeNetwork 9 | from fewshot_data.common.logger import Logger, AverageMeter 10 | from fewshot_data.common.evaluation import Evaluator 11 | from fewshot_data.common import utils 12 | from fewshot_data.data.dataset import FSSDataset 13 | 14 | 15 | def train(epoch, model, dataloader, optimizer, training): 16 | r""" Train HSNet """ 17 | 18 | # Force randomness during training / freeze randomness during testing 19 | utils.fix_randseed(None) if training else utils.fix_randseed(0) 20 | model.module.train_mode() if training else model.module.eval() 21 | average_meter = AverageMeter(dataloader.dataset) 22 | 23 | for idx, batch in enumerate(dataloader): 24 | # 1. Hypercorrelation Squeeze Networks forward pass 25 | batch = utils.to_cuda(batch) 26 | logit_mask = model(batch['query_img'], batch['support_imgs'].squeeze(1), batch['support_masks'].squeeze(1)) 27 | pred_mask = logit_mask.argmax(dim=1) 28 | 29 | # 2. Compute loss & update model parameters 30 | loss = model.module.compute_objective(logit_mask, batch['query_mask']) 31 | if training: 32 | optimizer.zero_grad() 33 | loss.backward() 34 | optimizer.step() 35 | 36 | # 3. Evaluate prediction 37 | area_inter, area_union = Evaluator.classify_prediction(pred_mask, batch) 38 | average_meter.update(area_inter, area_union, batch['class_id'], loss.detach().clone()) 39 | average_meter.write_process(idx, len(dataloader), epoch, write_batch_idx=50) 40 | 41 | # Write evaluation results 42 | average_meter.write_result('Training' if training else 'Validation', epoch) 43 | avg_loss = utils.mean(average_meter.loss_buf) 44 | miou, fb_iou = average_meter.compute_iou() 45 | 46 | return avg_loss, miou, fb_iou 47 | 48 | 49 | if __name__ == '__main__': 50 | 51 | # Arguments parsing 52 | parser = argparse.ArgumentParser(description='Hypercorrelation Squeeze Pytorch Implementation') 53 | parser.add_argument('--datapath', type=str, default='fewshot_data/Datasets_HSN') 54 | parser.add_argument('--benchmark', type=str, default='pascal', choices=['pascal', 'coco', 'fss']) 55 | parser.add_argument('--logpath', type=str, default='') 56 | parser.add_argument('--bsz', type=int, default=20) 57 | parser.add_argument('--lr', type=float, default=1e-3) 58 | parser.add_argument('--niter', type=int, default=2000) 59 | parser.add_argument('--nworker', type=int, default=8) 60 | parser.add_argument('--fold', type=int, default=0, choices=[0, 1, 2, 3]) 61 | parser.add_argument('--backbone', type=str, default='resnet101', choices=['vgg16', 'resnet50', 'resnet101']) 62 | args = parser.parse_args() 63 | Logger.initialize(args, training=True) 64 | 65 | # Model initialization 66 | model = HypercorrSqueezeNetwork(args.backbone, False) 67 | Logger.log_params(model) 68 | 69 | # Device setup 70 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 71 | Logger.info('# available GPUs: %d' % torch.cuda.device_count()) 72 | model = nn.DataParallel(model) 73 | model.to(device) 74 | 75 | # Helper classes (for training) initialization 76 | optimizer = optim.Adam([{"params": model.parameters(), "lr": args.lr}]) 77 | Evaluator.initialize() 78 | 79 | # Dataset initialization 80 | FSSDataset.initialize(img_size=400, datapath=args.datapath, use_original_imgsize=False) 81 | dataloader_trn = FSSDataset.build_dataloader(args.benchmark, args.bsz, args.nworker, args.fold, 'trn') 82 | dataloader_val = FSSDataset.build_dataloader(args.benchmark, args.bsz, args.nworker, args.fold, 'val') 83 | 84 | # Train HSNet 85 | best_val_miou = float('-inf') 86 | best_val_loss = float('inf') 87 | for epoch in range(args.niter): 88 | 89 | trn_loss, trn_miou, trn_fb_iou = train(epoch, model, dataloader_trn, optimizer, training=True) 90 | with torch.no_grad(): 91 | val_loss, val_miou, val_fb_iou = train(epoch, model, dataloader_val, optimizer, training=False) 92 | 93 | # Save the best model 94 | if val_miou > best_val_miou: 95 | best_val_miou = val_miou 96 | Logger.save_model_miou(model, epoch, val_miou) 97 | 98 | Logger.tbd_writer.add_scalars('fewshot_data/data/loss', {'trn_loss': trn_loss, 'val_loss': val_loss}, epoch) 99 | Logger.tbd_writer.add_scalars('fewshot_data/data/miou', {'trn_miou': trn_miou, 'val_miou': val_miou}, epoch) 100 | Logger.tbd_writer.add_scalars('fewshot_data/data/fb_iou', {'trn_fb_iou': trn_fb_iou, 'val_fb_iou': val_fb_iou}, epoch) 101 | Logger.tbd_writer.flush() 102 | Logger.tbd_writer.close() 103 | Logger.info('==================== Finished Training ====================') 104 | -------------------------------------------------------------------------------- /inputs/cat1.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isl-org/lang-seg/bb00b219f82dc1a18a01f9b0b647bae1d41f3aeb/inputs/cat1.jpeg -------------------------------------------------------------------------------- /label_files/ade20k_objectInfo150.txt: -------------------------------------------------------------------------------- 1 | Idx,Ratio,Train,Val,Stuff,Name 2 | 1,0.1576,11664,1172,1,wall 3 | 2,0.1072,6046,612,1,building;edifice 4 | 3,0.0878,8265,796,1,sky 5 | 4,0.0621,9336,917,1,floor;flooring 6 | 5,0.0480,6678,641,0,tree 7 | 6,0.0450,6604,643,1,ceiling 8 | 7,0.0398,4023,408,1,road;route 9 | 8,0.0231,1906,199,0,bed 10 | 9,0.0198,4688,460,0,windowpane;window 11 | 10,0.0183,2423,225,1,grass 12 | 11,0.0181,2874,294,0,cabinet 13 | 12,0.0166,3068,310,1,sidewalk;pavement 14 | 13,0.0160,5075,526,0,person;individual;someone;somebody;mortal;soul 15 | 14,0.0151,1804,190,1,earth;ground 16 | 15,0.0118,6666,796,0,door;double;door 17 | 16,0.0110,4269,411,0,table 18 | 17,0.0109,1691,160,1,mountain;mount 19 | 18,0.0104,3999,441,0,plant;flora;plant;life 20 | 19,0.0104,2149,217,0,curtain;drape;drapery;mantle;pall 21 | 20,0.0103,3261,318,0,chair 22 | 21,0.0098,3164,306,0,car;auto;automobile;machine;motorcar 23 | 22,0.0074,709,75,1,water 24 | 23,0.0067,3296,315,0,painting;picture 25 | 24,0.0065,1191,106,0,sofa;couch;lounge 26 | 25,0.0061,1516,162,0,shelf 27 | 26,0.0060,667,69,1,house 28 | 27,0.0053,651,57,1,sea 29 | 28,0.0052,1847,224,0,mirror 30 | 29,0.0046,1158,128,1,rug;carpet;carpeting 31 | 30,0.0044,480,44,1,field 32 | 31,0.0044,1172,98,0,armchair 33 | 32,0.0044,1292,184,0,seat 34 | 33,0.0033,1386,138,0,fence;fencing 35 | 34,0.0031,698,61,0,desk 36 | 35,0.0030,781,73,0,rock;stone 37 | 36,0.0027,380,43,0,wardrobe;closet;press 38 | 37,0.0026,3089,302,0,lamp 39 | 38,0.0024,404,37,0,bathtub;bathing;tub;bath;tub 40 | 39,0.0024,804,99,0,railing;rail 41 | 40,0.0023,1453,153,0,cushion 42 | 41,0.0023,411,37,0,base;pedestal;stand 43 | 42,0.0022,1440,162,0,box 44 | 43,0.0022,800,77,0,column;pillar 45 | 44,0.0020,2650,298,0,signboard;sign 46 | 45,0.0019,549,46,0,chest;of;drawers;chest;bureau;dresser 47 | 46,0.0019,367,36,0,counter 48 | 47,0.0018,311,30,1,sand 49 | 48,0.0018,1181,122,0,sink 50 | 49,0.0018,287,23,1,skyscraper 51 | 50,0.0018,468,38,0,fireplace;hearth;open;fireplace 52 | 51,0.0018,402,43,0,refrigerator;icebox 53 | 52,0.0018,130,12,1,grandstand;covered;stand 54 | 53,0.0018,561,64,1,path 55 | 54,0.0017,880,102,0,stairs;steps 56 | 55,0.0017,86,12,1,runway 57 | 56,0.0017,172,11,0,case;display;case;showcase;vitrine 58 | 57,0.0017,198,18,0,pool;table;billiard;table;snooker;table 59 | 58,0.0017,930,109,0,pillow 60 | 59,0.0015,139,18,0,screen;door;screen 61 | 60,0.0015,564,52,1,stairway;staircase 62 | 61,0.0015,320,26,1,river 63 | 62,0.0015,261,29,1,bridge;span 64 | 63,0.0014,275,22,0,bookcase 65 | 64,0.0014,335,60,0,blind;screen 66 | 65,0.0014,792,75,0,coffee;table;cocktail;table 67 | 66,0.0014,395,49,0,toilet;can;commode;crapper;pot;potty;stool;throne 68 | 67,0.0014,1309,138,0,flower 69 | 68,0.0013,1112,113,0,book 70 | 69,0.0013,266,27,1,hill 71 | 70,0.0013,659,66,0,bench 72 | 71,0.0012,331,31,0,countertop 73 | 72,0.0012,531,56,0,stove;kitchen;stove;range;kitchen;range;cooking;stove 74 | 73,0.0012,369,36,0,palm;palm;tree 75 | 74,0.0012,144,9,0,kitchen;island 76 | 75,0.0011,265,29,0,computer;computing;machine;computing;device;data;processor;electronic;computer;information;processing;system 77 | 76,0.0010,324,33,0,swivel;chair 78 | 77,0.0009,304,27,0,boat 79 | 78,0.0009,170,20,0,bar 80 | 79,0.0009,68,6,0,arcade;machine 81 | 80,0.0009,65,8,1,hovel;hut;hutch;shack;shanty 82 | 81,0.0009,248,25,0,bus;autobus;coach;charabanc;double-decker;jitney;motorbus;motorcoach;omnibus;passenger;vehicle 83 | 82,0.0008,492,49,0,towel 84 | 83,0.0008,2510,269,0,light;light;source 85 | 84,0.0008,440,39,0,truck;motortruck 86 | 85,0.0008,147,18,1,tower 87 | 86,0.0008,583,56,0,chandelier;pendant;pendent 88 | 87,0.0007,533,61,0,awning;sunshade;sunblind 89 | 88,0.0007,1989,239,0,streetlight;street;lamp 90 | 89,0.0007,71,5,0,booth;cubicle;stall;kiosk 91 | 90,0.0007,618,53,0,television;television;receiver;television;set;tv;tv;set;idiot;box;boob;tube;telly;goggle;box 92 | 91,0.0007,135,12,0,airplane;aeroplane;plane 93 | 92,0.0007,83,5,1,dirt;track 94 | 93,0.0007,178,17,0,apparel;wearing;apparel;dress;clothes 95 | 94,0.0006,1003,104,0,pole 96 | 95,0.0006,182,12,1,land;ground;soil 97 | 96,0.0006,452,50,0,bannister;banister;balustrade;balusters;handrail 98 | 97,0.0006,42,6,1,escalator;moving;staircase;moving;stairway 99 | 98,0.0006,307,31,0,ottoman;pouf;pouffe;puff;hassock 100 | 99,0.0006,965,114,0,bottle 101 | 100,0.0006,117,13,0,buffet;counter;sideboard 102 | 101,0.0006,354,35,0,poster;posting;placard;notice;bill;card 103 | 102,0.0006,108,9,1,stage 104 | 103,0.0006,557,55,0,van 105 | 104,0.0006,52,4,0,ship 106 | 105,0.0005,99,5,0,fountain 107 | 106,0.0005,57,4,1,conveyer;belt;conveyor;belt;conveyer;conveyor;transporter 108 | 107,0.0005,292,31,0,canopy 109 | 108,0.0005,77,9,0,washer;automatic;washer;washing;machine 110 | 109,0.0005,340,38,0,plaything;toy 111 | 110,0.0005,66,3,1,swimming;pool;swimming;bath;natatorium 112 | 111,0.0005,465,49,0,stool 113 | 112,0.0005,50,4,0,barrel;cask 114 | 113,0.0005,622,75,0,basket;handbasket 115 | 114,0.0005,80,9,1,waterfall;falls 116 | 115,0.0005,59,3,0,tent;collapsible;shelter 117 | 116,0.0005,531,72,0,bag 118 | 117,0.0005,282,30,0,minibike;motorbike 119 | 118,0.0005,73,7,0,cradle 120 | 119,0.0005,435,44,0,oven 121 | 120,0.0005,136,25,0,ball 122 | 121,0.0005,116,24,0,food;solid;food 123 | 122,0.0004,266,31,0,step;stair 124 | 123,0.0004,58,12,0,tank;storage;tank 125 | 124,0.0004,418,83,0,trade;name;brand;name;brand;marque 126 | 125,0.0004,319,43,0,microwave;microwave;oven 127 | 126,0.0004,1193,139,0,pot;flowerpot 128 | 127,0.0004,97,23,0,animal;animate;being;beast;brute;creature;fauna 129 | 128,0.0004,347,36,0,bicycle;bike;wheel;cycle 130 | 129,0.0004,52,5,1,lake 131 | 130,0.0004,246,22,0,dishwasher;dish;washer;dishwashing;machine 132 | 131,0.0004,108,13,0,screen;silver;screen;projection;screen 133 | 132,0.0004,201,30,0,blanket;cover 134 | 133,0.0004,285,21,0,sculpture 135 | 134,0.0004,268,27,0,hood;exhaust;hood 136 | 135,0.0003,1020,108,0,sconce 137 | 136,0.0003,1282,122,0,vase 138 | 137,0.0003,528,65,0,traffic;light;traffic;signal;stoplight 139 | 138,0.0003,453,57,0,tray 140 | 139,0.0003,671,100,0,ashcan;trash;can;garbage;can;wastebin;ash;bin;ash-bin;ashbin;dustbin;trash;barrel;trash;bin 141 | 140,0.0003,397,44,0,fan 142 | 141,0.0003,92,8,1,pier;wharf;wharfage;dock 143 | 142,0.0003,228,18,0,crt;screen 144 | 143,0.0003,570,59,0,plate 145 | 144,0.0003,217,22,0,monitor;monitoring;device 146 | 145,0.0003,206,19,0,bulletin;board;notice;board 147 | 146,0.0003,130,14,0,shower 148 | 147,0.0003,178,28,0,radiator 149 | 148,0.0002,504,57,0,glass;drinking;glass 150 | 149,0.0002,775,96,0,clock 151 | 150,0.0002,421,56,0,flag -------------------------------------------------------------------------------- /label_files/fewshot_coco.txt: -------------------------------------------------------------------------------- 1 | person 2 | bicycle 3 | car 4 | motorbike 5 | aeroplane 6 | bus 7 | train 8 | truck 9 | boat 10 | trafficlight 11 | firehydrant 12 | stopsign 13 | parkingmeter 14 | bench 15 | bird 16 | cat 17 | dog 18 | horse 19 | sheep 20 | cow 21 | elephant 22 | bear 23 | zebra 24 | giraffe 25 | backpack 26 | umbrella 27 | handbag 28 | tie 29 | suitcase 30 | frisbee 31 | skis 32 | snowboard 33 | sportsball 34 | kite 35 | baseballbat 36 | baseballglove 37 | skateboard 38 | surfboard 39 | tennisracket 40 | bottle 41 | wineglass 42 | cup 43 | fork 44 | knife 45 | spoon 46 | bowl 47 | banana 48 | apple 49 | sandwich 50 | orange 51 | broccoli 52 | carrot 53 | hotdog 54 | pizza 55 | donut 56 | cake 57 | chair 58 | sofa 59 | pottedplant 60 | bed 61 | diningtable 62 | toilet 63 | tvmonitor 64 | laptop 65 | mouse 66 | remote 67 | keyboard 68 | cellphone 69 | microwave 70 | oven 71 | toaster 72 | sink 73 | refrigerator 74 | book 75 | clock 76 | vase 77 | scissors 78 | teddybear 79 | hairdrier 80 | toothbrush -------------------------------------------------------------------------------- /label_files/fewshot_pascal.txt: -------------------------------------------------------------------------------- 1 | aeroplane 2 | bicycle 3 | bird 4 | boat 5 | bottle 6 | bus 7 | car 8 | cat 9 | chair 10 | cow 11 | diningtable 12 | dog 13 | horse 14 | motorbike 15 | person 16 | pottedplant 17 | sheep 18 | sofa 19 | train 20 | tvmonitor -------------------------------------------------------------------------------- /modules/lseg_module.py: -------------------------------------------------------------------------------- 1 | import re 2 | import torch 3 | import torch.nn as nn 4 | import torchvision.transforms as transforms 5 | from argparse import ArgumentParser 6 | import pytorch_lightning as pl 7 | from .lsegmentation_module import LSegmentationModule 8 | from .models.lseg_net import LSegNet 9 | from encoding.models.sseg.base import up_kwargs 10 | 11 | import os 12 | import clip 13 | import numpy as np 14 | 15 | from scipy import signal 16 | import glob 17 | 18 | from PIL import Image 19 | import matplotlib.pyplot as plt 20 | import pandas as pd 21 | 22 | 23 | class LSegModule(LSegmentationModule): 24 | def __init__(self, data_path, dataset, batch_size, base_lr, max_epochs, **kwargs): 25 | super(LSegModule, self).__init__( 26 | data_path, dataset, batch_size, base_lr, max_epochs, **kwargs 27 | ) 28 | 29 | if dataset == "citys": 30 | self.base_size = 2048 31 | self.crop_size = 768 32 | else: 33 | self.base_size = 520 34 | self.crop_size = 480 35 | 36 | use_pretrained = True 37 | norm_mean= [0.5, 0.5, 0.5] 38 | norm_std = [0.5, 0.5, 0.5] 39 | 40 | print('** Use norm {}, {} as the mean and std **'.format(norm_mean, norm_std)) 41 | 42 | train_transform = [ 43 | transforms.ToTensor(), 44 | transforms.Normalize(norm_mean, norm_std), 45 | ] 46 | 47 | val_transform = [ 48 | transforms.ToTensor(), 49 | transforms.Normalize(norm_mean, norm_std), 50 | ] 51 | 52 | self.train_transform = transforms.Compose(train_transform) 53 | self.val_transform = transforms.Compose(val_transform) 54 | 55 | self.trainset = self.get_trainset( 56 | dataset, 57 | augment=kwargs["augment"], 58 | base_size=self.base_size, 59 | crop_size=self.crop_size, 60 | ) 61 | 62 | self.valset = self.get_valset( 63 | dataset, 64 | augment=kwargs["augment"], 65 | base_size=self.base_size, 66 | crop_size=self.crop_size, 67 | ) 68 | 69 | use_batchnorm = ( 70 | (not kwargs["no_batchnorm"]) if "no_batchnorm" in kwargs else True 71 | ) 72 | # print(kwargs) 73 | 74 | labels = self.get_labels('ade20k') 75 | 76 | self.net = LSegNet( 77 | labels=labels, 78 | backbone=kwargs["backbone"], 79 | features=kwargs["num_features"], 80 | crop_size=self.crop_size, 81 | arch_option=kwargs["arch_option"], 82 | block_depth=kwargs["block_depth"], 83 | activation=kwargs["activation"], 84 | ) 85 | 86 | self.net.pretrained.model.patch_embed.img_size = ( 87 | self.crop_size, 88 | self.crop_size, 89 | ) 90 | 91 | self._up_kwargs = up_kwargs 92 | self.mean = norm_mean 93 | self.std = norm_std 94 | 95 | self.criterion = self.get_criterion(**kwargs) 96 | 97 | def get_labels(self, dataset): 98 | labels = [] 99 | path = 'label_files/{}_objectInfo150.txt'.format(dataset) 100 | assert os.path.exists(path), '*** Error : {} not exist !!!'.format(path) 101 | f = open(path, 'r') 102 | lines = f.readlines() 103 | for line in lines: 104 | label = line.strip().split(',')[-1].split(';')[0] 105 | labels.append(label) 106 | f.close() 107 | if dataset in ['ade20k']: 108 | labels = labels[1:] 109 | return labels 110 | 111 | 112 | @staticmethod 113 | def add_model_specific_args(parent_parser): 114 | parser = LSegmentationModule.add_model_specific_args(parent_parser) 115 | parser = ArgumentParser(parents=[parser]) 116 | 117 | parser.add_argument( 118 | "--backbone", 119 | type=str, 120 | default="clip_vitl16_384", 121 | help="backbone network", 122 | ) 123 | 124 | parser.add_argument( 125 | "--num_features", 126 | type=int, 127 | default=256, 128 | help="number of featurs that go from encoder to decoder", 129 | ) 130 | 131 | parser.add_argument("--dropout", type=float, default=0.1, help="dropout rate") 132 | 133 | parser.add_argument( 134 | "--finetune_weights", type=str, help="load weights to finetune from" 135 | ) 136 | 137 | parser.add_argument( 138 | "--no-scaleinv", 139 | default=True, 140 | action="store_false", 141 | help="turn off scaleinv layers", 142 | ) 143 | 144 | parser.add_argument( 145 | "--no-batchnorm", 146 | default=False, 147 | action="store_true", 148 | help="turn off batchnorm", 149 | ) 150 | 151 | parser.add_argument( 152 | "--widehead", default=False, action="store_true", help="wider output head" 153 | ) 154 | 155 | parser.add_argument( 156 | "--widehead_hr", 157 | default=False, 158 | action="store_true", 159 | help="wider output head", 160 | ) 161 | 162 | parser.add_argument( 163 | "--arch_option", 164 | type=int, 165 | default=0, 166 | help="which kind of architecture to be used", 167 | ) 168 | 169 | parser.add_argument( 170 | "--block_depth", 171 | type=int, 172 | default=0, 173 | help="how many blocks should be used", 174 | ) 175 | 176 | parser.add_argument( 177 | "--activation", 178 | choices=['lrelu', 'tanh'], 179 | default="lrelu", 180 | help="use which activation to activate the block", 181 | ) 182 | 183 | return parser 184 | -------------------------------------------------------------------------------- /modules/lseg_module_zs.py: -------------------------------------------------------------------------------- 1 | import re 2 | import torch 3 | import torch.nn as nn 4 | import torchvision.transforms as transforms 5 | from argparse import ArgumentParser 6 | import pytorch_lightning as pl 7 | from .lsegmentation_module_zs import LSegmentationModuleZS 8 | from .models.lseg_net_zs import LSegNetZS, LSegRNNetZS 9 | from encoding.models.sseg.base import up_kwargs 10 | import os 11 | import clip 12 | import numpy as np 13 | from scipy import signal 14 | import glob 15 | from PIL import Image 16 | import matplotlib.pyplot as plt 17 | import pandas as pd 18 | 19 | 20 | class LSegModuleZS(LSegmentationModuleZS): 21 | def __init__(self, data_path, dataset, batch_size, base_lr, max_epochs, **kwargs): 22 | super(LSegModuleZS, self).__init__( 23 | data_path, dataset, batch_size, base_lr, max_epochs, **kwargs 24 | ) 25 | label_list = self.get_labels(dataset) 26 | self.len_dataloader = len(label_list) 27 | 28 | # print(kwargs) 29 | if kwargs["use_pretrained"] in ['False', False]: 30 | use_pretrained = False 31 | elif kwargs["use_pretrained"] in ['True', True]: 32 | use_pretrained = True 33 | 34 | if kwargs["backbone"] in ["clip_resnet101"]: 35 | self.net = LSegRNNetZS( 36 | label_list=label_list, 37 | backbone=kwargs["backbone"], 38 | features=kwargs["num_features"], 39 | aux=kwargs["aux"], 40 | use_pretrained=use_pretrained, 41 | arch_option=kwargs["arch_option"], 42 | block_depth=kwargs["block_depth"], 43 | activation=kwargs["activation"], 44 | ) 45 | else: 46 | self.net = LSegNetZS( 47 | label_list=label_list, 48 | backbone=kwargs["backbone"], 49 | features=kwargs["num_features"], 50 | aux=kwargs["aux"], 51 | use_pretrained=use_pretrained, 52 | arch_option=kwargs["arch_option"], 53 | block_depth=kwargs["block_depth"], 54 | activation=kwargs["activation"], 55 | ) 56 | 57 | def get_labels(self, dataset): 58 | labels = [] 59 | path = 'label_files/fewshot_{}.txt'.format(dataset) 60 | assert os.path.exists(path), '*** Error : {} not exist !!!'.format(path) 61 | f = open(path, 'r') 62 | lines = f.readlines() 63 | for line in lines: 64 | label = line.strip() 65 | labels.append(label) 66 | f.close() 67 | print(labels) 68 | return labels 69 | 70 | @staticmethod 71 | def add_model_specific_args(parent_parser): 72 | parser = LSegmentationModuleZS.add_model_specific_args(parent_parser) 73 | parser = ArgumentParser(parents=[parser]) 74 | 75 | parser.add_argument( 76 | "--backbone", 77 | type=str, 78 | default="vitb16_384", 79 | help="backbone network", 80 | ) 81 | 82 | parser.add_argument( 83 | "--num_features", 84 | type=int, 85 | default=256, 86 | help="number of featurs that go from encoder to decoder", 87 | ) 88 | 89 | parser.add_argument("--dropout", type=float, default=0.1, help="dropout rate") 90 | 91 | parser.add_argument( 92 | "--finetune_weights", type=str, help="load weights to finetune from" 93 | ) 94 | 95 | parser.add_argument( 96 | "--no-scaleinv", 97 | default=True, 98 | action="store_false", 99 | help="turn off scaleinv layers", 100 | ) 101 | 102 | parser.add_argument( 103 | "--no-batchnorm", 104 | default=False, 105 | action="store_true", 106 | help="turn off batchnorm", 107 | ) 108 | 109 | parser.add_argument( 110 | "--widehead", default=False, action="store_true", help="wider output head" 111 | ) 112 | 113 | parser.add_argument( 114 | "--widehead_hr", 115 | default=False, 116 | action="store_true", 117 | help="wider output head", 118 | ) 119 | 120 | parser.add_argument( 121 | "--use_pretrained", 122 | type=str, 123 | default="True", 124 | help="whether use the default model to intialize the model", 125 | ) 126 | 127 | parser.add_argument( 128 | "--arch_option", 129 | type=int, 130 | default=0, 131 | help="which kind of architecture to be used", 132 | ) 133 | 134 | parser.add_argument( 135 | "--block_depth", 136 | type=int, 137 | default=0, 138 | help="how many blocks should be used", 139 | ) 140 | 141 | parser.add_argument( 142 | "--activation", 143 | choices=['relu', 'lrelu', 'tanh'], 144 | default="relu", 145 | help="use which activation to activate the block", 146 | ) 147 | 148 | return parser 149 | -------------------------------------------------------------------------------- /modules/lsegmentation_module.py: -------------------------------------------------------------------------------- 1 | import types 2 | import time 3 | import random 4 | import clip 5 | import torch 6 | import torch.nn as nn 7 | import torchvision.transforms as transforms 8 | 9 | from argparse import ArgumentParser 10 | 11 | import pytorch_lightning as pl 12 | 13 | from data import get_dataset, get_available_datasets 14 | 15 | from encoding.models import get_segmentation_model 16 | from encoding.nn import SegmentationLosses 17 | 18 | from encoding.utils import batch_pix_accuracy, batch_intersection_union 19 | 20 | # add mixed precision 21 | import torch.cuda.amp as amp 22 | import numpy as np 23 | 24 | from encoding.utils import SegmentationMetric 25 | 26 | class LSegmentationModule(pl.LightningModule): 27 | def __init__(self, data_path, dataset, batch_size, base_lr, max_epochs, **kwargs): 28 | super().__init__() 29 | 30 | self.data_path = data_path 31 | self.batch_size = batch_size 32 | self.base_lr = base_lr / 16 * batch_size 33 | self.lr = self.base_lr 34 | 35 | self.epochs = max_epochs 36 | self.other_kwargs = kwargs 37 | self.enabled = False #True mixed precision will make things complicated and leading to NAN error 38 | self.scaler = amp.GradScaler(enabled=self.enabled) 39 | 40 | def forward(self, x): 41 | return self.net(x) 42 | 43 | def evaluate(self, x, target=None): 44 | pred = self.net.forward(x) 45 | if isinstance(pred, (tuple, list)): 46 | pred = pred[0] 47 | if target is None: 48 | return pred 49 | correct, labeled = batch_pix_accuracy(pred.data, target.data) 50 | inter, union = batch_intersection_union(pred.data, target.data, self.nclass) 51 | 52 | return correct, labeled, inter, union 53 | 54 | def evaluate_random(self, x, labelset, target=None): 55 | pred = self.net.forward(x, labelset) 56 | if isinstance(pred, (tuple, list)): 57 | pred = pred[0] 58 | if target is None: 59 | return pred 60 | correct, labeled = batch_pix_accuracy(pred.data, target.data) 61 | inter, union = batch_intersection_union(pred.data, target.data, self.nclass) 62 | 63 | return correct, labeled, inter, union 64 | 65 | 66 | def training_step(self, batch, batch_nb): 67 | img, target = batch 68 | with amp.autocast(enabled=self.enabled): 69 | out = self(img) 70 | multi_loss = isinstance(out, tuple) 71 | if multi_loss: 72 | loss = self.criterion(*out, target) 73 | else: 74 | loss = self.criterion(out, target) 75 | loss = self.scaler.scale(loss) 76 | final_output = out[0] if multi_loss else out 77 | train_pred, train_gt = self._filter_invalid(final_output, target) 78 | if train_gt.nelement() != 0: 79 | self.train_accuracy(train_pred, train_gt) 80 | self.log("train_loss", loss) 81 | return loss 82 | 83 | def training_epoch_end(self, outs): 84 | self.log("train_acc_epoch", self.train_accuracy.compute()) 85 | 86 | def validation_step(self, batch, batch_nb): 87 | img, target = batch 88 | out = self(img) 89 | multi_loss = isinstance(out, tuple) 90 | if multi_loss: 91 | val_loss = self.criterion(*out, target) 92 | else: 93 | val_loss = self.criterion(out, target) 94 | final_output = out[0] if multi_loss else out 95 | valid_pred, valid_gt = self._filter_invalid(final_output, target) 96 | self.val_iou.update(target, final_output) 97 | pixAcc, iou = self.val_iou.get() 98 | self.log("val_loss_step", val_loss) 99 | self.log("pix_acc_step", pixAcc) 100 | self.log( 101 | "val_acc_step", 102 | self.val_accuracy(valid_pred, valid_gt), 103 | ) 104 | self.log("val_iou", iou) 105 | 106 | def validation_epoch_end(self, outs): 107 | pixAcc, iou = self.val_iou.get() 108 | self.log("val_acc_epoch", self.val_accuracy.compute()) 109 | self.log("val_iou_epoch", iou) 110 | self.log("pix_acc_epoch", pixAcc) 111 | 112 | self.val_iou.reset() 113 | 114 | def _filter_invalid(self, pred, target): 115 | valid = target != self.other_kwargs["ignore_index"] 116 | _, mx = torch.max(pred, dim=1) 117 | return mx[valid], target[valid] 118 | 119 | def configure_optimizers(self): 120 | params_list = [ 121 | {"params": self.net.pretrained.parameters(), "lr": self.base_lr}, 122 | ] 123 | if hasattr(self.net, "scratch"): 124 | print("Found output scratch") 125 | params_list.append( 126 | {"params": self.net.scratch.parameters(), "lr": self.base_lr * 10} 127 | ) 128 | if hasattr(self.net, "auxlayer"): 129 | print("Found auxlayer") 130 | params_list.append( 131 | {"params": self.net.auxlayer.parameters(), "lr": self.base_lr * 10} 132 | ) 133 | if hasattr(self.net, "scale_inv_conv"): 134 | print(self.net.scale_inv_conv) 135 | print("Found scaleinv layers") 136 | params_list.append( 137 | { 138 | "params": self.net.scale_inv_conv.parameters(), 139 | "lr": self.base_lr * 10, 140 | } 141 | ) 142 | params_list.append( 143 | {"params": self.net.scale2_conv.parameters(), "lr": self.base_lr * 10} 144 | ) 145 | params_list.append( 146 | {"params": self.net.scale3_conv.parameters(), "lr": self.base_lr * 10} 147 | ) 148 | params_list.append( 149 | {"params": self.net.scale4_conv.parameters(), "lr": self.base_lr * 10} 150 | ) 151 | 152 | if self.other_kwargs["midasproto"]: 153 | print("Using midas optimization protocol") 154 | 155 | opt = torch.optim.Adam( 156 | params_list, 157 | lr=self.base_lr, 158 | betas=(0.9, 0.999), 159 | weight_decay=self.other_kwargs["weight_decay"], 160 | ) 161 | sch = torch.optim.lr_scheduler.LambdaLR( 162 | opt, lambda x: pow(1.0 - x / self.epochs, 0.9) 163 | ) 164 | 165 | else: 166 | opt = torch.optim.SGD( 167 | params_list, 168 | lr=self.base_lr, 169 | momentum=0.9, 170 | weight_decay=self.other_kwargs["weight_decay"], 171 | ) 172 | sch = torch.optim.lr_scheduler.LambdaLR( 173 | opt, lambda x: pow(1.0 - x / self.epochs, 0.9) 174 | ) 175 | return [opt], [sch] 176 | 177 | def train_dataloader(self): 178 | return torch.utils.data.DataLoader( 179 | self.trainset, 180 | batch_size=self.batch_size, 181 | shuffle=True, 182 | num_workers=16, 183 | worker_init_fn=lambda x: random.seed(time.time() + x), 184 | ) 185 | 186 | def val_dataloader(self): 187 | return torch.utils.data.DataLoader( 188 | self.valset, 189 | batch_size=self.batch_size, 190 | shuffle=False, 191 | num_workers=16, 192 | ) 193 | 194 | def get_trainset(self, dset, augment=False, **kwargs): 195 | print(kwargs) 196 | if augment == True: 197 | mode = "train_x" 198 | else: 199 | mode = "train" 200 | 201 | print(mode) 202 | dset = get_dataset( 203 | dset, 204 | root=self.data_path, 205 | split="train", 206 | mode=mode, 207 | transform=self.train_transform, 208 | **kwargs 209 | ) 210 | 211 | self.num_classes = dset.num_class 212 | self.train_accuracy = pl.metrics.Accuracy() 213 | 214 | return dset 215 | 216 | def get_valset(self, dset, augment=False, **kwargs): 217 | self.val_accuracy = pl.metrics.Accuracy() 218 | self.val_iou = SegmentationMetric(self.num_classes) 219 | 220 | if augment == True: 221 | mode = "val_x" 222 | else: 223 | mode = "val" 224 | 225 | print(mode) 226 | return get_dataset( 227 | dset, 228 | root=self.data_path, 229 | split="val", 230 | mode=mode, 231 | transform=self.val_transform, 232 | **kwargs 233 | ) 234 | 235 | 236 | def get_criterion(self, **kwargs): 237 | return SegmentationLosses( 238 | se_loss=kwargs["se_loss"], 239 | aux=kwargs["aux"], 240 | nclass=self.num_classes, 241 | se_weight=kwargs["se_weight"], 242 | aux_weight=kwargs["aux_weight"], 243 | ignore_index=kwargs["ignore_index"], 244 | ) 245 | 246 | @staticmethod 247 | def add_model_specific_args(parent_parser): 248 | parser = ArgumentParser(parents=[parent_parser], add_help=False) 249 | parser.add_argument( 250 | "--data_path", type=str, help="path where dataset is stored" 251 | ) 252 | parser.add_argument( 253 | "--dataset", 254 | choices=get_available_datasets(), 255 | default="ade20k", 256 | help="dataset to train on", 257 | ) 258 | parser.add_argument( 259 | "--batch_size", type=int, default=16, help="size of the batches" 260 | ) 261 | parser.add_argument( 262 | "--base_lr", type=float, default=0.004, help="learning rate" 263 | ) 264 | parser.add_argument("--momentum", type=float, default=0.9, help="SGD momentum") 265 | parser.add_argument( 266 | "--weight_decay", type=float, default=1e-4, help="weight_decay" 267 | ) 268 | parser.add_argument( 269 | "--aux", action="store_true", default=False, help="Auxilary Loss" 270 | ) 271 | parser.add_argument( 272 | "--aux-weight", 273 | type=float, 274 | default=0.2, 275 | help="Auxilary loss weight (default: 0.2)", 276 | ) 277 | parser.add_argument( 278 | "--se-loss", 279 | action="store_true", 280 | default=False, 281 | help="Semantic Encoding Loss SE-loss", 282 | ) 283 | parser.add_argument( 284 | "--se-weight", type=float, default=0.2, help="SE-loss weight (default: 0.2)" 285 | ) 286 | 287 | parser.add_argument( 288 | "--midasproto", action="store_true", default=False, help="midasprotocol" 289 | ) 290 | 291 | parser.add_argument( 292 | "--ignore_index", 293 | type=int, 294 | default=-1, 295 | help="numeric value of ignore label in gt", 296 | ) 297 | parser.add_argument( 298 | "--augment", 299 | action="store_true", 300 | default=False, 301 | help="Use extended augmentations", 302 | ) 303 | 304 | return parser 305 | -------------------------------------------------------------------------------- /modules/models/lseg_blocks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .lseg_vit import ( 5 | _make_pretrained_clip_vitl16_384, 6 | _make_pretrained_clip_vitb32_384, 7 | _make_pretrained_clipRN50x16_vitl16_384, 8 | forward_vit, 9 | ) 10 | 11 | 12 | def _make_encoder( 13 | backbone, 14 | features, 15 | use_pretrained=True, 16 | groups=1, 17 | expand=False, 18 | exportable=True, 19 | hooks=None, 20 | use_vit_only=False, 21 | use_readout="ignore", 22 | enable_attention_hooks=False, 23 | ): 24 | if backbone == "clip_vitl16_384": 25 | clip_pretrained, pretrained = _make_pretrained_clip_vitl16_384( 26 | use_pretrained, 27 | hooks=hooks, 28 | use_readout=use_readout, 29 | enable_attention_hooks=enable_attention_hooks, 30 | ) 31 | scratch = _make_scratch( 32 | [256, 512, 1024, 1024], features, groups=groups, expand=expand 33 | ) 34 | elif backbone == "clipRN50x16_vitl16_384": 35 | clip_pretrained, pretrained = _make_pretrained_clipRN50x16_vitl16_384( 36 | use_pretrained, 37 | hooks=hooks, 38 | use_readout=use_readout, 39 | enable_attention_hooks=enable_attention_hooks, 40 | ) 41 | scratch = _make_scratch( 42 | [256, 512, 1024, 1024], features, groups=groups, expand=expand 43 | ) 44 | elif backbone == "clip_vitb32_384": 45 | clip_pretrained, pretrained = _make_pretrained_clip_vitb32_384( 46 | use_pretrained, 47 | hooks=hooks, 48 | use_readout=use_readout, 49 | ) 50 | scratch = _make_scratch( 51 | [96, 192, 384, 768], features, groups=groups, expand=expand 52 | ) 53 | else: 54 | print(f"Backbone '{backbone}' not implemented") 55 | assert False 56 | 57 | return clip_pretrained, pretrained, scratch 58 | 59 | 60 | def _make_scratch(in_shape, out_shape, groups=1, expand=False): 61 | scratch = nn.Module() 62 | 63 | out_shape1 = out_shape 64 | out_shape2 = out_shape 65 | out_shape3 = out_shape 66 | out_shape4 = out_shape 67 | if expand == True: 68 | out_shape1 = out_shape 69 | out_shape2 = out_shape * 2 70 | out_shape3 = out_shape * 4 71 | out_shape4 = out_shape * 8 72 | 73 | scratch.layer1_rn = nn.Conv2d( 74 | in_shape[0], 75 | out_shape1, 76 | kernel_size=3, 77 | stride=1, 78 | padding=1, 79 | bias=False, 80 | groups=groups, 81 | ) 82 | scratch.layer2_rn = nn.Conv2d( 83 | in_shape[1], 84 | out_shape2, 85 | kernel_size=3, 86 | stride=1, 87 | padding=1, 88 | bias=False, 89 | groups=groups, 90 | ) 91 | scratch.layer3_rn = nn.Conv2d( 92 | in_shape[2], 93 | out_shape3, 94 | kernel_size=3, 95 | stride=1, 96 | padding=1, 97 | bias=False, 98 | groups=groups, 99 | ) 100 | scratch.layer4_rn = nn.Conv2d( 101 | in_shape[3], 102 | out_shape4, 103 | kernel_size=3, 104 | stride=1, 105 | padding=1, 106 | bias=False, 107 | groups=groups, 108 | ) 109 | 110 | return scratch 111 | 112 | 113 | class Interpolate(nn.Module): 114 | """Interpolation module.""" 115 | 116 | def __init__(self, scale_factor, mode, align_corners=False): 117 | """Init. 118 | 119 | Args: 120 | scale_factor (float): scaling 121 | mode (str): interpolation mode 122 | """ 123 | super(Interpolate, self).__init__() 124 | 125 | self.interp = nn.functional.interpolate 126 | self.scale_factor = scale_factor 127 | self.mode = mode 128 | self.align_corners = align_corners 129 | 130 | def forward(self, x): 131 | """Forward pass. 132 | 133 | Args: 134 | x (tensor): input 135 | 136 | Returns: 137 | tensor: interpolated data 138 | """ 139 | 140 | x = self.interp( 141 | x, 142 | scale_factor=self.scale_factor, 143 | mode=self.mode, 144 | align_corners=self.align_corners, 145 | ) 146 | 147 | return x 148 | 149 | 150 | class ResidualConvUnit(nn.Module): 151 | """Residual convolution module.""" 152 | 153 | def __init__(self, features): 154 | """Init. 155 | 156 | Args: 157 | features (int): number of features 158 | """ 159 | super().__init__() 160 | 161 | self.conv1 = nn.Conv2d( 162 | features, features, kernel_size=3, stride=1, padding=1, bias=True 163 | ) 164 | 165 | self.conv2 = nn.Conv2d( 166 | features, features, kernel_size=3, stride=1, padding=1, bias=True 167 | ) 168 | 169 | self.relu = nn.ReLU(inplace=True) 170 | 171 | def forward(self, x): 172 | """Forward pass. 173 | 174 | Args: 175 | x (tensor): input 176 | 177 | Returns: 178 | tensor: output 179 | """ 180 | out = self.relu(x) 181 | out = self.conv1(out) 182 | out = self.relu(out) 183 | out = self.conv2(out) 184 | 185 | return out + x 186 | 187 | 188 | class FeatureFusionBlock(nn.Module): 189 | """Feature fusion block.""" 190 | 191 | def __init__(self, features): 192 | """Init. 193 | 194 | Args: 195 | features (int): number of features 196 | """ 197 | super(FeatureFusionBlock, self).__init__() 198 | 199 | self.resConfUnit1 = ResidualConvUnit(features) 200 | self.resConfUnit2 = ResidualConvUnit(features) 201 | 202 | def forward(self, *xs): 203 | """Forward pass. 204 | 205 | Returns: 206 | tensor: output 207 | """ 208 | output = xs[0] 209 | 210 | if len(xs) == 2: 211 | output += self.resConfUnit1(xs[1]) 212 | 213 | output = self.resConfUnit2(output) 214 | 215 | output = nn.functional.interpolate( 216 | output, scale_factor=2, mode="bilinear", align_corners=True 217 | ) 218 | 219 | return output 220 | 221 | 222 | class ResidualConvUnit_custom(nn.Module): 223 | """Residual convolution module.""" 224 | 225 | def __init__(self, features, activation, bn): 226 | """Init. 227 | 228 | Args: 229 | features (int): number of features 230 | """ 231 | super().__init__() 232 | 233 | self.bn = bn 234 | 235 | self.groups = 1 236 | 237 | self.conv1 = nn.Conv2d( 238 | features, 239 | features, 240 | kernel_size=3, 241 | stride=1, 242 | padding=1, 243 | bias=not self.bn, 244 | groups=self.groups, 245 | ) 246 | 247 | self.conv2 = nn.Conv2d( 248 | features, 249 | features, 250 | kernel_size=3, 251 | stride=1, 252 | padding=1, 253 | bias=not self.bn, 254 | groups=self.groups, 255 | ) 256 | 257 | if self.bn == True: 258 | self.bn1 = nn.BatchNorm2d(features) 259 | self.bn2 = nn.BatchNorm2d(features) 260 | 261 | self.activation = activation 262 | 263 | self.skip_add = nn.quantized.FloatFunctional() 264 | 265 | def forward(self, x): 266 | """Forward pass. 267 | 268 | Args: 269 | x (tensor): input 270 | 271 | Returns: 272 | tensor: output 273 | """ 274 | 275 | out = self.activation(x) 276 | out = self.conv1(out) 277 | if self.bn == True: 278 | out = self.bn1(out) 279 | 280 | out = self.activation(out) 281 | out = self.conv2(out) 282 | if self.bn == True: 283 | out = self.bn2(out) 284 | 285 | if self.groups > 1: 286 | out = self.conv_merge(out) 287 | 288 | return self.skip_add.add(out, x) 289 | 290 | # return out + x 291 | 292 | 293 | class FeatureFusionBlock_custom(nn.Module): 294 | """Feature fusion block.""" 295 | 296 | def __init__( 297 | self, 298 | features, 299 | activation, 300 | deconv=False, 301 | bn=False, 302 | expand=False, 303 | align_corners=True, 304 | ): 305 | """Init. 306 | 307 | Args: 308 | features (int): number of features 309 | """ 310 | super(FeatureFusionBlock_custom, self).__init__() 311 | 312 | self.deconv = deconv 313 | self.align_corners = align_corners 314 | 315 | self.groups = 1 316 | 317 | self.expand = expand 318 | out_features = features 319 | if self.expand == True: 320 | out_features = features // 2 321 | 322 | self.out_conv = nn.Conv2d( 323 | features, 324 | out_features, 325 | kernel_size=1, 326 | stride=1, 327 | padding=0, 328 | bias=True, 329 | groups=1, 330 | ) 331 | 332 | self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn) 333 | self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn) 334 | 335 | self.skip_add = nn.quantized.FloatFunctional() 336 | 337 | def forward(self, *xs): 338 | """Forward pass. 339 | 340 | Returns: 341 | tensor: output 342 | """ 343 | output = xs[0] 344 | 345 | if len(xs) == 2: 346 | res = self.resConfUnit1(xs[1]) 347 | output = self.skip_add.add(output, res) 348 | # output += res 349 | 350 | output = self.resConfUnit2(output) 351 | 352 | output = nn.functional.interpolate( 353 | output, scale_factor=2, mode="bilinear", align_corners=self.align_corners 354 | ) 355 | 356 | output = self.out_conv(output) 357 | 358 | return output 359 | 360 | -------------------------------------------------------------------------------- /modules/models/lseg_blocks_zs.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .lseg_vit_zs import ( 5 | _make_pretrained_clip_vitl16_384, 6 | _make_pretrained_clip_vitb32_384, 7 | _make_pretrained_clip_rn101, 8 | forward_vit, 9 | ) 10 | 11 | def _make_encoder( 12 | backbone, 13 | features, 14 | use_pretrained, 15 | groups=1, 16 | expand=False, 17 | exportable=True, 18 | hooks=None, 19 | use_vit_only=False, 20 | use_readout="ignore", 21 | enable_attention_hooks=False, 22 | ): 23 | if backbone == "clip_vitl16_384": 24 | clip_pretrained, pretrained = _make_pretrained_clip_vitl16_384( 25 | use_pretrained, 26 | hooks=hooks, 27 | use_readout=use_readout, 28 | enable_attention_hooks=enable_attention_hooks, 29 | ) 30 | scratch = _make_scratch( 31 | [256, 512, 1024, 1024], features, groups=groups, expand=expand 32 | ) 33 | elif backbone == "clip_vitb32_384": 34 | clip_pretrained, pretrained = _make_pretrained_clip_vitb32_384( 35 | use_pretrained, 36 | hooks=hooks, 37 | use_readout=use_readout, 38 | ) 39 | scratch = _make_scratch( 40 | [96, 192, 384, 768], features, groups=groups, expand=expand 41 | ) 42 | elif backbone == "clip_resnet101": 43 | clip_pretrained, pretrained = _make_pretrained_clip_rn101( 44 | use_pretrained, 45 | ) 46 | scratch = _make_scratch( 47 | [256, 512, 1024, 2048], features, groups=groups, expand=expand 48 | ) 49 | else: 50 | print(f"Backbone '{backbone}' not implemented") 51 | assert False 52 | 53 | return clip_pretrained, pretrained, scratch 54 | 55 | 56 | def _make_scratch(in_shape, out_shape, groups=1, expand=False): 57 | scratch = nn.Module() 58 | 59 | out_shape1 = out_shape 60 | out_shape2 = out_shape 61 | out_shape3 = out_shape 62 | out_shape4 = out_shape 63 | if expand == True: 64 | out_shape1 = out_shape 65 | out_shape2 = out_shape * 2 66 | out_shape3 = out_shape * 4 67 | out_shape4 = out_shape * 8 68 | 69 | scratch.layer1_rn = nn.Conv2d( 70 | in_shape[0], 71 | out_shape1, 72 | kernel_size=3, 73 | stride=1, 74 | padding=1, 75 | bias=False, 76 | groups=groups, 77 | ) 78 | scratch.layer2_rn = nn.Conv2d( 79 | in_shape[1], 80 | out_shape2, 81 | kernel_size=3, 82 | stride=1, 83 | padding=1, 84 | bias=False, 85 | groups=groups, 86 | ) 87 | scratch.layer3_rn = nn.Conv2d( 88 | in_shape[2], 89 | out_shape3, 90 | kernel_size=3, 91 | stride=1, 92 | padding=1, 93 | bias=False, 94 | groups=groups, 95 | ) 96 | scratch.layer4_rn = nn.Conv2d( 97 | in_shape[3], 98 | out_shape4, 99 | kernel_size=3, 100 | stride=1, 101 | padding=1, 102 | bias=False, 103 | groups=groups, 104 | ) 105 | 106 | return scratch 107 | 108 | 109 | def _make_resnet_backbone(resnet): 110 | pretrained = nn.Module() 111 | pretrained.layer1 = nn.Sequential( 112 | resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1 113 | ) 114 | 115 | pretrained.layer2 = resnet.layer2 116 | pretrained.layer3 = resnet.layer3 117 | pretrained.layer4 = resnet.layer4 118 | 119 | return pretrained 120 | 121 | 122 | def _make_pretrained_resnext101_wsl(use_pretrained): 123 | resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl") 124 | return _make_resnet_backbone(resnet) 125 | 126 | 127 | class Interpolate(nn.Module): 128 | """Interpolation module.""" 129 | 130 | def __init__(self, scale_factor, mode, align_corners=False): 131 | """Init. 132 | 133 | Args: 134 | scale_factor (float): scaling 135 | mode (str): interpolation mode 136 | """ 137 | super(Interpolate, self).__init__() 138 | 139 | self.interp = nn.functional.interpolate 140 | self.scale_factor = scale_factor 141 | self.mode = mode 142 | self.align_corners = align_corners 143 | 144 | def forward(self, x): 145 | """Forward pass. 146 | 147 | Args: 148 | x (tensor): input 149 | 150 | Returns: 151 | tensor: interpolated data 152 | """ 153 | 154 | x = self.interp( 155 | x, 156 | scale_factor=self.scale_factor, 157 | mode=self.mode, 158 | align_corners=self.align_corners, 159 | ) 160 | 161 | return x 162 | 163 | 164 | class ResidualConvUnit(nn.Module): 165 | """Residual convolution module.""" 166 | 167 | def __init__(self, features): 168 | """Init. 169 | 170 | Args: 171 | features (int): number of features 172 | """ 173 | super().__init__() 174 | 175 | self.conv1 = nn.Conv2d( 176 | features, features, kernel_size=3, stride=1, padding=1, bias=True 177 | ) 178 | 179 | self.conv2 = nn.Conv2d( 180 | features, features, kernel_size=3, stride=1, padding=1, bias=True 181 | ) 182 | 183 | self.relu = nn.ReLU(inplace=True) 184 | 185 | def forward(self, x): 186 | """Forward pass. 187 | 188 | Args: 189 | x (tensor): input 190 | 191 | Returns: 192 | tensor: output 193 | """ 194 | out = self.relu(x) 195 | out = self.conv1(out) 196 | out = self.relu(out) 197 | out = self.conv2(out) 198 | 199 | return out + x 200 | 201 | 202 | class FeatureFusionBlock(nn.Module): 203 | """Feature fusion block.""" 204 | 205 | def __init__(self, features): 206 | """Init. 207 | 208 | Args: 209 | features (int): number of features 210 | """ 211 | super(FeatureFusionBlock, self).__init__() 212 | 213 | self.resConfUnit1 = ResidualConvUnit(features) 214 | self.resConfUnit2 = ResidualConvUnit(features) 215 | 216 | def forward(self, *xs): 217 | """Forward pass. 218 | 219 | Returns: 220 | tensor: output 221 | """ 222 | output = xs[0] 223 | 224 | if len(xs) == 2: 225 | output += self.resConfUnit1(xs[1]) 226 | 227 | output = self.resConfUnit2(output) 228 | 229 | output = nn.functional.interpolate( 230 | output, scale_factor=2, mode="bilinear", align_corners=True 231 | ) 232 | 233 | return output 234 | 235 | 236 | class ResidualConvUnit_custom(nn.Module): 237 | """Residual convolution module.""" 238 | 239 | def __init__(self, features, activation, bn): 240 | """Init. 241 | 242 | Args: 243 | features (int): number of features 244 | """ 245 | super().__init__() 246 | 247 | self.bn = bn 248 | 249 | self.groups = 1 250 | 251 | self.conv1 = nn.Conv2d( 252 | features, 253 | features, 254 | kernel_size=3, 255 | stride=1, 256 | padding=1, 257 | bias=not self.bn, 258 | groups=self.groups, 259 | ) 260 | 261 | self.conv2 = nn.Conv2d( 262 | features, 263 | features, 264 | kernel_size=3, 265 | stride=1, 266 | padding=1, 267 | bias=not self.bn, 268 | groups=self.groups, 269 | ) 270 | 271 | if self.bn == True: 272 | self.bn1 = nn.BatchNorm2d(features) 273 | self.bn2 = nn.BatchNorm2d(features) 274 | 275 | self.activation = activation 276 | 277 | self.skip_add = nn.quantized.FloatFunctional() 278 | 279 | def forward(self, x): 280 | """Forward pass. 281 | 282 | Args: 283 | x (tensor): input 284 | 285 | Returns: 286 | tensor: output 287 | """ 288 | 289 | out = self.activation(x) 290 | out = self.conv1(out) 291 | if self.bn == True: 292 | out = self.bn1(out) 293 | 294 | out = self.activation(out) 295 | out = self.conv2(out) 296 | if self.bn == True: 297 | out = self.bn2(out) 298 | 299 | if self.groups > 1: 300 | out = self.conv_merge(out) 301 | 302 | return self.skip_add.add(out, x) 303 | 304 | # return out + x 305 | 306 | 307 | class FeatureFusionBlock_custom(nn.Module): 308 | """Feature fusion block.""" 309 | 310 | def __init__( 311 | self, 312 | features, 313 | activation, 314 | deconv=False, 315 | bn=False, 316 | expand=False, 317 | align_corners=True, 318 | ): 319 | """Init. 320 | 321 | Args: 322 | features (int): number of features 323 | """ 324 | super(FeatureFusionBlock_custom, self).__init__() 325 | 326 | self.deconv = deconv 327 | self.align_corners = align_corners 328 | 329 | self.groups = 1 330 | 331 | self.expand = expand 332 | out_features = features 333 | if self.expand == True: 334 | out_features = features // 2 335 | 336 | self.out_conv = nn.Conv2d( 337 | features, 338 | out_features, 339 | kernel_size=1, 340 | stride=1, 341 | padding=0, 342 | bias=True, 343 | groups=1, 344 | ) 345 | 346 | self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn) 347 | self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn) 348 | 349 | self.skip_add = nn.quantized.FloatFunctional() 350 | 351 | def forward(self, *xs): 352 | """Forward pass. 353 | 354 | Returns: 355 | tensor: output 356 | """ 357 | output = xs[0] 358 | 359 | if len(xs) == 2: 360 | res = self.resConfUnit1(xs[1]) 361 | output = self.skip_add.add(output, res) 362 | # output += res 363 | 364 | output = self.resConfUnit2(output) 365 | 366 | output = nn.functional.interpolate( 367 | output, scale_factor=2, mode="bilinear", align_corners=self.align_corners 368 | ) 369 | 370 | output = self.out_conv(output) 371 | 372 | return output 373 | 374 | -------------------------------------------------------------------------------- /modules/models/lseg_net.py: -------------------------------------------------------------------------------- 1 | import math 2 | import types 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from .lseg_blocks import FeatureFusionBlock, Interpolate, _make_encoder, FeatureFusionBlock_custom, forward_vit 9 | import clip 10 | import numpy as np 11 | import pandas as pd 12 | import os 13 | 14 | class depthwise_clipseg_conv(nn.Module): 15 | def __init__(self): 16 | super(depthwise_clipseg_conv, self).__init__() 17 | self.depthwise = nn.Conv2d(1, 1, kernel_size=3, padding=1) 18 | 19 | def depthwise_clipseg(self, x, channels): 20 | x = torch.cat([self.depthwise(x[:, i].unsqueeze(1)) for i in range(channels)], dim=1) 21 | return x 22 | 23 | def forward(self, x): 24 | channels = x.shape[1] 25 | out = self.depthwise_clipseg(x, channels) 26 | return out 27 | 28 | 29 | class depthwise_conv(nn.Module): 30 | def __init__(self, kernel_size=3, stride=1, padding=1): 31 | super(depthwise_conv, self).__init__() 32 | self.depthwise = nn.Conv2d(1, 1, kernel_size=kernel_size, stride=stride, padding=padding) 33 | 34 | def forward(self, x): 35 | # support for 4D tensor with NCHW 36 | C, H, W = x.shape[1:] 37 | x = x.reshape(-1, 1, H, W) 38 | x = self.depthwise(x) 39 | x = x.view(-1, C, H, W) 40 | return x 41 | 42 | 43 | class depthwise_block(nn.Module): 44 | def __init__(self, kernel_size=3, stride=1, padding=1, activation='relu'): 45 | super(depthwise_block, self).__init__() 46 | self.depthwise = depthwise_conv(kernel_size=3, stride=1, padding=1) 47 | if activation == 'relu': 48 | self.activation = nn.ReLU() 49 | elif activation == 'lrelu': 50 | self.activation = nn.LeakyReLU() 51 | elif activation == 'tanh': 52 | self.activation = nn.Tanh() 53 | 54 | def forward(self, x, act=True): 55 | x = self.depthwise(x) 56 | if act: 57 | x = self.activation(x) 58 | return x 59 | 60 | 61 | class bottleneck_block(nn.Module): 62 | def __init__(self, kernel_size=3, stride=1, padding=1, activation='relu'): 63 | super(bottleneck_block, self).__init__() 64 | self.depthwise = depthwise_conv(kernel_size=3, stride=1, padding=1) 65 | if activation == 'relu': 66 | self.activation = nn.ReLU() 67 | elif activation == 'lrelu': 68 | self.activation = nn.LeakyReLU() 69 | elif activation == 'tanh': 70 | self.activation = nn.Tanh() 71 | 72 | 73 | def forward(self, x, act=True): 74 | sum_layer = x.max(dim=1, keepdim=True)[0] 75 | x = self.depthwise(x) 76 | x = x + sum_layer 77 | if act: 78 | x = self.activation(x) 79 | return x 80 | 81 | class BaseModel(torch.nn.Module): 82 | def load(self, path): 83 | """Load model from file. 84 | Args: 85 | path (str): file path 86 | """ 87 | parameters = torch.load(path, map_location=torch.device("cpu")) 88 | 89 | if "optimizer" in parameters: 90 | parameters = parameters["model"] 91 | 92 | self.load_state_dict(parameters) 93 | 94 | def _make_fusion_block(features, use_bn): 95 | return FeatureFusionBlock_custom( 96 | features, 97 | activation=nn.ReLU(False), 98 | deconv=False, 99 | bn=use_bn, 100 | expand=False, 101 | align_corners=True, 102 | ) 103 | 104 | class LSeg(BaseModel): 105 | def __init__( 106 | self, 107 | head, 108 | features=256, 109 | backbone="clip_vitl16_384", 110 | readout="project", 111 | channels_last=False, 112 | use_bn=False, 113 | **kwargs, 114 | ): 115 | super(LSeg, self).__init__() 116 | 117 | self.channels_last = channels_last 118 | 119 | hooks = { 120 | "clip_vitl16_384": [5, 11, 17, 23], 121 | "clipRN50x16_vitl16_384": [5, 11, 17, 23], 122 | "clip_vitb32_384": [2, 5, 8, 11], 123 | } 124 | 125 | # Instantiate backbone and reassemble blocks 126 | self.clip_pretrained, self.pretrained, self.scratch = _make_encoder( 127 | backbone, 128 | features, 129 | groups=1, 130 | expand=False, 131 | exportable=False, 132 | hooks=hooks[backbone], 133 | use_readout=readout, 134 | ) 135 | 136 | self.scratch.refinenet1 = _make_fusion_block(features, use_bn) 137 | self.scratch.refinenet2 = _make_fusion_block(features, use_bn) 138 | self.scratch.refinenet3 = _make_fusion_block(features, use_bn) 139 | self.scratch.refinenet4 = _make_fusion_block(features, use_bn) 140 | 141 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)).exp() 142 | if backbone in ["clipRN50x16_vitl16_384"]: 143 | self.out_c = 768 144 | else: 145 | self.out_c = 512 146 | self.scratch.head1 = nn.Conv2d(features, self.out_c, kernel_size=1) 147 | 148 | self.arch_option = kwargs["arch_option"] 149 | if self.arch_option == 1: 150 | self.scratch.head_block = bottleneck_block(activation=kwargs["activation"]) 151 | self.block_depth = kwargs['block_depth'] 152 | elif self.arch_option == 2: 153 | self.scratch.head_block = depthwise_block(activation=kwargs["activation"]) 154 | self.block_depth = kwargs['block_depth'] 155 | 156 | self.scratch.output_conv = head 157 | 158 | self.text = clip.tokenize(self.labels) 159 | 160 | def forward(self, x, labelset=''): 161 | if labelset == '': 162 | text = self.text 163 | else: 164 | text = clip.tokenize(labelset) 165 | 166 | if self.channels_last == True: 167 | x.contiguous(memory_format=torch.channels_last) 168 | 169 | layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x) 170 | 171 | layer_1_rn = self.scratch.layer1_rn(layer_1) 172 | layer_2_rn = self.scratch.layer2_rn(layer_2) 173 | layer_3_rn = self.scratch.layer3_rn(layer_3) 174 | layer_4_rn = self.scratch.layer4_rn(layer_4) 175 | 176 | path_4 = self.scratch.refinenet4(layer_4_rn) 177 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 178 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 179 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 180 | 181 | text = text.to(x.device) 182 | self.logit_scale = self.logit_scale.to(x.device) 183 | text_features = self.clip_pretrained.encode_text(text) 184 | 185 | image_features = self.scratch.head1(path_1) 186 | 187 | imshape = image_features.shape 188 | image_features = image_features.permute(0,2,3,1).reshape(-1, self.out_c) 189 | 190 | # normalized features 191 | image_features = image_features / image_features.norm(dim=-1, keepdim=True) 192 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 193 | 194 | logits_per_image = self.logit_scale * image_features.half() @ text_features.t() 195 | 196 | out = logits_per_image.float().view(imshape[0], imshape[2], imshape[3], -1).permute(0,3,1,2) 197 | 198 | if self.arch_option in [1, 2]: 199 | for _ in range(self.block_depth - 1): 200 | out = self.scratch.head_block(out) 201 | out = self.scratch.head_block(out, False) 202 | 203 | out = self.scratch.output_conv(out) 204 | 205 | return out 206 | 207 | 208 | class LSegNet(LSeg): 209 | """Network for semantic segmentation.""" 210 | def __init__(self, labels, path=None, scale_factor=0.5, crop_size=480, **kwargs): 211 | 212 | features = kwargs["features"] if "features" in kwargs else 256 213 | kwargs["use_bn"] = True 214 | 215 | self.crop_size = crop_size 216 | self.scale_factor = scale_factor 217 | self.labels = labels 218 | 219 | head = nn.Sequential( 220 | Interpolate(scale_factor=2, mode="bilinear", align_corners=True), 221 | ) 222 | 223 | super().__init__(head, **kwargs) 224 | 225 | if path is not None: 226 | self.load(path) 227 | 228 | 229 | 230 | 231 | -------------------------------------------------------------------------------- /prepare_ade20k.py: -------------------------------------------------------------------------------- 1 | # + 2 | # revised from https://github.com/zhanghang1989/PyTorch-Encoding/blob/331ecdd5306104614cb414b16fbcd9d1a8d40e1e/scripts/prepare_ade20k.py 3 | 4 | """Prepare ADE20K dataset""" 5 | import os 6 | import shutil 7 | import argparse 8 | import zipfile 9 | from encoding.utils import download, mkdir 10 | # - 11 | 12 | _TARGET_DIR = os.path.expanduser('../datasets/') 13 | 14 | def parse_args(): 15 | parser = argparse.ArgumentParser( 16 | description='Initialize ADE20K dataset.', 17 | epilog='Example: python prepare_ade20k.py', 18 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 19 | parser.add_argument('--download-dir', default=None, help='dataset directory on disk') 20 | args = parser.parse_args() 21 | return args 22 | 23 | def download_ade(path, overwrite=False): 24 | _AUG_DOWNLOAD_URLS = [ 25 | ('http://data.csail.mit.edu/places/ADEchallenge/ADEChallengeData2016.zip', '219e1696abb36c8ba3a3afe7fb2f4b4606a897c7'), 26 | ('http://data.csail.mit.edu/places/ADEchallenge/release_test.zip', 'e05747892219d10e9243933371a497e905a4860c'),] 27 | download_dir = path 28 | mkdir(download_dir) 29 | for url, checksum in _AUG_DOWNLOAD_URLS: 30 | filename = download(url, path=download_dir, overwrite=overwrite, sha1_hash=checksum) 31 | # extract 32 | with zipfile.ZipFile(filename,"r") as zip_ref: 33 | zip_ref.extractall(path=path) 34 | 35 | 36 | if __name__ == '__main__': 37 | args = parse_args() 38 | mkdir(os.path.expanduser('../datasets/')) 39 | if args.download_dir is not None: 40 | if os.path.isdir(_TARGET_DIR): 41 | os.remove(_TARGET_DIR) 42 | # make symlink 43 | os.symlink(args.download_dir, _TARGET_DIR) 44 | else: 45 | download_ade(_TARGET_DIR, overwrite=False) 46 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.14.1 2 | aiohttp==3.7.4.post0 3 | anyio==3.3.4 4 | argon2-cffi==21.1.0 5 | async-timeout==3.0.1 6 | attrs==21.2.0 7 | Babel==2.9.1 8 | backcall==0.2.0 9 | bleach==4.1.0 10 | cachetools==4.2.4 11 | certifi==2021.10.8 12 | cffi==1.15.0 13 | chardet==4.0.0 14 | charset-normalizer==2.0.7 15 | clip @ git+https://github.com/openai/CLIP.git@04f4dc2ca1ed0acc9893bd1a3b526a7e02c4bb10 16 | cycler==0.10.0 17 | debugpy==1.5.0 18 | decorator==5.1.0 19 | defusedxml==0.7.1 20 | entrypoints==0.3 21 | fsspec==2021.10.1 22 | ftfy==6.0.3 23 | future==0.18.2 24 | google-auth==2.3.0 25 | google-auth-oauthlib==0.4.6 26 | grpcio==1.41.0 27 | idna==3.3 28 | imageio==2.9.0 29 | ipykernel==6.4.1 30 | ipython==7.28.0 31 | ipython-genutils==0.2.0 32 | ipywidgets==7.6.5 33 | jedi==0.18.0 34 | Jinja2==3.0.2 35 | json5==0.9.6 36 | jsonschema==4.1.0 37 | jupyter==1.0.0 38 | jupyter-client==7.0.6 39 | jupyter-console==6.4.0 40 | jupyter-core==4.8.1 41 | jupyter-server==1.11.1 42 | jupyterlab==3.2.0 43 | jupyterlab-pygments==0.1.2 44 | jupyterlab-server==2.8.2 45 | jupyterlab-widgets==1.0.2 46 | kiwisolver==1.3.2 47 | Markdown==3.3.4 48 | MarkupSafe==2.0.1 49 | matplotlib==3.4.3 50 | matplotlib-inline==0.1.3 51 | mistune==0.8.4 52 | multidict==5.2.0 53 | nbclassic==0.3.2 54 | nbclient==0.5.4 55 | nbconvert==6.2.0 56 | nbformat==5.1.3 57 | nest-asyncio==1.5.1 58 | nose==1.3.7 59 | notebook==6.4.4 60 | numpy==1.21.2 61 | oauthlib==3.1.1 62 | packaging==21.0 63 | pandas==1.3.4 64 | pandocfilters==1.5.0 65 | parso==0.8.2 66 | pexpect==4.8.0 67 | pickleshare==0.7.5 68 | Pillow==8.4.0 69 | portalocker==2.3.2 70 | prometheus-client==0.11.0 71 | prompt-toolkit==3.0.20 72 | protobuf==3.18.1 73 | ptyprocess==0.7.0 74 | pyasn1==0.4.8 75 | pyasn1-modules==0.2.8 76 | pycparser==2.20 77 | pyDeprecate==0.3.1 78 | Pygments==2.10.0 79 | pyparsing==2.4.7 80 | pyrsistent==0.18.0 81 | python-dateutil==2.8.2 82 | pytorch-lightning==1.4.9 83 | pytz==2021.3 84 | PyYAML==6.0 85 | pyzmq==22.3.0 86 | qtconsole==5.1.1 87 | QtPy==1.11.2 88 | regex==2021.10.8 89 | requests==2.26.0 90 | requests-oauthlib==1.3.0 91 | requests-unixsocket==0.2.0 92 | rsa==4.7.2 93 | scipy==1.7.1 94 | Send2Trash==1.8.0 95 | six==1.16.0 96 | sniffio==1.2.0 97 | tensorboard==2.7.0 98 | tensorboard-data-server==0.6.1 99 | tensorboard-plugin-wit==1.8.0 100 | terminado==0.12.1 101 | testpath==0.5.0 102 | timm==0.4.12 103 | torch==1.9.1+cu111 104 | torch-encoding @ git+https://github.com/zhanghang1989/PyTorch-Encoding/@331ecdd5306104614cb414b16fbcd9d1a8d40e1e 105 | torchaudio==0.9.1 106 | torchmetrics==0.5.1 107 | torchvision==0.10.1+cu111 108 | tornado==6.1 109 | tqdm==4.62.3 110 | traitlets==5.1.0 111 | typing-extensions==3.10.0.2 112 | urllib3==1.26.7 113 | wcwidth==0.2.5 114 | webencodings==0.5.1 115 | websocket-client==1.2.1 116 | Werkzeug==2.0.2 117 | widgetsnbextension==3.5.1 118 | yarl==1.7.0 119 | -------------------------------------------------------------------------------- /test.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0; python test_lseg.py --backbone clip_vitl16_384 --eval --dataset ade20k --data-path ../datasets/ \ 2 | --weights checkpoints/lseg_ade20k_l16.ckpt --widehead --no-scaleinv 3 | 4 | 5 | 6 | 7 | -------------------------------------------------------------------------------- /test_lseg_zs.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import numpy as np 4 | from tqdm import tqdm 5 | import torch 6 | import torch.nn.functional as F 7 | import torch.nn as nn 8 | from modules.lseg_module_zs import LSegModuleZS 9 | from additional_utils.models import LSeg_MultiEvalModule 10 | from fewshot_data.common.logger import Logger, AverageMeter 11 | from fewshot_data.common.vis import Visualizer 12 | from fewshot_data.common.evaluation import Evaluator 13 | from fewshot_data.common import utils 14 | from fewshot_data.data.dataset import FSSDataset 15 | 16 | 17 | class Options: 18 | def __init__(self): 19 | parser = argparse.ArgumentParser(description="PyTorch Segmentation") 20 | # model and dataset 21 | parser.add_argument( 22 | "--model", type=str, default="encnet", help="model name (default: encnet)" 23 | ) 24 | parser.add_argument( 25 | "--backbone", 26 | type=str, 27 | default="resnet50", 28 | help="backbone name (default: resnet50)", 29 | ) 30 | parser.add_argument( 31 | "--dataset", 32 | type=str, 33 | default="ade20k", 34 | help="dataset name (default: pascal12)", 35 | ) 36 | parser.add_argument( 37 | "--workers", type=int, default=16, metavar="N", help="dataloader threads" 38 | ) 39 | parser.add_argument( 40 | "--base-size", type=int, default=520, help="base image size" 41 | ) 42 | parser.add_argument( 43 | "--crop-size", type=int, default=480, help="crop image size" 44 | ) 45 | parser.add_argument( 46 | "--train-split", 47 | type=str, 48 | default="train", 49 | help="dataset train split (default: train)", 50 | ) 51 | # training hyper params 52 | parser.add_argument( 53 | "--aux", action="store_true", default=False, help="Auxilary Loss" 54 | ) 55 | parser.add_argument( 56 | "--se-loss", 57 | action="store_true", 58 | default=False, 59 | help="Semantic Encoding Loss SE-loss", 60 | ) 61 | parser.add_argument( 62 | "--se-weight", type=float, default=0.2, help="SE-loss weight (default: 0.2)" 63 | ) 64 | parser.add_argument( 65 | "--batch-size", 66 | type=int, 67 | default=16, 68 | metavar="N", 69 | help="input batch size for \ 70 | training (default: auto)", 71 | ) 72 | parser.add_argument( 73 | "--test-batch-size", 74 | type=int, 75 | default=16, 76 | metavar="N", 77 | help="input batch size for \ 78 | testing (default: same as batch size)", 79 | ) 80 | # cuda, seed and logging 81 | parser.add_argument( 82 | "--no-cuda", 83 | action="store_true", 84 | default=False, 85 | help="disables CUDA training", 86 | ) 87 | parser.add_argument( 88 | "--seed", type=int, default=1, metavar="S", help="random seed (default: 1)" 89 | ) 90 | # checking point 91 | parser.add_argument( 92 | "--weights", type=str, default=None, help="checkpoint to test" 93 | ) 94 | # evaluation option 95 | parser.add_argument( 96 | "--eval", action="store_true", default=False, help="evaluating mIoU" 97 | ) 98 | 99 | parser.add_argument( 100 | "--acc-bn", 101 | action="store_true", 102 | default=False, 103 | help="Re-accumulate BN statistics", 104 | ) 105 | parser.add_argument( 106 | "--test-val", 107 | action="store_true", 108 | default=False, 109 | help="generate masks on val set", 110 | ) 111 | parser.add_argument( 112 | "--no-val", 113 | action="store_true", 114 | default=False, 115 | help="skip validation during training", 116 | ) 117 | 118 | parser.add_argument( 119 | "--module", 120 | default='', 121 | help="select model definition", 122 | ) 123 | 124 | # test option 125 | parser.add_argument( 126 | "--no-scaleinv", 127 | dest="scale_inv", 128 | default=True, 129 | action="store_false", 130 | help="turn off scaleinv layers", 131 | ) 132 | 133 | parser.add_argument( 134 | "--widehead", default=False, action="store_true", help="wider output head" 135 | ) 136 | 137 | parser.add_argument( 138 | "--widehead_hr", 139 | default=False, 140 | action="store_true", 141 | help="wider output head", 142 | ) 143 | 144 | parser.add_argument( 145 | "--ignore_index", 146 | type=int, 147 | default=-1, 148 | help="numeric value of ignore label in gt", 149 | ) 150 | 151 | parser.add_argument( 152 | "--jobname", 153 | type=str, 154 | default="default", 155 | help="select which dataset", 156 | ) 157 | 158 | parser.add_argument( 159 | "--no-strict", 160 | dest="strict", 161 | default=True, 162 | action="store_false", 163 | help="no-strict copy the model", 164 | ) 165 | 166 | parser.add_argument( 167 | "--use_pretrained", 168 | type=str, 169 | default="True", 170 | help="whether use the default model to intialize the model", 171 | ) 172 | 173 | parser.add_argument( 174 | "--arch_option", 175 | type=int, 176 | default=0, 177 | help="which kind of architecture to be used", 178 | ) 179 | 180 | # fewshot options 181 | parser.add_argument( 182 | '--nshot', 183 | type=int, 184 | default=1 185 | ) 186 | parser.add_argument( 187 | '--fold', 188 | type=int, 189 | default=0, 190 | choices=[0, 1, 2, 3] 191 | ) 192 | parser.add_argument( 193 | '--nworker', 194 | type=int, 195 | default=0 196 | ) 197 | parser.add_argument( 198 | '--bsz', 199 | type=int, 200 | default=1 201 | ) 202 | parser.add_argument( 203 | '--benchmark', 204 | type=str, 205 | default='pascal', 206 | choices=['pascal', 'coco', 'fss', 'c2p'] 207 | ) 208 | parser.add_argument( 209 | '--datapath', 210 | type=str, 211 | default='fewshot_data/Datasets_HSN' 212 | ) 213 | 214 | parser.add_argument( 215 | "--activation", 216 | choices=['relu', 'lrelu', 'tanh'], 217 | default="relu", 218 | help="use which activation to activate the block", 219 | ) 220 | 221 | 222 | self.parser = parser 223 | 224 | def parse(self): 225 | args = self.parser.parse_args() 226 | args.cuda = not args.no_cuda and torch.cuda.is_available() 227 | print(args) 228 | return args 229 | 230 | 231 | def test(args): 232 | module_def = LSegModuleZS 233 | 234 | module = module_def.load_from_checkpoint( 235 | checkpoint_path=args.weights, 236 | data_path=args.datapath, 237 | dataset=args.dataset, 238 | backbone=args.backbone, 239 | aux=args.aux, 240 | num_features=256, 241 | aux_weight=0, 242 | se_loss=False, 243 | se_weight=0, 244 | base_lr=0, 245 | batch_size=1, 246 | max_epochs=0, 247 | ignore_index=args.ignore_index, 248 | dropout=0.0, 249 | scale_inv=args.scale_inv, 250 | augment=False, 251 | no_batchnorm=False, 252 | widehead=args.widehead, 253 | widehead_hr=args.widehead_hr, 254 | map_locatin="cpu", 255 | arch_option=args.arch_option, 256 | use_pretrained=args.use_pretrained, 257 | strict=args.strict, 258 | logpath='fewshot/logpath_4T/', 259 | fold=args.fold, 260 | block_depth=0, 261 | nshot=args.nshot, 262 | finetune_mode=False, 263 | activation=args.activation, 264 | ) 265 | 266 | Evaluator.initialize() 267 | if args.backbone in ["clip_resnet101"]: 268 | FSSDataset.initialize(img_size=480, datapath=args.datapath, use_original_imgsize=False, imagenet_norm=True) 269 | else: 270 | FSSDataset.initialize(img_size=480, datapath=args.datapath, use_original_imgsize=False) 271 | # dataloader 272 | args.benchmark = args.dataset 273 | dataloader = FSSDataset.build_dataloader(args.benchmark, args.bsz, args.nworker, args.fold, 'test', args.nshot) 274 | 275 | model = module.net.eval().cuda() 276 | # model = module.net.model.cpu() 277 | 278 | print(model) 279 | 280 | scales = ( 281 | [0.75, 1.0, 1.25, 1.5, 1.75, 2.0, 2.25] 282 | if args.dataset == "citys" 283 | else [0.5, 0.75, 1.0, 1.25, 1.5, 1.75] 284 | ) 285 | 286 | f = open("logs/fewshot/log_fewshot-test_nshot{}_{}.txt".format(args.nshot, args.dataset), "a+") 287 | 288 | utils.fix_randseed(0) 289 | average_meter = AverageMeter(dataloader.dataset) 290 | for idx, batch in enumerate(dataloader): 291 | batch = utils.to_cuda(batch) 292 | image = batch['query_img'] 293 | target = batch['query_mask'] 294 | class_info = batch['class_id'] 295 | # pred_mask = evaluator.parallel_forward(image, class_info) 296 | pred_mask = model(image, class_info) 297 | # assert pred_mask.argmax(dim=1).size() == batch['query_mask'].size() 298 | # 2. Evaluate prediction 299 | if args.benchmark == 'pascal' and batch['query_ignore_idx'] is not None: 300 | query_ignore_idx = batch['query_ignore_idx'] 301 | area_inter, area_union = Evaluator.classify_prediction(pred_mask.argmax(dim=1), target, query_ignore_idx) 302 | else: 303 | area_inter, area_union = Evaluator.classify_prediction(pred_mask.argmax(dim=1), target) 304 | 305 | average_meter.update(area_inter, area_union, class_info, loss=None) 306 | average_meter.write_process(idx, len(dataloader), epoch=-1, write_batch_idx=1) 307 | 308 | # Write evaluation results 309 | average_meter.write_result('Test', 0) 310 | test_miou, test_fb_iou = average_meter.compute_iou() 311 | 312 | Logger.info('Fold %d, %d-shot ==> mIoU: %5.2f \t FB-IoU: %5.2f' % (args.fold, args.nshot, test_miou.item(), test_fb_iou.item())) 313 | Logger.info('==================== Finished Testing ====================') 314 | f.write('{}\n'.format(args.weights)) 315 | f.write('Fold %d, %d-shot ==> mIoU: %5.2f \t FB-IoU: %5.2f\n' % (args.fold, args.nshot, test_miou.item(), test_fb_iou.item())) 316 | f.close() 317 | 318 | 319 | 320 | if __name__ == "__main__": 321 | args = Options().parse() 322 | torch.manual_seed(args.seed) 323 | test(args) 324 | -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #python -u train_lseg.py --dataset ade20k --data_path ../datasets --batch_size 4 --exp_name lseg_ade20k_l16 \ 3 | #--base_lr 0.004 --weight_decay 1e-4 --no-scaleinv --max_epochs 240 --widehead --accumulate_grad_batches 2 --backbone clip_vitl16_384 4 | 5 | python -u train_lseg.py --dataset ade20k --data_path ../datasets --batch_size 1 --exp_name lseg_ade20k_l16 \ 6 | --base_lr 0.004 --weight_decay 1e-4 --no-scaleinv --max_epochs 240 --widehead --accumulate_grad_batches 2 --backbone clip_vitl16_384 -------------------------------------------------------------------------------- /train_lseg.py: -------------------------------------------------------------------------------- 1 | from modules.lseg_module import LSegModule 2 | from utils import do_training, get_default_argument_parser 3 | 4 | if __name__ == "__main__": 5 | parser = LSegModule.add_model_specific_args(get_default_argument_parser()) 6 | args = parser.parse_args() 7 | do_training(args, LSegModule) 8 | --------------------------------------------------------------------------------