├── README.md
├── __pycache__
├── clip.cpython-38.pyc
├── model_clip.cpython-38.pyc
└── simple_tokenizer.cpython-38.pyc
├── bpe_simple_vocab_16e6.txt.gz
├── clip.py
├── common
├── __pycache__
│ ├── evaluation.cpython-38.pyc
│ ├── logger.cpython-38.pyc
│ ├── utils.cpython-38.pyc
│ └── vis.cpython-38.pyc
├── evaluation.py
├── logger.py
├── utils.py
└── vis.py
├── data
├── __pycache__
│ ├── coco.cpython-38.pyc
│ ├── dataset.cpython-38.pyc
│ ├── fss.cpython-38.pyc
│ └── pascal.cpython-38.pyc
├── assets
│ ├── architecture.png
│ └── qualitative_results.png
├── coco.py
├── dataset.py
├── fss.py
├── pascal.py
└── splits
│ ├── coco
│ ├── trn
│ │ ├── fold0.pkl
│ │ ├── fold1.pkl
│ │ ├── fold2.pkl
│ │ └── fold3.pkl
│ └── val
│ │ ├── fold0.pkl
│ │ ├── fold1.pkl
│ │ ├── fold2.pkl
│ │ └── fold3.pkl
│ ├── fss
│ ├── test.txt
│ ├── trn.txt
│ └── val.txt
│ └── pascal
│ ├── trn
│ ├── fold0.txt
│ ├── fold1.txt
│ ├── fold2.txt
│ └── fold3.txt
│ └── val
│ ├── fold0.txt
│ ├── fold1.txt
│ ├── fold2.txt
│ └── fold3.txt
├── generate_cam_coco.py
├── generate_cam_voc.py
├── model
├── __pycache__
│ ├── hsnet.cpython-38.pyc
│ ├── hsnet_imr.cpython-38.pyc
│ ├── hsnet_raft_res_multi_group_xiaorong_grouponly.cpython-38.pyc
│ └── learner.cpython-38.pyc
├── base
│ ├── __pycache__
│ │ ├── conv4d.cpython-38.pyc
│ │ ├── correlation.cpython-38.pyc
│ │ └── feature.cpython-38.pyc
│ ├── conv4d.py
│ ├── correlation.py
│ └── feature.py
├── hsnet_imr.py
└── learner.py
├── model_clip.py
├── pytorch_grad_cam
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-38.pyc
│ ├── ablation_cam.cpython-38.pyc
│ ├── activations_and_gradients.cpython-38.pyc
│ ├── base_cam.cpython-38.pyc
│ ├── eigen_cam.cpython-38.pyc
│ ├── eigen_grad_cam.cpython-38.pyc
│ ├── fullgrad_cam.cpython-38.pyc
│ ├── grad_cam.cpython-38.pyc
│ ├── grad_cam_plusplus.cpython-38.pyc
│ ├── guided_backprop.cpython-38.pyc
│ ├── layer_cam.cpython-38.pyc
│ ├── score_cam.cpython-38.pyc
│ └── xgrad_cam.cpython-38.pyc
├── ablation_cam.py
├── activations_and_gradients.py
├── base_cam.py
├── eigen_cam.py
├── eigen_grad_cam.py
├── fullgrad_cam.py
├── grad_cam.py
├── grad_cam_plusplus.py
├── guided_backprop.py
├── layer_cam.py
├── score_cam.py
├── utils
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-38.pyc
│ │ ├── find_layers.cpython-38.pyc
│ │ ├── image.cpython-38.pyc
│ │ └── svd_on_activations.cpython-38.pyc
│ ├── find_layers.py
│ ├── image.py
│ └── svd_on_activations.py
└── xgrad_cam.py
├── simple_tokenizer.py
├── test.py
└── train.py
/README.md:
--------------------------------------------------------------------------------
1 |
2 | ## Iterative Few-shot Semantic Segmentation from Image Label Text
3 | This is the implementation of the paper "Iterative Few-shot Semantic Segmentation from Image Label Text" (IJCAI 2022).
4 | The codes are implemented based on HSNet(https://github.com/juhongm999/hsnet), CLIP(https://github.com/openai/CLIP), and https://github.com/jacobgil/pytorch-grad-cam. Thanks for their great work!
5 |
6 | ## Requirements
7 | Following HSNet:
8 | - Python 3.7
9 | - PyTorch 1.5.1
10 | - cuda 10.1
11 | - tensorboard 1.14
12 |
13 | Conda environment settings:
14 | ```bash
15 | conda create -n hsnet python=3.7
16 | conda activate hsnet
17 |
18 | conda install pytorch=1.5.1 torchvision cudatoolkit=10.1 -c pytorch
19 | conda install -c conda-forge tensorflow
20 | pip install tensorboardX
21 | ```
22 | ## Preparing Few-Shot Segmentation Datasets
23 | Download following datasets:
24 |
25 | > #### 1. PASCAL-5i
26 | > Download PASCAL VOC2012 devkit (train/val data):
27 | > ```bash
28 | > wget http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar
29 | > ```
30 | > Download PASCAL VOC2012 SDS extended mask annotations from HSNet [[Google Drive](https://drive.google.com/file/d/10zxG2VExoEZUeyQl_uXga2OWHjGeZaf2/view?usp=sharing)].
31 |
32 | > #### 2. COCO-20i
33 | > Download COCO2014 train/val images and annotations:
34 | > ```bash
35 | > wget http://images.cocodataset.org/zips/train2014.zip
36 | > wget http://images.cocodataset.org/zips/val2014.zip
37 | > wget http://images.cocodataset.org/annotations/annotations_trainval2014.zip
38 | > ```
39 | > Download COCO2014 train/val annotations from HSNet Google Drive: [[train2014.zip](https://drive.google.com/file/d/1cwup51kcr4m7v9jO14ArpxKMA4O3-Uge/view?usp=sharing)], [[val2014.zip](https://drive.google.com/file/d/1PNw4U3T2MhzAEBWGGgceXvYU3cZ7mJL1/view?usp=sharing)].
40 | > (and locate both train2014/ and val2014/ under annotations/ directory).
41 |
42 |
43 |
44 | Create a directory '../Datasets_HSN' for the above three few-shot segmentation datasets and appropriately place each dataset to have following directory structure:
45 |
46 | ../ # parent directory
47 | ├── ./ # current (project) directory
48 | │ ├── common/ # (dir.) helper functions
49 | │ ├── data/ # (dir.) dataloaders and splits for each FSSS dataset
50 | │ ├── model/ # (dir.) implementation of Hypercorrelation Squeeze Network model
51 | │ ├── README.md # intstruction for reproduction
52 | │ ├── train.py # code for training HSNet
53 | │ └── test.py # code for testing HSNet
54 | └── Datasets_HSN/
55 | ├── VOC2012/ # PASCAL VOC2012 devkit
56 | │ ├── Annotations/
57 | │ ├── ImageSets/
58 | │ ├── ...
59 | │ └── SegmentationClassAug/
60 | ├── COCO2014/
61 | │ ├── annotations/
62 | │ │ ├── train2014/ # (dir.) training masks (from Google Drive)
63 | │ │ ├── val2014/ # (dir.) validation masks (from Google Drive)
64 | │ │ └── ..some json files..
65 | │ ├── train2014/
66 | │ └── val2014/
67 | ├── CAM_VOC_Train/
68 | ├── CAM_VOC_Val/
69 | └── CAM_COCO/
70 |
71 |
72 | ## Preparing CAM for Few-Shot Segmentation Datasets
73 | > ### 1. PASCAL-5i
74 | > * Generate Grad CAM for images
75 | > ```bash
76 | > python generate_cam_voc.py --traincampath ../Datasets_HSN/CAM_VOC_Train/
77 | > --valcampath ../Datasets_HSN/CAM_VOC_Val/
78 | > ```
79 | ### 2. COCO-20i
80 | > ```bash
81 | > python generate_cam_coco.py --campath ../Datasets_HSN/CAM_COCO/
82 |
83 |
84 |
85 |
86 | ## Training
87 | > ### 1. PASCAL-5i
88 | > ```bash
89 | > python train.py --backbone {vgg16, resnet50}
90 | > --fold {0, 1, 2, 3}
91 | > --benchmark pascal
92 | > --lr 4e-4
93 | > --bsz 40
94 | > --stage 2
95 | > --logpath "your_experiment_name"
96 | > --traincampath ../Datasets_HSN/CAM_VOC_Train/
97 | > --valcampath ../Datasets_HSN/CAM_VOC_Val/
98 | > ```
99 | > * Training takes approx. 1 days until convergence (trained with four V100 GPUs).
100 |
101 |
102 | > ### 2. COCO-20i
103 | > ```bash
104 | > python train.py --backbone {vgg16, resnet50}
105 | > --fold {0, 1, 2, 3}
106 | > --benchmark coco
107 | > --lr 2e-4
108 | > --bsz 20
109 | > --stage 3
110 | > --logpath "your_experiment_name"
111 | > --traincampath ../Datasets_HSN/CAM_COCO/
112 | > --valcampath ../Datasets_HSN/CAM_COCO/
113 | > ```
114 | > * Training takes approx. 1 week until convergence (trained four V100 GPUs).
115 |
116 |
117 | > ### Babysitting training:
118 | > Use tensorboard to babysit training progress:
119 | > - For each experiment, a directory that logs training progress will be automatically generated under logs/ directory.
120 | > - From terminal, run 'tensorboard --logdir logs/' to monitor the training progress.
121 | > - Choose the best model when the validation (mIoU) curve starts to saturate.
122 |
123 |
124 |
125 | ## Testing
126 |
127 | > ### 1. PASCAL-5i
128 | > Pretrained models with tensorboard logs are available on our [[Google Drive](https://drive.google.com/drive/folders/1fB3_jUEw972lDZIs3_S7lj2F5rZVq4Nu?usp=sharing)].
129 | > ```bash
130 | > python test.py --backbone {vgg16, resnet50}
131 | > --fold {0, 1, 2, 3}
132 | > --benchmark pascal
133 | > --nshot {1, 5}
134 | > --load "path_to_trained_model/best_model.pt"
135 | > ```
136 |
137 |
138 | > ### 2. COCO-20i
139 | > Pretrained models with tensorboard logs are available on our [[Google Drive](https://drive.google.com/drive/folders/1fB3_jUEw972lDZIs3_S7lj2F5rZVq4Nu?usp=sharing)].
140 | > ```bash
141 | > python test.py --backbone {vgg16, resnet50}
142 | > --fold {0, 1, 2, 3}
143 | > --benchmark coco
144 | > --nshot {1, 5}
145 | > --load "path_to_trained_model/best_model.pt"
146 | > ```
147 |
148 |
149 |
150 |
151 | ## BibTeX
152 | If you use this code for your research, please consider citing:
153 | ````BibTeX
154 | @inproceedings{ijcai2022p193,
155 | title = {Iterative Few-shot Semantic Segmentation from Image Label Text},
156 | author = {Wang, Haohan and Liu, Liang and Zhang, Wuhao and Zhang, Jiangning and Gan, Zhenye and Wang, Yabiao and Wang, Chengjie and Wang, Haoqian},
157 | booktitle = {Proceedings of the Thirty-First International Joint Conference on
158 | Artificial Intelligence, {IJCAI-22}},
159 | publisher = {International Joint Conferences on Artificial Intelligence Organization},
160 | editor = {Lud De Raedt},
161 | pages = {1385--1392},
162 | year = {2022},
163 | month = {7},
164 | note = {Main Track},
165 | doi = {10.24963/ijcai.2022/193},
166 | url = {https://doi.org/10.24963/ijcai.2022/193},
167 | }
168 | ````
169 |
--------------------------------------------------------------------------------
/__pycache__/clip.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Whileherham/IMR-HSNet/57572272f493c3b16540a650ff0398d605e46812/__pycache__/clip.cpython-38.pyc
--------------------------------------------------------------------------------
/__pycache__/model_clip.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Whileherham/IMR-HSNet/57572272f493c3b16540a650ff0398d605e46812/__pycache__/model_clip.cpython-38.pyc
--------------------------------------------------------------------------------
/__pycache__/simple_tokenizer.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Whileherham/IMR-HSNet/57572272f493c3b16540a650ff0398d605e46812/__pycache__/simple_tokenizer.cpython-38.pyc
--------------------------------------------------------------------------------
/bpe_simple_vocab_16e6.txt.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Whileherham/IMR-HSNet/57572272f493c3b16540a650ff0398d605e46812/bpe_simple_vocab_16e6.txt.gz
--------------------------------------------------------------------------------
/clip.py:
--------------------------------------------------------------------------------
1 | # Slightly modified to extract self-attention scores
2 | # Original code: https://github.com/openai/CLIP
3 |
4 | import hashlib
5 | import os
6 | import urllib
7 | import warnings
8 | from typing import Union, List
9 |
10 | import torch
11 | from PIL import Image
12 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
13 | from tqdm import tqdm
14 |
15 | from model_clip import build_model
16 | from simple_tokenizer import SimpleTokenizer as _Tokenizer
17 | _tokenizer = _Tokenizer()
18 |
19 | __all__ = ["available_models", "load"]
20 |
21 | _MODELS = {
22 | "RN50": "https://openaipublic.azureedge.net/clip/models\
23 | /afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
24 | "RN101": "https://openaipublic.azureedge.net/clip/models\
25 | /8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
26 | "RN50x4": "https://openaipublic.azureedge.net/clip/models\
27 | /7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
28 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models\
29 | /40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
30 | }
31 |
32 |
33 | def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")):
34 | os.makedirs(root, exist_ok=True)
35 | filename = os.path.basename(url)
36 |
37 | expected_sha256 = url.split("/")[-2]
38 | download_target = os.path.join(root, filename)
39 |
40 | if os.path.exists(download_target) and not os.path.isfile(download_target):
41 | raise RuntimeError(f"{download_target} exists and is not a regular file")
42 |
43 | if os.path.isfile(download_target):
44 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
45 | return download_target
46 | else:
47 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
48 |
49 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
50 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop:
51 | while True:
52 | buffer = source.read(8192)
53 | if not buffer:
54 | break
55 |
56 | output.write(buffer)
57 | loop.update(len(buffer))
58 |
59 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
60 | raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
61 |
62 | return download_target
63 |
64 |
65 | def _transform(n_px):
66 | return Compose([
67 | Resize(n_px, interpolation=Image.BICUBIC),
68 | CenterCrop(n_px),
69 | lambda image: image.convert("RGB"),
70 | ToTensor(),
71 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
72 | ])
73 |
74 |
75 | def available_models() -> List[str]:
76 | """Returns the names of available CLIP models"""
77 | return list(_MODELS.keys())
78 |
79 |
80 | def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=True):
81 | """Load a CLIP model
82 |
83 | Parameters
84 | ----------
85 | name : str
86 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
87 |
88 | device : Union[str, torch.device]
89 | The device to put the loaded model
90 |
91 | jit : bool
92 | Whether to load the optimized JIT model (default) or more hackable non-JIT model.
93 |
94 | Returns
95 | -------
96 | model : torch.nn.Module
97 | The CLIP model
98 |
99 | preprocess : Callable[[PIL.Image], torch.Tensor]
100 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
101 | """
102 |
103 | if name in _MODELS:
104 | model_path = _download(_MODELS[name])
105 | elif os.path.isfile(name):
106 | model_path = name
107 | else:
108 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
109 |
110 | try:
111 | # loading JIT archive
112 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
113 | state_dict = None
114 | except RuntimeError:
115 | # loading saved state dict
116 | if jit:
117 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
118 | jit = False
119 | state_dict = torch.load(model_path, map_location="cpu")
120 |
121 | if not jit:
122 | model = build_model(state_dict or model.state_dict()).to(device)
123 | if str(device) == "cpu":
124 | model.float()
125 | return model, _transform(model.visual.input_resolution)
126 |
127 | # patch the device names
128 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
129 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
130 |
131 | def patch_device(module):
132 | graphs = [module.graph] if hasattr(module, "graph") else []
133 | if hasattr(module, "forward1"):
134 | graphs.append(module.forward1.graph)
135 |
136 | for graph in graphs:
137 | for node in graph.findAllNodes("prim::Constant"):
138 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
139 | node.copyAttributes(device_node)
140 |
141 | model.apply(patch_device)
142 | patch_device(model.encode_image)
143 | patch_device(model.encode_text)
144 |
145 | # patch dtype to float32 on CPU
146 | if str(device) == "cpu":
147 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
148 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
149 | float_node = float_input.node()
150 |
151 | def patch_float(module):
152 | graphs = [module.graph] if hasattr(module, "graph") else []
153 | if hasattr(module, "forward1"):
154 | graphs.append(module.forward1.graph)
155 |
156 | for graph in graphs:
157 | for node in graph.findAllNodes("aten::to"):
158 | inputs = list(node.inputs())
159 | for i in [1, 2]: # dtype can be the second or third argument to aten::to()
160 | if inputs[i].node()["value"] == 5:
161 | inputs[i].node().copyAttributes(float_node)
162 |
163 | model.apply(patch_float)
164 | patch_float(model.encode_image)
165 | patch_float(model.encode_text)
166 |
167 | model.float()
168 |
169 | return model, _transform(model.input_resolution.item())
170 |
171 |
172 | def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor:
173 | """
174 | Returns the tokenized representation of given input string(s)
175 | Parameters
176 | ----------
177 | texts : Union[str, List[str]]
178 | An input string or a list of input strings to tokenize
179 | context_length : int
180 | The context length to use; all CLIP models use 77 as the context length
181 | truncate: bool
182 | Whether to truncate the text in case its encoding is longer than the context length
183 | Returns
184 | -------
185 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
186 | """
187 | if isinstance(texts, str):
188 | texts = [texts]
189 |
190 | sot_token = _tokenizer.encoder["<|startoftext|>"]
191 | eot_token = _tokenizer.encoder["<|endoftext|>"]
192 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
193 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
194 |
195 | for i, tokens in enumerate(all_tokens):
196 | if len(tokens) > context_length:
197 | if truncate:
198 | tokens = tokens[:context_length]
199 | tokens[-1] = eot_token
200 | else:
201 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
202 | result[i, :len(tokens)] = torch.tensor(tokens)
203 |
204 | return result
205 |
--------------------------------------------------------------------------------
/common/__pycache__/evaluation.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Whileherham/IMR-HSNet/57572272f493c3b16540a650ff0398d605e46812/common/__pycache__/evaluation.cpython-38.pyc
--------------------------------------------------------------------------------
/common/__pycache__/logger.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Whileherham/IMR-HSNet/57572272f493c3b16540a650ff0398d605e46812/common/__pycache__/logger.cpython-38.pyc
--------------------------------------------------------------------------------
/common/__pycache__/utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Whileherham/IMR-HSNet/57572272f493c3b16540a650ff0398d605e46812/common/__pycache__/utils.cpython-38.pyc
--------------------------------------------------------------------------------
/common/__pycache__/vis.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Whileherham/IMR-HSNet/57572272f493c3b16540a650ff0398d605e46812/common/__pycache__/vis.cpython-38.pyc
--------------------------------------------------------------------------------
/common/evaluation.py:
--------------------------------------------------------------------------------
1 | r""" Evaluate mask prediction """
2 | import torch
3 |
4 |
5 | class Evaluator:
6 | r""" Computes intersection and union between prediction and ground-truth """
7 | @classmethod
8 | def initialize(cls):
9 | cls.ignore_index = 255
10 |
11 | @classmethod
12 | def classify_prediction(cls, pred_mask, batch):
13 | gt_mask = batch.get('query_mask')
14 |
15 | # Apply ignore_index in PASCAL-5i masks (following evaluation scheme in PFE-Net (TPAMI 2020))
16 | query_ignore_idx = batch.get('query_ignore_idx')
17 | if query_ignore_idx is not None:
18 | assert torch.logical_and(query_ignore_idx, gt_mask).sum() == 0
19 | query_ignore_idx *= cls.ignore_index
20 | gt_mask = gt_mask + query_ignore_idx
21 | pred_mask[gt_mask == cls.ignore_index] = cls.ignore_index
22 |
23 | # compute intersection and union of each episode in a batch
24 | area_inter, area_pred, area_gt = [], [], []
25 | for _pred_mask, _gt_mask in zip(pred_mask, gt_mask):
26 | _inter = _pred_mask[_pred_mask == _gt_mask]
27 | if _inter.size(0) == 0: # as torch.histc returns error if it gets empty tensor (pytorch 1.5.1)
28 | _area_inter = torch.tensor([0, 0], device=_pred_mask.device)
29 | else:
30 | _area_inter = torch.histc(_inter, bins=2, min=0, max=1)
31 | area_inter.append(_area_inter)
32 | area_pred.append(torch.histc(_pred_mask, bins=2, min=0, max=1))
33 | area_gt.append(torch.histc(_gt_mask, bins=2, min=0, max=1))
34 | area_inter = torch.stack(area_inter).t()
35 | area_pred = torch.stack(area_pred).t()
36 | area_gt = torch.stack(area_gt).t()
37 | area_union = area_pred + area_gt - area_inter
38 |
39 | return area_inter, area_union
40 |
--------------------------------------------------------------------------------
/common/logger.py:
--------------------------------------------------------------------------------
1 | r""" Logging during training/testing """
2 | import datetime
3 | import logging
4 | import os
5 |
6 | from tensorboardX import SummaryWriter
7 | import torch
8 |
9 |
10 | class AverageMeter:
11 | r""" Stores loss, evaluation results """
12 | def __init__(self, dataset):
13 | self.benchmark = dataset.benchmark
14 | self.class_ids_interest = dataset.class_ids
15 | self.class_ids_interest = torch.tensor(self.class_ids_interest).cuda()
16 |
17 | if self.benchmark == 'pascal':
18 | self.nclass = 20
19 | elif self.benchmark == 'coco':
20 | self.nclass = 80
21 | elif self.benchmark == 'fss':
22 | self.nclass = 1000
23 |
24 | self.intersection_buf = torch.zeros([2, self.nclass]).float().cuda()
25 | self.union_buf = torch.zeros([2, self.nclass]).float().cuda()
26 | self.ones = torch.ones_like(self.union_buf)
27 | self.loss_buf = []
28 |
29 | def update(self, inter_b, union_b, class_id, loss):
30 | self.intersection_buf.index_add_(1, class_id, inter_b.float())
31 | self.union_buf.index_add_(1, class_id, union_b.float())
32 | if loss is None:
33 | loss = torch.tensor(0.0)
34 | self.loss_buf.append(loss)
35 |
36 | def compute_iou(self):
37 | iou = self.intersection_buf.float() / \
38 | torch.max(torch.stack([self.union_buf, self.ones]), dim=0)[0]
39 | iou = iou.index_select(1, self.class_ids_interest)
40 | miou = iou[1].mean() * 100
41 |
42 | fb_iou = (self.intersection_buf.index_select(1, self.class_ids_interest).sum(dim=1) /
43 | self.union_buf.index_select(1, self.class_ids_interest).sum(dim=1)).mean() * 100
44 |
45 | return miou, fb_iou
46 |
47 | def write_result(self, split, epoch):
48 | iou, fb_iou = self.compute_iou()
49 |
50 | loss_buf = torch.stack(self.loss_buf)
51 | msg = '\n*** %s ' % split
52 | msg += '[@Epoch %02d] ' % epoch
53 | msg += 'Avg L: %6.5f ' % loss_buf.mean()
54 | msg += 'mIoU: %5.2f ' % iou
55 | msg += 'FB-IoU: %5.2f ' % fb_iou
56 |
57 | msg += '***\n'
58 | Logger.info(msg)
59 |
60 | def write_process(self, batch_idx, datalen, epoch, write_batch_idx=20):
61 | if batch_idx % write_batch_idx == 0:
62 | msg = '[Epoch: %02d] ' % epoch if epoch != -1 else ''
63 | msg += '[Batch: %04d/%04d] ' % (batch_idx+1, datalen)
64 | iou, fb_iou = self.compute_iou()
65 | if epoch != -1:
66 | loss_buf = torch.stack(self.loss_buf)
67 | msg += 'L: %6.5f ' % loss_buf[-1]
68 | msg += 'Avg L: %6.5f ' % loss_buf.mean()
69 | msg += 'mIoU: %5.2f | ' % iou
70 | msg += 'FB-IoU: %5.2f' % fb_iou
71 | Logger.info(msg)
72 |
73 |
74 | class Logger:
75 | r""" Writes evaluation results of training/testing """
76 | @classmethod
77 | def initialize(cls, args, training):
78 | logtime = datetime.datetime.now().__format__('_%m%d_%H%M%S')
79 | logpath = args.logpath if training else '_TEST_' + args.load.split('/')[-2].split('.')[0] + logtime
80 | if logpath == '':
81 | logpath = logtime
82 |
83 | cls.logpath = os.path.join('logs', logpath + '.log')
84 | cls.benchmark = args.benchmark
85 | os.makedirs(cls.logpath)
86 |
87 | logging.basicConfig(filemode='w',
88 | filename=os.path.join(cls.logpath, 'log.txt'),
89 | level=logging.INFO,
90 | format='%(message)s',
91 | datefmt='%m-%d %H:%M:%S')
92 |
93 | # Console log config
94 | console = logging.StreamHandler()
95 | console.setLevel(logging.INFO)
96 | formatter = logging.Formatter('%(message)s')
97 | console.setFormatter(formatter)
98 | logging.getLogger('').addHandler(console)
99 |
100 | # Tensorboard writer
101 | cls.tbd_writer = SummaryWriter(os.path.join(cls.logpath, 'tbd/runs'))
102 |
103 | # Log arguments
104 | logging.info('\n:=========== Few-shot Seg. with HSNet ===========')
105 | for arg_key in args.__dict__:
106 | logging.info('| %20s: %-24s' % (arg_key, str(args.__dict__[arg_key])))
107 | logging.info(':================================================\n')
108 |
109 | @classmethod
110 | def info(cls, msg):
111 | r""" Writes log message to log.txt """
112 | logging.info(msg)
113 |
114 | @classmethod
115 | def save_model_miou(cls, model, epoch, val_miou):
116 | torch.save(model.state_dict(), os.path.join(cls.logpath, 'best_model.pt'))
117 | cls.info('Model saved @%d w/ val. mIoU: %5.2f.\n' % (epoch, val_miou))
118 |
119 | @classmethod
120 | def log_params(cls, model):
121 | backbone_param = 0
122 | learner_param = 0
123 | for k in model.state_dict().keys():
124 | n_param = model.state_dict()[k].view(-1).size(0)
125 | if k.split('.')[0] in 'backbone':
126 | if k.split('.')[1] in ['classifier', 'fc']: # as fc layers are not used in HSNet
127 | continue
128 | backbone_param += n_param
129 | else:
130 | learner_param += n_param
131 | Logger.info('Backbone # param.: %d' % backbone_param)
132 | Logger.info('Learnable # param.: %d' % learner_param)
133 | Logger.info('Total # param.: %d' % (backbone_param + learner_param))
134 |
135 |
136 |
--------------------------------------------------------------------------------
/common/utils.py:
--------------------------------------------------------------------------------
1 | r""" Helper functions """
2 | import random
3 |
4 | import torch
5 | import numpy as np
6 |
7 |
8 | def fix_randseed(seed):
9 | r""" Set random seeds for reproducibility """
10 | if seed is None:
11 | seed = int(random.random() * 1e5)
12 | np.random.seed(seed)
13 | torch.manual_seed(seed)
14 | torch.cuda.manual_seed(seed)
15 | torch.cuda.manual_seed_all(seed)
16 | torch.backends.cudnn.benchmark = False
17 | torch.backends.cudnn.deterministic = True
18 |
19 |
20 | def mean(x):
21 | return sum(x) / len(x) if len(x) > 0 else 0.0
22 |
23 |
24 | def to_cuda(batch):
25 | for key, value in batch.items():
26 | if isinstance(value, torch.Tensor):
27 | batch[key] = value.cuda()
28 | return batch
29 |
30 |
31 | def to_cpu(tensor):
32 | return tensor.detach().clone().cpu()
33 |
--------------------------------------------------------------------------------
/common/vis.py:
--------------------------------------------------------------------------------
1 | r""" Visualize model predictions """
2 | import os
3 |
4 | from PIL import Image
5 | import numpy as np
6 | import torchvision.transforms as transforms
7 |
8 | from . import utils
9 |
10 |
11 | class Visualizer:
12 |
13 | @classmethod
14 | def initialize(cls, visualize):
15 | cls.visualize = visualize
16 | if not visualize:
17 | return
18 |
19 | cls.colors = {'red': (255, 50, 50), 'blue': (102, 140, 255)}
20 | for key, value in cls.colors.items():
21 | cls.colors[key] = tuple([c / 255 for c in cls.colors[key]])
22 |
23 | cls.mean_img = [0.485, 0.456, 0.406]
24 | cls.std_img = [0.229, 0.224, 0.225]
25 | cls.to_pil = transforms.ToPILImage()
26 | cls.vis_path = './vis/'
27 | if not os.path.exists(cls.vis_path):
28 | os.makedirs(cls.vis_path)
29 |
30 | @classmethod
31 | def visualize_prediction_batch(
32 | cls, spt_img_b, spt_mask_b, qry_img_b, qry_mask_b, pred_mask_b, cls_id_b, batch_idx, iou_b=None):
33 | spt_img_b = utils.to_cpu(spt_img_b)
34 | spt_mask_b = utils.to_cpu(spt_mask_b)
35 | qry_img_b = utils.to_cpu(qry_img_b)
36 | qry_mask_b = utils.to_cpu(qry_mask_b)
37 | pred_mask_b = utils.to_cpu(pred_mask_b)
38 | cls_id_b = utils.to_cpu(cls_id_b)
39 |
40 | for sample_idx, (spt_img, spt_mask, qry_img, qry_mask, pred_mask, cls_id) in \
41 | enumerate(zip(spt_img_b, spt_mask_b, qry_img_b, qry_mask_b, pred_mask_b, cls_id_b)):
42 | iou = iou_b[sample_idx] if iou_b is not None else None
43 | cls.visualize_prediction(
44 | spt_img, spt_mask, qry_img, qry_mask, pred_mask, cls_id, batch_idx, sample_idx, True, iou)
45 |
46 | @classmethod
47 | def to_numpy(cls, tensor, type):
48 | if type == 'img':
49 | return np.array(cls.to_pil(cls.unnormalize(tensor))).astype(np.uint8)
50 | elif type == 'mask':
51 | return np.array(tensor).astype(np.uint8)
52 | else:
53 | raise Exception('Undefined tensor type: %s' % type)
54 |
55 | @classmethod
56 | def visualize_prediction(
57 | cls, spt_imgs, spt_masks, qry_img, qry_mask, pred_mask, cls_id, batch_idx, sample_idx, label, iou=None):
58 |
59 | spt_color = cls.colors['blue']
60 | qry_color = cls.colors['red']
61 | pred_color = cls.colors['red']
62 |
63 | spt_imgs = [cls.to_numpy(spt_img, 'img') for spt_img in spt_imgs]
64 | spt_pils = [cls.to_pil(spt_img) for spt_img in spt_imgs]
65 | spt_masks = [cls.to_numpy(spt_mask, 'mask') for spt_mask in spt_masks]
66 | spt_masked_pils = [Image.fromarray(
67 | cls.apply_mask(spt_img, spt_mask, spt_color)) for spt_img, spt_mask in zip(spt_imgs, spt_masks)]
68 |
69 | qry_img = cls.to_numpy(qry_img, 'img')
70 | qry_pil = cls.to_pil(qry_img)
71 | qry_mask = cls.to_numpy(qry_mask, 'mask')
72 | pred_mask = cls.to_numpy(pred_mask, 'mask')
73 | pred_masked_pil = \
74 | Image.fromarray(cls.apply_mask(qry_img.astype(np.uint8), pred_mask.astype(np.uint8), pred_color))
75 | qry_masked_pil = Image.fromarray(cls.apply_mask(qry_img.astype(np.uint8), qry_mask.astype(np.uint8), qry_color))
76 |
77 | merged_pil = cls.merge_image_pair(spt_masked_pils + [pred_masked_pil, qry_masked_pil])
78 |
79 | iou = iou.item() if iou else 0.0
80 | merged_pil.save(cls.vis_path + '%d_%d_class-%d_iou-%.2f' % (batch_idx, sample_idx, cls_id, iou) + '.jpg')
81 |
82 | @classmethod
83 | def merge_image_pair(cls, pil_imgs):
84 | r""" Horizontally aligns a pair of pytorch tensor images (3, H, W) and returns PIL object """
85 |
86 | canvas_width = sum([pil.size[0] for pil in pil_imgs])
87 | canvas_height = max([pil.size[1] for pil in pil_imgs])
88 | canvas = Image.new('RGB', (canvas_width, canvas_height))
89 |
90 | xpos = 0
91 | for pil in pil_imgs:
92 | canvas.paste(pil, (xpos, 0))
93 | xpos += pil.size[0]
94 |
95 | return canvas
96 |
97 | @classmethod
98 | def apply_mask(cls, image, mask, color, alpha=0.5):
99 | r""" Apply mask to the given image. """
100 | for c in range(3):
101 | image[:, :, c] = np.where(mask == 1,
102 | image[:, :, c] *
103 | (1 - alpha) + alpha * color[c] * 255,
104 | image[:, :, c])
105 | return image
106 |
107 | @classmethod
108 | def unnormalize(cls, img):
109 | img = img.clone()
110 | for im_channel, mean, std in zip(img, cls.mean_img, cls.std_img):
111 | im_channel.mul_(std).add_(mean)
112 | return img
113 |
--------------------------------------------------------------------------------
/data/__pycache__/coco.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Whileherham/IMR-HSNet/57572272f493c3b16540a650ff0398d605e46812/data/__pycache__/coco.cpython-38.pyc
--------------------------------------------------------------------------------
/data/__pycache__/dataset.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Whileherham/IMR-HSNet/57572272f493c3b16540a650ff0398d605e46812/data/__pycache__/dataset.cpython-38.pyc
--------------------------------------------------------------------------------
/data/__pycache__/fss.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Whileherham/IMR-HSNet/57572272f493c3b16540a650ff0398d605e46812/data/__pycache__/fss.cpython-38.pyc
--------------------------------------------------------------------------------
/data/__pycache__/pascal.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Whileherham/IMR-HSNet/57572272f493c3b16540a650ff0398d605e46812/data/__pycache__/pascal.cpython-38.pyc
--------------------------------------------------------------------------------
/data/assets/architecture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Whileherham/IMR-HSNet/57572272f493c3b16540a650ff0398d605e46812/data/assets/architecture.png
--------------------------------------------------------------------------------
/data/assets/qualitative_results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Whileherham/IMR-HSNet/57572272f493c3b16540a650ff0398d605e46812/data/assets/qualitative_results.png
--------------------------------------------------------------------------------
/data/coco.py:
--------------------------------------------------------------------------------
1 | r""" COCO-20i few-shot semantic segmentation dataset """
2 | import os
3 | import pickle
4 |
5 | from torch.utils.data import Dataset
6 | import torch.nn.functional as F
7 | import torch
8 | import PIL.Image as Image
9 | import numpy as np
10 |
11 |
12 | class DatasetCOCO(Dataset):
13 | def __init__(self, datapath, fold, transform, split, shot, use_original_imgsize, cam_train_path, cam_val_path):
14 | self.split = 'val' if split in ['val', 'test'] else 'trn'
15 | self.fold = fold
16 | self.nfolds = 4
17 | self.nclass = 80
18 | self.benchmark = 'coco'
19 | self.shot = shot
20 | self.split_coco = split if split == 'val2014' else 'train2014'
21 | self.base_path = os.path.join(datapath, 'COCO2014')
22 | self.transform = transform
23 | self.use_original_imgsize = use_original_imgsize
24 |
25 | self.class_ids = self.build_class_ids()
26 | self.img_metadata_classwise = self.build_img_metadata_classwise()
27 | self.img_metadata = self.build_img_metadata()
28 |
29 | assert cam_train_path == cam_val_path
30 | # This is because, the query name of COCO includes "train2014-img" pr "val2014-img"
31 | # But VOC only includes "img"
32 | self.cam_path = cam_train_path
33 |
34 |
35 | def __len__(self):
36 | return len(self.img_metadata) if self.split == 'trn' else 1000
37 |
38 | def __getitem__(self, idx):
39 | # ignores idx during training & testing and perform uniform sampling over object classes to form an episode
40 | # (due to the large size of the COCO dataset)
41 | query_img, query_mask, support_imgs, support_masks, \
42 | query_name, support_names, class_sample, org_qry_imsize = self.load_frame()
43 |
44 | query_img = self.transform(query_img)
45 | query_mask = query_mask.float()
46 | if not self.use_original_imgsize:
47 | query_mask = F.interpolate(
48 | query_mask.unsqueeze(0).unsqueeze(0).float(), query_img.size()[-2:], mode='nearest').squeeze()
49 |
50 | support_imgs = torch.stack([self.transform(support_img) for support_img in support_imgs])
51 | for midx, smask in enumerate(support_masks):
52 | support_masks[midx] = F.interpolate(
53 | smask.unsqueeze(0).unsqueeze(0).float(), support_imgs.size()[-2:], mode='nearest').squeeze()
54 | support_masks = torch.stack(support_masks)
55 | query_cam_path = self.cam_path + query_name + '--' + str(class_sample) + '.pt'
56 | query_cam = torch.load(query_cam_path) # 50 50
57 |
58 | nshot = len(support_names)
59 | support_cams = []
60 | for nn in range(nshot):
61 | support_cam_path = self.cam_path + support_names[nn] + '--' + str(class_sample) + '.pt'
62 | support_cam = torch.load(support_cam_path).unsqueeze(0) # 1 50 50
63 | support_cams.append(support_cam)
64 | support_cams = torch.cat(support_cams, dim=0) # nshot 50 50
65 |
66 | batch = {'query_img': query_img,
67 | 'query_mask': query_mask,
68 | 'query_name': query_name,
69 |
70 | 'org_query_imsize': org_qry_imsize,
71 |
72 | 'support_imgs': support_imgs,
73 | 'support_masks': support_masks,
74 | 'support_names': support_names,
75 | 'class_id': torch.tensor(class_sample),
76 | 'query_cam': query_cam,
77 | 'support_cams': support_cams
78 | }
79 |
80 | return batch
81 |
82 | def build_class_ids(self):
83 | nclass_trn = self.nclass // self.nfolds
84 | class_ids_val = [self.fold + self.nfolds * v for v in range(nclass_trn)]
85 | class_ids_trn = [x for x in range(self.nclass) if x not in class_ids_val]
86 | class_ids = class_ids_trn if self.split == 'trn' else class_ids_val
87 |
88 | return class_ids
89 |
90 | def build_img_metadata_classwise(self):
91 | with open('./data/splits/coco/%s/fold%d.pkl' % (self.split, self.fold), 'rb') as f:
92 | img_metadata_classwise = pickle.load(f)
93 | return img_metadata_classwise
94 |
95 | def build_img_metadata(self):
96 | img_metadata = []
97 | for k in self.img_metadata_classwise.keys():
98 | img_metadata += self.img_metadata_classwise[k]
99 | return sorted(list(set(img_metadata)))
100 |
101 | def read_mask(self, name):
102 | mask_path = os.path.join(self.base_path, 'annotations', name)
103 | mask = torch.tensor(np.array(Image.open(mask_path[:mask_path.index('.jpg')] + '.png')))
104 | return mask
105 |
106 | def load_frame(self):
107 | class_sample = np.random.choice(self.class_ids, 1, replace=False)[0]
108 | query_name = np.random.choice(self.img_metadata_classwise[class_sample], 1, replace=False)[0]
109 | query_img = Image.open(os.path.join(self.base_path, query_name)).convert('RGB')
110 | query_mask = self.read_mask(query_name)
111 |
112 | org_qry_imsize = query_img.size
113 |
114 | query_mask[query_mask != class_sample + 1] = 0
115 | query_mask[query_mask == class_sample + 1] = 1
116 |
117 | support_names = []
118 | while True: # keep sampling support set if query == support
119 | support_name = np.random.choice(self.img_metadata_classwise[class_sample], 1, replace=False)[0]
120 | if query_name != support_name:
121 | support_names.append(support_name)
122 | if len(support_names) == self.shot:
123 | break
124 |
125 | support_imgs = []
126 | support_masks = []
127 | for support_name in support_names:
128 | support_imgs.append(Image.open(os.path.join(self.base_path, support_name)).convert('RGB'))
129 | support_mask = self.read_mask(support_name)
130 | support_mask[support_mask != class_sample + 1] = 0
131 | support_mask[support_mask == class_sample + 1] = 1
132 | support_masks.append(support_mask)
133 |
134 | return \
135 | query_img, query_mask, support_imgs, support_masks, query_name, support_names, class_sample, org_qry_imsize
136 |
--------------------------------------------------------------------------------
/data/dataset.py:
--------------------------------------------------------------------------------
1 | r""" Dataloader builder for few-shot semantic segmentation dataset """
2 | from torchvision import transforms
3 | from torch.utils.data import DataLoader
4 |
5 | from data.pascal import DatasetPASCAL
6 | from data.coco import DatasetCOCO
7 | from data.fss import DatasetFSS
8 |
9 |
10 | class FSSDataset:
11 |
12 | @classmethod
13 | def initialize(cls, img_size, datapath, use_original_imgsize):
14 |
15 | cls.datasets = {
16 | 'pascal': DatasetPASCAL,
17 | 'coco': DatasetCOCO,
18 | 'fss': DatasetFSS,
19 | }
20 |
21 | cls.img_mean = [0.485, 0.456, 0.406]
22 | cls.img_std = [0.229, 0.224, 0.225]
23 | cls.datapath = datapath
24 | cls.use_original_imgsize = use_original_imgsize
25 |
26 | cls.transform = transforms.Compose([transforms.Resize(size=(img_size, img_size)),
27 | transforms.ToTensor(),
28 | transforms.Normalize(cls.img_mean, cls.img_std)])
29 |
30 | @classmethod
31 | def build_dataloader(cls, benchmark, bsz, nworker, fold, split, shot=1, cam_train_path=None, cam_val_path=None):
32 | # Force randomness during training for diverse episode combinations
33 | # Freeze randomness during testing for reproducibility
34 | shuffle = split == 'trn'
35 | nworker = nworker if split == 'trn' else 0
36 |
37 | dataset = cls.datasets[benchmark](cls.datapath, fold=fold, transform=cls.transform, split=split,
38 | shot=shot, use_original_imgsize=cls.use_original_imgsize,
39 | cam_train_path=cam_train_path, cam_val_path=cam_val_path)
40 | dataloader = DataLoader(dataset, batch_size=bsz, shuffle=shuffle, num_workers=nworker)
41 |
42 | return dataloader
43 |
--------------------------------------------------------------------------------
/data/fss.py:
--------------------------------------------------------------------------------
1 | r""" FSS-1000 few-shot semantic segmentation dataset """
2 | import os
3 | import glob
4 |
5 | from torch.utils.data import Dataset
6 | import torch.nn.functional as F
7 | import torch
8 | import PIL.Image as Image
9 | import numpy as np
10 |
11 |
12 | class DatasetFSS(Dataset):
13 | def __init__(self, datapath, fold, transform, split, shot, use_original_imgsize):
14 | self.split = split
15 | self.benchmark = 'fss'
16 | self.shot = shot
17 |
18 | self.base_path = os.path.join(datapath, 'FSS-1000')
19 |
20 | # Given predefined test split, load randomly generated training/val splits:
21 | # (reference regarding trn/val/test splits: https://github.com/HKUSTCV/FSS-1000/issues/7))
22 | with open('./data/splits/fss/%s.txt' % split, 'r') as f:
23 | self.categories = f.read().split('\n')[:-1]
24 | self.categories = sorted(self.categories)
25 |
26 | self.class_ids = self.build_class_ids()
27 | self.img_metadata = self.build_img_metadata()
28 |
29 | self.transform = transform
30 |
31 | def __len__(self):
32 | return len(self.img_metadata)
33 |
34 | def __getitem__(self, idx):
35 | query_name, support_names, class_sample = self.sample_episode(idx)
36 | query_img, query_mask, support_imgs, support_masks = self.load_frame(query_name, support_names)
37 |
38 | query_img = self.transform(query_img)
39 | query_mask = F.interpolate(
40 | query_mask.unsqueeze(0).unsqueeze(0).float(), query_img.size()[-2:], mode='nearest').squeeze()
41 |
42 | support_imgs = torch.stack([self.transform(support_img) for support_img in support_imgs])
43 |
44 | support_masks_tmp = []
45 | for smask in support_masks:
46 | smask = F.interpolate(
47 | smask.unsqueeze(0).unsqueeze(0).float(), support_imgs.size()[-2:], mode='nearest').squeeze()
48 | support_masks_tmp.append(smask)
49 | support_masks = torch.stack(support_masks_tmp)
50 |
51 | batch = {'query_img': query_img,
52 | 'query_mask': query_mask,
53 | 'query_name': query_name,
54 |
55 | 'support_imgs': support_imgs,
56 | 'support_masks': support_masks,
57 | 'support_names': support_names,
58 |
59 | 'class_id': torch.tensor(class_sample)}
60 |
61 | return batch
62 |
63 | def load_frame(self, query_name, support_names):
64 | query_img = Image.open(query_name).convert('RGB')
65 | support_imgs = [Image.open(name).convert('RGB') for name in support_names]
66 |
67 | query_id = query_name.split('/')[-1].split('.')[0]
68 | query_name = os.path.join(os.path.dirname(query_name), query_id) + '.png'
69 | support_ids = [name.split('/')[-1].split('.')[0] for name in support_names]
70 | support_names = [os.path.join(os.path.dirname(name), sid) + '.png'
71 | for name, sid in zip(support_names, support_ids)]
72 |
73 | query_mask = self.read_mask(query_name)
74 | support_masks = [self.read_mask(name) for name in support_names]
75 |
76 | return query_img, query_mask, support_imgs, support_masks
77 |
78 | def read_mask(self, img_name):
79 | mask = torch.tensor(np.array(Image.open(img_name).convert('L')))
80 | mask[mask < 128] = 0
81 | mask[mask >= 128] = 1
82 | return mask
83 |
84 | def sample_episode(self, idx):
85 | query_name = self.img_metadata[idx]
86 | class_sample = self.categories.index(query_name.split('/')[-2])
87 | if self.split == 'val':
88 | class_sample += 520
89 | elif self.split == 'test':
90 | class_sample += 760
91 |
92 | support_names = []
93 | while True: # keep sampling support set if query == support
94 | support_name = np.random.choice(range(1, 11), 1, replace=False)[0]
95 | support_name = os.path.join(os.path.dirname(query_name), str(support_name)) + '.jpg'
96 | if query_name != support_name:
97 | support_names.append(support_name)
98 | if len(support_names) == self.shot:
99 | break
100 |
101 | return query_name, support_names, class_sample
102 |
103 | def build_class_ids(self):
104 | if self.split == 'trn':
105 | class_ids = range(0, 520)
106 | elif self.split == 'val':
107 | class_ids = range(520, 760)
108 | elif self.split == 'test':
109 | class_ids = range(760, 1000)
110 | return class_ids
111 |
112 | def build_img_metadata(self):
113 | img_metadata = []
114 | for cat in self.categories:
115 | img_paths = sorted([path for path in glob.glob('%s/*' % os.path.join(self.base_path, cat))])
116 | for img_path in img_paths:
117 | if os.path.basename(img_path).split('.')[1] == 'jpg':
118 | img_metadata.append(img_path)
119 | return img_metadata
120 |
--------------------------------------------------------------------------------
/data/pascal.py:
--------------------------------------------------------------------------------
1 | r""" PASCAL-5i few-shot semantic segmentation dataset """
2 | import os
3 | import pdb
4 |
5 | from torch.utils.data import Dataset
6 | import torch.nn.functional as F
7 | import torch
8 | import PIL.Image as Image
9 | import numpy as np
10 |
11 |
12 | class DatasetPASCAL(Dataset):
13 | def __init__(self, datapath, fold, transform, split, shot, use_original_imgsize, cam_train_path, cam_val_path):
14 | self.split = 'val' if split in ['val', 'test'] else 'trn'
15 | self.fold = fold
16 | self.nfolds = 4
17 | self.nclass = 20
18 | self.benchmark = 'pascal'
19 | self.shot = shot
20 | self.use_original_imgsize = use_original_imgsize
21 |
22 | self.img_path = os.path.join(datapath, 'VOC2012/JPEGImages/')
23 | self.ann_path = os.path.join(datapath, 'VOC2012/SegmentationClassAug/')
24 | self.transform = transform
25 |
26 | self.class_ids = self.build_class_ids()
27 | self.img_metadata = self.build_img_metadata()
28 | self.img_metadata_classwise = self.build_img_metadata_classwise()
29 |
30 | self.cam_train_path = cam_train_path
31 | self.cam_val_path = cam_val_path
32 |
33 | def __len__(self):
34 | return len(self.img_metadata) if self.split == 'trn' else 1000
35 |
36 | def __getitem__(self, idx):
37 | idx %= len(self.img_metadata) # for testing, as n_images < 1000
38 | query_name, support_names, class_sample = self.sample_episode(idx)
39 | query_img, query_cmask, support_imgs, support_cmasks, org_qry_imsize = \
40 | self.load_frame(query_name, support_names)
41 |
42 | query_img = self.transform(query_img)
43 | if not self.use_original_imgsize:
44 | query_cmask = F.interpolate(query_cmask.unsqueeze(0).unsqueeze(0).float(), query_img.size()[-2:],
45 | mode='nearest').squeeze()
46 | query_mask, query_ignore_idx = self.extract_ignore_idx(query_cmask.float(), class_sample)
47 |
48 | support_imgs = torch.stack([self.transform(support_img) for support_img in support_imgs])
49 |
50 | support_masks = []
51 | support_ignore_idxs = []
52 | for scmask in support_cmasks:
53 | scmask = F.interpolate(scmask.unsqueeze(0).unsqueeze(0).float(), support_imgs.size()[-2:],
54 | mode='nearest').squeeze()
55 | support_mask, support_ignore_idx = self.extract_ignore_idx(scmask, class_sample)
56 | support_masks.append(support_mask)
57 | support_ignore_idxs.append(support_ignore_idx)
58 | support_masks = torch.stack(support_masks)
59 | support_ignore_idxs = torch.stack(support_ignore_idxs)
60 |
61 | if self.split == 'val':
62 | query_cam_path = self.cam_val_path + query_name + '--' + str(class_sample) + '.pt'
63 | query_cam = torch.load(query_cam_path)
64 | nshot = len(support_names)
65 | support_cams = []
66 | for nn in range(nshot):
67 | support_cam_path = self.cam_val_path + support_names[nn] + '--' + str(class_sample) + '.pt'
68 | support_cam = torch.load(support_cam_path).unsqueeze(0)
69 | support_cams.append(support_cam)
70 | support_cams = torch.cat(support_cams, dim=0)
71 | else:
72 | query_cam_path = self.cam_train_path + query_name + '--' + str(class_sample) + '.pt'
73 | query_cam = torch.load(query_cam_path)
74 | nshot = len(support_names)
75 | support_cams = []
76 | for nn in range(nshot):
77 | support_cam_path = self.cam_train_path + support_names[nn] + '--' + str(class_sample) + '.pt'
78 | support_cam = torch.load(support_cam_path).unsqueeze(0) # 1 50 50
79 | support_cams.append(support_cam)
80 | support_cams = torch.cat(support_cams, dim=0) # nshot 50 50
81 |
82 | batch = {'query_img': query_img,
83 | 'query_mask': query_mask,
84 | 'query_name': query_name,
85 | 'query_ignore_idx': query_ignore_idx,
86 | 'org_query_imsize': org_qry_imsize,
87 | 'support_imgs': support_imgs,
88 | 'support_masks': support_masks,
89 | 'support_names': support_names,
90 | 'support_ignore_idxs': support_ignore_idxs,
91 | 'class_id': torch.tensor(class_sample),
92 | 'query_cam': query_cam,
93 | 'support_cams': support_cams}
94 |
95 | return batch
96 |
97 | def extract_ignore_idx(self, mask, class_id):
98 | boundary = (mask / 255).floor()
99 | mask[mask != class_id + 1] = 0
100 | mask[mask == class_id + 1] = 1
101 |
102 | return mask, boundary
103 |
104 | def load_frame(self, query_name, support_names):
105 | query_img = self.read_img(query_name)
106 | query_mask = self.read_mask(query_name)
107 | support_imgs = [self.read_img(name) for name in support_names]
108 | support_masks = [self.read_mask(name) for name in support_names]
109 |
110 | org_qry_imsize = query_img.size
111 |
112 | return query_img, query_mask, support_imgs, support_masks, org_qry_imsize
113 |
114 | def read_mask(self, img_name):
115 | r"""Return segmentation mask in PIL Image"""
116 | mask = torch.tensor(np.array(Image.open(os.path.join(self.ann_path, img_name) + '.png')))
117 | return mask
118 |
119 | def read_img(self, img_name):
120 | r"""Return RGB image in PIL Image"""
121 | return Image.open(os.path.join(self.img_path, img_name) + '.jpg')
122 |
123 | def sample_episode(self, idx):
124 | query_name, class_sample = self.img_metadata[idx]
125 |
126 | support_names = []
127 | while True: # keep sampling support set if query == support
128 | support_name = np.random.choice(self.img_metadata_classwise[class_sample], 1, replace=False)[0]
129 | if query_name != support_name:
130 | support_names.append(support_name)
131 | if len(support_names) == self.shot:
132 | break
133 |
134 | return query_name, support_names, class_sample
135 |
136 | def build_class_ids(self):
137 | nclass_trn = self.nclass // self.nfolds
138 | class_ids_val = [self.fold * nclass_trn + i for i in range(nclass_trn)]
139 | class_ids_trn = [x for x in range(self.nclass) if x not in class_ids_val]
140 |
141 | if self.split == 'trn':
142 | return class_ids_trn
143 | else:
144 | return class_ids_val
145 |
146 | def build_img_metadata(self):
147 |
148 | def read_metadata(split, fold_id):
149 | fold_n_metadata = os.path.join('data/splits/pascal/%s/fold%d.txt' % (split, fold_id))
150 | with open(fold_n_metadata, 'r') as f:
151 | fold_n_metadata = f.read().split('\n')[:-1]
152 | fold_n_metadata = [[data.split('__')[0], int(data.split('__')[1]) - 1] for data in fold_n_metadata]
153 | return fold_n_metadata
154 |
155 | img_metadata = []
156 | if self.split == 'trn': # For training, read image-metadata of "the other" folds
157 | for fold_id in range(self.nfolds):
158 | if fold_id == self.fold: # Skip validation fold
159 | continue
160 | img_metadata += read_metadata(self.split, fold_id)
161 | elif self.split == 'val': # For validation, read image-metadata of "current" fold
162 | img_metadata = read_metadata(self.split, self.fold)
163 | else:
164 | raise Exception('Undefined split %s: ' % self.split)
165 |
166 | print('Total (%s) images are : %d' % (self.split, len(img_metadata)))
167 |
168 | return img_metadata
169 |
170 | def build_img_metadata_classwise(self):
171 | img_metadata_classwise = {}
172 | for class_id in range(self.nclass):
173 | img_metadata_classwise[class_id] = []
174 |
175 | for img_name, img_class in self.img_metadata:
176 | img_metadata_classwise[img_class] += [img_name]
177 | return img_metadata_classwise
178 |
--------------------------------------------------------------------------------
/data/splits/coco/trn/fold0.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Whileherham/IMR-HSNet/57572272f493c3b16540a650ff0398d605e46812/data/splits/coco/trn/fold0.pkl
--------------------------------------------------------------------------------
/data/splits/coco/trn/fold1.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Whileherham/IMR-HSNet/57572272f493c3b16540a650ff0398d605e46812/data/splits/coco/trn/fold1.pkl
--------------------------------------------------------------------------------
/data/splits/coco/trn/fold2.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Whileherham/IMR-HSNet/57572272f493c3b16540a650ff0398d605e46812/data/splits/coco/trn/fold2.pkl
--------------------------------------------------------------------------------
/data/splits/coco/trn/fold3.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Whileherham/IMR-HSNet/57572272f493c3b16540a650ff0398d605e46812/data/splits/coco/trn/fold3.pkl
--------------------------------------------------------------------------------
/data/splits/coco/val/fold0.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Whileherham/IMR-HSNet/57572272f493c3b16540a650ff0398d605e46812/data/splits/coco/val/fold0.pkl
--------------------------------------------------------------------------------
/data/splits/coco/val/fold1.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Whileherham/IMR-HSNet/57572272f493c3b16540a650ff0398d605e46812/data/splits/coco/val/fold1.pkl
--------------------------------------------------------------------------------
/data/splits/coco/val/fold2.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Whileherham/IMR-HSNet/57572272f493c3b16540a650ff0398d605e46812/data/splits/coco/val/fold2.pkl
--------------------------------------------------------------------------------
/data/splits/coco/val/fold3.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Whileherham/IMR-HSNet/57572272f493c3b16540a650ff0398d605e46812/data/splits/coco/val/fold3.pkl
--------------------------------------------------------------------------------
/data/splits/fss/test.txt:
--------------------------------------------------------------------------------
1 | bus
2 | hotel_slipper
3 | burj_al
4 | reflex_camera
5 | abe's_flyingfish
6 | oiltank_car
7 | doormat
8 | fish_eagle
9 | barber_shaver
10 | motorbike
11 | feather_clothes
12 | wandering_albatross
13 | rice_cooker
14 | delta_wing
15 | fish
16 | nintendo_switch
17 | bustard
18 | diver
19 | minicooper
20 | cathedrale_paris
21 | big_ben
22 | combination_lock
23 | villa_savoye
24 | american_alligator
25 | gym_ball
26 | andean_condor
27 | leggings
28 | pyramid_cube
29 | jet_aircraft
30 | meatloaf
31 | reel
32 | swan
33 | osprey
34 | crt_screen
35 | microscope
36 | rubber_eraser
37 | arrow
38 | monkey
39 | mitten
40 | spiderman
41 | parthenon
42 | bat
43 | chess_king
44 | sulphur_butterfly
45 | quail_egg
46 | oriole
47 | iron_man
48 | wooden_boat
49 | anise
50 | steering_wheel
51 | groenendael
52 | dwarf_beans
53 | pteropus
54 | chalk_brush
55 | bloodhound
56 | moon
57 | english_foxhound
58 | boxing_gloves
59 | peregine_falcon
60 | pyraminx
61 | cicada
62 | screw
63 | shower_curtain
64 | tredmill
65 | bulb
66 | bell_pepper
67 | lemur_catta
68 | doughnut
69 | twin_tower
70 | astronaut
71 | nintendo_3ds
72 | fennel_bulb
73 | indri
74 | captain_america_shield
75 | kunai
76 | broom
77 | iphone
78 | earphone1
79 | flying_squirrel
80 | onion
81 | vinyl
82 | sydney_opera_house
83 | oyster
84 | harmonica
85 | egg
86 | breast_pump
87 | guitar
88 | potato_chips
89 | tunnel
90 | cuckoo
91 | rubick_cube
92 | plastic_bag
93 | phonograph
94 | net_surface_shoes
95 | goldfinch
96 | ipad
97 | mite_predator
98 | coffee_mug
99 | golden_plover
100 | f1_racing
101 | lapwing
102 | nintendo_gba
103 | pizza
104 | rally_car
105 | drilling_platform
106 | cd
107 | fly
108 | magpie_bird
109 | leaf_fan
110 | little_blue_heron
111 | carriage
112 | moist_proof_pad
113 | flying_snakes
114 | dart_target
115 | warehouse_tray
116 | nintendo_wiiu
117 | chiffon_cake
118 | bath_ball
119 | manatee
120 | cloud
121 | marimba
122 | eagle
123 | ruler
124 | soymilk_machine
125 | sled
126 | seagull
127 | glider_flyingfish
128 | doublebus
129 | transport_helicopter
130 | window_screen
131 | truss_bridge
132 | wasp
133 | snowman
134 | poached_egg
135 | strawberry
136 | spinach
137 | earphone2
138 | downy_pitch
139 | taj_mahal
140 | rocking_chair
141 | cablestayed_bridge
142 | sealion
143 | banana_boat
144 | pheasant
145 | stone_lion
146 | electronic_stove
147 | fox
148 | iguana
149 | rugby_ball
150 | hang_glider
151 | water_buffalo
152 | lotus
153 | paper_plane
154 | missile
155 | flamingo
156 | american_chamelon
157 | kart
158 | chinese_knot
159 | cabbage_butterfly
160 | key
161 | church
162 | tiltrotor
163 | helicopter
164 | french_fries
165 | water_heater
166 | snow_leopard
167 | goblet
168 | fan
169 | snowplow
170 | leafhopper
171 | pspgo
172 | black_bear
173 | quail
174 | condor
175 | chandelier
176 | hair_razor
177 | white_wolf
178 | toaster
179 | pidan
180 | pyramid
181 | chicken_leg
182 | letter_opener
183 | apple_icon
184 | porcupine
185 | chicken
186 | stingray
187 | warplane
188 | windmill
189 | bamboo_slip
190 | wig
191 | flying_geckos
192 | stonechat
193 | haddock
194 | australian_terrier
195 | hover_board
196 | siamang
197 | canton_tower
198 | santa_sledge
199 | arch_bridge
200 | curlew
201 | sushi
202 | beet_root
203 | accordion
204 | leaf_egg
205 | stealth_aircraft
206 | stork
207 | bucket
208 | hawk
209 | chess_queen
210 | ocarina
211 | knife
212 | whippet
213 | cantilever_bridge
214 | may_bug
215 | wagtail
216 | leather_shoes
217 | wheelchair
218 | shumai
219 | speedboat
220 | vacuum_cup
221 | chess_knight
222 | pumpkin_pie
223 | wooden_spoon
224 | bamboo_dragonfly
225 | ganeva_chair
226 | soap
227 | clearwing_flyingfish
228 | pencil_sharpener1
229 | cricket
230 | photocopier
231 | nintendo_sp
232 | samarra_mosque
233 | clam
234 | charge_battery
235 | flying_frog
236 | ferrari911
237 | polo_shirt
238 | echidna
239 | coin
240 | tower_pisa
241 |
--------------------------------------------------------------------------------
/data/splits/fss/trn.txt:
--------------------------------------------------------------------------------
1 | fountain
2 | taxi
3 | assult_rifle
4 | radio
5 | comb
6 | box_turtle
7 | igloo
8 | head_cabbage
9 | cottontail
10 | coho
11 | ashtray
12 | joystick
13 | sleeping_bag
14 | jackfruit
15 | trailer_truck
16 | shower_cap
17 | ibex
18 | kinguin
19 | squirrel
20 | ac_wall
21 | sidewinder
22 | remote_control
23 | marshmallow
24 | bolotie
25 | polar_bear
26 | rock_beauty
27 | tokyo_tower
28 | wafer
29 | red_bayberry
30 | electronic_toothbrush
31 | hartebeest
32 | cassette
33 | oil_filter
34 | bomb
35 | walnut
36 | toilet_tissue
37 | memory_stick
38 | wild_boar
39 | cableways
40 | chihuahua
41 | envelope
42 | bison
43 | poker
44 | pubg_lvl3helmet
45 | indian_cobra
46 | staffordshire
47 | park_bench
48 | wombat
49 | black_grouse
50 | submarine
51 | washer
52 | agama
53 | coyote
54 | feeder
55 | sarong
56 | buckingham_palace
57 | frog
58 | steam_locomotive
59 | acorn
60 | german_pointer
61 | obelisk
62 | polecat
63 | black_swan
64 | butterfly
65 | mountain_tent
66 | gorilla
67 | sloth_bear
68 | aubergine
69 | stinkhorn
70 | stole
71 | owl
72 | mooli
73 | pool_table
74 | collar
75 | lhasa_apso
76 | ambulance
77 | spade
78 | pufferfish
79 | paint_brush
80 | lark
81 | golf_ball
82 | hock
83 | fork
84 | drake
85 | bee_house
86 | mooncake
87 | wok
88 | cocacola
89 | water_bike
90 | ladder
91 | psp
92 | bassoon
93 | bear
94 | border_terrier
95 | petri_dish
96 | pill_bottle
97 | aircraft_carrier
98 | panther
99 | canoe
100 | baseball_player
101 | turtle
102 | espresso
103 | throne
104 | cornet
105 | coucal
106 | eletrical_switch
107 | bra
108 | snail
109 | backpack
110 | jacamar
111 | scroll_brush
112 | gliding_lizard
113 | raft
114 | pinwheel
115 | grasshopper
116 | green_mamba
117 | eft_newt
118 | computer_mouse
119 | vine_snake
120 | recreational_vehicle
121 | llama
122 | meerkat
123 | chainsaw
124 | ferret
125 | garbage_can
126 | kangaroo
127 | litchi
128 | carbonara
129 | housefinch
130 | modem
131 | tebby_cat
132 | thatch
133 | face_powder
134 | tomb
135 | apple
136 | ladybug
137 | killer_whale
138 | rocket
139 | airship
140 | surfboard
141 | lesser_panda
142 | jordan_logo
143 | banana
144 | nail_scissor
145 | swab
146 | perfume
147 | punching_bag
148 | victor_icon
149 | waffle_iron
150 | trimaran
151 | garlic
152 | flute
153 | langur
154 | starfish
155 | parallel_bars
156 | dandie_dinmont
157 | cosmetic_brush
158 | screwdriver
159 | brick_card
160 | balance_weight
161 | hornet
162 | carton
163 | toothpaste
164 | bracelet
165 | egg_tart
166 | pencil_sharpener2
167 | swimming_glasses
168 | howler_monkey
169 | camel
170 | dragonfly
171 | lionfish
172 | convertible
173 | mule
174 | usb
175 | conch
176 | papaya
177 | garbage_truck
178 | dingo
179 | radiator
180 | solar_dish
181 | streetcar
182 | trilobite
183 | bouzouki
184 | ringlet_butterfly
185 | space_shuttle
186 | waffle
187 | american_staffordshire
188 | violin
189 | flowerpot
190 | forklift
191 | manx
192 | sundial
193 | snowmobile
194 | chickadee_bird
195 | ruffed_grouse
196 | brick_tea
197 | paddle
198 | stove
199 | carousel
200 | spatula
201 | beaker
202 | gas_pump
203 | lawn_mower
204 | speaker
205 | tank
206 | tresher
207 | kappa_logo
208 | hare
209 | tennis_racket
210 | shopping_cart
211 | thimble
212 | tractor
213 | anemone_fish
214 | trolleybus
215 | steak
216 | capuchin
217 | red_breasted_merganser
218 | golden_retriever
219 | light_tube
220 | flatworm
221 | melon_seed
222 | digital_watch
223 | jacko_lantern
224 | brown_bear
225 | cairn
226 | mushroom
227 | chalk
228 | skull
229 | stapler
230 | potato
231 | telescope
232 | proboscis
233 | microphone
234 | torii
235 | baseball_bat
236 | dhole
237 | excavator
238 | fig
239 | snake
240 | bradypod
241 | pepitas
242 | prairie_chicken
243 | scorpion
244 | shotgun
245 | bottle_cap
246 | file_cabinet
247 | grey_whale
248 | one-armed_bandit
249 | banded_gecko
250 | flying_disc
251 | croissant
252 | toothbrush
253 | miniskirt
254 | pokermon_ball
255 | gazelle
256 | grey_fox
257 | esport_chair
258 | necklace
259 | ptarmigan
260 | watermelon
261 | besom
262 | pomelo
263 | radio_telescope
264 | studio_couch
265 | black_stork
266 | vestment
267 | koala
268 | brambling
269 | muscle_car
270 | window_shade
271 | space_heater
272 | sunglasses
273 | motor_scooter
274 | ladyfinger
275 | pencil_box
276 | titi_monkey
277 | chicken_wings
278 | mount_fuji
279 | giant_panda
280 | dart
281 | fire_engine
282 | running_shoe
283 | dumbbell
284 | donkey
285 | loafer
286 | hard_disk
287 | globe
288 | lifeboat
289 | medical_kit
290 | brain_coral
291 | paper_towel
292 | dugong
293 | seatbelt
294 | skunk
295 | military_vest
296 | cocktail_shaker
297 | zucchini
298 | quad_drone
299 | ocicat
300 | shih-tzu
301 | teapot
302 | tile_roof
303 | cheese_burger
304 | handshower
305 | red_wolf
306 | stop_sign
307 | mouse
308 | battery
309 | adidas_logo2
310 | earplug
311 | hummingbird
312 | brush_pen
313 | pistachio
314 | hamster
315 | air_strip
316 | indian_elephant
317 | otter
318 | cucumber
319 | scabbard
320 | hawthorn
321 | bullet_train
322 | leopard
323 | whale
324 | cream
325 | chinese_date
326 | jellyfish
327 | lobster
328 | skua
329 | single_log
330 | chicory
331 | bagel
332 | beacon
333 | pingpong_racket
334 | spoon
335 | yurt
336 | wallaby
337 | egret
338 | christmas_stocking
339 | mcdonald_uncle
340 | wrench
341 | spark_plug
342 | triceratops
343 | wall_clock
344 | jinrikisha
345 | pickup
346 | rhinoceros
347 | swimming_trunk
348 | band-aid
349 | spotted_salamander
350 | leeks
351 | marmot
352 | warthog
353 | cello
354 | stool
355 | chest
356 | toilet_plunger
357 | wardrobe
358 | cannon
359 | adidas_logo1
360 | drumstick
361 | lady_slipper
362 | puma_logo
363 | great_wall
364 | white_shark
365 | witch_hat
366 | vending_machine
367 | wreck
368 | chopsticks
369 | garfish
370 | african_elephant
371 | children_slide
372 | hornbill
373 | zebra
374 | boa_constrictor
375 | armour
376 | pineapple
377 | angora
378 | brick
379 | car_wheel
380 | wallet
381 | boston_bull
382 | hyena
383 | lynx
384 | crash_helmet
385 | terrapin_turtle
386 | persian_cat
387 | shift_gear
388 | cactus_ball
389 | fur_coat
390 | plate
391 | pen
392 | okra
393 | mario
394 | airedale
395 | cowboy_hat
396 | celery
397 | macaque
398 | candle
399 | goose
400 | raccoon
401 | brasscica
402 | almond
403 | maotai_bottle
404 | soccer_ball
405 | sports_car
406 | tobacco_pipe
407 | water_polo
408 | eggnog
409 | hook
410 | ostrich
411 | patas
412 | table_lamp
413 | teddy
414 | mongoose
415 | spoonbill
416 | redheart
417 | crane
418 | dinosaur
419 | kitchen_knife
420 | seal
421 | baboon
422 | golfcart
423 | roller_coaster
424 | avocado
425 | birdhouse
426 | yorkshire_terrier
427 | saluki
428 | basketball
429 | buckler
430 | harvester
431 | afghan_hound
432 | beam_bridge
433 | guinea_pig
434 | lorikeet
435 | shakuhachi
436 | motarboard
437 | statue_liberty
438 | police_car
439 | sulphur_crested
440 | gourd
441 | sombrero
442 | mailbox
443 | adhensive_tape
444 | night_snake
445 | bushtit
446 | mouthpiece
447 | beaver
448 | bathtub
449 | printer
450 | cumquat
451 | orange
452 | cleaver
453 | quill_pen
454 | panpipe
455 | diamond
456 | gypsy_moth
457 | cauliflower
458 | lampshade
459 | cougar
460 | traffic_light
461 | briefcase
462 | ballpoint
463 | african_grey
464 | kremlin
465 | barometer
466 | peacock
467 | paper_crane
468 | sunscreen
469 | tofu
470 | bedlington_terrier
471 | snowball
472 | carrot
473 | tiger
474 | mink
475 | cristo_redentor
476 | ladle
477 | keyboard
478 | maraca
479 | monitor
480 | water_snake
481 | can_opener
482 | mud_turtle
483 | bald_eagle
484 | carp
485 | cn_tower
486 | egyptian_cat
487 | hen_of_the_woods
488 | measuring_cup
489 | roller_skate
490 | kite
491 | sandwich_cookies
492 | sandwich
493 | persimmon
494 | chess_bishop
495 | coffin
496 | ruddy_turnstone
497 | prayer_rug
498 | rain_barrel
499 | neck_brace
500 | nematode
501 | rosehip
502 | dutch_oven
503 | goldfish
504 | blossom_card
505 | dough
506 | trench_coat
507 | sponge
508 | stupa
509 | wash_basin
510 | electric_fan
511 | spring_scroll
512 | potted_plant
513 | sparrow
514 | car_mirror
515 | gecko
516 | diaper
517 | leatherback_turtle
518 | strainer
519 | guacamole
520 | microwave
521 |
--------------------------------------------------------------------------------
/data/splits/fss/val.txt:
--------------------------------------------------------------------------------
1 | handcuff
2 | mortar
3 | matchstick
4 | wine_bottle
5 | dowitcher
6 | triumphal_arch
7 | gyromitra
8 | hatchet
9 | airliner
10 | broccoli
11 | olive
12 | pubg_lvl3backpack
13 | calculator
14 | toucan
15 | shovel
16 | sewing_machine
17 | icecream
18 | woodpecker
19 | pig
20 | relay_stick
21 | mcdonald_sign
22 | cpu
23 | peanut
24 | pumpkin
25 | sturgeon
26 | hammer
27 | hami_melon
28 | squirrel_monkey
29 | shuriken
30 | power_drill
31 | pingpong_ball
32 | crocodile
33 | carambola
34 | monarch_butterfly
35 | drum
36 | water_tower
37 | panda
38 | toilet_brush
39 | pay_phone
40 | yonex_icon
41 | cricketball
42 | revolver
43 | chimpanzee
44 | crab
45 | corn
46 | baseball
47 | rabbit
48 | croquet_ball
49 | artichoke
50 | abacus
51 | harp
52 | bell
53 | gas_tank
54 | scissors
55 | vase
56 | upright_piano
57 | typewriter
58 | bittern
59 | impala
60 | tray
61 | fire_hydrant
62 | beer_bottle
63 | sock
64 | soup_bowl
65 | spider
66 | cherry
67 | macaw
68 | toilet_seat
69 | fire_balloon
70 | french_ball
71 | fox_squirrel
72 | volleyball
73 | cornmeal
74 | folding_chair
75 | pubg_airdrop
76 | beagle
77 | skateboard
78 | narcissus
79 | whiptail
80 | cup
81 | arabian_camel
82 | badger
83 | stopwatch
84 | ab_wheel
85 | ox
86 | lettuce
87 | monocycle
88 | redshank
89 | vulture
90 | whistle
91 | smoothing_iron
92 | mashed_potato
93 | conveyor
94 | yoga_pad
95 | tow_truck
96 | siamese_cat
97 | cigar
98 | white_stork
99 | sniper_rifle
100 | stretcher
101 | tulip
102 | handkerchief
103 | basset
104 | iceberg
105 | gibbon
106 | lacewing
107 | thrush
108 | cheetah
109 | bighorn_sheep
110 | espresso_maker
111 | pretzel
112 | english_setter
113 | sandbar
114 | cheese
115 | daisy
116 | arctic_fox
117 | briard
118 | colubus
119 | balance_beam
120 | coffeepot
121 | soap_dispenser
122 | yawl
123 | consomme
124 | parking_meter
125 | cactus
126 | turnstile
127 | taro
128 | fire_screen
129 | digital_clock
130 | rose
131 | pomegranate
132 | bee_eater
133 | schooner
134 | ski_mask
135 | jay_bird
136 | plaice
137 | red_fox
138 | syringe
139 | camomile
140 | pickelhaube
141 | blenheim_spaniel
142 | pear
143 | parachute
144 | common_newt
145 | bowtie
146 | cigarette
147 | oscilloscope
148 | laptop
149 | african_crocodile
150 | apron
151 | coconut
152 | sandal
153 | kwanyin
154 | lion
155 | eel
156 | balloon
157 | crepe
158 | armadillo
159 | kazoo
160 | lemon
161 | spider_monkey
162 | tape_player
163 | ipod
164 | bee
165 | sea_cucumber
166 | suitcase
167 | television
168 | pillow
169 | banjo
170 | rock_snake
171 | partridge
172 | platypus
173 | lycaenid_butterfly
174 | pinecone
175 | conversion_plug
176 | wolf
177 | frying_pan
178 | timber_wolf
179 | bluetick
180 | crayon
181 | giant_schnauzer
182 | orang
183 | scarerow
184 | kobe_logo
185 | loguat
186 | saxophone
187 | ceiling_fan
188 | cardoon
189 | equestrian_helmet
190 | louvre_pyramid
191 | hotdog
192 | ironing_board
193 | razor
194 | nagoya_castle
195 | loggerhead_turtle
196 | lipstick
197 | cradle
198 | strongbox
199 | raven
200 | kit_fox
201 | albatross
202 | flat-coated_retriever
203 | beer_glass
204 | ice_lolly
205 | sungnyemun
206 | totem_pole
207 | vacuum
208 | bolete
209 | mango
210 | ginger
211 | weasel
212 | cabbage
213 | refrigerator
214 | school_bus
215 | hippo
216 | tiger_cat
217 | saltshaker
218 | piano_keyboard
219 | windsor_tie
220 | sea_urchin
221 | microsd
222 | barbell
223 | swim_ring
224 | bulbul_bird
225 | water_ouzel
226 | ac_ground
227 | sweatshirt
228 | umbrella
229 | hair_drier
230 | hammerhead_shark
231 | tomato
232 | projector
233 | cushion
234 | dishwasher
235 | three-toed_sloth
236 | tiger_shark
237 | har_gow
238 | baby
239 | thor's_hammer
240 | nike_logo
241 |
--------------------------------------------------------------------------------
/data/splits/pascal/val/fold0.txt:
--------------------------------------------------------------------------------
1 | 2007_000033__01
2 | 2007_000061__04
3 | 2007_000129__02
4 | 2007_000346__05
5 | 2007_000529__04
6 | 2007_000559__05
7 | 2007_000572__02
8 | 2007_000762__05
9 | 2007_001288__01
10 | 2007_001289__03
11 | 2007_001311__02
12 | 2007_001408__05
13 | 2007_001568__01
14 | 2007_001630__02
15 | 2007_001761__01
16 | 2007_001884__01
17 | 2007_002094__03
18 | 2007_002266__01
19 | 2007_002376__01
20 | 2007_002400__03
21 | 2007_002619__01
22 | 2007_002719__04
23 | 2007_003088__05
24 | 2007_003131__04
25 | 2007_003188__02
26 | 2007_003349__03
27 | 2007_003571__04
28 | 2007_003621__02
29 | 2007_003682__03
30 | 2007_003861__04
31 | 2007_004052__01
32 | 2007_004143__03
33 | 2007_004241__04
34 | 2007_004468__05
35 | 2007_005074__04
36 | 2007_005107__02
37 | 2007_005294__05
38 | 2007_005304__05
39 | 2007_005428__05
40 | 2007_005509__01
41 | 2007_005600__01
42 | 2007_005705__04
43 | 2007_005828__01
44 | 2007_006076__03
45 | 2007_006086__05
46 | 2007_006449__02
47 | 2007_006946__01
48 | 2007_007084__03
49 | 2007_007235__02
50 | 2007_007341__01
51 | 2007_007470__01
52 | 2007_007477__04
53 | 2007_007836__02
54 | 2007_008051__03
55 | 2007_008084__03
56 | 2007_008204__05
57 | 2007_008670__03
58 | 2007_009088__03
59 | 2007_009258__02
60 | 2007_009323__03
61 | 2007_009458__05
62 | 2007_009687__05
63 | 2007_009817__03
64 | 2007_009911__01
65 | 2008_000120__04
66 | 2008_000123__03
67 | 2008_000533__03
68 | 2008_000725__02
69 | 2008_000911__05
70 | 2008_001013__04
71 | 2008_001040__04
72 | 2008_001135__04
73 | 2008_001260__04
74 | 2008_001404__02
75 | 2008_001514__03
76 | 2008_001531__02
77 | 2008_001546__01
78 | 2008_001580__04
79 | 2008_001966__03
80 | 2008_001971__01
81 | 2008_002043__03
82 | 2008_002269__02
83 | 2008_002358__01
84 | 2008_002429__03
85 | 2008_002467__05
86 | 2008_002504__04
87 | 2008_002775__05
88 | 2008_002864__05
89 | 2008_003034__04
90 | 2008_003076__05
91 | 2008_003108__02
92 | 2008_003110__03
93 | 2008_003155__01
94 | 2008_003270__02
95 | 2008_003369__01
96 | 2008_003858__04
97 | 2008_003876__01
98 | 2008_003886__04
99 | 2008_003926__01
100 | 2008_003976__01
101 | 2008_004363__02
102 | 2008_004654__02
103 | 2008_004659__05
104 | 2008_004704__01
105 | 2008_004758__02
106 | 2008_004995__02
107 | 2008_005262__05
108 | 2008_005338__01
109 | 2008_005628__04
110 | 2008_005727__02
111 | 2008_005812__05
112 | 2008_005904__05
113 | 2008_006216__01
114 | 2008_006229__04
115 | 2008_006254__02
116 | 2008_006703__01
117 | 2008_007120__03
118 | 2008_007143__04
119 | 2008_007219__05
120 | 2008_007350__01
121 | 2008_007498__03
122 | 2008_007811__05
123 | 2008_007994__03
124 | 2008_008268__03
125 | 2008_008629__02
126 | 2008_008711__02
127 | 2008_008746__03
128 | 2009_000032__01
129 | 2009_000037__03
130 | 2009_000121__05
131 | 2009_000149__02
132 | 2009_000201__05
133 | 2009_000205__01
134 | 2009_000318__03
135 | 2009_000354__02
136 | 2009_000387__01
137 | 2009_000421__04
138 | 2009_000440__01
139 | 2009_000446__04
140 | 2009_000457__02
141 | 2009_000469__04
142 | 2009_000573__02
143 | 2009_000619__03
144 | 2009_000664__03
145 | 2009_000723__04
146 | 2009_000828__04
147 | 2009_000840__05
148 | 2009_000879__03
149 | 2009_000991__03
150 | 2009_000998__03
151 | 2009_001108__03
152 | 2009_001160__03
153 | 2009_001255__02
154 | 2009_001278__05
155 | 2009_001314__03
156 | 2009_001332__01
157 | 2009_001565__03
158 | 2009_001607__03
159 | 2009_001683__03
160 | 2009_001718__02
161 | 2009_001765__03
162 | 2009_001818__05
163 | 2009_001850__01
164 | 2009_001851__01
165 | 2009_001941__04
166 | 2009_002185__05
167 | 2009_002295__02
168 | 2009_002320__01
169 | 2009_002372__05
170 | 2009_002521__05
171 | 2009_002594__05
172 | 2009_002604__03
173 | 2009_002649__05
174 | 2009_002727__04
175 | 2009_002732__05
176 | 2009_002749__05
177 | 2009_002808__01
178 | 2009_002856__05
179 | 2009_002888__01
180 | 2009_002928__02
181 | 2009_003003__05
182 | 2009_003005__01
183 | 2009_003043__04
184 | 2009_003080__04
185 | 2009_003193__02
186 | 2009_003224__02
187 | 2009_003269__05
188 | 2009_003273__03
189 | 2009_003343__02
190 | 2009_003378__03
191 | 2009_003450__03
192 | 2009_003498__03
193 | 2009_003504__04
194 | 2009_003517__05
195 | 2009_003640__03
196 | 2009_003696__01
197 | 2009_003707__04
198 | 2009_003806__01
199 | 2009_003858__03
200 | 2009_003971__02
201 | 2009_004021__03
202 | 2009_004084__03
203 | 2009_004125__04
204 | 2009_004247__05
205 | 2009_004324__05
206 | 2009_004509__03
207 | 2009_004540__03
208 | 2009_004568__03
209 | 2009_004579__05
210 | 2009_004635__04
211 | 2009_004653__01
212 | 2009_004848__02
213 | 2009_004882__02
214 | 2009_004886__03
215 | 2009_004895__03
216 | 2009_004969__01
217 | 2009_005038__05
218 | 2009_005137__03
219 | 2009_005156__02
220 | 2009_005189__01
221 | 2009_005190__05
222 | 2009_005260__03
223 | 2009_005262__03
224 | 2009_005302__05
225 | 2010_000065__02
226 | 2010_000083__02
227 | 2010_000084__04
228 | 2010_000238__01
229 | 2010_000241__03
230 | 2010_000272__04
231 | 2010_000342__02
232 | 2010_000426__05
233 | 2010_000572__01
234 | 2010_000622__01
235 | 2010_000814__03
236 | 2010_000906__04
237 | 2010_000961__03
238 | 2010_001016__03
239 | 2010_001017__01
240 | 2010_001024__01
241 | 2010_001036__04
242 | 2010_001061__03
243 | 2010_001069__03
244 | 2010_001174__01
245 | 2010_001367__02
246 | 2010_001367__05
247 | 2010_001448__01
248 | 2010_001830__05
249 | 2010_001995__03
250 | 2010_002017__05
251 | 2010_002030__02
252 | 2010_002142__03
253 | 2010_002147__01
254 | 2010_002150__04
255 | 2010_002200__01
256 | 2010_002310__01
257 | 2010_002536__02
258 | 2010_002546__04
259 | 2010_002693__02
260 | 2010_002939__01
261 | 2010_003127__01
262 | 2010_003132__01
263 | 2010_003168__03
264 | 2010_003362__03
265 | 2010_003365__01
266 | 2010_003418__03
267 | 2010_003468__05
268 | 2010_003473__03
269 | 2010_003495__01
270 | 2010_003547__04
271 | 2010_003716__01
272 | 2010_003771__03
273 | 2010_003781__05
274 | 2010_003820__03
275 | 2010_003912__02
276 | 2010_003915__01
277 | 2010_004041__04
278 | 2010_004056__05
279 | 2010_004208__04
280 | 2010_004314__01
281 | 2010_004419__01
282 | 2010_004520__05
283 | 2010_004529__05
284 | 2010_004551__05
285 | 2010_004556__03
286 | 2010_004559__03
287 | 2010_004662__04
288 | 2010_004772__04
289 | 2010_004828__05
290 | 2010_004994__03
291 | 2010_005252__04
292 | 2010_005401__04
293 | 2010_005428__03
294 | 2010_005496__05
295 | 2010_005531__03
296 | 2010_005534__01
297 | 2010_005582__05
298 | 2010_005664__02
299 | 2010_005705__04
300 | 2010_005718__01
301 | 2010_005762__05
302 | 2010_005877__01
303 | 2010_005888__01
304 | 2010_006034__01
305 | 2010_006070__02
306 | 2011_000066__05
307 | 2011_000112__03
308 | 2011_000185__03
309 | 2011_000234__04
310 | 2011_000238__04
311 | 2011_000412__02
312 | 2011_000435__04
313 | 2011_000456__03
314 | 2011_000482__03
315 | 2011_000585__02
316 | 2011_000669__03
317 | 2011_000747__05
318 | 2011_000874__01
319 | 2011_001114__01
320 | 2011_001161__04
321 | 2011_001263__01
322 | 2011_001287__03
323 | 2011_001407__01
324 | 2011_001421__03
325 | 2011_001434__01
326 | 2011_001589__04
327 | 2011_001624__01
328 | 2011_001793__04
329 | 2011_001880__01
330 | 2011_001988__02
331 | 2011_002064__02
332 | 2011_002098__05
333 | 2011_002223__02
334 | 2011_002295__03
335 | 2011_002327__01
336 | 2011_002515__01
337 | 2011_002675__01
338 | 2011_002713__02
339 | 2011_002754__04
340 | 2011_002863__05
341 | 2011_002929__01
342 | 2011_002975__04
343 | 2011_003003__02
344 | 2011_003030__03
345 | 2011_003145__03
346 | 2011_003271__05
347 |
--------------------------------------------------------------------------------
/data/splits/pascal/val/fold1.txt:
--------------------------------------------------------------------------------
1 | 2007_000452__09
2 | 2007_000464__10
3 | 2007_000491__10
4 | 2007_000663__06
5 | 2007_000663__07
6 | 2007_000727__06
7 | 2007_000727__07
8 | 2007_000804__09
9 | 2007_000830__09
10 | 2007_001299__10
11 | 2007_001321__07
12 | 2007_001457__09
13 | 2007_001677__09
14 | 2007_001717__09
15 | 2007_001763__08
16 | 2007_001774__08
17 | 2007_001884__06
18 | 2007_002268__08
19 | 2007_002387__10
20 | 2007_002445__08
21 | 2007_002470__08
22 | 2007_002539__06
23 | 2007_002597__08
24 | 2007_002643__07
25 | 2007_002903__10
26 | 2007_003011__09
27 | 2007_003051__07
28 | 2007_003101__06
29 | 2007_003106__08
30 | 2007_003137__06
31 | 2007_003143__07
32 | 2007_003169__08
33 | 2007_003195__06
34 | 2007_003201__10
35 | 2007_003503__06
36 | 2007_003503__07
37 | 2007_003621__06
38 | 2007_003711__06
39 | 2007_003786__06
40 | 2007_003841__10
41 | 2007_003917__07
42 | 2007_003991__08
43 | 2007_004193__09
44 | 2007_004392__09
45 | 2007_004405__09
46 | 2007_004510__09
47 | 2007_004712__09
48 | 2007_004856__08
49 | 2007_004866__08
50 | 2007_005074__07
51 | 2007_005114__10
52 | 2007_005296__07
53 | 2007_005331__07
54 | 2007_005460__08
55 | 2007_005547__07
56 | 2007_005547__10
57 | 2007_005844__09
58 | 2007_005845__08
59 | 2007_005911__06
60 | 2007_005978__06
61 | 2007_006035__07
62 | 2007_006086__09
63 | 2007_006241__09
64 | 2007_006260__08
65 | 2007_006277__07
66 | 2007_006348__09
67 | 2007_006553__09
68 | 2007_006761__10
69 | 2007_006841__10
70 | 2007_007414__07
71 | 2007_007417__08
72 | 2007_007524__08
73 | 2007_007815__07
74 | 2007_007818__07
75 | 2007_007996__09
76 | 2007_008106__09
77 | 2007_008110__09
78 | 2007_008543__09
79 | 2007_008722__10
80 | 2007_008747__06
81 | 2007_008815__08
82 | 2007_008897__09
83 | 2007_008973__10
84 | 2007_009015__06
85 | 2007_009015__07
86 | 2007_009068__09
87 | 2007_009084__09
88 | 2007_009096__07
89 | 2007_009221__08
90 | 2007_009245__10
91 | 2007_009346__08
92 | 2007_009392__06
93 | 2007_009392__07
94 | 2007_009413__09
95 | 2007_009521__09
96 | 2007_009764__06
97 | 2007_009794__08
98 | 2007_009897__10
99 | 2007_009923__08
100 | 2007_009938__07
101 | 2008_000009__10
102 | 2008_000073__10
103 | 2008_000075__06
104 | 2008_000107__09
105 | 2008_000149__09
106 | 2008_000182__08
107 | 2008_000345__08
108 | 2008_000401__08
109 | 2008_000464__08
110 | 2008_000501__07
111 | 2008_000673__09
112 | 2008_000853__08
113 | 2008_000919__10
114 | 2008_001078__08
115 | 2008_001433__08
116 | 2008_001439__09
117 | 2008_001513__08
118 | 2008_001640__08
119 | 2008_001715__09
120 | 2008_001885__08
121 | 2008_002152__08
122 | 2008_002205__06
123 | 2008_002212__07
124 | 2008_002379__09
125 | 2008_002521__09
126 | 2008_002623__08
127 | 2008_002681__08
128 | 2008_002778__10
129 | 2008_002958__07
130 | 2008_003141__06
131 | 2008_003141__07
132 | 2008_003333__07
133 | 2008_003477__09
134 | 2008_003499__08
135 | 2008_003577__07
136 | 2008_003777__06
137 | 2008_003821__09
138 | 2008_003846__07
139 | 2008_004069__07
140 | 2008_004339__07
141 | 2008_004552__07
142 | 2008_004612__09
143 | 2008_004701__10
144 | 2008_005097__10
145 | 2008_005105__10
146 | 2008_005245__07
147 | 2008_005676__06
148 | 2008_006008__09
149 | 2008_006063__10
150 | 2008_006254__07
151 | 2008_006325__08
152 | 2008_006341__08
153 | 2008_006480__08
154 | 2008_006528__10
155 | 2008_006554__06
156 | 2008_006986__07
157 | 2008_007025__10
158 | 2008_007031__10
159 | 2008_007048__09
160 | 2008_007123__10
161 | 2008_007194__09
162 | 2008_007273__10
163 | 2008_007378__09
164 | 2008_007402__09
165 | 2008_007527__09
166 | 2008_007548__08
167 | 2008_007596__10
168 | 2008_007737__09
169 | 2008_007797__06
170 | 2008_007804__07
171 | 2008_007828__09
172 | 2008_008252__06
173 | 2008_008301__06
174 | 2008_008469__06
175 | 2008_008682__06
176 | 2009_000013__08
177 | 2009_000080__08
178 | 2009_000219__10
179 | 2009_000309__10
180 | 2009_000335__06
181 | 2009_000335__07
182 | 2009_000426__06
183 | 2009_000455__06
184 | 2009_000457__07
185 | 2009_000523__07
186 | 2009_000641__10
187 | 2009_000716__08
188 | 2009_000731__10
189 | 2009_000771__10
190 | 2009_000825__07
191 | 2009_000964__08
192 | 2009_001008__08
193 | 2009_001082__06
194 | 2009_001240__07
195 | 2009_001255__07
196 | 2009_001299__09
197 | 2009_001391__08
198 | 2009_001411__08
199 | 2009_001536__07
200 | 2009_001775__09
201 | 2009_001804__06
202 | 2009_001816__06
203 | 2009_001854__06
204 | 2009_002035__10
205 | 2009_002122__10
206 | 2009_002150__10
207 | 2009_002164__07
208 | 2009_002171__10
209 | 2009_002221__10
210 | 2009_002238__06
211 | 2009_002238__07
212 | 2009_002239__07
213 | 2009_002268__08
214 | 2009_002346__09
215 | 2009_002415__09
216 | 2009_002487__09
217 | 2009_002527__08
218 | 2009_002535__06
219 | 2009_002549__10
220 | 2009_002571__09
221 | 2009_002618__07
222 | 2009_002635__10
223 | 2009_002753__08
224 | 2009_002936__08
225 | 2009_002990__07
226 | 2009_003003__07
227 | 2009_003059__10
228 | 2009_003071__09
229 | 2009_003269__07
230 | 2009_003304__06
231 | 2009_003387__07
232 | 2009_003406__07
233 | 2009_003494__09
234 | 2009_003507__09
235 | 2009_003542__10
236 | 2009_003549__07
237 | 2009_003569__10
238 | 2009_003589__07
239 | 2009_003703__06
240 | 2009_003771__08
241 | 2009_003773__10
242 | 2009_003849__09
243 | 2009_003895__09
244 | 2009_003904__08
245 | 2009_004072__06
246 | 2009_004140__09
247 | 2009_004217__09
248 | 2009_004248__08
249 | 2009_004455__07
250 | 2009_004504__08
251 | 2009_004590__06
252 | 2009_004594__07
253 | 2009_004687__09
254 | 2009_004721__08
255 | 2009_004732__06
256 | 2009_004748__07
257 | 2009_004789__06
258 | 2009_004859__09
259 | 2009_004867__06
260 | 2009_005158__08
261 | 2009_005219__08
262 | 2009_005231__06
263 | 2010_000003__09
264 | 2010_000160__07
265 | 2010_000163__08
266 | 2010_000372__07
267 | 2010_000427__10
268 | 2010_000530__07
269 | 2010_000552__08
270 | 2010_000573__06
271 | 2010_000628__07
272 | 2010_000639__09
273 | 2010_000682__06
274 | 2010_000683__08
275 | 2010_000724__08
276 | 2010_000907__10
277 | 2010_000941__08
278 | 2010_000952__07
279 | 2010_001000__10
280 | 2010_001010__10
281 | 2010_001070__08
282 | 2010_001206__06
283 | 2010_001292__08
284 | 2010_001331__08
285 | 2010_001351__08
286 | 2010_001403__06
287 | 2010_001403__07
288 | 2010_001534__08
289 | 2010_001553__07
290 | 2010_001579__09
291 | 2010_001646__06
292 | 2010_001656__08
293 | 2010_001692__10
294 | 2010_001699__09
295 | 2010_001767__07
296 | 2010_001851__09
297 | 2010_001913__08
298 | 2010_002017__07
299 | 2010_002017__09
300 | 2010_002025__08
301 | 2010_002137__08
302 | 2010_002146__08
303 | 2010_002305__08
304 | 2010_002336__09
305 | 2010_002348__08
306 | 2010_002361__07
307 | 2010_002390__10
308 | 2010_002422__08
309 | 2010_002512__08
310 | 2010_002531__08
311 | 2010_002546__06
312 | 2010_002623__09
313 | 2010_002693__08
314 | 2010_002693__09
315 | 2010_002763__08
316 | 2010_002763__10
317 | 2010_002868__06
318 | 2010_002900__08
319 | 2010_002902__07
320 | 2010_002921__09
321 | 2010_002929__07
322 | 2010_002988__07
323 | 2010_003123__07
324 | 2010_003183__10
325 | 2010_003231__07
326 | 2010_003239__10
327 | 2010_003275__08
328 | 2010_003276__07
329 | 2010_003293__06
330 | 2010_003302__09
331 | 2010_003325__09
332 | 2010_003381__07
333 | 2010_003402__08
334 | 2010_003409__09
335 | 2010_003446__07
336 | 2010_003453__07
337 | 2010_003468__08
338 | 2010_003531__09
339 | 2010_003675__08
340 | 2010_003746__07
341 | 2010_003758__08
342 | 2010_003764__08
343 | 2010_003768__07
344 | 2010_003772__06
345 | 2010_003781__08
346 | 2010_003813__07
347 | 2010_003854__07
348 | 2010_003971__08
349 | 2010_003971__09
350 | 2010_004104__08
351 | 2010_004120__08
352 | 2010_004320__08
353 | 2010_004322__10
354 | 2010_004348__06
355 | 2010_004369__08
356 | 2010_004472__07
357 | 2010_004479__08
358 | 2010_004635__10
359 | 2010_004763__09
360 | 2010_004783__09
361 | 2010_004789__10
362 | 2010_004815__08
363 | 2010_004825__09
364 | 2010_004861__08
365 | 2010_004946__07
366 | 2010_005013__07
367 | 2010_005021__08
368 | 2010_005021__09
369 | 2010_005063__06
370 | 2010_005108__08
371 | 2010_005118__06
372 | 2010_005160__06
373 | 2010_005166__10
374 | 2010_005284__06
375 | 2010_005344__08
376 | 2010_005421__08
377 | 2010_005432__07
378 | 2010_005501__07
379 | 2010_005508__08
380 | 2010_005606__08
381 | 2010_005709__08
382 | 2010_005718__07
383 | 2010_005860__07
384 | 2010_005899__08
385 | 2010_006070__07
386 | 2011_000178__06
387 | 2011_000226__09
388 | 2011_000239__06
389 | 2011_000248__06
390 | 2011_000312__06
391 | 2011_000338__09
392 | 2011_000419__08
393 | 2011_000503__07
394 | 2011_000548__10
395 | 2011_000566__10
396 | 2011_000607__09
397 | 2011_000661__08
398 | 2011_000661__09
399 | 2011_000780__08
400 | 2011_000789__08
401 | 2011_000809__09
402 | 2011_000813__08
403 | 2011_000813__09
404 | 2011_000830__06
405 | 2011_000843__09
406 | 2011_000888__06
407 | 2011_000900__07
408 | 2011_000969__06
409 | 2011_001047__10
410 | 2011_001064__06
411 | 2011_001071__09
412 | 2011_001110__07
413 | 2011_001159__10
414 | 2011_001232__10
415 | 2011_001292__08
416 | 2011_001341__06
417 | 2011_001346__09
418 | 2011_001447__09
419 | 2011_001530__10
420 | 2011_001534__08
421 | 2011_001546__10
422 | 2011_001567__09
423 | 2011_001597__08
424 | 2011_001601__08
425 | 2011_001607__08
426 | 2011_001665__09
427 | 2011_001708__10
428 | 2011_001775__08
429 | 2011_001782__10
430 | 2011_001812__09
431 | 2011_002041__09
432 | 2011_002064__07
433 | 2011_002124__09
434 | 2011_002200__09
435 | 2011_002298__09
436 | 2011_002322__07
437 | 2011_002343__09
438 | 2011_002358__09
439 | 2011_002391__09
440 | 2011_002509__09
441 | 2011_002592__07
442 | 2011_002644__09
443 | 2011_002685__08
444 | 2011_002812__07
445 | 2011_002885__10
446 | 2011_003011__09
447 | 2011_003019__07
448 | 2011_003019__10
449 | 2011_003055__07
450 | 2011_003103__09
451 | 2011_003114__06
452 |
--------------------------------------------------------------------------------
/data/splits/pascal/val/fold2.txt:
--------------------------------------------------------------------------------
1 | 2007_000129__15
2 | 2007_000323__15
3 | 2007_000332__13
4 | 2007_000346__15
5 | 2007_000762__11
6 | 2007_000762__15
7 | 2007_000783__13
8 | 2007_000783__15
9 | 2007_000799__13
10 | 2007_000799__15
11 | 2007_000830__11
12 | 2007_000847__11
13 | 2007_000847__15
14 | 2007_000999__15
15 | 2007_001175__15
16 | 2007_001239__12
17 | 2007_001284__15
18 | 2007_001311__15
19 | 2007_001408__15
20 | 2007_001423__15
21 | 2007_001430__11
22 | 2007_001430__15
23 | 2007_001526__15
24 | 2007_001585__15
25 | 2007_001586__13
26 | 2007_001586__15
27 | 2007_001594__15
28 | 2007_001630__15
29 | 2007_001677__11
30 | 2007_001678__15
31 | 2007_001717__15
32 | 2007_001763__12
33 | 2007_001955__13
34 | 2007_002046__13
35 | 2007_002119__15
36 | 2007_002260__14
37 | 2007_002268__12
38 | 2007_002378__15
39 | 2007_002426__15
40 | 2007_002539__15
41 | 2007_002565__15
42 | 2007_002597__12
43 | 2007_002624__11
44 | 2007_002624__15
45 | 2007_002643__15
46 | 2007_002728__15
47 | 2007_002823__14
48 | 2007_002823__15
49 | 2007_002824__15
50 | 2007_002852__12
51 | 2007_003011__11
52 | 2007_003020__15
53 | 2007_003022__13
54 | 2007_003022__15
55 | 2007_003088__15
56 | 2007_003106__15
57 | 2007_003110__12
58 | 2007_003134__15
59 | 2007_003188__15
60 | 2007_003194__12
61 | 2007_003367__14
62 | 2007_003367__15
63 | 2007_003373__12
64 | 2007_003373__15
65 | 2007_003530__15
66 | 2007_003621__15
67 | 2007_003742__11
68 | 2007_003742__15
69 | 2007_003872__12
70 | 2007_004033__14
71 | 2007_004033__15
72 | 2007_004112__12
73 | 2007_004112__15
74 | 2007_004121__15
75 | 2007_004189__12
76 | 2007_004275__14
77 | 2007_004275__15
78 | 2007_004281__15
79 | 2007_004380__14
80 | 2007_004380__15
81 | 2007_004392__15
82 | 2007_004405__11
83 | 2007_004538__13
84 | 2007_004538__15
85 | 2007_004644__12
86 | 2007_004712__11
87 | 2007_004712__15
88 | 2007_004722__13
89 | 2007_004722__15
90 | 2007_004902__13
91 | 2007_004902__15
92 | 2007_005114__13
93 | 2007_005114__15
94 | 2007_005149__12
95 | 2007_005173__14
96 | 2007_005173__15
97 | 2007_005281__15
98 | 2007_005304__15
99 | 2007_005331__13
100 | 2007_005331__15
101 | 2007_005354__14
102 | 2007_005354__15
103 | 2007_005509__15
104 | 2007_005547__15
105 | 2007_005608__14
106 | 2007_005608__15
107 | 2007_005696__12
108 | 2007_005759__14
109 | 2007_005803__11
110 | 2007_005844__11
111 | 2007_005845__15
112 | 2007_006028__15
113 | 2007_006076__15
114 | 2007_006086__11
115 | 2007_006117__15
116 | 2007_006171__12
117 | 2007_006171__15
118 | 2007_006241__11
119 | 2007_006364__13
120 | 2007_006364__15
121 | 2007_006373__15
122 | 2007_006444__12
123 | 2007_006444__15
124 | 2007_006560__15
125 | 2007_006647__14
126 | 2007_006647__15
127 | 2007_006698__15
128 | 2007_006802__15
129 | 2007_006841__15
130 | 2007_006864__15
131 | 2007_006866__13
132 | 2007_006866__15
133 | 2007_007007__11
134 | 2007_007007__15
135 | 2007_007109__13
136 | 2007_007109__15
137 | 2007_007195__15
138 | 2007_007203__15
139 | 2007_007211__14
140 | 2007_007235__15
141 | 2007_007417__12
142 | 2007_007493__15
143 | 2007_007498__11
144 | 2007_007498__15
145 | 2007_007651__11
146 | 2007_007651__15
147 | 2007_007688__14
148 | 2007_007748__13
149 | 2007_007748__15
150 | 2007_007795__15
151 | 2007_007810__11
152 | 2007_007810__15
153 | 2007_007815__15
154 | 2007_007836__15
155 | 2007_007849__15
156 | 2007_007996__15
157 | 2007_008110__15
158 | 2007_008204__15
159 | 2007_008222__12
160 | 2007_008256__13
161 | 2007_008256__15
162 | 2007_008260__12
163 | 2007_008374__15
164 | 2007_008415__12
165 | 2007_008430__15
166 | 2007_008596__13
167 | 2007_008596__15
168 | 2007_008708__15
169 | 2007_008802__13
170 | 2007_008897__15
171 | 2007_008944__15
172 | 2007_008964__12
173 | 2007_008964__15
174 | 2007_008980__12
175 | 2007_009068__15
176 | 2007_009084__12
177 | 2007_009084__14
178 | 2007_009251__13
179 | 2007_009251__15
180 | 2007_009258__15
181 | 2007_009320__15
182 | 2007_009331__12
183 | 2007_009331__13
184 | 2007_009331__15
185 | 2007_009413__11
186 | 2007_009413__15
187 | 2007_009521__11
188 | 2007_009562__12
189 | 2007_009592__12
190 | 2007_009654__15
191 | 2007_009655__15
192 | 2007_009684__15
193 | 2007_009687__15
194 | 2007_009691__14
195 | 2007_009691__15
196 | 2007_009706__11
197 | 2007_009750__15
198 | 2007_009756__14
199 | 2007_009756__15
200 | 2007_009841__13
201 | 2007_009938__14
202 | 2008_000080__12
203 | 2008_000213__15
204 | 2008_000215__15
205 | 2008_000223__15
206 | 2008_000233__15
207 | 2008_000234__15
208 | 2008_000239__12
209 | 2008_000270__12
210 | 2008_000270__15
211 | 2008_000271__15
212 | 2008_000359__15
213 | 2008_000474__15
214 | 2008_000510__15
215 | 2008_000573__11
216 | 2008_000573__15
217 | 2008_000602__13
218 | 2008_000630__15
219 | 2008_000661__12
220 | 2008_000661__15
221 | 2008_000662__15
222 | 2008_000666__15
223 | 2008_000673__15
224 | 2008_000700__15
225 | 2008_000725__15
226 | 2008_000731__15
227 | 2008_000763__11
228 | 2008_000763__15
229 | 2008_000765__13
230 | 2008_000782__14
231 | 2008_000795__15
232 | 2008_000811__14
233 | 2008_000811__15
234 | 2008_000863__12
235 | 2008_000943__12
236 | 2008_000992__15
237 | 2008_001013__15
238 | 2008_001028__15
239 | 2008_001070__12
240 | 2008_001074__15
241 | 2008_001076__15
242 | 2008_001150__14
243 | 2008_001170__15
244 | 2008_001231__15
245 | 2008_001249__15
246 | 2008_001283__15
247 | 2008_001308__15
248 | 2008_001379__12
249 | 2008_001404__15
250 | 2008_001478__12
251 | 2008_001491__15
252 | 2008_001504__15
253 | 2008_001531__15
254 | 2008_001547__15
255 | 2008_001629__15
256 | 2008_001682__13
257 | 2008_001821__15
258 | 2008_001874__15
259 | 2008_001895__12
260 | 2008_001895__15
261 | 2008_001992__13
262 | 2008_001992__15
263 | 2008_002212__15
264 | 2008_002239__12
265 | 2008_002240__14
266 | 2008_002241__15
267 | 2008_002379__11
268 | 2008_002383__14
269 | 2008_002495__15
270 | 2008_002536__12
271 | 2008_002588__15
272 | 2008_002775__11
273 | 2008_002775__15
274 | 2008_002835__13
275 | 2008_002835__15
276 | 2008_002859__12
277 | 2008_002864__11
278 | 2008_002864__15
279 | 2008_002904__12
280 | 2008_002929__15
281 | 2008_002936__12
282 | 2008_002942__15
283 | 2008_002958__12
284 | 2008_003034__15
285 | 2008_003076__15
286 | 2008_003108__15
287 | 2008_003141__15
288 | 2008_003210__15
289 | 2008_003238__12
290 | 2008_003238__15
291 | 2008_003330__15
292 | 2008_003333__14
293 | 2008_003333__15
294 | 2008_003379__13
295 | 2008_003451__14
296 | 2008_003451__15
297 | 2008_003461__13
298 | 2008_003461__15
299 | 2008_003477__11
300 | 2008_003492__15
301 | 2008_003511__12
302 | 2008_003511__15
303 | 2008_003546__15
304 | 2008_003576__12
305 | 2008_003676__15
306 | 2008_003733__15
307 | 2008_003782__13
308 | 2008_003856__15
309 | 2008_003874__15
310 | 2008_004101__15
311 | 2008_004140__11
312 | 2008_004140__15
313 | 2008_004175__13
314 | 2008_004345__14
315 | 2008_004396__13
316 | 2008_004399__14
317 | 2008_004399__15
318 | 2008_004575__11
319 | 2008_004575__15
320 | 2008_004624__13
321 | 2008_004654__15
322 | 2008_004687__13
323 | 2008_004705__13
324 | 2008_005049__14
325 | 2008_005089__15
326 | 2008_005145__11
327 | 2008_005197__12
328 | 2008_005197__15
329 | 2008_005245__14
330 | 2008_005245__15
331 | 2008_005399__15
332 | 2008_005422__14
333 | 2008_005445__15
334 | 2008_005525__13
335 | 2008_005637__14
336 | 2008_005642__13
337 | 2008_005691__13
338 | 2008_005738__15
339 | 2008_005812__15
340 | 2008_005915__14
341 | 2008_006008__11
342 | 2008_006036__13
343 | 2008_006108__11
344 | 2008_006108__15
345 | 2008_006130__12
346 | 2008_006216__15
347 | 2008_006219__13
348 | 2008_006254__15
349 | 2008_006275__15
350 | 2008_006341__15
351 | 2008_006408__11
352 | 2008_006408__15
353 | 2008_006526__14
354 | 2008_006526__15
355 | 2008_006554__15
356 | 2008_006722__12
357 | 2008_006722__15
358 | 2008_006874__14
359 | 2008_006874__15
360 | 2008_006981__12
361 | 2008_007048__11
362 | 2008_007219__15
363 | 2008_007378__11
364 | 2008_007378__12
365 | 2008_007392__13
366 | 2008_007392__15
367 | 2008_007402__11
368 | 2008_007402__15
369 | 2008_007513__12
370 | 2008_007737__15
371 | 2008_007828__15
372 | 2008_007945__13
373 | 2008_007994__15
374 | 2008_008051__11
375 | 2008_008127__14
376 | 2008_008127__15
377 | 2008_008221__15
378 | 2008_008335__11
379 | 2008_008335__15
380 | 2008_008362__11
381 | 2008_008362__15
382 | 2008_008392__13
383 | 2008_008393__13
384 | 2008_008421__13
385 | 2008_008469__15
386 | 2009_000012__13
387 | 2009_000074__14
388 | 2009_000074__15
389 | 2009_000156__12
390 | 2009_000219__15
391 | 2009_000309__15
392 | 2009_000412__13
393 | 2009_000418__15
394 | 2009_000421__15
395 | 2009_000457__15
396 | 2009_000704__15
397 | 2009_000705__13
398 | 2009_000727__13
399 | 2009_000730__14
400 | 2009_000730__15
401 | 2009_000825__14
402 | 2009_000825__15
403 | 2009_000839__12
404 | 2009_000892__12
405 | 2009_000931__13
406 | 2009_000935__12
407 | 2009_001215__11
408 | 2009_001215__15
409 | 2009_001299__15
410 | 2009_001433__13
411 | 2009_001433__15
412 | 2009_001535__12
413 | 2009_001663__15
414 | 2009_001687__12
415 | 2009_001687__15
416 | 2009_001718__15
417 | 2009_001768__15
418 | 2009_001854__15
419 | 2009_002012__12
420 | 2009_002042__15
421 | 2009_002097__13
422 | 2009_002155__12
423 | 2009_002165__13
424 | 2009_002185__15
425 | 2009_002239__14
426 | 2009_002239__15
427 | 2009_002317__14
428 | 2009_002317__15
429 | 2009_002346__12
430 | 2009_002346__15
431 | 2009_002372__15
432 | 2009_002382__14
433 | 2009_002382__15
434 | 2009_002415__11
435 | 2009_002445__12
436 | 2009_002487__11
437 | 2009_002539__12
438 | 2009_002571__11
439 | 2009_002584__15
440 | 2009_002649__15
441 | 2009_002651__14
442 | 2009_002651__15
443 | 2009_002732__15
444 | 2009_002975__13
445 | 2009_003003__11
446 | 2009_003003__15
447 | 2009_003063__12
448 | 2009_003065__15
449 | 2009_003071__11
450 | 2009_003071__15
451 | 2009_003123__11
452 | 2009_003196__14
453 | 2009_003217__12
454 | 2009_003241__12
455 | 2009_003269__15
456 | 2009_003323__13
457 | 2009_003323__15
458 | 2009_003466__12
459 | 2009_003481__13
460 | 2009_003494__15
461 | 2009_003507__11
462 | 2009_003576__14
463 | 2009_003576__15
464 | 2009_003756__12
465 | 2009_003804__13
466 | 2009_003810__12
467 | 2009_003849__11
468 | 2009_003849__15
469 | 2009_003903__13
470 | 2009_003928__12
471 | 2009_003991__11
472 | 2009_003991__15
473 | 2009_004033__12
474 | 2009_004043__14
475 | 2009_004043__15
476 | 2009_004140__11
477 | 2009_004221__15
478 | 2009_004455__14
479 | 2009_004497__13
480 | 2009_004507__12
481 | 2009_004507__15
482 | 2009_004581__12
483 | 2009_004592__12
484 | 2009_004738__14
485 | 2009_004738__15
486 | 2009_004848__15
487 | 2009_004859__11
488 | 2009_004859__15
489 | 2009_004942__13
490 | 2009_004987__14
491 | 2009_004987__15
492 | 2009_004994__12
493 | 2009_004994__15
494 | 2009_005038__11
495 | 2009_005038__15
496 | 2009_005078__14
497 | 2009_005087__15
498 | 2009_005217__13
499 | 2009_005217__15
500 | 2010_000003__12
501 | 2010_000038__13
502 | 2010_000038__15
503 | 2010_000087__14
504 | 2010_000087__15
505 | 2010_000110__12
506 | 2010_000110__15
507 | 2010_000159__12
508 | 2010_000174__11
509 | 2010_000174__15
510 | 2010_000216__12
511 | 2010_000238__15
512 | 2010_000256__15
513 | 2010_000422__12
514 | 2010_000530__15
515 | 2010_000559__15
516 | 2010_000639__12
517 | 2010_000666__13
518 | 2010_000666__15
519 | 2010_000738__15
520 | 2010_000788__12
521 | 2010_000874__13
522 | 2010_000904__12
523 | 2010_001024__15
524 | 2010_001124__12
525 | 2010_001251__14
526 | 2010_001264__12
527 | 2010_001313__14
528 | 2010_001313__15
529 | 2010_001367__15
530 | 2010_001376__12
531 | 2010_001451__13
532 | 2010_001553__14
533 | 2010_001563__12
534 | 2010_001563__15
535 | 2010_001579__11
536 | 2010_001579__15
537 | 2010_001692__15
538 | 2010_001699__15
539 | 2010_001734__15
540 | 2010_001767__15
541 | 2010_001851__11
542 | 2010_001908__12
543 | 2010_001956__12
544 | 2010_002017__15
545 | 2010_002137__15
546 | 2010_002161__13
547 | 2010_002161__15
548 | 2010_002228__12
549 | 2010_002251__14
550 | 2010_002251__15
551 | 2010_002271__14
552 | 2010_002336__11
553 | 2010_002396__14
554 | 2010_002396__15
555 | 2010_002480__12
556 | 2010_002623__15
557 | 2010_002691__13
558 | 2010_002763__15
559 | 2010_002792__15
560 | 2010_002902__15
561 | 2010_002929__15
562 | 2010_003014__15
563 | 2010_003060__12
564 | 2010_003187__12
565 | 2010_003207__14
566 | 2010_003239__15
567 | 2010_003325__11
568 | 2010_003325__15
569 | 2010_003381__15
570 | 2010_003409__15
571 | 2010_003446__15
572 | 2010_003506__12
573 | 2010_003531__11
574 | 2010_003532__13
575 | 2010_003597__11
576 | 2010_003597__15
577 | 2010_003746__12
578 | 2010_003746__15
579 | 2010_003947__14
580 | 2010_003971__11
581 | 2010_004042__14
582 | 2010_004165__12
583 | 2010_004165__15
584 | 2010_004219__14
585 | 2010_004219__15
586 | 2010_004337__15
587 | 2010_004355__14
588 | 2010_004432__15
589 | 2010_004472__15
590 | 2010_004479__15
591 | 2010_004519__13
592 | 2010_004550__12
593 | 2010_004559__15
594 | 2010_004628__12
595 | 2010_004697__14
596 | 2010_004697__15
597 | 2010_004795__12
598 | 2010_004815__15
599 | 2010_004825__11
600 | 2010_004828__15
601 | 2010_004856__13
602 | 2010_004941__14
603 | 2010_004951__15
604 | 2010_005046__11
605 | 2010_005046__15
606 | 2010_005118__15
607 | 2010_005159__12
608 | 2010_005160__14
609 | 2010_005166__15
610 | 2010_005174__13
611 | 2010_005206__12
612 | 2010_005245__12
613 | 2010_005245__15
614 | 2010_005252__14
615 | 2010_005252__15
616 | 2010_005284__15
617 | 2010_005366__14
618 | 2010_005433__14
619 | 2010_005501__14
620 | 2010_005575__12
621 | 2010_005582__15
622 | 2010_005606__15
623 | 2010_005626__11
624 | 2010_005626__15
625 | 2010_005644__12
626 | 2010_005709__15
627 | 2010_005871__15
628 | 2010_005991__12
629 | 2010_005991__15
630 | 2010_005992__12
631 | 2011_000045__12
632 | 2011_000051__15
633 | 2011_000054__15
634 | 2011_000178__15
635 | 2011_000226__11
636 | 2011_000248__15
637 | 2011_000338__11
638 | 2011_000396__13
639 | 2011_000435__15
640 | 2011_000438__15
641 | 2011_000455__14
642 | 2011_000455__15
643 | 2011_000479__15
644 | 2011_000512__14
645 | 2011_000526__13
646 | 2011_000536__12
647 | 2011_000566__15
648 | 2011_000585__15
649 | 2011_000598__11
650 | 2011_000618__14
651 | 2011_000618__15
652 | 2011_000638__15
653 | 2011_000780__15
654 | 2011_000809__11
655 | 2011_000809__15
656 | 2011_000843__15
657 | 2011_000953__11
658 | 2011_000953__15
659 | 2011_001014__12
660 | 2011_001060__15
661 | 2011_001069__15
662 | 2011_001071__15
663 | 2011_001159__15
664 | 2011_001276__11
665 | 2011_001276__12
666 | 2011_001276__15
667 | 2011_001346__15
668 | 2011_001416__15
669 | 2011_001447__15
670 | 2011_001530__15
671 | 2011_001567__15
672 | 2011_001619__15
673 | 2011_001642__12
674 | 2011_001665__11
675 | 2011_001674__15
676 | 2011_001714__12
677 | 2011_001714__15
678 | 2011_001722__13
679 | 2011_001745__12
680 | 2011_001794__15
681 | 2011_001862__11
682 | 2011_001862__12
683 | 2011_001868__12
684 | 2011_001984__12
685 | 2011_001988__15
686 | 2011_002002__15
687 | 2011_002040__12
688 | 2011_002075__11
689 | 2011_002075__15
690 | 2011_002098__12
691 | 2011_002110__12
692 | 2011_002110__15
693 | 2011_002121__12
694 | 2011_002124__15
695 | 2011_002156__12
696 | 2011_002200__11
697 | 2011_002200__15
698 | 2011_002247__15
699 | 2011_002279__12
700 | 2011_002298__12
701 | 2011_002308__15
702 | 2011_002317__15
703 | 2011_002322__14
704 | 2011_002322__15
705 | 2011_002343__15
706 | 2011_002358__11
707 | 2011_002358__15
708 | 2011_002371__12
709 | 2011_002498__15
710 | 2011_002509__15
711 | 2011_002532__15
712 | 2011_002575__15
713 | 2011_002578__15
714 | 2011_002589__12
715 | 2011_002623__15
716 | 2011_002641__15
717 | 2011_002675__15
718 | 2011_002951__13
719 | 2011_002997__15
720 | 2011_003019__14
721 | 2011_003019__15
722 | 2011_003085__13
723 | 2011_003114__15
724 | 2011_003240__15
725 | 2011_003256__12
726 |
--------------------------------------------------------------------------------
/data/splits/pascal/val/fold3.txt:
--------------------------------------------------------------------------------
1 | 2007_000042__19
2 | 2007_000123__19
3 | 2007_000175__17
4 | 2007_000187__20
5 | 2007_000452__18
6 | 2007_000559__20
7 | 2007_000629__19
8 | 2007_000636__19
9 | 2007_000661__18
10 | 2007_000676__17
11 | 2007_000804__18
12 | 2007_000925__17
13 | 2007_001154__18
14 | 2007_001175__20
15 | 2007_001408__16
16 | 2007_001430__16
17 | 2007_001430__20
18 | 2007_001457__18
19 | 2007_001458__18
20 | 2007_001585__18
21 | 2007_001594__17
22 | 2007_001678__20
23 | 2007_001717__20
24 | 2007_001733__17
25 | 2007_001763__18
26 | 2007_001763__20
27 | 2007_002119__20
28 | 2007_002132__20
29 | 2007_002268__18
30 | 2007_002284__16
31 | 2007_002378__16
32 | 2007_002426__18
33 | 2007_002427__18
34 | 2007_002565__19
35 | 2007_002618__17
36 | 2007_002648__17
37 | 2007_002728__19
38 | 2007_003011__18
39 | 2007_003011__20
40 | 2007_003169__18
41 | 2007_003367__16
42 | 2007_003499__19
43 | 2007_003506__16
44 | 2007_003530__18
45 | 2007_003587__19
46 | 2007_003714__17
47 | 2007_003848__19
48 | 2007_003957__19
49 | 2007_004190__20
50 | 2007_004193__20
51 | 2007_004275__16
52 | 2007_004281__19
53 | 2007_004483__19
54 | 2007_004510__20
55 | 2007_004558__16
56 | 2007_004649__19
57 | 2007_004712__16
58 | 2007_004969__17
59 | 2007_005469__17
60 | 2007_005626__19
61 | 2007_005689__19
62 | 2007_005813__16
63 | 2007_005857__16
64 | 2007_005915__17
65 | 2007_006171__18
66 | 2007_006348__20
67 | 2007_006373__18
68 | 2007_006678__17
69 | 2007_006680__19
70 | 2007_006802__19
71 | 2007_007130__20
72 | 2007_007165__17
73 | 2007_007168__19
74 | 2007_007195__19
75 | 2007_007196__20
76 | 2007_007203__20
77 | 2007_007417__18
78 | 2007_007534__17
79 | 2007_007624__16
80 | 2007_007795__16
81 | 2007_007881__19
82 | 2007_007996__18
83 | 2007_008204__20
84 | 2007_008260__18
85 | 2007_008339__19
86 | 2007_008374__20
87 | 2007_008543__18
88 | 2007_008547__16
89 | 2007_009068__18
90 | 2007_009252__18
91 | 2007_009320__17
92 | 2007_009419__16
93 | 2007_009446__20
94 | 2007_009521__18
95 | 2007_009521__20
96 | 2007_009592__18
97 | 2007_009655__18
98 | 2007_009684__18
99 | 2007_009750__16
100 | 2008_000016__20
101 | 2008_000149__18
102 | 2008_000270__18
103 | 2008_000391__16
104 | 2008_000589__18
105 | 2008_000657__19
106 | 2008_001078__16
107 | 2008_001283__16
108 | 2008_001688__16
109 | 2008_001688__20
110 | 2008_001966__16
111 | 2008_002273__16
112 | 2008_002379__16
113 | 2008_002464__20
114 | 2008_002536__17
115 | 2008_002680__20
116 | 2008_002900__19
117 | 2008_002929__18
118 | 2008_003003__20
119 | 2008_003026__20
120 | 2008_003105__19
121 | 2008_003135__16
122 | 2008_003676__16
123 | 2008_003709__18
124 | 2008_003733__18
125 | 2008_003885__20
126 | 2008_004172__18
127 | 2008_004212__19
128 | 2008_004279__20
129 | 2008_004367__19
130 | 2008_004453__17
131 | 2008_004477__16
132 | 2008_004562__18
133 | 2008_004610__19
134 | 2008_004621__17
135 | 2008_004754__20
136 | 2008_004854__17
137 | 2008_004910__20
138 | 2008_005089__20
139 | 2008_005217__16
140 | 2008_005242__16
141 | 2008_005254__20
142 | 2008_005439__20
143 | 2008_005445__20
144 | 2008_005544__19
145 | 2008_005633__17
146 | 2008_005680__16
147 | 2008_006055__19
148 | 2008_006159__20
149 | 2008_006327__17
150 | 2008_006523__19
151 | 2008_006553__19
152 | 2008_006752__19
153 | 2008_006784__18
154 | 2008_006835__17
155 | 2008_007497__17
156 | 2008_007527__20
157 | 2008_007677__17
158 | 2008_007814__17
159 | 2008_007828__20
160 | 2008_008103__18
161 | 2008_008221__19
162 | 2008_008434__16
163 | 2009_000022__19
164 | 2009_000039__17
165 | 2009_000087__18
166 | 2009_000096__18
167 | 2009_000136__20
168 | 2009_000242__18
169 | 2009_000391__20
170 | 2009_000418__16
171 | 2009_000418__18
172 | 2009_000487__18
173 | 2009_000488__16
174 | 2009_000488__20
175 | 2009_000628__19
176 | 2009_000675__17
177 | 2009_000704__20
178 | 2009_000712__19
179 | 2009_000732__18
180 | 2009_000845__19
181 | 2009_000924__17
182 | 2009_001300__19
183 | 2009_001333__19
184 | 2009_001363__20
185 | 2009_001505__17
186 | 2009_001644__16
187 | 2009_001644__18
188 | 2009_001644__20
189 | 2009_001684__16
190 | 2009_001731__18
191 | 2009_001768__17
192 | 2009_001775__16
193 | 2009_001775__18
194 | 2009_001991__17
195 | 2009_002082__17
196 | 2009_002094__20
197 | 2009_002202__19
198 | 2009_002265__19
199 | 2009_002291__19
200 | 2009_002346__18
201 | 2009_002366__20
202 | 2009_002390__18
203 | 2009_002487__16
204 | 2009_002562__20
205 | 2009_002568__19
206 | 2009_002571__16
207 | 2009_002571__18
208 | 2009_002573__20
209 | 2009_002584__16
210 | 2009_002638__19
211 | 2009_002732__18
212 | 2009_002887__19
213 | 2009_002982__19
214 | 2009_003105__19
215 | 2009_003123__18
216 | 2009_003299__19
217 | 2009_003311__19
218 | 2009_003433__19
219 | 2009_003523__20
220 | 2009_003551__20
221 | 2009_003564__16
222 | 2009_003564__18
223 | 2009_003607__18
224 | 2009_003666__17
225 | 2009_003857__20
226 | 2009_003895__18
227 | 2009_003895__20
228 | 2009_003938__19
229 | 2009_004099__18
230 | 2009_004140__18
231 | 2009_004255__19
232 | 2009_004298__18
233 | 2009_004687__18
234 | 2009_004730__19
235 | 2009_004799__19
236 | 2009_004993__18
237 | 2009_004993__20
238 | 2009_005148__19
239 | 2009_005220__19
240 | 2010_000256__18
241 | 2010_000284__18
242 | 2010_000309__17
243 | 2010_000318__20
244 | 2010_000330__16
245 | 2010_000639__16
246 | 2010_000738__20
247 | 2010_000764__19
248 | 2010_001011__17
249 | 2010_001079__17
250 | 2010_001104__19
251 | 2010_001149__18
252 | 2010_001151__19
253 | 2010_001246__16
254 | 2010_001256__17
255 | 2010_001327__18
256 | 2010_001367__20
257 | 2010_001522__17
258 | 2010_001557__17
259 | 2010_001577__17
260 | 2010_001699__16
261 | 2010_001734__19
262 | 2010_001752__20
263 | 2010_001767__18
264 | 2010_001773__16
265 | 2010_001851__16
266 | 2010_001951__19
267 | 2010_001962__18
268 | 2010_002106__17
269 | 2010_002137__16
270 | 2010_002137__18
271 | 2010_002232__17
272 | 2010_002531__18
273 | 2010_002682__19
274 | 2010_002921__20
275 | 2010_003014__18
276 | 2010_003123__16
277 | 2010_003302__16
278 | 2010_003514__19
279 | 2010_003541__17
280 | 2010_003597__18
281 | 2010_003781__16
282 | 2010_003956__19
283 | 2010_004149__19
284 | 2010_004226__17
285 | 2010_004382__16
286 | 2010_004479__20
287 | 2010_004757__16
288 | 2010_004757__18
289 | 2010_004783__18
290 | 2010_004825__16
291 | 2010_004857__20
292 | 2010_004951__19
293 | 2010_004980__19
294 | 2010_005180__18
295 | 2010_005187__16
296 | 2010_005305__20
297 | 2010_005606__18
298 | 2010_005706__19
299 | 2010_005719__17
300 | 2010_005727__19
301 | 2010_005788__17
302 | 2010_005860__16
303 | 2010_005871__19
304 | 2010_005991__18
305 | 2010_006054__19
306 | 2011_000070__18
307 | 2011_000173__18
308 | 2011_000283__19
309 | 2011_000291__19
310 | 2011_000310__18
311 | 2011_000436__17
312 | 2011_000521__19
313 | 2011_000747__16
314 | 2011_001005__18
315 | 2011_001060__19
316 | 2011_001281__19
317 | 2011_001350__17
318 | 2011_001567__18
319 | 2011_001601__18
320 | 2011_001614__19
321 | 2011_001674__18
322 | 2011_001713__16
323 | 2011_001713__18
324 | 2011_001726__20
325 | 2011_001794__18
326 | 2011_001862__18
327 | 2011_001863__16
328 | 2011_001910__20
329 | 2011_002124__18
330 | 2011_002156__20
331 | 2011_002178__17
332 | 2011_002247__19
333 | 2011_002379__19
334 | 2011_002391__18
335 | 2011_002532__20
336 | 2011_002535__19
337 | 2011_002644__18
338 | 2011_002644__20
339 | 2011_002879__18
340 | 2011_002879__20
341 | 2011_003103__16
342 | 2011_003103__18
343 | 2011_003146__19
344 | 2011_003182__18
345 | 2011_003197__19
346 | 2011_003256__18
347 |
--------------------------------------------------------------------------------
/generate_cam_coco.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import clip
3 | from PIL import Image
4 | from pytorch_grad_cam import GradCAM
5 | import cv2
6 | import argparse
7 | from data.dataset import FSSDataset
8 | import pdb
9 |
10 | COCO_CLASSES = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
11 | 'train', 'truck', 'boat', 'traffic light', 'fire hydrant',
12 | 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog',
13 | 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe',
14 | 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
15 | 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat',
16 | 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
17 | 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
18 | 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot',
19 | 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
20 | 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop',
21 | 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave',
22 | 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock',
23 | 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush']
24 |
25 |
26 | def get_cam_from_alldata(clip_model, preprocess, split='train', d0=None, d1=None, d2=None, d3=None,
27 | datapath=None, campath=None):
28 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
29 |
30 | d0 = d0.dataset.img_metadata_classwise
31 | d1 = d1.dataset.img_metadata_classwise
32 | d2 = d2.dataset.img_metadata_classwise
33 | d3 = d3.dataset.img_metadata_classwise
34 | dd = [d0, d1, d2, d3]
35 | dataset_all = {}
36 |
37 | if split == 'train':
38 | for ii in range(80):
39 | index = ii % 4 + 1
40 | if ii % 4 == 3:
41 | index = 0
42 | dataset_all[ii] = dd[index][ii]
43 | else:
44 | for ii in range(80):
45 | index = ii % 4
46 | dataset_all[ii] = dd[index][ii]
47 | del d0, d1, d2, d3, dd
48 |
49 | text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in COCO_CLASSES]).to(device)
50 | for cls_id in range(80):
51 | L = len(dataset_all[cls_id])
52 | for ll in range(L):
53 | img_path = datapath + dataset_all[cls_id][ll]
54 | img = Image.open(img_path)
55 | img_input = preprocess(img).unsqueeze(0).to(device)
56 | class_name_id = cls_id
57 |
58 | # CAM
59 | clip_model.get_text_features(text_inputs)
60 | target_layers = [clip_model.visual.layer4[-1]]
61 | input_tensor = img_input
62 | cam = GradCAM(model=clip_model, target_layers=target_layers, use_cuda=True)
63 | target_category = class_name_id
64 | grayscale_cam = cam(input_tensor=input_tensor, target_category=target_category)
65 | grayscale_cam = grayscale_cam[0, :]
66 | grayscale_cam = cv2.resize(grayscale_cam, (50, 50))
67 | grayscale_cam = torch.from_numpy(grayscale_cam)
68 | save_path = campath + dataset_all[cls_id][ll] + '--' + str(class_name_id) + '.pt'
69 | torch.save(grayscale_cam, save_path)
70 | print('cam saved in ', save_path)
71 |
72 |
73 |
74 | if __name__ == '__main__':
75 | parser = argparse.ArgumentParser(description='IMR')
76 | parser.add_argument('--imgpath', type=str, default='../Datasets_HSN/COCO2014/')
77 | parser.add_argument('--campath', type=str, default='../Datasets_HSN/CAM_Val_COCO/')
78 | args = parser.parse_args()
79 |
80 |
81 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
82 | model_clip, preprocess = clip.load('RN50', device, jit=False)
83 | FSSDataset.initialize(img_size=400, datapath='../Datasets_HSN', use_original_imgsize=False)
84 |
85 |
86 | # COCO-meta-train
87 | # train
88 | dataloader_test0 = FSSDataset.build_dataloader('coco', 1, 0, 0, 'train', 1)
89 | dataloader_test1 = FSSDataset.build_dataloader('coco', 1, 0, 1, 'train', 1)
90 | dataloader_test2 = FSSDataset.build_dataloader('coco', 1, 0, 2, 'train', 1)
91 | dataloader_test3 = FSSDataset.build_dataloader('coco', 1, 0, 3, 'train', 1)
92 | get_cam_from_alldata(model_clip, preprocess, split='train',
93 | d0=dataloader_test0, d1=dataloader_test1,
94 | d2=dataloader_test2, d3=dataloader_test3,
95 | datapath=args.imgpath, campath=args.campath)
96 |
97 | # val
98 | dataloader_test0 = FSSDataset.build_dataloader('coco', 1, 0, 0, 'val', 1)
99 | dataloader_test1 = FSSDataset.build_dataloader('coco', 1, 0, 1, 'val', 1)
100 | dataloader_test2 = FSSDataset.build_dataloader('coco', 1, 0, 2, 'val', 1)
101 | dataloader_test3 = FSSDataset.build_dataloader('coco', 1, 0, 3, 'val', 1)
102 | get_cam_from_alldata(model_clip, preprocess, split='val',
103 | d0=dataloader_test0, d1=dataloader_test1,
104 | d2=dataloader_test2, d3=dataloader_test3,
105 | datapath=args.imgpath, campath=args.campath)
106 |
--------------------------------------------------------------------------------
/generate_cam_voc.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import clip
3 | from PIL import Image
4 | from pytorch_grad_cam import GradCAM
5 | import cv2
6 | import argparse
7 | from data.dataset import FSSDataset
8 | import pdb
9 |
10 | PASCAL_CLASSES = ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car',
11 | 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse',
12 | 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train',
13 | 'tvmonitor']
14 |
15 |
16 | def get_cam_from_alldata(clip_model, preprocess, d=None, datapath=None, campath=None):
17 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
18 | dataset_all = d.dataset.img_metadata
19 | L = len(dataset_all)
20 | text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in PASCAL_CLASSES]).to(device)
21 | for ll in range(L):
22 | img_path = datapath + dataset_all[ll][0] + '.jpg'
23 | img = Image.open(img_path)
24 | img_input = preprocess(img).unsqueeze(0).to(device)
25 | class_name_id = dataset_all[ll][1]
26 | clip_model.get_text_features(text_inputs)
27 | target_layers = [clip_model.visual.layer4[-1]]
28 | input_tensor = img_input
29 | cam = GradCAM(model=clip_model, target_layers=target_layers, use_cuda=True)
30 | target_category = class_name_id
31 | grayscale_cam = cam(input_tensor=input_tensor, target_category=target_category)
32 | grayscale_cam = grayscale_cam[0, :]
33 | grayscale_cam = cv2.resize(grayscale_cam, (50, 50))
34 | grayscale_cam = torch.from_numpy(grayscale_cam)
35 | save_path = campath + dataset_all[ll][0] + '--' + str(class_name_id) + '.pt'
36 | torch.save(grayscale_cam, save_path)
37 | print('cam已经保存', save_path)
38 |
39 |
40 |
41 | if __name__ == '__main__':
42 | parser = argparse.ArgumentParser(description='IMR')
43 | parser.add_argument('--imgpath', type=str, default='../Datasets_HSN/VOC2012/JPEGImages/')
44 | parser.add_argument('--traincampath', type=str, default='../Datasets_HSN/CAM_Train/')
45 | parser.add_argument('--valcampath', type=str, default='../Datasets_HSN/CAM_Val/')
46 | args = parser.parse_args()
47 |
48 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
49 | model_clip, preprocess = clip.load('RN50', device, jit=False)
50 | FSSDataset.initialize(img_size=400, datapath='../Datasets_HSN', use_original_imgsize=False)
51 |
52 | # VOC
53 | # train
54 | dataloader_test0 = FSSDataset.build_dataloader('pascal', 1, 0, 0, 'train', 1)
55 | dataloader_test1 = FSSDataset.build_dataloader('pascal', 1, 0, 1, 'train', 1)
56 | dataloader_test2 = FSSDataset.build_dataloader('pascal', 1, 0, 2, 'train', 1)
57 | dataloader_test3 = FSSDataset.build_dataloader('pascal', 1, 0, 3, 'train', 1)
58 |
59 | get_cam_from_alldata(model_clip, preprocess, d=dataloader_test0, datapath=args.imgpath, campath=args.traincampath)
60 | get_cam_from_alldata(model_clip, preprocess, d=dataloader_test1, datapath=args.imgpath, campath=args.traincampath)
61 | get_cam_from_alldata(model_clip, preprocess, d=dataloader_test2, datapath=args.imgpath, campath=args.traincampath)
62 | get_cam_from_alldata(model_clip, preprocess, d=dataloader_test3, datapath=args.imgpath, campath=args.traincampath)
63 |
64 | # val
65 | dataloader_test0 = FSSDataset.build_dataloader('pascal', 1, 0, 0, 'val', 1)
66 | dataloader_test1 = FSSDataset.build_dataloader('pascal', 1, 0, 1, 'val', 1)
67 | dataloader_test2 = FSSDataset.build_dataloader('pascal', 1, 0, 2, 'val', 1)
68 | dataloader_test3 = FSSDataset.build_dataloader('pascal', 1, 0, 3, 'val', 1)
69 |
70 | get_cam_from_alldata(model_clip, preprocess, d=dataloader_test0, datapath=args.imgpath, campath=args.valcampath)
71 | get_cam_from_alldata(model_clip, preprocess, d=dataloader_test1, datapath=args.imgpath, campath=args.valcampath)
72 | get_cam_from_alldata(model_clip, preprocess, d=dataloader_test2, datapath=args.imgpath, campath=args.valcampath)
73 | get_cam_from_alldata(model_clip, preprocess, d=dataloader_test3, datapath=args.imgpath, campath=args.valcampath)
74 |
75 |
--------------------------------------------------------------------------------
/model/__pycache__/hsnet.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Whileherham/IMR-HSNet/57572272f493c3b16540a650ff0398d605e46812/model/__pycache__/hsnet.cpython-38.pyc
--------------------------------------------------------------------------------
/model/__pycache__/hsnet_imr.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Whileherham/IMR-HSNet/57572272f493c3b16540a650ff0398d605e46812/model/__pycache__/hsnet_imr.cpython-38.pyc
--------------------------------------------------------------------------------
/model/__pycache__/hsnet_raft_res_multi_group_xiaorong_grouponly.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Whileherham/IMR-HSNet/57572272f493c3b16540a650ff0398d605e46812/model/__pycache__/hsnet_raft_res_multi_group_xiaorong_grouponly.cpython-38.pyc
--------------------------------------------------------------------------------
/model/__pycache__/learner.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Whileherham/IMR-HSNet/57572272f493c3b16540a650ff0398d605e46812/model/__pycache__/learner.cpython-38.pyc
--------------------------------------------------------------------------------
/model/base/__pycache__/conv4d.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Whileherham/IMR-HSNet/57572272f493c3b16540a650ff0398d605e46812/model/base/__pycache__/conv4d.cpython-38.pyc
--------------------------------------------------------------------------------
/model/base/__pycache__/correlation.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Whileherham/IMR-HSNet/57572272f493c3b16540a650ff0398d605e46812/model/base/__pycache__/correlation.cpython-38.pyc
--------------------------------------------------------------------------------
/model/base/__pycache__/feature.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Whileherham/IMR-HSNet/57572272f493c3b16540a650ff0398d605e46812/model/base/__pycache__/feature.cpython-38.pyc
--------------------------------------------------------------------------------
/model/base/conv4d.py:
--------------------------------------------------------------------------------
1 | r""" Implementation of center-pivot 4D convolution """
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 |
7 | class CenterPivotConv4d(nn.Module):
8 | r""" CenterPivot 4D conv"""
9 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding, bias=True):
10 | super(CenterPivotConv4d, self).__init__()
11 |
12 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size[:2], stride=stride[:2],
13 | bias=bias, padding=padding[:2])
14 | self.conv2 = nn.Conv2d(in_channels, out_channels, kernel_size[2:], stride=stride[2:],
15 | bias=bias, padding=padding[2:])
16 |
17 | self.stride34 = stride[2:]
18 | self.kernel_size = kernel_size
19 | self.stride = stride
20 | self.padding = padding
21 | self.idx_initialized = False
22 |
23 | def prune(self, ct):
24 | bsz, ch, ha, wa, hb, wb = ct.size()
25 | if not self.idx_initialized:
26 | idxh = torch.arange(start=0, end=hb, step=self.stride[2:][0], device=ct.device)
27 | idxw = torch.arange(start=0, end=wb, step=self.stride[2:][1], device=ct.device)
28 | self.len_h = len(idxh)
29 | self.len_w = len(idxw)
30 | self.idx = (idxw.repeat(self.len_h, 1) + idxh.repeat(self.len_w, 1).t() * wb).view(-1)
31 | self.idx_initialized = True
32 | ct_pruned = ct.view(bsz, ch, ha, wa, -1).index_select(4, self.idx).view(bsz, ch, ha, wa, self.len_h, self.len_w)
33 |
34 | return ct_pruned
35 |
36 | def forward(self, x):
37 | if self.stride[2:][-1] > 1:
38 | out1 = self.prune(x)
39 | else:
40 | out1 = x
41 | bsz, inch, ha, wa, hb, wb = out1.size()
42 | out1 = out1.permute(0, 4, 5, 1, 2, 3).contiguous().view(-1, inch, ha, wa)
43 | out1 = self.conv1(out1)
44 | outch, o_ha, o_wa = out1.size(-3), out1.size(-2), out1.size(-1)
45 | out1 = out1.view(bsz, hb, wb, outch, o_ha, o_wa).permute(0, 3, 4, 5, 1, 2).contiguous()
46 |
47 | bsz, inch, ha, wa, hb, wb = x.size()
48 | out2 = x.permute(0, 2, 3, 1, 4, 5).contiguous().view(-1, inch, hb, wb)
49 | out2 = self.conv2(out2)
50 | outch, o_hb, o_wb = out2.size(-3), out2.size(-2), out2.size(-1)
51 | out2 = out2.view(bsz, ha, wa, outch, o_hb, o_wb).permute(0, 3, 1, 2, 4, 5).contiguous()
52 |
53 | if out1.size()[-2:] != out2.size()[-2:] and self.padding[-2:] == (0, 0):
54 | out1 = out1.view(bsz, outch, o_ha, o_wa, -1).sum(dim=-1)
55 | out2 = out2.squeeze()
56 |
57 | y = out1 + out2
58 | return y
59 |
--------------------------------------------------------------------------------
/model/base/correlation.py:
--------------------------------------------------------------------------------
1 | r""" Provides functions that builds/manipulates correlation tensors """
2 | import torch
3 |
4 |
5 | class Correlation:
6 |
7 | @classmethod
8 | def multilayer_correlation(cls, query_feats, support_feats, stack_ids):
9 | eps = 1e-5
10 |
11 | corrs = []
12 | for idx, (query_feat, support_feat) in enumerate(zip(query_feats, support_feats)):
13 | bsz, ch, hb, wb = support_feat.size()
14 | support_feat = support_feat.view(bsz, ch, -1)
15 | support_feat = support_feat / (support_feat.norm(dim=1, p=2, keepdim=True) + eps)
16 |
17 | bsz, ch, ha, wa = query_feat.size()
18 | query_feat = query_feat.view(bsz, ch, -1)
19 | query_feat = query_feat / (query_feat.norm(dim=1, p=2, keepdim=True) + eps)
20 |
21 | corr = torch.bmm(query_feat.transpose(1, 2), support_feat).view(bsz, ha, wa, hb, wb)
22 | corr = corr.clamp(min=0)
23 | corrs.append(corr)
24 |
25 | corr_l4 = torch.stack(corrs[-stack_ids[0]:]).transpose(0, 1).contiguous()
26 | corr_l3 = torch.stack(corrs[-stack_ids[1]:-stack_ids[0]]).transpose(0, 1).contiguous()
27 | corr_l2 = torch.stack(corrs[-stack_ids[2]:-stack_ids[1]]).transpose(0, 1).contiguous()
28 |
29 | return [corr_l4, corr_l3, corr_l2]
30 |
--------------------------------------------------------------------------------
/model/base/feature.py:
--------------------------------------------------------------------------------
1 | r""" Extracts intermediate features from given backbone network & layer ids """
2 |
3 |
4 | def extract_feat_vgg(img, backbone, feat_ids, bottleneck_ids=None, lids=None):
5 | r""" Extract intermediate features from VGG """
6 | feats = []
7 | feat = img
8 | for lid, module in enumerate(backbone.features):
9 | feat = module(feat)
10 | if lid in feat_ids:
11 | feats.append(feat.clone())
12 | return feats
13 |
14 |
15 | def extract_feat_res(img, backbone, feat_ids, bottleneck_ids, lids):
16 | r""" Extract intermediate features from ResNet"""
17 | feats = []
18 |
19 | # Layer 0
20 | feat = backbone.conv1.forward(img)
21 | feat = backbone.bn1.forward(feat)
22 | feat = backbone.relu.forward(feat)
23 | feat = backbone.maxpool.forward(feat)
24 |
25 | # Layer 1-4
26 | for hid, (bid, lid) in enumerate(zip(bottleneck_ids, lids)):
27 | res = feat
28 | feat = backbone.__getattr__('layer%d' % lid)[bid].conv1.forward(feat)
29 | feat = backbone.__getattr__('layer%d' % lid)[bid].bn1.forward(feat)
30 | feat = backbone.__getattr__('layer%d' % lid)[bid].relu.forward(feat)
31 | feat = backbone.__getattr__('layer%d' % lid)[bid].conv2.forward(feat)
32 | feat = backbone.__getattr__('layer%d' % lid)[bid].bn2.forward(feat)
33 | feat = backbone.__getattr__('layer%d' % lid)[bid].relu.forward(feat)
34 | feat = backbone.__getattr__('layer%d' % lid)[bid].conv3.forward(feat)
35 | feat = backbone.__getattr__('layer%d' % lid)[bid].bn3.forward(feat)
36 |
37 | if bid == 0:
38 | res = backbone.__getattr__('layer%d' % lid)[bid].downsample.forward(res)
39 |
40 | feat += res
41 |
42 | if hid + 1 in feat_ids:
43 | feats.append(feat.clone())
44 |
45 | feat = backbone.__getattr__('layer%d' % lid)[bid].relu.forward(feat)
46 |
47 | return feats
48 |
--------------------------------------------------------------------------------
/model/hsnet_imr.py:
--------------------------------------------------------------------------------
1 | r""" Hypercorrelation Squeeze Network """
2 | import pdb
3 | from functools import reduce
4 | from operator import add
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | from torchvision.models import resnet
10 | from torchvision.models import vgg
11 |
12 | from .base.feature import extract_feat_vgg, extract_feat_res
13 | from .base.correlation import Correlation
14 | from .learner import HPNLearner
15 |
16 |
17 | class HypercorrSqueezeNetwork_imr(nn.Module):
18 | # 与不是res的结构相比,这里就是让logit——mask和cam cat一下,然后过个卷积出结果,来显式地再用一下cam
19 | def __init__(self, backbone, use_original_imgsize):
20 | super(HypercorrSqueezeNetwork_imr, self).__init__()
21 |
22 | # 1. Backbone network initialization
23 | self.backbone_type = backbone
24 | self.use_original_imgsize = use_original_imgsize
25 | if backbone == 'vgg16':
26 | self.backbone = vgg.vgg16(pretrained=False)
27 | ckpt = torch.load('../Datasets_HSN/Pretrain/vgg16-397923af.pth')
28 | self.backbone.load_state_dict(ckpt)
29 | self.feat_ids = [17, 19, 21, 24, 26, 28, 30]
30 | self.extract_feats = extract_feat_vgg
31 | nbottlenecks = [2, 2, 3, 3, 3, 1]
32 | elif backbone == 'resnet50':
33 | self.backbone = resnet.resnet50(pretrained=False)
34 | ckpt = torch.load('../Datasets_HSN/Pretrain/resnet50-19c8e357.pth')
35 | self.backbone.load_state_dict(ckpt)
36 | self.feat_ids = list(range(4, 17))
37 | self.extract_feats = extract_feat_res
38 | nbottlenecks = [3, 4, 6, 3]
39 | self.conv1024_512 = nn.Conv2d(1024, 512, kernel_size=1)
40 |
41 | elif backbone == 'resnet101':
42 | self.backbone = resnet.resnet101(pretrained=True)
43 | self.feat_ids = list(range(4, 34))
44 | self.extract_feats = extract_feat_res
45 | nbottlenecks = [3, 4, 23, 3]
46 | else:
47 | raise Exception('Unavailable backbone: %s' % backbone)
48 |
49 | self.bottleneck_ids = reduce(add, list(map(lambda x: list(range(x)), nbottlenecks)))
50 | self.lids = reduce(add, [[i + 1] * x for i, x in enumerate(nbottlenecks)])
51 | self.stack_ids = torch.tensor(self.lids).bincount().__reversed__().cumsum(dim=0)[:3]
52 | self.backbone.eval()
53 | self.hpn_learner = HPNLearner(list(reversed(nbottlenecks[-3:])))
54 | self.cross_entropy_loss = nn.CrossEntropyLoss()
55 |
56 | # IMR
57 | self.state = nn.Parameter(torch.zeros([1, 128, 50, 50]))
58 | self.convz0 = nn.Conv2d(769, 512, kernel_size=1, padding=0)
59 | self.convz1 = nn.Conv2d(256, 64, kernel_size=3, padding=1, groups=8, dilation=1)
60 | self.convz2 = nn.Conv2d(256, 64, kernel_size=3, padding=1, groups=8, dilation=1)
61 |
62 | self.convr0 = nn.Conv2d(769, 512, kernel_size=1, padding=0)
63 | self.convr1 = nn.Conv2d(256, 64, kernel_size=3, padding=1, groups=8, dilation=1)
64 | self.convr2 = nn.Conv2d(256, 64, kernel_size=3, padding=1, groups=8, dilation=1)
65 |
66 | self.convh0 = nn.Conv2d(769, 512, kernel_size=1, padding=0)
67 | self.convh1 = nn.Conv2d(256, 64, kernel_size=3, padding=1, groups=8, dilation=1)
68 | self.convh2 = nn.Conv2d(256, 64, kernel_size=3, padding=1, groups=8, dilation=1)
69 |
70 | # copied from hsnet-learner
71 | outch1, outch2, outch3 = 16, 64, 128
72 | self.decoder1 = nn.Sequential(nn.Conv2d(outch3, outch3, (3, 3), padding=(1, 1), bias=True),
73 | nn.ReLU(),
74 | nn.Conv2d(outch3, outch2, (3, 3), padding=(1, 1), bias=True),
75 | nn.ReLU())
76 |
77 | self.decoder2 = nn.Sequential(nn.Conv2d(outch2, outch2, (3, 3), padding=(1, 1), bias=True),
78 | nn.ReLU(),
79 | nn.Conv2d(outch2, 2, (3, 3), padding=(1, 1), bias=True))
80 |
81 | self.res = nn.Sequential(nn.Conv2d(3, 10, kernel_size=1),
82 | nn.GELU(),
83 | nn.Conv2d(10, 2, kernel_size=1))
84 |
85 | def forward(self, query_img, support_img, support_cam, query_cam,
86 | query_mask=None, support_mask=None, stage=2, w='same'):
87 | with torch.no_grad():
88 | query_feats = self.extract_feats(query_img, self.backbone, self.feat_ids, self.bottleneck_ids, self.lids)
89 | support_feats = self.extract_feats(
90 | support_img, self.backbone, self.feat_ids, self.bottleneck_ids, self.lids)
91 |
92 | # extracting feature
93 | if len(query_feats) == 7:
94 | isvgg = True # VGG
95 | q_mid_feat = F.interpolate(query_feats[3] + query_feats[4] + query_feats[5],
96 | (50, 50), mode='bilinear', align_corners=True)
97 | s_mid_feat = F.interpolate(support_feats[3] + support_feats[4] + support_feats[5],
98 | (50, 50), mode='bilinear', align_corners=True)
99 | else:
100 | isvgg = False # R50
101 | q_mid_feat = F.interpolate(
102 | query_feats[4] + query_feats[5] + query_feats[6] + query_feats[7] + query_feats[8] + query_feats[9],
103 | (50, 50), mode='bilinear', align_corners=True)
104 |
105 | s_mid_feat = F.interpolate(
106 | support_feats[4] + support_feats[5] + support_feats[6] + support_feats[7] + support_feats[8] +
107 | support_feats[9],
108 | (50, 50), mode='bilinear', align_corners=True)
109 |
110 | query_feats_masked = self.mask_feature(query_feats, support_cam.clone())
111 | support_feats_masked = self.mask_feature(support_feats, query_cam.clone())
112 |
113 | corr_query = Correlation.multilayer_correlation(query_feats, support_feats_masked, self.stack_ids)
114 | corr_support = Correlation.multilayer_correlation(support_feats, query_feats_masked, self.stack_ids)
115 |
116 | query_cam = query_cam.unsqueeze(1)
117 | support_cam = support_cam.unsqueeze(1)
118 |
119 | if not isvgg:
120 | # make feat dim in R50 same as VGG
121 | q_mid_feat = self.conv1024_512(q_mid_feat)
122 | s_mid_feat = self.conv1024_512(s_mid_feat)
123 |
124 | bsz = query_img.shape[0]
125 | state_query = self.state.expand(bsz, -1, -1, -1)
126 | state_support = self.state.expand(bsz, -1, -1, -1)
127 |
128 | losses = 0
129 | for ss in range(stage):
130 | # query
131 | after4d_query = self.hpn_learner.forward_conv4d(corr_query)
132 | imr_x_query = torch.cat([query_cam, after4d_query, q_mid_feat, state_query], dim=1)
133 |
134 | imr_x_query_z = self.convz0(imr_x_query)
135 | imr_z_query1 = self.convz1(imr_x_query_z[:, :256])
136 | imr_z_query2 = self.convz2(imr_x_query_z[:, 256:])
137 | imr_z_query = torch.sigmoid(torch.cat([imr_z_query1, imr_z_query2], dim=1))
138 |
139 | imr_x_query_r = self.convr0(imr_x_query)
140 | imr_r_query1 = self.convr1(imr_x_query_r[:, :256])
141 | imr_r_query2 = self.convr2(imr_x_query_r[:, 256:])
142 | imr_r_query = torch.sigmoid(torch.cat([imr_r_query1, imr_r_query2], dim=1))
143 |
144 | imr_x_query_h = self.convh0(
145 | torch.cat([query_cam, after4d_query, q_mid_feat, imr_r_query * state_query], dim=1))
146 | imr_h_query1 = self.convh1(imr_x_query_h[:, :256])
147 | imr_h_query2 = self.convh2(imr_x_query_h[:, 256:])
148 | imr_h_query = torch.cat([imr_h_query1, imr_h_query2], dim=1)
149 |
150 | state_new_query = torch.tanh(imr_h_query)
151 | state_query = (1 - imr_z_query) * state_query + imr_z_query * state_new_query
152 |
153 | # support
154 | after4d_support = self.hpn_learner.forward_conv4d(corr_support)
155 | imr_x_support = torch.cat([support_cam, after4d_support, s_mid_feat, state_support], dim=1)
156 |
157 | imr_x_support_z = self.convz0(imr_x_support)
158 | imr_z_support1 = self.convz1(imr_x_support_z[:, :256])
159 | imr_z_support2 = self.convz2(imr_x_support_z[:, 256:])
160 | imr_z_support = torch.sigmoid(torch.cat([imr_z_support1, imr_z_support2], dim=1))
161 |
162 | imr_x_support_r = self.convr0(imr_x_support)
163 | imr_r_support1 = self.convr1(imr_x_support_r[:, :256])
164 | imr_r_support2 = self.convr2(imr_x_support_r[:, 256:])
165 | imr_r_support = torch.sigmoid(torch.cat([imr_r_support1, imr_r_support2], dim=1))
166 |
167 | imr_x_support_h = self.convh0(
168 | torch.cat([support_cam, after4d_support, s_mid_feat, imr_r_support * state_support], dim=1))
169 | imr_h_support1 = self.convh1(imr_x_support_h[:, :256])
170 | imr_h_support2 = self.convh2(imr_x_support_h[:, 256:])
171 | imr_h_support = torch.cat([imr_h_support1, imr_h_support2], dim=1)
172 |
173 | state_new_support = torch.tanh(imr_h_support)
174 | state_support = (1 - imr_z_support) * state_support + imr_z_support * state_new_support
175 |
176 | # decoder
177 | hypercorr_decoded_s = self.decoder1(state_support + after4d_support)
178 | upsample_size = (hypercorr_decoded_s.size(-1) * 2,) * 2
179 | hypercorr_decoded_s = F.interpolate(hypercorr_decoded_s, upsample_size, mode='bilinear', align_corners=True)
180 | logit_mask_support = self.decoder2(hypercorr_decoded_s)
181 |
182 | hypercorr_decoded_q = self.decoder1(state_query + after4d_query)
183 | upsample_size = (hypercorr_decoded_q.size(-1) * 2,) * 2
184 | hypercorr_decoded_q = F.interpolate(hypercorr_decoded_q, upsample_size, mode='bilinear', align_corners=True)
185 | logit_mask_query = self.decoder2(hypercorr_decoded_q)
186 |
187 | logit_mask_support = self.res(
188 | torch.cat(
189 | [logit_mask_support, F.interpolate(support_cam, (100, 100), mode='bilinear', align_corners=True)],
190 | dim=1))
191 | logit_mask_query = self.res(
192 | torch.cat([logit_mask_query, F.interpolate(query_cam, (100, 100), mode='bilinear', align_corners=True)],
193 | dim=1))
194 |
195 | # loss
196 | if query_mask is not None: # for training
197 | if not self.use_original_imgsize:
198 | logit_mask_query_temp = F.interpolate(logit_mask_query, support_img.size()[2:], mode='bilinear',
199 | align_corners=True)
200 | logit_mask_support_temp = F.interpolate(logit_mask_support, support_img.size()[2:], mode='bilinear',
201 | align_corners=True)
202 | loss_q_stage = self.compute_objective(logit_mask_query_temp, query_mask)
203 | loss_s_stage = self.compute_objective(logit_mask_support_temp, support_mask)
204 | losses = losses + loss_q_stage + loss_s_stage
205 |
206 | if ss != stage - 1:
207 | support_cam = logit_mask_support.softmax(dim=1)[:, 1]
208 | query_cam = logit_mask_query.softmax(dim=1)[:, 1]
209 | query_feats_masked = self.mask_feature(query_feats, query_cam)
210 | support_feats_masked = self.mask_feature(support_feats, support_cam)
211 | corr_query = Correlation.multilayer_correlation(query_feats, support_feats_masked, self.stack_ids)
212 | corr_support = Correlation.multilayer_correlation(support_feats, query_feats_masked, self.stack_ids)
213 |
214 | query_cam = F.interpolate(query_cam.unsqueeze(1), (50, 50), mode='bilinear', align_corners=True)
215 | support_cam = F.interpolate(support_cam.unsqueeze(1), (50, 50), mode='bilinear', align_corners=True)
216 |
217 | if query_mask is not None:
218 | return logit_mask_query_temp, logit_mask_support_temp, losses
219 | else:
220 | # test
221 | if not self.use_original_imgsize:
222 | logit_mask_query = F.interpolate(
223 | logit_mask_query, support_img.size()[2:], mode='bilinear', align_corners=True)
224 | logit_mask_support = F.interpolate(
225 | logit_mask_support, support_img.size()[2:], mode='bilinear', align_corners=True)
226 | return logit_mask_query, logit_mask_support
227 |
228 | def mask_feature(self, features, support_mask):
229 | for idx, feature in enumerate(features):
230 | mask = F.interpolate(
231 | support_mask.unsqueeze(1).float(), feature.size()[2:], mode='bilinear', align_corners=True)
232 | features[idx] = features[idx] * mask
233 | return features
234 |
235 | def predict_mask_nshot(self, batch, nshot, stage):
236 | # Perform multiple prediction given (nshot) number of different support sets
237 | logit_mask_agg = 0
238 | for s_idx in range(nshot):
239 | logit_mask, logit_mask_s = self(query_img=batch['query_img'],
240 | support_img=batch['support_imgs'][:, s_idx],
241 | support_cam=batch['support_cams'][:, s_idx],
242 | query_cam=batch['query_cam'], stage=stage)
243 | if self.use_original_imgsize:
244 | org_qry_imsize = tuple([batch['org_query_imsize'][1].item(), batch['org_query_imsize'][0].item()])
245 | logit_mask = F.interpolate(logit_mask, org_qry_imsize, mode='bilinear', align_corners=True)
246 |
247 | logit_mask_agg += logit_mask.argmax(dim=1).clone()
248 | if nshot == 1:
249 | return logit_mask_agg
250 |
251 | # Average & quantize predictions given threshold (=0.5)
252 | bsz = logit_mask_agg.size(0)
253 | max_vote = logit_mask_agg.view(bsz, -1).max(dim=1)[0]
254 | max_vote = torch.stack([max_vote, torch.ones_like(max_vote).long()])
255 | max_vote = max_vote.max(dim=0)[0].view(bsz, 1, 1)
256 | pred_mask = logit_mask_agg.float() / max_vote
257 | pred_mask[pred_mask < 0.5] = 0
258 | pred_mask[pred_mask >= 0.5] = 1
259 |
260 | return pred_mask
261 |
262 | def compute_objective(self, logit_mask, gt_mask):
263 | bsz = logit_mask.size(0)
264 | logit_mask = logit_mask.view(bsz, 2, -1)
265 | gt_mask = gt_mask.view(bsz, -1).long()
266 | return self.cross_entropy_loss(logit_mask, gt_mask)
267 |
268 | def train_mode(self):
269 | self.train()
270 | self.backbone.eval() # to prevent BN from learning data statistics with exponential averaging
271 |
--------------------------------------------------------------------------------
/model/learner.py:
--------------------------------------------------------------------------------
1 | import pdb
2 |
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from .base.conv4d import CenterPivotConv4d as Conv4d
7 |
8 |
9 | class HPNLearner(nn.Module):
10 | def __init__(self, inch):
11 | super(HPNLearner, self).__init__()
12 |
13 | def make_building_block(in_channel, out_channels, kernel_sizes, spt_strides, group=4):
14 | assert len(out_channels) == len(kernel_sizes) == len(spt_strides)
15 |
16 | building_block_layers = []
17 | for idx, (outch, ksz, stride) in enumerate(zip(out_channels, kernel_sizes, spt_strides)):
18 | inch = in_channel if idx == 0 else out_channels[idx - 1]
19 | ksz4d = (ksz,) * 4
20 | str4d = (1, 1) + (stride,) * 2
21 | pad4d = (ksz // 2,) * 4
22 |
23 | building_block_layers.append(Conv4d(inch, outch, ksz4d, str4d, pad4d))
24 | building_block_layers.append(nn.GroupNorm(group, outch))
25 | building_block_layers.append(nn.ReLU(inplace=True))
26 |
27 | return nn.Sequential(*building_block_layers)
28 |
29 | outch1, outch2, outch3 = 16, 64, 128
30 |
31 | # Squeezing building blocks
32 | self.encoder_layer4 = make_building_block(inch[0], [outch1, outch2, outch3], [3, 3, 3], [2, 2, 2])
33 | self.encoder_layer3 = make_building_block(inch[1], [outch1, outch2, outch3], [5, 3, 3], [4, 2, 2])
34 | self.encoder_layer2 = make_building_block(inch[2], [outch1, outch2, outch3], [5, 5, 3], [4, 4, 2])
35 |
36 | # Mixing building blocks
37 | self.encoder_layer4to3 = make_building_block(outch3, [outch3, outch3, outch3], [3, 3, 3], [1, 1, 1])
38 | self.encoder_layer3to2 = make_building_block(outch3, [outch3, outch3, outch3], [3, 3, 3], [1, 1, 1])
39 |
40 | # Decoder layers
41 | # self.decoder1 = nn.Sequential(nn.Conv2d(outch3, outch3, (3, 3), padding=(1, 1), bias=True),
42 | # nn.ReLU(),
43 | # nn.Conv2d(outch3, outch2, (3, 3), padding=(1, 1), bias=True),
44 | # nn.ReLU())
45 | #
46 | # self.decoder2 = nn.Sequential(nn.Conv2d(outch2, outch2, (3, 3), padding=(1, 1), bias=True),
47 | # nn.ReLU(),
48 | # nn.Conv2d(outch2, 2, (3, 3), padding=(1, 1), bias=True))
49 | self.decoder1 = None
50 | self.decoder2 = None
51 | # 注意,HSNet代码里,learner需要在这里定义decoder的,但是我们的代码把decoder定义在了hsnet_imr中,所以这里的decoder
52 | # 其实没用。为了开源代码,必须公司审核,这块不注释掉总是报错“重复率过高”,没办法我只能把这块先注释掉
53 |
54 |
55 |
56 |
57 |
58 | def interpolate_support_dims(self, hypercorr, spatial_size=None):
59 | bsz, ch, ha, wa, hb, wb = hypercorr.size()
60 | hypercorr = hypercorr.permute(0, 4, 5, 1, 2, 3).contiguous().view(bsz * hb * wb, ch, ha, wa)
61 | hypercorr = F.interpolate(hypercorr, spatial_size, mode='bilinear', align_corners=True)
62 | o_hb, o_wb = spatial_size
63 | hypercorr = hypercorr.view(bsz, hb, wb, ch, o_hb, o_wb).permute(0, 3, 4, 5, 1, 2).contiguous()
64 | return hypercorr
65 |
66 | def forward_conv4d(self, hypercorr_pyramid):
67 | # Encode hypercorrelations from each layer (Squeezing building blocks)
68 | hypercorr_sqz4 = self.encoder_layer4(hypercorr_pyramid[0])
69 | hypercorr_sqz3 = self.encoder_layer3(hypercorr_pyramid[1])
70 | hypercorr_sqz2 = self.encoder_layer2(hypercorr_pyramid[2])
71 |
72 | # Propagate encoded 4D-tensor (Mixing building blocks)
73 | hypercorr_sqz4 = self.interpolate_support_dims(hypercorr_sqz4, hypercorr_sqz3.size()[-4:-2])
74 | hypercorr_mix43 = hypercorr_sqz4 + hypercorr_sqz3
75 | hypercorr_mix43 = self.encoder_layer4to3(hypercorr_mix43)
76 |
77 | hypercorr_mix43 = self.interpolate_support_dims(hypercorr_mix43, hypercorr_sqz2.size()[-4:-2])
78 | hypercorr_mix432 = hypercorr_mix43 + hypercorr_sqz2
79 | hypercorr_mix432 = self.encoder_layer3to2(hypercorr_mix432)
80 |
81 | bsz, ch, ha, wa, hb, wb = hypercorr_mix432.size()
82 | hypercorr_encoded = hypercorr_mix432.view(bsz, ch, ha, wa, -1).mean(dim=-1) # torch.Size([1, 128, 50, 50])
83 | return hypercorr_encoded
84 |
85 | def forward_decode(self, hypercorr_encoded):
86 | hypercorr_decoded = self.decoder1(hypercorr_encoded)
87 | upsample_size = (hypercorr_decoded.size(-1) * 2,) * 2
88 | hypercorr_decoded = F.interpolate(hypercorr_decoded, upsample_size, mode='bilinear', align_corners=True)
89 | logit_mask = self.decoder2(hypercorr_decoded)
90 | return logit_mask
91 |
92 | def forward(self, hypercorr_pyramid):
93 | hypercorr_encoded = self.forward_conv4d(hypercorr_pyramid)
94 | # vgg16: torch.Size([1, 128, 50, 50])
95 | # r50; torch.Size([1, 128, 50, 50])
96 | # pdb.set_trace()
97 | logit_mask = self.forward_decode(hypercorr_encoded)
98 | return logit_mask
99 |
100 | # def forward(self, hypercorr_pyramid):
101 | #
102 | # # Encode hypercorrelations from each layer (Squeezing building blocks)
103 | # hypercorr_sqz4 = self.encoder_layer4(hypercorr_pyramid[0])
104 | # hypercorr_sqz3 = self.encoder_layer3(hypercorr_pyramid[1])
105 | # hypercorr_sqz2 = self.encoder_layer2(hypercorr_pyramid[2])
106 | # pdb.set_trace()
107 | #
108 | # # Propagate encoded 4D-tensor (Mixing building blocks)
109 | # hypercorr_sqz4 = self.interpolate_support_dims(hypercorr_sqz4, hypercorr_sqz3.size()[-4:-2])
110 | # hypercorr_mix43 = hypercorr_sqz4 + hypercorr_sqz3
111 | # hypercorr_mix43 = self.encoder_layer4to3(hypercorr_mix43)
112 | #
113 | # hypercorr_mix43 = self.interpolate_support_dims(hypercorr_mix43, hypercorr_sqz2.size()[-4:-2])
114 | # hypercorr_mix432 = hypercorr_mix43 + hypercorr_sqz2
115 | # hypercorr_mix432 = self.encoder_layer3to2(hypercorr_mix432)
116 | #
117 | #
118 | # bsz, ch, ha, wa, hb, wb = hypercorr_mix432.size()
119 | # hypercorr_encoded = hypercorr_mix432.view(bsz, ch, ha, wa, -1).mean(dim=-1) #torch.Size([1, 128, 50, 50])
120 | #
121 | # # Decode the encoded 4D-tensor
122 | # hypercorr_decoded = self.decoder1(hypercorr_encoded)
123 | # upsample_size = (hypercorr_decoded.size(-1) * 2,) * 2
124 | # hypercorr_decoded = F.interpolate(hypercorr_decoded, upsample_size, mode='bilinear', align_corners=True)
125 | # logit_mask = self.decoder2(hypercorr_decoded)
126 | #
127 | # return logit_mask
128 |
--------------------------------------------------------------------------------
/pytorch_grad_cam/__init__.py:
--------------------------------------------------------------------------------
1 | from pytorch_grad_cam.grad_cam import GradCAM
2 | from pytorch_grad_cam.ablation_cam import AblationCAM
3 | from pytorch_grad_cam.xgrad_cam import XGradCAM
4 | from pytorch_grad_cam.grad_cam_plusplus import GradCAMPlusPlus
5 | from pytorch_grad_cam.score_cam import ScoreCAM
6 | from pytorch_grad_cam.layer_cam import LayerCAM
7 | from pytorch_grad_cam.eigen_cam import EigenCAM
8 | from pytorch_grad_cam.eigen_grad_cam import EigenGradCAM
9 | from pytorch_grad_cam.fullgrad_cam import FullGrad
10 | from pytorch_grad_cam.guided_backprop import GuidedBackpropReLUModel
11 |
--------------------------------------------------------------------------------
/pytorch_grad_cam/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Whileherham/IMR-HSNet/57572272f493c3b16540a650ff0398d605e46812/pytorch_grad_cam/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/pytorch_grad_cam/__pycache__/ablation_cam.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Whileherham/IMR-HSNet/57572272f493c3b16540a650ff0398d605e46812/pytorch_grad_cam/__pycache__/ablation_cam.cpython-38.pyc
--------------------------------------------------------------------------------
/pytorch_grad_cam/__pycache__/activations_and_gradients.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Whileherham/IMR-HSNet/57572272f493c3b16540a650ff0398d605e46812/pytorch_grad_cam/__pycache__/activations_and_gradients.cpython-38.pyc
--------------------------------------------------------------------------------
/pytorch_grad_cam/__pycache__/base_cam.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Whileherham/IMR-HSNet/57572272f493c3b16540a650ff0398d605e46812/pytorch_grad_cam/__pycache__/base_cam.cpython-38.pyc
--------------------------------------------------------------------------------
/pytorch_grad_cam/__pycache__/eigen_cam.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Whileherham/IMR-HSNet/57572272f493c3b16540a650ff0398d605e46812/pytorch_grad_cam/__pycache__/eigen_cam.cpython-38.pyc
--------------------------------------------------------------------------------
/pytorch_grad_cam/__pycache__/eigen_grad_cam.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Whileherham/IMR-HSNet/57572272f493c3b16540a650ff0398d605e46812/pytorch_grad_cam/__pycache__/eigen_grad_cam.cpython-38.pyc
--------------------------------------------------------------------------------
/pytorch_grad_cam/__pycache__/fullgrad_cam.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Whileherham/IMR-HSNet/57572272f493c3b16540a650ff0398d605e46812/pytorch_grad_cam/__pycache__/fullgrad_cam.cpython-38.pyc
--------------------------------------------------------------------------------
/pytorch_grad_cam/__pycache__/grad_cam.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Whileherham/IMR-HSNet/57572272f493c3b16540a650ff0398d605e46812/pytorch_grad_cam/__pycache__/grad_cam.cpython-38.pyc
--------------------------------------------------------------------------------
/pytorch_grad_cam/__pycache__/grad_cam_plusplus.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Whileherham/IMR-HSNet/57572272f493c3b16540a650ff0398d605e46812/pytorch_grad_cam/__pycache__/grad_cam_plusplus.cpython-38.pyc
--------------------------------------------------------------------------------
/pytorch_grad_cam/__pycache__/guided_backprop.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Whileherham/IMR-HSNet/57572272f493c3b16540a650ff0398d605e46812/pytorch_grad_cam/__pycache__/guided_backprop.cpython-38.pyc
--------------------------------------------------------------------------------
/pytorch_grad_cam/__pycache__/layer_cam.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Whileherham/IMR-HSNet/57572272f493c3b16540a650ff0398d605e46812/pytorch_grad_cam/__pycache__/layer_cam.cpython-38.pyc
--------------------------------------------------------------------------------
/pytorch_grad_cam/__pycache__/score_cam.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Whileherham/IMR-HSNet/57572272f493c3b16540a650ff0398d605e46812/pytorch_grad_cam/__pycache__/score_cam.cpython-38.pyc
--------------------------------------------------------------------------------
/pytorch_grad_cam/__pycache__/xgrad_cam.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Whileherham/IMR-HSNet/57572272f493c3b16540a650ff0398d605e46812/pytorch_grad_cam/__pycache__/xgrad_cam.cpython-38.pyc
--------------------------------------------------------------------------------
/pytorch_grad_cam/ablation_cam.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import tqdm
4 | from pytorch_grad_cam.base_cam import BaseCAM
5 | from pytorch_grad_cam.utils.find_layers import replace_layer_recursive
6 |
7 |
8 | class AblationLayer(torch.nn.Module):
9 | def __init__(self, layer, reshape_transform, indices):
10 | super(AblationLayer, self).__init__()
11 |
12 | self.layer = layer
13 | self.reshape_transform = reshape_transform
14 | # The channels to zero out:
15 | self.indices = indices
16 |
17 | def forward(self, x):
18 | self.__call__(x)
19 |
20 | def __call__(self, x):
21 | output = self.layer(x)
22 |
23 | # Hack to work with ViT,
24 | # Since the activation channels are last and not first like in CNNs
25 | # Probably should remove it?
26 | if self.reshape_transform is not None:
27 | output = output.transpose(1, 2)
28 |
29 | for i in range(output.size(0)):
30 |
31 | # Commonly the minimum activation will be 0,
32 | # And then it makes sense to zero it out.
33 | # However depending on the architecture,
34 | # If the values can be negative, we use very negative values
35 | # to perform the ablation, deviating from the paper.
36 | if torch.min(output) == 0:
37 | output[i, self.indices[i], :] = 0
38 | else:
39 | ABLATION_VALUE = 1e5
40 | output[i, self.indices[i], :] = torch.min(
41 | output) - ABLATION_VALUE
42 |
43 | if self.reshape_transform is not None:
44 | output = output.transpose(2, 1)
45 |
46 | return output
47 |
48 |
49 | class AblationCAM(BaseCAM):
50 | def __init__(self,
51 | model,
52 | target_layers,
53 | use_cuda=False,
54 | reshape_transform=None):
55 | super(AblationCAM, self).__init__(model, target_layers, use_cuda,
56 | reshape_transform)
57 |
58 | def get_cam_weights(self,
59 | input_tensor,
60 | target_layer,
61 | target_category,
62 | activations,
63 | grads):
64 | with torch.no_grad():
65 | outputs = self.model(input_tensor).cpu().numpy()
66 | original_scores = []
67 | for i in range(input_tensor.size(0)):
68 | original_scores.append(outputs[i, target_category[i]])
69 | original_scores = np.float32(original_scores)
70 |
71 | ablation_layer = AblationLayer(target_layer,
72 | self.reshape_transform,
73 | indices=[])
74 | replace_layer_recursive(self.model, target_layer, ablation_layer)
75 |
76 | if hasattr(self, "batch_size"):
77 | BATCH_SIZE = self.batch_size
78 | else:
79 | BATCH_SIZE = 32
80 |
81 | number_of_channels = activations.shape[1]
82 | weights = []
83 |
84 | with torch.no_grad():
85 | # Iterate over the input batch
86 | for tensor, category in zip(input_tensor, target_category):
87 | batch_tensor = tensor.repeat(BATCH_SIZE, 1, 1, 1)
88 | for i in tqdm.tqdm(range(0, number_of_channels, BATCH_SIZE)):
89 | ablation_layer.indices = list(range(i, i + BATCH_SIZE))
90 |
91 | if i + BATCH_SIZE > number_of_channels:
92 | keep = number_of_channels - i
93 | batch_tensor = batch_tensor[:keep]
94 | ablation_layer.indices = ablation_layer.indices[:keep]
95 | score = self.model(batch_tensor)[:, category].cpu().numpy()
96 | weights.extend(score)
97 |
98 | weights = np.float32(weights)
99 | weights = weights.reshape(activations.shape[:2])
100 | original_scores = original_scores[:, None]
101 | weights = (original_scores - weights) / original_scores
102 |
103 | # Replace the model back to the original state
104 | replace_layer_recursive(self.model, ablation_layer, target_layer)
105 | return weights
106 |
--------------------------------------------------------------------------------
/pytorch_grad_cam/activations_and_gradients.py:
--------------------------------------------------------------------------------
1 | class ActivationsAndGradients:
2 | """ Class for extracting activations and
3 | registering gradients from targetted intermediate layers """
4 |
5 | def __init__(self, model, target_layers, reshape_transform):
6 | self.model = model
7 | self.gradients = []
8 | self.activations = []
9 | self.reshape_transform = reshape_transform
10 | self.handles = []
11 | for target_layer in target_layers:
12 | self.handles.append(
13 | target_layer.register_forward_hook(
14 | self.save_activation))
15 | # Backward compitability with older pytorch versions:
16 | if hasattr(target_layer, 'register_full_backward_hook'):
17 | self.handles.append(
18 | target_layer.register_full_backward_hook(
19 | self.save_gradient))
20 | else:
21 | self.handles.append(
22 | target_layer.register_backward_hook(
23 | self.save_gradient))
24 |
25 | def save_activation(self, module, input, output):
26 | activation = output
27 | if self.reshape_transform is not None:
28 | activation = self.reshape_transform(activation)
29 | self.activations.append(activation.cpu().detach())
30 |
31 | def save_gradient(self, module, grad_input, grad_output):
32 | # Gradients are computed in reverse order
33 | grad = grad_output[0]
34 | if self.reshape_transform is not None:
35 | grad = self.reshape_transform(grad)
36 | self.gradients = [grad.cpu().detach()] + self.gradients
37 |
38 | def __call__(self, x):
39 | self.gradients = []
40 | self.activations = []
41 | return self.model(x)
42 |
43 | def release(self):
44 | for handle in self.handles:
45 | handle.remove()
46 |
--------------------------------------------------------------------------------
/pytorch_grad_cam/base_cam.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | import torch
4 | import ttach as tta
5 | from pytorch_grad_cam.activations_and_gradients import ActivationsAndGradients
6 | from pytorch_grad_cam.utils.svd_on_activations import get_2d_projection
7 |
8 |
9 | class BaseCAM:
10 | def __init__(self,
11 | model,
12 | target_layers,
13 | use_cuda=False,
14 | reshape_transform=None,
15 | compute_input_gradient=False,
16 | uses_gradients=True):
17 | self.model = model.eval()
18 | self.target_layers = target_layers
19 | self.cuda = use_cuda
20 | if self.cuda:
21 | self.model = model.cuda()
22 | self.reshape_transform = reshape_transform
23 | self.compute_input_gradient = compute_input_gradient
24 | self.uses_gradients = uses_gradients
25 | self.activations_and_grads = ActivationsAndGradients(
26 | self.model, target_layers, reshape_transform)
27 |
28 | """ Get a vector of weights for every channel in the target layer.
29 | Methods that return weights channels,
30 | will typically need to only implement this function. """
31 |
32 | def get_cam_weights(self,
33 | input_tensor,
34 | target_layers,
35 | target_category,
36 | activations,
37 | grads):
38 | raise Exception("Not Implemented")
39 |
40 | def get_loss(self, output, target_category):
41 | loss = 0
42 | for i in range(len(target_category)):
43 | loss = loss + output[i, target_category[i]]
44 | return loss
45 |
46 | def get_cam_image(self,
47 | input_tensor,
48 | target_layer,
49 | target_category,
50 | activations,
51 | grads,
52 | eigen_smooth=False):
53 | weights = self.get_cam_weights(input_tensor, target_layer,
54 | target_category, activations, grads)
55 |
56 | weighted_activations = weights[:, :, None, None] * activations
57 |
58 | # 不管n=? activation一样
59 | if eigen_smooth:
60 | cam = get_2d_projection(weighted_activations)
61 | else:
62 | cam = weighted_activations.sum(axis=1)
63 | # pdb.set_trace()
64 | return cam
65 |
66 | def forward(self, input_tensor, target_category=None, eigen_smooth=False):
67 |
68 | if self.cuda:
69 | input_tensor = input_tensor.cuda()
70 |
71 | if self.compute_input_gradient: # False
72 | input_tensor = torch.autograd.Variable(input_tensor,
73 | requires_grad=True)
74 |
75 | output = self.activations_and_grads(input_tensor) # tensor.shape 1,N
76 |
77 | if isinstance(target_category, int):
78 | target_category = [target_category] * input_tensor.size(0)
79 |
80 | if target_category is None:
81 | target_category = np.argmax(output.cpu().data.numpy(), axis=-1)
82 | else:
83 | assert (len(target_category) == input_tensor.size(0))
84 |
85 | if self.uses_gradients: # True
86 | self.model.zero_grad()
87 |
88 | loss = self.get_loss(output, target_category)
89 | loss.backward(retain_graph=True)
90 |
91 | # In most of the saliency attribution papers, the saliency is
92 | # computed with a single target layer.
93 | # Commonly it is the last convolutional layer.
94 | # Here we support passing a list with multiple target layers.
95 | # It will compute the saliency image for every image,
96 | # and then aggregate them (with a default mean aggregation).
97 | # This gives you more flexibility in case you just want to
98 | # use all conv layers for example, all Batchnorm layers,
99 | # or something else.
100 |
101 | cam_per_layer = self.compute_cam_per_layer(input_tensor,
102 | target_category,
103 | eigen_smooth)
104 | return self.aggregate_multi_layers(cam_per_layer)
105 |
106 | def get_target_width_height(self, input_tensor):
107 | width, height = input_tensor.size(-1), input_tensor.size(-2)
108 | return width, height
109 |
110 | def compute_cam_per_layer(
111 | self,
112 | input_tensor,
113 | target_category,
114 | eigen_smooth):
115 | activations_list = [a.cpu().data.numpy()
116 | for a in self.activations_and_grads.activations]
117 | grads_list = [g.cpu().data.numpy()
118 | for g in self.activations_and_grads.gradients]
119 | target_size = self.get_target_width_height(input_tensor)
120 |
121 | cam_per_target_layer = []
122 | # Loop over the saliency image from every layer
123 |
124 | for target_layer, layer_activations, layer_grads in \
125 | zip(self.target_layers, activations_list, grads_list):
126 | cam = self.get_cam_image(input_tensor,
127 | target_layer,
128 | target_category,
129 | layer_activations,
130 | layer_grads,
131 | eigen_smooth)
132 | # pdb.set_trace()
133 | cam[cam < 0] = 0 # works like mute the min-max scale in the function of scale_cam_image
134 | scaled = self.scale_cam_image(cam, target_size)
135 | cam_per_target_layer.append(scaled[:, None, :])
136 |
137 | return cam_per_target_layer
138 |
139 | def aggregate_multi_layers(self, cam_per_target_layer):
140 | cam_per_target_layer = np.concatenate(cam_per_target_layer, axis=1)
141 | cam_per_target_layer = np.maximum(cam_per_target_layer, 0)
142 | result = np.mean(cam_per_target_layer, axis=1)
143 | return self.scale_cam_image(result)
144 |
145 | def scale_cam_image(self, cam, target_size=None):
146 | result = []
147 | for img in cam:
148 |
149 | # WHH ADD
150 | # pdb.set_trace()
151 |
152 | img = img - np.min(img)
153 | img = img / (1e-7 + np.max(img))
154 | if target_size is not None:
155 | img = np.float32(img) # whh add here不然就报错
156 | img = cv2.resize(img, target_size)
157 | result.append(img)
158 | result = np.float32(result)
159 |
160 | return result
161 |
162 | def forward_augmentation_smoothing(self,
163 | input_tensor,
164 | target_category=None,
165 | eigen_smooth=False):
166 | transforms = tta.Compose(
167 | [
168 | tta.HorizontalFlip(),
169 | tta.Multiply(factors=[0.9, 1, 1.1]),
170 | ]
171 | )
172 | cams = []
173 | for transform in transforms:
174 | augmented_tensor = transform.augment_image(input_tensor)
175 | cam = self.forward(augmented_tensor,
176 | target_category, eigen_smooth)
177 |
178 | # The ttach library expects a tensor of size BxCxHxW
179 | cam = cam[:, None, :, :]
180 | cam = torch.from_numpy(cam)
181 | cam = transform.deaugment_mask(cam)
182 |
183 | # Back to numpy float32, HxW
184 | cam = cam.numpy()
185 | cam = cam[:, 0, :, :]
186 | cams.append(cam)
187 |
188 | cam = np.mean(np.float32(cams), axis=0)
189 | return cam
190 |
191 | def __call__(self,
192 | input_tensor,
193 | target_category=None,
194 | aug_smooth=False,
195 | eigen_smooth=False):
196 |
197 | # Smooth the CAM result with test time augmentation
198 | # print('call, !=forward', aug_smooth) # aug smooth = False
199 | if aug_smooth is True:
200 | return self.forward_augmentation_smoothing(
201 | input_tensor, target_category, eigen_smooth)
202 |
203 | return self.forward(input_tensor,
204 | target_category, eigen_smooth)
205 |
206 | def __del__(self):
207 | self.activations_and_grads.release()
208 |
209 | def __enter__(self):
210 | return self
211 |
212 | def __exit__(self, exc_type, exc_value, exc_tb):
213 | self.activations_and_grads.release()
214 | if isinstance(exc_value, IndexError):
215 | # Handle IndexError here...
216 | print(
217 | f"An exception occurred in CAM with block: {exc_type}. Message: {exc_value}")
218 | return True
219 |
--------------------------------------------------------------------------------
/pytorch_grad_cam/eigen_cam.py:
--------------------------------------------------------------------------------
1 | from pytorch_grad_cam.base_cam import BaseCAM
2 | from pytorch_grad_cam.utils.svd_on_activations import get_2d_projection
3 |
4 | # https://arxiv.org/abs/2008.00299
5 |
6 |
7 | class EigenCAM(BaseCAM):
8 | def __init__(self, model, target_layers, use_cuda=False,
9 | reshape_transform=None):
10 | super(EigenCAM, self).__init__(model, target_layers, use_cuda,
11 | reshape_transform)
12 |
13 | def get_cam_image(self,
14 | input_tensor,
15 | target_layer,
16 | target_category,
17 | activations,
18 | grads,
19 | eigen_smooth):
20 | return get_2d_projection(activations)
21 |
--------------------------------------------------------------------------------
/pytorch_grad_cam/eigen_grad_cam.py:
--------------------------------------------------------------------------------
1 | from pytorch_grad_cam.base_cam import BaseCAM
2 | from pytorch_grad_cam.utils.svd_on_activations import get_2d_projection
3 |
4 | # Like Eigen CAM: https://arxiv.org/abs/2008.00299
5 | # But multiply the activations x gradients
6 |
7 |
8 | class EigenGradCAM(BaseCAM):
9 | def __init__(self, model, target_layers, use_cuda=False,
10 | reshape_transform=None):
11 | super(EigenGradCAM, self).__init__(model, target_layers, use_cuda,
12 | reshape_transform)
13 |
14 | def get_cam_image(self,
15 | input_tensor,
16 | target_layer,
17 | target_category,
18 | activations,
19 | grads,
20 | eigen_smooth):
21 | return get_2d_projection(grads * activations)
22 |
--------------------------------------------------------------------------------
/pytorch_grad_cam/fullgrad_cam.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from pytorch_grad_cam.base_cam import BaseCAM
4 | from pytorch_grad_cam.utils.find_layers import find_layer_predicate_recursive
5 | from pytorch_grad_cam.utils.svd_on_activations import get_2d_projection
6 |
7 | # https://arxiv.org/abs/1905.00780
8 |
9 |
10 | class FullGrad(BaseCAM):
11 | def __init__(self, model, target_layers, use_cuda=False,
12 | reshape_transform=None):
13 | if len(target_layers) > 0:
14 | print(
15 | "Warning: target_layers is ignored in FullGrad. All bias layers will be used instead")
16 |
17 | def layer_with_2D_bias(layer):
18 | bias_target_layers = [torch.nn.Conv2d, torch.nn.BatchNorm2d]
19 | if type(layer) in bias_target_layers and layer.bias is not None:
20 | return True
21 | return False
22 | target_layers = find_layer_predicate_recursive(
23 | model, layer_with_2D_bias)
24 | super(
25 | FullGrad,
26 | self).__init__(
27 | model,
28 | target_layers,
29 | use_cuda,
30 | reshape_transform,
31 | compute_input_gradient=True)
32 | self.bias_data = [self.get_bias_data(
33 | layer).cpu().numpy() for layer in target_layers]
34 |
35 | def get_bias_data(self, layer):
36 | # Borrowed from official paper impl:
37 | # https://github.com/idiap/fullgrad-saliency/blob/master/saliency/tensor_extractor.py#L47
38 | if isinstance(layer, torch.nn.BatchNorm2d):
39 | bias = - (layer.running_mean * layer.weight
40 | / torch.sqrt(layer.running_var + layer.eps)) + layer.bias
41 | return bias.data
42 | else:
43 | return layer.bias.data
44 |
45 | def scale_accross_batch_and_channels(self, tensor, target_size):
46 | batch_size, channel_size = tensor.shape[:2]
47 | reshaped_tensor = tensor.reshape(
48 | batch_size * channel_size, *tensor.shape[2:])
49 | result = self.scale_cam_image(reshaped_tensor, target_size)
50 | result = result.reshape(
51 | batch_size,
52 | channel_size,
53 | target_size[1],
54 | target_size[0])
55 | return result
56 |
57 | def compute_cam_per_layer(
58 | self,
59 | input_tensor,
60 | target_category,
61 | eigen_smooth):
62 | input_grad = input_tensor.grad.data.cpu().numpy()
63 | grads_list = [g.cpu().data.numpy() for g in
64 | self.activations_and_grads.gradients]
65 | cam_per_target_layer = []
66 | target_size = self.get_target_width_height(input_tensor)
67 |
68 | gradient_multiplied_input = input_grad * input_tensor.data.cpu().numpy()
69 | gradient_multiplied_input = np.abs(gradient_multiplied_input)
70 | gradient_multiplied_input = self.scale_accross_batch_and_channels(
71 | gradient_multiplied_input,
72 | target_size)
73 | cam_per_target_layer.append(gradient_multiplied_input)
74 |
75 | # Loop over the saliency image from every layer
76 | assert(len(self.bias_data) == len(grads_list))
77 | for bias, grads in zip(self.bias_data, grads_list):
78 | bias = bias[None, :, None, None]
79 | # In the paper they take the absolute value,
80 | # but possibily taking only the positive gradients will work
81 | # better.
82 | bias_grad = np.abs(bias * grads)
83 | result = self.scale_accross_batch_and_channels(
84 | bias_grad, target_size)
85 | result = np.sum(result, axis=1)
86 | cam_per_target_layer.append(result[:, None, :])
87 | cam_per_target_layer = np.concatenate(cam_per_target_layer, axis=1)
88 | if eigen_smooth:
89 | # Resize to a smaller image, since this method typically has a very large number of channels,
90 | # and then consumes a lot of memory
91 | cam_per_target_layer = self.scale_accross_batch_and_channels(
92 | cam_per_target_layer, (target_size[0] // 8, target_size[1] // 8))
93 | cam_per_target_layer = get_2d_projection(cam_per_target_layer)
94 | cam_per_target_layer = cam_per_target_layer[:, None, :, :]
95 | cam_per_target_layer = self.scale_accross_batch_and_channels(
96 | cam_per_target_layer,
97 | target_size)
98 | else:
99 | cam_per_target_layer = np.sum(
100 | cam_per_target_layer, axis=1)[:, None, :]
101 |
102 | return cam_per_target_layer
103 |
104 | def aggregate_multi_layers(self, cam_per_target_layer):
105 | result = np.sum(cam_per_target_layer, axis=1)
106 | return self.scale_cam_image(result)
107 |
--------------------------------------------------------------------------------
/pytorch_grad_cam/grad_cam.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from pytorch_grad_cam.base_cam import BaseCAM
3 | import pdb
4 |
5 |
6 | class GradCAM(BaseCAM):
7 | def __init__(self, model, target_layers, use_cuda=False,
8 | reshape_transform=None):
9 | super(
10 | GradCAM,
11 | self).__init__(
12 | model,
13 | target_layers,
14 | use_cuda,
15 | reshape_transform)
16 |
17 | def get_cam_weights(self,
18 | input_tensor,
19 | target_layer,
20 | target_category,
21 | activations,
22 | grads):
23 | return np.mean(grads, axis=(2, 3))
24 |
--------------------------------------------------------------------------------
/pytorch_grad_cam/grad_cam_plusplus.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from pytorch_grad_cam.base_cam import BaseCAM
3 |
4 | # https://arxiv.org/abs/1710.11063
5 |
6 |
7 | class GradCAMPlusPlus(BaseCAM):
8 | def __init__(self, model, target_layers, use_cuda=False,
9 | reshape_transform=None):
10 | super(GradCAMPlusPlus, self).__init__(model, target_layers, use_cuda,
11 | reshape_transform)
12 |
13 | def get_cam_weights(self,
14 | input_tensor,
15 | target_layers,
16 | target_category,
17 | activations,
18 | grads):
19 | grads_power_2 = grads**2
20 | grads_power_3 = grads_power_2 * grads
21 | # Equation 19 in https://arxiv.org/abs/1710.11063
22 | sum_activations = np.sum(activations, axis=(2, 3))
23 | eps = 0.000001
24 | aij = grads_power_2 / (2 * grads_power_2 +
25 | sum_activations[:, :, None, None] * grads_power_3 + eps)
26 | # Now bring back the ReLU from eq.7 in the paper,
27 | # And zero out aijs where the activations are 0
28 | aij = np.where(grads != 0, aij, 0)
29 |
30 | weights = np.maximum(grads, 0) * aij
31 | weights = np.sum(weights, axis=(2, 3))
32 | return weights
33 |
--------------------------------------------------------------------------------
/pytorch_grad_cam/guided_backprop.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch.autograd import Function
4 | from pytorch_grad_cam.utils.find_layers import replace_all_layer_type_recursive
5 |
6 |
7 | class GuidedBackpropReLU(Function):
8 | @staticmethod
9 | def forward(self, input_img):
10 | positive_mask = (input_img > 0).type_as(input_img)
11 | output = torch.addcmul(
12 | torch.zeros(
13 | input_img.size()).type_as(input_img),
14 | input_img,
15 | positive_mask)
16 | self.save_for_backward(input_img, output)
17 | return output
18 |
19 | @staticmethod
20 | def backward(self, grad_output):
21 | input_img, output = self.saved_tensors
22 | grad_input = None
23 |
24 | positive_mask_1 = (input_img > 0).type_as(grad_output)
25 | positive_mask_2 = (grad_output > 0).type_as(grad_output)
26 | grad_input = torch.addcmul(
27 | torch.zeros(
28 | input_img.size()).type_as(input_img),
29 | torch.addcmul(
30 | torch.zeros(
31 | input_img.size()).type_as(input_img),
32 | grad_output,
33 | positive_mask_1),
34 | positive_mask_2)
35 | return grad_input
36 |
37 |
38 | class GuidedBackpropReLUasModule(torch.nn.Module):
39 | def __init__(self):
40 | super(GuidedBackpropReLUasModule, self).__init__()
41 |
42 | def forward(self, input_img):
43 | return GuidedBackpropReLU.apply(input_img)
44 |
45 |
46 | class GuidedBackpropReLUModel:
47 | def __init__(self, model, use_cuda):
48 | self.model = model
49 | self.model.eval()
50 | self.cuda = use_cuda
51 | if self.cuda:
52 | self.model = self.model.cuda()
53 |
54 | def forward(self, input_img):
55 | return self.model(input_img)
56 |
57 | def recursive_replace_relu_with_guidedrelu(self, module_top):
58 |
59 | for idx, module in module_top._modules.items():
60 | self.recursive_replace_relu_with_guidedrelu(module)
61 | if module.__class__.__name__ == 'ReLU':
62 | module_top._modules[idx] = GuidedBackpropReLU.apply
63 | print("b")
64 |
65 | def recursive_replace_guidedrelu_with_relu(self, module_top):
66 | # noinspection PyBroadException
67 | try:
68 | for idx, module in module_top._modules.items():
69 | self.recursive_replace_guidedrelu_with_relu(module)
70 | if module == GuidedBackpropReLU.apply:
71 | module_top._modules[idx] = torch.nn.ReLU()
72 | except BaseException:
73 | pass
74 |
75 | def __call__(self, input_img, target_category=None):
76 | replace_all_layer_type_recursive(self.model,
77 | torch.nn.ReLU,
78 | GuidedBackpropReLUasModule())
79 |
80 | if self.cuda:
81 | input_img = input_img.cuda()
82 |
83 | input_img = input_img.requires_grad_(True)
84 |
85 | output = self.forward(input_img)
86 |
87 | if target_category is None:
88 | target_category = np.argmax(output.cpu().data.numpy())
89 |
90 | loss = output[0, target_category]
91 | loss.backward(retain_graph=True)
92 |
93 | output = input_img.grad.cpu().data.numpy()
94 | output = output[0, :, :, :]
95 | output = output.transpose((1, 2, 0))
96 |
97 | replace_all_layer_type_recursive(self.model,
98 | GuidedBackpropReLUasModule,
99 | torch.nn.ReLU())
100 |
101 | return output
102 |
--------------------------------------------------------------------------------
/pytorch_grad_cam/layer_cam.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from pytorch_grad_cam.base_cam import BaseCAM
3 | from pytorch_grad_cam.utils.svd_on_activations import get_2d_projection
4 |
5 | # https://ieeexplore.ieee.org/document/9462463
6 |
7 |
8 | class LayerCAM(BaseCAM):
9 | def __init__(
10 | self,
11 | model,
12 | target_layers,
13 | use_cuda=False,
14 | reshape_transform=None):
15 | super(
16 | LayerCAM,
17 | self).__init__(
18 | model,
19 | target_layers,
20 | use_cuda,
21 | reshape_transform)
22 |
23 | def get_cam_image(self,
24 | input_tensor,
25 | target_layer,
26 | target_category,
27 | activations,
28 | grads,
29 | eigen_smooth):
30 | spatial_weighted_activations = np.maximum(grads, 0) * activations
31 |
32 | if eigen_smooth:
33 | cam = get_2d_projection(spatial_weighted_activations)
34 | else:
35 | cam = spatial_weighted_activations.sum(axis=1)
36 | return cam
37 |
--------------------------------------------------------------------------------
/pytorch_grad_cam/score_cam.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import tqdm
3 | from pytorch_grad_cam.base_cam import BaseCAM
4 |
5 |
6 | class ScoreCAM(BaseCAM):
7 | def __init__(
8 | self,
9 | model,
10 | target_layers,
11 | use_cuda=False,
12 | reshape_transform=None):
13 | super(ScoreCAM, self).__init__(model, target_layers, use_cuda,
14 | reshape_transform=reshape_transform)
15 |
16 | if len(target_layers) > 0:
17 | print("Warning: You are using ScoreCAM with target layers, "
18 | "however ScoreCAM will ignore them.")
19 |
20 | def get_cam_weights(self,
21 | input_tensor,
22 | target_layer,
23 | target_category,
24 | activations,
25 | grads):
26 | with torch.no_grad():
27 | upsample = torch.nn.UpsamplingBilinear2d(
28 | size=input_tensor.shape[-2:])
29 | activation_tensor = torch.from_numpy(activations)
30 | if self.cuda:
31 | activation_tensor = activation_tensor.cuda()
32 |
33 | upsampled = upsample(activation_tensor)
34 |
35 | maxs = upsampled.view(upsampled.size(0),
36 | upsampled.size(1), -1).max(dim=-1)[0]
37 | mins = upsampled.view(upsampled.size(0),
38 | upsampled.size(1), -1).min(dim=-1)[0]
39 | maxs, mins = maxs[:, :, None, None], mins[:, :, None, None]
40 | upsampled = (upsampled - mins) / (maxs - mins)
41 |
42 | input_tensors = input_tensor[:, None,
43 | :, :] * upsampled[:, :, None, :, :]
44 |
45 | if hasattr(self, "batch_size"):
46 | BATCH_SIZE = self.batch_size
47 | else:
48 | BATCH_SIZE = 16
49 |
50 | scores = []
51 | for batch_index, tensor in enumerate(input_tensors):
52 | category = target_category[batch_index]
53 | for i in tqdm.tqdm(range(0, tensor.size(0), BATCH_SIZE)):
54 | batch = tensor[i: i + BATCH_SIZE, :]
55 | outputs = self.model(batch).cpu().numpy()[:, category]
56 | scores.extend(outputs)
57 | scores = torch.Tensor(scores)
58 | scores = scores.view(activations.shape[0], activations.shape[1])
59 |
60 | weights = torch.nn.Softmax(dim=-1)(scores).numpy()
61 | return weights
62 |
--------------------------------------------------------------------------------
/pytorch_grad_cam/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from pytorch_grad_cam.utils.image import deprocess_image
2 | from pytorch_grad_cam.utils.svd_on_activations import get_2d_projection
3 |
--------------------------------------------------------------------------------
/pytorch_grad_cam/utils/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Whileherham/IMR-HSNet/57572272f493c3b16540a650ff0398d605e46812/pytorch_grad_cam/utils/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/pytorch_grad_cam/utils/__pycache__/find_layers.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Whileherham/IMR-HSNet/57572272f493c3b16540a650ff0398d605e46812/pytorch_grad_cam/utils/__pycache__/find_layers.cpython-38.pyc
--------------------------------------------------------------------------------
/pytorch_grad_cam/utils/__pycache__/image.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Whileherham/IMR-HSNet/57572272f493c3b16540a650ff0398d605e46812/pytorch_grad_cam/utils/__pycache__/image.cpython-38.pyc
--------------------------------------------------------------------------------
/pytorch_grad_cam/utils/__pycache__/svd_on_activations.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Whileherham/IMR-HSNet/57572272f493c3b16540a650ff0398d605e46812/pytorch_grad_cam/utils/__pycache__/svd_on_activations.cpython-38.pyc
--------------------------------------------------------------------------------
/pytorch_grad_cam/utils/find_layers.py:
--------------------------------------------------------------------------------
1 | def replace_layer_recursive(model, old_layer, new_layer):
2 | for name, layer in model._modules.items():
3 | if layer == old_layer:
4 | model._modules[name] = new_layer
5 | return True
6 | elif replace_layer_recursive(layer, old_layer, new_layer):
7 | return True
8 | return False
9 |
10 |
11 | def replace_all_layer_type_recursive(model, old_layer_type, new_layer):
12 | for name, layer in model._modules.items():
13 | if isinstance(layer, old_layer_type):
14 | model._modules[name] = new_layer
15 | replace_all_layer_type_recursive(layer, old_layer_type, new_layer)
16 |
17 |
18 | def find_layer_types_recursive(model, layer_types):
19 | def predicate(layer):
20 | return type(layer) in layer_types
21 | return find_layer_predicate_recursive(model, predicate)
22 |
23 |
24 | def find_layer_predicate_recursive(model, predicate):
25 | result = []
26 | for name, layer in model._modules.items():
27 | if predicate(layer):
28 | result.append(layer)
29 | result.extend(find_layer_predicate_recursive(layer, predicate))
30 | return result
31 |
--------------------------------------------------------------------------------
/pytorch_grad_cam/utils/image.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | import torch
4 | from torchvision.transforms import Compose, Normalize, ToTensor
5 |
6 |
7 | def preprocess_image(img: np.ndarray, mean=None, std=None) -> torch.Tensor:
8 | if mean is None:
9 | mean = [0.5, 0.5, 0.5]
10 | if std is None:
11 | std = [0.5, 0.5, 0.5]
12 | preprocessing = Compose([
13 | ToTensor(),
14 | Normalize(mean=mean, std=std)
15 | ])
16 | return preprocessing(img.copy()).unsqueeze(0)
17 |
18 |
19 | def deprocess_image(img):
20 | """ see https://github.com/jacobgil/keras-grad-cam/blob/master/grad-cam.py#L65 """
21 | img = img - np.mean(img)
22 | img = img / (np.std(img) + 1e-5)
23 | img = img * 0.1
24 | img = img + 0.5
25 | img = np.clip(img, 0, 1)
26 | return np.uint8(img * 255)
27 |
28 |
29 | def show_cam_on_image(img: np.ndarray,
30 | mask: np.ndarray,
31 | use_rgb: bool = False,
32 | colormap: int = cv2.COLORMAP_JET) -> np.ndarray:
33 | """ This function overlays the cam mask on the image as an heatmap.
34 | By default the heatmap is in BGR format.
35 |
36 | :param img: The base image in RGB or BGR format.
37 | :param mask: The cam mask.
38 | :param use_rgb: Whether to use an RGB or BGR heatmap, this should be set to True if 'img' is in RGB format.
39 | :param colormap: The OpenCV colormap to be used.
40 | :returns: The default image with the cam overlay.
41 | """
42 | heatmap = cv2.applyColorMap(np.uint8(255 * mask), colormap)
43 | if use_rgb:
44 | heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
45 | heatmap = np.float32(heatmap) / 255
46 |
47 | if np.max(img) > 1:
48 | raise Exception(
49 | "The input image should np.float32 in the range [0, 1]")
50 |
51 | cam = heatmap + img
52 | cam = cam / np.max(cam)
53 | return np.uint8(255 * cam)
54 |
--------------------------------------------------------------------------------
/pytorch_grad_cam/utils/svd_on_activations.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def get_2d_projection(activation_batch):
5 | # TBD: use pytorch batch svd implementation
6 | activation_batch[np.isnan(activation_batch)] = 0
7 | projections = []
8 | for activations in activation_batch:
9 | reshaped_activations = (activations).reshape(
10 | activations.shape[0], -1).transpose()
11 | # Centering before the SVD seems to be important here,
12 | # Otherwise the image returned is negative
13 | reshaped_activations = reshaped_activations - \
14 | reshaped_activations.mean(axis=0)
15 | U, S, VT = np.linalg.svd(reshaped_activations, full_matrices=True)
16 | projection = reshaped_activations @ VT[0, :]
17 | projection = projection.reshape(activations.shape[1:])
18 | projections.append(projection)
19 | return np.float32(projections)
20 |
--------------------------------------------------------------------------------
/pytorch_grad_cam/xgrad_cam.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from pytorch_grad_cam.base_cam import BaseCAM
3 |
4 |
5 | class XGradCAM(BaseCAM):
6 | def __init__(
7 | self,
8 | model,
9 | target_layers,
10 | use_cuda=False,
11 | reshape_transform=None):
12 | super(
13 | XGradCAM,
14 | self).__init__(
15 | model,
16 | target_layers,
17 | use_cuda,
18 | reshape_transform)
19 |
20 | def get_cam_weights(self,
21 | input_tensor,
22 | target_layer,
23 | target_category,
24 | activations,
25 | grads):
26 | sum_activations = np.sum(activations, axis=(2, 3))
27 | eps = 1e-7
28 | weights = grads * activations / \
29 | (sum_activations[:, :, None, None] + eps)
30 | weights = weights.sum(axis=(2, 3))
31 | return weights
32 |
--------------------------------------------------------------------------------
/simple_tokenizer.py:
--------------------------------------------------------------------------------
1 | import gzip
2 | import html
3 | import os
4 | from functools import lru_cache
5 |
6 | import ftfy
7 | import regex as re
8 |
9 |
10 | @lru_cache()
11 | def default_bpe():
12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
13 |
14 |
15 | @lru_cache()
16 | def bytes_to_unicode():
17 | """
18 | Returns list of utf-8 byte and a corresponding list of unicode strings.
19 | The reversible bpe codes work on unicode strings.
20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
22 | This is a signficant percentage of your normal, say, 32K bpe vocab.
23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
24 | And avoids mapping to whitespace/control characters the bpe code barfs on.
25 | """
26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
27 | cs = bs[:]
28 | n = 0
29 | for b in range(2**8):
30 | if b not in bs:
31 | bs.append(b)
32 | cs.append(2**8+n)
33 | n += 1
34 | cs = [chr(n) for n in cs]
35 | return dict(zip(bs, cs))
36 |
37 |
38 | def get_pairs(word):
39 | """Return set of symbol pairs in a word.
40 | Word is represented as tuple of symbols (symbols being variable-length strings).
41 | """
42 | pairs = set()
43 | prev_char = word[0]
44 | for char in word[1:]:
45 | pairs.add((prev_char, char))
46 | prev_char = char
47 | return pairs
48 |
49 |
50 | def basic_clean(text):
51 | text = ftfy.fix_text(text)
52 | text = html.unescape(html.unescape(text))
53 | return text.strip()
54 |
55 |
56 | def whitespace_clean(text):
57 | text = re.sub(r'\s+', ' ', text)
58 | text = text.strip()
59 | return text
60 |
61 |
62 | class SimpleTokenizer(object):
63 | def __init__(self, bpe_path: str = default_bpe()):
64 | self.byte_encoder = bytes_to_unicode()
65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
67 | merges = merges[1:49152-256-2+1]
68 | merges = [tuple(merge.split()) for merge in merges]
69 | vocab = list(bytes_to_unicode().values())
70 | vocab = vocab + [v+'' for v in vocab]
71 | for merge in merges:
72 | vocab.append(''.join(merge))
73 | vocab.extend(['<|startoftext|>', '<|endoftext|>'])
74 | self.encoder = dict(zip(vocab, range(len(vocab))))
75 | self.decoder = {v: k for k, v in self.encoder.items()}
76 | self.bpe_ranks = dict(zip(merges, range(len(merges))))
77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
78 | self.pat = re.compile(
79 | r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
80 | re.IGNORECASE)
81 |
82 | def bpe(self, token):
83 | if token in self.cache:
84 | return self.cache[token]
85 | word = tuple(token[:-1]) + ( token[-1] + '',)
86 | pairs = get_pairs(word)
87 |
88 | if not pairs:
89 | return token+''
90 |
91 | while True:
92 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
93 | if bigram not in self.bpe_ranks:
94 | break
95 | first, second = bigram
96 | new_word = []
97 | i = 0
98 | while i < len(word):
99 | # noinspection PyBroadException
100 | try:
101 | j = word.index(first, i)
102 | new_word.extend(word[i:j])
103 | i = j
104 | except:
105 | new_word.extend(word[i:])
106 | break
107 |
108 | if word[i] == first and i < len(word)-1 and word[i+1] == second:
109 | new_word.append(first+second)
110 | i += 2
111 | else:
112 | new_word.append(word[i])
113 | i += 1
114 | new_word = tuple(new_word)
115 | word = new_word
116 | if len(word) == 1:
117 | break
118 | else:
119 | pairs = get_pairs(word)
120 | word = ' '.join(word)
121 | self.cache[token] = word
122 | return word
123 |
124 | def encode(self, text):
125 | bpe_tokens = []
126 | text = whitespace_clean(basic_clean(text)).lower()
127 | for token in re.findall(self.pat, text):
128 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
129 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
130 | return bpe_tokens
131 |
132 | def decode(self, tokens):
133 | text = ''.join([self.decoder[token] for token in tokens])
134 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ')
135 | return text
136 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | r""" Hypercorrelation Squeeze testing code """
2 | import argparse
3 |
4 | import torch.nn.functional as F
5 | import torch.nn as nn
6 | import torch
7 |
8 | from model.hsnet_imr import HypercorrSqueezeNetwork_imr
9 | from common.logger import Logger, AverageMeter
10 | from common.vis import Visualizer
11 | from common.evaluation import Evaluator
12 | from common import utils
13 | from data.dataset import FSSDataset
14 |
15 |
16 |
17 | def test(model, dataloader, nshot, stage):
18 | r""" Test HSNet """
19 |
20 | # Freeze randomness during testing for reproducibility
21 | utils.fix_randseed(0)
22 | average_meter = AverageMeter(dataloader.dataset)
23 |
24 | for idx, batch in enumerate(dataloader):
25 |
26 | # 1. Hypercorrelation Squeeze Networks forward pass
27 | batch = utils.to_cuda(batch)
28 | pred_mask = model.module.predict_mask_nshot(batch, nshot=nshot, stage=stage)
29 |
30 | assert pred_mask.size() == batch['query_mask'].size()
31 |
32 | # 2. Evaluate prediction
33 | area_inter, area_union = Evaluator.classify_prediction(pred_mask.clone(), batch)
34 | average_meter.update(area_inter, area_union, batch['class_id'], loss=None)
35 | average_meter.write_process(idx, len(dataloader), epoch=-1, write_batch_idx=1)
36 |
37 | # Visualize predictions
38 | if Visualizer.visualize:
39 | Visualizer.visualize_prediction_batch(batch['support_imgs'], batch['support_masks'],
40 | batch['query_img'], batch['query_mask'],
41 | pred_mask, batch['class_id'], idx,
42 | area_inter[1].float() / area_union[1].float())
43 |
44 | average_meter.write_result('Test', 0)
45 | miou, fb_iou = average_meter.compute_iou()
46 | return miou, fb_iou
47 |
48 |
49 | if __name__ == '__main__':
50 |
51 | # Arguments parsing
52 | parser = argparse.ArgumentParser(description='Hypercorrelation Squeeze Pytorch Implementation')
53 | parser.add_argument('--datapath', type=str, default='../Datasets_HSN')
54 | parser.add_argument('--benchmark', type=str, default='pascal', choices=['pascal', 'coco', 'fss'])
55 | parser.add_argument('--logpath', type=str, default='')
56 | parser.add_argument('--bsz', type=int, default=10)
57 | parser.add_argument('--nworker', type=int, default=8)
58 | parser.add_argument('--load', type=str, default=None)
59 | parser.add_argument('--fold', type=int, default=0, choices=[0, 1, 2, 3])
60 | parser.add_argument('--nshot', type=int, default=1)
61 | parser.add_argument('--backbone', type=str, default='resnet50', choices=['vgg16', 'resnet50', 'resnet101'])
62 | parser.add_argument('--visualize', action='store_true')
63 | parser.add_argument('--use_original_imgsize', action='store_true')
64 | parser.add_argument('--stage', type=int, default=2)
65 | parser.add_argument('--traincampath', type=str, default='../Datasets_HSN/CAM_Train/')
66 | parser.add_argument('--valcampath', type=str, default='../Datasets_HSN/CAM_Val/')
67 | args = parser.parse_args()
68 | Logger.initialize(args, training=False)
69 |
70 | # Model initialization
71 | model = HypercorrSqueezeNetwork_imr(args.backbone, args.use_original_imgsize)
72 | model.eval()
73 | Logger.log_params(model)
74 |
75 | # Device setup
76 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
77 | Logger.info('# available GPUs: %d' % torch.cuda.device_count())
78 | model = nn.DataParallel(model)
79 | model.to(device)
80 |
81 | # Load trained model
82 | if args.load == '':
83 | raise Exception('Pretrained model not specified.')
84 | model.load_state_dict(torch.load(args.load))
85 |
86 | # Helper classes (for testing) initialization
87 | Evaluator.initialize()
88 | Visualizer.initialize(args.visualize)
89 |
90 | # Dataset initialization
91 | FSSDataset.initialize(img_size=400, datapath=args.datapath, use_original_imgsize=args.use_original_imgsize)
92 | dataloader_test = FSSDataset.build_dataloader(args.benchmark, args.bsz, args.nworker, args.fold, 'test', args.nshot,
93 | cam_train_path=args.traincampath, cam_val_path=args.valcampath)
94 |
95 | # Test HSNet
96 | with torch.no_grad():
97 | test_miou, test_fb_iou = test(model, dataloader_test, args.nshot, args.stage)
98 | Logger.info('Fold %d mIoU: %5.2f \t FB-IoU: %5.2f' % (args.fold, test_miou.item(), test_fb_iou.item()))
99 | Logger.info('==================== Finished Testing ====================')
100 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | r""" Hypercorrelation Squeeze training (validation) code """
2 | import argparse
3 | import pdb
4 |
5 | import torch.optim as optim
6 | import torch.nn as nn
7 | import torch
8 |
9 | from common.logger import Logger, AverageMeter
10 | from common.evaluation import Evaluator
11 | from common import utils
12 | from data.dataset import FSSDataset
13 | from model.hsnet_imr import HypercorrSqueezeNetwork_imr
14 |
15 |
16 | def train(epoch, model, dataloader, optimizer, training, stage):
17 | r""" Train HSNet """
18 |
19 | # Force randomness during training / freeze randomness during testing
20 | utils.fix_randseed(None) if training else utils.fix_randseed(0)
21 | model.module.train_mode() if training else model.module.eval()
22 | average_meter = AverageMeter(dataloader.dataset)
23 |
24 | for idx, batch in enumerate(dataloader):
25 |
26 | # 1. Hypercorrelation Squeeze Networks forward pass
27 | batch = utils.to_cuda(batch)
28 | logit_mask_q, logit_mask_s, losses = model(
29 | query_img=batch['query_img'], support_img=batch['support_imgs'].squeeze(1),
30 | support_cam=batch['support_cams'].squeeze(1), query_cam=batch['query_cam'], stage=stage,
31 | query_mask=batch['query_mask'], support_mask=batch['support_masks'].squeeze(1))
32 | pred_mask_q = logit_mask_q.argmax(dim=1)
33 |
34 | # 2. Compute loss & update model parameters
35 | loss = losses.mean()
36 | if training:
37 | optimizer.zero_grad()
38 | loss.backward()
39 | optimizer.step()
40 |
41 | # 3. Evaluate prediction
42 | area_inter, area_union = Evaluator.classify_prediction(pred_mask_q, batch)
43 | average_meter.update(area_inter, area_union, batch['class_id'], loss.detach().clone())
44 | average_meter.write_process(idx, len(dataloader), epoch, write_batch_idx=50)
45 |
46 | # Write evaluation results
47 | average_meter.write_result('Training' if training else 'Validation', epoch)
48 | avg_loss = utils.mean(average_meter.loss_buf)
49 | miou, fb_iou = average_meter.compute_iou()
50 |
51 | return avg_loss, miou, fb_iou
52 |
53 |
54 | if __name__ == '__main__':
55 | # Arguments parsing
56 | parser = argparse.ArgumentParser(description='Hypercorrelation Squeeze Pytorch Implementation')
57 | parser.add_argument('--datapath', type=str, default='../Datasets_HSN')
58 | parser.add_argument('--benchmark', type=str, default='pascal', choices=['pascal', 'coco', 'fss'])
59 | parser.add_argument('--logpath', type=str, default='')
60 | parser.add_argument('--bsz', type=int, default=40)
61 | parser.add_argument('--lr', type=float, default=4e-4)
62 | parser.add_argument('--niter', type=int, default=400)
63 | parser.add_argument('--nworker', type=int, default=8)
64 | parser.add_argument('--fold', type=int, default=0, choices=[0, 1, 2, 3])
65 | parser.add_argument('--stage', type=int, default=2)
66 | parser.add_argument('--backbone', type=str, default='resnet50', choices=['vgg16', 'resnet50', 'resnet101'])
67 | parser.add_argument('--traincampath', type=str, default='../Datasets_HSN/CAM_Train/')
68 | parser.add_argument('--valcampath', type=str, default='../Datasets_HSN/CAM_Val/')
69 |
70 | args = parser.parse_args()
71 | Logger.initialize(args, training=True)
72 | assert args.bsz % torch.cuda.device_count() == 0
73 |
74 | # Model initialization
75 | model = HypercorrSqueezeNetwork_imr(args.backbone, False)
76 | Logger.log_params(model)
77 |
78 | # Device setup
79 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
80 | Logger.info('# available GPUs: %d' % torch.cuda.device_count())
81 | model = nn.DataParallel(model)
82 | model.to(device)
83 |
84 | # Helper classes (for training) initialization
85 | optimizer = optim.Adam([{"params": model.parameters(), "lr": args.lr}])
86 | Evaluator.initialize()
87 |
88 | # Dataset initialization
89 | FSSDataset.initialize(img_size=400, datapath=args.datapath, use_original_imgsize=False)
90 | dataloader_trn = FSSDataset.build_dataloader(args.benchmark, args.bsz, args.nworker, args.fold, 'trn',
91 | cam_train_path=args.traincampath, cam_val_path=args.valcampath)
92 | dataloader_val = FSSDataset.build_dataloader(args.benchmark, args.bsz, args.nworker, args.fold, 'val',
93 | cam_train_path=args.traincampath, cam_val_path=args.valcampath)
94 |
95 | # Train HSNet
96 | best_val_miou = float('-inf')
97 | best_val_loss = float('inf')
98 | for epoch in range(args.niter):
99 | trn_loss, trn_miou, trn_fb_iou = train(epoch, model, dataloader_trn, optimizer, training=True,
100 | stage=args.stage)
101 | with torch.no_grad():
102 | val_loss, val_miou, val_fb_iou = train(epoch, model, dataloader_val, optimizer, training=False,
103 | stage=args.stage)
104 | # Save the best model
105 | if val_miou > best_val_miou:
106 | best_val_miou = val_miou
107 | Logger.save_model_miou(model, epoch, val_miou)
108 |
109 | Logger.tbd_writer.add_scalars('data/loss', {'trn_loss': trn_loss, 'val_loss': val_loss}, epoch)
110 | Logger.tbd_writer.add_scalars('data/miou', {'trn_miou': trn_miou, 'val_miou': val_miou}, epoch)
111 | Logger.tbd_writer.add_scalars('data/fb_iou', {'trn_fb_iou': trn_fb_iou, 'val_fb_iou': val_fb_iou}, epoch)
112 | Logger.tbd_writer.flush()
113 | Logger.tbd_writer.close()
114 | Logger.info('==================== Finished Training ====================')
115 |
--------------------------------------------------------------------------------