├── .gitignore
├── LICENSE
├── README.MD
├── additional_utils
├── encoding_models.py
└── models.py
├── data
└── __init__.py
├── fewshot_data
├── README.md
├── common
│ ├── evaluation.py
│ ├── logger.py
│ ├── utils.py
│ └── vis.py
├── data
│ ├── assets
│ │ ├── architecture.png
│ │ └── qualitative_results.png
│ ├── coco.py
│ ├── dataset.py
│ ├── fss.py
│ ├── pascal.py
│ └── splits
│ │ ├── coco
│ │ ├── trn
│ │ │ ├── fold0.pkl
│ │ │ ├── fold1.pkl
│ │ │ ├── fold2.pkl
│ │ │ └── fold3.pkl
│ │ └── val
│ │ │ ├── fold0.pkl
│ │ │ ├── fold1.pkl
│ │ │ ├── fold2.pkl
│ │ │ └── fold3.pkl
│ │ ├── fss
│ │ ├── test.txt
│ │ ├── trn.txt
│ │ └── val.txt
│ │ └── pascal
│ │ ├── trn
│ │ ├── fold0.txt
│ │ ├── fold1.txt
│ │ ├── fold2.txt
│ │ └── fold3.txt
│ │ └── val
│ │ ├── fold0.txt
│ │ ├── fold1.txt
│ │ ├── fold2.txt
│ │ └── fold3.txt
├── model
│ ├── base
│ │ ├── conv4d.py
│ │ ├── correlation.py
│ │ └── feature.py
│ ├── hsnet.py
│ └── learner.py
├── sbatch_run.sh
├── test.py
└── train.py
├── inputs
└── cat1.jpeg
├── label_files
├── ade20k_objectInfo150.txt
├── fewshot_coco.txt
├── fewshot_fss.txt
└── fewshot_pascal.txt
├── lseg_app.py
├── lseg_demo.ipynb
├── modules
├── lseg_module.py
├── lseg_module_zs.py
├── lsegmentation_module.py
├── lsegmentation_module_zs.py
└── models
│ ├── lseg_blocks.py
│ ├── lseg_blocks_zs.py
│ ├── lseg_net.py
│ ├── lseg_net_zs.py
│ ├── lseg_vit.py
│ └── lseg_vit_zs.py
├── prepare_ade20k.py
├── requirements.txt
├── test.sh
├── test_lseg.py
├── test_lseg_zs.py
├── train.sh
├── train_lseg.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 | checkpoints/
131 | logs/
132 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Intelligent Systems Lab Org
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.MD:
--------------------------------------------------------------------------------
1 | # PROJECT NOT UNDER ACTIVE MANAGEMENT
2 | This project will no longer be maintained by Intel.
3 | Intel has ceased development and contributions including, but not limited to, maintenance, bug fixes, new releases, or updates, to this project.
4 | Intel no longer accepts patches to this project.
5 | If you have an ongoing need to use this project, are interested in independently developing it, or would like to maintain patches for the open source software community, please create your own fork of this project.
6 |
7 | # Language-driven Semantic Segmentation (LSeg)
8 | The repo contains official PyTorch Implementation of paper [Language-driven Semantic Segmentation](https://arxiv.org/abs/2201.03546).
9 |
10 | ICLR 2022
11 |
12 | #### Authors:
13 | * [Boyi Li](https://sites.google.com/site/boyilics/home)
14 | * [Kilian Q. Weinberger](http://kilian.cs.cornell.edu/index.html)
15 | * [Serge Belongie](https://scholar.google.com/citations?user=ORr4XJYAAAAJ&hl=zh-CN)
16 | * [Vladlen Koltun](http://vladlen.info/)
17 | * [Rene Ranftl](https://scholar.google.at/citations?user=cwKg158AAAAJ&hl=de)
18 |
19 |
20 | ### Overview
21 |
22 |
23 | We present LSeg, a novel model for language-driven semantic image segmentation. LSeg uses a text encoder to compute embeddings of descriptive input labels (e.g., ''grass'' or 'building'') together with a transformer-based image encoder that computes dense per-pixel embeddings of the input image. The image encoder is trained with a contrastive objective to align pixel embeddings to the text embedding of the corresponding semantic class. The text embeddings provide a flexible label representation in which semantically similar labels map to similar regions in the embedding space (e.g., ''cat'' and ''furry''). This allows LSeg to generalize to previously unseen categories at test time, without retraining or even requiring a single additional training sample. We demonstrate that our approach achieves highly competitive zero-shot performance compared to existing zero- and few-shot semantic segmentation methods, and even matches the accuracy of traditional segmentation algorithms when a fixed label set is provided.
24 |
25 | Please check our [Video Demo (4k)](https://www.youtube.com/watch?v=bmU75rsmv6s) to further showcase the capabilities of LSeg.
26 |
27 | ## Usage
28 | ### Installation
29 | Option 1:
30 |
31 | ``` pip install -r requirements.txt ```
32 |
33 | Option 2:
34 | ```
35 | conda install ipython
36 | pip install torch==1.7.1 torchvision==0.8.2 torchaudio==0.7.2
37 | pip install git+https://github.com/zhanghang1989/PyTorch-Encoding/
38 | pip install pytorch-lightning==1.3.5
39 | pip install opencv-python
40 | pip install imageio
41 | pip install ftfy regex tqdm
42 | pip install git+https://github.com/openai/CLIP.git
43 | pip install altair
44 | pip install streamlit
45 | pip install --upgrade protobuf
46 | pip install timm
47 | pip install tensorboardX
48 | pip install matplotlib
49 | pip install test-tube
50 | pip install wandb
51 | ```
52 |
53 | ### Data Preparation
54 | By default, for training, testing and demo, we use [ADE20k](https://groups.csail.mit.edu/vision/datasets/ADE20K/).
55 |
56 | ```
57 | python prepare_ade20k.py
58 | unzip ../datasets/ADEChallengeData2016.zip
59 | ```
60 |
61 | Note: for demo, if you want to use random inputs, you can ignore data loading and comment the code at [link](https://github.com/isl-org/lang-seg/blob/main/modules/lseg_module.py#L55).
62 |
63 |
64 | ### 🌻 Try demo now
65 |
66 | #### Download Demo Model
67 |
68 |
69 |
70 | name |
71 | backbone |
72 | text encoder |
73 | url |
74 |
75 |
76 |
77 |
78 | Model for demo |
79 | ViT-L/16 |
80 | CLIP ViT-B/32 |
81 | download |
82 |
83 |
84 |
85 |
86 | #### 👉 Option 1: Running interactive app
87 | Download the model for demo and put it under folder `checkpoints` as `checkpoints/demo_e200.ckpt`.
88 |
89 | Then ``` streamlit run lseg_app.py ```
90 |
91 | #### 👉 Option 2: Jupyter Notebook
92 | Download the model for demo and put it under folder `checkpoints` as `checkpoints/demo_e200.ckpt`.
93 |
94 | Then follow [lseg_demo.ipynb](https://github.com/isl-org/lang-seg/blob/main/lseg_demo.ipynb) to play around with LSeg. Enjoy!
95 |
96 |
97 |
98 | ### Training and Testing Example
99 | Training: Backbone = ViT-L/16, Text Encoder from CLIP ViT-B/32
100 |
101 | ``` bash train.sh ```
102 |
103 | Testing: Backbone = ViT-L/16, Text Encoder from CLIP ViT-B/32
104 |
105 | ``` bash test.sh ```
106 |
107 | ### Zero-shot Experiments
108 | #### Data Preparation
109 | Please follow [HSNet](https://github.com/juhongm999/hsnet) and put all dataset in `data/Dataset_HSN`
110 |
111 | #### Pascal-5i
112 | ```
113 | for fold in 0 1 2 3; do
114 | python -u test_lseg_zs.py --backbone clip_resnet101 --module clipseg_DPT_test_v2 --dataset pascal \
115 | --widehead --no-scaleinv --arch_option 0 --ignore_index 255 --fold ${fold} --nshot 0 \
116 | --weights checkpoints/pascal_fold${fold}.ckpt
117 | done
118 | ```
119 | #### COCO-20i
120 | ```
121 | for fold in 0 1 2 3; do
122 | python -u test_lseg_zs.py --backbone clip_resnet101 --module clipseg_DPT_test_v2 --dataset coco \
123 | --widehead --no-scaleinv --arch_option 0 --ignore_index 255 --fold ${fold} --nshot 0 \
124 | --weights checkpoints/pascal_fold${fold}.ckpt
125 | done
126 | ```
127 | #### FSS
128 | ```
129 | python -u test_lseg_zs.py --backbone clip_vitl16_384 --module clipseg_DPT_test_v2 --dataset fss \
130 | --widehead --no-scaleinv --arch_option 0 --ignore_index 255 --fold 0 --nshot 0 \
131 | --weights checkpoints/fss_l16.ckpt
132 | ```
133 |
134 | ```
135 | python -u test_lseg_zs.py --backbone clip_resnet101 --module clipseg_DPT_test_v2 --dataset fss \
136 | --widehead --no-scaleinv --arch_option 0 --ignore_index 255 --fold 0 --nshot 0 \
137 | --weights checkpoints/fss_rn101.ckpt
138 | ```
139 |
140 | #### Model Zoo
141 |
142 |
143 |
144 | dataset |
145 | fold |
146 | backbone |
147 | text encoder |
148 | performance |
149 | url |
150 |
151 |
152 |
153 |
154 | pascal |
155 | 0 |
156 | ResNet101 |
157 | CLIP ViT-B/32 |
158 | 52.8 |
159 | download |
160 |
161 |
162 | pascal |
163 | 1 |
164 | ResNet101 |
165 | CLIP ViT-B/32 |
166 | 53.8 |
167 | download |
168 |
169 |
170 | pascal |
171 | 2 |
172 | ResNet101 |
173 | CLIP ViT-B/32 |
174 | 44.4 |
175 | download |
176 |
177 |
178 | pascal |
179 | 3 |
180 | ResNet101 |
181 | CLIP ViT-B/32 |
182 | 38.5 |
183 | download |
184 |
185 |
186 | coco |
187 | 0 |
188 | ResNet101 |
189 | CLIP ViT-B/32 |
190 | 22.1 |
191 | download |
192 |
193 |
194 | coco |
195 | 1 |
196 | ResNet101 |
197 | CLIP ViT-B/32 |
198 | 25.1 |
199 | download |
200 |
201 |
202 | coco |
203 | 2 |
204 | ResNet101 |
205 | CLIP ViT-B/32 |
206 | 24.9 |
207 | download |
208 |
209 |
210 | coco |
211 | 3 |
212 | ResNet101 |
213 | CLIP ViT-B/32 |
214 | 21.5 |
215 | download |
216 |
217 |
218 | fss |
219 | - |
220 | ResNet101 |
221 | CLIP ViT-B/32 |
222 | 84.7 |
223 | download |
224 |
225 |
226 | fss |
227 | - |
228 | ViT-L/16 |
229 | CLIP ViT-B/32 |
230 | 87.8 |
231 | download |
232 |
233 |
234 |
235 |
236 | If you find this repo useful, please cite:
237 | ```
238 | @inproceedings{
239 | li2022languagedriven,
240 | title={Language-driven Semantic Segmentation},
241 | author={Boyi Li and Kilian Q Weinberger and Serge Belongie and Vladlen Koltun and Rene Ranftl},
242 | booktitle={International Conference on Learning Representations},
243 | year={2022},
244 | url={https://openreview.net/forum?id=RriDjddCLN}
245 | }
246 | ```
247 |
248 | ## Acknowledgement
249 | Thanks to the code base from [DPT](https://github.com/isl-org/DPT), [Pytorch_lightning](https://github.com/PyTorchLightning/pytorch-lightning), [CLIP](https://github.com/openai/CLIP), [Pytorch Encoding](https://github.com/zhanghang1989/PyTorch-Encoding), [Streamlit](https://streamlit.io/), [Wandb](https://wandb.ai/site)
250 |
--------------------------------------------------------------------------------
/additional_utils/encoding_models.py:
--------------------------------------------------------------------------------
1 | ###########################################################################
2 | # Referred to: https://github.com/zhanghang1989/PyTorch-Encoding
3 | ###########################################################################
4 | import math
5 | import numpy as np
6 |
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 | from torch.nn.parallel.data_parallel import DataParallel
11 | from torch.nn.parallel.scatter_gather import scatter
12 | import threading
13 | import torch
14 | from torch.cuda._utils import _get_device_index
15 | from torch.cuda.amp import autocast
16 | from torch._utils import ExceptionWrapper
17 |
18 | up_kwargs = {'mode': 'bilinear', 'align_corners': True}
19 |
20 | __all__ = ['MultiEvalModule']
21 |
22 | class MultiEvalModule(DataParallel):
23 | """Multi-size Segmentation Eavluator"""
24 | def __init__(self, module, nclass, device_ids=None, flip=True,
25 | scales=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75]):
26 | super(MultiEvalModule, self).__init__(module, device_ids)
27 | self.nclass = nclass
28 | self.base_size = module.base_size
29 | self.crop_size = module.crop_size
30 | self.scales = scales
31 | self.flip = flip
32 | print('MultiEvalModule: base_size {}, crop_size {}'. \
33 | format(self.base_size, self.crop_size))
34 |
35 | def parallel_forward(self, inputs, **kwargs):
36 | """Multi-GPU Mult-size Evaluation
37 |
38 | Args:
39 | inputs: list of Tensors
40 | """
41 | inputs = [(input.unsqueeze(0).cuda(device),)
42 | for input, device in zip(inputs, self.device_ids)]
43 | replicas = self.replicate(self, self.device_ids[:len(inputs)])
44 | kwargs = scatter(kwargs, target_gpus, dim) if kwargs else []
45 | if len(inputs) < len(kwargs):
46 | inputs.extend([() for _ in range(len(kwargs) - len(inputs))])
47 | elif len(kwargs) < len(inputs):
48 | kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))])
49 | outputs = self.parallel_apply(replicas, inputs, kwargs)
50 | #for out in outputs:
51 | # print('out.size()', out.size())
52 | return outputs
53 |
54 | def forward(self, image):
55 | """Mult-size Evaluation"""
56 | # only single image is supported for evaluation
57 | batch, _, h, w = image.size()
58 | assert(batch == 1)
59 | stride_rate = 2.0/3.0
60 | crop_size = self.crop_size
61 | stride = int(crop_size * stride_rate)
62 | with torch.cuda.device_of(image):
63 | scores = image.new().resize_(batch,self.nclass,h,w).zero_().cuda()
64 |
65 | for scale in self.scales:
66 | long_size = int(math.ceil(self.base_size * scale))
67 | if h > w:
68 | height = long_size
69 | width = int(1.0 * w * long_size / h + 0.5)
70 | short_size = width
71 | else:
72 | width = long_size
73 | height = int(1.0 * h * long_size / w + 0.5)
74 | short_size = height
75 | """
76 | short_size = int(math.ceil(self.base_size * scale))
77 | if h > w:
78 | width = short_size
79 | height = int(1.0 * h * short_size / w)
80 | long_size = height
81 | else:
82 | height = short_size
83 | width = int(1.0 * w * short_size / h)
84 | long_size = width
85 | """
86 | # resize image to current size
87 | cur_img = resize_image(image, height, width, **self.module._up_kwargs)
88 | if long_size <= crop_size:
89 | pad_img = pad_image(cur_img, self.module.mean,
90 | self.module.std, crop_size)
91 | outputs = module_inference(self.module, pad_img, self.flip)
92 | outputs = crop_image(outputs, 0, height, 0, width)
93 | else:
94 | if short_size < crop_size:
95 | # pad if needed
96 | pad_img = pad_image(cur_img, self.module.mean,
97 | self.module.std, crop_size)
98 | else:
99 | pad_img = cur_img
100 | _,_,ph,pw = pad_img.size()
101 | assert(ph >= height and pw >= width)
102 | # grid forward and normalize
103 | h_grids = int(math.ceil(1.0 * (ph-crop_size)/stride)) + 1
104 | w_grids = int(math.ceil(1.0 * (pw-crop_size)/stride)) + 1
105 | with torch.cuda.device_of(image):
106 | outputs = image.new().resize_(batch,self.nclass,ph,pw).zero_().cuda()
107 | count_norm = image.new().resize_(batch,1,ph,pw).zero_().cuda()
108 | # grid evaluation
109 | for idh in range(h_grids):
110 | for idw in range(w_grids):
111 | h0 = idh * stride
112 | w0 = idw * stride
113 | h1 = min(h0 + crop_size, ph)
114 | w1 = min(w0 + crop_size, pw)
115 | crop_img = crop_image(pad_img, h0, h1, w0, w1)
116 | # pad if needed
117 | pad_crop_img = pad_image(crop_img, self.module.mean,
118 | self.module.std, crop_size)
119 | output = module_inference(self.module, pad_crop_img, self.flip)
120 | outputs[:,:,h0:h1,w0:w1] += crop_image(output,
121 | 0, h1-h0, 0, w1-w0)
122 | count_norm[:,:,h0:h1,w0:w1] += 1
123 | assert((count_norm==0).sum()==0)
124 | outputs = outputs / count_norm
125 | outputs = outputs[:,:,:height,:width]
126 |
127 | score = resize_image(outputs, h, w, **self.module._up_kwargs)
128 | scores += score
129 |
130 | return scores
131 |
132 |
133 | def module_inference(module, image, flip=True):
134 | output = module.evaluate(image)
135 | if flip:
136 | fimg = flip_image(image)
137 | foutput = module.evaluate(fimg)
138 | output += flip_image(foutput)
139 | return output
140 |
141 | def resize_image(img, h, w, **up_kwargs):
142 | return F.interpolate(img, (h, w), **up_kwargs)
143 |
144 | def pad_image(img, mean, std, crop_size):
145 | b,c,h,w = img.size()
146 | assert(c==3)
147 | padh = crop_size - h if h < crop_size else 0
148 | padw = crop_size - w if w < crop_size else 0
149 | pad_values = -np.array(mean) / np.array(std)
150 | img_pad = img.new().resize_(b,c,h+padh,w+padw)
151 | for i in range(c):
152 | # note that pytorch pad params is in reversed orders
153 | img_pad[:,i,:,:] = F.pad(img[:,i,:,:], (0, padw, 0, padh), value=pad_values[i])
154 | assert(img_pad.size(2)>=crop_size and img_pad.size(3)>=crop_size)
155 | return img_pad
156 |
157 | def crop_image(img, h0, h1, w0, w1):
158 | return img[:,:,h0:h1,w0:w1]
159 |
160 | def flip_image(img):
161 | assert(img.dim()==4)
162 | with torch.cuda.device_of(img):
163 | idx = torch.arange(img.size(3)-1, -1, -1).type_as(img).long()
164 | return img.index_select(3, idx)
--------------------------------------------------------------------------------
/additional_utils/models.py:
--------------------------------------------------------------------------------
1 | ###########################################################################
2 | # Referred to: https://github.com/zhanghang1989/PyTorch-Encoding
3 | ###########################################################################
4 | import math
5 | import numpy as np
6 |
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 | from torch.nn.parallel.data_parallel import DataParallel
11 | from torch.nn.parallel.scatter_gather import scatter
12 | import threading
13 | import torch
14 | from torch.cuda._utils import _get_device_index
15 | from torch.cuda.amp import autocast
16 | from torch._utils import ExceptionWrapper
17 |
18 | up_kwargs = {'mode': 'bilinear', 'align_corners': True}
19 |
20 | __all__ = ['LSeg_MultiEvalModule']
21 |
22 |
23 | class LSeg_MultiEvalModule(DataParallel):
24 | """Multi-size Segmentation Eavluator"""
25 | def __init__(self, module, device_ids=None, flip=True,
26 | scales=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75]):
27 | super(LSeg_MultiEvalModule, self).__init__(module, device_ids)
28 | self.base_size = module.base_size
29 | self.crop_size = module.crop_size
30 | self.scales = scales
31 | self.flip = flip
32 | print('MultiEvalModule: base_size {}, crop_size {}'. \
33 | format(self.base_size, self.crop_size))
34 |
35 | def parallel_forward(self, inputs, label_set='', **kwargs):
36 | """Multi-GPU Mult-size Evaluation
37 |
38 | Args:
39 | inputs: list of Tensors
40 | """
41 | if len(label_set) < 10:
42 | print('** MultiEvalModule parallel_forward phase: {} **'.format(label_set))
43 | self.nclass = len(label_set)
44 | inputs = [(input.unsqueeze(0).cuda(device),)
45 | for input, device in zip(inputs, self.device_ids)]
46 | replicas = self.replicate(self, self.device_ids[:len(inputs)])
47 | kwargs = scatter(kwargs, target_gpus, dim) if kwargs else []
48 | if len(inputs) < len(kwargs):
49 | inputs.extend([() for _ in range(len(kwargs) - len(inputs))])
50 | elif len(kwargs) < len(inputs):
51 | kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))])
52 | outputs = parallel_apply(replicas, inputs, label_set, kwargs)
53 | return outputs
54 |
55 | def forward(self, image, label_set=''):
56 | """Mult-size Evaluation"""
57 | # only single image is supported for evaluation
58 | if len(label_set) < 10:
59 | print('** MultiEvalModule forward phase: {} **'.format(label_set))
60 | batch, _, h, w = image.size()
61 | assert(batch == 1)
62 | self.nclass = len(label_set)
63 | stride_rate = 2.0/3.0
64 | crop_size = self.crop_size
65 | stride = int(crop_size * stride_rate)
66 | with torch.cuda.device_of(image):
67 | scores = image.new().resize_(batch,self.nclass,h,w).zero_().cuda()
68 |
69 | for scale in self.scales:
70 | long_size = int(math.ceil(self.base_size * scale))
71 | if h > w:
72 | height = long_size
73 | width = int(1.0 * w * long_size / h + 0.5)
74 | short_size = width
75 | else:
76 | width = long_size
77 | height = int(1.0 * h * long_size / w + 0.5)
78 | short_size = height
79 | """
80 | short_size = int(math.ceil(self.base_size * scale))
81 | if h > w:
82 | width = short_size
83 | height = int(1.0 * h * short_size / w)
84 | long_size = height
85 | else:
86 | height = short_size
87 | width = int(1.0 * w * short_size / h)
88 | long_size = width
89 | """
90 | # resize image to current size
91 | cur_img = resize_image(image, height, width, **self.module._up_kwargs)
92 | if long_size <= crop_size:
93 | pad_img = pad_image(cur_img, self.module.mean,
94 | self.module.std, crop_size)
95 | outputs = module_inference(self.module, pad_img, label_set, self.flip)
96 | outputs = crop_image(outputs, 0, height, 0, width)
97 | else:
98 | if short_size < crop_size:
99 | # pad if needed
100 | pad_img = pad_image(cur_img, self.module.mean,
101 | self.module.std, crop_size)
102 | else:
103 | pad_img = cur_img
104 | _,_,ph,pw = pad_img.shape #.size()
105 | assert(ph >= height and pw >= width)
106 | # grid forward and normalize
107 | h_grids = int(math.ceil(1.0 * (ph-crop_size)/stride)) + 1
108 | w_grids = int(math.ceil(1.0 * (pw-crop_size)/stride)) + 1
109 | with torch.cuda.device_of(image):
110 | outputs = image.new().resize_(batch,self.nclass,ph,pw).zero_().cuda()
111 | count_norm = image.new().resize_(batch,1,ph,pw).zero_().cuda()
112 | # grid evaluation
113 | for idh in range(h_grids):
114 | for idw in range(w_grids):
115 | h0 = idh * stride
116 | w0 = idw * stride
117 | h1 = min(h0 + crop_size, ph)
118 | w1 = min(w0 + crop_size, pw)
119 | crop_img = crop_image(pad_img, h0, h1, w0, w1)
120 | # pad if needed
121 | pad_crop_img = pad_image(crop_img, self.module.mean,
122 | self.module.std, crop_size)
123 | output = module_inference(self.module, pad_crop_img, label_set, self.flip)
124 | outputs[:,:,h0:h1,w0:w1] += crop_image(output,
125 | 0, h1-h0, 0, w1-w0)
126 | count_norm[:,:,h0:h1,w0:w1] += 1
127 | assert((count_norm==0).sum()==0)
128 | outputs = outputs / count_norm
129 | outputs = outputs[:,:,:height,:width]
130 | score = resize_image(outputs, h, w, **self.module._up_kwargs)
131 | scores += score
132 | return scores
133 |
134 | def module_inference(module, image, label_set, flip=True):
135 | output = module.evaluate_random(image, label_set)
136 | if flip:
137 | fimg = flip_image(image)
138 | foutput = module.evaluate_random(fimg, label_set)
139 | output += flip_image(foutput)
140 | return output
141 |
142 | def resize_image(img, h, w, **up_kwargs):
143 | return F.interpolate(img, (h, w), **up_kwargs)
144 |
145 | def pad_image(img, mean, std, crop_size):
146 | b,c,h,w = img.shape #.size()
147 | assert(c==3)
148 | padh = crop_size - h if h < crop_size else 0
149 | padw = crop_size - w if w < crop_size else 0
150 | pad_values = -np.array(mean) / np.array(std)
151 | img_pad = img.new().resize_(b,c,h+padh,w+padw)
152 | for i in range(c):
153 | # note that pytorch pad params is in reversed orders
154 | img_pad[:,i,:,:] = F.pad(img[:,i,:,:], (0, padw, 0, padh), value=pad_values[i])
155 | assert(img_pad.size(2)>=crop_size and img_pad.size(3)>=crop_size)
156 | return img_pad
157 |
158 | def crop_image(img, h0, h1, w0, w1):
159 | return img[:,:,h0:h1,w0:w1]
160 |
161 | def flip_image(img):
162 | assert(img.dim()==4)
163 | with torch.cuda.device_of(img):
164 | idx = torch.arange(img.size(3)-1, -1, -1).type_as(img).long()
165 | return img.index_select(3, idx)
166 |
167 |
168 | def get_a_var(obj):
169 | if isinstance(obj, torch.Tensor):
170 | return obj
171 |
172 | if isinstance(obj, list) or isinstance(obj, tuple):
173 | for result in map(get_a_var, obj):
174 | if isinstance(result, torch.Tensor):
175 | return result
176 | if isinstance(obj, dict):
177 | for result in map(get_a_var, obj.items()):
178 | if isinstance(result, torch.Tensor):
179 | return result
180 | return None
181 |
182 |
183 | def parallel_apply(modules, inputs, label_set, kwargs_tup=None, devices=None):
184 | r"""Applies each `module` in :attr:`modules` in parallel on arguments
185 | contained in :attr:`inputs` (positional) and :attr:`kwargs_tup` (keyword)
186 | on each of :attr:`devices`.
187 |
188 | Args:
189 | modules (Module): modules to be parallelized
190 | inputs (tensor): inputs to the modules
191 | devices (list of int or torch.device): CUDA devices
192 |
193 | :attr:`modules`, :attr:`inputs`, :attr:`kwargs_tup` (if given), and
194 | :attr:`devices` (if given) should all have same length. Moreover, each
195 | element of :attr:`inputs` can either be a single object as the only argument
196 | to a module, or a collection of positional arguments.
197 | """
198 | assert len(modules) == len(inputs)
199 | if kwargs_tup is not None:
200 | assert len(modules) == len(kwargs_tup)
201 | else:
202 | kwargs_tup = ({},) * len(modules)
203 | if devices is not None:
204 | assert len(modules) == len(devices)
205 | else:
206 | devices = [None] * len(modules)
207 | devices = [_get_device_index(x, True) for x in devices]
208 | lock = threading.Lock()
209 | results = {}
210 | grad_enabled, autocast_enabled = torch.is_grad_enabled(), torch.is_autocast_enabled()
211 |
212 | def _worker(i, module, input, label_set, kwargs, device=None):
213 | torch.set_grad_enabled(grad_enabled)
214 | if device is None:
215 | device = get_a_var(input).get_device()
216 | try:
217 | with torch.cuda.device(device), autocast(enabled=autocast_enabled):
218 | # this also avoids accidental slicing of `input` if it is a Tensor
219 | if not isinstance(input, (list, tuple)):
220 | input = (input,)
221 | output = module(*input, label_set, **kwargs)
222 | with lock:
223 | results[i] = output
224 | except Exception:
225 | with lock:
226 | results[i] = ExceptionWrapper(
227 | where="in replica {} on device {}".format(i, device))
228 |
229 | if len(modules) > 1:
230 | threads = [threading.Thread(target=_worker,
231 | args=(i, module, input, label_set, kwargs, device))
232 | for i, (module, input, kwargs, device) in
233 | enumerate(zip(modules, inputs, kwargs_tup, devices))]
234 |
235 | for thread in threads:
236 | thread.start()
237 | for thread in threads:
238 | thread.join()
239 | else:
240 | _worker(0, modules[0], inputs[0], label_set, kwargs_tup[0], devices[0])
241 |
242 | outputs = []
243 | for i in range(len(inputs)):
244 | output = results[i]
245 | if isinstance(output, ExceptionWrapper):
246 | output.reraise()
247 | outputs.append(output)
248 | return outputs
249 |
250 |
251 |
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
1 | import copy
2 |
3 | import itertools
4 | import functools
5 | import numpy as np
6 | import torch
7 | import torch.utils.data
8 | import torchvision.transforms as torch_transforms
9 | import encoding.datasets as enc_ds
10 |
11 | encoding_datasets = {
12 | x: functools.partial(enc_ds.get_dataset, x)
13 | for x in ["coco", "ade20k", "pascal_voc", "pascal_aug", "pcontext", "citys"]
14 | }
15 |
16 |
17 | def get_dataset(name, **kwargs):
18 | if name in encoding_datasets:
19 | return encoding_datasets[name.lower()](**kwargs)
20 | assert False, f"dataset {name} not found"
21 |
22 |
23 | def get_available_datasets():
24 | return list(encoding_datasets.keys())
25 |
--------------------------------------------------------------------------------
/fewshot_data/README.md:
--------------------------------------------------------------------------------
1 | [](https://paperswithcode.com/sota/few-shot-semantic-segmentation-on-pascal-5i-1?p=hypercorrelation-squeeze-for-few-shot)
2 | [](https://paperswithcode.com/sota/few-shot-semantic-segmentation-on-pascal-5i-5?p=hypercorrelation-squeeze-for-few-shot)
3 | [](https://paperswithcode.com/sota/few-shot-semantic-segmentation-on-pascal-5i?p=hypercorrelation-squeeze-for-few-shot)
4 | [](https://paperswithcode.com/sota/few-shot-semantic-segmentation-on-coco-20i-1?p=hypercorrelation-squeeze-for-few-shot)
5 | [](https://paperswithcode.com/sota/few-shot-semantic-segmentation-on-coco-20i-5?p=hypercorrelation-squeeze-for-few-shot)
6 | [](https://paperswithcode.com/sota/few-shot-semantic-segmentation-on-coco-20i-10?p=hypercorrelation-squeeze-for-few-shot)
7 | [](https://paperswithcode.com/sota/few-shot-semantic-segmentation-on-fss-1000-1?p=hypercorrelation-squeeze-for-few-shot)
8 | [](https://paperswithcode.com/sota/few-shot-semantic-segmentation-on-fss-1000-5?p=hypercorrelation-squeeze-for-few-shot)
9 |
10 |
11 | ## Hypercorrelation Squeeze for Few-Shot Segmentation
12 | This is the implementation of the paper "Hypercorrelation Squeeze for Few-Shot Segmentation" by Juhong Min, Dahyun Kang, and Minsu Cho. Implemented on Python 3.7 and Pytorch 1.5.1.
13 |
14 |
15 |
16 |
17 |
18 | For more information, check out project [[website](http://cvlab.postech.ac.kr/research/HSNet/)] and the paper on [[arXiv](https://arxiv.org/abs/2104.01538)].
19 |
20 | ## Requirements
21 |
22 | - Python 3.7
23 | - PyTorch 1.5.1
24 | - cuda 10.1
25 | - tensorboard 1.14
26 |
27 | Conda environment settings:
28 | ```bash
29 | conda create -n hsnet python=3.7
30 | conda activate hsnet
31 |
32 | conda install pytorch=1.5.1 torchvision cudatoolkit=10.1 -c pytorch
33 | conda install -c conda-forge tensorflow
34 | pip install tensorboardX
35 | ```
36 | ## Preparing Few-Shot Segmentation Datasets
37 | Download following datasets:
38 |
39 | > #### 1. PASCAL-5i
40 | > Download PASCAL VOC2012 devkit (train/val data):
41 | > ```bash
42 | > wget http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar
43 | > ```
44 | > Download PASCAL VOC2012 SDS extended mask annotations from our [[Google Drive](https://drive.google.com/file/d/10zxG2VExoEZUeyQl_uXga2OWHjGeZaf2/view?usp=sharing)].
45 |
46 | > #### 2. COCO-20i
47 | > Download COCO2014 train/val images and annotations:
48 | > ```bash
49 | > wget http://images.cocodataset.org/zips/train2014.zip
50 | > wget http://images.cocodataset.org/zips/val2014.zip
51 | > wget http://images.cocodataset.org/annotations/annotations_trainval2014.zip
52 | > ```
53 | > Download COCO2014 train/val annotations from our Google Drive: [[train2014.zip](https://drive.google.com/file/d/1cwup51kcr4m7v9jO14ArpxKMA4O3-Uge/view?usp=sharing)], [[val2014.zip](https://drive.google.com/file/d/1PNw4U3T2MhzAEBWGGgceXvYU3cZ7mJL1/view?usp=sharing)].
54 | > (and locate both train2014/ and val2014/ under annotations/ directory).
55 |
56 | > #### 3. FSS-1000
57 | > Download FSS-1000 images and annotations from our [[Google Drive](https://drive.google.com/file/d/1Fn-cUESMMF1pQy8Xff-vPQvXJdZoUlP3/view?usp=sharing)].
58 |
59 | Create a directory '../Datasets_HSN' for the above three few-shot segmentation datasets and appropriately place each dataset to have following directory structure:
60 |
61 | ../ # parent directory
62 | ├── ./ # current (project) directory
63 | │ ├── common/ # (dir.) helper functions
64 | │ ├── data/ # (dir.) dataloaders and splits for each FSSS dataset
65 | │ ├── model/ # (dir.) implementation of Hypercorrelation Squeeze Network model
66 | │ ├── README.md # intstruction for reproduction
67 | │ ├── train.py # code for training HSNet
68 | │ └── test.py # code for testing HSNet
69 | └── Datasets_HSN/
70 | ├── VOC2012/ # PASCAL VOC2012 devkit
71 | │ ├── Annotations/
72 | │ ├── ImageSets/
73 | │ ├── ...
74 | │ └── SegmentationClassAug/
75 | ├── COCO2014/
76 | │ ├── annotations/
77 | │ │ ├── train2014/ # (dir.) training masks (from Google Drive)
78 | │ │ ├── val2014/ # (dir.) validation masks (from Google Drive)
79 | │ │ └── ..some json files..
80 | │ ├── train2014/
81 | │ └── val2014/
82 | └── FSS-1000/ # (dir.) contains 1000 object classes
83 | ├── abacus/
84 | ├── ...
85 | └── zucchini/
86 |
87 | ## Training
88 | > ### 1. PASCAL-5i
89 | > ```bash
90 | > python train.py --backbone {vgg16, resnet50, resnet101}
91 | > --fold {0, 1, 2, 3}
92 | > --benchmark pascal
93 | > --lr 1e-3
94 | > --bsz 20
95 | > --logpath "your_experiment_name"
96 | > ```
97 | > * Training takes approx. 2 days until convergence (trained with four 2080 Ti GPUs).
98 |
99 |
100 | > ### 2. COCO-20i
101 | > ```bash
102 | > python train.py --backbone {resnet50, resnet101}
103 | > --fold {0, 1, 2, 3}
104 | > --benchmark coco
105 | > --lr 1e-3
106 | > --bsz 40
107 | > --logpath "your_experiment_name"
108 | > ```
109 | > * Training takes approx. 1 week until convergence (trained four Titan RTX GPUs).
110 |
111 | > ### 3. FSS-1000
112 | > ```bash
113 | > python train.py --backbone {vgg16, resnet50, resnet101}
114 | > --benchmark fss
115 | > --lr 1e-3
116 | > --bsz 20
117 | > --logpath "your_experiment_name"
118 | > ```
119 | > * Training takes approx. 3 days until convergence (trained with four 2080 Ti GPUs).
120 |
121 | > ### Babysitting training:
122 | > Use tensorboard to babysit training progress:
123 | > - For each experiment, a directory that logs training progress will be automatically generated under logs/ directory.
124 | > - From terminal, run 'tensorboard --logdir logs/' to monitor the training progress.
125 | > - Choose the best model when the validation (mIoU) curve starts to saturate.
126 |
127 |
128 |
129 | ## Testing
130 |
131 | > ### 1. PASCAL-5i
132 | > Pretrained models with tensorboard logs are available on our [[Google Drive](https://drive.google.com/drive/folders/1z4KgjgOu--k6YuIj3qWrGg264GRcMis2?usp=sharing)].
133 | > ```bash
134 | > python test.py --backbone {vgg16, resnet50, resnet101}
135 | > --fold {0, 1, 2, 3}
136 | > --benchmark pascal
137 | > --nshot {1, 5}
138 | > --load "path_to_trained_model/best_model.pt"
139 | > ```
140 |
141 |
142 | > ### 2. COCO-20i
143 | > Pretrained models with tensorboard logs are available on our [[Google Drive](https://drive.google.com/drive/folders/1WpwmCQzxTWhJD5aLQhsgJASaoxxqmFUk?usp=sharing)].
144 | > ```bash
145 | > python test.py --backbone {resnet50, resnet101}
146 | > --fold {0, 1, 2, 3}
147 | > --benchmark coco
148 | > --nshot {1, 5}
149 | > --load "path_to_trained_model/best_model.pt"
150 | > ```
151 |
152 | > ### 3. FSS-1000
153 | > Pretrained models with tensorboard logs are available on our [[Google Drive](https://drive.google.com/drive/folders/1JOaaJknGwsrSEPoLF3x6_lDiy4XfAe99?usp=sharing)].
154 | > ```bash
155 | > python test.py --backbone {vgg16, resnet50, resnet101}
156 | > --benchmark fss
157 | > --nshot {1, 5}
158 | > --load "path_to_trained_model/best_model.pt"
159 | > ```
160 |
161 | > ### 4. Evaluation *without support feature masking* on PASCAL-5i
162 | > * To reproduce the results in Tab.1 of our main paper, **COMMENT OUT line 51 in hsnet.py**: support_feats = self.mask_feature(support_feats, support_mask.clone())
163 | >
164 | > Pretrained models with tensorboard logs are available on our [[Google Drive](https://drive.google.com/drive/folders/18YWMCePIrza194pZvVMqQBuYqhwBmJwd?usp=sharing)].
165 | > ```bash
166 | > python test.py --backbone resnet101
167 | > --fold {0, 1, 2, 3}
168 | > --benchmark pascal
169 | > --nshot {1, 5}
170 | > --load "path_to_trained_model/best_model.pt"
171 | > ```
172 |
173 |
174 | ## Visualization
175 |
176 | * To visualize mask predictions, add command line argument **--visualize**:
177 | (prediction results will be saved under vis/ directory)
178 | ```bash
179 | python test.py '...other arguments...' --visualize
180 | ```
181 |
182 | #### Example qualitative results (1-shot):
183 |
184 |
185 |
186 |
187 |
188 | ## BibTeX
189 | If you use this code for your research, please consider citing:
190 | ````BibTeX
191 | @InProceedings{min2021hypercorrelation,
192 | title={Hypercorrelation Squeeze for Few-Shot Segmentation},
193 | author={Juhong Min and Dahyun Kang and Minsu Cho},
194 | booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
195 | year={2021}
196 | }
197 | ````
198 |
--------------------------------------------------------------------------------
/fewshot_data/common/evaluation.py:
--------------------------------------------------------------------------------
1 | r""" Evaluate mask prediction """
2 | import torch
3 |
4 |
5 | class Evaluator:
6 | r""" Computes intersection and union between prediction and ground-truth """
7 | @classmethod
8 | def initialize(cls):
9 | cls.ignore_index = 255
10 |
11 | @classmethod
12 | def classify_prediction(cls, pred_mask, gt_mask, query_ignore_idx=None):
13 | # gt_mask = batch.get('query_mask')
14 |
15 | # # Apply ignore_index in PASCAL-5i masks (following evaluation scheme in PFE-Net (TPAMI 2020))
16 | # query_ignore_idx = batch.get('query_ignore_idx')
17 | if query_ignore_idx is not None:
18 | assert torch.logical_and(query_ignore_idx, gt_mask).sum() == 0
19 | query_ignore_idx *= cls.ignore_index
20 | gt_mask = gt_mask + query_ignore_idx
21 | pred_mask[gt_mask == cls.ignore_index] = cls.ignore_index
22 |
23 | # compute intersection and union of each episode in a batch
24 | area_inter, area_pred, area_gt = [], [], []
25 | for _pred_mask, _gt_mask in zip(pred_mask, gt_mask):
26 | _inter = _pred_mask[_pred_mask == _gt_mask]
27 | if _inter.size(0) == 0: # as torch.histc returns error if it gets empty tensor (pytorch 1.5.1)
28 | _area_inter = torch.tensor([0, 0], device=_pred_mask.device)
29 | else:
30 | _area_inter = torch.histc(_inter, bins=2, min=0, max=1)
31 | area_inter.append(_area_inter)
32 | area_pred.append(torch.histc(_pred_mask, bins=2, min=0, max=1))
33 | area_gt.append(torch.histc(_gt_mask, bins=2, min=0, max=1))
34 | area_inter = torch.stack(area_inter).t()
35 | area_pred = torch.stack(area_pred).t()
36 | area_gt = torch.stack(area_gt).t()
37 | area_union = area_pred + area_gt - area_inter
38 |
39 | return area_inter, area_union
40 |
--------------------------------------------------------------------------------
/fewshot_data/common/logger.py:
--------------------------------------------------------------------------------
1 | r""" Logging during training/testing """
2 | import datetime
3 | import logging
4 | import os
5 |
6 | from tensorboardX import SummaryWriter
7 | import torch
8 |
9 |
10 | class AverageMeter:
11 | r""" Stores loss, evaluation results """
12 | def __init__(self, dataset):
13 | self.benchmark = dataset.benchmark
14 | self.class_ids_interest = dataset.class_ids
15 | self.class_ids_interest = torch.tensor(self.class_ids_interest).cuda()
16 |
17 | if self.benchmark == 'pascal':
18 | self.nclass = 20
19 | elif self.benchmark == 'coco':
20 | self.nclass = 80
21 | elif self.benchmark == 'fss':
22 | self.nclass = 1000
23 |
24 | self.intersection_buf = torch.zeros([2, self.nclass]).float().cuda()
25 | self.union_buf = torch.zeros([2, self.nclass]).float().cuda()
26 | self.ones = torch.ones_like(self.union_buf)
27 | self.loss_buf = []
28 |
29 | def update(self, inter_b, union_b, class_id, loss):
30 | self.intersection_buf.index_add_(1, class_id, inter_b.float())
31 | self.union_buf.index_add_(1, class_id, union_b.float())
32 | if loss is None:
33 | loss = torch.tensor(0.0)
34 | self.loss_buf.append(loss)
35 |
36 | def compute_iou(self):
37 | iou = self.intersection_buf.float() / \
38 | torch.max(torch.stack([self.union_buf, self.ones]), dim=0)[0]
39 | iou = iou.index_select(1, self.class_ids_interest)
40 | miou = iou[1].mean() * 100
41 |
42 | fb_iou = (self.intersection_buf.index_select(1, self.class_ids_interest).sum(dim=1) /
43 | self.union_buf.index_select(1, self.class_ids_interest).sum(dim=1)).mean() * 100
44 |
45 | return miou, fb_iou
46 |
47 | def write_result(self, split, epoch):
48 | iou, fb_iou = self.compute_iou()
49 |
50 | loss_buf = torch.stack(self.loss_buf)
51 | msg = '\n*** %s ' % split
52 | msg += '[@Epoch %02d] ' % epoch
53 | msg += 'Avg L: %6.5f ' % loss_buf.mean()
54 | msg += 'mIoU: %5.2f ' % iou
55 | msg += 'FB-IoU: %5.2f ' % fb_iou
56 |
57 | msg += '***\n'
58 | Logger.info(msg)
59 |
60 | def write_process(self, batch_idx, datalen, epoch, write_batch_idx=20):
61 | if batch_idx % write_batch_idx == 0:
62 | msg = '[Epoch: %02d] ' % epoch if epoch != -1 else ''
63 | msg += '[Batch: %04d/%04d] ' % (batch_idx+1, datalen)
64 | iou, fb_iou = self.compute_iou()
65 | if epoch != -1:
66 | loss_buf = torch.stack(self.loss_buf)
67 | msg += 'L: %6.5f ' % loss_buf[-1]
68 | msg += 'Avg L: %6.5f ' % loss_buf.mean()
69 | msg += 'mIoU: %5.2f | ' % iou
70 | msg += 'FB-IoU: %5.2f' % fb_iou
71 | Logger.info(msg)
72 | return iou, fb_iou
73 |
74 |
75 | class Logger:
76 | r""" Writes evaluation results of training/testing """
77 | @classmethod
78 | def initialize(cls, args, training):
79 | logtime = datetime.datetime.now().__format__('_%m%d_%H%M%S')
80 | logpath = args.logpath if training else '_TEST_' + args.load.split('/')[-2].split('.')[0] + logtime
81 | if logpath == '': logpath = logtime
82 |
83 | cls.logpath = os.path.join('logs', logpath + '.log')
84 | cls.benchmark = args.benchmark
85 | if not os.path.exists(cls.logpath):
86 | os.makedirs(cls.logpath)
87 |
88 | logging.basicConfig(filemode='w',
89 | filename=os.path.join(cls.logpath, 'log.txt'),
90 | level=logging.INFO,
91 | format='%(message)s',
92 | datefmt='%m-%d %H:%M:%S')
93 |
94 | # Console log config
95 | console = logging.StreamHandler()
96 | console.setLevel(logging.INFO)
97 | formatter = logging.Formatter('%(message)s')
98 | console.setFormatter(formatter)
99 | logging.getLogger('').addHandler(console)
100 |
101 | # Tensorboard writer
102 | cls.tbd_writer = SummaryWriter(os.path.join(cls.logpath, 'tbd/runs'))
103 |
104 | # Log arguments
105 | logging.info('\n:=========== Few-shot Seg. with HSNet ===========')
106 | for arg_key in args.__dict__:
107 | logging.info('| %20s: %-24s' % (arg_key, str(args.__dict__[arg_key])))
108 | logging.info(':================================================\n')
109 |
110 | @classmethod
111 | def info(cls, msg):
112 | r""" Writes log message to log.txt """
113 | logging.info(msg)
114 |
115 | @classmethod
116 | def save_model_miou(cls, model, epoch, val_miou):
117 | torch.save(model.state_dict(), os.path.join(cls.logpath, 'best_model.pt'))
118 | cls.info('Model saved @%d w/ val. mIoU: %5.2f.\n' % (epoch, val_miou))
119 |
120 | @classmethod
121 | def log_params(cls, model):
122 | backbone_param = 0
123 | learner_param = 0
124 | for k in model.state_dict().keys():
125 | n_param = model.state_dict()[k].view(-1).size(0)
126 | if k.split('.')[0] in 'backbone':
127 | if k.split('.')[1] in ['classifier', 'fc']: # as fc layers are not used in HSNet
128 | continue
129 | backbone_param += n_param
130 | else:
131 | learner_param += n_param
132 | Logger.info('Backbone # param.: %d' % backbone_param)
133 | Logger.info('Learnable # param.: %d' % learner_param)
134 | Logger.info('Total # param.: %d' % (backbone_param + learner_param))
135 |
136 |
--------------------------------------------------------------------------------
/fewshot_data/common/utils.py:
--------------------------------------------------------------------------------
1 | r""" Helper functions """
2 | import random
3 |
4 | import torch
5 | import numpy as np
6 |
7 |
8 | def fix_randseed(seed):
9 | r""" Set random seeds for reproducibility """
10 | if seed is None:
11 | seed = int(random.random() * 1e5)
12 | np.random.seed(seed)
13 | torch.manual_seed(seed)
14 | torch.cuda.manual_seed(seed)
15 | torch.cuda.manual_seed_all(seed)
16 | torch.backends.cudnn.benchmark = False
17 | torch.backends.cudnn.deterministic = True
18 |
19 |
20 | def mean(x):
21 | return sum(x) / len(x) if len(x) > 0 else 0.0
22 |
23 |
24 | def to_cuda(batch):
25 | for key, value in batch.items():
26 | if isinstance(value, torch.Tensor):
27 | batch[key] = value.cuda()
28 | return batch
29 |
30 |
31 | def to_cpu(tensor):
32 | return tensor.detach().clone().cpu()
33 |
--------------------------------------------------------------------------------
/fewshot_data/common/vis.py:
--------------------------------------------------------------------------------
1 | r""" Visualize model predictions """
2 | import os
3 |
4 | from PIL import Image
5 | import numpy as np
6 | import torchvision.transforms as transforms
7 |
8 | from fewshot_data.common import utils
9 |
10 |
11 | class Visualizer:
12 |
13 | @classmethod
14 | def initialize(cls, visualize):
15 | cls.visualize = visualize
16 | if not visualize:
17 | return
18 |
19 | cls.colors = {'red': (255, 50, 50), 'blue': (102, 140, 255)}
20 | for key, value in cls.colors.items():
21 | cls.colors[key] = tuple([c / 255 for c in cls.colors[key]])
22 |
23 | # cls.mean_img = [0.485, 0.456, 0.406]
24 | # cls.std_img = [0.229, 0.224, 0.225]
25 | cls.mean_img = [0.5] * 3
26 | cls.std_img = [0.5] * 3
27 | cls.to_pil = transforms.ToPILImage()
28 | cls.vis_path = './vis/'
29 | if not os.path.exists(cls.vis_path): os.makedirs(cls.vis_path)
30 |
31 | @classmethod
32 | def visualize_prediction_batch(cls, spt_img_b, spt_mask_b, qry_img_b, qry_mask_b, pred_mask_b, cls_id_b, batch_idx, iou_b=None):
33 | spt_img_b = utils.to_cpu(spt_img_b)
34 | spt_mask_b = utils.to_cpu(spt_mask_b)
35 | qry_img_b = utils.to_cpu(qry_img_b)
36 | qry_mask_b = utils.to_cpu(qry_mask_b)
37 | pred_mask_b = utils.to_cpu(pred_mask_b)
38 | cls_id_b = utils.to_cpu(cls_id_b)
39 |
40 | for sample_idx, (spt_img, spt_mask, qry_img, qry_mask, pred_mask, cls_id) in \
41 | enumerate(zip(spt_img_b, spt_mask_b, qry_img_b, qry_mask_b, pred_mask_b, cls_id_b)):
42 | iou = iou_b[sample_idx] if iou_b is not None else None
43 | cls.visualize_prediction(spt_img, spt_mask, qry_img, qry_mask, pred_mask, cls_id, batch_idx, sample_idx, True, iou)
44 |
45 | @classmethod
46 | def to_numpy(cls, tensor, type):
47 | if type == 'img':
48 | return np.array(cls.to_pil(cls.unnormalize(tensor))).astype(np.uint8)
49 | elif type == 'mask':
50 | return np.array(tensor).astype(np.uint8)
51 | else:
52 | raise Exception('Undefined tensor type: %s' % type)
53 |
54 | @classmethod
55 | def visualize_prediction(cls, spt_imgs, spt_masks, qry_img, qry_mask, pred_mask, cls_id, batch_idx, sample_idx, label, iou=None):
56 |
57 | spt_color = cls.colors['blue']
58 | qry_color = cls.colors['red']
59 | pred_color = cls.colors['red']
60 |
61 | spt_imgs = [cls.to_numpy(spt_img, 'img') for spt_img in spt_imgs]
62 | spt_pils = [cls.to_pil(spt_img) for spt_img in spt_imgs]
63 | spt_masks = [cls.to_numpy(spt_mask, 'mask') for spt_mask in spt_masks]
64 | spt_masked_pils = [Image.fromarray(cls.apply_mask(spt_img, spt_mask, spt_color)) for spt_img, spt_mask in zip(spt_imgs, spt_masks)]
65 |
66 | qry_img = cls.to_numpy(qry_img, 'img')
67 | qry_pil = cls.to_pil(qry_img)
68 | qry_mask = cls.to_numpy(qry_mask, 'mask')
69 | pred_mask = cls.to_numpy(pred_mask, 'mask')
70 | pred_masked_pil = Image.fromarray(cls.apply_mask(qry_img.astype(np.uint8), pred_mask.astype(np.uint8), pred_color))
71 | qry_masked_pil = Image.fromarray(cls.apply_mask(qry_img.astype(np.uint8), qry_mask.astype(np.uint8), qry_color))
72 |
73 | merged_pil = cls.merge_image_pair(spt_masked_pils + [pred_masked_pil, qry_masked_pil])
74 |
75 | iou = iou.item() if iou else 0.0
76 | merged_pil.save(cls.vis_path + '%d_%d_class-%d_iou-%.2f' % (batch_idx, sample_idx, cls_id, iou) + '.jpg')
77 |
78 | @classmethod
79 | def merge_image_pair(cls, pil_imgs):
80 | r""" Horizontally aligns a pair of pytorch tensor images (3, H, W) and returns PIL object """
81 |
82 | canvas_width = sum([pil.size[0] for pil in pil_imgs])
83 | canvas_height = max([pil.size[1] for pil in pil_imgs])
84 | canvas = Image.new('RGB', (canvas_width, canvas_height))
85 |
86 | xpos = 0
87 | for pil in pil_imgs:
88 | canvas.paste(pil, (xpos, 0))
89 | xpos += pil.size[0]
90 |
91 | return canvas
92 |
93 | @classmethod
94 | def apply_mask(cls, image, mask, color, alpha=0.5):
95 | r""" Apply mask to the given image. """
96 | for c in range(3):
97 | image[:, :, c] = np.where(mask == 1,
98 | image[:, :, c] *
99 | (1 - alpha) + alpha * color[c] * 255,
100 | image[:, :, c])
101 | return image
102 |
103 | @classmethod
104 | def unnormalize(cls, img):
105 | img = img.clone()
106 | for im_channel, mean, std in zip(img, cls.mean_img, cls.std_img):
107 | im_channel.mul_(std).add_(mean)
108 | return img
109 |
--------------------------------------------------------------------------------
/fewshot_data/data/assets/architecture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/isl-org/lang-seg/bb00b219f82dc1a18a01f9b0b647bae1d41f3aeb/fewshot_data/data/assets/architecture.png
--------------------------------------------------------------------------------
/fewshot_data/data/assets/qualitative_results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/isl-org/lang-seg/bb00b219f82dc1a18a01f9b0b647bae1d41f3aeb/fewshot_data/data/assets/qualitative_results.png
--------------------------------------------------------------------------------
/fewshot_data/data/coco.py:
--------------------------------------------------------------------------------
1 | r""" COCO-20i few-shot semantic segmentation dataset """
2 | import os
3 | import pickle
4 |
5 | from torch.utils.data import Dataset
6 | import torch.nn.functional as F
7 | import torch
8 | import PIL.Image as Image
9 | import numpy as np
10 |
11 |
12 | class DatasetCOCO(Dataset):
13 | def __init__(self, datapath, fold, transform, split, shot, use_original_imgsize):
14 | self.split = 'val' if split in ['val', 'test'] else 'trn'
15 | self.fold = fold
16 | self.nfolds = 4
17 | self.nclass = 80
18 | self.benchmark = 'coco'
19 | self.shot = shot
20 | self.split_coco = split if split == 'val2014' else 'train2014'
21 | self.base_path = os.path.join(datapath, 'COCO2014')
22 | self.transform = transform
23 | self.use_original_imgsize = use_original_imgsize
24 |
25 | self.class_ids = self.build_class_ids()
26 | self.img_metadata_classwise = self.build_img_metadata_classwise()
27 | self.img_metadata = self.build_img_metadata()
28 |
29 | def __len__(self):
30 | return len(self.img_metadata) if self.split == 'trn' else 1000
31 |
32 | def __getitem__(self, idx):
33 | # ignores idx during training & testing and perform uniform sampling over object classes to form an episode
34 | # (due to the large size of the COCO dataset)
35 | query_img, query_mask, support_imgs, support_masks, query_name, support_names, class_sample, org_qry_imsize = self.load_frame()
36 |
37 | query_img = self.transform(query_img)
38 | query_mask = query_mask.float()
39 | if not self.use_original_imgsize:
40 | query_mask = F.interpolate(query_mask.unsqueeze(0).unsqueeze(0).float(), query_img.size()[-2:], mode='nearest').squeeze()
41 |
42 | if self.shot:
43 | support_imgs = torch.stack([self.transform(support_img) for support_img in support_imgs])
44 | for midx, smask in enumerate(support_masks):
45 | support_masks[midx] = F.interpolate(smask.unsqueeze(0).unsqueeze(0).float(), support_imgs.size()[-2:], mode='nearest').squeeze()
46 | support_masks = torch.stack(support_masks)
47 |
48 |
49 | batch = {'query_img': query_img,
50 | 'query_mask': query_mask,
51 | 'query_name': query_name,
52 |
53 | 'org_query_imsize': org_qry_imsize,
54 |
55 | 'support_imgs': support_imgs,
56 | 'support_masks': support_masks,
57 | 'support_names': support_names,
58 | 'class_id': torch.tensor(class_sample)}
59 |
60 | return batch
61 |
62 | def build_class_ids(self):
63 | nclass_trn = self.nclass // self.nfolds
64 | class_ids_val = [self.fold + self.nfolds * v for v in range(nclass_trn)]
65 | class_ids_trn = [x for x in range(self.nclass) if x not in class_ids_val]
66 | class_ids = class_ids_trn if self.split == 'trn' else class_ids_val
67 |
68 | return class_ids
69 |
70 | def build_img_metadata_classwise(self):
71 | with open('fewshot_data/data/splits/coco/%s/fold%d.pkl' % (self.split, self.fold), 'rb') as f:
72 | img_metadata_classwise = pickle.load(f)
73 | return img_metadata_classwise
74 |
75 | def build_img_metadata(self):
76 | img_metadata = []
77 | for k in self.img_metadata_classwise.keys():
78 | img_metadata += self.img_metadata_classwise[k]
79 | return sorted(list(set(img_metadata)))
80 |
81 | def read_mask(self, name):
82 | mask_path = os.path.join(self.base_path, 'annotations', name)
83 | mask = torch.tensor(np.array(Image.open(mask_path[:mask_path.index('.jpg')] + '.png')))
84 | return mask
85 |
86 | def load_frame(self):
87 | class_sample = np.random.choice(self.class_ids, 1, replace=False)[0]
88 | query_name = np.random.choice(self.img_metadata_classwise[class_sample], 1, replace=False)[0]
89 | query_img = Image.open(os.path.join(self.base_path, query_name)).convert('RGB')
90 | query_mask = self.read_mask(query_name)
91 |
92 | org_qry_imsize = query_img.size
93 |
94 | query_mask[query_mask != class_sample + 1] = 0
95 | query_mask[query_mask == class_sample + 1] = 1
96 |
97 | support_names = []
98 | if self.shot:
99 | while True: # keep sampling support set if query == support
100 | support_name = np.random.choice(self.img_metadata_classwise[class_sample], 1, replace=False)[0]
101 | if query_name != support_name: support_names.append(support_name)
102 | if len(support_names) == self.shot: break
103 |
104 | support_imgs = []
105 | support_masks = []
106 | if self.shot:
107 | for support_name in support_names:
108 | support_imgs.append(Image.open(os.path.join(self.base_path, support_name)).convert('RGB'))
109 | support_mask = self.read_mask(support_name)
110 | support_mask[support_mask != class_sample + 1] = 0
111 | support_mask[support_mask == class_sample + 1] = 1
112 | support_masks.append(support_mask)
113 |
114 | return query_img, query_mask, support_imgs, support_masks, query_name, support_names, class_sample, org_qry_imsize
115 |
116 |
--------------------------------------------------------------------------------
/fewshot_data/data/dataset.py:
--------------------------------------------------------------------------------
1 | r""" Dataloader builder for few-shot semantic segmentation dataset """
2 | from torchvision import transforms
3 | from torch.utils.data import DataLoader
4 |
5 | from fewshot_data.data.pascal import DatasetPASCAL
6 | from fewshot_data.data.coco import DatasetCOCO
7 | from fewshot_data.data.fss import DatasetFSS
8 |
9 |
10 | class FSSDataset:
11 | @classmethod
12 | def initialize(cls, img_size, datapath, use_original_imgsize, imagenet_norm=False):
13 | cls.datasets = {
14 | 'pascal': DatasetPASCAL,
15 | 'coco': DatasetCOCO,
16 | 'fss': DatasetFSS,
17 | }
18 |
19 | if imagenet_norm:
20 | cls.img_mean = [0.485, 0.456, 0.406]
21 | cls.img_std = [0.229, 0.224, 0.225]
22 | print('use norm: {}, {}'.format(cls.img_mean, cls.img_std))
23 | else:
24 | cls.img_mean = [0.5] * 3
25 | cls.img_std = [0.5] * 3
26 | print('use norm: {}, {}'.format(cls.img_mean, cls.img_std))
27 |
28 | cls.datapath = datapath
29 | cls.use_original_imgsize = use_original_imgsize
30 |
31 | cls.transform = transforms.Compose([transforms.Resize(size=(img_size, img_size)),
32 | transforms.ToTensor(),
33 | transforms.Normalize(cls.img_mean, cls.img_std)])
34 |
35 | @classmethod
36 | def build_dataloader(cls, benchmark, bsz, nworker, fold, split, shot=1):
37 | shuffle = split == 'trn'
38 | nworker = nworker if split == 'trn' else 0
39 | dataset = cls.datasets[benchmark](cls.datapath, fold=fold, transform=cls.transform, split=split, shot=shot, use_original_imgsize=cls.use_original_imgsize)
40 | dataloader = DataLoader(dataset, batch_size=bsz, shuffle=shuffle, num_workers=nworker)
41 |
42 | return dataloader
43 |
--------------------------------------------------------------------------------
/fewshot_data/data/fss.py:
--------------------------------------------------------------------------------
1 | r""" FSS-1000 few-shot semantic segmentation dataset """
2 | import os
3 | import glob
4 |
5 | from torch.utils.data import Dataset
6 | import torch.nn.functional as F
7 | import torch
8 | import PIL.Image as Image
9 | import numpy as np
10 |
11 |
12 | class DatasetFSS(Dataset):
13 | def __init__(self, datapath, fold, transform, split, shot, use_original_imgsize=None):
14 | self.split = split
15 | self.benchmark = 'fss'
16 | self.shot = shot
17 |
18 | self.base_path = os.path.join(datapath, 'FSS-1000')
19 |
20 | # Given predefined test split, load randomly generated training/val splits:
21 | # (reference regarding trn/val/test splits: https://github.com/HKUSTCV/FSS-1000/issues/7))
22 | with open('fewshot_data/data/splits/fss/%s.txt' % split, 'r') as f:
23 | self.categories = f.read().split('\n')[:-1]
24 | self.categories = sorted(self.categories)
25 |
26 | self.class_ids = self.build_class_ids()
27 | self.img_metadata = self.build_img_metadata()
28 |
29 | self.transform = transform
30 |
31 | def __len__(self):
32 | return len(self.img_metadata)
33 |
34 | def __getitem__(self, idx):
35 | query_name, support_names, class_sample = self.sample_episode(idx)
36 | query_img, query_mask, support_imgs, support_masks = self.load_frame(query_name, support_names)
37 |
38 | query_img = self.transform(query_img)
39 | query_mask = F.interpolate(query_mask.unsqueeze(0).unsqueeze(0).float(), query_img.size()[-2:], mode='nearest').squeeze()
40 | if self.shot:
41 | support_imgs = torch.stack([self.transform(support_img) for support_img in support_imgs])
42 |
43 | support_masks_tmp = []
44 | for smask in support_masks:
45 | smask = F.interpolate(smask.unsqueeze(0).unsqueeze(0).float(), support_imgs.size()[-2:], mode='nearest').squeeze()
46 | support_masks_tmp.append(smask)
47 | support_masks = torch.stack(support_masks_tmp)
48 |
49 | batch = {'query_img': query_img,
50 | 'query_mask': query_mask,
51 | 'query_name': query_name,
52 |
53 | 'support_imgs': support_imgs,
54 | 'support_masks': support_masks,
55 | 'support_names': support_names,
56 |
57 | 'class_id': torch.tensor(class_sample)}
58 |
59 | return batch
60 |
61 | def load_frame(self, query_name, support_names):
62 | query_img = Image.open(query_name).convert('RGB')
63 | if self.shot:
64 | support_imgs = [Image.open(name).convert('RGB') for name in support_names]
65 | else:
66 | support_imgs = []
67 |
68 | query_id = query_name.split('/')[-1].split('.')[0]
69 | query_name = os.path.join(os.path.dirname(query_name), query_id) + '.png'
70 |
71 | if self.shot:
72 | support_ids = [name.split('/')[-1].split('.')[0] for name in support_names]
73 | support_names = [os.path.join(os.path.dirname(name), sid) + '.png' for name, sid in zip(support_names, support_ids)]
74 |
75 | query_mask = self.read_mask(query_name)
76 | if self.shot:
77 | support_masks = [self.read_mask(name) for name in support_names]
78 | else:
79 | support_masks = []
80 |
81 | return query_img, query_mask, support_imgs, support_masks
82 |
83 | def read_mask(self, img_name):
84 | mask = torch.tensor(np.array(Image.open(img_name).convert('L')))
85 | mask[mask < 128] = 0
86 | mask[mask >= 128] = 1
87 | return mask
88 |
89 | def sample_episode(self, idx):
90 | query_name = self.img_metadata[idx]
91 | class_sample = self.categories.index(query_name.split('/')[-2])
92 | if self.split == 'val':
93 | class_sample += 520
94 | elif self.split == 'test':
95 | class_sample += 760
96 |
97 | support_names = []
98 | # here we only test with shot=1
99 | if self.split == 'test' and self.shot == 1:
100 | while True:
101 | support_name = 1
102 | support_name = os.path.join(os.path.dirname(query_name), str(support_name)) + '.jpg'
103 | if query_name != support_name:
104 | support_names.append(support_name)
105 | else:
106 | print('Error in sample_episode!')
107 | exit()
108 | if len(support_names) == self.shot: break
109 | elif self.shot:
110 | while True: # keep sampling support set if query == support
111 | support_name = np.random.choice(range(1, 11), 1, replace=False)[0]
112 | support_name = os.path.join(os.path.dirname(query_name), str(support_name)) + '.jpg'
113 | if query_name != support_name: support_names.append(support_name)
114 | if len(support_names) == self.shot: break
115 |
116 | return query_name, support_names, class_sample
117 |
118 | def build_class_ids(self):
119 | if self.split == 'trn':
120 | class_ids = range(0, 520)
121 | elif self.split == 'val':
122 | class_ids = range(520, 760)
123 | elif self.split == 'test':
124 | class_ids = range(760, 1000)
125 | return class_ids
126 |
127 | def build_img_metadata(self):
128 | img_metadata = []
129 | for cat in self.categories:
130 | img_paths = sorted([path for path in glob.glob('%s/*' % os.path.join(self.base_path, cat))])
131 | if self.split == 'test' and self.shot == 1:
132 | for i in range(1, len(img_paths)):
133 | img_path = img_paths[i]
134 | if os.path.basename(img_path).split('.')[1] == 'jpg':
135 | img_metadata.append(img_path)
136 | else:
137 | for img_path in img_paths:
138 | if os.path.basename(img_path).split('.')[1] == 'jpg':
139 | img_metadata.append(img_path)
140 | return img_metadata
141 |
--------------------------------------------------------------------------------
/fewshot_data/data/pascal.py:
--------------------------------------------------------------------------------
1 | r""" PASCAL-5i few-shot semantic segmentation dataset """
2 | import os
3 |
4 | from torch.utils.data import Dataset
5 | import torch.nn.functional as F
6 | import torch
7 | import PIL.Image as Image
8 | import numpy as np
9 |
10 |
11 | class DatasetPASCAL(Dataset):
12 | def __init__(self, datapath, fold, transform, split, shot, use_original_imgsize):
13 | self.split = 'val' if split in ['val', 'test'] else 'trn'
14 | self.fold = fold
15 | self.nfolds = 4
16 | self.nclass = 20
17 | self.benchmark = 'pascal'
18 | self.shot = shot
19 | self.use_original_imgsize = use_original_imgsize
20 |
21 | self.img_path = os.path.join(datapath, 'VOC2012/JPEGImages/')
22 | self.ann_path = os.path.join(datapath, 'VOC2012/SegmentationClassAug/')
23 | self.transform = transform
24 |
25 | self.class_ids = self.build_class_ids()
26 | self.img_metadata = self.build_img_metadata()
27 | self.img_metadata_classwise = self.build_img_metadata_classwise()
28 |
29 | def __len__(self):
30 | return len(self.img_metadata) if self.split == 'trn' else 1000
31 |
32 | def __getitem__(self, idx):
33 | idx %= len(self.img_metadata) # for testing, as n_images < 1000
34 | query_name, support_names, class_sample = self.sample_episode(idx)
35 | query_img, query_cmask, support_imgs, support_cmasks, org_qry_imsize = self.load_frame(query_name, support_names)
36 |
37 | query_img = self.transform(query_img)
38 | if not self.use_original_imgsize:
39 | query_cmask = F.interpolate(query_cmask.unsqueeze(0).unsqueeze(0).float(), query_img.size()[-2:], mode='nearest').squeeze()
40 | query_mask, query_ignore_idx = self.extract_ignore_idx(query_cmask.float(), class_sample)
41 |
42 | if self.shot:
43 | support_imgs = torch.stack([self.transform(support_img) for support_img in support_imgs])
44 |
45 | support_masks = []
46 | support_ignore_idxs = []
47 | for scmask in support_cmasks:
48 | scmask = F.interpolate(scmask.unsqueeze(0).unsqueeze(0).float(), support_imgs.size()[-2:], mode='nearest').squeeze()
49 | support_mask, support_ignore_idx = self.extract_ignore_idx(scmask, class_sample)
50 | support_masks.append(support_mask)
51 | support_ignore_idxs.append(support_ignore_idx)
52 | support_masks = torch.stack(support_masks)
53 | support_ignore_idxs = torch.stack(support_ignore_idxs)
54 | else:
55 | support_masks = []
56 | support_ignore_idxs = []
57 | batch = {'query_img': query_img,
58 | 'query_mask': query_mask,
59 | 'query_name': query_name,
60 | 'query_ignore_idx': query_ignore_idx,
61 |
62 | 'org_query_imsize': org_qry_imsize,
63 |
64 | 'support_imgs': support_imgs,
65 | 'support_masks': support_masks,
66 | 'support_names': support_names,
67 | 'support_ignore_idxs': support_ignore_idxs,
68 |
69 | 'class_id': torch.tensor(class_sample)}
70 |
71 | return batch
72 |
73 | def extract_ignore_idx(self, mask, class_id):
74 | boundary = (mask / 255).floor()
75 | mask[mask != class_id + 1] = 0
76 | mask[mask == class_id + 1] = 1
77 |
78 | return mask, boundary
79 |
80 | def load_frame(self, query_name, support_names):
81 | query_img = self.read_img(query_name)
82 | query_mask = self.read_mask(query_name)
83 | support_imgs = [self.read_img(name) for name in support_names]
84 | support_masks = [self.read_mask(name) for name in support_names]
85 |
86 | org_qry_imsize = query_img.size
87 |
88 | return query_img, query_mask, support_imgs, support_masks, org_qry_imsize
89 |
90 | def read_mask(self, img_name):
91 | r"""Return segmentation mask in PIL Image"""
92 | mask = torch.tensor(np.array(Image.open(os.path.join(self.ann_path, img_name) + '.png')))
93 | return mask
94 |
95 | def read_img(self, img_name):
96 | r"""Return RGB image in PIL Image"""
97 | return Image.open(os.path.join(self.img_path, img_name) + '.jpg')
98 |
99 | def sample_episode(self, idx):
100 | query_name, class_sample = self.img_metadata[idx]
101 |
102 | support_names = []
103 | if self.shot:
104 | while True: # keep sampling support set if query == support
105 | support_name = np.random.choice(self.img_metadata_classwise[class_sample], 1, replace=False)[0]
106 | if query_name != support_name: support_names.append(support_name)
107 | if len(support_names) == self.shot: break
108 |
109 | return query_name, support_names, class_sample
110 |
111 | def build_class_ids(self):
112 | nclass_trn = self.nclass // self.nfolds
113 | class_ids_val = [self.fold * nclass_trn + i for i in range(nclass_trn)]
114 | class_ids_trn = [x for x in range(self.nclass) if x not in class_ids_val]
115 |
116 | if self.split == 'trn':
117 | return class_ids_trn
118 | else:
119 | return class_ids_val
120 |
121 | def build_img_metadata(self):
122 |
123 | def read_metadata(split, fold_id):
124 | fold_n_metadata = os.path.join('fewshot_data/data/splits/pascal/%s/fold%d.txt' % (split, fold_id))
125 | with open(fold_n_metadata, 'r') as f:
126 | fold_n_metadata = f.read().split('\n')[:-1]
127 | fold_n_metadata = [[data.split('__')[0], int(data.split('__')[1]) - 1] for data in fold_n_metadata]
128 | return fold_n_metadata
129 |
130 | img_metadata = []
131 | if self.split == 'trn': # For training, read image-metadata of "the other" folds
132 | for fold_id in range(self.nfolds):
133 | if fold_id == self.fold: # Skip validation fold
134 | continue
135 | img_metadata += read_metadata(self.split, fold_id)
136 | elif self.split == 'val': # For validation, read image-metadata of "current" fold
137 | img_metadata = read_metadata(self.split, self.fold)
138 | else:
139 | raise Exception('Undefined split %s: ' % self.split)
140 |
141 | print('Total (%s) images are : %d' % (self.split, len(img_metadata)))
142 |
143 | return img_metadata
144 |
145 | def build_img_metadata_classwise(self):
146 | img_metadata_classwise = {}
147 | for class_id in range(self.nclass):
148 | img_metadata_classwise[class_id] = []
149 |
150 | for img_name, img_class in self.img_metadata:
151 | img_metadata_classwise[img_class] += [img_name]
152 | return img_metadata_classwise
153 |
--------------------------------------------------------------------------------
/fewshot_data/data/splits/coco/trn/fold0.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/isl-org/lang-seg/bb00b219f82dc1a18a01f9b0b647bae1d41f3aeb/fewshot_data/data/splits/coco/trn/fold0.pkl
--------------------------------------------------------------------------------
/fewshot_data/data/splits/coco/trn/fold1.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/isl-org/lang-seg/bb00b219f82dc1a18a01f9b0b647bae1d41f3aeb/fewshot_data/data/splits/coco/trn/fold1.pkl
--------------------------------------------------------------------------------
/fewshot_data/data/splits/coco/trn/fold2.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/isl-org/lang-seg/bb00b219f82dc1a18a01f9b0b647bae1d41f3aeb/fewshot_data/data/splits/coco/trn/fold2.pkl
--------------------------------------------------------------------------------
/fewshot_data/data/splits/coco/trn/fold3.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/isl-org/lang-seg/bb00b219f82dc1a18a01f9b0b647bae1d41f3aeb/fewshot_data/data/splits/coco/trn/fold3.pkl
--------------------------------------------------------------------------------
/fewshot_data/data/splits/coco/val/fold0.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/isl-org/lang-seg/bb00b219f82dc1a18a01f9b0b647bae1d41f3aeb/fewshot_data/data/splits/coco/val/fold0.pkl
--------------------------------------------------------------------------------
/fewshot_data/data/splits/coco/val/fold1.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/isl-org/lang-seg/bb00b219f82dc1a18a01f9b0b647bae1d41f3aeb/fewshot_data/data/splits/coco/val/fold1.pkl
--------------------------------------------------------------------------------
/fewshot_data/data/splits/coco/val/fold2.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/isl-org/lang-seg/bb00b219f82dc1a18a01f9b0b647bae1d41f3aeb/fewshot_data/data/splits/coco/val/fold2.pkl
--------------------------------------------------------------------------------
/fewshot_data/data/splits/coco/val/fold3.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/isl-org/lang-seg/bb00b219f82dc1a18a01f9b0b647bae1d41f3aeb/fewshot_data/data/splits/coco/val/fold3.pkl
--------------------------------------------------------------------------------
/fewshot_data/data/splits/fss/test.txt:
--------------------------------------------------------------------------------
1 | bus
2 | hotel_slipper
3 | burj_al
4 | reflex_camera
5 | abe's_flyingfish
6 | oiltank_car
7 | doormat
8 | fish_eagle
9 | barber_shaver
10 | motorbike
11 | feather_clothes
12 | wandering_albatross
13 | rice_cooker
14 | delta_wing
15 | fish
16 | nintendo_switch
17 | bustard
18 | diver
19 | minicooper
20 | cathedrale_paris
21 | big_ben
22 | combination_lock
23 | villa_savoye
24 | american_alligator
25 | gym_ball
26 | andean_condor
27 | leggings
28 | pyramid_cube
29 | jet_aircraft
30 | meatloaf
31 | reel
32 | swan
33 | osprey
34 | crt_screen
35 | microscope
36 | rubber_eraser
37 | arrow
38 | monkey
39 | mitten
40 | spiderman
41 | parthenon
42 | bat
43 | chess_king
44 | sulphur_butterfly
45 | quail_egg
46 | oriole
47 | iron_man
48 | wooden_boat
49 | anise
50 | steering_wheel
51 | groenendael
52 | dwarf_beans
53 | pteropus
54 | chalk_brush
55 | bloodhound
56 | moon
57 | english_foxhound
58 | boxing_gloves
59 | peregine_falcon
60 | pyraminx
61 | cicada
62 | screw
63 | shower_curtain
64 | tredmill
65 | bulb
66 | bell_pepper
67 | lemur_catta
68 | doughnut
69 | twin_tower
70 | astronaut
71 | nintendo_3ds
72 | fennel_bulb
73 | indri
74 | captain_america_shield
75 | kunai
76 | broom
77 | iphone
78 | earphone1
79 | flying_squirrel
80 | onion
81 | vinyl
82 | sydney_opera_house
83 | oyster
84 | harmonica
85 | egg
86 | breast_pump
87 | guitar
88 | potato_chips
89 | tunnel
90 | cuckoo
91 | rubick_cube
92 | plastic_bag
93 | phonograph
94 | net_surface_shoes
95 | goldfinch
96 | ipad
97 | mite_predator
98 | coffee_mug
99 | golden_plover
100 | f1_racing
101 | lapwing
102 | nintendo_gba
103 | pizza
104 | rally_car
105 | drilling_platform
106 | cd
107 | fly
108 | magpie_bird
109 | leaf_fan
110 | little_blue_heron
111 | carriage
112 | moist_proof_pad
113 | flying_snakes
114 | dart_target
115 | warehouse_tray
116 | nintendo_wiiu
117 | chiffon_cake
118 | bath_ball
119 | manatee
120 | cloud
121 | marimba
122 | eagle
123 | ruler
124 | soymilk_machine
125 | sled
126 | seagull
127 | glider_flyingfish
128 | doublebus
129 | transport_helicopter
130 | window_screen
131 | truss_bridge
132 | wasp
133 | snowman
134 | poached_egg
135 | strawberry
136 | spinach
137 | earphone2
138 | downy_pitch
139 | taj_mahal
140 | rocking_chair
141 | cablestayed_bridge
142 | sealion
143 | banana_boat
144 | pheasant
145 | stone_lion
146 | electronic_stove
147 | fox
148 | iguana
149 | rugby_ball
150 | hang_glider
151 | water_buffalo
152 | lotus
153 | paper_plane
154 | missile
155 | flamingo
156 | american_chamelon
157 | kart
158 | chinese_knot
159 | cabbage_butterfly
160 | key
161 | church
162 | tiltrotor
163 | helicopter
164 | french_fries
165 | water_heater
166 | snow_leopard
167 | goblet
168 | fan
169 | snowplow
170 | leafhopper
171 | pspgo
172 | black_bear
173 | quail
174 | condor
175 | chandelier
176 | hair_razor
177 | white_wolf
178 | toaster
179 | pidan
180 | pyramid
181 | chicken_leg
182 | letter_opener
183 | apple_icon
184 | porcupine
185 | chicken
186 | stingray
187 | warplane
188 | windmill
189 | bamboo_slip
190 | wig
191 | flying_geckos
192 | stonechat
193 | haddock
194 | australian_terrier
195 | hover_board
196 | siamang
197 | canton_tower
198 | santa_sledge
199 | arch_bridge
200 | curlew
201 | sushi
202 | beet_root
203 | accordion
204 | leaf_egg
205 | stealth_aircraft
206 | stork
207 | bucket
208 | hawk
209 | chess_queen
210 | ocarina
211 | knife
212 | whippet
213 | cantilever_bridge
214 | may_bug
215 | wagtail
216 | leather_shoes
217 | wheelchair
218 | shumai
219 | speedboat
220 | vacuum_cup
221 | chess_knight
222 | pumpkin_pie
223 | wooden_spoon
224 | bamboo_dragonfly
225 | ganeva_chair
226 | soap
227 | clearwing_flyingfish
228 | pencil_sharpener1
229 | cricket
230 | photocopier
231 | nintendo_sp
232 | samarra_mosque
233 | clam
234 | charge_battery
235 | flying_frog
236 | ferrari911
237 | polo_shirt
238 | echidna
239 | coin
240 | tower_pisa
241 |
--------------------------------------------------------------------------------
/fewshot_data/data/splits/fss/trn.txt:
--------------------------------------------------------------------------------
1 | fountain
2 | taxi
3 | assult_rifle
4 | radio
5 | comb
6 | box_turtle
7 | igloo
8 | head_cabbage
9 | cottontail
10 | coho
11 | ashtray
12 | joystick
13 | sleeping_bag
14 | jackfruit
15 | trailer_truck
16 | shower_cap
17 | ibex
18 | kinguin
19 | squirrel
20 | ac_wall
21 | sidewinder
22 | remote_control
23 | marshmallow
24 | bolotie
25 | polar_bear
26 | rock_beauty
27 | tokyo_tower
28 | wafer
29 | red_bayberry
30 | electronic_toothbrush
31 | hartebeest
32 | cassette
33 | oil_filter
34 | bomb
35 | walnut
36 | toilet_tissue
37 | memory_stick
38 | wild_boar
39 | cableways
40 | chihuahua
41 | envelope
42 | bison
43 | poker
44 | pubg_lvl3helmet
45 | indian_cobra
46 | staffordshire
47 | park_bench
48 | wombat
49 | black_grouse
50 | submarine
51 | washer
52 | agama
53 | coyote
54 | feeder
55 | sarong
56 | buckingham_palace
57 | frog
58 | steam_locomotive
59 | acorn
60 | german_pointer
61 | obelisk
62 | polecat
63 | black_swan
64 | butterfly
65 | mountain_tent
66 | gorilla
67 | sloth_bear
68 | aubergine
69 | stinkhorn
70 | stole
71 | owl
72 | mooli
73 | pool_table
74 | collar
75 | lhasa_apso
76 | ambulance
77 | spade
78 | pufferfish
79 | paint_brush
80 | lark
81 | golf_ball
82 | hock
83 | fork
84 | drake
85 | bee_house
86 | mooncake
87 | wok
88 | cocacola
89 | water_bike
90 | ladder
91 | psp
92 | bassoon
93 | bear
94 | border_terrier
95 | petri_dish
96 | pill_bottle
97 | aircraft_carrier
98 | panther
99 | canoe
100 | baseball_player
101 | turtle
102 | espresso
103 | throne
104 | cornet
105 | coucal
106 | eletrical_switch
107 | bra
108 | snail
109 | backpack
110 | jacamar
111 | scroll_brush
112 | gliding_lizard
113 | raft
114 | pinwheel
115 | grasshopper
116 | green_mamba
117 | eft_newt
118 | computer_mouse
119 | vine_snake
120 | recreational_vehicle
121 | llama
122 | meerkat
123 | chainsaw
124 | ferret
125 | garbage_can
126 | kangaroo
127 | litchi
128 | carbonara
129 | housefinch
130 | modem
131 | tebby_cat
132 | thatch
133 | face_powder
134 | tomb
135 | apple
136 | ladybug
137 | killer_whale
138 | rocket
139 | airship
140 | surfboard
141 | lesser_panda
142 | jordan_logo
143 | banana
144 | nail_scissor
145 | swab
146 | perfume
147 | punching_bag
148 | victor_icon
149 | waffle_iron
150 | trimaran
151 | garlic
152 | flute
153 | langur
154 | starfish
155 | parallel_bars
156 | dandie_dinmont
157 | cosmetic_brush
158 | screwdriver
159 | brick_card
160 | balance_weight
161 | hornet
162 | carton
163 | toothpaste
164 | bracelet
165 | egg_tart
166 | pencil_sharpener2
167 | swimming_glasses
168 | howler_monkey
169 | camel
170 | dragonfly
171 | lionfish
172 | convertible
173 | mule
174 | usb
175 | conch
176 | papaya
177 | garbage_truck
178 | dingo
179 | radiator
180 | solar_dish
181 | streetcar
182 | trilobite
183 | bouzouki
184 | ringlet_butterfly
185 | space_shuttle
186 | waffle
187 | american_staffordshire
188 | violin
189 | flowerpot
190 | forklift
191 | manx
192 | sundial
193 | snowmobile
194 | chickadee_bird
195 | ruffed_grouse
196 | brick_tea
197 | paddle
198 | stove
199 | carousel
200 | spatula
201 | beaker
202 | gas_pump
203 | lawn_mower
204 | speaker
205 | tank
206 | tresher
207 | kappa_logo
208 | hare
209 | tennis_racket
210 | shopping_cart
211 | thimble
212 | tractor
213 | anemone_fish
214 | trolleybus
215 | steak
216 | capuchin
217 | red_breasted_merganser
218 | golden_retriever
219 | light_tube
220 | flatworm
221 | melon_seed
222 | digital_watch
223 | jacko_lantern
224 | brown_bear
225 | cairn
226 | mushroom
227 | chalk
228 | skull
229 | stapler
230 | potato
231 | telescope
232 | proboscis
233 | microphone
234 | torii
235 | baseball_bat
236 | dhole
237 | excavator
238 | fig
239 | snake
240 | bradypod
241 | pepitas
242 | prairie_chicken
243 | scorpion
244 | shotgun
245 | bottle_cap
246 | file_cabinet
247 | grey_whale
248 | one-armed_bandit
249 | banded_gecko
250 | flying_disc
251 | croissant
252 | toothbrush
253 | miniskirt
254 | pokermon_ball
255 | gazelle
256 | grey_fox
257 | esport_chair
258 | necklace
259 | ptarmigan
260 | watermelon
261 | besom
262 | pomelo
263 | radio_telescope
264 | studio_couch
265 | black_stork
266 | vestment
267 | koala
268 | brambling
269 | muscle_car
270 | window_shade
271 | space_heater
272 | sunglasses
273 | motor_scooter
274 | ladyfinger
275 | pencil_box
276 | titi_monkey
277 | chicken_wings
278 | mount_fuji
279 | giant_panda
280 | dart
281 | fire_engine
282 | running_shoe
283 | dumbbell
284 | donkey
285 | loafer
286 | hard_disk
287 | globe
288 | lifeboat
289 | medical_kit
290 | brain_coral
291 | paper_towel
292 | dugong
293 | seatbelt
294 | skunk
295 | military_vest
296 | cocktail_shaker
297 | zucchini
298 | quad_drone
299 | ocicat
300 | shih-tzu
301 | teapot
302 | tile_roof
303 | cheese_burger
304 | handshower
305 | red_wolf
306 | stop_sign
307 | mouse
308 | battery
309 | adidas_logo2
310 | earplug
311 | hummingbird
312 | brush_pen
313 | pistachio
314 | hamster
315 | air_strip
316 | indian_elephant
317 | otter
318 | cucumber
319 | scabbard
320 | hawthorn
321 | bullet_train
322 | leopard
323 | whale
324 | cream
325 | chinese_date
326 | jellyfish
327 | lobster
328 | skua
329 | single_log
330 | chicory
331 | bagel
332 | beacon
333 | pingpong_racket
334 | spoon
335 | yurt
336 | wallaby
337 | egret
338 | christmas_stocking
339 | mcdonald_uncle
340 | wrench
341 | spark_plug
342 | triceratops
343 | wall_clock
344 | jinrikisha
345 | pickup
346 | rhinoceros
347 | swimming_trunk
348 | band-aid
349 | spotted_salamander
350 | leeks
351 | marmot
352 | warthog
353 | cello
354 | stool
355 | chest
356 | toilet_plunger
357 | wardrobe
358 | cannon
359 | adidas_logo1
360 | drumstick
361 | lady_slipper
362 | puma_logo
363 | great_wall
364 | white_shark
365 | witch_hat
366 | vending_machine
367 | wreck
368 | chopsticks
369 | garfish
370 | african_elephant
371 | children_slide
372 | hornbill
373 | zebra
374 | boa_constrictor
375 | armour
376 | pineapple
377 | angora
378 | brick
379 | car_wheel
380 | wallet
381 | boston_bull
382 | hyena
383 | lynx
384 | crash_helmet
385 | terrapin_turtle
386 | persian_cat
387 | shift_gear
388 | cactus_ball
389 | fur_coat
390 | plate
391 | pen
392 | okra
393 | mario
394 | airedale
395 | cowboy_hat
396 | celery
397 | macaque
398 | candle
399 | goose
400 | raccoon
401 | brasscica
402 | almond
403 | maotai_bottle
404 | soccer_ball
405 | sports_car
406 | tobacco_pipe
407 | water_polo
408 | eggnog
409 | hook
410 | ostrich
411 | patas
412 | table_lamp
413 | teddy
414 | mongoose
415 | spoonbill
416 | redheart
417 | crane
418 | dinosaur
419 | kitchen_knife
420 | seal
421 | baboon
422 | golfcart
423 | roller_coaster
424 | avocado
425 | birdhouse
426 | yorkshire_terrier
427 | saluki
428 | basketball
429 | buckler
430 | harvester
431 | afghan_hound
432 | beam_bridge
433 | guinea_pig
434 | lorikeet
435 | shakuhachi
436 | motarboard
437 | statue_liberty
438 | police_car
439 | sulphur_crested
440 | gourd
441 | sombrero
442 | mailbox
443 | adhensive_tape
444 | night_snake
445 | bushtit
446 | mouthpiece
447 | beaver
448 | bathtub
449 | printer
450 | cumquat
451 | orange
452 | cleaver
453 | quill_pen
454 | panpipe
455 | diamond
456 | gypsy_moth
457 | cauliflower
458 | lampshade
459 | cougar
460 | traffic_light
461 | briefcase
462 | ballpoint
463 | african_grey
464 | kremlin
465 | barometer
466 | peacock
467 | paper_crane
468 | sunscreen
469 | tofu
470 | bedlington_terrier
471 | snowball
472 | carrot
473 | tiger
474 | mink
475 | cristo_redentor
476 | ladle
477 | keyboard
478 | maraca
479 | monitor
480 | water_snake
481 | can_opener
482 | mud_turtle
483 | bald_eagle
484 | carp
485 | cn_tower
486 | egyptian_cat
487 | hen_of_the_woods
488 | measuring_cup
489 | roller_skate
490 | kite
491 | sandwich_cookies
492 | sandwich
493 | persimmon
494 | chess_bishop
495 | coffin
496 | ruddy_turnstone
497 | prayer_rug
498 | rain_barrel
499 | neck_brace
500 | nematode
501 | rosehip
502 | dutch_oven
503 | goldfish
504 | blossom_card
505 | dough
506 | trench_coat
507 | sponge
508 | stupa
509 | wash_basin
510 | electric_fan
511 | spring_scroll
512 | potted_plant
513 | sparrow
514 | car_mirror
515 | gecko
516 | diaper
517 | leatherback_turtle
518 | strainer
519 | guacamole
520 | microwave
521 |
--------------------------------------------------------------------------------
/fewshot_data/data/splits/fss/val.txt:
--------------------------------------------------------------------------------
1 | handcuff
2 | mortar
3 | matchstick
4 | wine_bottle
5 | dowitcher
6 | triumphal_arch
7 | gyromitra
8 | hatchet
9 | airliner
10 | broccoli
11 | olive
12 | pubg_lvl3backpack
13 | calculator
14 | toucan
15 | shovel
16 | sewing_machine
17 | icecream
18 | woodpecker
19 | pig
20 | relay_stick
21 | mcdonald_sign
22 | cpu
23 | peanut
24 | pumpkin
25 | sturgeon
26 | hammer
27 | hami_melon
28 | squirrel_monkey
29 | shuriken
30 | power_drill
31 | pingpong_ball
32 | crocodile
33 | carambola
34 | monarch_butterfly
35 | drum
36 | water_tower
37 | panda
38 | toilet_brush
39 | pay_phone
40 | yonex_icon
41 | cricketball
42 | revolver
43 | chimpanzee
44 | crab
45 | corn
46 | baseball
47 | rabbit
48 | croquet_ball
49 | artichoke
50 | abacus
51 | harp
52 | bell
53 | gas_tank
54 | scissors
55 | vase
56 | upright_piano
57 | typewriter
58 | bittern
59 | impala
60 | tray
61 | fire_hydrant
62 | beer_bottle
63 | sock
64 | soup_bowl
65 | spider
66 | cherry
67 | macaw
68 | toilet_seat
69 | fire_balloon
70 | french_ball
71 | fox_squirrel
72 | volleyball
73 | cornmeal
74 | folding_chair
75 | pubg_airdrop
76 | beagle
77 | skateboard
78 | narcissus
79 | whiptail
80 | cup
81 | arabian_camel
82 | badger
83 | stopwatch
84 | ab_wheel
85 | ox
86 | lettuce
87 | monocycle
88 | redshank
89 | vulture
90 | whistle
91 | smoothing_iron
92 | mashed_potato
93 | conveyor
94 | yoga_pad
95 | tow_truck
96 | siamese_cat
97 | cigar
98 | white_stork
99 | sniper_rifle
100 | stretcher
101 | tulip
102 | handkerchief
103 | basset
104 | iceberg
105 | gibbon
106 | lacewing
107 | thrush
108 | cheetah
109 | bighorn_sheep
110 | espresso_maker
111 | pretzel
112 | english_setter
113 | sandbar
114 | cheese
115 | daisy
116 | arctic_fox
117 | briard
118 | colubus
119 | balance_beam
120 | coffeepot
121 | soap_dispenser
122 | yawl
123 | consomme
124 | parking_meter
125 | cactus
126 | turnstile
127 | taro
128 | fire_screen
129 | digital_clock
130 | rose
131 | pomegranate
132 | bee_eater
133 | schooner
134 | ski_mask
135 | jay_bird
136 | plaice
137 | red_fox
138 | syringe
139 | camomile
140 | pickelhaube
141 | blenheim_spaniel
142 | pear
143 | parachute
144 | common_newt
145 | bowtie
146 | cigarette
147 | oscilloscope
148 | laptop
149 | african_crocodile
150 | apron
151 | coconut
152 | sandal
153 | kwanyin
154 | lion
155 | eel
156 | balloon
157 | crepe
158 | armadillo
159 | kazoo
160 | lemon
161 | spider_monkey
162 | tape_player
163 | ipod
164 | bee
165 | sea_cucumber
166 | suitcase
167 | television
168 | pillow
169 | banjo
170 | rock_snake
171 | partridge
172 | platypus
173 | lycaenid_butterfly
174 | pinecone
175 | conversion_plug
176 | wolf
177 | frying_pan
178 | timber_wolf
179 | bluetick
180 | crayon
181 | giant_schnauzer
182 | orang
183 | scarerow
184 | kobe_logo
185 | loguat
186 | saxophone
187 | ceiling_fan
188 | cardoon
189 | equestrian_helmet
190 | louvre_pyramid
191 | hotdog
192 | ironing_board
193 | razor
194 | nagoya_castle
195 | loggerhead_turtle
196 | lipstick
197 | cradle
198 | strongbox
199 | raven
200 | kit_fox
201 | albatross
202 | flat-coated_retriever
203 | beer_glass
204 | ice_lolly
205 | sungnyemun
206 | totem_pole
207 | vacuum
208 | bolete
209 | mango
210 | ginger
211 | weasel
212 | cabbage
213 | refrigerator
214 | school_bus
215 | hippo
216 | tiger_cat
217 | saltshaker
218 | piano_keyboard
219 | windsor_tie
220 | sea_urchin
221 | microsd
222 | barbell
223 | swim_ring
224 | bulbul_bird
225 | water_ouzel
226 | ac_ground
227 | sweatshirt
228 | umbrella
229 | hair_drier
230 | hammerhead_shark
231 | tomato
232 | projector
233 | cushion
234 | dishwasher
235 | three-toed_sloth
236 | tiger_shark
237 | har_gow
238 | baby
239 | thor's_hammer
240 | nike_logo
241 |
--------------------------------------------------------------------------------
/fewshot_data/data/splits/pascal/val/fold0.txt:
--------------------------------------------------------------------------------
1 | 2007_000033__01
2 | 2007_000061__04
3 | 2007_000129__02
4 | 2007_000346__05
5 | 2007_000529__04
6 | 2007_000559__05
7 | 2007_000572__02
8 | 2007_000762__05
9 | 2007_001288__01
10 | 2007_001289__03
11 | 2007_001311__02
12 | 2007_001408__05
13 | 2007_001568__01
14 | 2007_001630__02
15 | 2007_001761__01
16 | 2007_001884__01
17 | 2007_002094__03
18 | 2007_002266__01
19 | 2007_002376__01
20 | 2007_002400__03
21 | 2007_002619__01
22 | 2007_002719__04
23 | 2007_003088__05
24 | 2007_003131__04
25 | 2007_003188__02
26 | 2007_003349__03
27 | 2007_003571__04
28 | 2007_003621__02
29 | 2007_003682__03
30 | 2007_003861__04
31 | 2007_004052__01
32 | 2007_004143__03
33 | 2007_004241__04
34 | 2007_004468__05
35 | 2007_005074__04
36 | 2007_005107__02
37 | 2007_005294__05
38 | 2007_005304__05
39 | 2007_005428__05
40 | 2007_005509__01
41 | 2007_005600__01
42 | 2007_005705__04
43 | 2007_005828__01
44 | 2007_006076__03
45 | 2007_006086__05
46 | 2007_006449__02
47 | 2007_006946__01
48 | 2007_007084__03
49 | 2007_007235__02
50 | 2007_007341__01
51 | 2007_007470__01
52 | 2007_007477__04
53 | 2007_007836__02
54 | 2007_008051__03
55 | 2007_008084__03
56 | 2007_008204__05
57 | 2007_008670__03
58 | 2007_009088__03
59 | 2007_009258__02
60 | 2007_009323__03
61 | 2007_009458__05
62 | 2007_009687__05
63 | 2007_009817__03
64 | 2007_009911__01
65 | 2008_000120__04
66 | 2008_000123__03
67 | 2008_000533__03
68 | 2008_000725__02
69 | 2008_000911__05
70 | 2008_001013__04
71 | 2008_001040__04
72 | 2008_001135__04
73 | 2008_001260__04
74 | 2008_001404__02
75 | 2008_001514__03
76 | 2008_001531__02
77 | 2008_001546__01
78 | 2008_001580__04
79 | 2008_001966__03
80 | 2008_001971__01
81 | 2008_002043__03
82 | 2008_002269__02
83 | 2008_002358__01
84 | 2008_002429__03
85 | 2008_002467__05
86 | 2008_002504__04
87 | 2008_002775__05
88 | 2008_002864__05
89 | 2008_003034__04
90 | 2008_003076__05
91 | 2008_003108__02
92 | 2008_003110__03
93 | 2008_003155__01
94 | 2008_003270__02
95 | 2008_003369__01
96 | 2008_003858__04
97 | 2008_003876__01
98 | 2008_003886__04
99 | 2008_003926__01
100 | 2008_003976__01
101 | 2008_004363__02
102 | 2008_004654__02
103 | 2008_004659__05
104 | 2008_004704__01
105 | 2008_004758__02
106 | 2008_004995__02
107 | 2008_005262__05
108 | 2008_005338__01
109 | 2008_005628__04
110 | 2008_005727__02
111 | 2008_005812__05
112 | 2008_005904__05
113 | 2008_006216__01
114 | 2008_006229__04
115 | 2008_006254__02
116 | 2008_006703__01
117 | 2008_007120__03
118 | 2008_007143__04
119 | 2008_007219__05
120 | 2008_007350__01
121 | 2008_007498__03
122 | 2008_007811__05
123 | 2008_007994__03
124 | 2008_008268__03
125 | 2008_008629__02
126 | 2008_008711__02
127 | 2008_008746__03
128 | 2009_000032__01
129 | 2009_000037__03
130 | 2009_000121__05
131 | 2009_000149__02
132 | 2009_000201__05
133 | 2009_000205__01
134 | 2009_000318__03
135 | 2009_000354__02
136 | 2009_000387__01
137 | 2009_000421__04
138 | 2009_000440__01
139 | 2009_000446__04
140 | 2009_000457__02
141 | 2009_000469__04
142 | 2009_000573__02
143 | 2009_000619__03
144 | 2009_000664__03
145 | 2009_000723__04
146 | 2009_000828__04
147 | 2009_000840__05
148 | 2009_000879__03
149 | 2009_000991__03
150 | 2009_000998__03
151 | 2009_001108__03
152 | 2009_001160__03
153 | 2009_001255__02
154 | 2009_001278__05
155 | 2009_001314__03
156 | 2009_001332__01
157 | 2009_001565__03
158 | 2009_001607__03
159 | 2009_001683__03
160 | 2009_001718__02
161 | 2009_001765__03
162 | 2009_001818__05
163 | 2009_001850__01
164 | 2009_001851__01
165 | 2009_001941__04
166 | 2009_002185__05
167 | 2009_002295__02
168 | 2009_002320__01
169 | 2009_002372__05
170 | 2009_002521__05
171 | 2009_002594__05
172 | 2009_002604__03
173 | 2009_002649__05
174 | 2009_002727__04
175 | 2009_002732__05
176 | 2009_002749__05
177 | 2009_002808__01
178 | 2009_002856__05
179 | 2009_002888__01
180 | 2009_002928__02
181 | 2009_003003__05
182 | 2009_003005__01
183 | 2009_003043__04
184 | 2009_003080__04
185 | 2009_003193__02
186 | 2009_003224__02
187 | 2009_003269__05
188 | 2009_003273__03
189 | 2009_003343__02
190 | 2009_003378__03
191 | 2009_003450__03
192 | 2009_003498__03
193 | 2009_003504__04
194 | 2009_003517__05
195 | 2009_003640__03
196 | 2009_003696__01
197 | 2009_003707__04
198 | 2009_003806__01
199 | 2009_003858__03
200 | 2009_003971__02
201 | 2009_004021__03
202 | 2009_004084__03
203 | 2009_004125__04
204 | 2009_004247__05
205 | 2009_004324__05
206 | 2009_004509__03
207 | 2009_004540__03
208 | 2009_004568__03
209 | 2009_004579__05
210 | 2009_004635__04
211 | 2009_004653__01
212 | 2009_004848__02
213 | 2009_004882__02
214 | 2009_004886__03
215 | 2009_004895__03
216 | 2009_004969__01
217 | 2009_005038__05
218 | 2009_005137__03
219 | 2009_005156__02
220 | 2009_005189__01
221 | 2009_005190__05
222 | 2009_005260__03
223 | 2009_005262__03
224 | 2009_005302__05
225 | 2010_000065__02
226 | 2010_000083__02
227 | 2010_000084__04
228 | 2010_000238__01
229 | 2010_000241__03
230 | 2010_000272__04
231 | 2010_000342__02
232 | 2010_000426__05
233 | 2010_000572__01
234 | 2010_000622__01
235 | 2010_000814__03
236 | 2010_000906__04
237 | 2010_000961__03
238 | 2010_001016__03
239 | 2010_001017__01
240 | 2010_001024__01
241 | 2010_001036__04
242 | 2010_001061__03
243 | 2010_001069__03
244 | 2010_001174__01
245 | 2010_001367__02
246 | 2010_001367__05
247 | 2010_001448__01
248 | 2010_001830__05
249 | 2010_001995__03
250 | 2010_002017__05
251 | 2010_002030__02
252 | 2010_002142__03
253 | 2010_002147__01
254 | 2010_002150__04
255 | 2010_002200__01
256 | 2010_002310__01
257 | 2010_002536__02
258 | 2010_002546__04
259 | 2010_002693__02
260 | 2010_002939__01
261 | 2010_003127__01
262 | 2010_003132__01
263 | 2010_003168__03
264 | 2010_003362__03
265 | 2010_003365__01
266 | 2010_003418__03
267 | 2010_003468__05
268 | 2010_003473__03
269 | 2010_003495__01
270 | 2010_003547__04
271 | 2010_003716__01
272 | 2010_003771__03
273 | 2010_003781__05
274 | 2010_003820__03
275 | 2010_003912__02
276 | 2010_003915__01
277 | 2010_004041__04
278 | 2010_004056__05
279 | 2010_004208__04
280 | 2010_004314__01
281 | 2010_004419__01
282 | 2010_004520__05
283 | 2010_004529__05
284 | 2010_004551__05
285 | 2010_004556__03
286 | 2010_004559__03
287 | 2010_004662__04
288 | 2010_004772__04
289 | 2010_004828__05
290 | 2010_004994__03
291 | 2010_005252__04
292 | 2010_005401__04
293 | 2010_005428__03
294 | 2010_005496__05
295 | 2010_005531__03
296 | 2010_005534__01
297 | 2010_005582__05
298 | 2010_005664__02
299 | 2010_005705__04
300 | 2010_005718__01
301 | 2010_005762__05
302 | 2010_005877__01
303 | 2010_005888__01
304 | 2010_006034__01
305 | 2010_006070__02
306 | 2011_000066__05
307 | 2011_000112__03
308 | 2011_000185__03
309 | 2011_000234__04
310 | 2011_000238__04
311 | 2011_000412__02
312 | 2011_000435__04
313 | 2011_000456__03
314 | 2011_000482__03
315 | 2011_000585__02
316 | 2011_000669__03
317 | 2011_000747__05
318 | 2011_000874__01
319 | 2011_001114__01
320 | 2011_001161__04
321 | 2011_001263__01
322 | 2011_001287__03
323 | 2011_001407__01
324 | 2011_001421__03
325 | 2011_001434__01
326 | 2011_001589__04
327 | 2011_001624__01
328 | 2011_001793__04
329 | 2011_001880__01
330 | 2011_001988__02
331 | 2011_002064__02
332 | 2011_002098__05
333 | 2011_002223__02
334 | 2011_002295__03
335 | 2011_002327__01
336 | 2011_002515__01
337 | 2011_002675__01
338 | 2011_002713__02
339 | 2011_002754__04
340 | 2011_002863__05
341 | 2011_002929__01
342 | 2011_002975__04
343 | 2011_003003__02
344 | 2011_003030__03
345 | 2011_003145__03
346 | 2011_003271__05
347 |
--------------------------------------------------------------------------------
/fewshot_data/data/splits/pascal/val/fold1.txt:
--------------------------------------------------------------------------------
1 | 2007_000452__09
2 | 2007_000464__10
3 | 2007_000491__10
4 | 2007_000663__06
5 | 2007_000663__07
6 | 2007_000727__06
7 | 2007_000727__07
8 | 2007_000804__09
9 | 2007_000830__09
10 | 2007_001299__10
11 | 2007_001321__07
12 | 2007_001457__09
13 | 2007_001677__09
14 | 2007_001717__09
15 | 2007_001763__08
16 | 2007_001774__08
17 | 2007_001884__06
18 | 2007_002268__08
19 | 2007_002387__10
20 | 2007_002445__08
21 | 2007_002470__08
22 | 2007_002539__06
23 | 2007_002597__08
24 | 2007_002643__07
25 | 2007_002903__10
26 | 2007_003011__09
27 | 2007_003051__07
28 | 2007_003101__06
29 | 2007_003106__08
30 | 2007_003137__06
31 | 2007_003143__07
32 | 2007_003169__08
33 | 2007_003195__06
34 | 2007_003201__10
35 | 2007_003503__06
36 | 2007_003503__07
37 | 2007_003621__06
38 | 2007_003711__06
39 | 2007_003786__06
40 | 2007_003841__10
41 | 2007_003917__07
42 | 2007_003991__08
43 | 2007_004193__09
44 | 2007_004392__09
45 | 2007_004405__09
46 | 2007_004510__09
47 | 2007_004712__09
48 | 2007_004856__08
49 | 2007_004866__08
50 | 2007_005074__07
51 | 2007_005114__10
52 | 2007_005296__07
53 | 2007_005331__07
54 | 2007_005460__08
55 | 2007_005547__07
56 | 2007_005547__10
57 | 2007_005844__09
58 | 2007_005845__08
59 | 2007_005911__06
60 | 2007_005978__06
61 | 2007_006035__07
62 | 2007_006086__09
63 | 2007_006241__09
64 | 2007_006260__08
65 | 2007_006277__07
66 | 2007_006348__09
67 | 2007_006553__09
68 | 2007_006761__10
69 | 2007_006841__10
70 | 2007_007414__07
71 | 2007_007417__08
72 | 2007_007524__08
73 | 2007_007815__07
74 | 2007_007818__07
75 | 2007_007996__09
76 | 2007_008106__09
77 | 2007_008110__09
78 | 2007_008543__09
79 | 2007_008722__10
80 | 2007_008747__06
81 | 2007_008815__08
82 | 2007_008897__09
83 | 2007_008973__10
84 | 2007_009015__06
85 | 2007_009015__07
86 | 2007_009068__09
87 | 2007_009084__09
88 | 2007_009096__07
89 | 2007_009221__08
90 | 2007_009245__10
91 | 2007_009346__08
92 | 2007_009392__06
93 | 2007_009392__07
94 | 2007_009413__09
95 | 2007_009521__09
96 | 2007_009764__06
97 | 2007_009794__08
98 | 2007_009897__10
99 | 2007_009923__08
100 | 2007_009938__07
101 | 2008_000009__10
102 | 2008_000073__10
103 | 2008_000075__06
104 | 2008_000107__09
105 | 2008_000149__09
106 | 2008_000182__08
107 | 2008_000345__08
108 | 2008_000401__08
109 | 2008_000464__08
110 | 2008_000501__07
111 | 2008_000673__09
112 | 2008_000853__08
113 | 2008_000919__10
114 | 2008_001078__08
115 | 2008_001433__08
116 | 2008_001439__09
117 | 2008_001513__08
118 | 2008_001640__08
119 | 2008_001715__09
120 | 2008_001885__08
121 | 2008_002152__08
122 | 2008_002205__06
123 | 2008_002212__07
124 | 2008_002379__09
125 | 2008_002521__09
126 | 2008_002623__08
127 | 2008_002681__08
128 | 2008_002778__10
129 | 2008_002958__07
130 | 2008_003141__06
131 | 2008_003141__07
132 | 2008_003333__07
133 | 2008_003477__09
134 | 2008_003499__08
135 | 2008_003577__07
136 | 2008_003777__06
137 | 2008_003821__09
138 | 2008_003846__07
139 | 2008_004069__07
140 | 2008_004339__07
141 | 2008_004552__07
142 | 2008_004612__09
143 | 2008_004701__10
144 | 2008_005097__10
145 | 2008_005105__10
146 | 2008_005245__07
147 | 2008_005676__06
148 | 2008_006008__09
149 | 2008_006063__10
150 | 2008_006254__07
151 | 2008_006325__08
152 | 2008_006341__08
153 | 2008_006480__08
154 | 2008_006528__10
155 | 2008_006554__06
156 | 2008_006986__07
157 | 2008_007025__10
158 | 2008_007031__10
159 | 2008_007048__09
160 | 2008_007123__10
161 | 2008_007194__09
162 | 2008_007273__10
163 | 2008_007378__09
164 | 2008_007402__09
165 | 2008_007527__09
166 | 2008_007548__08
167 | 2008_007596__10
168 | 2008_007737__09
169 | 2008_007797__06
170 | 2008_007804__07
171 | 2008_007828__09
172 | 2008_008252__06
173 | 2008_008301__06
174 | 2008_008469__06
175 | 2008_008682__06
176 | 2009_000013__08
177 | 2009_000080__08
178 | 2009_000219__10
179 | 2009_000309__10
180 | 2009_000335__06
181 | 2009_000335__07
182 | 2009_000426__06
183 | 2009_000455__06
184 | 2009_000457__07
185 | 2009_000523__07
186 | 2009_000641__10
187 | 2009_000716__08
188 | 2009_000731__10
189 | 2009_000771__10
190 | 2009_000825__07
191 | 2009_000964__08
192 | 2009_001008__08
193 | 2009_001082__06
194 | 2009_001240__07
195 | 2009_001255__07
196 | 2009_001299__09
197 | 2009_001391__08
198 | 2009_001411__08
199 | 2009_001536__07
200 | 2009_001775__09
201 | 2009_001804__06
202 | 2009_001816__06
203 | 2009_001854__06
204 | 2009_002035__10
205 | 2009_002122__10
206 | 2009_002150__10
207 | 2009_002164__07
208 | 2009_002171__10
209 | 2009_002221__10
210 | 2009_002238__06
211 | 2009_002238__07
212 | 2009_002239__07
213 | 2009_002268__08
214 | 2009_002346__09
215 | 2009_002415__09
216 | 2009_002487__09
217 | 2009_002527__08
218 | 2009_002535__06
219 | 2009_002549__10
220 | 2009_002571__09
221 | 2009_002618__07
222 | 2009_002635__10
223 | 2009_002753__08
224 | 2009_002936__08
225 | 2009_002990__07
226 | 2009_003003__07
227 | 2009_003059__10
228 | 2009_003071__09
229 | 2009_003269__07
230 | 2009_003304__06
231 | 2009_003387__07
232 | 2009_003406__07
233 | 2009_003494__09
234 | 2009_003507__09
235 | 2009_003542__10
236 | 2009_003549__07
237 | 2009_003569__10
238 | 2009_003589__07
239 | 2009_003703__06
240 | 2009_003771__08
241 | 2009_003773__10
242 | 2009_003849__09
243 | 2009_003895__09
244 | 2009_003904__08
245 | 2009_004072__06
246 | 2009_004140__09
247 | 2009_004217__09
248 | 2009_004248__08
249 | 2009_004455__07
250 | 2009_004504__08
251 | 2009_004590__06
252 | 2009_004594__07
253 | 2009_004687__09
254 | 2009_004721__08
255 | 2009_004732__06
256 | 2009_004748__07
257 | 2009_004789__06
258 | 2009_004859__09
259 | 2009_004867__06
260 | 2009_005158__08
261 | 2009_005219__08
262 | 2009_005231__06
263 | 2010_000003__09
264 | 2010_000160__07
265 | 2010_000163__08
266 | 2010_000372__07
267 | 2010_000427__10
268 | 2010_000530__07
269 | 2010_000552__08
270 | 2010_000573__06
271 | 2010_000628__07
272 | 2010_000639__09
273 | 2010_000682__06
274 | 2010_000683__08
275 | 2010_000724__08
276 | 2010_000907__10
277 | 2010_000941__08
278 | 2010_000952__07
279 | 2010_001000__10
280 | 2010_001010__10
281 | 2010_001070__08
282 | 2010_001206__06
283 | 2010_001292__08
284 | 2010_001331__08
285 | 2010_001351__08
286 | 2010_001403__06
287 | 2010_001403__07
288 | 2010_001534__08
289 | 2010_001553__07
290 | 2010_001579__09
291 | 2010_001646__06
292 | 2010_001656__08
293 | 2010_001692__10
294 | 2010_001699__09
295 | 2010_001767__07
296 | 2010_001851__09
297 | 2010_001913__08
298 | 2010_002017__07
299 | 2010_002017__09
300 | 2010_002025__08
301 | 2010_002137__08
302 | 2010_002146__08
303 | 2010_002305__08
304 | 2010_002336__09
305 | 2010_002348__08
306 | 2010_002361__07
307 | 2010_002390__10
308 | 2010_002422__08
309 | 2010_002512__08
310 | 2010_002531__08
311 | 2010_002546__06
312 | 2010_002623__09
313 | 2010_002693__08
314 | 2010_002693__09
315 | 2010_002763__08
316 | 2010_002763__10
317 | 2010_002868__06
318 | 2010_002900__08
319 | 2010_002902__07
320 | 2010_002921__09
321 | 2010_002929__07
322 | 2010_002988__07
323 | 2010_003123__07
324 | 2010_003183__10
325 | 2010_003231__07
326 | 2010_003239__10
327 | 2010_003275__08
328 | 2010_003276__07
329 | 2010_003293__06
330 | 2010_003302__09
331 | 2010_003325__09
332 | 2010_003381__07
333 | 2010_003402__08
334 | 2010_003409__09
335 | 2010_003446__07
336 | 2010_003453__07
337 | 2010_003468__08
338 | 2010_003531__09
339 | 2010_003675__08
340 | 2010_003746__07
341 | 2010_003758__08
342 | 2010_003764__08
343 | 2010_003768__07
344 | 2010_003772__06
345 | 2010_003781__08
346 | 2010_003813__07
347 | 2010_003854__07
348 | 2010_003971__08
349 | 2010_003971__09
350 | 2010_004104__08
351 | 2010_004120__08
352 | 2010_004320__08
353 | 2010_004322__10
354 | 2010_004348__06
355 | 2010_004369__08
356 | 2010_004472__07
357 | 2010_004479__08
358 | 2010_004635__10
359 | 2010_004763__09
360 | 2010_004783__09
361 | 2010_004789__10
362 | 2010_004815__08
363 | 2010_004825__09
364 | 2010_004861__08
365 | 2010_004946__07
366 | 2010_005013__07
367 | 2010_005021__08
368 | 2010_005021__09
369 | 2010_005063__06
370 | 2010_005108__08
371 | 2010_005118__06
372 | 2010_005160__06
373 | 2010_005166__10
374 | 2010_005284__06
375 | 2010_005344__08
376 | 2010_005421__08
377 | 2010_005432__07
378 | 2010_005501__07
379 | 2010_005508__08
380 | 2010_005606__08
381 | 2010_005709__08
382 | 2010_005718__07
383 | 2010_005860__07
384 | 2010_005899__08
385 | 2010_006070__07
386 | 2011_000178__06
387 | 2011_000226__09
388 | 2011_000239__06
389 | 2011_000248__06
390 | 2011_000312__06
391 | 2011_000338__09
392 | 2011_000419__08
393 | 2011_000503__07
394 | 2011_000548__10
395 | 2011_000566__10
396 | 2011_000607__09
397 | 2011_000661__08
398 | 2011_000661__09
399 | 2011_000780__08
400 | 2011_000789__08
401 | 2011_000809__09
402 | 2011_000813__08
403 | 2011_000813__09
404 | 2011_000830__06
405 | 2011_000843__09
406 | 2011_000888__06
407 | 2011_000900__07
408 | 2011_000969__06
409 | 2011_001047__10
410 | 2011_001064__06
411 | 2011_001071__09
412 | 2011_001110__07
413 | 2011_001159__10
414 | 2011_001232__10
415 | 2011_001292__08
416 | 2011_001341__06
417 | 2011_001346__09
418 | 2011_001447__09
419 | 2011_001530__10
420 | 2011_001534__08
421 | 2011_001546__10
422 | 2011_001567__09
423 | 2011_001597__08
424 | 2011_001601__08
425 | 2011_001607__08
426 | 2011_001665__09
427 | 2011_001708__10
428 | 2011_001775__08
429 | 2011_001782__10
430 | 2011_001812__09
431 | 2011_002041__09
432 | 2011_002064__07
433 | 2011_002124__09
434 | 2011_002200__09
435 | 2011_002298__09
436 | 2011_002322__07
437 | 2011_002343__09
438 | 2011_002358__09
439 | 2011_002391__09
440 | 2011_002509__09
441 | 2011_002592__07
442 | 2011_002644__09
443 | 2011_002685__08
444 | 2011_002812__07
445 | 2011_002885__10
446 | 2011_003011__09
447 | 2011_003019__07
448 | 2011_003019__10
449 | 2011_003055__07
450 | 2011_003103__09
451 | 2011_003114__06
452 |
--------------------------------------------------------------------------------
/fewshot_data/data/splits/pascal/val/fold3.txt:
--------------------------------------------------------------------------------
1 | 2007_000042__19
2 | 2007_000123__19
3 | 2007_000175__17
4 | 2007_000187__20
5 | 2007_000452__18
6 | 2007_000559__20
7 | 2007_000629__19
8 | 2007_000636__19
9 | 2007_000661__18
10 | 2007_000676__17
11 | 2007_000804__18
12 | 2007_000925__17
13 | 2007_001154__18
14 | 2007_001175__20
15 | 2007_001408__16
16 | 2007_001430__16
17 | 2007_001430__20
18 | 2007_001457__18
19 | 2007_001458__18
20 | 2007_001585__18
21 | 2007_001594__17
22 | 2007_001678__20
23 | 2007_001717__20
24 | 2007_001733__17
25 | 2007_001763__18
26 | 2007_001763__20
27 | 2007_002119__20
28 | 2007_002132__20
29 | 2007_002268__18
30 | 2007_002284__16
31 | 2007_002378__16
32 | 2007_002426__18
33 | 2007_002427__18
34 | 2007_002565__19
35 | 2007_002618__17
36 | 2007_002648__17
37 | 2007_002728__19
38 | 2007_003011__18
39 | 2007_003011__20
40 | 2007_003169__18
41 | 2007_003367__16
42 | 2007_003499__19
43 | 2007_003506__16
44 | 2007_003530__18
45 | 2007_003587__19
46 | 2007_003714__17
47 | 2007_003848__19
48 | 2007_003957__19
49 | 2007_004190__20
50 | 2007_004193__20
51 | 2007_004275__16
52 | 2007_004281__19
53 | 2007_004483__19
54 | 2007_004510__20
55 | 2007_004558__16
56 | 2007_004649__19
57 | 2007_004712__16
58 | 2007_004969__17
59 | 2007_005469__17
60 | 2007_005626__19
61 | 2007_005689__19
62 | 2007_005813__16
63 | 2007_005857__16
64 | 2007_005915__17
65 | 2007_006171__18
66 | 2007_006348__20
67 | 2007_006373__18
68 | 2007_006678__17
69 | 2007_006680__19
70 | 2007_006802__19
71 | 2007_007130__20
72 | 2007_007165__17
73 | 2007_007168__19
74 | 2007_007195__19
75 | 2007_007196__20
76 | 2007_007203__20
77 | 2007_007417__18
78 | 2007_007534__17
79 | 2007_007624__16
80 | 2007_007795__16
81 | 2007_007881__19
82 | 2007_007996__18
83 | 2007_008204__20
84 | 2007_008260__18
85 | 2007_008339__19
86 | 2007_008374__20
87 | 2007_008543__18
88 | 2007_008547__16
89 | 2007_009068__18
90 | 2007_009252__18
91 | 2007_009320__17
92 | 2007_009419__16
93 | 2007_009446__20
94 | 2007_009521__18
95 | 2007_009521__20
96 | 2007_009592__18
97 | 2007_009655__18
98 | 2007_009684__18
99 | 2007_009750__16
100 | 2008_000016__20
101 | 2008_000149__18
102 | 2008_000270__18
103 | 2008_000391__16
104 | 2008_000589__18
105 | 2008_000657__19
106 | 2008_001078__16
107 | 2008_001283__16
108 | 2008_001688__16
109 | 2008_001688__20
110 | 2008_001966__16
111 | 2008_002273__16
112 | 2008_002379__16
113 | 2008_002464__20
114 | 2008_002536__17
115 | 2008_002680__20
116 | 2008_002900__19
117 | 2008_002929__18
118 | 2008_003003__20
119 | 2008_003026__20
120 | 2008_003105__19
121 | 2008_003135__16
122 | 2008_003676__16
123 | 2008_003709__18
124 | 2008_003733__18
125 | 2008_003885__20
126 | 2008_004172__18
127 | 2008_004212__19
128 | 2008_004279__20
129 | 2008_004367__19
130 | 2008_004453__17
131 | 2008_004477__16
132 | 2008_004562__18
133 | 2008_004610__19
134 | 2008_004621__17
135 | 2008_004754__20
136 | 2008_004854__17
137 | 2008_004910__20
138 | 2008_005089__20
139 | 2008_005217__16
140 | 2008_005242__16
141 | 2008_005254__20
142 | 2008_005439__20
143 | 2008_005445__20
144 | 2008_005544__19
145 | 2008_005633__17
146 | 2008_005680__16
147 | 2008_006055__19
148 | 2008_006159__20
149 | 2008_006327__17
150 | 2008_006523__19
151 | 2008_006553__19
152 | 2008_006752__19
153 | 2008_006784__18
154 | 2008_006835__17
155 | 2008_007497__17
156 | 2008_007527__20
157 | 2008_007677__17
158 | 2008_007814__17
159 | 2008_007828__20
160 | 2008_008103__18
161 | 2008_008221__19
162 | 2008_008434__16
163 | 2009_000022__19
164 | 2009_000039__17
165 | 2009_000087__18
166 | 2009_000096__18
167 | 2009_000136__20
168 | 2009_000242__18
169 | 2009_000391__20
170 | 2009_000418__16
171 | 2009_000418__18
172 | 2009_000487__18
173 | 2009_000488__16
174 | 2009_000488__20
175 | 2009_000628__19
176 | 2009_000675__17
177 | 2009_000704__20
178 | 2009_000712__19
179 | 2009_000732__18
180 | 2009_000845__19
181 | 2009_000924__17
182 | 2009_001300__19
183 | 2009_001333__19
184 | 2009_001363__20
185 | 2009_001505__17
186 | 2009_001644__16
187 | 2009_001644__18
188 | 2009_001644__20
189 | 2009_001684__16
190 | 2009_001731__18
191 | 2009_001768__17
192 | 2009_001775__16
193 | 2009_001775__18
194 | 2009_001991__17
195 | 2009_002082__17
196 | 2009_002094__20
197 | 2009_002202__19
198 | 2009_002265__19
199 | 2009_002291__19
200 | 2009_002346__18
201 | 2009_002366__20
202 | 2009_002390__18
203 | 2009_002487__16
204 | 2009_002562__20
205 | 2009_002568__19
206 | 2009_002571__16
207 | 2009_002571__18
208 | 2009_002573__20
209 | 2009_002584__16
210 | 2009_002638__19
211 | 2009_002732__18
212 | 2009_002887__19
213 | 2009_002982__19
214 | 2009_003105__19
215 | 2009_003123__18
216 | 2009_003299__19
217 | 2009_003311__19
218 | 2009_003433__19
219 | 2009_003523__20
220 | 2009_003551__20
221 | 2009_003564__16
222 | 2009_003564__18
223 | 2009_003607__18
224 | 2009_003666__17
225 | 2009_003857__20
226 | 2009_003895__18
227 | 2009_003895__20
228 | 2009_003938__19
229 | 2009_004099__18
230 | 2009_004140__18
231 | 2009_004255__19
232 | 2009_004298__18
233 | 2009_004687__18
234 | 2009_004730__19
235 | 2009_004799__19
236 | 2009_004993__18
237 | 2009_004993__20
238 | 2009_005148__19
239 | 2009_005220__19
240 | 2010_000256__18
241 | 2010_000284__18
242 | 2010_000309__17
243 | 2010_000318__20
244 | 2010_000330__16
245 | 2010_000639__16
246 | 2010_000738__20
247 | 2010_000764__19
248 | 2010_001011__17
249 | 2010_001079__17
250 | 2010_001104__19
251 | 2010_001149__18
252 | 2010_001151__19
253 | 2010_001246__16
254 | 2010_001256__17
255 | 2010_001327__18
256 | 2010_001367__20
257 | 2010_001522__17
258 | 2010_001557__17
259 | 2010_001577__17
260 | 2010_001699__16
261 | 2010_001734__19
262 | 2010_001752__20
263 | 2010_001767__18
264 | 2010_001773__16
265 | 2010_001851__16
266 | 2010_001951__19
267 | 2010_001962__18
268 | 2010_002106__17
269 | 2010_002137__16
270 | 2010_002137__18
271 | 2010_002232__17
272 | 2010_002531__18
273 | 2010_002682__19
274 | 2010_002921__20
275 | 2010_003014__18
276 | 2010_003123__16
277 | 2010_003302__16
278 | 2010_003514__19
279 | 2010_003541__17
280 | 2010_003597__18
281 | 2010_003781__16
282 | 2010_003956__19
283 | 2010_004149__19
284 | 2010_004226__17
285 | 2010_004382__16
286 | 2010_004479__20
287 | 2010_004757__16
288 | 2010_004757__18
289 | 2010_004783__18
290 | 2010_004825__16
291 | 2010_004857__20
292 | 2010_004951__19
293 | 2010_004980__19
294 | 2010_005180__18
295 | 2010_005187__16
296 | 2010_005305__20
297 | 2010_005606__18
298 | 2010_005706__19
299 | 2010_005719__17
300 | 2010_005727__19
301 | 2010_005788__17
302 | 2010_005860__16
303 | 2010_005871__19
304 | 2010_005991__18
305 | 2010_006054__19
306 | 2011_000070__18
307 | 2011_000173__18
308 | 2011_000283__19
309 | 2011_000291__19
310 | 2011_000310__18
311 | 2011_000436__17
312 | 2011_000521__19
313 | 2011_000747__16
314 | 2011_001005__18
315 | 2011_001060__19
316 | 2011_001281__19
317 | 2011_001350__17
318 | 2011_001567__18
319 | 2011_001601__18
320 | 2011_001614__19
321 | 2011_001674__18
322 | 2011_001713__16
323 | 2011_001713__18
324 | 2011_001726__20
325 | 2011_001794__18
326 | 2011_001862__18
327 | 2011_001863__16
328 | 2011_001910__20
329 | 2011_002124__18
330 | 2011_002156__20
331 | 2011_002178__17
332 | 2011_002247__19
333 | 2011_002379__19
334 | 2011_002391__18
335 | 2011_002532__20
336 | 2011_002535__19
337 | 2011_002644__18
338 | 2011_002644__20
339 | 2011_002879__18
340 | 2011_002879__20
341 | 2011_003103__16
342 | 2011_003103__18
343 | 2011_003146__19
344 | 2011_003182__18
345 | 2011_003197__19
346 | 2011_003256__18
347 |
--------------------------------------------------------------------------------
/fewshot_data/model/base/conv4d.py:
--------------------------------------------------------------------------------
1 | r""" Implementation of center-pivot 4D convolution """
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 |
7 | class CenterPivotConv4d(nn.Module):
8 | r""" CenterPivot 4D conv"""
9 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding, bias=True):
10 | super(CenterPivotConv4d, self).__init__()
11 |
12 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size[:2], stride=stride[:2],
13 | bias=bias, padding=padding[:2])
14 | self.conv2 = nn.Conv2d(in_channels, out_channels, kernel_size[2:], stride=stride[2:],
15 | bias=bias, padding=padding[2:])
16 |
17 | self.stride34 = stride[2:]
18 | self.kernel_size = kernel_size
19 | self.stride = stride
20 | self.padding = padding
21 | self.idx_initialized = False
22 |
23 | def prune(self, ct):
24 | bsz, ch, ha, wa, hb, wb = ct.size()
25 | if not self.idx_initialized:
26 | idxh = torch.arange(start=0, end=hb, step=self.stride[2:][0], device=ct.device)
27 | idxw = torch.arange(start=0, end=wb, step=self.stride[2:][1], device=ct.device)
28 | self.len_h = len(idxh)
29 | self.len_w = len(idxw)
30 | self.idx = (idxw.repeat(self.len_h, 1) + idxh.repeat(self.len_w, 1).t() * wb).view(-1)
31 | self.idx_initialized = True
32 | ct_pruned = ct.view(bsz, ch, ha, wa, -1).index_select(4, self.idx).view(bsz, ch, ha, wa, self.len_h, self.len_w)
33 |
34 | return ct_pruned
35 |
36 | def forward(self, x):
37 | if self.stride[2:][-1] > 1:
38 | out1 = self.prune(x)
39 | else:
40 | out1 = x
41 | bsz, inch, ha, wa, hb, wb = out1.size()
42 | out1 = out1.permute(0, 4, 5, 1, 2, 3).contiguous().view(-1, inch, ha, wa)
43 | out1 = self.conv1(out1)
44 | outch, o_ha, o_wa = out1.size(-3), out1.size(-2), out1.size(-1)
45 | out1 = out1.view(bsz, hb, wb, outch, o_ha, o_wa).permute(0, 3, 4, 5, 1, 2).contiguous()
46 |
47 | bsz, inch, ha, wa, hb, wb = x.size()
48 | out2 = x.permute(0, 2, 3, 1, 4, 5).contiguous().view(-1, inch, hb, wb)
49 | out2 = self.conv2(out2)
50 | outch, o_hb, o_wb = out2.size(-3), out2.size(-2), out2.size(-1)
51 | out2 = out2.view(bsz, ha, wa, outch, o_hb, o_wb).permute(0, 3, 1, 2, 4, 5).contiguous()
52 |
53 | if out1.size()[-2:] != out2.size()[-2:] and self.padding[-2:] == (0, 0):
54 | out1 = out1.view(bsz, outch, o_ha, o_wa, -1).sum(dim=-1)
55 | out2 = out2.squeeze()
56 |
57 | y = out1 + out2
58 | return y
59 |
--------------------------------------------------------------------------------
/fewshot_data/model/base/correlation.py:
--------------------------------------------------------------------------------
1 | r""" Provides functions that builds/manipulates correlation tensors """
2 | import torch
3 |
4 |
5 | class Correlation:
6 |
7 | @classmethod
8 | def multilayer_correlation(cls, query_feats, support_feats, stack_ids):
9 | eps = 1e-5
10 |
11 | corrs = []
12 | for idx, (query_feat, support_feat) in enumerate(zip(query_feats, support_feats)):
13 | bsz, ch, hb, wb = support_feat.size()
14 | support_feat = support_feat.view(bsz, ch, -1)
15 | support_feat = support_feat / (support_feat.norm(dim=1, p=2, keepdim=True) + eps)
16 |
17 | bsz, ch, ha, wa = query_feat.size()
18 | query_feat = query_feat.view(bsz, ch, -1)
19 | query_feat = query_feat / (query_feat.norm(dim=1, p=2, keepdim=True) + eps)
20 |
21 | corr = torch.bmm(query_feat.transpose(1, 2), support_feat).view(bsz, ha, wa, hb, wb)
22 | corr = corr.clamp(min=0)
23 | corrs.append(corr)
24 |
25 | corr_l4 = torch.stack(corrs[-stack_ids[0]:]).transpose(0, 1).contiguous()
26 | corr_l3 = torch.stack(corrs[-stack_ids[1]:-stack_ids[0]]).transpose(0, 1).contiguous()
27 | corr_l2 = torch.stack(corrs[-stack_ids[2]:-stack_ids[1]]).transpose(0, 1).contiguous()
28 |
29 | return [corr_l4, corr_l3, corr_l2]
30 |
--------------------------------------------------------------------------------
/fewshot_data/model/base/feature.py:
--------------------------------------------------------------------------------
1 | r""" Extracts intermediate features from given backbone network & layer ids """
2 |
3 |
4 | def extract_feat_vgg(img, backbone, feat_ids, bottleneck_ids=None, lids=None):
5 | r""" Extract intermediate features from VGG """
6 | feats = []
7 | feat = img
8 | for lid, module in enumerate(backbone.features):
9 | feat = module(feat)
10 | if lid in feat_ids:
11 | feats.append(feat.clone())
12 | return feats
13 |
14 |
15 | def extract_feat_res(img, backbone, feat_ids, bottleneck_ids, lids):
16 | r""" Extract intermediate features from ResNet"""
17 | feats = []
18 |
19 | # Layer 0
20 | feat = backbone.conv1.forward(img)
21 | feat = backbone.bn1.forward(feat)
22 | feat = backbone.relu.forward(feat)
23 | feat = backbone.maxpool.forward(feat)
24 |
25 | # Layer 1-4
26 | for hid, (bid, lid) in enumerate(zip(bottleneck_ids, lids)):
27 | res = feat
28 | feat = backbone.__getattr__('layer%d' % lid)[bid].conv1.forward(feat)
29 | feat = backbone.__getattr__('layer%d' % lid)[bid].bn1.forward(feat)
30 | feat = backbone.__getattr__('layer%d' % lid)[bid].relu.forward(feat)
31 | feat = backbone.__getattr__('layer%d' % lid)[bid].conv2.forward(feat)
32 | feat = backbone.__getattr__('layer%d' % lid)[bid].bn2.forward(feat)
33 | feat = backbone.__getattr__('layer%d' % lid)[bid].relu.forward(feat)
34 | feat = backbone.__getattr__('layer%d' % lid)[bid].conv3.forward(feat)
35 | feat = backbone.__getattr__('layer%d' % lid)[bid].bn3.forward(feat)
36 |
37 | if bid == 0:
38 | res = backbone.__getattr__('layer%d' % lid)[bid].downsample.forward(res)
39 |
40 | feat += res
41 |
42 | if hid + 1 in feat_ids:
43 | feats.append(feat.clone())
44 |
45 | feat = backbone.__getattr__('layer%d' % lid)[bid].relu.forward(feat)
46 |
47 | return feats
--------------------------------------------------------------------------------
/fewshot_data/model/hsnet.py:
--------------------------------------------------------------------------------
1 | r""" Hypercorrelation Squeeze Network """
2 | from functools import reduce
3 | from operator import add
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | from torchvision.models import resnet
9 | from torchvision.models import vgg
10 |
11 | from fewshot_data.model.base.feature import extract_feat_vgg, extract_feat_res
12 | from fewshot_data.model.base.correlation import Correlation
13 | from fewshot_data.model.learner import HPNLearner
14 |
15 |
16 | class HypercorrSqueezeNetwork(nn.Module):
17 | def __init__(self, backbone, use_original_imgsize):
18 | super(HypercorrSqueezeNetwork, self).__init__()
19 |
20 | # 1. Backbone network initialization
21 | self.backbone_type = backbone
22 | self.use_original_imgsize = use_original_imgsize
23 | if backbone == 'vgg16':
24 | self.backbone = vgg.vgg16(pretrained=True)
25 | self.feat_ids = [17, 19, 21, 24, 26, 28, 30]
26 | self.extract_feats = extract_feat_vgg
27 | nbottlenecks = [2, 2, 3, 3, 3, 1]
28 | elif backbone == 'resnet50':
29 | self.backbone = resnet.resnet50(pretrained=True)
30 | self.feat_ids = list(range(4, 17))
31 | self.extract_feats = extract_feat_res
32 | nbottlenecks = [3, 4, 6, 3]
33 | elif backbone == 'resnet101':
34 | self.backbone = resnet.resnet101(pretrained=True)
35 | self.feat_ids = list(range(4, 34))
36 | self.extract_feats = extract_feat_res
37 | nbottlenecks = [3, 4, 23, 3]
38 | else:
39 | raise Exception('Unavailable backbone: %s' % backbone)
40 |
41 | self.bottleneck_ids = reduce(add, list(map(lambda x: list(range(x)), nbottlenecks)))
42 | self.lids = reduce(add, [[i + 1] * x for i, x in enumerate(nbottlenecks)])
43 | self.stack_ids = torch.tensor(self.lids).bincount().__reversed__().cumsum(dim=0)[:3]
44 | self.backbone.eval()
45 | self.hpn_learner = HPNLearner(list(reversed(nbottlenecks[-3:])))
46 | self.cross_entropy_loss = nn.CrossEntropyLoss()
47 |
48 | def forward(self, query_img, support_img, support_mask):
49 | with torch.no_grad():
50 | query_feats = self.extract_feats(query_img, self.backbone, self.feat_ids, self.bottleneck_ids, self.lids)
51 | support_feats = self.extract_feats(support_img, self.backbone, self.feat_ids, self.bottleneck_ids, self.lids)
52 | support_feats = self.mask_feature(support_feats, support_mask.clone())
53 | corr = Correlation.multilayer_correlation(query_feats, support_feats, self.stack_ids)
54 |
55 | logit_mask = self.hpn_learner(corr)
56 | if not self.use_original_imgsize:
57 | logit_mask = F.interpolate(logit_mask, support_img.size()[2:], mode='bilinear', align_corners=True)
58 |
59 | return logit_mask
60 |
61 | def mask_feature(self, features, support_mask):
62 | for idx, feature in enumerate(features):
63 | mask = F.interpolate(support_mask.unsqueeze(1).float(), feature.size()[2:], mode='bilinear', align_corners=True)
64 | features[idx] = features[idx] * mask
65 | return features
66 |
67 | def predict_mask_nshot(self, batch, nshot):
68 |
69 | # Perform multiple prediction given (nshot) number of different support sets
70 | logit_mask_agg = 0
71 | for s_idx in range(nshot):
72 | logit_mask = self(batch['query_img'], batch['support_imgs'][:, s_idx], batch['support_masks'][:, s_idx])
73 |
74 | if self.use_original_imgsize:
75 | org_qry_imsize = tuple([batch['org_query_imsize'][1].item(), batch['org_query_imsize'][0].item()])
76 | logit_mask = F.interpolate(logit_mask, org_qry_imsize, mode='bilinear', align_corners=True)
77 |
78 | logit_mask_agg += logit_mask.argmax(dim=1).clone()
79 | if nshot == 1: return logit_mask_agg
80 |
81 | # Average & quantize predictions given threshold (=0.5)
82 | bsz = logit_mask_agg.size(0)
83 | max_vote = logit_mask_agg.view(bsz, -1).max(dim=1)[0]
84 | max_vote = torch.stack([max_vote, torch.ones_like(max_vote).long()])
85 | max_vote = max_vote.max(dim=0)[0].view(bsz, 1, 1)
86 | pred_mask = logit_mask_agg.float() / max_vote
87 | pred_mask[pred_mask < 0.5] = 0
88 | pred_mask[pred_mask >= 0.5] = 1
89 |
90 | return pred_mask
91 |
92 | def compute_objective(self, logit_mask, gt_mask):
93 | bsz = logit_mask.size(0)
94 | logit_mask = logit_mask.view(bsz, 2, -1)
95 | gt_mask = gt_mask.view(bsz, -1).long()
96 |
97 | return self.cross_entropy_loss(logit_mask, gt_mask)
98 |
99 | def train_mode(self):
100 | self.train()
101 | self.backbone.eval() # to prevent BN from learning data statistics with exponential averaging
102 |
--------------------------------------------------------------------------------
/fewshot_data/model/learner.py:
--------------------------------------------------------------------------------
1 |
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from fewshot_data.model.base.conv4d import CenterPivotConv4d as Conv4d
6 |
7 |
8 | class HPNLearner(nn.Module):
9 | def __init__(self, inch):
10 | super(HPNLearner, self).__init__()
11 |
12 | def make_building_block(in_channel, out_channels, kernel_sizes, spt_strides, group=4):
13 | assert len(out_channels) == len(kernel_sizes) == len(spt_strides)
14 |
15 | building_block_layers = []
16 | for idx, (outch, ksz, stride) in enumerate(zip(out_channels, kernel_sizes, spt_strides)):
17 | inch = in_channel if idx == 0 else out_channels[idx - 1]
18 | ksz4d = (ksz,) * 4
19 | str4d = (1, 1) + (stride,) * 2
20 | pad4d = (ksz // 2,) * 4
21 |
22 | building_block_layers.append(Conv4d(inch, outch, ksz4d, str4d, pad4d))
23 | building_block_layers.append(nn.GroupNorm(group, outch))
24 | building_block_layers.append(nn.ReLU(inplace=True))
25 |
26 | return nn.Sequential(*building_block_layers)
27 |
28 | outch1, outch2, outch3 = 16, 64, 128
29 |
30 | # Squeezing building blocks
31 | self.encoder_layer4 = make_building_block(inch[0], [outch1, outch2, outch3], [3, 3, 3], [2, 2, 2])
32 | self.encoder_layer3 = make_building_block(inch[1], [outch1, outch2, outch3], [5, 3, 3], [4, 2, 2])
33 | self.encoder_layer2 = make_building_block(inch[2], [outch1, outch2, outch3], [5, 5, 3], [4, 4, 2])
34 |
35 | # Mixing building blocks
36 | self.encoder_layer4to3 = make_building_block(outch3, [outch3, outch3, outch3], [3, 3, 3], [1, 1, 1])
37 | self.encoder_layer3to2 = make_building_block(outch3, [outch3, outch3, outch3], [3, 3, 3], [1, 1, 1])
38 |
39 | # Decoder layers
40 | self.decoder1 = nn.Sequential(nn.Conv2d(outch3, outch3, (3, 3), padding=(1, 1), bias=True),
41 | nn.ReLU(),
42 | nn.Conv2d(outch3, outch2, (3, 3), padding=(1, 1), bias=True),
43 | nn.ReLU())
44 |
45 | self.decoder2 = nn.Sequential(nn.Conv2d(outch2, outch2, (3, 3), padding=(1, 1), bias=True),
46 | nn.ReLU(),
47 | nn.Conv2d(outch2, 2, (3, 3), padding=(1, 1), bias=True))
48 |
49 | def interpolate_support_dims(self, hypercorr, spatial_size=None):
50 | bsz, ch, ha, wa, hb, wb = hypercorr.size()
51 | hypercorr = hypercorr.permute(0, 4, 5, 1, 2, 3).contiguous().view(bsz * hb * wb, ch, ha, wa)
52 | hypercorr = F.interpolate(hypercorr, spatial_size, mode='bilinear', align_corners=True)
53 | o_hb, o_wb = spatial_size
54 | hypercorr = hypercorr.view(bsz, hb, wb, ch, o_hb, o_wb).permute(0, 3, 4, 5, 1, 2).contiguous()
55 | return hypercorr
56 |
57 | def forward(self, hypercorr_pyramid):
58 |
59 | # Encode hypercorrelations from each layer (Squeezing building blocks)
60 | hypercorr_sqz4 = self.encoder_layer4(hypercorr_pyramid[0])
61 | hypercorr_sqz3 = self.encoder_layer3(hypercorr_pyramid[1])
62 | hypercorr_sqz2 = self.encoder_layer2(hypercorr_pyramid[2])
63 |
64 | # Propagate encoded 4D-tensor (Mixing building blocks)
65 | hypercorr_sqz4 = self.interpolate_support_dims(hypercorr_sqz4, hypercorr_sqz3.size()[-4:-2])
66 | hypercorr_mix43 = hypercorr_sqz4 + hypercorr_sqz3
67 | hypercorr_mix43 = self.encoder_layer4to3(hypercorr_mix43)
68 |
69 | hypercorr_mix43 = self.interpolate_support_dims(hypercorr_mix43, hypercorr_sqz2.size()[-4:-2])
70 | hypercorr_mix432 = hypercorr_mix43 + hypercorr_sqz2
71 | hypercorr_mix432 = self.encoder_layer3to2(hypercorr_mix432)
72 |
73 | bsz, ch, ha, wa, hb, wb = hypercorr_mix432.size()
74 | hypercorr_encoded = hypercorr_mix432.view(bsz, ch, ha, wa, -1).mean(dim=-1)
75 |
76 | # Decode the encoded 4D-tensor
77 | hypercorr_decoded = self.decoder1(hypercorr_encoded)
78 | upsample_size = (hypercorr_decoded.size(-1) * 2,) * 2
79 | hypercorr_decoded = F.interpolate(hypercorr_decoded, upsample_size, mode='bilinear', align_corners=True)
80 | logit_mask = self.decoder2(hypercorr_decoded)
81 |
82 | return logit_mask
83 |
--------------------------------------------------------------------------------
/fewshot_data/sbatch_run.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | DATE=`date -d now`
3 | EXP=hsnet
4 | NGPU=4
5 | partition=g24
6 | JOB=fewshot_${EXP}
7 | SAVE_ROOT="save/${EXP}"
8 | SCRIPT_ROOT="sweep_scripts/${EXP}"
9 | mkdir -p $SCRIPT_ROOT
10 | NCPU=$((NGPU * 10))
11 | qos=normal # high normal low
12 |
13 | function print_append {
14 | echo $@ >> $SCRIPT
15 | }
16 |
17 | function slurm_append {
18 | echo $@ >> $SLURM
19 | }
20 |
21 | function print_setup {
22 | SAVE="${SAVE_ROOT}/${JOB}"
23 | SCRIPT="${SCRIPT_ROOT}/${JOB}.sh"
24 | SLURM="${SCRIPT_ROOT}/${JOB}.slrm"
25 | mkdir -p $SAVE
26 | echo `date -d now` $SAVE >> 'submitted.txt'
27 | echo "#!/bin/bash" > $SLURM
28 | slurm_append "#SBATCH --job-name=job1111_${JOB}"
29 | slurm_append "#SBATCH --output=${SAVE}/stdout.txt"
30 | slurm_append "#SBATCH --error=${SAVE}/stderr.txt"
31 | slurm_append "#SBATCH --open-mode=append"
32 | slurm_append "#SBATCH --signal=B:USR1@120"
33 |
34 | slurm_append "#SBATCH -p ${partition}"
35 | slurm_append "#SBATCH --gres=gpu:${NGPU}"
36 | slurm_append "#SBATCH -c ${NCPU}"
37 | slurm_append "#SBATCH -t 02-00"
38 | # slurm_append "#SBATCH -t 01-00"
39 | # slurm_append "#SBATCH -t 00-06"
40 | slurm_append "#SBATCH --qos=${qos}"
41 | slurm_append "srun sh ${SCRIPT}"
42 |
43 | echo "#!/bin/bash" > $SCRIPT
44 | print_append "trap_handler () {"
45 | print_append "echo \"Caught signal: \" \$1"
46 | print_append "# SIGTERM must be bypassed"
47 | print_append "if [ "$1" = "TERM" ]; then"
48 | print_append "echo \"bypass sigterm\""
49 | print_append "else"
50 | print_append "# Submit a new job to the queue"
51 | print_append "echo \"Requeuing \" \$SLURM_JOB_ID"
52 | print_append "scontrol requeue \$SLURM_JOB_ID"
53 | print_append "fi"
54 | print_append "}"
55 | print_append "trap 'trap_handler USR1' USR1"
56 | print_append "trap 'trap_handler TERM' TERM"
57 |
58 | print_append "{"
59 | print_append "source activate pytorch"
60 | print_append "conda activate pytorch"
61 | print_append "export PATH=/home/boyili/programfiles/anaconda3/envs/pytorch/bin:/home/boyili/programfiles/anaconda3/condabin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/games:/usr/local/games:/snap/bin"
62 | print_append "which python"
63 | print_append "echo \$PATH"
64 | print_append "export NCCL_DEBUG=INFO"
65 | print_append "export PYTHONFAULTHANDLER=1"
66 |
67 | echo $JOB
68 | }
69 |
70 | function print_after {
71 | print_append "} & "
72 | print_append "wait \$!"
73 | print_append "sleep 610 &"
74 | print_append "wait \$!"
75 | }
76 |
77 | print_setup
78 | print_append stdbuf -o0 -e0 \
79 | python train.py --log 'log_pascal'
80 | print_after
81 | sbatch $SLURM
82 |
--------------------------------------------------------------------------------
/fewshot_data/test.py:
--------------------------------------------------------------------------------
1 | r""" Hypercorrelation Squeeze testing code """
2 | import argparse
3 |
4 | import torch.nn.functional as F
5 | import torch.nn as nn
6 | import torch
7 |
8 | from fewshot_data.model.hsnet import HypercorrSqueezeNetwork
9 | from fewshot_data.common.logger import Logger, AverageMeter
10 | from fewshot_data.common.vis import Visualizer
11 | from fewshot_data.common.evaluation import Evaluator
12 | from fewshot_data.common import utils
13 | from fewshot_data.data.dataset import FSSDataset
14 |
15 |
16 | def test(model, dataloader, nshot):
17 | r""" Test HSNet """
18 |
19 | # Freeze randomness during testing for reproducibility
20 | utils.fix_randseed(0)
21 | average_meter = AverageMeter(dataloader.dataset)
22 |
23 | for idx, batch in enumerate(dataloader):
24 |
25 | # 1. Hypercorrelation Squeeze Networks forward pass
26 | batch = utils.to_cuda(batch)
27 | pred_mask = model.module.predict_mask_nshot(batch, nshot=nshot)
28 |
29 | assert pred_mask.size() == batch['query_mask'].size()
30 |
31 | # 2. Evaluate prediction
32 | area_inter, area_union = Evaluator.classify_prediction(pred_mask.clone(), batch)
33 | average_meter.update(area_inter, area_union, batch['class_id'], loss=None)
34 | average_meter.write_process(idx, len(dataloader), epoch=-1, write_batch_idx=1)
35 |
36 | # Visualize predictions
37 | if Visualizer.visualize:
38 | Visualizer.visualize_prediction_batch(batch['support_imgs'], batch['support_masks'],
39 | batch['query_img'], batch['query_mask'],
40 | pred_mask, batch['class_id'], idx,
41 | area_inter[1].float() / area_union[1].float())
42 | # Write evaluation results
43 | average_meter.write_result('Test', 0)
44 | miou, fb_iou = average_meter.compute_iou()
45 |
46 | return miou, fb_iou
47 |
48 |
49 | if __name__ == '__main__':
50 |
51 | # Arguments parsing
52 | parser = argparse.ArgumentParser(description='Hypercorrelation Squeeze Pytorch Implementation')
53 | parser.add_argument('--datapath', type=str, default='fewshot_data/Datasets_HSN')
54 | parser.add_argument('--benchmark', type=str, default='pascal', choices=['pascal', 'coco', 'fss'])
55 | parser.add_argument('--logpath', type=str, default='')
56 | parser.add_argument('--bsz', type=int, default=1)
57 | parser.add_argument('--nworker', type=int, default=0)
58 | parser.add_argument('--load', type=str, default='')
59 | parser.add_argument('--fold', type=int, default=0, choices=[0, 1, 2, 3])
60 | parser.add_argument('--nshot', type=int, default=1)
61 | parser.add_argument('--backbone', type=str, default='resnet101', choices=['vgg16', 'resnet50', 'resnet101'])
62 | parser.add_argument('--visualize', action='store_true')
63 | parser.add_argument('--use_original_imgsize', action='store_true')
64 | args = parser.parse_args()
65 | Logger.initialize(args, training=False)
66 |
67 | # Model initialization
68 | model = HypercorrSqueezeNetwork(args.backbone, args.use_original_imgsize)
69 | model.eval()
70 | Logger.log_params(model)
71 |
72 | # Device setup
73 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
74 | Logger.info('# available GPUs: %d' % torch.cuda.device_count())
75 | model = nn.DataParallel(model)
76 | model.to(device)
77 |
78 | # Load trained model
79 | if args.load == '': raise Exception('Pretrained model not specified.')
80 | model.load_state_dict(torch.load(args.load))
81 |
82 | # Helper classes (for testing) initialization
83 | Evaluator.initialize()
84 | Visualizer.initialize(args.visualize)
85 |
86 | # Dataset initialization
87 | FSSDataset.initialize(img_size=400, datapath=args.datapath, use_original_imgsize=args.use_original_imgsize)
88 | dataloader_test = FSSDataset.build_dataloader(args.benchmark, args.bsz, args.nworker, args.fold, 'test', args.nshot)
89 |
90 | # Test HSNet
91 | with torch.no_grad():
92 | test_miou, test_fb_iou = test(model, dataloader_test, args.nshot)
93 | Logger.info('Fold %d mIoU: %5.2f \t FB-IoU: %5.2f' % (args.fold, test_miou.item(), test_fb_iou.item()))
94 | Logger.info('==================== Finished Testing ====================')
95 |
--------------------------------------------------------------------------------
/fewshot_data/train.py:
--------------------------------------------------------------------------------
1 | r""" Hypercorrelation Squeeze training (validation) code """
2 | import argparse
3 |
4 | import torch.optim as optim
5 | import torch.nn as nn
6 | import torch
7 |
8 | from fewshot_data.model.hsnet import HypercorrSqueezeNetwork
9 | from fewshot_data.common.logger import Logger, AverageMeter
10 | from fewshot_data.common.evaluation import Evaluator
11 | from fewshot_data.common import utils
12 | from fewshot_data.data.dataset import FSSDataset
13 |
14 |
15 | def train(epoch, model, dataloader, optimizer, training):
16 | r""" Train HSNet """
17 |
18 | # Force randomness during training / freeze randomness during testing
19 | utils.fix_randseed(None) if training else utils.fix_randseed(0)
20 | model.module.train_mode() if training else model.module.eval()
21 | average_meter = AverageMeter(dataloader.dataset)
22 |
23 | for idx, batch in enumerate(dataloader):
24 | # 1. Hypercorrelation Squeeze Networks forward pass
25 | batch = utils.to_cuda(batch)
26 | logit_mask = model(batch['query_img'], batch['support_imgs'].squeeze(1), batch['support_masks'].squeeze(1))
27 | pred_mask = logit_mask.argmax(dim=1)
28 |
29 | # 2. Compute loss & update model parameters
30 | loss = model.module.compute_objective(logit_mask, batch['query_mask'])
31 | if training:
32 | optimizer.zero_grad()
33 | loss.backward()
34 | optimizer.step()
35 |
36 | # 3. Evaluate prediction
37 | area_inter, area_union = Evaluator.classify_prediction(pred_mask, batch)
38 | average_meter.update(area_inter, area_union, batch['class_id'], loss.detach().clone())
39 | average_meter.write_process(idx, len(dataloader), epoch, write_batch_idx=50)
40 |
41 | # Write evaluation results
42 | average_meter.write_result('Training' if training else 'Validation', epoch)
43 | avg_loss = utils.mean(average_meter.loss_buf)
44 | miou, fb_iou = average_meter.compute_iou()
45 |
46 | return avg_loss, miou, fb_iou
47 |
48 |
49 | if __name__ == '__main__':
50 |
51 | # Arguments parsing
52 | parser = argparse.ArgumentParser(description='Hypercorrelation Squeeze Pytorch Implementation')
53 | parser.add_argument('--datapath', type=str, default='fewshot_data/Datasets_HSN')
54 | parser.add_argument('--benchmark', type=str, default='pascal', choices=['pascal', 'coco', 'fss'])
55 | parser.add_argument('--logpath', type=str, default='')
56 | parser.add_argument('--bsz', type=int, default=20)
57 | parser.add_argument('--lr', type=float, default=1e-3)
58 | parser.add_argument('--niter', type=int, default=2000)
59 | parser.add_argument('--nworker', type=int, default=8)
60 | parser.add_argument('--fold', type=int, default=0, choices=[0, 1, 2, 3])
61 | parser.add_argument('--backbone', type=str, default='resnet101', choices=['vgg16', 'resnet50', 'resnet101'])
62 | args = parser.parse_args()
63 | Logger.initialize(args, training=True)
64 |
65 | # Model initialization
66 | model = HypercorrSqueezeNetwork(args.backbone, False)
67 | Logger.log_params(model)
68 |
69 | # Device setup
70 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
71 | Logger.info('# available GPUs: %d' % torch.cuda.device_count())
72 | model = nn.DataParallel(model)
73 | model.to(device)
74 |
75 | # Helper classes (for training) initialization
76 | optimizer = optim.Adam([{"params": model.parameters(), "lr": args.lr}])
77 | Evaluator.initialize()
78 |
79 | # Dataset initialization
80 | FSSDataset.initialize(img_size=400, datapath=args.datapath, use_original_imgsize=False)
81 | dataloader_trn = FSSDataset.build_dataloader(args.benchmark, args.bsz, args.nworker, args.fold, 'trn')
82 | dataloader_val = FSSDataset.build_dataloader(args.benchmark, args.bsz, args.nworker, args.fold, 'val')
83 |
84 | # Train HSNet
85 | best_val_miou = float('-inf')
86 | best_val_loss = float('inf')
87 | for epoch in range(args.niter):
88 |
89 | trn_loss, trn_miou, trn_fb_iou = train(epoch, model, dataloader_trn, optimizer, training=True)
90 | with torch.no_grad():
91 | val_loss, val_miou, val_fb_iou = train(epoch, model, dataloader_val, optimizer, training=False)
92 |
93 | # Save the best model
94 | if val_miou > best_val_miou:
95 | best_val_miou = val_miou
96 | Logger.save_model_miou(model, epoch, val_miou)
97 |
98 | Logger.tbd_writer.add_scalars('fewshot_data/data/loss', {'trn_loss': trn_loss, 'val_loss': val_loss}, epoch)
99 | Logger.tbd_writer.add_scalars('fewshot_data/data/miou', {'trn_miou': trn_miou, 'val_miou': val_miou}, epoch)
100 | Logger.tbd_writer.add_scalars('fewshot_data/data/fb_iou', {'trn_fb_iou': trn_fb_iou, 'val_fb_iou': val_fb_iou}, epoch)
101 | Logger.tbd_writer.flush()
102 | Logger.tbd_writer.close()
103 | Logger.info('==================== Finished Training ====================')
104 |
--------------------------------------------------------------------------------
/inputs/cat1.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/isl-org/lang-seg/bb00b219f82dc1a18a01f9b0b647bae1d41f3aeb/inputs/cat1.jpeg
--------------------------------------------------------------------------------
/label_files/ade20k_objectInfo150.txt:
--------------------------------------------------------------------------------
1 | Idx,Ratio,Train,Val,Stuff,Name
2 | 1,0.1576,11664,1172,1,wall
3 | 2,0.1072,6046,612,1,building;edifice
4 | 3,0.0878,8265,796,1,sky
5 | 4,0.0621,9336,917,1,floor;flooring
6 | 5,0.0480,6678,641,0,tree
7 | 6,0.0450,6604,643,1,ceiling
8 | 7,0.0398,4023,408,1,road;route
9 | 8,0.0231,1906,199,0,bed
10 | 9,0.0198,4688,460,0,windowpane;window
11 | 10,0.0183,2423,225,1,grass
12 | 11,0.0181,2874,294,0,cabinet
13 | 12,0.0166,3068,310,1,sidewalk;pavement
14 | 13,0.0160,5075,526,0,person;individual;someone;somebody;mortal;soul
15 | 14,0.0151,1804,190,1,earth;ground
16 | 15,0.0118,6666,796,0,door;double;door
17 | 16,0.0110,4269,411,0,table
18 | 17,0.0109,1691,160,1,mountain;mount
19 | 18,0.0104,3999,441,0,plant;flora;plant;life
20 | 19,0.0104,2149,217,0,curtain;drape;drapery;mantle;pall
21 | 20,0.0103,3261,318,0,chair
22 | 21,0.0098,3164,306,0,car;auto;automobile;machine;motorcar
23 | 22,0.0074,709,75,1,water
24 | 23,0.0067,3296,315,0,painting;picture
25 | 24,0.0065,1191,106,0,sofa;couch;lounge
26 | 25,0.0061,1516,162,0,shelf
27 | 26,0.0060,667,69,1,house
28 | 27,0.0053,651,57,1,sea
29 | 28,0.0052,1847,224,0,mirror
30 | 29,0.0046,1158,128,1,rug;carpet;carpeting
31 | 30,0.0044,480,44,1,field
32 | 31,0.0044,1172,98,0,armchair
33 | 32,0.0044,1292,184,0,seat
34 | 33,0.0033,1386,138,0,fence;fencing
35 | 34,0.0031,698,61,0,desk
36 | 35,0.0030,781,73,0,rock;stone
37 | 36,0.0027,380,43,0,wardrobe;closet;press
38 | 37,0.0026,3089,302,0,lamp
39 | 38,0.0024,404,37,0,bathtub;bathing;tub;bath;tub
40 | 39,0.0024,804,99,0,railing;rail
41 | 40,0.0023,1453,153,0,cushion
42 | 41,0.0023,411,37,0,base;pedestal;stand
43 | 42,0.0022,1440,162,0,box
44 | 43,0.0022,800,77,0,column;pillar
45 | 44,0.0020,2650,298,0,signboard;sign
46 | 45,0.0019,549,46,0,chest;of;drawers;chest;bureau;dresser
47 | 46,0.0019,367,36,0,counter
48 | 47,0.0018,311,30,1,sand
49 | 48,0.0018,1181,122,0,sink
50 | 49,0.0018,287,23,1,skyscraper
51 | 50,0.0018,468,38,0,fireplace;hearth;open;fireplace
52 | 51,0.0018,402,43,0,refrigerator;icebox
53 | 52,0.0018,130,12,1,grandstand;covered;stand
54 | 53,0.0018,561,64,1,path
55 | 54,0.0017,880,102,0,stairs;steps
56 | 55,0.0017,86,12,1,runway
57 | 56,0.0017,172,11,0,case;display;case;showcase;vitrine
58 | 57,0.0017,198,18,0,pool;table;billiard;table;snooker;table
59 | 58,0.0017,930,109,0,pillow
60 | 59,0.0015,139,18,0,screen;door;screen
61 | 60,0.0015,564,52,1,stairway;staircase
62 | 61,0.0015,320,26,1,river
63 | 62,0.0015,261,29,1,bridge;span
64 | 63,0.0014,275,22,0,bookcase
65 | 64,0.0014,335,60,0,blind;screen
66 | 65,0.0014,792,75,0,coffee;table;cocktail;table
67 | 66,0.0014,395,49,0,toilet;can;commode;crapper;pot;potty;stool;throne
68 | 67,0.0014,1309,138,0,flower
69 | 68,0.0013,1112,113,0,book
70 | 69,0.0013,266,27,1,hill
71 | 70,0.0013,659,66,0,bench
72 | 71,0.0012,331,31,0,countertop
73 | 72,0.0012,531,56,0,stove;kitchen;stove;range;kitchen;range;cooking;stove
74 | 73,0.0012,369,36,0,palm;palm;tree
75 | 74,0.0012,144,9,0,kitchen;island
76 | 75,0.0011,265,29,0,computer;computing;machine;computing;device;data;processor;electronic;computer;information;processing;system
77 | 76,0.0010,324,33,0,swivel;chair
78 | 77,0.0009,304,27,0,boat
79 | 78,0.0009,170,20,0,bar
80 | 79,0.0009,68,6,0,arcade;machine
81 | 80,0.0009,65,8,1,hovel;hut;hutch;shack;shanty
82 | 81,0.0009,248,25,0,bus;autobus;coach;charabanc;double-decker;jitney;motorbus;motorcoach;omnibus;passenger;vehicle
83 | 82,0.0008,492,49,0,towel
84 | 83,0.0008,2510,269,0,light;light;source
85 | 84,0.0008,440,39,0,truck;motortruck
86 | 85,0.0008,147,18,1,tower
87 | 86,0.0008,583,56,0,chandelier;pendant;pendent
88 | 87,0.0007,533,61,0,awning;sunshade;sunblind
89 | 88,0.0007,1989,239,0,streetlight;street;lamp
90 | 89,0.0007,71,5,0,booth;cubicle;stall;kiosk
91 | 90,0.0007,618,53,0,television;television;receiver;television;set;tv;tv;set;idiot;box;boob;tube;telly;goggle;box
92 | 91,0.0007,135,12,0,airplane;aeroplane;plane
93 | 92,0.0007,83,5,1,dirt;track
94 | 93,0.0007,178,17,0,apparel;wearing;apparel;dress;clothes
95 | 94,0.0006,1003,104,0,pole
96 | 95,0.0006,182,12,1,land;ground;soil
97 | 96,0.0006,452,50,0,bannister;banister;balustrade;balusters;handrail
98 | 97,0.0006,42,6,1,escalator;moving;staircase;moving;stairway
99 | 98,0.0006,307,31,0,ottoman;pouf;pouffe;puff;hassock
100 | 99,0.0006,965,114,0,bottle
101 | 100,0.0006,117,13,0,buffet;counter;sideboard
102 | 101,0.0006,354,35,0,poster;posting;placard;notice;bill;card
103 | 102,0.0006,108,9,1,stage
104 | 103,0.0006,557,55,0,van
105 | 104,0.0006,52,4,0,ship
106 | 105,0.0005,99,5,0,fountain
107 | 106,0.0005,57,4,1,conveyer;belt;conveyor;belt;conveyer;conveyor;transporter
108 | 107,0.0005,292,31,0,canopy
109 | 108,0.0005,77,9,0,washer;automatic;washer;washing;machine
110 | 109,0.0005,340,38,0,plaything;toy
111 | 110,0.0005,66,3,1,swimming;pool;swimming;bath;natatorium
112 | 111,0.0005,465,49,0,stool
113 | 112,0.0005,50,4,0,barrel;cask
114 | 113,0.0005,622,75,0,basket;handbasket
115 | 114,0.0005,80,9,1,waterfall;falls
116 | 115,0.0005,59,3,0,tent;collapsible;shelter
117 | 116,0.0005,531,72,0,bag
118 | 117,0.0005,282,30,0,minibike;motorbike
119 | 118,0.0005,73,7,0,cradle
120 | 119,0.0005,435,44,0,oven
121 | 120,0.0005,136,25,0,ball
122 | 121,0.0005,116,24,0,food;solid;food
123 | 122,0.0004,266,31,0,step;stair
124 | 123,0.0004,58,12,0,tank;storage;tank
125 | 124,0.0004,418,83,0,trade;name;brand;name;brand;marque
126 | 125,0.0004,319,43,0,microwave;microwave;oven
127 | 126,0.0004,1193,139,0,pot;flowerpot
128 | 127,0.0004,97,23,0,animal;animate;being;beast;brute;creature;fauna
129 | 128,0.0004,347,36,0,bicycle;bike;wheel;cycle
130 | 129,0.0004,52,5,1,lake
131 | 130,0.0004,246,22,0,dishwasher;dish;washer;dishwashing;machine
132 | 131,0.0004,108,13,0,screen;silver;screen;projection;screen
133 | 132,0.0004,201,30,0,blanket;cover
134 | 133,0.0004,285,21,0,sculpture
135 | 134,0.0004,268,27,0,hood;exhaust;hood
136 | 135,0.0003,1020,108,0,sconce
137 | 136,0.0003,1282,122,0,vase
138 | 137,0.0003,528,65,0,traffic;light;traffic;signal;stoplight
139 | 138,0.0003,453,57,0,tray
140 | 139,0.0003,671,100,0,ashcan;trash;can;garbage;can;wastebin;ash;bin;ash-bin;ashbin;dustbin;trash;barrel;trash;bin
141 | 140,0.0003,397,44,0,fan
142 | 141,0.0003,92,8,1,pier;wharf;wharfage;dock
143 | 142,0.0003,228,18,0,crt;screen
144 | 143,0.0003,570,59,0,plate
145 | 144,0.0003,217,22,0,monitor;monitoring;device
146 | 145,0.0003,206,19,0,bulletin;board;notice;board
147 | 146,0.0003,130,14,0,shower
148 | 147,0.0003,178,28,0,radiator
149 | 148,0.0002,504,57,0,glass;drinking;glass
150 | 149,0.0002,775,96,0,clock
151 | 150,0.0002,421,56,0,flag
--------------------------------------------------------------------------------
/label_files/fewshot_coco.txt:
--------------------------------------------------------------------------------
1 | person
2 | bicycle
3 | car
4 | motorbike
5 | aeroplane
6 | bus
7 | train
8 | truck
9 | boat
10 | trafficlight
11 | firehydrant
12 | stopsign
13 | parkingmeter
14 | bench
15 | bird
16 | cat
17 | dog
18 | horse
19 | sheep
20 | cow
21 | elephant
22 | bear
23 | zebra
24 | giraffe
25 | backpack
26 | umbrella
27 | handbag
28 | tie
29 | suitcase
30 | frisbee
31 | skis
32 | snowboard
33 | sportsball
34 | kite
35 | baseballbat
36 | baseballglove
37 | skateboard
38 | surfboard
39 | tennisracket
40 | bottle
41 | wineglass
42 | cup
43 | fork
44 | knife
45 | spoon
46 | bowl
47 | banana
48 | apple
49 | sandwich
50 | orange
51 | broccoli
52 | carrot
53 | hotdog
54 | pizza
55 | donut
56 | cake
57 | chair
58 | sofa
59 | pottedplant
60 | bed
61 | diningtable
62 | toilet
63 | tvmonitor
64 | laptop
65 | mouse
66 | remote
67 | keyboard
68 | cellphone
69 | microwave
70 | oven
71 | toaster
72 | sink
73 | refrigerator
74 | book
75 | clock
76 | vase
77 | scissors
78 | teddybear
79 | hairdrier
80 | toothbrush
--------------------------------------------------------------------------------
/label_files/fewshot_pascal.txt:
--------------------------------------------------------------------------------
1 | aeroplane
2 | bicycle
3 | bird
4 | boat
5 | bottle
6 | bus
7 | car
8 | cat
9 | chair
10 | cow
11 | diningtable
12 | dog
13 | horse
14 | motorbike
15 | person
16 | pottedplant
17 | sheep
18 | sofa
19 | train
20 | tvmonitor
--------------------------------------------------------------------------------
/modules/lseg_module.py:
--------------------------------------------------------------------------------
1 | import re
2 | import torch
3 | import torch.nn as nn
4 | import torchvision.transforms as transforms
5 | from argparse import ArgumentParser
6 | import pytorch_lightning as pl
7 | from .lsegmentation_module import LSegmentationModule
8 | from .models.lseg_net import LSegNet
9 | from encoding.models.sseg.base import up_kwargs
10 |
11 | import os
12 | import clip
13 | import numpy as np
14 |
15 | from scipy import signal
16 | import glob
17 |
18 | from PIL import Image
19 | import matplotlib.pyplot as plt
20 | import pandas as pd
21 |
22 |
23 | class LSegModule(LSegmentationModule):
24 | def __init__(self, data_path, dataset, batch_size, base_lr, max_epochs, **kwargs):
25 | super(LSegModule, self).__init__(
26 | data_path, dataset, batch_size, base_lr, max_epochs, **kwargs
27 | )
28 |
29 | if dataset == "citys":
30 | self.base_size = 2048
31 | self.crop_size = 768
32 | else:
33 | self.base_size = 520
34 | self.crop_size = 480
35 |
36 | use_pretrained = True
37 | norm_mean= [0.5, 0.5, 0.5]
38 | norm_std = [0.5, 0.5, 0.5]
39 |
40 | print('** Use norm {}, {} as the mean and std **'.format(norm_mean, norm_std))
41 |
42 | train_transform = [
43 | transforms.ToTensor(),
44 | transforms.Normalize(norm_mean, norm_std),
45 | ]
46 |
47 | val_transform = [
48 | transforms.ToTensor(),
49 | transforms.Normalize(norm_mean, norm_std),
50 | ]
51 |
52 | self.train_transform = transforms.Compose(train_transform)
53 | self.val_transform = transforms.Compose(val_transform)
54 |
55 | self.trainset = self.get_trainset(
56 | dataset,
57 | augment=kwargs["augment"],
58 | base_size=self.base_size,
59 | crop_size=self.crop_size,
60 | )
61 |
62 | self.valset = self.get_valset(
63 | dataset,
64 | augment=kwargs["augment"],
65 | base_size=self.base_size,
66 | crop_size=self.crop_size,
67 | )
68 |
69 | use_batchnorm = (
70 | (not kwargs["no_batchnorm"]) if "no_batchnorm" in kwargs else True
71 | )
72 | # print(kwargs)
73 |
74 | labels = self.get_labels('ade20k')
75 |
76 | self.net = LSegNet(
77 | labels=labels,
78 | backbone=kwargs["backbone"],
79 | features=kwargs["num_features"],
80 | crop_size=self.crop_size,
81 | arch_option=kwargs["arch_option"],
82 | block_depth=kwargs["block_depth"],
83 | activation=kwargs["activation"],
84 | )
85 |
86 | self.net.pretrained.model.patch_embed.img_size = (
87 | self.crop_size,
88 | self.crop_size,
89 | )
90 |
91 | self._up_kwargs = up_kwargs
92 | self.mean = norm_mean
93 | self.std = norm_std
94 |
95 | self.criterion = self.get_criterion(**kwargs)
96 |
97 | def get_labels(self, dataset):
98 | labels = []
99 | path = 'label_files/{}_objectInfo150.txt'.format(dataset)
100 | assert os.path.exists(path), '*** Error : {} not exist !!!'.format(path)
101 | f = open(path, 'r')
102 | lines = f.readlines()
103 | for line in lines:
104 | label = line.strip().split(',')[-1].split(';')[0]
105 | labels.append(label)
106 | f.close()
107 | if dataset in ['ade20k']:
108 | labels = labels[1:]
109 | return labels
110 |
111 |
112 | @staticmethod
113 | def add_model_specific_args(parent_parser):
114 | parser = LSegmentationModule.add_model_specific_args(parent_parser)
115 | parser = ArgumentParser(parents=[parser])
116 |
117 | parser.add_argument(
118 | "--backbone",
119 | type=str,
120 | default="clip_vitl16_384",
121 | help="backbone network",
122 | )
123 |
124 | parser.add_argument(
125 | "--num_features",
126 | type=int,
127 | default=256,
128 | help="number of featurs that go from encoder to decoder",
129 | )
130 |
131 | parser.add_argument("--dropout", type=float, default=0.1, help="dropout rate")
132 |
133 | parser.add_argument(
134 | "--finetune_weights", type=str, help="load weights to finetune from"
135 | )
136 |
137 | parser.add_argument(
138 | "--no-scaleinv",
139 | default=True,
140 | action="store_false",
141 | help="turn off scaleinv layers",
142 | )
143 |
144 | parser.add_argument(
145 | "--no-batchnorm",
146 | default=False,
147 | action="store_true",
148 | help="turn off batchnorm",
149 | )
150 |
151 | parser.add_argument(
152 | "--widehead", default=False, action="store_true", help="wider output head"
153 | )
154 |
155 | parser.add_argument(
156 | "--widehead_hr",
157 | default=False,
158 | action="store_true",
159 | help="wider output head",
160 | )
161 |
162 | parser.add_argument(
163 | "--arch_option",
164 | type=int,
165 | default=0,
166 | help="which kind of architecture to be used",
167 | )
168 |
169 | parser.add_argument(
170 | "--block_depth",
171 | type=int,
172 | default=0,
173 | help="how many blocks should be used",
174 | )
175 |
176 | parser.add_argument(
177 | "--activation",
178 | choices=['lrelu', 'tanh'],
179 | default="lrelu",
180 | help="use which activation to activate the block",
181 | )
182 |
183 | return parser
184 |
--------------------------------------------------------------------------------
/modules/lseg_module_zs.py:
--------------------------------------------------------------------------------
1 | import re
2 | import torch
3 | import torch.nn as nn
4 | import torchvision.transforms as transforms
5 | from argparse import ArgumentParser
6 | import pytorch_lightning as pl
7 | from .lsegmentation_module_zs import LSegmentationModuleZS
8 | from .models.lseg_net_zs import LSegNetZS, LSegRNNetZS
9 | from encoding.models.sseg.base import up_kwargs
10 | import os
11 | import clip
12 | import numpy as np
13 | from scipy import signal
14 | import glob
15 | from PIL import Image
16 | import matplotlib.pyplot as plt
17 | import pandas as pd
18 |
19 |
20 | class LSegModuleZS(LSegmentationModuleZS):
21 | def __init__(self, data_path, dataset, batch_size, base_lr, max_epochs, **kwargs):
22 | super(LSegModuleZS, self).__init__(
23 | data_path, dataset, batch_size, base_lr, max_epochs, **kwargs
24 | )
25 | label_list = self.get_labels(dataset)
26 | self.len_dataloader = len(label_list)
27 |
28 | # print(kwargs)
29 | if kwargs["use_pretrained"] in ['False', False]:
30 | use_pretrained = False
31 | elif kwargs["use_pretrained"] in ['True', True]:
32 | use_pretrained = True
33 |
34 | if kwargs["backbone"] in ["clip_resnet101"]:
35 | self.net = LSegRNNetZS(
36 | label_list=label_list,
37 | backbone=kwargs["backbone"],
38 | features=kwargs["num_features"],
39 | aux=kwargs["aux"],
40 | use_pretrained=use_pretrained,
41 | arch_option=kwargs["arch_option"],
42 | block_depth=kwargs["block_depth"],
43 | activation=kwargs["activation"],
44 | )
45 | else:
46 | self.net = LSegNetZS(
47 | label_list=label_list,
48 | backbone=kwargs["backbone"],
49 | features=kwargs["num_features"],
50 | aux=kwargs["aux"],
51 | use_pretrained=use_pretrained,
52 | arch_option=kwargs["arch_option"],
53 | block_depth=kwargs["block_depth"],
54 | activation=kwargs["activation"],
55 | )
56 |
57 | def get_labels(self, dataset):
58 | labels = []
59 | path = 'label_files/fewshot_{}.txt'.format(dataset)
60 | assert os.path.exists(path), '*** Error : {} not exist !!!'.format(path)
61 | f = open(path, 'r')
62 | lines = f.readlines()
63 | for line in lines:
64 | label = line.strip()
65 | labels.append(label)
66 | f.close()
67 | print(labels)
68 | return labels
69 |
70 | @staticmethod
71 | def add_model_specific_args(parent_parser):
72 | parser = LSegmentationModuleZS.add_model_specific_args(parent_parser)
73 | parser = ArgumentParser(parents=[parser])
74 |
75 | parser.add_argument(
76 | "--backbone",
77 | type=str,
78 | default="vitb16_384",
79 | help="backbone network",
80 | )
81 |
82 | parser.add_argument(
83 | "--num_features",
84 | type=int,
85 | default=256,
86 | help="number of featurs that go from encoder to decoder",
87 | )
88 |
89 | parser.add_argument("--dropout", type=float, default=0.1, help="dropout rate")
90 |
91 | parser.add_argument(
92 | "--finetune_weights", type=str, help="load weights to finetune from"
93 | )
94 |
95 | parser.add_argument(
96 | "--no-scaleinv",
97 | default=True,
98 | action="store_false",
99 | help="turn off scaleinv layers",
100 | )
101 |
102 | parser.add_argument(
103 | "--no-batchnorm",
104 | default=False,
105 | action="store_true",
106 | help="turn off batchnorm",
107 | )
108 |
109 | parser.add_argument(
110 | "--widehead", default=False, action="store_true", help="wider output head"
111 | )
112 |
113 | parser.add_argument(
114 | "--widehead_hr",
115 | default=False,
116 | action="store_true",
117 | help="wider output head",
118 | )
119 |
120 | parser.add_argument(
121 | "--use_pretrained",
122 | type=str,
123 | default="True",
124 | help="whether use the default model to intialize the model",
125 | )
126 |
127 | parser.add_argument(
128 | "--arch_option",
129 | type=int,
130 | default=0,
131 | help="which kind of architecture to be used",
132 | )
133 |
134 | parser.add_argument(
135 | "--block_depth",
136 | type=int,
137 | default=0,
138 | help="how many blocks should be used",
139 | )
140 |
141 | parser.add_argument(
142 | "--activation",
143 | choices=['relu', 'lrelu', 'tanh'],
144 | default="relu",
145 | help="use which activation to activate the block",
146 | )
147 |
148 | return parser
149 |
--------------------------------------------------------------------------------
/modules/lsegmentation_module.py:
--------------------------------------------------------------------------------
1 | import types
2 | import time
3 | import random
4 | import clip
5 | import torch
6 | import torch.nn as nn
7 | import torchvision.transforms as transforms
8 |
9 | from argparse import ArgumentParser
10 |
11 | import pytorch_lightning as pl
12 |
13 | from data import get_dataset, get_available_datasets
14 |
15 | from encoding.models import get_segmentation_model
16 | from encoding.nn import SegmentationLosses
17 |
18 | from encoding.utils import batch_pix_accuracy, batch_intersection_union
19 |
20 | # add mixed precision
21 | import torch.cuda.amp as amp
22 | import numpy as np
23 |
24 | from encoding.utils import SegmentationMetric
25 |
26 | class LSegmentationModule(pl.LightningModule):
27 | def __init__(self, data_path, dataset, batch_size, base_lr, max_epochs, **kwargs):
28 | super().__init__()
29 |
30 | self.data_path = data_path
31 | self.batch_size = batch_size
32 | self.base_lr = base_lr / 16 * batch_size
33 | self.lr = self.base_lr
34 |
35 | self.epochs = max_epochs
36 | self.other_kwargs = kwargs
37 | self.enabled = False #True mixed precision will make things complicated and leading to NAN error
38 | self.scaler = amp.GradScaler(enabled=self.enabled)
39 |
40 | def forward(self, x):
41 | return self.net(x)
42 |
43 | def evaluate(self, x, target=None):
44 | pred = self.net.forward(x)
45 | if isinstance(pred, (tuple, list)):
46 | pred = pred[0]
47 | if target is None:
48 | return pred
49 | correct, labeled = batch_pix_accuracy(pred.data, target.data)
50 | inter, union = batch_intersection_union(pred.data, target.data, self.nclass)
51 |
52 | return correct, labeled, inter, union
53 |
54 | def evaluate_random(self, x, labelset, target=None):
55 | pred = self.net.forward(x, labelset)
56 | if isinstance(pred, (tuple, list)):
57 | pred = pred[0]
58 | if target is None:
59 | return pred
60 | correct, labeled = batch_pix_accuracy(pred.data, target.data)
61 | inter, union = batch_intersection_union(pred.data, target.data, self.nclass)
62 |
63 | return correct, labeled, inter, union
64 |
65 |
66 | def training_step(self, batch, batch_nb):
67 | img, target = batch
68 | with amp.autocast(enabled=self.enabled):
69 | out = self(img)
70 | multi_loss = isinstance(out, tuple)
71 | if multi_loss:
72 | loss = self.criterion(*out, target)
73 | else:
74 | loss = self.criterion(out, target)
75 | loss = self.scaler.scale(loss)
76 | final_output = out[0] if multi_loss else out
77 | train_pred, train_gt = self._filter_invalid(final_output, target)
78 | if train_gt.nelement() != 0:
79 | self.train_accuracy(train_pred, train_gt)
80 | self.log("train_loss", loss)
81 | return loss
82 |
83 | def training_epoch_end(self, outs):
84 | self.log("train_acc_epoch", self.train_accuracy.compute())
85 |
86 | def validation_step(self, batch, batch_nb):
87 | img, target = batch
88 | out = self(img)
89 | multi_loss = isinstance(out, tuple)
90 | if multi_loss:
91 | val_loss = self.criterion(*out, target)
92 | else:
93 | val_loss = self.criterion(out, target)
94 | final_output = out[0] if multi_loss else out
95 | valid_pred, valid_gt = self._filter_invalid(final_output, target)
96 | self.val_iou.update(target, final_output)
97 | pixAcc, iou = self.val_iou.get()
98 | self.log("val_loss_step", val_loss)
99 | self.log("pix_acc_step", pixAcc)
100 | self.log(
101 | "val_acc_step",
102 | self.val_accuracy(valid_pred, valid_gt),
103 | )
104 | self.log("val_iou", iou)
105 |
106 | def validation_epoch_end(self, outs):
107 | pixAcc, iou = self.val_iou.get()
108 | self.log("val_acc_epoch", self.val_accuracy.compute())
109 | self.log("val_iou_epoch", iou)
110 | self.log("pix_acc_epoch", pixAcc)
111 |
112 | self.val_iou.reset()
113 |
114 | def _filter_invalid(self, pred, target):
115 | valid = target != self.other_kwargs["ignore_index"]
116 | _, mx = torch.max(pred, dim=1)
117 | return mx[valid], target[valid]
118 |
119 | def configure_optimizers(self):
120 | params_list = [
121 | {"params": self.net.pretrained.parameters(), "lr": self.base_lr},
122 | ]
123 | if hasattr(self.net, "scratch"):
124 | print("Found output scratch")
125 | params_list.append(
126 | {"params": self.net.scratch.parameters(), "lr": self.base_lr * 10}
127 | )
128 | if hasattr(self.net, "auxlayer"):
129 | print("Found auxlayer")
130 | params_list.append(
131 | {"params": self.net.auxlayer.parameters(), "lr": self.base_lr * 10}
132 | )
133 | if hasattr(self.net, "scale_inv_conv"):
134 | print(self.net.scale_inv_conv)
135 | print("Found scaleinv layers")
136 | params_list.append(
137 | {
138 | "params": self.net.scale_inv_conv.parameters(),
139 | "lr": self.base_lr * 10,
140 | }
141 | )
142 | params_list.append(
143 | {"params": self.net.scale2_conv.parameters(), "lr": self.base_lr * 10}
144 | )
145 | params_list.append(
146 | {"params": self.net.scale3_conv.parameters(), "lr": self.base_lr * 10}
147 | )
148 | params_list.append(
149 | {"params": self.net.scale4_conv.parameters(), "lr": self.base_lr * 10}
150 | )
151 |
152 | if self.other_kwargs["midasproto"]:
153 | print("Using midas optimization protocol")
154 |
155 | opt = torch.optim.Adam(
156 | params_list,
157 | lr=self.base_lr,
158 | betas=(0.9, 0.999),
159 | weight_decay=self.other_kwargs["weight_decay"],
160 | )
161 | sch = torch.optim.lr_scheduler.LambdaLR(
162 | opt, lambda x: pow(1.0 - x / self.epochs, 0.9)
163 | )
164 |
165 | else:
166 | opt = torch.optim.SGD(
167 | params_list,
168 | lr=self.base_lr,
169 | momentum=0.9,
170 | weight_decay=self.other_kwargs["weight_decay"],
171 | )
172 | sch = torch.optim.lr_scheduler.LambdaLR(
173 | opt, lambda x: pow(1.0 - x / self.epochs, 0.9)
174 | )
175 | return [opt], [sch]
176 |
177 | def train_dataloader(self):
178 | return torch.utils.data.DataLoader(
179 | self.trainset,
180 | batch_size=self.batch_size,
181 | shuffle=True,
182 | num_workers=16,
183 | worker_init_fn=lambda x: random.seed(time.time() + x),
184 | )
185 |
186 | def val_dataloader(self):
187 | return torch.utils.data.DataLoader(
188 | self.valset,
189 | batch_size=self.batch_size,
190 | shuffle=False,
191 | num_workers=16,
192 | )
193 |
194 | def get_trainset(self, dset, augment=False, **kwargs):
195 | print(kwargs)
196 | if augment == True:
197 | mode = "train_x"
198 | else:
199 | mode = "train"
200 |
201 | print(mode)
202 | dset = get_dataset(
203 | dset,
204 | root=self.data_path,
205 | split="train",
206 | mode=mode,
207 | transform=self.train_transform,
208 | **kwargs
209 | )
210 |
211 | self.num_classes = dset.num_class
212 | self.train_accuracy = pl.metrics.Accuracy()
213 |
214 | return dset
215 |
216 | def get_valset(self, dset, augment=False, **kwargs):
217 | self.val_accuracy = pl.metrics.Accuracy()
218 | self.val_iou = SegmentationMetric(self.num_classes)
219 |
220 | if augment == True:
221 | mode = "val_x"
222 | else:
223 | mode = "val"
224 |
225 | print(mode)
226 | return get_dataset(
227 | dset,
228 | root=self.data_path,
229 | split="val",
230 | mode=mode,
231 | transform=self.val_transform,
232 | **kwargs
233 | )
234 |
235 |
236 | def get_criterion(self, **kwargs):
237 | return SegmentationLosses(
238 | se_loss=kwargs["se_loss"],
239 | aux=kwargs["aux"],
240 | nclass=self.num_classes,
241 | se_weight=kwargs["se_weight"],
242 | aux_weight=kwargs["aux_weight"],
243 | ignore_index=kwargs["ignore_index"],
244 | )
245 |
246 | @staticmethod
247 | def add_model_specific_args(parent_parser):
248 | parser = ArgumentParser(parents=[parent_parser], add_help=False)
249 | parser.add_argument(
250 | "--data_path", type=str, help="path where dataset is stored"
251 | )
252 | parser.add_argument(
253 | "--dataset",
254 | choices=get_available_datasets(),
255 | default="ade20k",
256 | help="dataset to train on",
257 | )
258 | parser.add_argument(
259 | "--batch_size", type=int, default=16, help="size of the batches"
260 | )
261 | parser.add_argument(
262 | "--base_lr", type=float, default=0.004, help="learning rate"
263 | )
264 | parser.add_argument("--momentum", type=float, default=0.9, help="SGD momentum")
265 | parser.add_argument(
266 | "--weight_decay", type=float, default=1e-4, help="weight_decay"
267 | )
268 | parser.add_argument(
269 | "--aux", action="store_true", default=False, help="Auxilary Loss"
270 | )
271 | parser.add_argument(
272 | "--aux-weight",
273 | type=float,
274 | default=0.2,
275 | help="Auxilary loss weight (default: 0.2)",
276 | )
277 | parser.add_argument(
278 | "--se-loss",
279 | action="store_true",
280 | default=False,
281 | help="Semantic Encoding Loss SE-loss",
282 | )
283 | parser.add_argument(
284 | "--se-weight", type=float, default=0.2, help="SE-loss weight (default: 0.2)"
285 | )
286 |
287 | parser.add_argument(
288 | "--midasproto", action="store_true", default=False, help="midasprotocol"
289 | )
290 |
291 | parser.add_argument(
292 | "--ignore_index",
293 | type=int,
294 | default=-1,
295 | help="numeric value of ignore label in gt",
296 | )
297 | parser.add_argument(
298 | "--augment",
299 | action="store_true",
300 | default=False,
301 | help="Use extended augmentations",
302 | )
303 |
304 | return parser
305 |
--------------------------------------------------------------------------------
/modules/models/lseg_blocks.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from .lseg_vit import (
5 | _make_pretrained_clip_vitl16_384,
6 | _make_pretrained_clip_vitb32_384,
7 | _make_pretrained_clipRN50x16_vitl16_384,
8 | forward_vit,
9 | )
10 |
11 |
12 | def _make_encoder(
13 | backbone,
14 | features,
15 | use_pretrained=True,
16 | groups=1,
17 | expand=False,
18 | exportable=True,
19 | hooks=None,
20 | use_vit_only=False,
21 | use_readout="ignore",
22 | enable_attention_hooks=False,
23 | ):
24 | if backbone == "clip_vitl16_384":
25 | clip_pretrained, pretrained = _make_pretrained_clip_vitl16_384(
26 | use_pretrained,
27 | hooks=hooks,
28 | use_readout=use_readout,
29 | enable_attention_hooks=enable_attention_hooks,
30 | )
31 | scratch = _make_scratch(
32 | [256, 512, 1024, 1024], features, groups=groups, expand=expand
33 | )
34 | elif backbone == "clipRN50x16_vitl16_384":
35 | clip_pretrained, pretrained = _make_pretrained_clipRN50x16_vitl16_384(
36 | use_pretrained,
37 | hooks=hooks,
38 | use_readout=use_readout,
39 | enable_attention_hooks=enable_attention_hooks,
40 | )
41 | scratch = _make_scratch(
42 | [256, 512, 1024, 1024], features, groups=groups, expand=expand
43 | )
44 | elif backbone == "clip_vitb32_384":
45 | clip_pretrained, pretrained = _make_pretrained_clip_vitb32_384(
46 | use_pretrained,
47 | hooks=hooks,
48 | use_readout=use_readout,
49 | )
50 | scratch = _make_scratch(
51 | [96, 192, 384, 768], features, groups=groups, expand=expand
52 | )
53 | else:
54 | print(f"Backbone '{backbone}' not implemented")
55 | assert False
56 |
57 | return clip_pretrained, pretrained, scratch
58 |
59 |
60 | def _make_scratch(in_shape, out_shape, groups=1, expand=False):
61 | scratch = nn.Module()
62 |
63 | out_shape1 = out_shape
64 | out_shape2 = out_shape
65 | out_shape3 = out_shape
66 | out_shape4 = out_shape
67 | if expand == True:
68 | out_shape1 = out_shape
69 | out_shape2 = out_shape * 2
70 | out_shape3 = out_shape * 4
71 | out_shape4 = out_shape * 8
72 |
73 | scratch.layer1_rn = nn.Conv2d(
74 | in_shape[0],
75 | out_shape1,
76 | kernel_size=3,
77 | stride=1,
78 | padding=1,
79 | bias=False,
80 | groups=groups,
81 | )
82 | scratch.layer2_rn = nn.Conv2d(
83 | in_shape[1],
84 | out_shape2,
85 | kernel_size=3,
86 | stride=1,
87 | padding=1,
88 | bias=False,
89 | groups=groups,
90 | )
91 | scratch.layer3_rn = nn.Conv2d(
92 | in_shape[2],
93 | out_shape3,
94 | kernel_size=3,
95 | stride=1,
96 | padding=1,
97 | bias=False,
98 | groups=groups,
99 | )
100 | scratch.layer4_rn = nn.Conv2d(
101 | in_shape[3],
102 | out_shape4,
103 | kernel_size=3,
104 | stride=1,
105 | padding=1,
106 | bias=False,
107 | groups=groups,
108 | )
109 |
110 | return scratch
111 |
112 |
113 | class Interpolate(nn.Module):
114 | """Interpolation module."""
115 |
116 | def __init__(self, scale_factor, mode, align_corners=False):
117 | """Init.
118 |
119 | Args:
120 | scale_factor (float): scaling
121 | mode (str): interpolation mode
122 | """
123 | super(Interpolate, self).__init__()
124 |
125 | self.interp = nn.functional.interpolate
126 | self.scale_factor = scale_factor
127 | self.mode = mode
128 | self.align_corners = align_corners
129 |
130 | def forward(self, x):
131 | """Forward pass.
132 |
133 | Args:
134 | x (tensor): input
135 |
136 | Returns:
137 | tensor: interpolated data
138 | """
139 |
140 | x = self.interp(
141 | x,
142 | scale_factor=self.scale_factor,
143 | mode=self.mode,
144 | align_corners=self.align_corners,
145 | )
146 |
147 | return x
148 |
149 |
150 | class ResidualConvUnit(nn.Module):
151 | """Residual convolution module."""
152 |
153 | def __init__(self, features):
154 | """Init.
155 |
156 | Args:
157 | features (int): number of features
158 | """
159 | super().__init__()
160 |
161 | self.conv1 = nn.Conv2d(
162 | features, features, kernel_size=3, stride=1, padding=1, bias=True
163 | )
164 |
165 | self.conv2 = nn.Conv2d(
166 | features, features, kernel_size=3, stride=1, padding=1, bias=True
167 | )
168 |
169 | self.relu = nn.ReLU(inplace=True)
170 |
171 | def forward(self, x):
172 | """Forward pass.
173 |
174 | Args:
175 | x (tensor): input
176 |
177 | Returns:
178 | tensor: output
179 | """
180 | out = self.relu(x)
181 | out = self.conv1(out)
182 | out = self.relu(out)
183 | out = self.conv2(out)
184 |
185 | return out + x
186 |
187 |
188 | class FeatureFusionBlock(nn.Module):
189 | """Feature fusion block."""
190 |
191 | def __init__(self, features):
192 | """Init.
193 |
194 | Args:
195 | features (int): number of features
196 | """
197 | super(FeatureFusionBlock, self).__init__()
198 |
199 | self.resConfUnit1 = ResidualConvUnit(features)
200 | self.resConfUnit2 = ResidualConvUnit(features)
201 |
202 | def forward(self, *xs):
203 | """Forward pass.
204 |
205 | Returns:
206 | tensor: output
207 | """
208 | output = xs[0]
209 |
210 | if len(xs) == 2:
211 | output += self.resConfUnit1(xs[1])
212 |
213 | output = self.resConfUnit2(output)
214 |
215 | output = nn.functional.interpolate(
216 | output, scale_factor=2, mode="bilinear", align_corners=True
217 | )
218 |
219 | return output
220 |
221 |
222 | class ResidualConvUnit_custom(nn.Module):
223 | """Residual convolution module."""
224 |
225 | def __init__(self, features, activation, bn):
226 | """Init.
227 |
228 | Args:
229 | features (int): number of features
230 | """
231 | super().__init__()
232 |
233 | self.bn = bn
234 |
235 | self.groups = 1
236 |
237 | self.conv1 = nn.Conv2d(
238 | features,
239 | features,
240 | kernel_size=3,
241 | stride=1,
242 | padding=1,
243 | bias=not self.bn,
244 | groups=self.groups,
245 | )
246 |
247 | self.conv2 = nn.Conv2d(
248 | features,
249 | features,
250 | kernel_size=3,
251 | stride=1,
252 | padding=1,
253 | bias=not self.bn,
254 | groups=self.groups,
255 | )
256 |
257 | if self.bn == True:
258 | self.bn1 = nn.BatchNorm2d(features)
259 | self.bn2 = nn.BatchNorm2d(features)
260 |
261 | self.activation = activation
262 |
263 | self.skip_add = nn.quantized.FloatFunctional()
264 |
265 | def forward(self, x):
266 | """Forward pass.
267 |
268 | Args:
269 | x (tensor): input
270 |
271 | Returns:
272 | tensor: output
273 | """
274 |
275 | out = self.activation(x)
276 | out = self.conv1(out)
277 | if self.bn == True:
278 | out = self.bn1(out)
279 |
280 | out = self.activation(out)
281 | out = self.conv2(out)
282 | if self.bn == True:
283 | out = self.bn2(out)
284 |
285 | if self.groups > 1:
286 | out = self.conv_merge(out)
287 |
288 | return self.skip_add.add(out, x)
289 |
290 | # return out + x
291 |
292 |
293 | class FeatureFusionBlock_custom(nn.Module):
294 | """Feature fusion block."""
295 |
296 | def __init__(
297 | self,
298 | features,
299 | activation,
300 | deconv=False,
301 | bn=False,
302 | expand=False,
303 | align_corners=True,
304 | ):
305 | """Init.
306 |
307 | Args:
308 | features (int): number of features
309 | """
310 | super(FeatureFusionBlock_custom, self).__init__()
311 |
312 | self.deconv = deconv
313 | self.align_corners = align_corners
314 |
315 | self.groups = 1
316 |
317 | self.expand = expand
318 | out_features = features
319 | if self.expand == True:
320 | out_features = features // 2
321 |
322 | self.out_conv = nn.Conv2d(
323 | features,
324 | out_features,
325 | kernel_size=1,
326 | stride=1,
327 | padding=0,
328 | bias=True,
329 | groups=1,
330 | )
331 |
332 | self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
333 | self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
334 |
335 | self.skip_add = nn.quantized.FloatFunctional()
336 |
337 | def forward(self, *xs):
338 | """Forward pass.
339 |
340 | Returns:
341 | tensor: output
342 | """
343 | output = xs[0]
344 |
345 | if len(xs) == 2:
346 | res = self.resConfUnit1(xs[1])
347 | output = self.skip_add.add(output, res)
348 | # output += res
349 |
350 | output = self.resConfUnit2(output)
351 |
352 | output = nn.functional.interpolate(
353 | output, scale_factor=2, mode="bilinear", align_corners=self.align_corners
354 | )
355 |
356 | output = self.out_conv(output)
357 |
358 | return output
359 |
360 |
--------------------------------------------------------------------------------
/modules/models/lseg_blocks_zs.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from .lseg_vit_zs import (
5 | _make_pretrained_clip_vitl16_384,
6 | _make_pretrained_clip_vitb32_384,
7 | _make_pretrained_clip_rn101,
8 | forward_vit,
9 | )
10 |
11 | def _make_encoder(
12 | backbone,
13 | features,
14 | use_pretrained,
15 | groups=1,
16 | expand=False,
17 | exportable=True,
18 | hooks=None,
19 | use_vit_only=False,
20 | use_readout="ignore",
21 | enable_attention_hooks=False,
22 | ):
23 | if backbone == "clip_vitl16_384":
24 | clip_pretrained, pretrained = _make_pretrained_clip_vitl16_384(
25 | use_pretrained,
26 | hooks=hooks,
27 | use_readout=use_readout,
28 | enable_attention_hooks=enable_attention_hooks,
29 | )
30 | scratch = _make_scratch(
31 | [256, 512, 1024, 1024], features, groups=groups, expand=expand
32 | )
33 | elif backbone == "clip_vitb32_384":
34 | clip_pretrained, pretrained = _make_pretrained_clip_vitb32_384(
35 | use_pretrained,
36 | hooks=hooks,
37 | use_readout=use_readout,
38 | )
39 | scratch = _make_scratch(
40 | [96, 192, 384, 768], features, groups=groups, expand=expand
41 | )
42 | elif backbone == "clip_resnet101":
43 | clip_pretrained, pretrained = _make_pretrained_clip_rn101(
44 | use_pretrained,
45 | )
46 | scratch = _make_scratch(
47 | [256, 512, 1024, 2048], features, groups=groups, expand=expand
48 | )
49 | else:
50 | print(f"Backbone '{backbone}' not implemented")
51 | assert False
52 |
53 | return clip_pretrained, pretrained, scratch
54 |
55 |
56 | def _make_scratch(in_shape, out_shape, groups=1, expand=False):
57 | scratch = nn.Module()
58 |
59 | out_shape1 = out_shape
60 | out_shape2 = out_shape
61 | out_shape3 = out_shape
62 | out_shape4 = out_shape
63 | if expand == True:
64 | out_shape1 = out_shape
65 | out_shape2 = out_shape * 2
66 | out_shape3 = out_shape * 4
67 | out_shape4 = out_shape * 8
68 |
69 | scratch.layer1_rn = nn.Conv2d(
70 | in_shape[0],
71 | out_shape1,
72 | kernel_size=3,
73 | stride=1,
74 | padding=1,
75 | bias=False,
76 | groups=groups,
77 | )
78 | scratch.layer2_rn = nn.Conv2d(
79 | in_shape[1],
80 | out_shape2,
81 | kernel_size=3,
82 | stride=1,
83 | padding=1,
84 | bias=False,
85 | groups=groups,
86 | )
87 | scratch.layer3_rn = nn.Conv2d(
88 | in_shape[2],
89 | out_shape3,
90 | kernel_size=3,
91 | stride=1,
92 | padding=1,
93 | bias=False,
94 | groups=groups,
95 | )
96 | scratch.layer4_rn = nn.Conv2d(
97 | in_shape[3],
98 | out_shape4,
99 | kernel_size=3,
100 | stride=1,
101 | padding=1,
102 | bias=False,
103 | groups=groups,
104 | )
105 |
106 | return scratch
107 |
108 |
109 | def _make_resnet_backbone(resnet):
110 | pretrained = nn.Module()
111 | pretrained.layer1 = nn.Sequential(
112 | resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1
113 | )
114 |
115 | pretrained.layer2 = resnet.layer2
116 | pretrained.layer3 = resnet.layer3
117 | pretrained.layer4 = resnet.layer4
118 |
119 | return pretrained
120 |
121 |
122 | def _make_pretrained_resnext101_wsl(use_pretrained):
123 | resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl")
124 | return _make_resnet_backbone(resnet)
125 |
126 |
127 | class Interpolate(nn.Module):
128 | """Interpolation module."""
129 |
130 | def __init__(self, scale_factor, mode, align_corners=False):
131 | """Init.
132 |
133 | Args:
134 | scale_factor (float): scaling
135 | mode (str): interpolation mode
136 | """
137 | super(Interpolate, self).__init__()
138 |
139 | self.interp = nn.functional.interpolate
140 | self.scale_factor = scale_factor
141 | self.mode = mode
142 | self.align_corners = align_corners
143 |
144 | def forward(self, x):
145 | """Forward pass.
146 |
147 | Args:
148 | x (tensor): input
149 |
150 | Returns:
151 | tensor: interpolated data
152 | """
153 |
154 | x = self.interp(
155 | x,
156 | scale_factor=self.scale_factor,
157 | mode=self.mode,
158 | align_corners=self.align_corners,
159 | )
160 |
161 | return x
162 |
163 |
164 | class ResidualConvUnit(nn.Module):
165 | """Residual convolution module."""
166 |
167 | def __init__(self, features):
168 | """Init.
169 |
170 | Args:
171 | features (int): number of features
172 | """
173 | super().__init__()
174 |
175 | self.conv1 = nn.Conv2d(
176 | features, features, kernel_size=3, stride=1, padding=1, bias=True
177 | )
178 |
179 | self.conv2 = nn.Conv2d(
180 | features, features, kernel_size=3, stride=1, padding=1, bias=True
181 | )
182 |
183 | self.relu = nn.ReLU(inplace=True)
184 |
185 | def forward(self, x):
186 | """Forward pass.
187 |
188 | Args:
189 | x (tensor): input
190 |
191 | Returns:
192 | tensor: output
193 | """
194 | out = self.relu(x)
195 | out = self.conv1(out)
196 | out = self.relu(out)
197 | out = self.conv2(out)
198 |
199 | return out + x
200 |
201 |
202 | class FeatureFusionBlock(nn.Module):
203 | """Feature fusion block."""
204 |
205 | def __init__(self, features):
206 | """Init.
207 |
208 | Args:
209 | features (int): number of features
210 | """
211 | super(FeatureFusionBlock, self).__init__()
212 |
213 | self.resConfUnit1 = ResidualConvUnit(features)
214 | self.resConfUnit2 = ResidualConvUnit(features)
215 |
216 | def forward(self, *xs):
217 | """Forward pass.
218 |
219 | Returns:
220 | tensor: output
221 | """
222 | output = xs[0]
223 |
224 | if len(xs) == 2:
225 | output += self.resConfUnit1(xs[1])
226 |
227 | output = self.resConfUnit2(output)
228 |
229 | output = nn.functional.interpolate(
230 | output, scale_factor=2, mode="bilinear", align_corners=True
231 | )
232 |
233 | return output
234 |
235 |
236 | class ResidualConvUnit_custom(nn.Module):
237 | """Residual convolution module."""
238 |
239 | def __init__(self, features, activation, bn):
240 | """Init.
241 |
242 | Args:
243 | features (int): number of features
244 | """
245 | super().__init__()
246 |
247 | self.bn = bn
248 |
249 | self.groups = 1
250 |
251 | self.conv1 = nn.Conv2d(
252 | features,
253 | features,
254 | kernel_size=3,
255 | stride=1,
256 | padding=1,
257 | bias=not self.bn,
258 | groups=self.groups,
259 | )
260 |
261 | self.conv2 = nn.Conv2d(
262 | features,
263 | features,
264 | kernel_size=3,
265 | stride=1,
266 | padding=1,
267 | bias=not self.bn,
268 | groups=self.groups,
269 | )
270 |
271 | if self.bn == True:
272 | self.bn1 = nn.BatchNorm2d(features)
273 | self.bn2 = nn.BatchNorm2d(features)
274 |
275 | self.activation = activation
276 |
277 | self.skip_add = nn.quantized.FloatFunctional()
278 |
279 | def forward(self, x):
280 | """Forward pass.
281 |
282 | Args:
283 | x (tensor): input
284 |
285 | Returns:
286 | tensor: output
287 | """
288 |
289 | out = self.activation(x)
290 | out = self.conv1(out)
291 | if self.bn == True:
292 | out = self.bn1(out)
293 |
294 | out = self.activation(out)
295 | out = self.conv2(out)
296 | if self.bn == True:
297 | out = self.bn2(out)
298 |
299 | if self.groups > 1:
300 | out = self.conv_merge(out)
301 |
302 | return self.skip_add.add(out, x)
303 |
304 | # return out + x
305 |
306 |
307 | class FeatureFusionBlock_custom(nn.Module):
308 | """Feature fusion block."""
309 |
310 | def __init__(
311 | self,
312 | features,
313 | activation,
314 | deconv=False,
315 | bn=False,
316 | expand=False,
317 | align_corners=True,
318 | ):
319 | """Init.
320 |
321 | Args:
322 | features (int): number of features
323 | """
324 | super(FeatureFusionBlock_custom, self).__init__()
325 |
326 | self.deconv = deconv
327 | self.align_corners = align_corners
328 |
329 | self.groups = 1
330 |
331 | self.expand = expand
332 | out_features = features
333 | if self.expand == True:
334 | out_features = features // 2
335 |
336 | self.out_conv = nn.Conv2d(
337 | features,
338 | out_features,
339 | kernel_size=1,
340 | stride=1,
341 | padding=0,
342 | bias=True,
343 | groups=1,
344 | )
345 |
346 | self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
347 | self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
348 |
349 | self.skip_add = nn.quantized.FloatFunctional()
350 |
351 | def forward(self, *xs):
352 | """Forward pass.
353 |
354 | Returns:
355 | tensor: output
356 | """
357 | output = xs[0]
358 |
359 | if len(xs) == 2:
360 | res = self.resConfUnit1(xs[1])
361 | output = self.skip_add.add(output, res)
362 | # output += res
363 |
364 | output = self.resConfUnit2(output)
365 |
366 | output = nn.functional.interpolate(
367 | output, scale_factor=2, mode="bilinear", align_corners=self.align_corners
368 | )
369 |
370 | output = self.out_conv(output)
371 |
372 | return output
373 |
374 |
--------------------------------------------------------------------------------
/modules/models/lseg_net.py:
--------------------------------------------------------------------------------
1 | import math
2 | import types
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 | from .lseg_blocks import FeatureFusionBlock, Interpolate, _make_encoder, FeatureFusionBlock_custom, forward_vit
9 | import clip
10 | import numpy as np
11 | import pandas as pd
12 | import os
13 |
14 | class depthwise_clipseg_conv(nn.Module):
15 | def __init__(self):
16 | super(depthwise_clipseg_conv, self).__init__()
17 | self.depthwise = nn.Conv2d(1, 1, kernel_size=3, padding=1)
18 |
19 | def depthwise_clipseg(self, x, channels):
20 | x = torch.cat([self.depthwise(x[:, i].unsqueeze(1)) for i in range(channels)], dim=1)
21 | return x
22 |
23 | def forward(self, x):
24 | channels = x.shape[1]
25 | out = self.depthwise_clipseg(x, channels)
26 | return out
27 |
28 |
29 | class depthwise_conv(nn.Module):
30 | def __init__(self, kernel_size=3, stride=1, padding=1):
31 | super(depthwise_conv, self).__init__()
32 | self.depthwise = nn.Conv2d(1, 1, kernel_size=kernel_size, stride=stride, padding=padding)
33 |
34 | def forward(self, x):
35 | # support for 4D tensor with NCHW
36 | C, H, W = x.shape[1:]
37 | x = x.reshape(-1, 1, H, W)
38 | x = self.depthwise(x)
39 | x = x.view(-1, C, H, W)
40 | return x
41 |
42 |
43 | class depthwise_block(nn.Module):
44 | def __init__(self, kernel_size=3, stride=1, padding=1, activation='relu'):
45 | super(depthwise_block, self).__init__()
46 | self.depthwise = depthwise_conv(kernel_size=3, stride=1, padding=1)
47 | if activation == 'relu':
48 | self.activation = nn.ReLU()
49 | elif activation == 'lrelu':
50 | self.activation = nn.LeakyReLU()
51 | elif activation == 'tanh':
52 | self.activation = nn.Tanh()
53 |
54 | def forward(self, x, act=True):
55 | x = self.depthwise(x)
56 | if act:
57 | x = self.activation(x)
58 | return x
59 |
60 |
61 | class bottleneck_block(nn.Module):
62 | def __init__(self, kernel_size=3, stride=1, padding=1, activation='relu'):
63 | super(bottleneck_block, self).__init__()
64 | self.depthwise = depthwise_conv(kernel_size=3, stride=1, padding=1)
65 | if activation == 'relu':
66 | self.activation = nn.ReLU()
67 | elif activation == 'lrelu':
68 | self.activation = nn.LeakyReLU()
69 | elif activation == 'tanh':
70 | self.activation = nn.Tanh()
71 |
72 |
73 | def forward(self, x, act=True):
74 | sum_layer = x.max(dim=1, keepdim=True)[0]
75 | x = self.depthwise(x)
76 | x = x + sum_layer
77 | if act:
78 | x = self.activation(x)
79 | return x
80 |
81 | class BaseModel(torch.nn.Module):
82 | def load(self, path):
83 | """Load model from file.
84 | Args:
85 | path (str): file path
86 | """
87 | parameters = torch.load(path, map_location=torch.device("cpu"))
88 |
89 | if "optimizer" in parameters:
90 | parameters = parameters["model"]
91 |
92 | self.load_state_dict(parameters)
93 |
94 | def _make_fusion_block(features, use_bn):
95 | return FeatureFusionBlock_custom(
96 | features,
97 | activation=nn.ReLU(False),
98 | deconv=False,
99 | bn=use_bn,
100 | expand=False,
101 | align_corners=True,
102 | )
103 |
104 | class LSeg(BaseModel):
105 | def __init__(
106 | self,
107 | head,
108 | features=256,
109 | backbone="clip_vitl16_384",
110 | readout="project",
111 | channels_last=False,
112 | use_bn=False,
113 | **kwargs,
114 | ):
115 | super(LSeg, self).__init__()
116 |
117 | self.channels_last = channels_last
118 |
119 | hooks = {
120 | "clip_vitl16_384": [5, 11, 17, 23],
121 | "clipRN50x16_vitl16_384": [5, 11, 17, 23],
122 | "clip_vitb32_384": [2, 5, 8, 11],
123 | }
124 |
125 | # Instantiate backbone and reassemble blocks
126 | self.clip_pretrained, self.pretrained, self.scratch = _make_encoder(
127 | backbone,
128 | features,
129 | groups=1,
130 | expand=False,
131 | exportable=False,
132 | hooks=hooks[backbone],
133 | use_readout=readout,
134 | )
135 |
136 | self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
137 | self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
138 | self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
139 | self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
140 |
141 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)).exp()
142 | if backbone in ["clipRN50x16_vitl16_384"]:
143 | self.out_c = 768
144 | else:
145 | self.out_c = 512
146 | self.scratch.head1 = nn.Conv2d(features, self.out_c, kernel_size=1)
147 |
148 | self.arch_option = kwargs["arch_option"]
149 | if self.arch_option == 1:
150 | self.scratch.head_block = bottleneck_block(activation=kwargs["activation"])
151 | self.block_depth = kwargs['block_depth']
152 | elif self.arch_option == 2:
153 | self.scratch.head_block = depthwise_block(activation=kwargs["activation"])
154 | self.block_depth = kwargs['block_depth']
155 |
156 | self.scratch.output_conv = head
157 |
158 | self.text = clip.tokenize(self.labels)
159 |
160 | def forward(self, x, labelset=''):
161 | if labelset == '':
162 | text = self.text
163 | else:
164 | text = clip.tokenize(labelset)
165 |
166 | if self.channels_last == True:
167 | x.contiguous(memory_format=torch.channels_last)
168 |
169 | layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x)
170 |
171 | layer_1_rn = self.scratch.layer1_rn(layer_1)
172 | layer_2_rn = self.scratch.layer2_rn(layer_2)
173 | layer_3_rn = self.scratch.layer3_rn(layer_3)
174 | layer_4_rn = self.scratch.layer4_rn(layer_4)
175 |
176 | path_4 = self.scratch.refinenet4(layer_4_rn)
177 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
178 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
179 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
180 |
181 | text = text.to(x.device)
182 | self.logit_scale = self.logit_scale.to(x.device)
183 | text_features = self.clip_pretrained.encode_text(text)
184 |
185 | image_features = self.scratch.head1(path_1)
186 |
187 | imshape = image_features.shape
188 | image_features = image_features.permute(0,2,3,1).reshape(-1, self.out_c)
189 |
190 | # normalized features
191 | image_features = image_features / image_features.norm(dim=-1, keepdim=True)
192 | text_features = text_features / text_features.norm(dim=-1, keepdim=True)
193 |
194 | logits_per_image = self.logit_scale * image_features.half() @ text_features.t()
195 |
196 | out = logits_per_image.float().view(imshape[0], imshape[2], imshape[3], -1).permute(0,3,1,2)
197 |
198 | if self.arch_option in [1, 2]:
199 | for _ in range(self.block_depth - 1):
200 | out = self.scratch.head_block(out)
201 | out = self.scratch.head_block(out, False)
202 |
203 | out = self.scratch.output_conv(out)
204 |
205 | return out
206 |
207 |
208 | class LSegNet(LSeg):
209 | """Network for semantic segmentation."""
210 | def __init__(self, labels, path=None, scale_factor=0.5, crop_size=480, **kwargs):
211 |
212 | features = kwargs["features"] if "features" in kwargs else 256
213 | kwargs["use_bn"] = True
214 |
215 | self.crop_size = crop_size
216 | self.scale_factor = scale_factor
217 | self.labels = labels
218 |
219 | head = nn.Sequential(
220 | Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
221 | )
222 |
223 | super().__init__(head, **kwargs)
224 |
225 | if path is not None:
226 | self.load(path)
227 |
228 |
229 |
230 |
231 |
--------------------------------------------------------------------------------
/prepare_ade20k.py:
--------------------------------------------------------------------------------
1 | # +
2 | # revised from https://github.com/zhanghang1989/PyTorch-Encoding/blob/331ecdd5306104614cb414b16fbcd9d1a8d40e1e/scripts/prepare_ade20k.py
3 |
4 | """Prepare ADE20K dataset"""
5 | import os
6 | import shutil
7 | import argparse
8 | import zipfile
9 | from encoding.utils import download, mkdir
10 | # -
11 |
12 | _TARGET_DIR = os.path.expanduser('../datasets/')
13 |
14 | def parse_args():
15 | parser = argparse.ArgumentParser(
16 | description='Initialize ADE20K dataset.',
17 | epilog='Example: python prepare_ade20k.py',
18 | formatter_class=argparse.ArgumentDefaultsHelpFormatter)
19 | parser.add_argument('--download-dir', default=None, help='dataset directory on disk')
20 | args = parser.parse_args()
21 | return args
22 |
23 | def download_ade(path, overwrite=False):
24 | _AUG_DOWNLOAD_URLS = [
25 | ('http://data.csail.mit.edu/places/ADEchallenge/ADEChallengeData2016.zip', '219e1696abb36c8ba3a3afe7fb2f4b4606a897c7'),
26 | ('http://data.csail.mit.edu/places/ADEchallenge/release_test.zip', 'e05747892219d10e9243933371a497e905a4860c'),]
27 | download_dir = path
28 | mkdir(download_dir)
29 | for url, checksum in _AUG_DOWNLOAD_URLS:
30 | filename = download(url, path=download_dir, overwrite=overwrite, sha1_hash=checksum)
31 | # extract
32 | with zipfile.ZipFile(filename,"r") as zip_ref:
33 | zip_ref.extractall(path=path)
34 |
35 |
36 | if __name__ == '__main__':
37 | args = parse_args()
38 | mkdir(os.path.expanduser('../datasets/'))
39 | if args.download_dir is not None:
40 | if os.path.isdir(_TARGET_DIR):
41 | os.remove(_TARGET_DIR)
42 | # make symlink
43 | os.symlink(args.download_dir, _TARGET_DIR)
44 | else:
45 | download_ade(_TARGET_DIR, overwrite=False)
46 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py==0.14.1
2 | aiohttp==3.7.4.post0
3 | anyio==3.3.4
4 | argon2-cffi==21.1.0
5 | async-timeout==3.0.1
6 | attrs==21.2.0
7 | Babel==2.9.1
8 | backcall==0.2.0
9 | bleach==4.1.0
10 | cachetools==4.2.4
11 | certifi==2021.10.8
12 | cffi==1.15.0
13 | chardet==4.0.0
14 | charset-normalizer==2.0.7
15 | clip @ git+https://github.com/openai/CLIP.git@04f4dc2ca1ed0acc9893bd1a3b526a7e02c4bb10
16 | cycler==0.10.0
17 | debugpy==1.5.0
18 | decorator==5.1.0
19 | defusedxml==0.7.1
20 | entrypoints==0.3
21 | fsspec==2021.10.1
22 | ftfy==6.0.3
23 | future==0.18.2
24 | google-auth==2.3.0
25 | google-auth-oauthlib==0.4.6
26 | grpcio==1.41.0
27 | idna==3.3
28 | imageio==2.9.0
29 | ipykernel==6.4.1
30 | ipython==7.28.0
31 | ipython-genutils==0.2.0
32 | ipywidgets==7.6.5
33 | jedi==0.18.0
34 | Jinja2==3.0.2
35 | json5==0.9.6
36 | jsonschema==4.1.0
37 | jupyter==1.0.0
38 | jupyter-client==7.0.6
39 | jupyter-console==6.4.0
40 | jupyter-core==4.8.1
41 | jupyter-server==1.11.1
42 | jupyterlab==3.2.0
43 | jupyterlab-pygments==0.1.2
44 | jupyterlab-server==2.8.2
45 | jupyterlab-widgets==1.0.2
46 | kiwisolver==1.3.2
47 | Markdown==3.3.4
48 | MarkupSafe==2.0.1
49 | matplotlib==3.4.3
50 | matplotlib-inline==0.1.3
51 | mistune==0.8.4
52 | multidict==5.2.0
53 | nbclassic==0.3.2
54 | nbclient==0.5.4
55 | nbconvert==6.2.0
56 | nbformat==5.1.3
57 | nest-asyncio==1.5.1
58 | nose==1.3.7
59 | notebook==6.4.4
60 | numpy==1.21.2
61 | oauthlib==3.1.1
62 | packaging==21.0
63 | pandas==1.3.4
64 | pandocfilters==1.5.0
65 | parso==0.8.2
66 | pexpect==4.8.0
67 | pickleshare==0.7.5
68 | Pillow==8.4.0
69 | portalocker==2.3.2
70 | prometheus-client==0.11.0
71 | prompt-toolkit==3.0.20
72 | protobuf==3.18.1
73 | ptyprocess==0.7.0
74 | pyasn1==0.4.8
75 | pyasn1-modules==0.2.8
76 | pycparser==2.20
77 | pyDeprecate==0.3.1
78 | Pygments==2.10.0
79 | pyparsing==2.4.7
80 | pyrsistent==0.18.0
81 | python-dateutil==2.8.2
82 | pytorch-lightning==1.4.9
83 | pytz==2021.3
84 | PyYAML==6.0
85 | pyzmq==22.3.0
86 | qtconsole==5.1.1
87 | QtPy==1.11.2
88 | regex==2021.10.8
89 | requests==2.26.0
90 | requests-oauthlib==1.3.0
91 | requests-unixsocket==0.2.0
92 | rsa==4.7.2
93 | scipy==1.7.1
94 | Send2Trash==1.8.0
95 | six==1.16.0
96 | sniffio==1.2.0
97 | tensorboard==2.7.0
98 | tensorboard-data-server==0.6.1
99 | tensorboard-plugin-wit==1.8.0
100 | terminado==0.12.1
101 | testpath==0.5.0
102 | timm==0.4.12
103 | torch==1.9.1+cu111
104 | torch-encoding @ git+https://github.com/zhanghang1989/PyTorch-Encoding/@331ecdd5306104614cb414b16fbcd9d1a8d40e1e
105 | torchaudio==0.9.1
106 | torchmetrics==0.5.1
107 | torchvision==0.10.1+cu111
108 | tornado==6.1
109 | tqdm==4.62.3
110 | traitlets==5.1.0
111 | typing-extensions==3.10.0.2
112 | urllib3==1.26.7
113 | wcwidth==0.2.5
114 | webencodings==0.5.1
115 | websocket-client==1.2.1
116 | Werkzeug==2.0.2
117 | widgetsnbextension==3.5.1
118 | yarl==1.7.0
119 |
--------------------------------------------------------------------------------
/test.sh:
--------------------------------------------------------------------------------
1 | export CUDA_VISIBLE_DEVICES=0; python test_lseg.py --backbone clip_vitl16_384 --eval --dataset ade20k --data-path ../datasets/ \
2 | --weights checkpoints/lseg_ade20k_l16.ckpt --widehead --no-scaleinv
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/test_lseg_zs.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import numpy as np
4 | from tqdm import tqdm
5 | import torch
6 | import torch.nn.functional as F
7 | import torch.nn as nn
8 | from modules.lseg_module_zs import LSegModuleZS
9 | from additional_utils.models import LSeg_MultiEvalModule
10 | from fewshot_data.common.logger import Logger, AverageMeter
11 | from fewshot_data.common.vis import Visualizer
12 | from fewshot_data.common.evaluation import Evaluator
13 | from fewshot_data.common import utils
14 | from fewshot_data.data.dataset import FSSDataset
15 |
16 |
17 | class Options:
18 | def __init__(self):
19 | parser = argparse.ArgumentParser(description="PyTorch Segmentation")
20 | # model and dataset
21 | parser.add_argument(
22 | "--model", type=str, default="encnet", help="model name (default: encnet)"
23 | )
24 | parser.add_argument(
25 | "--backbone",
26 | type=str,
27 | default="resnet50",
28 | help="backbone name (default: resnet50)",
29 | )
30 | parser.add_argument(
31 | "--dataset",
32 | type=str,
33 | default="ade20k",
34 | help="dataset name (default: pascal12)",
35 | )
36 | parser.add_argument(
37 | "--workers", type=int, default=16, metavar="N", help="dataloader threads"
38 | )
39 | parser.add_argument(
40 | "--base-size", type=int, default=520, help="base image size"
41 | )
42 | parser.add_argument(
43 | "--crop-size", type=int, default=480, help="crop image size"
44 | )
45 | parser.add_argument(
46 | "--train-split",
47 | type=str,
48 | default="train",
49 | help="dataset train split (default: train)",
50 | )
51 | # training hyper params
52 | parser.add_argument(
53 | "--aux", action="store_true", default=False, help="Auxilary Loss"
54 | )
55 | parser.add_argument(
56 | "--se-loss",
57 | action="store_true",
58 | default=False,
59 | help="Semantic Encoding Loss SE-loss",
60 | )
61 | parser.add_argument(
62 | "--se-weight", type=float, default=0.2, help="SE-loss weight (default: 0.2)"
63 | )
64 | parser.add_argument(
65 | "--batch-size",
66 | type=int,
67 | default=16,
68 | metavar="N",
69 | help="input batch size for \
70 | training (default: auto)",
71 | )
72 | parser.add_argument(
73 | "--test-batch-size",
74 | type=int,
75 | default=16,
76 | metavar="N",
77 | help="input batch size for \
78 | testing (default: same as batch size)",
79 | )
80 | # cuda, seed and logging
81 | parser.add_argument(
82 | "--no-cuda",
83 | action="store_true",
84 | default=False,
85 | help="disables CUDA training",
86 | )
87 | parser.add_argument(
88 | "--seed", type=int, default=1, metavar="S", help="random seed (default: 1)"
89 | )
90 | # checking point
91 | parser.add_argument(
92 | "--weights", type=str, default=None, help="checkpoint to test"
93 | )
94 | # evaluation option
95 | parser.add_argument(
96 | "--eval", action="store_true", default=False, help="evaluating mIoU"
97 | )
98 |
99 | parser.add_argument(
100 | "--acc-bn",
101 | action="store_true",
102 | default=False,
103 | help="Re-accumulate BN statistics",
104 | )
105 | parser.add_argument(
106 | "--test-val",
107 | action="store_true",
108 | default=False,
109 | help="generate masks on val set",
110 | )
111 | parser.add_argument(
112 | "--no-val",
113 | action="store_true",
114 | default=False,
115 | help="skip validation during training",
116 | )
117 |
118 | parser.add_argument(
119 | "--module",
120 | default='',
121 | help="select model definition",
122 | )
123 |
124 | # test option
125 | parser.add_argument(
126 | "--no-scaleinv",
127 | dest="scale_inv",
128 | default=True,
129 | action="store_false",
130 | help="turn off scaleinv layers",
131 | )
132 |
133 | parser.add_argument(
134 | "--widehead", default=False, action="store_true", help="wider output head"
135 | )
136 |
137 | parser.add_argument(
138 | "--widehead_hr",
139 | default=False,
140 | action="store_true",
141 | help="wider output head",
142 | )
143 |
144 | parser.add_argument(
145 | "--ignore_index",
146 | type=int,
147 | default=-1,
148 | help="numeric value of ignore label in gt",
149 | )
150 |
151 | parser.add_argument(
152 | "--jobname",
153 | type=str,
154 | default="default",
155 | help="select which dataset",
156 | )
157 |
158 | parser.add_argument(
159 | "--no-strict",
160 | dest="strict",
161 | default=True,
162 | action="store_false",
163 | help="no-strict copy the model",
164 | )
165 |
166 | parser.add_argument(
167 | "--use_pretrained",
168 | type=str,
169 | default="True",
170 | help="whether use the default model to intialize the model",
171 | )
172 |
173 | parser.add_argument(
174 | "--arch_option",
175 | type=int,
176 | default=0,
177 | help="which kind of architecture to be used",
178 | )
179 |
180 | # fewshot options
181 | parser.add_argument(
182 | '--nshot',
183 | type=int,
184 | default=1
185 | )
186 | parser.add_argument(
187 | '--fold',
188 | type=int,
189 | default=0,
190 | choices=[0, 1, 2, 3]
191 | )
192 | parser.add_argument(
193 | '--nworker',
194 | type=int,
195 | default=0
196 | )
197 | parser.add_argument(
198 | '--bsz',
199 | type=int,
200 | default=1
201 | )
202 | parser.add_argument(
203 | '--benchmark',
204 | type=str,
205 | default='pascal',
206 | choices=['pascal', 'coco', 'fss', 'c2p']
207 | )
208 | parser.add_argument(
209 | '--datapath',
210 | type=str,
211 | default='fewshot_data/Datasets_HSN'
212 | )
213 |
214 | parser.add_argument(
215 | "--activation",
216 | choices=['relu', 'lrelu', 'tanh'],
217 | default="relu",
218 | help="use which activation to activate the block",
219 | )
220 |
221 |
222 | self.parser = parser
223 |
224 | def parse(self):
225 | args = self.parser.parse_args()
226 | args.cuda = not args.no_cuda and torch.cuda.is_available()
227 | print(args)
228 | return args
229 |
230 |
231 | def test(args):
232 | module_def = LSegModuleZS
233 |
234 | module = module_def.load_from_checkpoint(
235 | checkpoint_path=args.weights,
236 | data_path=args.datapath,
237 | dataset=args.dataset,
238 | backbone=args.backbone,
239 | aux=args.aux,
240 | num_features=256,
241 | aux_weight=0,
242 | se_loss=False,
243 | se_weight=0,
244 | base_lr=0,
245 | batch_size=1,
246 | max_epochs=0,
247 | ignore_index=args.ignore_index,
248 | dropout=0.0,
249 | scale_inv=args.scale_inv,
250 | augment=False,
251 | no_batchnorm=False,
252 | widehead=args.widehead,
253 | widehead_hr=args.widehead_hr,
254 | map_locatin="cpu",
255 | arch_option=args.arch_option,
256 | use_pretrained=args.use_pretrained,
257 | strict=args.strict,
258 | logpath='fewshot/logpath_4T/',
259 | fold=args.fold,
260 | block_depth=0,
261 | nshot=args.nshot,
262 | finetune_mode=False,
263 | activation=args.activation,
264 | )
265 |
266 | Evaluator.initialize()
267 | if args.backbone in ["clip_resnet101"]:
268 | FSSDataset.initialize(img_size=480, datapath=args.datapath, use_original_imgsize=False, imagenet_norm=True)
269 | else:
270 | FSSDataset.initialize(img_size=480, datapath=args.datapath, use_original_imgsize=False)
271 | # dataloader
272 | args.benchmark = args.dataset
273 | dataloader = FSSDataset.build_dataloader(args.benchmark, args.bsz, args.nworker, args.fold, 'test', args.nshot)
274 |
275 | model = module.net.eval().cuda()
276 | # model = module.net.model.cpu()
277 |
278 | print(model)
279 |
280 | scales = (
281 | [0.75, 1.0, 1.25, 1.5, 1.75, 2.0, 2.25]
282 | if args.dataset == "citys"
283 | else [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
284 | )
285 |
286 | f = open("logs/fewshot/log_fewshot-test_nshot{}_{}.txt".format(args.nshot, args.dataset), "a+")
287 |
288 | utils.fix_randseed(0)
289 | average_meter = AverageMeter(dataloader.dataset)
290 | for idx, batch in enumerate(dataloader):
291 | batch = utils.to_cuda(batch)
292 | image = batch['query_img']
293 | target = batch['query_mask']
294 | class_info = batch['class_id']
295 | # pred_mask = evaluator.parallel_forward(image, class_info)
296 | pred_mask = model(image, class_info)
297 | # assert pred_mask.argmax(dim=1).size() == batch['query_mask'].size()
298 | # 2. Evaluate prediction
299 | if args.benchmark == 'pascal' and batch['query_ignore_idx'] is not None:
300 | query_ignore_idx = batch['query_ignore_idx']
301 | area_inter, area_union = Evaluator.classify_prediction(pred_mask.argmax(dim=1), target, query_ignore_idx)
302 | else:
303 | area_inter, area_union = Evaluator.classify_prediction(pred_mask.argmax(dim=1), target)
304 |
305 | average_meter.update(area_inter, area_union, class_info, loss=None)
306 | average_meter.write_process(idx, len(dataloader), epoch=-1, write_batch_idx=1)
307 |
308 | # Write evaluation results
309 | average_meter.write_result('Test', 0)
310 | test_miou, test_fb_iou = average_meter.compute_iou()
311 |
312 | Logger.info('Fold %d, %d-shot ==> mIoU: %5.2f \t FB-IoU: %5.2f' % (args.fold, args.nshot, test_miou.item(), test_fb_iou.item()))
313 | Logger.info('==================== Finished Testing ====================')
314 | f.write('{}\n'.format(args.weights))
315 | f.write('Fold %d, %d-shot ==> mIoU: %5.2f \t FB-IoU: %5.2f\n' % (args.fold, args.nshot, test_miou.item(), test_fb_iou.item()))
316 | f.close()
317 |
318 |
319 |
320 | if __name__ == "__main__":
321 | args = Options().parse()
322 | torch.manual_seed(args.seed)
323 | test(args)
324 |
--------------------------------------------------------------------------------
/train.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #python -u train_lseg.py --dataset ade20k --data_path ../datasets --batch_size 4 --exp_name lseg_ade20k_l16 \
3 | #--base_lr 0.004 --weight_decay 1e-4 --no-scaleinv --max_epochs 240 --widehead --accumulate_grad_batches 2 --backbone clip_vitl16_384
4 |
5 | python -u train_lseg.py --dataset ade20k --data_path ../datasets --batch_size 1 --exp_name lseg_ade20k_l16 \
6 | --base_lr 0.004 --weight_decay 1e-4 --no-scaleinv --max_epochs 240 --widehead --accumulate_grad_batches 2 --backbone clip_vitl16_384
--------------------------------------------------------------------------------
/train_lseg.py:
--------------------------------------------------------------------------------
1 | from modules.lseg_module import LSegModule
2 | from utils import do_training, get_default_argument_parser
3 |
4 | if __name__ == "__main__":
5 | parser = LSegModule.add_model_specific_args(get_default_argument_parser())
6 | args = parser.parse_args()
7 | do_training(args, LSegModule)
8 |
--------------------------------------------------------------------------------