├── README.md ├── __pycache__ ├── clip.cpython-38.pyc ├── model_clip.cpython-38.pyc └── simple_tokenizer.cpython-38.pyc ├── bpe_simple_vocab_16e6.txt.gz ├── clip.py ├── common ├── __pycache__ │ ├── evaluation.cpython-38.pyc │ ├── logger.cpython-38.pyc │ ├── utils.cpython-38.pyc │ └── vis.cpython-38.pyc ├── evaluation.py ├── logger.py ├── utils.py └── vis.py ├── data ├── __pycache__ │ ├── coco.cpython-38.pyc │ ├── dataset.cpython-38.pyc │ ├── fss.cpython-38.pyc │ └── pascal.cpython-38.pyc ├── assets │ ├── architecture.png │ └── qualitative_results.png ├── coco.py ├── dataset.py ├── fss.py ├── pascal.py └── splits │ ├── coco │ ├── trn │ │ ├── fold0.pkl │ │ ├── fold1.pkl │ │ ├── fold2.pkl │ │ └── fold3.pkl │ └── val │ │ ├── fold0.pkl │ │ ├── fold1.pkl │ │ ├── fold2.pkl │ │ └── fold3.pkl │ ├── fss │ ├── test.txt │ ├── trn.txt │ └── val.txt │ └── pascal │ ├── trn │ ├── fold0.txt │ ├── fold1.txt │ ├── fold2.txt │ └── fold3.txt │ └── val │ ├── fold0.txt │ ├── fold1.txt │ ├── fold2.txt │ └── fold3.txt ├── generate_cam_coco.py ├── generate_cam_voc.py ├── model ├── __pycache__ │ ├── hsnet.cpython-38.pyc │ ├── hsnet_imr.cpython-38.pyc │ ├── hsnet_raft_res_multi_group_xiaorong_grouponly.cpython-38.pyc │ └── learner.cpython-38.pyc ├── base │ ├── __pycache__ │ │ ├── conv4d.cpython-38.pyc │ │ ├── correlation.cpython-38.pyc │ │ └── feature.cpython-38.pyc │ ├── conv4d.py │ ├── correlation.py │ └── feature.py ├── hsnet_imr.py └── learner.py ├── model_clip.py ├── pytorch_grad_cam ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-38.pyc │ ├── ablation_cam.cpython-38.pyc │ ├── activations_and_gradients.cpython-38.pyc │ ├── base_cam.cpython-38.pyc │ ├── eigen_cam.cpython-38.pyc │ ├── eigen_grad_cam.cpython-38.pyc │ ├── fullgrad_cam.cpython-38.pyc │ ├── grad_cam.cpython-38.pyc │ ├── grad_cam_plusplus.cpython-38.pyc │ ├── guided_backprop.cpython-38.pyc │ ├── layer_cam.cpython-38.pyc │ ├── score_cam.cpython-38.pyc │ └── xgrad_cam.cpython-38.pyc ├── ablation_cam.py ├── activations_and_gradients.py ├── base_cam.py ├── eigen_cam.py ├── eigen_grad_cam.py ├── fullgrad_cam.py ├── grad_cam.py ├── grad_cam_plusplus.py ├── guided_backprop.py ├── layer_cam.py ├── score_cam.py ├── utils │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-38.pyc │ │ ├── find_layers.cpython-38.pyc │ │ ├── image.cpython-38.pyc │ │ └── svd_on_activations.cpython-38.pyc │ ├── find_layers.py │ ├── image.py │ └── svd_on_activations.py └── xgrad_cam.py ├── simple_tokenizer.py ├── test.py └── train.py /README.md: -------------------------------------------------------------------------------- 1 | 2 | ## Iterative Few-shot Semantic Segmentation from Image Label Text 3 | This is the implementation of the paper "Iterative Few-shot Semantic Segmentation from Image Label Text" (IJCAI 2022). 4 | The codes are implemented based on HSNet(https://github.com/juhongm999/hsnet), CLIP(https://github.com/openai/CLIP), and https://github.com/jacobgil/pytorch-grad-cam. Thanks for their great work! 5 | 6 | ## Requirements 7 | Following HSNet: 8 | - Python 3.7 9 | - PyTorch 1.5.1 10 | - cuda 10.1 11 | - tensorboard 1.14 12 | 13 | Conda environment settings: 14 | ```bash 15 | conda create -n hsnet python=3.7 16 | conda activate hsnet 17 | 18 | conda install pytorch=1.5.1 torchvision cudatoolkit=10.1 -c pytorch 19 | conda install -c conda-forge tensorflow 20 | pip install tensorboardX 21 | ``` 22 | ## Preparing Few-Shot Segmentation Datasets 23 | Download following datasets: 24 | 25 | > #### 1. PASCAL-5i 26 | > Download PASCAL VOC2012 devkit (train/val data): 27 | > ```bash 28 | > wget http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar 29 | > ``` 30 | > Download PASCAL VOC2012 SDS extended mask annotations from HSNet [[Google Drive](https://drive.google.com/file/d/10zxG2VExoEZUeyQl_uXga2OWHjGeZaf2/view?usp=sharing)]. 31 | 32 | > #### 2. COCO-20i 33 | > Download COCO2014 train/val images and annotations: 34 | > ```bash 35 | > wget http://images.cocodataset.org/zips/train2014.zip 36 | > wget http://images.cocodataset.org/zips/val2014.zip 37 | > wget http://images.cocodataset.org/annotations/annotations_trainval2014.zip 38 | > ``` 39 | > Download COCO2014 train/val annotations from HSNet Google Drive: [[train2014.zip](https://drive.google.com/file/d/1cwup51kcr4m7v9jO14ArpxKMA4O3-Uge/view?usp=sharing)], [[val2014.zip](https://drive.google.com/file/d/1PNw4U3T2MhzAEBWGGgceXvYU3cZ7mJL1/view?usp=sharing)]. 40 | > (and locate both train2014/ and val2014/ under annotations/ directory). 41 | 42 | 43 | 44 | Create a directory '../Datasets_HSN' for the above three few-shot segmentation datasets and appropriately place each dataset to have following directory structure: 45 | 46 | ../ # parent directory 47 | ├── ./ # current (project) directory 48 | │ ├── common/ # (dir.) helper functions 49 | │ ├── data/ # (dir.) dataloaders and splits for each FSSS dataset 50 | │ ├── model/ # (dir.) implementation of Hypercorrelation Squeeze Network model 51 | │ ├── README.md # intstruction for reproduction 52 | │ ├── train.py # code for training HSNet 53 | │ └── test.py # code for testing HSNet 54 | └── Datasets_HSN/ 55 | ├── VOC2012/ # PASCAL VOC2012 devkit 56 | │ ├── Annotations/ 57 | │ ├── ImageSets/ 58 | │ ├── ... 59 | │ └── SegmentationClassAug/ 60 | ├── COCO2014/ 61 | │ ├── annotations/ 62 | │ │ ├── train2014/ # (dir.) training masks (from Google Drive) 63 | │ │ ├── val2014/ # (dir.) validation masks (from Google Drive) 64 | │ │ └── ..some json files.. 65 | │ ├── train2014/ 66 | │ └── val2014/ 67 | ├── CAM_VOC_Train/ 68 | ├── CAM_VOC_Val/ 69 | └── CAM_COCO/ 70 | 71 | 72 | ## Preparing CAM for Few-Shot Segmentation Datasets 73 | > ### 1. PASCAL-5i 74 | > * Generate Grad CAM for images 75 | > ```bash 76 | > python generate_cam_voc.py --traincampath ../Datasets_HSN/CAM_VOC_Train/ 77 | > --valcampath ../Datasets_HSN/CAM_VOC_Val/ 78 | > ``` 79 | ### 2. COCO-20i 80 | > ```bash 81 | > python generate_cam_coco.py --campath ../Datasets_HSN/CAM_COCO/ 82 | 83 | 84 | 85 | 86 | ## Training 87 | > ### 1. PASCAL-5i 88 | > ```bash 89 | > python train.py --backbone {vgg16, resnet50} 90 | > --fold {0, 1, 2, 3} 91 | > --benchmark pascal 92 | > --lr 4e-4 93 | > --bsz 40 94 | > --stage 2 95 | > --logpath "your_experiment_name" 96 | > --traincampath ../Datasets_HSN/CAM_VOC_Train/ 97 | > --valcampath ../Datasets_HSN/CAM_VOC_Val/ 98 | > ``` 99 | > * Training takes approx. 1 days until convergence (trained with four V100 GPUs). 100 | 101 | 102 | > ### 2. COCO-20i 103 | > ```bash 104 | > python train.py --backbone {vgg16, resnet50} 105 | > --fold {0, 1, 2, 3} 106 | > --benchmark coco 107 | > --lr 2e-4 108 | > --bsz 20 109 | > --stage 3 110 | > --logpath "your_experiment_name" 111 | > --traincampath ../Datasets_HSN/CAM_COCO/ 112 | > --valcampath ../Datasets_HSN/CAM_COCO/ 113 | > ``` 114 | > * Training takes approx. 1 week until convergence (trained four V100 GPUs). 115 | 116 | 117 | > ### Babysitting training: 118 | > Use tensorboard to babysit training progress: 119 | > - For each experiment, a directory that logs training progress will be automatically generated under logs/ directory. 120 | > - From terminal, run 'tensorboard --logdir logs/' to monitor the training progress. 121 | > - Choose the best model when the validation (mIoU) curve starts to saturate. 122 | 123 | 124 | 125 | ## Testing 126 | 127 | > ### 1. PASCAL-5i 128 | > Pretrained models with tensorboard logs are available on our [[Google Drive](https://drive.google.com/drive/folders/1fB3_jUEw972lDZIs3_S7lj2F5rZVq4Nu?usp=sharing)]. 129 | > ```bash 130 | > python test.py --backbone {vgg16, resnet50} 131 | > --fold {0, 1, 2, 3} 132 | > --benchmark pascal 133 | > --nshot {1, 5} 134 | > --load "path_to_trained_model/best_model.pt" 135 | > ``` 136 | 137 | 138 | > ### 2. COCO-20i 139 | > Pretrained models with tensorboard logs are available on our [[Google Drive](https://drive.google.com/drive/folders/1fB3_jUEw972lDZIs3_S7lj2F5rZVq4Nu?usp=sharing)]. 140 | > ```bash 141 | > python test.py --backbone {vgg16, resnet50} 142 | > --fold {0, 1, 2, 3} 143 | > --benchmark coco 144 | > --nshot {1, 5} 145 | > --load "path_to_trained_model/best_model.pt" 146 | > ``` 147 | 148 | 149 | 150 | 151 | ## BibTeX 152 | If you use this code for your research, please consider citing: 153 | ````BibTeX 154 | @inproceedings{ijcai2022p193, 155 | title = {Iterative Few-shot Semantic Segmentation from Image Label Text}, 156 | author = {Wang, Haohan and Liu, Liang and Zhang, Wuhao and Zhang, Jiangning and Gan, Zhenye and Wang, Yabiao and Wang, Chengjie and Wang, Haoqian}, 157 | booktitle = {Proceedings of the Thirty-First International Joint Conference on 158 | Artificial Intelligence, {IJCAI-22}}, 159 | publisher = {International Joint Conferences on Artificial Intelligence Organization}, 160 | editor = {Lud De Raedt}, 161 | pages = {1385--1392}, 162 | year = {2022}, 163 | month = {7}, 164 | note = {Main Track}, 165 | doi = {10.24963/ijcai.2022/193}, 166 | url = {https://doi.org/10.24963/ijcai.2022/193}, 167 | } 168 | ```` 169 | -------------------------------------------------------------------------------- /__pycache__/clip.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Whileherham/IMR-HSNet/57572272f493c3b16540a650ff0398d605e46812/__pycache__/clip.cpython-38.pyc -------------------------------------------------------------------------------- /__pycache__/model_clip.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Whileherham/IMR-HSNet/57572272f493c3b16540a650ff0398d605e46812/__pycache__/model_clip.cpython-38.pyc -------------------------------------------------------------------------------- /__pycache__/simple_tokenizer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Whileherham/IMR-HSNet/57572272f493c3b16540a650ff0398d605e46812/__pycache__/simple_tokenizer.cpython-38.pyc -------------------------------------------------------------------------------- /bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Whileherham/IMR-HSNet/57572272f493c3b16540a650ff0398d605e46812/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /clip.py: -------------------------------------------------------------------------------- 1 | # Slightly modified to extract self-attention scores 2 | # Original code: https://github.com/openai/CLIP 3 | 4 | import hashlib 5 | import os 6 | import urllib 7 | import warnings 8 | from typing import Union, List 9 | 10 | import torch 11 | from PIL import Image 12 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize 13 | from tqdm import tqdm 14 | 15 | from model_clip import build_model 16 | from simple_tokenizer import SimpleTokenizer as _Tokenizer 17 | _tokenizer = _Tokenizer() 18 | 19 | __all__ = ["available_models", "load"] 20 | 21 | _MODELS = { 22 | "RN50": "https://openaipublic.azureedge.net/clip/models\ 23 | /afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 24 | "RN101": "https://openaipublic.azureedge.net/clip/models\ 25 | /8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 26 | "RN50x4": "https://openaipublic.azureedge.net/clip/models\ 27 | /7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", 28 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models\ 29 | /40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 30 | } 31 | 32 | 33 | def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")): 34 | os.makedirs(root, exist_ok=True) 35 | filename = os.path.basename(url) 36 | 37 | expected_sha256 = url.split("/")[-2] 38 | download_target = os.path.join(root, filename) 39 | 40 | if os.path.exists(download_target) and not os.path.isfile(download_target): 41 | raise RuntimeError(f"{download_target} exists and is not a regular file") 42 | 43 | if os.path.isfile(download_target): 44 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: 45 | return download_target 46 | else: 47 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") 48 | 49 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 50 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop: 51 | while True: 52 | buffer = source.read(8192) 53 | if not buffer: 54 | break 55 | 56 | output.write(buffer) 57 | loop.update(len(buffer)) 58 | 59 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: 60 | raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") 61 | 62 | return download_target 63 | 64 | 65 | def _transform(n_px): 66 | return Compose([ 67 | Resize(n_px, interpolation=Image.BICUBIC), 68 | CenterCrop(n_px), 69 | lambda image: image.convert("RGB"), 70 | ToTensor(), 71 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 72 | ]) 73 | 74 | 75 | def available_models() -> List[str]: 76 | """Returns the names of available CLIP models""" 77 | return list(_MODELS.keys()) 78 | 79 | 80 | def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=True): 81 | """Load a CLIP model 82 | 83 | Parameters 84 | ---------- 85 | name : str 86 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 87 | 88 | device : Union[str, torch.device] 89 | The device to put the loaded model 90 | 91 | jit : bool 92 | Whether to load the optimized JIT model (default) or more hackable non-JIT model. 93 | 94 | Returns 95 | ------- 96 | model : torch.nn.Module 97 | The CLIP model 98 | 99 | preprocess : Callable[[PIL.Image], torch.Tensor] 100 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 101 | """ 102 | 103 | if name in _MODELS: 104 | model_path = _download(_MODELS[name]) 105 | elif os.path.isfile(name): 106 | model_path = name 107 | else: 108 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}") 109 | 110 | try: 111 | # loading JIT archive 112 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() 113 | state_dict = None 114 | except RuntimeError: 115 | # loading saved state dict 116 | if jit: 117 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") 118 | jit = False 119 | state_dict = torch.load(model_path, map_location="cpu") 120 | 121 | if not jit: 122 | model = build_model(state_dict or model.state_dict()).to(device) 123 | if str(device) == "cpu": 124 | model.float() 125 | return model, _transform(model.visual.input_resolution) 126 | 127 | # patch the device names 128 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) 129 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] 130 | 131 | def patch_device(module): 132 | graphs = [module.graph] if hasattr(module, "graph") else [] 133 | if hasattr(module, "forward1"): 134 | graphs.append(module.forward1.graph) 135 | 136 | for graph in graphs: 137 | for node in graph.findAllNodes("prim::Constant"): 138 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): 139 | node.copyAttributes(device_node) 140 | 141 | model.apply(patch_device) 142 | patch_device(model.encode_image) 143 | patch_device(model.encode_text) 144 | 145 | # patch dtype to float32 on CPU 146 | if str(device) == "cpu": 147 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) 148 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 149 | float_node = float_input.node() 150 | 151 | def patch_float(module): 152 | graphs = [module.graph] if hasattr(module, "graph") else [] 153 | if hasattr(module, "forward1"): 154 | graphs.append(module.forward1.graph) 155 | 156 | for graph in graphs: 157 | for node in graph.findAllNodes("aten::to"): 158 | inputs = list(node.inputs()) 159 | for i in [1, 2]: # dtype can be the second or third argument to aten::to() 160 | if inputs[i].node()["value"] == 5: 161 | inputs[i].node().copyAttributes(float_node) 162 | 163 | model.apply(patch_float) 164 | patch_float(model.encode_image) 165 | patch_float(model.encode_text) 166 | 167 | model.float() 168 | 169 | return model, _transform(model.input_resolution.item()) 170 | 171 | 172 | def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor: 173 | """ 174 | Returns the tokenized representation of given input string(s) 175 | Parameters 176 | ---------- 177 | texts : Union[str, List[str]] 178 | An input string or a list of input strings to tokenize 179 | context_length : int 180 | The context length to use; all CLIP models use 77 as the context length 181 | truncate: bool 182 | Whether to truncate the text in case its encoding is longer than the context length 183 | Returns 184 | ------- 185 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] 186 | """ 187 | if isinstance(texts, str): 188 | texts = [texts] 189 | 190 | sot_token = _tokenizer.encoder["<|startoftext|>"] 191 | eot_token = _tokenizer.encoder["<|endoftext|>"] 192 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 193 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 194 | 195 | for i, tokens in enumerate(all_tokens): 196 | if len(tokens) > context_length: 197 | if truncate: 198 | tokens = tokens[:context_length] 199 | tokens[-1] = eot_token 200 | else: 201 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") 202 | result[i, :len(tokens)] = torch.tensor(tokens) 203 | 204 | return result 205 | -------------------------------------------------------------------------------- /common/__pycache__/evaluation.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Whileherham/IMR-HSNet/57572272f493c3b16540a650ff0398d605e46812/common/__pycache__/evaluation.cpython-38.pyc -------------------------------------------------------------------------------- /common/__pycache__/logger.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Whileherham/IMR-HSNet/57572272f493c3b16540a650ff0398d605e46812/common/__pycache__/logger.cpython-38.pyc -------------------------------------------------------------------------------- /common/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Whileherham/IMR-HSNet/57572272f493c3b16540a650ff0398d605e46812/common/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /common/__pycache__/vis.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Whileherham/IMR-HSNet/57572272f493c3b16540a650ff0398d605e46812/common/__pycache__/vis.cpython-38.pyc -------------------------------------------------------------------------------- /common/evaluation.py: -------------------------------------------------------------------------------- 1 | r""" Evaluate mask prediction """ 2 | import torch 3 | 4 | 5 | class Evaluator: 6 | r""" Computes intersection and union between prediction and ground-truth """ 7 | @classmethod 8 | def initialize(cls): 9 | cls.ignore_index = 255 10 | 11 | @classmethod 12 | def classify_prediction(cls, pred_mask, batch): 13 | gt_mask = batch.get('query_mask') 14 | 15 | # Apply ignore_index in PASCAL-5i masks (following evaluation scheme in PFE-Net (TPAMI 2020)) 16 | query_ignore_idx = batch.get('query_ignore_idx') 17 | if query_ignore_idx is not None: 18 | assert torch.logical_and(query_ignore_idx, gt_mask).sum() == 0 19 | query_ignore_idx *= cls.ignore_index 20 | gt_mask = gt_mask + query_ignore_idx 21 | pred_mask[gt_mask == cls.ignore_index] = cls.ignore_index 22 | 23 | # compute intersection and union of each episode in a batch 24 | area_inter, area_pred, area_gt = [], [], [] 25 | for _pred_mask, _gt_mask in zip(pred_mask, gt_mask): 26 | _inter = _pred_mask[_pred_mask == _gt_mask] 27 | if _inter.size(0) == 0: # as torch.histc returns error if it gets empty tensor (pytorch 1.5.1) 28 | _area_inter = torch.tensor([0, 0], device=_pred_mask.device) 29 | else: 30 | _area_inter = torch.histc(_inter, bins=2, min=0, max=1) 31 | area_inter.append(_area_inter) 32 | area_pred.append(torch.histc(_pred_mask, bins=2, min=0, max=1)) 33 | area_gt.append(torch.histc(_gt_mask, bins=2, min=0, max=1)) 34 | area_inter = torch.stack(area_inter).t() 35 | area_pred = torch.stack(area_pred).t() 36 | area_gt = torch.stack(area_gt).t() 37 | area_union = area_pred + area_gt - area_inter 38 | 39 | return area_inter, area_union 40 | -------------------------------------------------------------------------------- /common/logger.py: -------------------------------------------------------------------------------- 1 | r""" Logging during training/testing """ 2 | import datetime 3 | import logging 4 | import os 5 | 6 | from tensorboardX import SummaryWriter 7 | import torch 8 | 9 | 10 | class AverageMeter: 11 | r""" Stores loss, evaluation results """ 12 | def __init__(self, dataset): 13 | self.benchmark = dataset.benchmark 14 | self.class_ids_interest = dataset.class_ids 15 | self.class_ids_interest = torch.tensor(self.class_ids_interest).cuda() 16 | 17 | if self.benchmark == 'pascal': 18 | self.nclass = 20 19 | elif self.benchmark == 'coco': 20 | self.nclass = 80 21 | elif self.benchmark == 'fss': 22 | self.nclass = 1000 23 | 24 | self.intersection_buf = torch.zeros([2, self.nclass]).float().cuda() 25 | self.union_buf = torch.zeros([2, self.nclass]).float().cuda() 26 | self.ones = torch.ones_like(self.union_buf) 27 | self.loss_buf = [] 28 | 29 | def update(self, inter_b, union_b, class_id, loss): 30 | self.intersection_buf.index_add_(1, class_id, inter_b.float()) 31 | self.union_buf.index_add_(1, class_id, union_b.float()) 32 | if loss is None: 33 | loss = torch.tensor(0.0) 34 | self.loss_buf.append(loss) 35 | 36 | def compute_iou(self): 37 | iou = self.intersection_buf.float() / \ 38 | torch.max(torch.stack([self.union_buf, self.ones]), dim=0)[0] 39 | iou = iou.index_select(1, self.class_ids_interest) 40 | miou = iou[1].mean() * 100 41 | 42 | fb_iou = (self.intersection_buf.index_select(1, self.class_ids_interest).sum(dim=1) / 43 | self.union_buf.index_select(1, self.class_ids_interest).sum(dim=1)).mean() * 100 44 | 45 | return miou, fb_iou 46 | 47 | def write_result(self, split, epoch): 48 | iou, fb_iou = self.compute_iou() 49 | 50 | loss_buf = torch.stack(self.loss_buf) 51 | msg = '\n*** %s ' % split 52 | msg += '[@Epoch %02d] ' % epoch 53 | msg += 'Avg L: %6.5f ' % loss_buf.mean() 54 | msg += 'mIoU: %5.2f ' % iou 55 | msg += 'FB-IoU: %5.2f ' % fb_iou 56 | 57 | msg += '***\n' 58 | Logger.info(msg) 59 | 60 | def write_process(self, batch_idx, datalen, epoch, write_batch_idx=20): 61 | if batch_idx % write_batch_idx == 0: 62 | msg = '[Epoch: %02d] ' % epoch if epoch != -1 else '' 63 | msg += '[Batch: %04d/%04d] ' % (batch_idx+1, datalen) 64 | iou, fb_iou = self.compute_iou() 65 | if epoch != -1: 66 | loss_buf = torch.stack(self.loss_buf) 67 | msg += 'L: %6.5f ' % loss_buf[-1] 68 | msg += 'Avg L: %6.5f ' % loss_buf.mean() 69 | msg += 'mIoU: %5.2f | ' % iou 70 | msg += 'FB-IoU: %5.2f' % fb_iou 71 | Logger.info(msg) 72 | 73 | 74 | class Logger: 75 | r""" Writes evaluation results of training/testing """ 76 | @classmethod 77 | def initialize(cls, args, training): 78 | logtime = datetime.datetime.now().__format__('_%m%d_%H%M%S') 79 | logpath = args.logpath if training else '_TEST_' + args.load.split('/')[-2].split('.')[0] + logtime 80 | if logpath == '': 81 | logpath = logtime 82 | 83 | cls.logpath = os.path.join('logs', logpath + '.log') 84 | cls.benchmark = args.benchmark 85 | os.makedirs(cls.logpath) 86 | 87 | logging.basicConfig(filemode='w', 88 | filename=os.path.join(cls.logpath, 'log.txt'), 89 | level=logging.INFO, 90 | format='%(message)s', 91 | datefmt='%m-%d %H:%M:%S') 92 | 93 | # Console log config 94 | console = logging.StreamHandler() 95 | console.setLevel(logging.INFO) 96 | formatter = logging.Formatter('%(message)s') 97 | console.setFormatter(formatter) 98 | logging.getLogger('').addHandler(console) 99 | 100 | # Tensorboard writer 101 | cls.tbd_writer = SummaryWriter(os.path.join(cls.logpath, 'tbd/runs')) 102 | 103 | # Log arguments 104 | logging.info('\n:=========== Few-shot Seg. with HSNet ===========') 105 | for arg_key in args.__dict__: 106 | logging.info('| %20s: %-24s' % (arg_key, str(args.__dict__[arg_key]))) 107 | logging.info(':================================================\n') 108 | 109 | @classmethod 110 | def info(cls, msg): 111 | r""" Writes log message to log.txt """ 112 | logging.info(msg) 113 | 114 | @classmethod 115 | def save_model_miou(cls, model, epoch, val_miou): 116 | torch.save(model.state_dict(), os.path.join(cls.logpath, 'best_model.pt')) 117 | cls.info('Model saved @%d w/ val. mIoU: %5.2f.\n' % (epoch, val_miou)) 118 | 119 | @classmethod 120 | def log_params(cls, model): 121 | backbone_param = 0 122 | learner_param = 0 123 | for k in model.state_dict().keys(): 124 | n_param = model.state_dict()[k].view(-1).size(0) 125 | if k.split('.')[0] in 'backbone': 126 | if k.split('.')[1] in ['classifier', 'fc']: # as fc layers are not used in HSNet 127 | continue 128 | backbone_param += n_param 129 | else: 130 | learner_param += n_param 131 | Logger.info('Backbone # param.: %d' % backbone_param) 132 | Logger.info('Learnable # param.: %d' % learner_param) 133 | Logger.info('Total # param.: %d' % (backbone_param + learner_param)) 134 | 135 | 136 | -------------------------------------------------------------------------------- /common/utils.py: -------------------------------------------------------------------------------- 1 | r""" Helper functions """ 2 | import random 3 | 4 | import torch 5 | import numpy as np 6 | 7 | 8 | def fix_randseed(seed): 9 | r""" Set random seeds for reproducibility """ 10 | if seed is None: 11 | seed = int(random.random() * 1e5) 12 | np.random.seed(seed) 13 | torch.manual_seed(seed) 14 | torch.cuda.manual_seed(seed) 15 | torch.cuda.manual_seed_all(seed) 16 | torch.backends.cudnn.benchmark = False 17 | torch.backends.cudnn.deterministic = True 18 | 19 | 20 | def mean(x): 21 | return sum(x) / len(x) if len(x) > 0 else 0.0 22 | 23 | 24 | def to_cuda(batch): 25 | for key, value in batch.items(): 26 | if isinstance(value, torch.Tensor): 27 | batch[key] = value.cuda() 28 | return batch 29 | 30 | 31 | def to_cpu(tensor): 32 | return tensor.detach().clone().cpu() 33 | -------------------------------------------------------------------------------- /common/vis.py: -------------------------------------------------------------------------------- 1 | r""" Visualize model predictions """ 2 | import os 3 | 4 | from PIL import Image 5 | import numpy as np 6 | import torchvision.transforms as transforms 7 | 8 | from . import utils 9 | 10 | 11 | class Visualizer: 12 | 13 | @classmethod 14 | def initialize(cls, visualize): 15 | cls.visualize = visualize 16 | if not visualize: 17 | return 18 | 19 | cls.colors = {'red': (255, 50, 50), 'blue': (102, 140, 255)} 20 | for key, value in cls.colors.items(): 21 | cls.colors[key] = tuple([c / 255 for c in cls.colors[key]]) 22 | 23 | cls.mean_img = [0.485, 0.456, 0.406] 24 | cls.std_img = [0.229, 0.224, 0.225] 25 | cls.to_pil = transforms.ToPILImage() 26 | cls.vis_path = './vis/' 27 | if not os.path.exists(cls.vis_path): 28 | os.makedirs(cls.vis_path) 29 | 30 | @classmethod 31 | def visualize_prediction_batch( 32 | cls, spt_img_b, spt_mask_b, qry_img_b, qry_mask_b, pred_mask_b, cls_id_b, batch_idx, iou_b=None): 33 | spt_img_b = utils.to_cpu(spt_img_b) 34 | spt_mask_b = utils.to_cpu(spt_mask_b) 35 | qry_img_b = utils.to_cpu(qry_img_b) 36 | qry_mask_b = utils.to_cpu(qry_mask_b) 37 | pred_mask_b = utils.to_cpu(pred_mask_b) 38 | cls_id_b = utils.to_cpu(cls_id_b) 39 | 40 | for sample_idx, (spt_img, spt_mask, qry_img, qry_mask, pred_mask, cls_id) in \ 41 | enumerate(zip(spt_img_b, spt_mask_b, qry_img_b, qry_mask_b, pred_mask_b, cls_id_b)): 42 | iou = iou_b[sample_idx] if iou_b is not None else None 43 | cls.visualize_prediction( 44 | spt_img, spt_mask, qry_img, qry_mask, pred_mask, cls_id, batch_idx, sample_idx, True, iou) 45 | 46 | @classmethod 47 | def to_numpy(cls, tensor, type): 48 | if type == 'img': 49 | return np.array(cls.to_pil(cls.unnormalize(tensor))).astype(np.uint8) 50 | elif type == 'mask': 51 | return np.array(tensor).astype(np.uint8) 52 | else: 53 | raise Exception('Undefined tensor type: %s' % type) 54 | 55 | @classmethod 56 | def visualize_prediction( 57 | cls, spt_imgs, spt_masks, qry_img, qry_mask, pred_mask, cls_id, batch_idx, sample_idx, label, iou=None): 58 | 59 | spt_color = cls.colors['blue'] 60 | qry_color = cls.colors['red'] 61 | pred_color = cls.colors['red'] 62 | 63 | spt_imgs = [cls.to_numpy(spt_img, 'img') for spt_img in spt_imgs] 64 | spt_pils = [cls.to_pil(spt_img) for spt_img in spt_imgs] 65 | spt_masks = [cls.to_numpy(spt_mask, 'mask') for spt_mask in spt_masks] 66 | spt_masked_pils = [Image.fromarray( 67 | cls.apply_mask(spt_img, spt_mask, spt_color)) for spt_img, spt_mask in zip(spt_imgs, spt_masks)] 68 | 69 | qry_img = cls.to_numpy(qry_img, 'img') 70 | qry_pil = cls.to_pil(qry_img) 71 | qry_mask = cls.to_numpy(qry_mask, 'mask') 72 | pred_mask = cls.to_numpy(pred_mask, 'mask') 73 | pred_masked_pil = \ 74 | Image.fromarray(cls.apply_mask(qry_img.astype(np.uint8), pred_mask.astype(np.uint8), pred_color)) 75 | qry_masked_pil = Image.fromarray(cls.apply_mask(qry_img.astype(np.uint8), qry_mask.astype(np.uint8), qry_color)) 76 | 77 | merged_pil = cls.merge_image_pair(spt_masked_pils + [pred_masked_pil, qry_masked_pil]) 78 | 79 | iou = iou.item() if iou else 0.0 80 | merged_pil.save(cls.vis_path + '%d_%d_class-%d_iou-%.2f' % (batch_idx, sample_idx, cls_id, iou) + '.jpg') 81 | 82 | @classmethod 83 | def merge_image_pair(cls, pil_imgs): 84 | r""" Horizontally aligns a pair of pytorch tensor images (3, H, W) and returns PIL object """ 85 | 86 | canvas_width = sum([pil.size[0] for pil in pil_imgs]) 87 | canvas_height = max([pil.size[1] for pil in pil_imgs]) 88 | canvas = Image.new('RGB', (canvas_width, canvas_height)) 89 | 90 | xpos = 0 91 | for pil in pil_imgs: 92 | canvas.paste(pil, (xpos, 0)) 93 | xpos += pil.size[0] 94 | 95 | return canvas 96 | 97 | @classmethod 98 | def apply_mask(cls, image, mask, color, alpha=0.5): 99 | r""" Apply mask to the given image. """ 100 | for c in range(3): 101 | image[:, :, c] = np.where(mask == 1, 102 | image[:, :, c] * 103 | (1 - alpha) + alpha * color[c] * 255, 104 | image[:, :, c]) 105 | return image 106 | 107 | @classmethod 108 | def unnormalize(cls, img): 109 | img = img.clone() 110 | for im_channel, mean, std in zip(img, cls.mean_img, cls.std_img): 111 | im_channel.mul_(std).add_(mean) 112 | return img 113 | -------------------------------------------------------------------------------- /data/__pycache__/coco.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Whileherham/IMR-HSNet/57572272f493c3b16540a650ff0398d605e46812/data/__pycache__/coco.cpython-38.pyc -------------------------------------------------------------------------------- /data/__pycache__/dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Whileherham/IMR-HSNet/57572272f493c3b16540a650ff0398d605e46812/data/__pycache__/dataset.cpython-38.pyc -------------------------------------------------------------------------------- /data/__pycache__/fss.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Whileherham/IMR-HSNet/57572272f493c3b16540a650ff0398d605e46812/data/__pycache__/fss.cpython-38.pyc -------------------------------------------------------------------------------- /data/__pycache__/pascal.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Whileherham/IMR-HSNet/57572272f493c3b16540a650ff0398d605e46812/data/__pycache__/pascal.cpython-38.pyc -------------------------------------------------------------------------------- /data/assets/architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Whileherham/IMR-HSNet/57572272f493c3b16540a650ff0398d605e46812/data/assets/architecture.png -------------------------------------------------------------------------------- /data/assets/qualitative_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Whileherham/IMR-HSNet/57572272f493c3b16540a650ff0398d605e46812/data/assets/qualitative_results.png -------------------------------------------------------------------------------- /data/coco.py: -------------------------------------------------------------------------------- 1 | r""" COCO-20i few-shot semantic segmentation dataset """ 2 | import os 3 | import pickle 4 | 5 | from torch.utils.data import Dataset 6 | import torch.nn.functional as F 7 | import torch 8 | import PIL.Image as Image 9 | import numpy as np 10 | 11 | 12 | class DatasetCOCO(Dataset): 13 | def __init__(self, datapath, fold, transform, split, shot, use_original_imgsize, cam_train_path, cam_val_path): 14 | self.split = 'val' if split in ['val', 'test'] else 'trn' 15 | self.fold = fold 16 | self.nfolds = 4 17 | self.nclass = 80 18 | self.benchmark = 'coco' 19 | self.shot = shot 20 | self.split_coco = split if split == 'val2014' else 'train2014' 21 | self.base_path = os.path.join(datapath, 'COCO2014') 22 | self.transform = transform 23 | self.use_original_imgsize = use_original_imgsize 24 | 25 | self.class_ids = self.build_class_ids() 26 | self.img_metadata_classwise = self.build_img_metadata_classwise() 27 | self.img_metadata = self.build_img_metadata() 28 | 29 | assert cam_train_path == cam_val_path 30 | # This is because, the query name of COCO includes "train2014-img" pr "val2014-img" 31 | # But VOC only includes "img" 32 | self.cam_path = cam_train_path 33 | 34 | 35 | def __len__(self): 36 | return len(self.img_metadata) if self.split == 'trn' else 1000 37 | 38 | def __getitem__(self, idx): 39 | # ignores idx during training & testing and perform uniform sampling over object classes to form an episode 40 | # (due to the large size of the COCO dataset) 41 | query_img, query_mask, support_imgs, support_masks, \ 42 | query_name, support_names, class_sample, org_qry_imsize = self.load_frame() 43 | 44 | query_img = self.transform(query_img) 45 | query_mask = query_mask.float() 46 | if not self.use_original_imgsize: 47 | query_mask = F.interpolate( 48 | query_mask.unsqueeze(0).unsqueeze(0).float(), query_img.size()[-2:], mode='nearest').squeeze() 49 | 50 | support_imgs = torch.stack([self.transform(support_img) for support_img in support_imgs]) 51 | for midx, smask in enumerate(support_masks): 52 | support_masks[midx] = F.interpolate( 53 | smask.unsqueeze(0).unsqueeze(0).float(), support_imgs.size()[-2:], mode='nearest').squeeze() 54 | support_masks = torch.stack(support_masks) 55 | query_cam_path = self.cam_path + query_name + '--' + str(class_sample) + '.pt' 56 | query_cam = torch.load(query_cam_path) # 50 50 57 | 58 | nshot = len(support_names) 59 | support_cams = [] 60 | for nn in range(nshot): 61 | support_cam_path = self.cam_path + support_names[nn] + '--' + str(class_sample) + '.pt' 62 | support_cam = torch.load(support_cam_path).unsqueeze(0) # 1 50 50 63 | support_cams.append(support_cam) 64 | support_cams = torch.cat(support_cams, dim=0) # nshot 50 50 65 | 66 | batch = {'query_img': query_img, 67 | 'query_mask': query_mask, 68 | 'query_name': query_name, 69 | 70 | 'org_query_imsize': org_qry_imsize, 71 | 72 | 'support_imgs': support_imgs, 73 | 'support_masks': support_masks, 74 | 'support_names': support_names, 75 | 'class_id': torch.tensor(class_sample), 76 | 'query_cam': query_cam, 77 | 'support_cams': support_cams 78 | } 79 | 80 | return batch 81 | 82 | def build_class_ids(self): 83 | nclass_trn = self.nclass // self.nfolds 84 | class_ids_val = [self.fold + self.nfolds * v for v in range(nclass_trn)] 85 | class_ids_trn = [x for x in range(self.nclass) if x not in class_ids_val] 86 | class_ids = class_ids_trn if self.split == 'trn' else class_ids_val 87 | 88 | return class_ids 89 | 90 | def build_img_metadata_classwise(self): 91 | with open('./data/splits/coco/%s/fold%d.pkl' % (self.split, self.fold), 'rb') as f: 92 | img_metadata_classwise = pickle.load(f) 93 | return img_metadata_classwise 94 | 95 | def build_img_metadata(self): 96 | img_metadata = [] 97 | for k in self.img_metadata_classwise.keys(): 98 | img_metadata += self.img_metadata_classwise[k] 99 | return sorted(list(set(img_metadata))) 100 | 101 | def read_mask(self, name): 102 | mask_path = os.path.join(self.base_path, 'annotations', name) 103 | mask = torch.tensor(np.array(Image.open(mask_path[:mask_path.index('.jpg')] + '.png'))) 104 | return mask 105 | 106 | def load_frame(self): 107 | class_sample = np.random.choice(self.class_ids, 1, replace=False)[0] 108 | query_name = np.random.choice(self.img_metadata_classwise[class_sample], 1, replace=False)[0] 109 | query_img = Image.open(os.path.join(self.base_path, query_name)).convert('RGB') 110 | query_mask = self.read_mask(query_name) 111 | 112 | org_qry_imsize = query_img.size 113 | 114 | query_mask[query_mask != class_sample + 1] = 0 115 | query_mask[query_mask == class_sample + 1] = 1 116 | 117 | support_names = [] 118 | while True: # keep sampling support set if query == support 119 | support_name = np.random.choice(self.img_metadata_classwise[class_sample], 1, replace=False)[0] 120 | if query_name != support_name: 121 | support_names.append(support_name) 122 | if len(support_names) == self.shot: 123 | break 124 | 125 | support_imgs = [] 126 | support_masks = [] 127 | for support_name in support_names: 128 | support_imgs.append(Image.open(os.path.join(self.base_path, support_name)).convert('RGB')) 129 | support_mask = self.read_mask(support_name) 130 | support_mask[support_mask != class_sample + 1] = 0 131 | support_mask[support_mask == class_sample + 1] = 1 132 | support_masks.append(support_mask) 133 | 134 | return \ 135 | query_img, query_mask, support_imgs, support_masks, query_name, support_names, class_sample, org_qry_imsize 136 | -------------------------------------------------------------------------------- /data/dataset.py: -------------------------------------------------------------------------------- 1 | r""" Dataloader builder for few-shot semantic segmentation dataset """ 2 | from torchvision import transforms 3 | from torch.utils.data import DataLoader 4 | 5 | from data.pascal import DatasetPASCAL 6 | from data.coco import DatasetCOCO 7 | from data.fss import DatasetFSS 8 | 9 | 10 | class FSSDataset: 11 | 12 | @classmethod 13 | def initialize(cls, img_size, datapath, use_original_imgsize): 14 | 15 | cls.datasets = { 16 | 'pascal': DatasetPASCAL, 17 | 'coco': DatasetCOCO, 18 | 'fss': DatasetFSS, 19 | } 20 | 21 | cls.img_mean = [0.485, 0.456, 0.406] 22 | cls.img_std = [0.229, 0.224, 0.225] 23 | cls.datapath = datapath 24 | cls.use_original_imgsize = use_original_imgsize 25 | 26 | cls.transform = transforms.Compose([transforms.Resize(size=(img_size, img_size)), 27 | transforms.ToTensor(), 28 | transforms.Normalize(cls.img_mean, cls.img_std)]) 29 | 30 | @classmethod 31 | def build_dataloader(cls, benchmark, bsz, nworker, fold, split, shot=1, cam_train_path=None, cam_val_path=None): 32 | # Force randomness during training for diverse episode combinations 33 | # Freeze randomness during testing for reproducibility 34 | shuffle = split == 'trn' 35 | nworker = nworker if split == 'trn' else 0 36 | 37 | dataset = cls.datasets[benchmark](cls.datapath, fold=fold, transform=cls.transform, split=split, 38 | shot=shot, use_original_imgsize=cls.use_original_imgsize, 39 | cam_train_path=cam_train_path, cam_val_path=cam_val_path) 40 | dataloader = DataLoader(dataset, batch_size=bsz, shuffle=shuffle, num_workers=nworker) 41 | 42 | return dataloader 43 | -------------------------------------------------------------------------------- /data/fss.py: -------------------------------------------------------------------------------- 1 | r""" FSS-1000 few-shot semantic segmentation dataset """ 2 | import os 3 | import glob 4 | 5 | from torch.utils.data import Dataset 6 | import torch.nn.functional as F 7 | import torch 8 | import PIL.Image as Image 9 | import numpy as np 10 | 11 | 12 | class DatasetFSS(Dataset): 13 | def __init__(self, datapath, fold, transform, split, shot, use_original_imgsize): 14 | self.split = split 15 | self.benchmark = 'fss' 16 | self.shot = shot 17 | 18 | self.base_path = os.path.join(datapath, 'FSS-1000') 19 | 20 | # Given predefined test split, load randomly generated training/val splits: 21 | # (reference regarding trn/val/test splits: https://github.com/HKUSTCV/FSS-1000/issues/7)) 22 | with open('./data/splits/fss/%s.txt' % split, 'r') as f: 23 | self.categories = f.read().split('\n')[:-1] 24 | self.categories = sorted(self.categories) 25 | 26 | self.class_ids = self.build_class_ids() 27 | self.img_metadata = self.build_img_metadata() 28 | 29 | self.transform = transform 30 | 31 | def __len__(self): 32 | return len(self.img_metadata) 33 | 34 | def __getitem__(self, idx): 35 | query_name, support_names, class_sample = self.sample_episode(idx) 36 | query_img, query_mask, support_imgs, support_masks = self.load_frame(query_name, support_names) 37 | 38 | query_img = self.transform(query_img) 39 | query_mask = F.interpolate( 40 | query_mask.unsqueeze(0).unsqueeze(0).float(), query_img.size()[-2:], mode='nearest').squeeze() 41 | 42 | support_imgs = torch.stack([self.transform(support_img) for support_img in support_imgs]) 43 | 44 | support_masks_tmp = [] 45 | for smask in support_masks: 46 | smask = F.interpolate( 47 | smask.unsqueeze(0).unsqueeze(0).float(), support_imgs.size()[-2:], mode='nearest').squeeze() 48 | support_masks_tmp.append(smask) 49 | support_masks = torch.stack(support_masks_tmp) 50 | 51 | batch = {'query_img': query_img, 52 | 'query_mask': query_mask, 53 | 'query_name': query_name, 54 | 55 | 'support_imgs': support_imgs, 56 | 'support_masks': support_masks, 57 | 'support_names': support_names, 58 | 59 | 'class_id': torch.tensor(class_sample)} 60 | 61 | return batch 62 | 63 | def load_frame(self, query_name, support_names): 64 | query_img = Image.open(query_name).convert('RGB') 65 | support_imgs = [Image.open(name).convert('RGB') for name in support_names] 66 | 67 | query_id = query_name.split('/')[-1].split('.')[0] 68 | query_name = os.path.join(os.path.dirname(query_name), query_id) + '.png' 69 | support_ids = [name.split('/')[-1].split('.')[0] for name in support_names] 70 | support_names = [os.path.join(os.path.dirname(name), sid) + '.png' 71 | for name, sid in zip(support_names, support_ids)] 72 | 73 | query_mask = self.read_mask(query_name) 74 | support_masks = [self.read_mask(name) for name in support_names] 75 | 76 | return query_img, query_mask, support_imgs, support_masks 77 | 78 | def read_mask(self, img_name): 79 | mask = torch.tensor(np.array(Image.open(img_name).convert('L'))) 80 | mask[mask < 128] = 0 81 | mask[mask >= 128] = 1 82 | return mask 83 | 84 | def sample_episode(self, idx): 85 | query_name = self.img_metadata[idx] 86 | class_sample = self.categories.index(query_name.split('/')[-2]) 87 | if self.split == 'val': 88 | class_sample += 520 89 | elif self.split == 'test': 90 | class_sample += 760 91 | 92 | support_names = [] 93 | while True: # keep sampling support set if query == support 94 | support_name = np.random.choice(range(1, 11), 1, replace=False)[0] 95 | support_name = os.path.join(os.path.dirname(query_name), str(support_name)) + '.jpg' 96 | if query_name != support_name: 97 | support_names.append(support_name) 98 | if len(support_names) == self.shot: 99 | break 100 | 101 | return query_name, support_names, class_sample 102 | 103 | def build_class_ids(self): 104 | if self.split == 'trn': 105 | class_ids = range(0, 520) 106 | elif self.split == 'val': 107 | class_ids = range(520, 760) 108 | elif self.split == 'test': 109 | class_ids = range(760, 1000) 110 | return class_ids 111 | 112 | def build_img_metadata(self): 113 | img_metadata = [] 114 | for cat in self.categories: 115 | img_paths = sorted([path for path in glob.glob('%s/*' % os.path.join(self.base_path, cat))]) 116 | for img_path in img_paths: 117 | if os.path.basename(img_path).split('.')[1] == 'jpg': 118 | img_metadata.append(img_path) 119 | return img_metadata 120 | -------------------------------------------------------------------------------- /data/pascal.py: -------------------------------------------------------------------------------- 1 | r""" PASCAL-5i few-shot semantic segmentation dataset """ 2 | import os 3 | import pdb 4 | 5 | from torch.utils.data import Dataset 6 | import torch.nn.functional as F 7 | import torch 8 | import PIL.Image as Image 9 | import numpy as np 10 | 11 | 12 | class DatasetPASCAL(Dataset): 13 | def __init__(self, datapath, fold, transform, split, shot, use_original_imgsize, cam_train_path, cam_val_path): 14 | self.split = 'val' if split in ['val', 'test'] else 'trn' 15 | self.fold = fold 16 | self.nfolds = 4 17 | self.nclass = 20 18 | self.benchmark = 'pascal' 19 | self.shot = shot 20 | self.use_original_imgsize = use_original_imgsize 21 | 22 | self.img_path = os.path.join(datapath, 'VOC2012/JPEGImages/') 23 | self.ann_path = os.path.join(datapath, 'VOC2012/SegmentationClassAug/') 24 | self.transform = transform 25 | 26 | self.class_ids = self.build_class_ids() 27 | self.img_metadata = self.build_img_metadata() 28 | self.img_metadata_classwise = self.build_img_metadata_classwise() 29 | 30 | self.cam_train_path = cam_train_path 31 | self.cam_val_path = cam_val_path 32 | 33 | def __len__(self): 34 | return len(self.img_metadata) if self.split == 'trn' else 1000 35 | 36 | def __getitem__(self, idx): 37 | idx %= len(self.img_metadata) # for testing, as n_images < 1000 38 | query_name, support_names, class_sample = self.sample_episode(idx) 39 | query_img, query_cmask, support_imgs, support_cmasks, org_qry_imsize = \ 40 | self.load_frame(query_name, support_names) 41 | 42 | query_img = self.transform(query_img) 43 | if not self.use_original_imgsize: 44 | query_cmask = F.interpolate(query_cmask.unsqueeze(0).unsqueeze(0).float(), query_img.size()[-2:], 45 | mode='nearest').squeeze() 46 | query_mask, query_ignore_idx = self.extract_ignore_idx(query_cmask.float(), class_sample) 47 | 48 | support_imgs = torch.stack([self.transform(support_img) for support_img in support_imgs]) 49 | 50 | support_masks = [] 51 | support_ignore_idxs = [] 52 | for scmask in support_cmasks: 53 | scmask = F.interpolate(scmask.unsqueeze(0).unsqueeze(0).float(), support_imgs.size()[-2:], 54 | mode='nearest').squeeze() 55 | support_mask, support_ignore_idx = self.extract_ignore_idx(scmask, class_sample) 56 | support_masks.append(support_mask) 57 | support_ignore_idxs.append(support_ignore_idx) 58 | support_masks = torch.stack(support_masks) 59 | support_ignore_idxs = torch.stack(support_ignore_idxs) 60 | 61 | if self.split == 'val': 62 | query_cam_path = self.cam_val_path + query_name + '--' + str(class_sample) + '.pt' 63 | query_cam = torch.load(query_cam_path) 64 | nshot = len(support_names) 65 | support_cams = [] 66 | for nn in range(nshot): 67 | support_cam_path = self.cam_val_path + support_names[nn] + '--' + str(class_sample) + '.pt' 68 | support_cam = torch.load(support_cam_path).unsqueeze(0) 69 | support_cams.append(support_cam) 70 | support_cams = torch.cat(support_cams, dim=0) 71 | else: 72 | query_cam_path = self.cam_train_path + query_name + '--' + str(class_sample) + '.pt' 73 | query_cam = torch.load(query_cam_path) 74 | nshot = len(support_names) 75 | support_cams = [] 76 | for nn in range(nshot): 77 | support_cam_path = self.cam_train_path + support_names[nn] + '--' + str(class_sample) + '.pt' 78 | support_cam = torch.load(support_cam_path).unsqueeze(0) # 1 50 50 79 | support_cams.append(support_cam) 80 | support_cams = torch.cat(support_cams, dim=0) # nshot 50 50 81 | 82 | batch = {'query_img': query_img, 83 | 'query_mask': query_mask, 84 | 'query_name': query_name, 85 | 'query_ignore_idx': query_ignore_idx, 86 | 'org_query_imsize': org_qry_imsize, 87 | 'support_imgs': support_imgs, 88 | 'support_masks': support_masks, 89 | 'support_names': support_names, 90 | 'support_ignore_idxs': support_ignore_idxs, 91 | 'class_id': torch.tensor(class_sample), 92 | 'query_cam': query_cam, 93 | 'support_cams': support_cams} 94 | 95 | return batch 96 | 97 | def extract_ignore_idx(self, mask, class_id): 98 | boundary = (mask / 255).floor() 99 | mask[mask != class_id + 1] = 0 100 | mask[mask == class_id + 1] = 1 101 | 102 | return mask, boundary 103 | 104 | def load_frame(self, query_name, support_names): 105 | query_img = self.read_img(query_name) 106 | query_mask = self.read_mask(query_name) 107 | support_imgs = [self.read_img(name) for name in support_names] 108 | support_masks = [self.read_mask(name) for name in support_names] 109 | 110 | org_qry_imsize = query_img.size 111 | 112 | return query_img, query_mask, support_imgs, support_masks, org_qry_imsize 113 | 114 | def read_mask(self, img_name): 115 | r"""Return segmentation mask in PIL Image""" 116 | mask = torch.tensor(np.array(Image.open(os.path.join(self.ann_path, img_name) + '.png'))) 117 | return mask 118 | 119 | def read_img(self, img_name): 120 | r"""Return RGB image in PIL Image""" 121 | return Image.open(os.path.join(self.img_path, img_name) + '.jpg') 122 | 123 | def sample_episode(self, idx): 124 | query_name, class_sample = self.img_metadata[idx] 125 | 126 | support_names = [] 127 | while True: # keep sampling support set if query == support 128 | support_name = np.random.choice(self.img_metadata_classwise[class_sample], 1, replace=False)[0] 129 | if query_name != support_name: 130 | support_names.append(support_name) 131 | if len(support_names) == self.shot: 132 | break 133 | 134 | return query_name, support_names, class_sample 135 | 136 | def build_class_ids(self): 137 | nclass_trn = self.nclass // self.nfolds 138 | class_ids_val = [self.fold * nclass_trn + i for i in range(nclass_trn)] 139 | class_ids_trn = [x for x in range(self.nclass) if x not in class_ids_val] 140 | 141 | if self.split == 'trn': 142 | return class_ids_trn 143 | else: 144 | return class_ids_val 145 | 146 | def build_img_metadata(self): 147 | 148 | def read_metadata(split, fold_id): 149 | fold_n_metadata = os.path.join('data/splits/pascal/%s/fold%d.txt' % (split, fold_id)) 150 | with open(fold_n_metadata, 'r') as f: 151 | fold_n_metadata = f.read().split('\n')[:-1] 152 | fold_n_metadata = [[data.split('__')[0], int(data.split('__')[1]) - 1] for data in fold_n_metadata] 153 | return fold_n_metadata 154 | 155 | img_metadata = [] 156 | if self.split == 'trn': # For training, read image-metadata of "the other" folds 157 | for fold_id in range(self.nfolds): 158 | if fold_id == self.fold: # Skip validation fold 159 | continue 160 | img_metadata += read_metadata(self.split, fold_id) 161 | elif self.split == 'val': # For validation, read image-metadata of "current" fold 162 | img_metadata = read_metadata(self.split, self.fold) 163 | else: 164 | raise Exception('Undefined split %s: ' % self.split) 165 | 166 | print('Total (%s) images are : %d' % (self.split, len(img_metadata))) 167 | 168 | return img_metadata 169 | 170 | def build_img_metadata_classwise(self): 171 | img_metadata_classwise = {} 172 | for class_id in range(self.nclass): 173 | img_metadata_classwise[class_id] = [] 174 | 175 | for img_name, img_class in self.img_metadata: 176 | img_metadata_classwise[img_class] += [img_name] 177 | return img_metadata_classwise 178 | -------------------------------------------------------------------------------- /data/splits/coco/trn/fold0.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Whileherham/IMR-HSNet/57572272f493c3b16540a650ff0398d605e46812/data/splits/coco/trn/fold0.pkl -------------------------------------------------------------------------------- /data/splits/coco/trn/fold1.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Whileherham/IMR-HSNet/57572272f493c3b16540a650ff0398d605e46812/data/splits/coco/trn/fold1.pkl -------------------------------------------------------------------------------- /data/splits/coco/trn/fold2.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Whileherham/IMR-HSNet/57572272f493c3b16540a650ff0398d605e46812/data/splits/coco/trn/fold2.pkl -------------------------------------------------------------------------------- /data/splits/coco/trn/fold3.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Whileherham/IMR-HSNet/57572272f493c3b16540a650ff0398d605e46812/data/splits/coco/trn/fold3.pkl -------------------------------------------------------------------------------- /data/splits/coco/val/fold0.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Whileherham/IMR-HSNet/57572272f493c3b16540a650ff0398d605e46812/data/splits/coco/val/fold0.pkl -------------------------------------------------------------------------------- /data/splits/coco/val/fold1.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Whileherham/IMR-HSNet/57572272f493c3b16540a650ff0398d605e46812/data/splits/coco/val/fold1.pkl -------------------------------------------------------------------------------- /data/splits/coco/val/fold2.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Whileherham/IMR-HSNet/57572272f493c3b16540a650ff0398d605e46812/data/splits/coco/val/fold2.pkl -------------------------------------------------------------------------------- /data/splits/coco/val/fold3.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Whileherham/IMR-HSNet/57572272f493c3b16540a650ff0398d605e46812/data/splits/coco/val/fold3.pkl -------------------------------------------------------------------------------- /data/splits/fss/test.txt: -------------------------------------------------------------------------------- 1 | bus 2 | hotel_slipper 3 | burj_al 4 | reflex_camera 5 | abe's_flyingfish 6 | oiltank_car 7 | doormat 8 | fish_eagle 9 | barber_shaver 10 | motorbike 11 | feather_clothes 12 | wandering_albatross 13 | rice_cooker 14 | delta_wing 15 | fish 16 | nintendo_switch 17 | bustard 18 | diver 19 | minicooper 20 | cathedrale_paris 21 | big_ben 22 | combination_lock 23 | villa_savoye 24 | american_alligator 25 | gym_ball 26 | andean_condor 27 | leggings 28 | pyramid_cube 29 | jet_aircraft 30 | meatloaf 31 | reel 32 | swan 33 | osprey 34 | crt_screen 35 | microscope 36 | rubber_eraser 37 | arrow 38 | monkey 39 | mitten 40 | spiderman 41 | parthenon 42 | bat 43 | chess_king 44 | sulphur_butterfly 45 | quail_egg 46 | oriole 47 | iron_man 48 | wooden_boat 49 | anise 50 | steering_wheel 51 | groenendael 52 | dwarf_beans 53 | pteropus 54 | chalk_brush 55 | bloodhound 56 | moon 57 | english_foxhound 58 | boxing_gloves 59 | peregine_falcon 60 | pyraminx 61 | cicada 62 | screw 63 | shower_curtain 64 | tredmill 65 | bulb 66 | bell_pepper 67 | lemur_catta 68 | doughnut 69 | twin_tower 70 | astronaut 71 | nintendo_3ds 72 | fennel_bulb 73 | indri 74 | captain_america_shield 75 | kunai 76 | broom 77 | iphone 78 | earphone1 79 | flying_squirrel 80 | onion 81 | vinyl 82 | sydney_opera_house 83 | oyster 84 | harmonica 85 | egg 86 | breast_pump 87 | guitar 88 | potato_chips 89 | tunnel 90 | cuckoo 91 | rubick_cube 92 | plastic_bag 93 | phonograph 94 | net_surface_shoes 95 | goldfinch 96 | ipad 97 | mite_predator 98 | coffee_mug 99 | golden_plover 100 | f1_racing 101 | lapwing 102 | nintendo_gba 103 | pizza 104 | rally_car 105 | drilling_platform 106 | cd 107 | fly 108 | magpie_bird 109 | leaf_fan 110 | little_blue_heron 111 | carriage 112 | moist_proof_pad 113 | flying_snakes 114 | dart_target 115 | warehouse_tray 116 | nintendo_wiiu 117 | chiffon_cake 118 | bath_ball 119 | manatee 120 | cloud 121 | marimba 122 | eagle 123 | ruler 124 | soymilk_machine 125 | sled 126 | seagull 127 | glider_flyingfish 128 | doublebus 129 | transport_helicopter 130 | window_screen 131 | truss_bridge 132 | wasp 133 | snowman 134 | poached_egg 135 | strawberry 136 | spinach 137 | earphone2 138 | downy_pitch 139 | taj_mahal 140 | rocking_chair 141 | cablestayed_bridge 142 | sealion 143 | banana_boat 144 | pheasant 145 | stone_lion 146 | electronic_stove 147 | fox 148 | iguana 149 | rugby_ball 150 | hang_glider 151 | water_buffalo 152 | lotus 153 | paper_plane 154 | missile 155 | flamingo 156 | american_chamelon 157 | kart 158 | chinese_knot 159 | cabbage_butterfly 160 | key 161 | church 162 | tiltrotor 163 | helicopter 164 | french_fries 165 | water_heater 166 | snow_leopard 167 | goblet 168 | fan 169 | snowplow 170 | leafhopper 171 | pspgo 172 | black_bear 173 | quail 174 | condor 175 | chandelier 176 | hair_razor 177 | white_wolf 178 | toaster 179 | pidan 180 | pyramid 181 | chicken_leg 182 | letter_opener 183 | apple_icon 184 | porcupine 185 | chicken 186 | stingray 187 | warplane 188 | windmill 189 | bamboo_slip 190 | wig 191 | flying_geckos 192 | stonechat 193 | haddock 194 | australian_terrier 195 | hover_board 196 | siamang 197 | canton_tower 198 | santa_sledge 199 | arch_bridge 200 | curlew 201 | sushi 202 | beet_root 203 | accordion 204 | leaf_egg 205 | stealth_aircraft 206 | stork 207 | bucket 208 | hawk 209 | chess_queen 210 | ocarina 211 | knife 212 | whippet 213 | cantilever_bridge 214 | may_bug 215 | wagtail 216 | leather_shoes 217 | wheelchair 218 | shumai 219 | speedboat 220 | vacuum_cup 221 | chess_knight 222 | pumpkin_pie 223 | wooden_spoon 224 | bamboo_dragonfly 225 | ganeva_chair 226 | soap 227 | clearwing_flyingfish 228 | pencil_sharpener1 229 | cricket 230 | photocopier 231 | nintendo_sp 232 | samarra_mosque 233 | clam 234 | charge_battery 235 | flying_frog 236 | ferrari911 237 | polo_shirt 238 | echidna 239 | coin 240 | tower_pisa 241 | -------------------------------------------------------------------------------- /data/splits/fss/trn.txt: -------------------------------------------------------------------------------- 1 | fountain 2 | taxi 3 | assult_rifle 4 | radio 5 | comb 6 | box_turtle 7 | igloo 8 | head_cabbage 9 | cottontail 10 | coho 11 | ashtray 12 | joystick 13 | sleeping_bag 14 | jackfruit 15 | trailer_truck 16 | shower_cap 17 | ibex 18 | kinguin 19 | squirrel 20 | ac_wall 21 | sidewinder 22 | remote_control 23 | marshmallow 24 | bolotie 25 | polar_bear 26 | rock_beauty 27 | tokyo_tower 28 | wafer 29 | red_bayberry 30 | electronic_toothbrush 31 | hartebeest 32 | cassette 33 | oil_filter 34 | bomb 35 | walnut 36 | toilet_tissue 37 | memory_stick 38 | wild_boar 39 | cableways 40 | chihuahua 41 | envelope 42 | bison 43 | poker 44 | pubg_lvl3helmet 45 | indian_cobra 46 | staffordshire 47 | park_bench 48 | wombat 49 | black_grouse 50 | submarine 51 | washer 52 | agama 53 | coyote 54 | feeder 55 | sarong 56 | buckingham_palace 57 | frog 58 | steam_locomotive 59 | acorn 60 | german_pointer 61 | obelisk 62 | polecat 63 | black_swan 64 | butterfly 65 | mountain_tent 66 | gorilla 67 | sloth_bear 68 | aubergine 69 | stinkhorn 70 | stole 71 | owl 72 | mooli 73 | pool_table 74 | collar 75 | lhasa_apso 76 | ambulance 77 | spade 78 | pufferfish 79 | paint_brush 80 | lark 81 | golf_ball 82 | hock 83 | fork 84 | drake 85 | bee_house 86 | mooncake 87 | wok 88 | cocacola 89 | water_bike 90 | ladder 91 | psp 92 | bassoon 93 | bear 94 | border_terrier 95 | petri_dish 96 | pill_bottle 97 | aircraft_carrier 98 | panther 99 | canoe 100 | baseball_player 101 | turtle 102 | espresso 103 | throne 104 | cornet 105 | coucal 106 | eletrical_switch 107 | bra 108 | snail 109 | backpack 110 | jacamar 111 | scroll_brush 112 | gliding_lizard 113 | raft 114 | pinwheel 115 | grasshopper 116 | green_mamba 117 | eft_newt 118 | computer_mouse 119 | vine_snake 120 | recreational_vehicle 121 | llama 122 | meerkat 123 | chainsaw 124 | ferret 125 | garbage_can 126 | kangaroo 127 | litchi 128 | carbonara 129 | housefinch 130 | modem 131 | tebby_cat 132 | thatch 133 | face_powder 134 | tomb 135 | apple 136 | ladybug 137 | killer_whale 138 | rocket 139 | airship 140 | surfboard 141 | lesser_panda 142 | jordan_logo 143 | banana 144 | nail_scissor 145 | swab 146 | perfume 147 | punching_bag 148 | victor_icon 149 | waffle_iron 150 | trimaran 151 | garlic 152 | flute 153 | langur 154 | starfish 155 | parallel_bars 156 | dandie_dinmont 157 | cosmetic_brush 158 | screwdriver 159 | brick_card 160 | balance_weight 161 | hornet 162 | carton 163 | toothpaste 164 | bracelet 165 | egg_tart 166 | pencil_sharpener2 167 | swimming_glasses 168 | howler_monkey 169 | camel 170 | dragonfly 171 | lionfish 172 | convertible 173 | mule 174 | usb 175 | conch 176 | papaya 177 | garbage_truck 178 | dingo 179 | radiator 180 | solar_dish 181 | streetcar 182 | trilobite 183 | bouzouki 184 | ringlet_butterfly 185 | space_shuttle 186 | waffle 187 | american_staffordshire 188 | violin 189 | flowerpot 190 | forklift 191 | manx 192 | sundial 193 | snowmobile 194 | chickadee_bird 195 | ruffed_grouse 196 | brick_tea 197 | paddle 198 | stove 199 | carousel 200 | spatula 201 | beaker 202 | gas_pump 203 | lawn_mower 204 | speaker 205 | tank 206 | tresher 207 | kappa_logo 208 | hare 209 | tennis_racket 210 | shopping_cart 211 | thimble 212 | tractor 213 | anemone_fish 214 | trolleybus 215 | steak 216 | capuchin 217 | red_breasted_merganser 218 | golden_retriever 219 | light_tube 220 | flatworm 221 | melon_seed 222 | digital_watch 223 | jacko_lantern 224 | brown_bear 225 | cairn 226 | mushroom 227 | chalk 228 | skull 229 | stapler 230 | potato 231 | telescope 232 | proboscis 233 | microphone 234 | torii 235 | baseball_bat 236 | dhole 237 | excavator 238 | fig 239 | snake 240 | bradypod 241 | pepitas 242 | prairie_chicken 243 | scorpion 244 | shotgun 245 | bottle_cap 246 | file_cabinet 247 | grey_whale 248 | one-armed_bandit 249 | banded_gecko 250 | flying_disc 251 | croissant 252 | toothbrush 253 | miniskirt 254 | pokermon_ball 255 | gazelle 256 | grey_fox 257 | esport_chair 258 | necklace 259 | ptarmigan 260 | watermelon 261 | besom 262 | pomelo 263 | radio_telescope 264 | studio_couch 265 | black_stork 266 | vestment 267 | koala 268 | brambling 269 | muscle_car 270 | window_shade 271 | space_heater 272 | sunglasses 273 | motor_scooter 274 | ladyfinger 275 | pencil_box 276 | titi_monkey 277 | chicken_wings 278 | mount_fuji 279 | giant_panda 280 | dart 281 | fire_engine 282 | running_shoe 283 | dumbbell 284 | donkey 285 | loafer 286 | hard_disk 287 | globe 288 | lifeboat 289 | medical_kit 290 | brain_coral 291 | paper_towel 292 | dugong 293 | seatbelt 294 | skunk 295 | military_vest 296 | cocktail_shaker 297 | zucchini 298 | quad_drone 299 | ocicat 300 | shih-tzu 301 | teapot 302 | tile_roof 303 | cheese_burger 304 | handshower 305 | red_wolf 306 | stop_sign 307 | mouse 308 | battery 309 | adidas_logo2 310 | earplug 311 | hummingbird 312 | brush_pen 313 | pistachio 314 | hamster 315 | air_strip 316 | indian_elephant 317 | otter 318 | cucumber 319 | scabbard 320 | hawthorn 321 | bullet_train 322 | leopard 323 | whale 324 | cream 325 | chinese_date 326 | jellyfish 327 | lobster 328 | skua 329 | single_log 330 | chicory 331 | bagel 332 | beacon 333 | pingpong_racket 334 | spoon 335 | yurt 336 | wallaby 337 | egret 338 | christmas_stocking 339 | mcdonald_uncle 340 | wrench 341 | spark_plug 342 | triceratops 343 | wall_clock 344 | jinrikisha 345 | pickup 346 | rhinoceros 347 | swimming_trunk 348 | band-aid 349 | spotted_salamander 350 | leeks 351 | marmot 352 | warthog 353 | cello 354 | stool 355 | chest 356 | toilet_plunger 357 | wardrobe 358 | cannon 359 | adidas_logo1 360 | drumstick 361 | lady_slipper 362 | puma_logo 363 | great_wall 364 | white_shark 365 | witch_hat 366 | vending_machine 367 | wreck 368 | chopsticks 369 | garfish 370 | african_elephant 371 | children_slide 372 | hornbill 373 | zebra 374 | boa_constrictor 375 | armour 376 | pineapple 377 | angora 378 | brick 379 | car_wheel 380 | wallet 381 | boston_bull 382 | hyena 383 | lynx 384 | crash_helmet 385 | terrapin_turtle 386 | persian_cat 387 | shift_gear 388 | cactus_ball 389 | fur_coat 390 | plate 391 | pen 392 | okra 393 | mario 394 | airedale 395 | cowboy_hat 396 | celery 397 | macaque 398 | candle 399 | goose 400 | raccoon 401 | brasscica 402 | almond 403 | maotai_bottle 404 | soccer_ball 405 | sports_car 406 | tobacco_pipe 407 | water_polo 408 | eggnog 409 | hook 410 | ostrich 411 | patas 412 | table_lamp 413 | teddy 414 | mongoose 415 | spoonbill 416 | redheart 417 | crane 418 | dinosaur 419 | kitchen_knife 420 | seal 421 | baboon 422 | golfcart 423 | roller_coaster 424 | avocado 425 | birdhouse 426 | yorkshire_terrier 427 | saluki 428 | basketball 429 | buckler 430 | harvester 431 | afghan_hound 432 | beam_bridge 433 | guinea_pig 434 | lorikeet 435 | shakuhachi 436 | motarboard 437 | statue_liberty 438 | police_car 439 | sulphur_crested 440 | gourd 441 | sombrero 442 | mailbox 443 | adhensive_tape 444 | night_snake 445 | bushtit 446 | mouthpiece 447 | beaver 448 | bathtub 449 | printer 450 | cumquat 451 | orange 452 | cleaver 453 | quill_pen 454 | panpipe 455 | diamond 456 | gypsy_moth 457 | cauliflower 458 | lampshade 459 | cougar 460 | traffic_light 461 | briefcase 462 | ballpoint 463 | african_grey 464 | kremlin 465 | barometer 466 | peacock 467 | paper_crane 468 | sunscreen 469 | tofu 470 | bedlington_terrier 471 | snowball 472 | carrot 473 | tiger 474 | mink 475 | cristo_redentor 476 | ladle 477 | keyboard 478 | maraca 479 | monitor 480 | water_snake 481 | can_opener 482 | mud_turtle 483 | bald_eagle 484 | carp 485 | cn_tower 486 | egyptian_cat 487 | hen_of_the_woods 488 | measuring_cup 489 | roller_skate 490 | kite 491 | sandwich_cookies 492 | sandwich 493 | persimmon 494 | chess_bishop 495 | coffin 496 | ruddy_turnstone 497 | prayer_rug 498 | rain_barrel 499 | neck_brace 500 | nematode 501 | rosehip 502 | dutch_oven 503 | goldfish 504 | blossom_card 505 | dough 506 | trench_coat 507 | sponge 508 | stupa 509 | wash_basin 510 | electric_fan 511 | spring_scroll 512 | potted_plant 513 | sparrow 514 | car_mirror 515 | gecko 516 | diaper 517 | leatherback_turtle 518 | strainer 519 | guacamole 520 | microwave 521 | -------------------------------------------------------------------------------- /data/splits/fss/val.txt: -------------------------------------------------------------------------------- 1 | handcuff 2 | mortar 3 | matchstick 4 | wine_bottle 5 | dowitcher 6 | triumphal_arch 7 | gyromitra 8 | hatchet 9 | airliner 10 | broccoli 11 | olive 12 | pubg_lvl3backpack 13 | calculator 14 | toucan 15 | shovel 16 | sewing_machine 17 | icecream 18 | woodpecker 19 | pig 20 | relay_stick 21 | mcdonald_sign 22 | cpu 23 | peanut 24 | pumpkin 25 | sturgeon 26 | hammer 27 | hami_melon 28 | squirrel_monkey 29 | shuriken 30 | power_drill 31 | pingpong_ball 32 | crocodile 33 | carambola 34 | monarch_butterfly 35 | drum 36 | water_tower 37 | panda 38 | toilet_brush 39 | pay_phone 40 | yonex_icon 41 | cricketball 42 | revolver 43 | chimpanzee 44 | crab 45 | corn 46 | baseball 47 | rabbit 48 | croquet_ball 49 | artichoke 50 | abacus 51 | harp 52 | bell 53 | gas_tank 54 | scissors 55 | vase 56 | upright_piano 57 | typewriter 58 | bittern 59 | impala 60 | tray 61 | fire_hydrant 62 | beer_bottle 63 | sock 64 | soup_bowl 65 | spider 66 | cherry 67 | macaw 68 | toilet_seat 69 | fire_balloon 70 | french_ball 71 | fox_squirrel 72 | volleyball 73 | cornmeal 74 | folding_chair 75 | pubg_airdrop 76 | beagle 77 | skateboard 78 | narcissus 79 | whiptail 80 | cup 81 | arabian_camel 82 | badger 83 | stopwatch 84 | ab_wheel 85 | ox 86 | lettuce 87 | monocycle 88 | redshank 89 | vulture 90 | whistle 91 | smoothing_iron 92 | mashed_potato 93 | conveyor 94 | yoga_pad 95 | tow_truck 96 | siamese_cat 97 | cigar 98 | white_stork 99 | sniper_rifle 100 | stretcher 101 | tulip 102 | handkerchief 103 | basset 104 | iceberg 105 | gibbon 106 | lacewing 107 | thrush 108 | cheetah 109 | bighorn_sheep 110 | espresso_maker 111 | pretzel 112 | english_setter 113 | sandbar 114 | cheese 115 | daisy 116 | arctic_fox 117 | briard 118 | colubus 119 | balance_beam 120 | coffeepot 121 | soap_dispenser 122 | yawl 123 | consomme 124 | parking_meter 125 | cactus 126 | turnstile 127 | taro 128 | fire_screen 129 | digital_clock 130 | rose 131 | pomegranate 132 | bee_eater 133 | schooner 134 | ski_mask 135 | jay_bird 136 | plaice 137 | red_fox 138 | syringe 139 | camomile 140 | pickelhaube 141 | blenheim_spaniel 142 | pear 143 | parachute 144 | common_newt 145 | bowtie 146 | cigarette 147 | oscilloscope 148 | laptop 149 | african_crocodile 150 | apron 151 | coconut 152 | sandal 153 | kwanyin 154 | lion 155 | eel 156 | balloon 157 | crepe 158 | armadillo 159 | kazoo 160 | lemon 161 | spider_monkey 162 | tape_player 163 | ipod 164 | bee 165 | sea_cucumber 166 | suitcase 167 | television 168 | pillow 169 | banjo 170 | rock_snake 171 | partridge 172 | platypus 173 | lycaenid_butterfly 174 | pinecone 175 | conversion_plug 176 | wolf 177 | frying_pan 178 | timber_wolf 179 | bluetick 180 | crayon 181 | giant_schnauzer 182 | orang 183 | scarerow 184 | kobe_logo 185 | loguat 186 | saxophone 187 | ceiling_fan 188 | cardoon 189 | equestrian_helmet 190 | louvre_pyramid 191 | hotdog 192 | ironing_board 193 | razor 194 | nagoya_castle 195 | loggerhead_turtle 196 | lipstick 197 | cradle 198 | strongbox 199 | raven 200 | kit_fox 201 | albatross 202 | flat-coated_retriever 203 | beer_glass 204 | ice_lolly 205 | sungnyemun 206 | totem_pole 207 | vacuum 208 | bolete 209 | mango 210 | ginger 211 | weasel 212 | cabbage 213 | refrigerator 214 | school_bus 215 | hippo 216 | tiger_cat 217 | saltshaker 218 | piano_keyboard 219 | windsor_tie 220 | sea_urchin 221 | microsd 222 | barbell 223 | swim_ring 224 | bulbul_bird 225 | water_ouzel 226 | ac_ground 227 | sweatshirt 228 | umbrella 229 | hair_drier 230 | hammerhead_shark 231 | tomato 232 | projector 233 | cushion 234 | dishwasher 235 | three-toed_sloth 236 | tiger_shark 237 | har_gow 238 | baby 239 | thor's_hammer 240 | nike_logo 241 | -------------------------------------------------------------------------------- /data/splits/pascal/val/fold0.txt: -------------------------------------------------------------------------------- 1 | 2007_000033__01 2 | 2007_000061__04 3 | 2007_000129__02 4 | 2007_000346__05 5 | 2007_000529__04 6 | 2007_000559__05 7 | 2007_000572__02 8 | 2007_000762__05 9 | 2007_001288__01 10 | 2007_001289__03 11 | 2007_001311__02 12 | 2007_001408__05 13 | 2007_001568__01 14 | 2007_001630__02 15 | 2007_001761__01 16 | 2007_001884__01 17 | 2007_002094__03 18 | 2007_002266__01 19 | 2007_002376__01 20 | 2007_002400__03 21 | 2007_002619__01 22 | 2007_002719__04 23 | 2007_003088__05 24 | 2007_003131__04 25 | 2007_003188__02 26 | 2007_003349__03 27 | 2007_003571__04 28 | 2007_003621__02 29 | 2007_003682__03 30 | 2007_003861__04 31 | 2007_004052__01 32 | 2007_004143__03 33 | 2007_004241__04 34 | 2007_004468__05 35 | 2007_005074__04 36 | 2007_005107__02 37 | 2007_005294__05 38 | 2007_005304__05 39 | 2007_005428__05 40 | 2007_005509__01 41 | 2007_005600__01 42 | 2007_005705__04 43 | 2007_005828__01 44 | 2007_006076__03 45 | 2007_006086__05 46 | 2007_006449__02 47 | 2007_006946__01 48 | 2007_007084__03 49 | 2007_007235__02 50 | 2007_007341__01 51 | 2007_007470__01 52 | 2007_007477__04 53 | 2007_007836__02 54 | 2007_008051__03 55 | 2007_008084__03 56 | 2007_008204__05 57 | 2007_008670__03 58 | 2007_009088__03 59 | 2007_009258__02 60 | 2007_009323__03 61 | 2007_009458__05 62 | 2007_009687__05 63 | 2007_009817__03 64 | 2007_009911__01 65 | 2008_000120__04 66 | 2008_000123__03 67 | 2008_000533__03 68 | 2008_000725__02 69 | 2008_000911__05 70 | 2008_001013__04 71 | 2008_001040__04 72 | 2008_001135__04 73 | 2008_001260__04 74 | 2008_001404__02 75 | 2008_001514__03 76 | 2008_001531__02 77 | 2008_001546__01 78 | 2008_001580__04 79 | 2008_001966__03 80 | 2008_001971__01 81 | 2008_002043__03 82 | 2008_002269__02 83 | 2008_002358__01 84 | 2008_002429__03 85 | 2008_002467__05 86 | 2008_002504__04 87 | 2008_002775__05 88 | 2008_002864__05 89 | 2008_003034__04 90 | 2008_003076__05 91 | 2008_003108__02 92 | 2008_003110__03 93 | 2008_003155__01 94 | 2008_003270__02 95 | 2008_003369__01 96 | 2008_003858__04 97 | 2008_003876__01 98 | 2008_003886__04 99 | 2008_003926__01 100 | 2008_003976__01 101 | 2008_004363__02 102 | 2008_004654__02 103 | 2008_004659__05 104 | 2008_004704__01 105 | 2008_004758__02 106 | 2008_004995__02 107 | 2008_005262__05 108 | 2008_005338__01 109 | 2008_005628__04 110 | 2008_005727__02 111 | 2008_005812__05 112 | 2008_005904__05 113 | 2008_006216__01 114 | 2008_006229__04 115 | 2008_006254__02 116 | 2008_006703__01 117 | 2008_007120__03 118 | 2008_007143__04 119 | 2008_007219__05 120 | 2008_007350__01 121 | 2008_007498__03 122 | 2008_007811__05 123 | 2008_007994__03 124 | 2008_008268__03 125 | 2008_008629__02 126 | 2008_008711__02 127 | 2008_008746__03 128 | 2009_000032__01 129 | 2009_000037__03 130 | 2009_000121__05 131 | 2009_000149__02 132 | 2009_000201__05 133 | 2009_000205__01 134 | 2009_000318__03 135 | 2009_000354__02 136 | 2009_000387__01 137 | 2009_000421__04 138 | 2009_000440__01 139 | 2009_000446__04 140 | 2009_000457__02 141 | 2009_000469__04 142 | 2009_000573__02 143 | 2009_000619__03 144 | 2009_000664__03 145 | 2009_000723__04 146 | 2009_000828__04 147 | 2009_000840__05 148 | 2009_000879__03 149 | 2009_000991__03 150 | 2009_000998__03 151 | 2009_001108__03 152 | 2009_001160__03 153 | 2009_001255__02 154 | 2009_001278__05 155 | 2009_001314__03 156 | 2009_001332__01 157 | 2009_001565__03 158 | 2009_001607__03 159 | 2009_001683__03 160 | 2009_001718__02 161 | 2009_001765__03 162 | 2009_001818__05 163 | 2009_001850__01 164 | 2009_001851__01 165 | 2009_001941__04 166 | 2009_002185__05 167 | 2009_002295__02 168 | 2009_002320__01 169 | 2009_002372__05 170 | 2009_002521__05 171 | 2009_002594__05 172 | 2009_002604__03 173 | 2009_002649__05 174 | 2009_002727__04 175 | 2009_002732__05 176 | 2009_002749__05 177 | 2009_002808__01 178 | 2009_002856__05 179 | 2009_002888__01 180 | 2009_002928__02 181 | 2009_003003__05 182 | 2009_003005__01 183 | 2009_003043__04 184 | 2009_003080__04 185 | 2009_003193__02 186 | 2009_003224__02 187 | 2009_003269__05 188 | 2009_003273__03 189 | 2009_003343__02 190 | 2009_003378__03 191 | 2009_003450__03 192 | 2009_003498__03 193 | 2009_003504__04 194 | 2009_003517__05 195 | 2009_003640__03 196 | 2009_003696__01 197 | 2009_003707__04 198 | 2009_003806__01 199 | 2009_003858__03 200 | 2009_003971__02 201 | 2009_004021__03 202 | 2009_004084__03 203 | 2009_004125__04 204 | 2009_004247__05 205 | 2009_004324__05 206 | 2009_004509__03 207 | 2009_004540__03 208 | 2009_004568__03 209 | 2009_004579__05 210 | 2009_004635__04 211 | 2009_004653__01 212 | 2009_004848__02 213 | 2009_004882__02 214 | 2009_004886__03 215 | 2009_004895__03 216 | 2009_004969__01 217 | 2009_005038__05 218 | 2009_005137__03 219 | 2009_005156__02 220 | 2009_005189__01 221 | 2009_005190__05 222 | 2009_005260__03 223 | 2009_005262__03 224 | 2009_005302__05 225 | 2010_000065__02 226 | 2010_000083__02 227 | 2010_000084__04 228 | 2010_000238__01 229 | 2010_000241__03 230 | 2010_000272__04 231 | 2010_000342__02 232 | 2010_000426__05 233 | 2010_000572__01 234 | 2010_000622__01 235 | 2010_000814__03 236 | 2010_000906__04 237 | 2010_000961__03 238 | 2010_001016__03 239 | 2010_001017__01 240 | 2010_001024__01 241 | 2010_001036__04 242 | 2010_001061__03 243 | 2010_001069__03 244 | 2010_001174__01 245 | 2010_001367__02 246 | 2010_001367__05 247 | 2010_001448__01 248 | 2010_001830__05 249 | 2010_001995__03 250 | 2010_002017__05 251 | 2010_002030__02 252 | 2010_002142__03 253 | 2010_002147__01 254 | 2010_002150__04 255 | 2010_002200__01 256 | 2010_002310__01 257 | 2010_002536__02 258 | 2010_002546__04 259 | 2010_002693__02 260 | 2010_002939__01 261 | 2010_003127__01 262 | 2010_003132__01 263 | 2010_003168__03 264 | 2010_003362__03 265 | 2010_003365__01 266 | 2010_003418__03 267 | 2010_003468__05 268 | 2010_003473__03 269 | 2010_003495__01 270 | 2010_003547__04 271 | 2010_003716__01 272 | 2010_003771__03 273 | 2010_003781__05 274 | 2010_003820__03 275 | 2010_003912__02 276 | 2010_003915__01 277 | 2010_004041__04 278 | 2010_004056__05 279 | 2010_004208__04 280 | 2010_004314__01 281 | 2010_004419__01 282 | 2010_004520__05 283 | 2010_004529__05 284 | 2010_004551__05 285 | 2010_004556__03 286 | 2010_004559__03 287 | 2010_004662__04 288 | 2010_004772__04 289 | 2010_004828__05 290 | 2010_004994__03 291 | 2010_005252__04 292 | 2010_005401__04 293 | 2010_005428__03 294 | 2010_005496__05 295 | 2010_005531__03 296 | 2010_005534__01 297 | 2010_005582__05 298 | 2010_005664__02 299 | 2010_005705__04 300 | 2010_005718__01 301 | 2010_005762__05 302 | 2010_005877__01 303 | 2010_005888__01 304 | 2010_006034__01 305 | 2010_006070__02 306 | 2011_000066__05 307 | 2011_000112__03 308 | 2011_000185__03 309 | 2011_000234__04 310 | 2011_000238__04 311 | 2011_000412__02 312 | 2011_000435__04 313 | 2011_000456__03 314 | 2011_000482__03 315 | 2011_000585__02 316 | 2011_000669__03 317 | 2011_000747__05 318 | 2011_000874__01 319 | 2011_001114__01 320 | 2011_001161__04 321 | 2011_001263__01 322 | 2011_001287__03 323 | 2011_001407__01 324 | 2011_001421__03 325 | 2011_001434__01 326 | 2011_001589__04 327 | 2011_001624__01 328 | 2011_001793__04 329 | 2011_001880__01 330 | 2011_001988__02 331 | 2011_002064__02 332 | 2011_002098__05 333 | 2011_002223__02 334 | 2011_002295__03 335 | 2011_002327__01 336 | 2011_002515__01 337 | 2011_002675__01 338 | 2011_002713__02 339 | 2011_002754__04 340 | 2011_002863__05 341 | 2011_002929__01 342 | 2011_002975__04 343 | 2011_003003__02 344 | 2011_003030__03 345 | 2011_003145__03 346 | 2011_003271__05 347 | -------------------------------------------------------------------------------- /data/splits/pascal/val/fold1.txt: -------------------------------------------------------------------------------- 1 | 2007_000452__09 2 | 2007_000464__10 3 | 2007_000491__10 4 | 2007_000663__06 5 | 2007_000663__07 6 | 2007_000727__06 7 | 2007_000727__07 8 | 2007_000804__09 9 | 2007_000830__09 10 | 2007_001299__10 11 | 2007_001321__07 12 | 2007_001457__09 13 | 2007_001677__09 14 | 2007_001717__09 15 | 2007_001763__08 16 | 2007_001774__08 17 | 2007_001884__06 18 | 2007_002268__08 19 | 2007_002387__10 20 | 2007_002445__08 21 | 2007_002470__08 22 | 2007_002539__06 23 | 2007_002597__08 24 | 2007_002643__07 25 | 2007_002903__10 26 | 2007_003011__09 27 | 2007_003051__07 28 | 2007_003101__06 29 | 2007_003106__08 30 | 2007_003137__06 31 | 2007_003143__07 32 | 2007_003169__08 33 | 2007_003195__06 34 | 2007_003201__10 35 | 2007_003503__06 36 | 2007_003503__07 37 | 2007_003621__06 38 | 2007_003711__06 39 | 2007_003786__06 40 | 2007_003841__10 41 | 2007_003917__07 42 | 2007_003991__08 43 | 2007_004193__09 44 | 2007_004392__09 45 | 2007_004405__09 46 | 2007_004510__09 47 | 2007_004712__09 48 | 2007_004856__08 49 | 2007_004866__08 50 | 2007_005074__07 51 | 2007_005114__10 52 | 2007_005296__07 53 | 2007_005331__07 54 | 2007_005460__08 55 | 2007_005547__07 56 | 2007_005547__10 57 | 2007_005844__09 58 | 2007_005845__08 59 | 2007_005911__06 60 | 2007_005978__06 61 | 2007_006035__07 62 | 2007_006086__09 63 | 2007_006241__09 64 | 2007_006260__08 65 | 2007_006277__07 66 | 2007_006348__09 67 | 2007_006553__09 68 | 2007_006761__10 69 | 2007_006841__10 70 | 2007_007414__07 71 | 2007_007417__08 72 | 2007_007524__08 73 | 2007_007815__07 74 | 2007_007818__07 75 | 2007_007996__09 76 | 2007_008106__09 77 | 2007_008110__09 78 | 2007_008543__09 79 | 2007_008722__10 80 | 2007_008747__06 81 | 2007_008815__08 82 | 2007_008897__09 83 | 2007_008973__10 84 | 2007_009015__06 85 | 2007_009015__07 86 | 2007_009068__09 87 | 2007_009084__09 88 | 2007_009096__07 89 | 2007_009221__08 90 | 2007_009245__10 91 | 2007_009346__08 92 | 2007_009392__06 93 | 2007_009392__07 94 | 2007_009413__09 95 | 2007_009521__09 96 | 2007_009764__06 97 | 2007_009794__08 98 | 2007_009897__10 99 | 2007_009923__08 100 | 2007_009938__07 101 | 2008_000009__10 102 | 2008_000073__10 103 | 2008_000075__06 104 | 2008_000107__09 105 | 2008_000149__09 106 | 2008_000182__08 107 | 2008_000345__08 108 | 2008_000401__08 109 | 2008_000464__08 110 | 2008_000501__07 111 | 2008_000673__09 112 | 2008_000853__08 113 | 2008_000919__10 114 | 2008_001078__08 115 | 2008_001433__08 116 | 2008_001439__09 117 | 2008_001513__08 118 | 2008_001640__08 119 | 2008_001715__09 120 | 2008_001885__08 121 | 2008_002152__08 122 | 2008_002205__06 123 | 2008_002212__07 124 | 2008_002379__09 125 | 2008_002521__09 126 | 2008_002623__08 127 | 2008_002681__08 128 | 2008_002778__10 129 | 2008_002958__07 130 | 2008_003141__06 131 | 2008_003141__07 132 | 2008_003333__07 133 | 2008_003477__09 134 | 2008_003499__08 135 | 2008_003577__07 136 | 2008_003777__06 137 | 2008_003821__09 138 | 2008_003846__07 139 | 2008_004069__07 140 | 2008_004339__07 141 | 2008_004552__07 142 | 2008_004612__09 143 | 2008_004701__10 144 | 2008_005097__10 145 | 2008_005105__10 146 | 2008_005245__07 147 | 2008_005676__06 148 | 2008_006008__09 149 | 2008_006063__10 150 | 2008_006254__07 151 | 2008_006325__08 152 | 2008_006341__08 153 | 2008_006480__08 154 | 2008_006528__10 155 | 2008_006554__06 156 | 2008_006986__07 157 | 2008_007025__10 158 | 2008_007031__10 159 | 2008_007048__09 160 | 2008_007123__10 161 | 2008_007194__09 162 | 2008_007273__10 163 | 2008_007378__09 164 | 2008_007402__09 165 | 2008_007527__09 166 | 2008_007548__08 167 | 2008_007596__10 168 | 2008_007737__09 169 | 2008_007797__06 170 | 2008_007804__07 171 | 2008_007828__09 172 | 2008_008252__06 173 | 2008_008301__06 174 | 2008_008469__06 175 | 2008_008682__06 176 | 2009_000013__08 177 | 2009_000080__08 178 | 2009_000219__10 179 | 2009_000309__10 180 | 2009_000335__06 181 | 2009_000335__07 182 | 2009_000426__06 183 | 2009_000455__06 184 | 2009_000457__07 185 | 2009_000523__07 186 | 2009_000641__10 187 | 2009_000716__08 188 | 2009_000731__10 189 | 2009_000771__10 190 | 2009_000825__07 191 | 2009_000964__08 192 | 2009_001008__08 193 | 2009_001082__06 194 | 2009_001240__07 195 | 2009_001255__07 196 | 2009_001299__09 197 | 2009_001391__08 198 | 2009_001411__08 199 | 2009_001536__07 200 | 2009_001775__09 201 | 2009_001804__06 202 | 2009_001816__06 203 | 2009_001854__06 204 | 2009_002035__10 205 | 2009_002122__10 206 | 2009_002150__10 207 | 2009_002164__07 208 | 2009_002171__10 209 | 2009_002221__10 210 | 2009_002238__06 211 | 2009_002238__07 212 | 2009_002239__07 213 | 2009_002268__08 214 | 2009_002346__09 215 | 2009_002415__09 216 | 2009_002487__09 217 | 2009_002527__08 218 | 2009_002535__06 219 | 2009_002549__10 220 | 2009_002571__09 221 | 2009_002618__07 222 | 2009_002635__10 223 | 2009_002753__08 224 | 2009_002936__08 225 | 2009_002990__07 226 | 2009_003003__07 227 | 2009_003059__10 228 | 2009_003071__09 229 | 2009_003269__07 230 | 2009_003304__06 231 | 2009_003387__07 232 | 2009_003406__07 233 | 2009_003494__09 234 | 2009_003507__09 235 | 2009_003542__10 236 | 2009_003549__07 237 | 2009_003569__10 238 | 2009_003589__07 239 | 2009_003703__06 240 | 2009_003771__08 241 | 2009_003773__10 242 | 2009_003849__09 243 | 2009_003895__09 244 | 2009_003904__08 245 | 2009_004072__06 246 | 2009_004140__09 247 | 2009_004217__09 248 | 2009_004248__08 249 | 2009_004455__07 250 | 2009_004504__08 251 | 2009_004590__06 252 | 2009_004594__07 253 | 2009_004687__09 254 | 2009_004721__08 255 | 2009_004732__06 256 | 2009_004748__07 257 | 2009_004789__06 258 | 2009_004859__09 259 | 2009_004867__06 260 | 2009_005158__08 261 | 2009_005219__08 262 | 2009_005231__06 263 | 2010_000003__09 264 | 2010_000160__07 265 | 2010_000163__08 266 | 2010_000372__07 267 | 2010_000427__10 268 | 2010_000530__07 269 | 2010_000552__08 270 | 2010_000573__06 271 | 2010_000628__07 272 | 2010_000639__09 273 | 2010_000682__06 274 | 2010_000683__08 275 | 2010_000724__08 276 | 2010_000907__10 277 | 2010_000941__08 278 | 2010_000952__07 279 | 2010_001000__10 280 | 2010_001010__10 281 | 2010_001070__08 282 | 2010_001206__06 283 | 2010_001292__08 284 | 2010_001331__08 285 | 2010_001351__08 286 | 2010_001403__06 287 | 2010_001403__07 288 | 2010_001534__08 289 | 2010_001553__07 290 | 2010_001579__09 291 | 2010_001646__06 292 | 2010_001656__08 293 | 2010_001692__10 294 | 2010_001699__09 295 | 2010_001767__07 296 | 2010_001851__09 297 | 2010_001913__08 298 | 2010_002017__07 299 | 2010_002017__09 300 | 2010_002025__08 301 | 2010_002137__08 302 | 2010_002146__08 303 | 2010_002305__08 304 | 2010_002336__09 305 | 2010_002348__08 306 | 2010_002361__07 307 | 2010_002390__10 308 | 2010_002422__08 309 | 2010_002512__08 310 | 2010_002531__08 311 | 2010_002546__06 312 | 2010_002623__09 313 | 2010_002693__08 314 | 2010_002693__09 315 | 2010_002763__08 316 | 2010_002763__10 317 | 2010_002868__06 318 | 2010_002900__08 319 | 2010_002902__07 320 | 2010_002921__09 321 | 2010_002929__07 322 | 2010_002988__07 323 | 2010_003123__07 324 | 2010_003183__10 325 | 2010_003231__07 326 | 2010_003239__10 327 | 2010_003275__08 328 | 2010_003276__07 329 | 2010_003293__06 330 | 2010_003302__09 331 | 2010_003325__09 332 | 2010_003381__07 333 | 2010_003402__08 334 | 2010_003409__09 335 | 2010_003446__07 336 | 2010_003453__07 337 | 2010_003468__08 338 | 2010_003531__09 339 | 2010_003675__08 340 | 2010_003746__07 341 | 2010_003758__08 342 | 2010_003764__08 343 | 2010_003768__07 344 | 2010_003772__06 345 | 2010_003781__08 346 | 2010_003813__07 347 | 2010_003854__07 348 | 2010_003971__08 349 | 2010_003971__09 350 | 2010_004104__08 351 | 2010_004120__08 352 | 2010_004320__08 353 | 2010_004322__10 354 | 2010_004348__06 355 | 2010_004369__08 356 | 2010_004472__07 357 | 2010_004479__08 358 | 2010_004635__10 359 | 2010_004763__09 360 | 2010_004783__09 361 | 2010_004789__10 362 | 2010_004815__08 363 | 2010_004825__09 364 | 2010_004861__08 365 | 2010_004946__07 366 | 2010_005013__07 367 | 2010_005021__08 368 | 2010_005021__09 369 | 2010_005063__06 370 | 2010_005108__08 371 | 2010_005118__06 372 | 2010_005160__06 373 | 2010_005166__10 374 | 2010_005284__06 375 | 2010_005344__08 376 | 2010_005421__08 377 | 2010_005432__07 378 | 2010_005501__07 379 | 2010_005508__08 380 | 2010_005606__08 381 | 2010_005709__08 382 | 2010_005718__07 383 | 2010_005860__07 384 | 2010_005899__08 385 | 2010_006070__07 386 | 2011_000178__06 387 | 2011_000226__09 388 | 2011_000239__06 389 | 2011_000248__06 390 | 2011_000312__06 391 | 2011_000338__09 392 | 2011_000419__08 393 | 2011_000503__07 394 | 2011_000548__10 395 | 2011_000566__10 396 | 2011_000607__09 397 | 2011_000661__08 398 | 2011_000661__09 399 | 2011_000780__08 400 | 2011_000789__08 401 | 2011_000809__09 402 | 2011_000813__08 403 | 2011_000813__09 404 | 2011_000830__06 405 | 2011_000843__09 406 | 2011_000888__06 407 | 2011_000900__07 408 | 2011_000969__06 409 | 2011_001047__10 410 | 2011_001064__06 411 | 2011_001071__09 412 | 2011_001110__07 413 | 2011_001159__10 414 | 2011_001232__10 415 | 2011_001292__08 416 | 2011_001341__06 417 | 2011_001346__09 418 | 2011_001447__09 419 | 2011_001530__10 420 | 2011_001534__08 421 | 2011_001546__10 422 | 2011_001567__09 423 | 2011_001597__08 424 | 2011_001601__08 425 | 2011_001607__08 426 | 2011_001665__09 427 | 2011_001708__10 428 | 2011_001775__08 429 | 2011_001782__10 430 | 2011_001812__09 431 | 2011_002041__09 432 | 2011_002064__07 433 | 2011_002124__09 434 | 2011_002200__09 435 | 2011_002298__09 436 | 2011_002322__07 437 | 2011_002343__09 438 | 2011_002358__09 439 | 2011_002391__09 440 | 2011_002509__09 441 | 2011_002592__07 442 | 2011_002644__09 443 | 2011_002685__08 444 | 2011_002812__07 445 | 2011_002885__10 446 | 2011_003011__09 447 | 2011_003019__07 448 | 2011_003019__10 449 | 2011_003055__07 450 | 2011_003103__09 451 | 2011_003114__06 452 | -------------------------------------------------------------------------------- /data/splits/pascal/val/fold2.txt: -------------------------------------------------------------------------------- 1 | 2007_000129__15 2 | 2007_000323__15 3 | 2007_000332__13 4 | 2007_000346__15 5 | 2007_000762__11 6 | 2007_000762__15 7 | 2007_000783__13 8 | 2007_000783__15 9 | 2007_000799__13 10 | 2007_000799__15 11 | 2007_000830__11 12 | 2007_000847__11 13 | 2007_000847__15 14 | 2007_000999__15 15 | 2007_001175__15 16 | 2007_001239__12 17 | 2007_001284__15 18 | 2007_001311__15 19 | 2007_001408__15 20 | 2007_001423__15 21 | 2007_001430__11 22 | 2007_001430__15 23 | 2007_001526__15 24 | 2007_001585__15 25 | 2007_001586__13 26 | 2007_001586__15 27 | 2007_001594__15 28 | 2007_001630__15 29 | 2007_001677__11 30 | 2007_001678__15 31 | 2007_001717__15 32 | 2007_001763__12 33 | 2007_001955__13 34 | 2007_002046__13 35 | 2007_002119__15 36 | 2007_002260__14 37 | 2007_002268__12 38 | 2007_002378__15 39 | 2007_002426__15 40 | 2007_002539__15 41 | 2007_002565__15 42 | 2007_002597__12 43 | 2007_002624__11 44 | 2007_002624__15 45 | 2007_002643__15 46 | 2007_002728__15 47 | 2007_002823__14 48 | 2007_002823__15 49 | 2007_002824__15 50 | 2007_002852__12 51 | 2007_003011__11 52 | 2007_003020__15 53 | 2007_003022__13 54 | 2007_003022__15 55 | 2007_003088__15 56 | 2007_003106__15 57 | 2007_003110__12 58 | 2007_003134__15 59 | 2007_003188__15 60 | 2007_003194__12 61 | 2007_003367__14 62 | 2007_003367__15 63 | 2007_003373__12 64 | 2007_003373__15 65 | 2007_003530__15 66 | 2007_003621__15 67 | 2007_003742__11 68 | 2007_003742__15 69 | 2007_003872__12 70 | 2007_004033__14 71 | 2007_004033__15 72 | 2007_004112__12 73 | 2007_004112__15 74 | 2007_004121__15 75 | 2007_004189__12 76 | 2007_004275__14 77 | 2007_004275__15 78 | 2007_004281__15 79 | 2007_004380__14 80 | 2007_004380__15 81 | 2007_004392__15 82 | 2007_004405__11 83 | 2007_004538__13 84 | 2007_004538__15 85 | 2007_004644__12 86 | 2007_004712__11 87 | 2007_004712__15 88 | 2007_004722__13 89 | 2007_004722__15 90 | 2007_004902__13 91 | 2007_004902__15 92 | 2007_005114__13 93 | 2007_005114__15 94 | 2007_005149__12 95 | 2007_005173__14 96 | 2007_005173__15 97 | 2007_005281__15 98 | 2007_005304__15 99 | 2007_005331__13 100 | 2007_005331__15 101 | 2007_005354__14 102 | 2007_005354__15 103 | 2007_005509__15 104 | 2007_005547__15 105 | 2007_005608__14 106 | 2007_005608__15 107 | 2007_005696__12 108 | 2007_005759__14 109 | 2007_005803__11 110 | 2007_005844__11 111 | 2007_005845__15 112 | 2007_006028__15 113 | 2007_006076__15 114 | 2007_006086__11 115 | 2007_006117__15 116 | 2007_006171__12 117 | 2007_006171__15 118 | 2007_006241__11 119 | 2007_006364__13 120 | 2007_006364__15 121 | 2007_006373__15 122 | 2007_006444__12 123 | 2007_006444__15 124 | 2007_006560__15 125 | 2007_006647__14 126 | 2007_006647__15 127 | 2007_006698__15 128 | 2007_006802__15 129 | 2007_006841__15 130 | 2007_006864__15 131 | 2007_006866__13 132 | 2007_006866__15 133 | 2007_007007__11 134 | 2007_007007__15 135 | 2007_007109__13 136 | 2007_007109__15 137 | 2007_007195__15 138 | 2007_007203__15 139 | 2007_007211__14 140 | 2007_007235__15 141 | 2007_007417__12 142 | 2007_007493__15 143 | 2007_007498__11 144 | 2007_007498__15 145 | 2007_007651__11 146 | 2007_007651__15 147 | 2007_007688__14 148 | 2007_007748__13 149 | 2007_007748__15 150 | 2007_007795__15 151 | 2007_007810__11 152 | 2007_007810__15 153 | 2007_007815__15 154 | 2007_007836__15 155 | 2007_007849__15 156 | 2007_007996__15 157 | 2007_008110__15 158 | 2007_008204__15 159 | 2007_008222__12 160 | 2007_008256__13 161 | 2007_008256__15 162 | 2007_008260__12 163 | 2007_008374__15 164 | 2007_008415__12 165 | 2007_008430__15 166 | 2007_008596__13 167 | 2007_008596__15 168 | 2007_008708__15 169 | 2007_008802__13 170 | 2007_008897__15 171 | 2007_008944__15 172 | 2007_008964__12 173 | 2007_008964__15 174 | 2007_008980__12 175 | 2007_009068__15 176 | 2007_009084__12 177 | 2007_009084__14 178 | 2007_009251__13 179 | 2007_009251__15 180 | 2007_009258__15 181 | 2007_009320__15 182 | 2007_009331__12 183 | 2007_009331__13 184 | 2007_009331__15 185 | 2007_009413__11 186 | 2007_009413__15 187 | 2007_009521__11 188 | 2007_009562__12 189 | 2007_009592__12 190 | 2007_009654__15 191 | 2007_009655__15 192 | 2007_009684__15 193 | 2007_009687__15 194 | 2007_009691__14 195 | 2007_009691__15 196 | 2007_009706__11 197 | 2007_009750__15 198 | 2007_009756__14 199 | 2007_009756__15 200 | 2007_009841__13 201 | 2007_009938__14 202 | 2008_000080__12 203 | 2008_000213__15 204 | 2008_000215__15 205 | 2008_000223__15 206 | 2008_000233__15 207 | 2008_000234__15 208 | 2008_000239__12 209 | 2008_000270__12 210 | 2008_000270__15 211 | 2008_000271__15 212 | 2008_000359__15 213 | 2008_000474__15 214 | 2008_000510__15 215 | 2008_000573__11 216 | 2008_000573__15 217 | 2008_000602__13 218 | 2008_000630__15 219 | 2008_000661__12 220 | 2008_000661__15 221 | 2008_000662__15 222 | 2008_000666__15 223 | 2008_000673__15 224 | 2008_000700__15 225 | 2008_000725__15 226 | 2008_000731__15 227 | 2008_000763__11 228 | 2008_000763__15 229 | 2008_000765__13 230 | 2008_000782__14 231 | 2008_000795__15 232 | 2008_000811__14 233 | 2008_000811__15 234 | 2008_000863__12 235 | 2008_000943__12 236 | 2008_000992__15 237 | 2008_001013__15 238 | 2008_001028__15 239 | 2008_001070__12 240 | 2008_001074__15 241 | 2008_001076__15 242 | 2008_001150__14 243 | 2008_001170__15 244 | 2008_001231__15 245 | 2008_001249__15 246 | 2008_001283__15 247 | 2008_001308__15 248 | 2008_001379__12 249 | 2008_001404__15 250 | 2008_001478__12 251 | 2008_001491__15 252 | 2008_001504__15 253 | 2008_001531__15 254 | 2008_001547__15 255 | 2008_001629__15 256 | 2008_001682__13 257 | 2008_001821__15 258 | 2008_001874__15 259 | 2008_001895__12 260 | 2008_001895__15 261 | 2008_001992__13 262 | 2008_001992__15 263 | 2008_002212__15 264 | 2008_002239__12 265 | 2008_002240__14 266 | 2008_002241__15 267 | 2008_002379__11 268 | 2008_002383__14 269 | 2008_002495__15 270 | 2008_002536__12 271 | 2008_002588__15 272 | 2008_002775__11 273 | 2008_002775__15 274 | 2008_002835__13 275 | 2008_002835__15 276 | 2008_002859__12 277 | 2008_002864__11 278 | 2008_002864__15 279 | 2008_002904__12 280 | 2008_002929__15 281 | 2008_002936__12 282 | 2008_002942__15 283 | 2008_002958__12 284 | 2008_003034__15 285 | 2008_003076__15 286 | 2008_003108__15 287 | 2008_003141__15 288 | 2008_003210__15 289 | 2008_003238__12 290 | 2008_003238__15 291 | 2008_003330__15 292 | 2008_003333__14 293 | 2008_003333__15 294 | 2008_003379__13 295 | 2008_003451__14 296 | 2008_003451__15 297 | 2008_003461__13 298 | 2008_003461__15 299 | 2008_003477__11 300 | 2008_003492__15 301 | 2008_003511__12 302 | 2008_003511__15 303 | 2008_003546__15 304 | 2008_003576__12 305 | 2008_003676__15 306 | 2008_003733__15 307 | 2008_003782__13 308 | 2008_003856__15 309 | 2008_003874__15 310 | 2008_004101__15 311 | 2008_004140__11 312 | 2008_004140__15 313 | 2008_004175__13 314 | 2008_004345__14 315 | 2008_004396__13 316 | 2008_004399__14 317 | 2008_004399__15 318 | 2008_004575__11 319 | 2008_004575__15 320 | 2008_004624__13 321 | 2008_004654__15 322 | 2008_004687__13 323 | 2008_004705__13 324 | 2008_005049__14 325 | 2008_005089__15 326 | 2008_005145__11 327 | 2008_005197__12 328 | 2008_005197__15 329 | 2008_005245__14 330 | 2008_005245__15 331 | 2008_005399__15 332 | 2008_005422__14 333 | 2008_005445__15 334 | 2008_005525__13 335 | 2008_005637__14 336 | 2008_005642__13 337 | 2008_005691__13 338 | 2008_005738__15 339 | 2008_005812__15 340 | 2008_005915__14 341 | 2008_006008__11 342 | 2008_006036__13 343 | 2008_006108__11 344 | 2008_006108__15 345 | 2008_006130__12 346 | 2008_006216__15 347 | 2008_006219__13 348 | 2008_006254__15 349 | 2008_006275__15 350 | 2008_006341__15 351 | 2008_006408__11 352 | 2008_006408__15 353 | 2008_006526__14 354 | 2008_006526__15 355 | 2008_006554__15 356 | 2008_006722__12 357 | 2008_006722__15 358 | 2008_006874__14 359 | 2008_006874__15 360 | 2008_006981__12 361 | 2008_007048__11 362 | 2008_007219__15 363 | 2008_007378__11 364 | 2008_007378__12 365 | 2008_007392__13 366 | 2008_007392__15 367 | 2008_007402__11 368 | 2008_007402__15 369 | 2008_007513__12 370 | 2008_007737__15 371 | 2008_007828__15 372 | 2008_007945__13 373 | 2008_007994__15 374 | 2008_008051__11 375 | 2008_008127__14 376 | 2008_008127__15 377 | 2008_008221__15 378 | 2008_008335__11 379 | 2008_008335__15 380 | 2008_008362__11 381 | 2008_008362__15 382 | 2008_008392__13 383 | 2008_008393__13 384 | 2008_008421__13 385 | 2008_008469__15 386 | 2009_000012__13 387 | 2009_000074__14 388 | 2009_000074__15 389 | 2009_000156__12 390 | 2009_000219__15 391 | 2009_000309__15 392 | 2009_000412__13 393 | 2009_000418__15 394 | 2009_000421__15 395 | 2009_000457__15 396 | 2009_000704__15 397 | 2009_000705__13 398 | 2009_000727__13 399 | 2009_000730__14 400 | 2009_000730__15 401 | 2009_000825__14 402 | 2009_000825__15 403 | 2009_000839__12 404 | 2009_000892__12 405 | 2009_000931__13 406 | 2009_000935__12 407 | 2009_001215__11 408 | 2009_001215__15 409 | 2009_001299__15 410 | 2009_001433__13 411 | 2009_001433__15 412 | 2009_001535__12 413 | 2009_001663__15 414 | 2009_001687__12 415 | 2009_001687__15 416 | 2009_001718__15 417 | 2009_001768__15 418 | 2009_001854__15 419 | 2009_002012__12 420 | 2009_002042__15 421 | 2009_002097__13 422 | 2009_002155__12 423 | 2009_002165__13 424 | 2009_002185__15 425 | 2009_002239__14 426 | 2009_002239__15 427 | 2009_002317__14 428 | 2009_002317__15 429 | 2009_002346__12 430 | 2009_002346__15 431 | 2009_002372__15 432 | 2009_002382__14 433 | 2009_002382__15 434 | 2009_002415__11 435 | 2009_002445__12 436 | 2009_002487__11 437 | 2009_002539__12 438 | 2009_002571__11 439 | 2009_002584__15 440 | 2009_002649__15 441 | 2009_002651__14 442 | 2009_002651__15 443 | 2009_002732__15 444 | 2009_002975__13 445 | 2009_003003__11 446 | 2009_003003__15 447 | 2009_003063__12 448 | 2009_003065__15 449 | 2009_003071__11 450 | 2009_003071__15 451 | 2009_003123__11 452 | 2009_003196__14 453 | 2009_003217__12 454 | 2009_003241__12 455 | 2009_003269__15 456 | 2009_003323__13 457 | 2009_003323__15 458 | 2009_003466__12 459 | 2009_003481__13 460 | 2009_003494__15 461 | 2009_003507__11 462 | 2009_003576__14 463 | 2009_003576__15 464 | 2009_003756__12 465 | 2009_003804__13 466 | 2009_003810__12 467 | 2009_003849__11 468 | 2009_003849__15 469 | 2009_003903__13 470 | 2009_003928__12 471 | 2009_003991__11 472 | 2009_003991__15 473 | 2009_004033__12 474 | 2009_004043__14 475 | 2009_004043__15 476 | 2009_004140__11 477 | 2009_004221__15 478 | 2009_004455__14 479 | 2009_004497__13 480 | 2009_004507__12 481 | 2009_004507__15 482 | 2009_004581__12 483 | 2009_004592__12 484 | 2009_004738__14 485 | 2009_004738__15 486 | 2009_004848__15 487 | 2009_004859__11 488 | 2009_004859__15 489 | 2009_004942__13 490 | 2009_004987__14 491 | 2009_004987__15 492 | 2009_004994__12 493 | 2009_004994__15 494 | 2009_005038__11 495 | 2009_005038__15 496 | 2009_005078__14 497 | 2009_005087__15 498 | 2009_005217__13 499 | 2009_005217__15 500 | 2010_000003__12 501 | 2010_000038__13 502 | 2010_000038__15 503 | 2010_000087__14 504 | 2010_000087__15 505 | 2010_000110__12 506 | 2010_000110__15 507 | 2010_000159__12 508 | 2010_000174__11 509 | 2010_000174__15 510 | 2010_000216__12 511 | 2010_000238__15 512 | 2010_000256__15 513 | 2010_000422__12 514 | 2010_000530__15 515 | 2010_000559__15 516 | 2010_000639__12 517 | 2010_000666__13 518 | 2010_000666__15 519 | 2010_000738__15 520 | 2010_000788__12 521 | 2010_000874__13 522 | 2010_000904__12 523 | 2010_001024__15 524 | 2010_001124__12 525 | 2010_001251__14 526 | 2010_001264__12 527 | 2010_001313__14 528 | 2010_001313__15 529 | 2010_001367__15 530 | 2010_001376__12 531 | 2010_001451__13 532 | 2010_001553__14 533 | 2010_001563__12 534 | 2010_001563__15 535 | 2010_001579__11 536 | 2010_001579__15 537 | 2010_001692__15 538 | 2010_001699__15 539 | 2010_001734__15 540 | 2010_001767__15 541 | 2010_001851__11 542 | 2010_001908__12 543 | 2010_001956__12 544 | 2010_002017__15 545 | 2010_002137__15 546 | 2010_002161__13 547 | 2010_002161__15 548 | 2010_002228__12 549 | 2010_002251__14 550 | 2010_002251__15 551 | 2010_002271__14 552 | 2010_002336__11 553 | 2010_002396__14 554 | 2010_002396__15 555 | 2010_002480__12 556 | 2010_002623__15 557 | 2010_002691__13 558 | 2010_002763__15 559 | 2010_002792__15 560 | 2010_002902__15 561 | 2010_002929__15 562 | 2010_003014__15 563 | 2010_003060__12 564 | 2010_003187__12 565 | 2010_003207__14 566 | 2010_003239__15 567 | 2010_003325__11 568 | 2010_003325__15 569 | 2010_003381__15 570 | 2010_003409__15 571 | 2010_003446__15 572 | 2010_003506__12 573 | 2010_003531__11 574 | 2010_003532__13 575 | 2010_003597__11 576 | 2010_003597__15 577 | 2010_003746__12 578 | 2010_003746__15 579 | 2010_003947__14 580 | 2010_003971__11 581 | 2010_004042__14 582 | 2010_004165__12 583 | 2010_004165__15 584 | 2010_004219__14 585 | 2010_004219__15 586 | 2010_004337__15 587 | 2010_004355__14 588 | 2010_004432__15 589 | 2010_004472__15 590 | 2010_004479__15 591 | 2010_004519__13 592 | 2010_004550__12 593 | 2010_004559__15 594 | 2010_004628__12 595 | 2010_004697__14 596 | 2010_004697__15 597 | 2010_004795__12 598 | 2010_004815__15 599 | 2010_004825__11 600 | 2010_004828__15 601 | 2010_004856__13 602 | 2010_004941__14 603 | 2010_004951__15 604 | 2010_005046__11 605 | 2010_005046__15 606 | 2010_005118__15 607 | 2010_005159__12 608 | 2010_005160__14 609 | 2010_005166__15 610 | 2010_005174__13 611 | 2010_005206__12 612 | 2010_005245__12 613 | 2010_005245__15 614 | 2010_005252__14 615 | 2010_005252__15 616 | 2010_005284__15 617 | 2010_005366__14 618 | 2010_005433__14 619 | 2010_005501__14 620 | 2010_005575__12 621 | 2010_005582__15 622 | 2010_005606__15 623 | 2010_005626__11 624 | 2010_005626__15 625 | 2010_005644__12 626 | 2010_005709__15 627 | 2010_005871__15 628 | 2010_005991__12 629 | 2010_005991__15 630 | 2010_005992__12 631 | 2011_000045__12 632 | 2011_000051__15 633 | 2011_000054__15 634 | 2011_000178__15 635 | 2011_000226__11 636 | 2011_000248__15 637 | 2011_000338__11 638 | 2011_000396__13 639 | 2011_000435__15 640 | 2011_000438__15 641 | 2011_000455__14 642 | 2011_000455__15 643 | 2011_000479__15 644 | 2011_000512__14 645 | 2011_000526__13 646 | 2011_000536__12 647 | 2011_000566__15 648 | 2011_000585__15 649 | 2011_000598__11 650 | 2011_000618__14 651 | 2011_000618__15 652 | 2011_000638__15 653 | 2011_000780__15 654 | 2011_000809__11 655 | 2011_000809__15 656 | 2011_000843__15 657 | 2011_000953__11 658 | 2011_000953__15 659 | 2011_001014__12 660 | 2011_001060__15 661 | 2011_001069__15 662 | 2011_001071__15 663 | 2011_001159__15 664 | 2011_001276__11 665 | 2011_001276__12 666 | 2011_001276__15 667 | 2011_001346__15 668 | 2011_001416__15 669 | 2011_001447__15 670 | 2011_001530__15 671 | 2011_001567__15 672 | 2011_001619__15 673 | 2011_001642__12 674 | 2011_001665__11 675 | 2011_001674__15 676 | 2011_001714__12 677 | 2011_001714__15 678 | 2011_001722__13 679 | 2011_001745__12 680 | 2011_001794__15 681 | 2011_001862__11 682 | 2011_001862__12 683 | 2011_001868__12 684 | 2011_001984__12 685 | 2011_001988__15 686 | 2011_002002__15 687 | 2011_002040__12 688 | 2011_002075__11 689 | 2011_002075__15 690 | 2011_002098__12 691 | 2011_002110__12 692 | 2011_002110__15 693 | 2011_002121__12 694 | 2011_002124__15 695 | 2011_002156__12 696 | 2011_002200__11 697 | 2011_002200__15 698 | 2011_002247__15 699 | 2011_002279__12 700 | 2011_002298__12 701 | 2011_002308__15 702 | 2011_002317__15 703 | 2011_002322__14 704 | 2011_002322__15 705 | 2011_002343__15 706 | 2011_002358__11 707 | 2011_002358__15 708 | 2011_002371__12 709 | 2011_002498__15 710 | 2011_002509__15 711 | 2011_002532__15 712 | 2011_002575__15 713 | 2011_002578__15 714 | 2011_002589__12 715 | 2011_002623__15 716 | 2011_002641__15 717 | 2011_002675__15 718 | 2011_002951__13 719 | 2011_002997__15 720 | 2011_003019__14 721 | 2011_003019__15 722 | 2011_003085__13 723 | 2011_003114__15 724 | 2011_003240__15 725 | 2011_003256__12 726 | -------------------------------------------------------------------------------- /data/splits/pascal/val/fold3.txt: -------------------------------------------------------------------------------- 1 | 2007_000042__19 2 | 2007_000123__19 3 | 2007_000175__17 4 | 2007_000187__20 5 | 2007_000452__18 6 | 2007_000559__20 7 | 2007_000629__19 8 | 2007_000636__19 9 | 2007_000661__18 10 | 2007_000676__17 11 | 2007_000804__18 12 | 2007_000925__17 13 | 2007_001154__18 14 | 2007_001175__20 15 | 2007_001408__16 16 | 2007_001430__16 17 | 2007_001430__20 18 | 2007_001457__18 19 | 2007_001458__18 20 | 2007_001585__18 21 | 2007_001594__17 22 | 2007_001678__20 23 | 2007_001717__20 24 | 2007_001733__17 25 | 2007_001763__18 26 | 2007_001763__20 27 | 2007_002119__20 28 | 2007_002132__20 29 | 2007_002268__18 30 | 2007_002284__16 31 | 2007_002378__16 32 | 2007_002426__18 33 | 2007_002427__18 34 | 2007_002565__19 35 | 2007_002618__17 36 | 2007_002648__17 37 | 2007_002728__19 38 | 2007_003011__18 39 | 2007_003011__20 40 | 2007_003169__18 41 | 2007_003367__16 42 | 2007_003499__19 43 | 2007_003506__16 44 | 2007_003530__18 45 | 2007_003587__19 46 | 2007_003714__17 47 | 2007_003848__19 48 | 2007_003957__19 49 | 2007_004190__20 50 | 2007_004193__20 51 | 2007_004275__16 52 | 2007_004281__19 53 | 2007_004483__19 54 | 2007_004510__20 55 | 2007_004558__16 56 | 2007_004649__19 57 | 2007_004712__16 58 | 2007_004969__17 59 | 2007_005469__17 60 | 2007_005626__19 61 | 2007_005689__19 62 | 2007_005813__16 63 | 2007_005857__16 64 | 2007_005915__17 65 | 2007_006171__18 66 | 2007_006348__20 67 | 2007_006373__18 68 | 2007_006678__17 69 | 2007_006680__19 70 | 2007_006802__19 71 | 2007_007130__20 72 | 2007_007165__17 73 | 2007_007168__19 74 | 2007_007195__19 75 | 2007_007196__20 76 | 2007_007203__20 77 | 2007_007417__18 78 | 2007_007534__17 79 | 2007_007624__16 80 | 2007_007795__16 81 | 2007_007881__19 82 | 2007_007996__18 83 | 2007_008204__20 84 | 2007_008260__18 85 | 2007_008339__19 86 | 2007_008374__20 87 | 2007_008543__18 88 | 2007_008547__16 89 | 2007_009068__18 90 | 2007_009252__18 91 | 2007_009320__17 92 | 2007_009419__16 93 | 2007_009446__20 94 | 2007_009521__18 95 | 2007_009521__20 96 | 2007_009592__18 97 | 2007_009655__18 98 | 2007_009684__18 99 | 2007_009750__16 100 | 2008_000016__20 101 | 2008_000149__18 102 | 2008_000270__18 103 | 2008_000391__16 104 | 2008_000589__18 105 | 2008_000657__19 106 | 2008_001078__16 107 | 2008_001283__16 108 | 2008_001688__16 109 | 2008_001688__20 110 | 2008_001966__16 111 | 2008_002273__16 112 | 2008_002379__16 113 | 2008_002464__20 114 | 2008_002536__17 115 | 2008_002680__20 116 | 2008_002900__19 117 | 2008_002929__18 118 | 2008_003003__20 119 | 2008_003026__20 120 | 2008_003105__19 121 | 2008_003135__16 122 | 2008_003676__16 123 | 2008_003709__18 124 | 2008_003733__18 125 | 2008_003885__20 126 | 2008_004172__18 127 | 2008_004212__19 128 | 2008_004279__20 129 | 2008_004367__19 130 | 2008_004453__17 131 | 2008_004477__16 132 | 2008_004562__18 133 | 2008_004610__19 134 | 2008_004621__17 135 | 2008_004754__20 136 | 2008_004854__17 137 | 2008_004910__20 138 | 2008_005089__20 139 | 2008_005217__16 140 | 2008_005242__16 141 | 2008_005254__20 142 | 2008_005439__20 143 | 2008_005445__20 144 | 2008_005544__19 145 | 2008_005633__17 146 | 2008_005680__16 147 | 2008_006055__19 148 | 2008_006159__20 149 | 2008_006327__17 150 | 2008_006523__19 151 | 2008_006553__19 152 | 2008_006752__19 153 | 2008_006784__18 154 | 2008_006835__17 155 | 2008_007497__17 156 | 2008_007527__20 157 | 2008_007677__17 158 | 2008_007814__17 159 | 2008_007828__20 160 | 2008_008103__18 161 | 2008_008221__19 162 | 2008_008434__16 163 | 2009_000022__19 164 | 2009_000039__17 165 | 2009_000087__18 166 | 2009_000096__18 167 | 2009_000136__20 168 | 2009_000242__18 169 | 2009_000391__20 170 | 2009_000418__16 171 | 2009_000418__18 172 | 2009_000487__18 173 | 2009_000488__16 174 | 2009_000488__20 175 | 2009_000628__19 176 | 2009_000675__17 177 | 2009_000704__20 178 | 2009_000712__19 179 | 2009_000732__18 180 | 2009_000845__19 181 | 2009_000924__17 182 | 2009_001300__19 183 | 2009_001333__19 184 | 2009_001363__20 185 | 2009_001505__17 186 | 2009_001644__16 187 | 2009_001644__18 188 | 2009_001644__20 189 | 2009_001684__16 190 | 2009_001731__18 191 | 2009_001768__17 192 | 2009_001775__16 193 | 2009_001775__18 194 | 2009_001991__17 195 | 2009_002082__17 196 | 2009_002094__20 197 | 2009_002202__19 198 | 2009_002265__19 199 | 2009_002291__19 200 | 2009_002346__18 201 | 2009_002366__20 202 | 2009_002390__18 203 | 2009_002487__16 204 | 2009_002562__20 205 | 2009_002568__19 206 | 2009_002571__16 207 | 2009_002571__18 208 | 2009_002573__20 209 | 2009_002584__16 210 | 2009_002638__19 211 | 2009_002732__18 212 | 2009_002887__19 213 | 2009_002982__19 214 | 2009_003105__19 215 | 2009_003123__18 216 | 2009_003299__19 217 | 2009_003311__19 218 | 2009_003433__19 219 | 2009_003523__20 220 | 2009_003551__20 221 | 2009_003564__16 222 | 2009_003564__18 223 | 2009_003607__18 224 | 2009_003666__17 225 | 2009_003857__20 226 | 2009_003895__18 227 | 2009_003895__20 228 | 2009_003938__19 229 | 2009_004099__18 230 | 2009_004140__18 231 | 2009_004255__19 232 | 2009_004298__18 233 | 2009_004687__18 234 | 2009_004730__19 235 | 2009_004799__19 236 | 2009_004993__18 237 | 2009_004993__20 238 | 2009_005148__19 239 | 2009_005220__19 240 | 2010_000256__18 241 | 2010_000284__18 242 | 2010_000309__17 243 | 2010_000318__20 244 | 2010_000330__16 245 | 2010_000639__16 246 | 2010_000738__20 247 | 2010_000764__19 248 | 2010_001011__17 249 | 2010_001079__17 250 | 2010_001104__19 251 | 2010_001149__18 252 | 2010_001151__19 253 | 2010_001246__16 254 | 2010_001256__17 255 | 2010_001327__18 256 | 2010_001367__20 257 | 2010_001522__17 258 | 2010_001557__17 259 | 2010_001577__17 260 | 2010_001699__16 261 | 2010_001734__19 262 | 2010_001752__20 263 | 2010_001767__18 264 | 2010_001773__16 265 | 2010_001851__16 266 | 2010_001951__19 267 | 2010_001962__18 268 | 2010_002106__17 269 | 2010_002137__16 270 | 2010_002137__18 271 | 2010_002232__17 272 | 2010_002531__18 273 | 2010_002682__19 274 | 2010_002921__20 275 | 2010_003014__18 276 | 2010_003123__16 277 | 2010_003302__16 278 | 2010_003514__19 279 | 2010_003541__17 280 | 2010_003597__18 281 | 2010_003781__16 282 | 2010_003956__19 283 | 2010_004149__19 284 | 2010_004226__17 285 | 2010_004382__16 286 | 2010_004479__20 287 | 2010_004757__16 288 | 2010_004757__18 289 | 2010_004783__18 290 | 2010_004825__16 291 | 2010_004857__20 292 | 2010_004951__19 293 | 2010_004980__19 294 | 2010_005180__18 295 | 2010_005187__16 296 | 2010_005305__20 297 | 2010_005606__18 298 | 2010_005706__19 299 | 2010_005719__17 300 | 2010_005727__19 301 | 2010_005788__17 302 | 2010_005860__16 303 | 2010_005871__19 304 | 2010_005991__18 305 | 2010_006054__19 306 | 2011_000070__18 307 | 2011_000173__18 308 | 2011_000283__19 309 | 2011_000291__19 310 | 2011_000310__18 311 | 2011_000436__17 312 | 2011_000521__19 313 | 2011_000747__16 314 | 2011_001005__18 315 | 2011_001060__19 316 | 2011_001281__19 317 | 2011_001350__17 318 | 2011_001567__18 319 | 2011_001601__18 320 | 2011_001614__19 321 | 2011_001674__18 322 | 2011_001713__16 323 | 2011_001713__18 324 | 2011_001726__20 325 | 2011_001794__18 326 | 2011_001862__18 327 | 2011_001863__16 328 | 2011_001910__20 329 | 2011_002124__18 330 | 2011_002156__20 331 | 2011_002178__17 332 | 2011_002247__19 333 | 2011_002379__19 334 | 2011_002391__18 335 | 2011_002532__20 336 | 2011_002535__19 337 | 2011_002644__18 338 | 2011_002644__20 339 | 2011_002879__18 340 | 2011_002879__20 341 | 2011_003103__16 342 | 2011_003103__18 343 | 2011_003146__19 344 | 2011_003182__18 345 | 2011_003197__19 346 | 2011_003256__18 347 | -------------------------------------------------------------------------------- /generate_cam_coco.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import clip 3 | from PIL import Image 4 | from pytorch_grad_cam import GradCAM 5 | import cv2 6 | import argparse 7 | from data.dataset import FSSDataset 8 | import pdb 9 | 10 | COCO_CLASSES = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 11 | 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 12 | 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 13 | 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 14 | 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 15 | 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 16 | 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 17 | 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 18 | 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 19 | 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 20 | 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 21 | 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 22 | 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 23 | 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'] 24 | 25 | 26 | def get_cam_from_alldata(clip_model, preprocess, split='train', d0=None, d1=None, d2=None, d3=None, 27 | datapath=None, campath=None): 28 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 29 | 30 | d0 = d0.dataset.img_metadata_classwise 31 | d1 = d1.dataset.img_metadata_classwise 32 | d2 = d2.dataset.img_metadata_classwise 33 | d3 = d3.dataset.img_metadata_classwise 34 | dd = [d0, d1, d2, d3] 35 | dataset_all = {} 36 | 37 | if split == 'train': 38 | for ii in range(80): 39 | index = ii % 4 + 1 40 | if ii % 4 == 3: 41 | index = 0 42 | dataset_all[ii] = dd[index][ii] 43 | else: 44 | for ii in range(80): 45 | index = ii % 4 46 | dataset_all[ii] = dd[index][ii] 47 | del d0, d1, d2, d3, dd 48 | 49 | text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in COCO_CLASSES]).to(device) 50 | for cls_id in range(80): 51 | L = len(dataset_all[cls_id]) 52 | for ll in range(L): 53 | img_path = datapath + dataset_all[cls_id][ll] 54 | img = Image.open(img_path) 55 | img_input = preprocess(img).unsqueeze(0).to(device) 56 | class_name_id = cls_id 57 | 58 | # CAM 59 | clip_model.get_text_features(text_inputs) 60 | target_layers = [clip_model.visual.layer4[-1]] 61 | input_tensor = img_input 62 | cam = GradCAM(model=clip_model, target_layers=target_layers, use_cuda=True) 63 | target_category = class_name_id 64 | grayscale_cam = cam(input_tensor=input_tensor, target_category=target_category) 65 | grayscale_cam = grayscale_cam[0, :] 66 | grayscale_cam = cv2.resize(grayscale_cam, (50, 50)) 67 | grayscale_cam = torch.from_numpy(grayscale_cam) 68 | save_path = campath + dataset_all[cls_id][ll] + '--' + str(class_name_id) + '.pt' 69 | torch.save(grayscale_cam, save_path) 70 | print('cam saved in ', save_path) 71 | 72 | 73 | 74 | if __name__ == '__main__': 75 | parser = argparse.ArgumentParser(description='IMR') 76 | parser.add_argument('--imgpath', type=str, default='../Datasets_HSN/COCO2014/') 77 | parser.add_argument('--campath', type=str, default='../Datasets_HSN/CAM_Val_COCO/') 78 | args = parser.parse_args() 79 | 80 | 81 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 82 | model_clip, preprocess = clip.load('RN50', device, jit=False) 83 | FSSDataset.initialize(img_size=400, datapath='../Datasets_HSN', use_original_imgsize=False) 84 | 85 | 86 | # COCO-meta-train 87 | # train 88 | dataloader_test0 = FSSDataset.build_dataloader('coco', 1, 0, 0, 'train', 1) 89 | dataloader_test1 = FSSDataset.build_dataloader('coco', 1, 0, 1, 'train', 1) 90 | dataloader_test2 = FSSDataset.build_dataloader('coco', 1, 0, 2, 'train', 1) 91 | dataloader_test3 = FSSDataset.build_dataloader('coco', 1, 0, 3, 'train', 1) 92 | get_cam_from_alldata(model_clip, preprocess, split='train', 93 | d0=dataloader_test0, d1=dataloader_test1, 94 | d2=dataloader_test2, d3=dataloader_test3, 95 | datapath=args.imgpath, campath=args.campath) 96 | 97 | # val 98 | dataloader_test0 = FSSDataset.build_dataloader('coco', 1, 0, 0, 'val', 1) 99 | dataloader_test1 = FSSDataset.build_dataloader('coco', 1, 0, 1, 'val', 1) 100 | dataloader_test2 = FSSDataset.build_dataloader('coco', 1, 0, 2, 'val', 1) 101 | dataloader_test3 = FSSDataset.build_dataloader('coco', 1, 0, 3, 'val', 1) 102 | get_cam_from_alldata(model_clip, preprocess, split='val', 103 | d0=dataloader_test0, d1=dataloader_test1, 104 | d2=dataloader_test2, d3=dataloader_test3, 105 | datapath=args.imgpath, campath=args.campath) 106 | -------------------------------------------------------------------------------- /generate_cam_voc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import clip 3 | from PIL import Image 4 | from pytorch_grad_cam import GradCAM 5 | import cv2 6 | import argparse 7 | from data.dataset import FSSDataset 8 | import pdb 9 | 10 | PASCAL_CLASSES = ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 11 | 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 12 | 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 13 | 'tvmonitor'] 14 | 15 | 16 | def get_cam_from_alldata(clip_model, preprocess, d=None, datapath=None, campath=None): 17 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 18 | dataset_all = d.dataset.img_metadata 19 | L = len(dataset_all) 20 | text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in PASCAL_CLASSES]).to(device) 21 | for ll in range(L): 22 | img_path = datapath + dataset_all[ll][0] + '.jpg' 23 | img = Image.open(img_path) 24 | img_input = preprocess(img).unsqueeze(0).to(device) 25 | class_name_id = dataset_all[ll][1] 26 | clip_model.get_text_features(text_inputs) 27 | target_layers = [clip_model.visual.layer4[-1]] 28 | input_tensor = img_input 29 | cam = GradCAM(model=clip_model, target_layers=target_layers, use_cuda=True) 30 | target_category = class_name_id 31 | grayscale_cam = cam(input_tensor=input_tensor, target_category=target_category) 32 | grayscale_cam = grayscale_cam[0, :] 33 | grayscale_cam = cv2.resize(grayscale_cam, (50, 50)) 34 | grayscale_cam = torch.from_numpy(grayscale_cam) 35 | save_path = campath + dataset_all[ll][0] + '--' + str(class_name_id) + '.pt' 36 | torch.save(grayscale_cam, save_path) 37 | print('cam已经保存', save_path) 38 | 39 | 40 | 41 | if __name__ == '__main__': 42 | parser = argparse.ArgumentParser(description='IMR') 43 | parser.add_argument('--imgpath', type=str, default='../Datasets_HSN/VOC2012/JPEGImages/') 44 | parser.add_argument('--traincampath', type=str, default='../Datasets_HSN/CAM_Train/') 45 | parser.add_argument('--valcampath', type=str, default='../Datasets_HSN/CAM_Val/') 46 | args = parser.parse_args() 47 | 48 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 49 | model_clip, preprocess = clip.load('RN50', device, jit=False) 50 | FSSDataset.initialize(img_size=400, datapath='../Datasets_HSN', use_original_imgsize=False) 51 | 52 | # VOC 53 | # train 54 | dataloader_test0 = FSSDataset.build_dataloader('pascal', 1, 0, 0, 'train', 1) 55 | dataloader_test1 = FSSDataset.build_dataloader('pascal', 1, 0, 1, 'train', 1) 56 | dataloader_test2 = FSSDataset.build_dataloader('pascal', 1, 0, 2, 'train', 1) 57 | dataloader_test3 = FSSDataset.build_dataloader('pascal', 1, 0, 3, 'train', 1) 58 | 59 | get_cam_from_alldata(model_clip, preprocess, d=dataloader_test0, datapath=args.imgpath, campath=args.traincampath) 60 | get_cam_from_alldata(model_clip, preprocess, d=dataloader_test1, datapath=args.imgpath, campath=args.traincampath) 61 | get_cam_from_alldata(model_clip, preprocess, d=dataloader_test2, datapath=args.imgpath, campath=args.traincampath) 62 | get_cam_from_alldata(model_clip, preprocess, d=dataloader_test3, datapath=args.imgpath, campath=args.traincampath) 63 | 64 | # val 65 | dataloader_test0 = FSSDataset.build_dataloader('pascal', 1, 0, 0, 'val', 1) 66 | dataloader_test1 = FSSDataset.build_dataloader('pascal', 1, 0, 1, 'val', 1) 67 | dataloader_test2 = FSSDataset.build_dataloader('pascal', 1, 0, 2, 'val', 1) 68 | dataloader_test3 = FSSDataset.build_dataloader('pascal', 1, 0, 3, 'val', 1) 69 | 70 | get_cam_from_alldata(model_clip, preprocess, d=dataloader_test0, datapath=args.imgpath, campath=args.valcampath) 71 | get_cam_from_alldata(model_clip, preprocess, d=dataloader_test1, datapath=args.imgpath, campath=args.valcampath) 72 | get_cam_from_alldata(model_clip, preprocess, d=dataloader_test2, datapath=args.imgpath, campath=args.valcampath) 73 | get_cam_from_alldata(model_clip, preprocess, d=dataloader_test3, datapath=args.imgpath, campath=args.valcampath) 74 | 75 | -------------------------------------------------------------------------------- /model/__pycache__/hsnet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Whileherham/IMR-HSNet/57572272f493c3b16540a650ff0398d605e46812/model/__pycache__/hsnet.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/hsnet_imr.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Whileherham/IMR-HSNet/57572272f493c3b16540a650ff0398d605e46812/model/__pycache__/hsnet_imr.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/hsnet_raft_res_multi_group_xiaorong_grouponly.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Whileherham/IMR-HSNet/57572272f493c3b16540a650ff0398d605e46812/model/__pycache__/hsnet_raft_res_multi_group_xiaorong_grouponly.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/learner.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Whileherham/IMR-HSNet/57572272f493c3b16540a650ff0398d605e46812/model/__pycache__/learner.cpython-38.pyc -------------------------------------------------------------------------------- /model/base/__pycache__/conv4d.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Whileherham/IMR-HSNet/57572272f493c3b16540a650ff0398d605e46812/model/base/__pycache__/conv4d.cpython-38.pyc -------------------------------------------------------------------------------- /model/base/__pycache__/correlation.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Whileherham/IMR-HSNet/57572272f493c3b16540a650ff0398d605e46812/model/base/__pycache__/correlation.cpython-38.pyc -------------------------------------------------------------------------------- /model/base/__pycache__/feature.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Whileherham/IMR-HSNet/57572272f493c3b16540a650ff0398d605e46812/model/base/__pycache__/feature.cpython-38.pyc -------------------------------------------------------------------------------- /model/base/conv4d.py: -------------------------------------------------------------------------------- 1 | r""" Implementation of center-pivot 4D convolution """ 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class CenterPivotConv4d(nn.Module): 8 | r""" CenterPivot 4D conv""" 9 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding, bias=True): 10 | super(CenterPivotConv4d, self).__init__() 11 | 12 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size[:2], stride=stride[:2], 13 | bias=bias, padding=padding[:2]) 14 | self.conv2 = nn.Conv2d(in_channels, out_channels, kernel_size[2:], stride=stride[2:], 15 | bias=bias, padding=padding[2:]) 16 | 17 | self.stride34 = stride[2:] 18 | self.kernel_size = kernel_size 19 | self.stride = stride 20 | self.padding = padding 21 | self.idx_initialized = False 22 | 23 | def prune(self, ct): 24 | bsz, ch, ha, wa, hb, wb = ct.size() 25 | if not self.idx_initialized: 26 | idxh = torch.arange(start=0, end=hb, step=self.stride[2:][0], device=ct.device) 27 | idxw = torch.arange(start=0, end=wb, step=self.stride[2:][1], device=ct.device) 28 | self.len_h = len(idxh) 29 | self.len_w = len(idxw) 30 | self.idx = (idxw.repeat(self.len_h, 1) + idxh.repeat(self.len_w, 1).t() * wb).view(-1) 31 | self.idx_initialized = True 32 | ct_pruned = ct.view(bsz, ch, ha, wa, -1).index_select(4, self.idx).view(bsz, ch, ha, wa, self.len_h, self.len_w) 33 | 34 | return ct_pruned 35 | 36 | def forward(self, x): 37 | if self.stride[2:][-1] > 1: 38 | out1 = self.prune(x) 39 | else: 40 | out1 = x 41 | bsz, inch, ha, wa, hb, wb = out1.size() 42 | out1 = out1.permute(0, 4, 5, 1, 2, 3).contiguous().view(-1, inch, ha, wa) 43 | out1 = self.conv1(out1) 44 | outch, o_ha, o_wa = out1.size(-3), out1.size(-2), out1.size(-1) 45 | out1 = out1.view(bsz, hb, wb, outch, o_ha, o_wa).permute(0, 3, 4, 5, 1, 2).contiguous() 46 | 47 | bsz, inch, ha, wa, hb, wb = x.size() 48 | out2 = x.permute(0, 2, 3, 1, 4, 5).contiguous().view(-1, inch, hb, wb) 49 | out2 = self.conv2(out2) 50 | outch, o_hb, o_wb = out2.size(-3), out2.size(-2), out2.size(-1) 51 | out2 = out2.view(bsz, ha, wa, outch, o_hb, o_wb).permute(0, 3, 1, 2, 4, 5).contiguous() 52 | 53 | if out1.size()[-2:] != out2.size()[-2:] and self.padding[-2:] == (0, 0): 54 | out1 = out1.view(bsz, outch, o_ha, o_wa, -1).sum(dim=-1) 55 | out2 = out2.squeeze() 56 | 57 | y = out1 + out2 58 | return y 59 | -------------------------------------------------------------------------------- /model/base/correlation.py: -------------------------------------------------------------------------------- 1 | r""" Provides functions that builds/manipulates correlation tensors """ 2 | import torch 3 | 4 | 5 | class Correlation: 6 | 7 | @classmethod 8 | def multilayer_correlation(cls, query_feats, support_feats, stack_ids): 9 | eps = 1e-5 10 | 11 | corrs = [] 12 | for idx, (query_feat, support_feat) in enumerate(zip(query_feats, support_feats)): 13 | bsz, ch, hb, wb = support_feat.size() 14 | support_feat = support_feat.view(bsz, ch, -1) 15 | support_feat = support_feat / (support_feat.norm(dim=1, p=2, keepdim=True) + eps) 16 | 17 | bsz, ch, ha, wa = query_feat.size() 18 | query_feat = query_feat.view(bsz, ch, -1) 19 | query_feat = query_feat / (query_feat.norm(dim=1, p=2, keepdim=True) + eps) 20 | 21 | corr = torch.bmm(query_feat.transpose(1, 2), support_feat).view(bsz, ha, wa, hb, wb) 22 | corr = corr.clamp(min=0) 23 | corrs.append(corr) 24 | 25 | corr_l4 = torch.stack(corrs[-stack_ids[0]:]).transpose(0, 1).contiguous() 26 | corr_l3 = torch.stack(corrs[-stack_ids[1]:-stack_ids[0]]).transpose(0, 1).contiguous() 27 | corr_l2 = torch.stack(corrs[-stack_ids[2]:-stack_ids[1]]).transpose(0, 1).contiguous() 28 | 29 | return [corr_l4, corr_l3, corr_l2] 30 | -------------------------------------------------------------------------------- /model/base/feature.py: -------------------------------------------------------------------------------- 1 | r""" Extracts intermediate features from given backbone network & layer ids """ 2 | 3 | 4 | def extract_feat_vgg(img, backbone, feat_ids, bottleneck_ids=None, lids=None): 5 | r""" Extract intermediate features from VGG """ 6 | feats = [] 7 | feat = img 8 | for lid, module in enumerate(backbone.features): 9 | feat = module(feat) 10 | if lid in feat_ids: 11 | feats.append(feat.clone()) 12 | return feats 13 | 14 | 15 | def extract_feat_res(img, backbone, feat_ids, bottleneck_ids, lids): 16 | r""" Extract intermediate features from ResNet""" 17 | feats = [] 18 | 19 | # Layer 0 20 | feat = backbone.conv1.forward(img) 21 | feat = backbone.bn1.forward(feat) 22 | feat = backbone.relu.forward(feat) 23 | feat = backbone.maxpool.forward(feat) 24 | 25 | # Layer 1-4 26 | for hid, (bid, lid) in enumerate(zip(bottleneck_ids, lids)): 27 | res = feat 28 | feat = backbone.__getattr__('layer%d' % lid)[bid].conv1.forward(feat) 29 | feat = backbone.__getattr__('layer%d' % lid)[bid].bn1.forward(feat) 30 | feat = backbone.__getattr__('layer%d' % lid)[bid].relu.forward(feat) 31 | feat = backbone.__getattr__('layer%d' % lid)[bid].conv2.forward(feat) 32 | feat = backbone.__getattr__('layer%d' % lid)[bid].bn2.forward(feat) 33 | feat = backbone.__getattr__('layer%d' % lid)[bid].relu.forward(feat) 34 | feat = backbone.__getattr__('layer%d' % lid)[bid].conv3.forward(feat) 35 | feat = backbone.__getattr__('layer%d' % lid)[bid].bn3.forward(feat) 36 | 37 | if bid == 0: 38 | res = backbone.__getattr__('layer%d' % lid)[bid].downsample.forward(res) 39 | 40 | feat += res 41 | 42 | if hid + 1 in feat_ids: 43 | feats.append(feat.clone()) 44 | 45 | feat = backbone.__getattr__('layer%d' % lid)[bid].relu.forward(feat) 46 | 47 | return feats 48 | -------------------------------------------------------------------------------- /model/hsnet_imr.py: -------------------------------------------------------------------------------- 1 | r""" Hypercorrelation Squeeze Network """ 2 | import pdb 3 | from functools import reduce 4 | from operator import add 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torchvision.models import resnet 10 | from torchvision.models import vgg 11 | 12 | from .base.feature import extract_feat_vgg, extract_feat_res 13 | from .base.correlation import Correlation 14 | from .learner import HPNLearner 15 | 16 | 17 | class HypercorrSqueezeNetwork_imr(nn.Module): 18 | # 与不是res的结构相比,这里就是让logit——mask和cam cat一下,然后过个卷积出结果,来显式地再用一下cam 19 | def __init__(self, backbone, use_original_imgsize): 20 | super(HypercorrSqueezeNetwork_imr, self).__init__() 21 | 22 | # 1. Backbone network initialization 23 | self.backbone_type = backbone 24 | self.use_original_imgsize = use_original_imgsize 25 | if backbone == 'vgg16': 26 | self.backbone = vgg.vgg16(pretrained=False) 27 | ckpt = torch.load('../Datasets_HSN/Pretrain/vgg16-397923af.pth') 28 | self.backbone.load_state_dict(ckpt) 29 | self.feat_ids = [17, 19, 21, 24, 26, 28, 30] 30 | self.extract_feats = extract_feat_vgg 31 | nbottlenecks = [2, 2, 3, 3, 3, 1] 32 | elif backbone == 'resnet50': 33 | self.backbone = resnet.resnet50(pretrained=False) 34 | ckpt = torch.load('../Datasets_HSN/Pretrain/resnet50-19c8e357.pth') 35 | self.backbone.load_state_dict(ckpt) 36 | self.feat_ids = list(range(4, 17)) 37 | self.extract_feats = extract_feat_res 38 | nbottlenecks = [3, 4, 6, 3] 39 | self.conv1024_512 = nn.Conv2d(1024, 512, kernel_size=1) 40 | 41 | elif backbone == 'resnet101': 42 | self.backbone = resnet.resnet101(pretrained=True) 43 | self.feat_ids = list(range(4, 34)) 44 | self.extract_feats = extract_feat_res 45 | nbottlenecks = [3, 4, 23, 3] 46 | else: 47 | raise Exception('Unavailable backbone: %s' % backbone) 48 | 49 | self.bottleneck_ids = reduce(add, list(map(lambda x: list(range(x)), nbottlenecks))) 50 | self.lids = reduce(add, [[i + 1] * x for i, x in enumerate(nbottlenecks)]) 51 | self.stack_ids = torch.tensor(self.lids).bincount().__reversed__().cumsum(dim=0)[:3] 52 | self.backbone.eval() 53 | self.hpn_learner = HPNLearner(list(reversed(nbottlenecks[-3:]))) 54 | self.cross_entropy_loss = nn.CrossEntropyLoss() 55 | 56 | # IMR 57 | self.state = nn.Parameter(torch.zeros([1, 128, 50, 50])) 58 | self.convz0 = nn.Conv2d(769, 512, kernel_size=1, padding=0) 59 | self.convz1 = nn.Conv2d(256, 64, kernel_size=3, padding=1, groups=8, dilation=1) 60 | self.convz2 = nn.Conv2d(256, 64, kernel_size=3, padding=1, groups=8, dilation=1) 61 | 62 | self.convr0 = nn.Conv2d(769, 512, kernel_size=1, padding=0) 63 | self.convr1 = nn.Conv2d(256, 64, kernel_size=3, padding=1, groups=8, dilation=1) 64 | self.convr2 = nn.Conv2d(256, 64, kernel_size=3, padding=1, groups=8, dilation=1) 65 | 66 | self.convh0 = nn.Conv2d(769, 512, kernel_size=1, padding=0) 67 | self.convh1 = nn.Conv2d(256, 64, kernel_size=3, padding=1, groups=8, dilation=1) 68 | self.convh2 = nn.Conv2d(256, 64, kernel_size=3, padding=1, groups=8, dilation=1) 69 | 70 | # copied from hsnet-learner 71 | outch1, outch2, outch3 = 16, 64, 128 72 | self.decoder1 = nn.Sequential(nn.Conv2d(outch3, outch3, (3, 3), padding=(1, 1), bias=True), 73 | nn.ReLU(), 74 | nn.Conv2d(outch3, outch2, (3, 3), padding=(1, 1), bias=True), 75 | nn.ReLU()) 76 | 77 | self.decoder2 = nn.Sequential(nn.Conv2d(outch2, outch2, (3, 3), padding=(1, 1), bias=True), 78 | nn.ReLU(), 79 | nn.Conv2d(outch2, 2, (3, 3), padding=(1, 1), bias=True)) 80 | 81 | self.res = nn.Sequential(nn.Conv2d(3, 10, kernel_size=1), 82 | nn.GELU(), 83 | nn.Conv2d(10, 2, kernel_size=1)) 84 | 85 | def forward(self, query_img, support_img, support_cam, query_cam, 86 | query_mask=None, support_mask=None, stage=2, w='same'): 87 | with torch.no_grad(): 88 | query_feats = self.extract_feats(query_img, self.backbone, self.feat_ids, self.bottleneck_ids, self.lids) 89 | support_feats = self.extract_feats( 90 | support_img, self.backbone, self.feat_ids, self.bottleneck_ids, self.lids) 91 | 92 | # extracting feature 93 | if len(query_feats) == 7: 94 | isvgg = True # VGG 95 | q_mid_feat = F.interpolate(query_feats[3] + query_feats[4] + query_feats[5], 96 | (50, 50), mode='bilinear', align_corners=True) 97 | s_mid_feat = F.interpolate(support_feats[3] + support_feats[4] + support_feats[5], 98 | (50, 50), mode='bilinear', align_corners=True) 99 | else: 100 | isvgg = False # R50 101 | q_mid_feat = F.interpolate( 102 | query_feats[4] + query_feats[5] + query_feats[6] + query_feats[7] + query_feats[8] + query_feats[9], 103 | (50, 50), mode='bilinear', align_corners=True) 104 | 105 | s_mid_feat = F.interpolate( 106 | support_feats[4] + support_feats[5] + support_feats[6] + support_feats[7] + support_feats[8] + 107 | support_feats[9], 108 | (50, 50), mode='bilinear', align_corners=True) 109 | 110 | query_feats_masked = self.mask_feature(query_feats, support_cam.clone()) 111 | support_feats_masked = self.mask_feature(support_feats, query_cam.clone()) 112 | 113 | corr_query = Correlation.multilayer_correlation(query_feats, support_feats_masked, self.stack_ids) 114 | corr_support = Correlation.multilayer_correlation(support_feats, query_feats_masked, self.stack_ids) 115 | 116 | query_cam = query_cam.unsqueeze(1) 117 | support_cam = support_cam.unsqueeze(1) 118 | 119 | if not isvgg: 120 | # make feat dim in R50 same as VGG 121 | q_mid_feat = self.conv1024_512(q_mid_feat) 122 | s_mid_feat = self.conv1024_512(s_mid_feat) 123 | 124 | bsz = query_img.shape[0] 125 | state_query = self.state.expand(bsz, -1, -1, -1) 126 | state_support = self.state.expand(bsz, -1, -1, -1) 127 | 128 | losses = 0 129 | for ss in range(stage): 130 | # query 131 | after4d_query = self.hpn_learner.forward_conv4d(corr_query) 132 | imr_x_query = torch.cat([query_cam, after4d_query, q_mid_feat, state_query], dim=1) 133 | 134 | imr_x_query_z = self.convz0(imr_x_query) 135 | imr_z_query1 = self.convz1(imr_x_query_z[:, :256]) 136 | imr_z_query2 = self.convz2(imr_x_query_z[:, 256:]) 137 | imr_z_query = torch.sigmoid(torch.cat([imr_z_query1, imr_z_query2], dim=1)) 138 | 139 | imr_x_query_r = self.convr0(imr_x_query) 140 | imr_r_query1 = self.convr1(imr_x_query_r[:, :256]) 141 | imr_r_query2 = self.convr2(imr_x_query_r[:, 256:]) 142 | imr_r_query = torch.sigmoid(torch.cat([imr_r_query1, imr_r_query2], dim=1)) 143 | 144 | imr_x_query_h = self.convh0( 145 | torch.cat([query_cam, after4d_query, q_mid_feat, imr_r_query * state_query], dim=1)) 146 | imr_h_query1 = self.convh1(imr_x_query_h[:, :256]) 147 | imr_h_query2 = self.convh2(imr_x_query_h[:, 256:]) 148 | imr_h_query = torch.cat([imr_h_query1, imr_h_query2], dim=1) 149 | 150 | state_new_query = torch.tanh(imr_h_query) 151 | state_query = (1 - imr_z_query) * state_query + imr_z_query * state_new_query 152 | 153 | # support 154 | after4d_support = self.hpn_learner.forward_conv4d(corr_support) 155 | imr_x_support = torch.cat([support_cam, after4d_support, s_mid_feat, state_support], dim=1) 156 | 157 | imr_x_support_z = self.convz0(imr_x_support) 158 | imr_z_support1 = self.convz1(imr_x_support_z[:, :256]) 159 | imr_z_support2 = self.convz2(imr_x_support_z[:, 256:]) 160 | imr_z_support = torch.sigmoid(torch.cat([imr_z_support1, imr_z_support2], dim=1)) 161 | 162 | imr_x_support_r = self.convr0(imr_x_support) 163 | imr_r_support1 = self.convr1(imr_x_support_r[:, :256]) 164 | imr_r_support2 = self.convr2(imr_x_support_r[:, 256:]) 165 | imr_r_support = torch.sigmoid(torch.cat([imr_r_support1, imr_r_support2], dim=1)) 166 | 167 | imr_x_support_h = self.convh0( 168 | torch.cat([support_cam, after4d_support, s_mid_feat, imr_r_support * state_support], dim=1)) 169 | imr_h_support1 = self.convh1(imr_x_support_h[:, :256]) 170 | imr_h_support2 = self.convh2(imr_x_support_h[:, 256:]) 171 | imr_h_support = torch.cat([imr_h_support1, imr_h_support2], dim=1) 172 | 173 | state_new_support = torch.tanh(imr_h_support) 174 | state_support = (1 - imr_z_support) * state_support + imr_z_support * state_new_support 175 | 176 | # decoder 177 | hypercorr_decoded_s = self.decoder1(state_support + after4d_support) 178 | upsample_size = (hypercorr_decoded_s.size(-1) * 2,) * 2 179 | hypercorr_decoded_s = F.interpolate(hypercorr_decoded_s, upsample_size, mode='bilinear', align_corners=True) 180 | logit_mask_support = self.decoder2(hypercorr_decoded_s) 181 | 182 | hypercorr_decoded_q = self.decoder1(state_query + after4d_query) 183 | upsample_size = (hypercorr_decoded_q.size(-1) * 2,) * 2 184 | hypercorr_decoded_q = F.interpolate(hypercorr_decoded_q, upsample_size, mode='bilinear', align_corners=True) 185 | logit_mask_query = self.decoder2(hypercorr_decoded_q) 186 | 187 | logit_mask_support = self.res( 188 | torch.cat( 189 | [logit_mask_support, F.interpolate(support_cam, (100, 100), mode='bilinear', align_corners=True)], 190 | dim=1)) 191 | logit_mask_query = self.res( 192 | torch.cat([logit_mask_query, F.interpolate(query_cam, (100, 100), mode='bilinear', align_corners=True)], 193 | dim=1)) 194 | 195 | # loss 196 | if query_mask is not None: # for training 197 | if not self.use_original_imgsize: 198 | logit_mask_query_temp = F.interpolate(logit_mask_query, support_img.size()[2:], mode='bilinear', 199 | align_corners=True) 200 | logit_mask_support_temp = F.interpolate(logit_mask_support, support_img.size()[2:], mode='bilinear', 201 | align_corners=True) 202 | loss_q_stage = self.compute_objective(logit_mask_query_temp, query_mask) 203 | loss_s_stage = self.compute_objective(logit_mask_support_temp, support_mask) 204 | losses = losses + loss_q_stage + loss_s_stage 205 | 206 | if ss != stage - 1: 207 | support_cam = logit_mask_support.softmax(dim=1)[:, 1] 208 | query_cam = logit_mask_query.softmax(dim=1)[:, 1] 209 | query_feats_masked = self.mask_feature(query_feats, query_cam) 210 | support_feats_masked = self.mask_feature(support_feats, support_cam) 211 | corr_query = Correlation.multilayer_correlation(query_feats, support_feats_masked, self.stack_ids) 212 | corr_support = Correlation.multilayer_correlation(support_feats, query_feats_masked, self.stack_ids) 213 | 214 | query_cam = F.interpolate(query_cam.unsqueeze(1), (50, 50), mode='bilinear', align_corners=True) 215 | support_cam = F.interpolate(support_cam.unsqueeze(1), (50, 50), mode='bilinear', align_corners=True) 216 | 217 | if query_mask is not None: 218 | return logit_mask_query_temp, logit_mask_support_temp, losses 219 | else: 220 | # test 221 | if not self.use_original_imgsize: 222 | logit_mask_query = F.interpolate( 223 | logit_mask_query, support_img.size()[2:], mode='bilinear', align_corners=True) 224 | logit_mask_support = F.interpolate( 225 | logit_mask_support, support_img.size()[2:], mode='bilinear', align_corners=True) 226 | return logit_mask_query, logit_mask_support 227 | 228 | def mask_feature(self, features, support_mask): 229 | for idx, feature in enumerate(features): 230 | mask = F.interpolate( 231 | support_mask.unsqueeze(1).float(), feature.size()[2:], mode='bilinear', align_corners=True) 232 | features[idx] = features[idx] * mask 233 | return features 234 | 235 | def predict_mask_nshot(self, batch, nshot, stage): 236 | # Perform multiple prediction given (nshot) number of different support sets 237 | logit_mask_agg = 0 238 | for s_idx in range(nshot): 239 | logit_mask, logit_mask_s = self(query_img=batch['query_img'], 240 | support_img=batch['support_imgs'][:, s_idx], 241 | support_cam=batch['support_cams'][:, s_idx], 242 | query_cam=batch['query_cam'], stage=stage) 243 | if self.use_original_imgsize: 244 | org_qry_imsize = tuple([batch['org_query_imsize'][1].item(), batch['org_query_imsize'][0].item()]) 245 | logit_mask = F.interpolate(logit_mask, org_qry_imsize, mode='bilinear', align_corners=True) 246 | 247 | logit_mask_agg += logit_mask.argmax(dim=1).clone() 248 | if nshot == 1: 249 | return logit_mask_agg 250 | 251 | # Average & quantize predictions given threshold (=0.5) 252 | bsz = logit_mask_agg.size(0) 253 | max_vote = logit_mask_agg.view(bsz, -1).max(dim=1)[0] 254 | max_vote = torch.stack([max_vote, torch.ones_like(max_vote).long()]) 255 | max_vote = max_vote.max(dim=0)[0].view(bsz, 1, 1) 256 | pred_mask = logit_mask_agg.float() / max_vote 257 | pred_mask[pred_mask < 0.5] = 0 258 | pred_mask[pred_mask >= 0.5] = 1 259 | 260 | return pred_mask 261 | 262 | def compute_objective(self, logit_mask, gt_mask): 263 | bsz = logit_mask.size(0) 264 | logit_mask = logit_mask.view(bsz, 2, -1) 265 | gt_mask = gt_mask.view(bsz, -1).long() 266 | return self.cross_entropy_loss(logit_mask, gt_mask) 267 | 268 | def train_mode(self): 269 | self.train() 270 | self.backbone.eval() # to prevent BN from learning data statistics with exponential averaging 271 | -------------------------------------------------------------------------------- /model/learner.py: -------------------------------------------------------------------------------- 1 | import pdb 2 | 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from .base.conv4d import CenterPivotConv4d as Conv4d 7 | 8 | 9 | class HPNLearner(nn.Module): 10 | def __init__(self, inch): 11 | super(HPNLearner, self).__init__() 12 | 13 | def make_building_block(in_channel, out_channels, kernel_sizes, spt_strides, group=4): 14 | assert len(out_channels) == len(kernel_sizes) == len(spt_strides) 15 | 16 | building_block_layers = [] 17 | for idx, (outch, ksz, stride) in enumerate(zip(out_channels, kernel_sizes, spt_strides)): 18 | inch = in_channel if idx == 0 else out_channels[idx - 1] 19 | ksz4d = (ksz,) * 4 20 | str4d = (1, 1) + (stride,) * 2 21 | pad4d = (ksz // 2,) * 4 22 | 23 | building_block_layers.append(Conv4d(inch, outch, ksz4d, str4d, pad4d)) 24 | building_block_layers.append(nn.GroupNorm(group, outch)) 25 | building_block_layers.append(nn.ReLU(inplace=True)) 26 | 27 | return nn.Sequential(*building_block_layers) 28 | 29 | outch1, outch2, outch3 = 16, 64, 128 30 | 31 | # Squeezing building blocks 32 | self.encoder_layer4 = make_building_block(inch[0], [outch1, outch2, outch3], [3, 3, 3], [2, 2, 2]) 33 | self.encoder_layer3 = make_building_block(inch[1], [outch1, outch2, outch3], [5, 3, 3], [4, 2, 2]) 34 | self.encoder_layer2 = make_building_block(inch[2], [outch1, outch2, outch3], [5, 5, 3], [4, 4, 2]) 35 | 36 | # Mixing building blocks 37 | self.encoder_layer4to3 = make_building_block(outch3, [outch3, outch3, outch3], [3, 3, 3], [1, 1, 1]) 38 | self.encoder_layer3to2 = make_building_block(outch3, [outch3, outch3, outch3], [3, 3, 3], [1, 1, 1]) 39 | 40 | # Decoder layers 41 | # self.decoder1 = nn.Sequential(nn.Conv2d(outch3, outch3, (3, 3), padding=(1, 1), bias=True), 42 | # nn.ReLU(), 43 | # nn.Conv2d(outch3, outch2, (3, 3), padding=(1, 1), bias=True), 44 | # nn.ReLU()) 45 | # 46 | # self.decoder2 = nn.Sequential(nn.Conv2d(outch2, outch2, (3, 3), padding=(1, 1), bias=True), 47 | # nn.ReLU(), 48 | # nn.Conv2d(outch2, 2, (3, 3), padding=(1, 1), bias=True)) 49 | self.decoder1 = None 50 | self.decoder2 = None 51 | # 注意,HSNet代码里,learner需要在这里定义decoder的,但是我们的代码把decoder定义在了hsnet_imr中,所以这里的decoder 52 | # 其实没用。为了开源代码,必须公司审核,这块不注释掉总是报错“重复率过高”,没办法我只能把这块先注释掉 53 | 54 | 55 | 56 | 57 | 58 | def interpolate_support_dims(self, hypercorr, spatial_size=None): 59 | bsz, ch, ha, wa, hb, wb = hypercorr.size() 60 | hypercorr = hypercorr.permute(0, 4, 5, 1, 2, 3).contiguous().view(bsz * hb * wb, ch, ha, wa) 61 | hypercorr = F.interpolate(hypercorr, spatial_size, mode='bilinear', align_corners=True) 62 | o_hb, o_wb = spatial_size 63 | hypercorr = hypercorr.view(bsz, hb, wb, ch, o_hb, o_wb).permute(0, 3, 4, 5, 1, 2).contiguous() 64 | return hypercorr 65 | 66 | def forward_conv4d(self, hypercorr_pyramid): 67 | # Encode hypercorrelations from each layer (Squeezing building blocks) 68 | hypercorr_sqz4 = self.encoder_layer4(hypercorr_pyramid[0]) 69 | hypercorr_sqz3 = self.encoder_layer3(hypercorr_pyramid[1]) 70 | hypercorr_sqz2 = self.encoder_layer2(hypercorr_pyramid[2]) 71 | 72 | # Propagate encoded 4D-tensor (Mixing building blocks) 73 | hypercorr_sqz4 = self.interpolate_support_dims(hypercorr_sqz4, hypercorr_sqz3.size()[-4:-2]) 74 | hypercorr_mix43 = hypercorr_sqz4 + hypercorr_sqz3 75 | hypercorr_mix43 = self.encoder_layer4to3(hypercorr_mix43) 76 | 77 | hypercorr_mix43 = self.interpolate_support_dims(hypercorr_mix43, hypercorr_sqz2.size()[-4:-2]) 78 | hypercorr_mix432 = hypercorr_mix43 + hypercorr_sqz2 79 | hypercorr_mix432 = self.encoder_layer3to2(hypercorr_mix432) 80 | 81 | bsz, ch, ha, wa, hb, wb = hypercorr_mix432.size() 82 | hypercorr_encoded = hypercorr_mix432.view(bsz, ch, ha, wa, -1).mean(dim=-1) # torch.Size([1, 128, 50, 50]) 83 | return hypercorr_encoded 84 | 85 | def forward_decode(self, hypercorr_encoded): 86 | hypercorr_decoded = self.decoder1(hypercorr_encoded) 87 | upsample_size = (hypercorr_decoded.size(-1) * 2,) * 2 88 | hypercorr_decoded = F.interpolate(hypercorr_decoded, upsample_size, mode='bilinear', align_corners=True) 89 | logit_mask = self.decoder2(hypercorr_decoded) 90 | return logit_mask 91 | 92 | def forward(self, hypercorr_pyramid): 93 | hypercorr_encoded = self.forward_conv4d(hypercorr_pyramid) 94 | # vgg16: torch.Size([1, 128, 50, 50]) 95 | # r50; torch.Size([1, 128, 50, 50]) 96 | # pdb.set_trace() 97 | logit_mask = self.forward_decode(hypercorr_encoded) 98 | return logit_mask 99 | 100 | # def forward(self, hypercorr_pyramid): 101 | # 102 | # # Encode hypercorrelations from each layer (Squeezing building blocks) 103 | # hypercorr_sqz4 = self.encoder_layer4(hypercorr_pyramid[0]) 104 | # hypercorr_sqz3 = self.encoder_layer3(hypercorr_pyramid[1]) 105 | # hypercorr_sqz2 = self.encoder_layer2(hypercorr_pyramid[2]) 106 | # pdb.set_trace() 107 | # 108 | # # Propagate encoded 4D-tensor (Mixing building blocks) 109 | # hypercorr_sqz4 = self.interpolate_support_dims(hypercorr_sqz4, hypercorr_sqz3.size()[-4:-2]) 110 | # hypercorr_mix43 = hypercorr_sqz4 + hypercorr_sqz3 111 | # hypercorr_mix43 = self.encoder_layer4to3(hypercorr_mix43) 112 | # 113 | # hypercorr_mix43 = self.interpolate_support_dims(hypercorr_mix43, hypercorr_sqz2.size()[-4:-2]) 114 | # hypercorr_mix432 = hypercorr_mix43 + hypercorr_sqz2 115 | # hypercorr_mix432 = self.encoder_layer3to2(hypercorr_mix432) 116 | # 117 | # 118 | # bsz, ch, ha, wa, hb, wb = hypercorr_mix432.size() 119 | # hypercorr_encoded = hypercorr_mix432.view(bsz, ch, ha, wa, -1).mean(dim=-1) #torch.Size([1, 128, 50, 50]) 120 | # 121 | # # Decode the encoded 4D-tensor 122 | # hypercorr_decoded = self.decoder1(hypercorr_encoded) 123 | # upsample_size = (hypercorr_decoded.size(-1) * 2,) * 2 124 | # hypercorr_decoded = F.interpolate(hypercorr_decoded, upsample_size, mode='bilinear', align_corners=True) 125 | # logit_mask = self.decoder2(hypercorr_decoded) 126 | # 127 | # return logit_mask 128 | -------------------------------------------------------------------------------- /pytorch_grad_cam/__init__.py: -------------------------------------------------------------------------------- 1 | from pytorch_grad_cam.grad_cam import GradCAM 2 | from pytorch_grad_cam.ablation_cam import AblationCAM 3 | from pytorch_grad_cam.xgrad_cam import XGradCAM 4 | from pytorch_grad_cam.grad_cam_plusplus import GradCAMPlusPlus 5 | from pytorch_grad_cam.score_cam import ScoreCAM 6 | from pytorch_grad_cam.layer_cam import LayerCAM 7 | from pytorch_grad_cam.eigen_cam import EigenCAM 8 | from pytorch_grad_cam.eigen_grad_cam import EigenGradCAM 9 | from pytorch_grad_cam.fullgrad_cam import FullGrad 10 | from pytorch_grad_cam.guided_backprop import GuidedBackpropReLUModel 11 | -------------------------------------------------------------------------------- /pytorch_grad_cam/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Whileherham/IMR-HSNet/57572272f493c3b16540a650ff0398d605e46812/pytorch_grad_cam/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /pytorch_grad_cam/__pycache__/ablation_cam.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Whileherham/IMR-HSNet/57572272f493c3b16540a650ff0398d605e46812/pytorch_grad_cam/__pycache__/ablation_cam.cpython-38.pyc -------------------------------------------------------------------------------- /pytorch_grad_cam/__pycache__/activations_and_gradients.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Whileherham/IMR-HSNet/57572272f493c3b16540a650ff0398d605e46812/pytorch_grad_cam/__pycache__/activations_and_gradients.cpython-38.pyc -------------------------------------------------------------------------------- /pytorch_grad_cam/__pycache__/base_cam.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Whileherham/IMR-HSNet/57572272f493c3b16540a650ff0398d605e46812/pytorch_grad_cam/__pycache__/base_cam.cpython-38.pyc -------------------------------------------------------------------------------- /pytorch_grad_cam/__pycache__/eigen_cam.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Whileherham/IMR-HSNet/57572272f493c3b16540a650ff0398d605e46812/pytorch_grad_cam/__pycache__/eigen_cam.cpython-38.pyc -------------------------------------------------------------------------------- /pytorch_grad_cam/__pycache__/eigen_grad_cam.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Whileherham/IMR-HSNet/57572272f493c3b16540a650ff0398d605e46812/pytorch_grad_cam/__pycache__/eigen_grad_cam.cpython-38.pyc -------------------------------------------------------------------------------- /pytorch_grad_cam/__pycache__/fullgrad_cam.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Whileherham/IMR-HSNet/57572272f493c3b16540a650ff0398d605e46812/pytorch_grad_cam/__pycache__/fullgrad_cam.cpython-38.pyc -------------------------------------------------------------------------------- /pytorch_grad_cam/__pycache__/grad_cam.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Whileherham/IMR-HSNet/57572272f493c3b16540a650ff0398d605e46812/pytorch_grad_cam/__pycache__/grad_cam.cpython-38.pyc -------------------------------------------------------------------------------- /pytorch_grad_cam/__pycache__/grad_cam_plusplus.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Whileherham/IMR-HSNet/57572272f493c3b16540a650ff0398d605e46812/pytorch_grad_cam/__pycache__/grad_cam_plusplus.cpython-38.pyc -------------------------------------------------------------------------------- /pytorch_grad_cam/__pycache__/guided_backprop.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Whileherham/IMR-HSNet/57572272f493c3b16540a650ff0398d605e46812/pytorch_grad_cam/__pycache__/guided_backprop.cpython-38.pyc -------------------------------------------------------------------------------- /pytorch_grad_cam/__pycache__/layer_cam.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Whileherham/IMR-HSNet/57572272f493c3b16540a650ff0398d605e46812/pytorch_grad_cam/__pycache__/layer_cam.cpython-38.pyc -------------------------------------------------------------------------------- /pytorch_grad_cam/__pycache__/score_cam.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Whileherham/IMR-HSNet/57572272f493c3b16540a650ff0398d605e46812/pytorch_grad_cam/__pycache__/score_cam.cpython-38.pyc -------------------------------------------------------------------------------- /pytorch_grad_cam/__pycache__/xgrad_cam.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Whileherham/IMR-HSNet/57572272f493c3b16540a650ff0398d605e46812/pytorch_grad_cam/__pycache__/xgrad_cam.cpython-38.pyc -------------------------------------------------------------------------------- /pytorch_grad_cam/ablation_cam.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import tqdm 4 | from pytorch_grad_cam.base_cam import BaseCAM 5 | from pytorch_grad_cam.utils.find_layers import replace_layer_recursive 6 | 7 | 8 | class AblationLayer(torch.nn.Module): 9 | def __init__(self, layer, reshape_transform, indices): 10 | super(AblationLayer, self).__init__() 11 | 12 | self.layer = layer 13 | self.reshape_transform = reshape_transform 14 | # The channels to zero out: 15 | self.indices = indices 16 | 17 | def forward(self, x): 18 | self.__call__(x) 19 | 20 | def __call__(self, x): 21 | output = self.layer(x) 22 | 23 | # Hack to work with ViT, 24 | # Since the activation channels are last and not first like in CNNs 25 | # Probably should remove it? 26 | if self.reshape_transform is not None: 27 | output = output.transpose(1, 2) 28 | 29 | for i in range(output.size(0)): 30 | 31 | # Commonly the minimum activation will be 0, 32 | # And then it makes sense to zero it out. 33 | # However depending on the architecture, 34 | # If the values can be negative, we use very negative values 35 | # to perform the ablation, deviating from the paper. 36 | if torch.min(output) == 0: 37 | output[i, self.indices[i], :] = 0 38 | else: 39 | ABLATION_VALUE = 1e5 40 | output[i, self.indices[i], :] = torch.min( 41 | output) - ABLATION_VALUE 42 | 43 | if self.reshape_transform is not None: 44 | output = output.transpose(2, 1) 45 | 46 | return output 47 | 48 | 49 | class AblationCAM(BaseCAM): 50 | def __init__(self, 51 | model, 52 | target_layers, 53 | use_cuda=False, 54 | reshape_transform=None): 55 | super(AblationCAM, self).__init__(model, target_layers, use_cuda, 56 | reshape_transform) 57 | 58 | def get_cam_weights(self, 59 | input_tensor, 60 | target_layer, 61 | target_category, 62 | activations, 63 | grads): 64 | with torch.no_grad(): 65 | outputs = self.model(input_tensor).cpu().numpy() 66 | original_scores = [] 67 | for i in range(input_tensor.size(0)): 68 | original_scores.append(outputs[i, target_category[i]]) 69 | original_scores = np.float32(original_scores) 70 | 71 | ablation_layer = AblationLayer(target_layer, 72 | self.reshape_transform, 73 | indices=[]) 74 | replace_layer_recursive(self.model, target_layer, ablation_layer) 75 | 76 | if hasattr(self, "batch_size"): 77 | BATCH_SIZE = self.batch_size 78 | else: 79 | BATCH_SIZE = 32 80 | 81 | number_of_channels = activations.shape[1] 82 | weights = [] 83 | 84 | with torch.no_grad(): 85 | # Iterate over the input batch 86 | for tensor, category in zip(input_tensor, target_category): 87 | batch_tensor = tensor.repeat(BATCH_SIZE, 1, 1, 1) 88 | for i in tqdm.tqdm(range(0, number_of_channels, BATCH_SIZE)): 89 | ablation_layer.indices = list(range(i, i + BATCH_SIZE)) 90 | 91 | if i + BATCH_SIZE > number_of_channels: 92 | keep = number_of_channels - i 93 | batch_tensor = batch_tensor[:keep] 94 | ablation_layer.indices = ablation_layer.indices[:keep] 95 | score = self.model(batch_tensor)[:, category].cpu().numpy() 96 | weights.extend(score) 97 | 98 | weights = np.float32(weights) 99 | weights = weights.reshape(activations.shape[:2]) 100 | original_scores = original_scores[:, None] 101 | weights = (original_scores - weights) / original_scores 102 | 103 | # Replace the model back to the original state 104 | replace_layer_recursive(self.model, ablation_layer, target_layer) 105 | return weights 106 | -------------------------------------------------------------------------------- /pytorch_grad_cam/activations_and_gradients.py: -------------------------------------------------------------------------------- 1 | class ActivationsAndGradients: 2 | """ Class for extracting activations and 3 | registering gradients from targetted intermediate layers """ 4 | 5 | def __init__(self, model, target_layers, reshape_transform): 6 | self.model = model 7 | self.gradients = [] 8 | self.activations = [] 9 | self.reshape_transform = reshape_transform 10 | self.handles = [] 11 | for target_layer in target_layers: 12 | self.handles.append( 13 | target_layer.register_forward_hook( 14 | self.save_activation)) 15 | # Backward compitability with older pytorch versions: 16 | if hasattr(target_layer, 'register_full_backward_hook'): 17 | self.handles.append( 18 | target_layer.register_full_backward_hook( 19 | self.save_gradient)) 20 | else: 21 | self.handles.append( 22 | target_layer.register_backward_hook( 23 | self.save_gradient)) 24 | 25 | def save_activation(self, module, input, output): 26 | activation = output 27 | if self.reshape_transform is not None: 28 | activation = self.reshape_transform(activation) 29 | self.activations.append(activation.cpu().detach()) 30 | 31 | def save_gradient(self, module, grad_input, grad_output): 32 | # Gradients are computed in reverse order 33 | grad = grad_output[0] 34 | if self.reshape_transform is not None: 35 | grad = self.reshape_transform(grad) 36 | self.gradients = [grad.cpu().detach()] + self.gradients 37 | 38 | def __call__(self, x): 39 | self.gradients = [] 40 | self.activations = [] 41 | return self.model(x) 42 | 43 | def release(self): 44 | for handle in self.handles: 45 | handle.remove() 46 | -------------------------------------------------------------------------------- /pytorch_grad_cam/base_cam.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import torch 4 | import ttach as tta 5 | from pytorch_grad_cam.activations_and_gradients import ActivationsAndGradients 6 | from pytorch_grad_cam.utils.svd_on_activations import get_2d_projection 7 | 8 | 9 | class BaseCAM: 10 | def __init__(self, 11 | model, 12 | target_layers, 13 | use_cuda=False, 14 | reshape_transform=None, 15 | compute_input_gradient=False, 16 | uses_gradients=True): 17 | self.model = model.eval() 18 | self.target_layers = target_layers 19 | self.cuda = use_cuda 20 | if self.cuda: 21 | self.model = model.cuda() 22 | self.reshape_transform = reshape_transform 23 | self.compute_input_gradient = compute_input_gradient 24 | self.uses_gradients = uses_gradients 25 | self.activations_and_grads = ActivationsAndGradients( 26 | self.model, target_layers, reshape_transform) 27 | 28 | """ Get a vector of weights for every channel in the target layer. 29 | Methods that return weights channels, 30 | will typically need to only implement this function. """ 31 | 32 | def get_cam_weights(self, 33 | input_tensor, 34 | target_layers, 35 | target_category, 36 | activations, 37 | grads): 38 | raise Exception("Not Implemented") 39 | 40 | def get_loss(self, output, target_category): 41 | loss = 0 42 | for i in range(len(target_category)): 43 | loss = loss + output[i, target_category[i]] 44 | return loss 45 | 46 | def get_cam_image(self, 47 | input_tensor, 48 | target_layer, 49 | target_category, 50 | activations, 51 | grads, 52 | eigen_smooth=False): 53 | weights = self.get_cam_weights(input_tensor, target_layer, 54 | target_category, activations, grads) 55 | 56 | weighted_activations = weights[:, :, None, None] * activations 57 | 58 | # 不管n=? activation一样 59 | if eigen_smooth: 60 | cam = get_2d_projection(weighted_activations) 61 | else: 62 | cam = weighted_activations.sum(axis=1) 63 | # pdb.set_trace() 64 | return cam 65 | 66 | def forward(self, input_tensor, target_category=None, eigen_smooth=False): 67 | 68 | if self.cuda: 69 | input_tensor = input_tensor.cuda() 70 | 71 | if self.compute_input_gradient: # False 72 | input_tensor = torch.autograd.Variable(input_tensor, 73 | requires_grad=True) 74 | 75 | output = self.activations_and_grads(input_tensor) # tensor.shape 1,N 76 | 77 | if isinstance(target_category, int): 78 | target_category = [target_category] * input_tensor.size(0) 79 | 80 | if target_category is None: 81 | target_category = np.argmax(output.cpu().data.numpy(), axis=-1) 82 | else: 83 | assert (len(target_category) == input_tensor.size(0)) 84 | 85 | if self.uses_gradients: # True 86 | self.model.zero_grad() 87 | 88 | loss = self.get_loss(output, target_category) 89 | loss.backward(retain_graph=True) 90 | 91 | # In most of the saliency attribution papers, the saliency is 92 | # computed with a single target layer. 93 | # Commonly it is the last convolutional layer. 94 | # Here we support passing a list with multiple target layers. 95 | # It will compute the saliency image for every image, 96 | # and then aggregate them (with a default mean aggregation). 97 | # This gives you more flexibility in case you just want to 98 | # use all conv layers for example, all Batchnorm layers, 99 | # or something else. 100 | 101 | cam_per_layer = self.compute_cam_per_layer(input_tensor, 102 | target_category, 103 | eigen_smooth) 104 | return self.aggregate_multi_layers(cam_per_layer) 105 | 106 | def get_target_width_height(self, input_tensor): 107 | width, height = input_tensor.size(-1), input_tensor.size(-2) 108 | return width, height 109 | 110 | def compute_cam_per_layer( 111 | self, 112 | input_tensor, 113 | target_category, 114 | eigen_smooth): 115 | activations_list = [a.cpu().data.numpy() 116 | for a in self.activations_and_grads.activations] 117 | grads_list = [g.cpu().data.numpy() 118 | for g in self.activations_and_grads.gradients] 119 | target_size = self.get_target_width_height(input_tensor) 120 | 121 | cam_per_target_layer = [] 122 | # Loop over the saliency image from every layer 123 | 124 | for target_layer, layer_activations, layer_grads in \ 125 | zip(self.target_layers, activations_list, grads_list): 126 | cam = self.get_cam_image(input_tensor, 127 | target_layer, 128 | target_category, 129 | layer_activations, 130 | layer_grads, 131 | eigen_smooth) 132 | # pdb.set_trace() 133 | cam[cam < 0] = 0 # works like mute the min-max scale in the function of scale_cam_image 134 | scaled = self.scale_cam_image(cam, target_size) 135 | cam_per_target_layer.append(scaled[:, None, :]) 136 | 137 | return cam_per_target_layer 138 | 139 | def aggregate_multi_layers(self, cam_per_target_layer): 140 | cam_per_target_layer = np.concatenate(cam_per_target_layer, axis=1) 141 | cam_per_target_layer = np.maximum(cam_per_target_layer, 0) 142 | result = np.mean(cam_per_target_layer, axis=1) 143 | return self.scale_cam_image(result) 144 | 145 | def scale_cam_image(self, cam, target_size=None): 146 | result = [] 147 | for img in cam: 148 | 149 | # WHH ADD 150 | # pdb.set_trace() 151 | 152 | img = img - np.min(img) 153 | img = img / (1e-7 + np.max(img)) 154 | if target_size is not None: 155 | img = np.float32(img) # whh add here不然就报错 156 | img = cv2.resize(img, target_size) 157 | result.append(img) 158 | result = np.float32(result) 159 | 160 | return result 161 | 162 | def forward_augmentation_smoothing(self, 163 | input_tensor, 164 | target_category=None, 165 | eigen_smooth=False): 166 | transforms = tta.Compose( 167 | [ 168 | tta.HorizontalFlip(), 169 | tta.Multiply(factors=[0.9, 1, 1.1]), 170 | ] 171 | ) 172 | cams = [] 173 | for transform in transforms: 174 | augmented_tensor = transform.augment_image(input_tensor) 175 | cam = self.forward(augmented_tensor, 176 | target_category, eigen_smooth) 177 | 178 | # The ttach library expects a tensor of size BxCxHxW 179 | cam = cam[:, None, :, :] 180 | cam = torch.from_numpy(cam) 181 | cam = transform.deaugment_mask(cam) 182 | 183 | # Back to numpy float32, HxW 184 | cam = cam.numpy() 185 | cam = cam[:, 0, :, :] 186 | cams.append(cam) 187 | 188 | cam = np.mean(np.float32(cams), axis=0) 189 | return cam 190 | 191 | def __call__(self, 192 | input_tensor, 193 | target_category=None, 194 | aug_smooth=False, 195 | eigen_smooth=False): 196 | 197 | # Smooth the CAM result with test time augmentation 198 | # print('call, !=forward', aug_smooth) # aug smooth = False 199 | if aug_smooth is True: 200 | return self.forward_augmentation_smoothing( 201 | input_tensor, target_category, eigen_smooth) 202 | 203 | return self.forward(input_tensor, 204 | target_category, eigen_smooth) 205 | 206 | def __del__(self): 207 | self.activations_and_grads.release() 208 | 209 | def __enter__(self): 210 | return self 211 | 212 | def __exit__(self, exc_type, exc_value, exc_tb): 213 | self.activations_and_grads.release() 214 | if isinstance(exc_value, IndexError): 215 | # Handle IndexError here... 216 | print( 217 | f"An exception occurred in CAM with block: {exc_type}. Message: {exc_value}") 218 | return True 219 | -------------------------------------------------------------------------------- /pytorch_grad_cam/eigen_cam.py: -------------------------------------------------------------------------------- 1 | from pytorch_grad_cam.base_cam import BaseCAM 2 | from pytorch_grad_cam.utils.svd_on_activations import get_2d_projection 3 | 4 | # https://arxiv.org/abs/2008.00299 5 | 6 | 7 | class EigenCAM(BaseCAM): 8 | def __init__(self, model, target_layers, use_cuda=False, 9 | reshape_transform=None): 10 | super(EigenCAM, self).__init__(model, target_layers, use_cuda, 11 | reshape_transform) 12 | 13 | def get_cam_image(self, 14 | input_tensor, 15 | target_layer, 16 | target_category, 17 | activations, 18 | grads, 19 | eigen_smooth): 20 | return get_2d_projection(activations) 21 | -------------------------------------------------------------------------------- /pytorch_grad_cam/eigen_grad_cam.py: -------------------------------------------------------------------------------- 1 | from pytorch_grad_cam.base_cam import BaseCAM 2 | from pytorch_grad_cam.utils.svd_on_activations import get_2d_projection 3 | 4 | # Like Eigen CAM: https://arxiv.org/abs/2008.00299 5 | # But multiply the activations x gradients 6 | 7 | 8 | class EigenGradCAM(BaseCAM): 9 | def __init__(self, model, target_layers, use_cuda=False, 10 | reshape_transform=None): 11 | super(EigenGradCAM, self).__init__(model, target_layers, use_cuda, 12 | reshape_transform) 13 | 14 | def get_cam_image(self, 15 | input_tensor, 16 | target_layer, 17 | target_category, 18 | activations, 19 | grads, 20 | eigen_smooth): 21 | return get_2d_projection(grads * activations) 22 | -------------------------------------------------------------------------------- /pytorch_grad_cam/fullgrad_cam.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from pytorch_grad_cam.base_cam import BaseCAM 4 | from pytorch_grad_cam.utils.find_layers import find_layer_predicate_recursive 5 | from pytorch_grad_cam.utils.svd_on_activations import get_2d_projection 6 | 7 | # https://arxiv.org/abs/1905.00780 8 | 9 | 10 | class FullGrad(BaseCAM): 11 | def __init__(self, model, target_layers, use_cuda=False, 12 | reshape_transform=None): 13 | if len(target_layers) > 0: 14 | print( 15 | "Warning: target_layers is ignored in FullGrad. All bias layers will be used instead") 16 | 17 | def layer_with_2D_bias(layer): 18 | bias_target_layers = [torch.nn.Conv2d, torch.nn.BatchNorm2d] 19 | if type(layer) in bias_target_layers and layer.bias is not None: 20 | return True 21 | return False 22 | target_layers = find_layer_predicate_recursive( 23 | model, layer_with_2D_bias) 24 | super( 25 | FullGrad, 26 | self).__init__( 27 | model, 28 | target_layers, 29 | use_cuda, 30 | reshape_transform, 31 | compute_input_gradient=True) 32 | self.bias_data = [self.get_bias_data( 33 | layer).cpu().numpy() for layer in target_layers] 34 | 35 | def get_bias_data(self, layer): 36 | # Borrowed from official paper impl: 37 | # https://github.com/idiap/fullgrad-saliency/blob/master/saliency/tensor_extractor.py#L47 38 | if isinstance(layer, torch.nn.BatchNorm2d): 39 | bias = - (layer.running_mean * layer.weight 40 | / torch.sqrt(layer.running_var + layer.eps)) + layer.bias 41 | return bias.data 42 | else: 43 | return layer.bias.data 44 | 45 | def scale_accross_batch_and_channels(self, tensor, target_size): 46 | batch_size, channel_size = tensor.shape[:2] 47 | reshaped_tensor = tensor.reshape( 48 | batch_size * channel_size, *tensor.shape[2:]) 49 | result = self.scale_cam_image(reshaped_tensor, target_size) 50 | result = result.reshape( 51 | batch_size, 52 | channel_size, 53 | target_size[1], 54 | target_size[0]) 55 | return result 56 | 57 | def compute_cam_per_layer( 58 | self, 59 | input_tensor, 60 | target_category, 61 | eigen_smooth): 62 | input_grad = input_tensor.grad.data.cpu().numpy() 63 | grads_list = [g.cpu().data.numpy() for g in 64 | self.activations_and_grads.gradients] 65 | cam_per_target_layer = [] 66 | target_size = self.get_target_width_height(input_tensor) 67 | 68 | gradient_multiplied_input = input_grad * input_tensor.data.cpu().numpy() 69 | gradient_multiplied_input = np.abs(gradient_multiplied_input) 70 | gradient_multiplied_input = self.scale_accross_batch_and_channels( 71 | gradient_multiplied_input, 72 | target_size) 73 | cam_per_target_layer.append(gradient_multiplied_input) 74 | 75 | # Loop over the saliency image from every layer 76 | assert(len(self.bias_data) == len(grads_list)) 77 | for bias, grads in zip(self.bias_data, grads_list): 78 | bias = bias[None, :, None, None] 79 | # In the paper they take the absolute value, 80 | # but possibily taking only the positive gradients will work 81 | # better. 82 | bias_grad = np.abs(bias * grads) 83 | result = self.scale_accross_batch_and_channels( 84 | bias_grad, target_size) 85 | result = np.sum(result, axis=1) 86 | cam_per_target_layer.append(result[:, None, :]) 87 | cam_per_target_layer = np.concatenate(cam_per_target_layer, axis=1) 88 | if eigen_smooth: 89 | # Resize to a smaller image, since this method typically has a very large number of channels, 90 | # and then consumes a lot of memory 91 | cam_per_target_layer = self.scale_accross_batch_and_channels( 92 | cam_per_target_layer, (target_size[0] // 8, target_size[1] // 8)) 93 | cam_per_target_layer = get_2d_projection(cam_per_target_layer) 94 | cam_per_target_layer = cam_per_target_layer[:, None, :, :] 95 | cam_per_target_layer = self.scale_accross_batch_and_channels( 96 | cam_per_target_layer, 97 | target_size) 98 | else: 99 | cam_per_target_layer = np.sum( 100 | cam_per_target_layer, axis=1)[:, None, :] 101 | 102 | return cam_per_target_layer 103 | 104 | def aggregate_multi_layers(self, cam_per_target_layer): 105 | result = np.sum(cam_per_target_layer, axis=1) 106 | return self.scale_cam_image(result) 107 | -------------------------------------------------------------------------------- /pytorch_grad_cam/grad_cam.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from pytorch_grad_cam.base_cam import BaseCAM 3 | import pdb 4 | 5 | 6 | class GradCAM(BaseCAM): 7 | def __init__(self, model, target_layers, use_cuda=False, 8 | reshape_transform=None): 9 | super( 10 | GradCAM, 11 | self).__init__( 12 | model, 13 | target_layers, 14 | use_cuda, 15 | reshape_transform) 16 | 17 | def get_cam_weights(self, 18 | input_tensor, 19 | target_layer, 20 | target_category, 21 | activations, 22 | grads): 23 | return np.mean(grads, axis=(2, 3)) 24 | -------------------------------------------------------------------------------- /pytorch_grad_cam/grad_cam_plusplus.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from pytorch_grad_cam.base_cam import BaseCAM 3 | 4 | # https://arxiv.org/abs/1710.11063 5 | 6 | 7 | class GradCAMPlusPlus(BaseCAM): 8 | def __init__(self, model, target_layers, use_cuda=False, 9 | reshape_transform=None): 10 | super(GradCAMPlusPlus, self).__init__(model, target_layers, use_cuda, 11 | reshape_transform) 12 | 13 | def get_cam_weights(self, 14 | input_tensor, 15 | target_layers, 16 | target_category, 17 | activations, 18 | grads): 19 | grads_power_2 = grads**2 20 | grads_power_3 = grads_power_2 * grads 21 | # Equation 19 in https://arxiv.org/abs/1710.11063 22 | sum_activations = np.sum(activations, axis=(2, 3)) 23 | eps = 0.000001 24 | aij = grads_power_2 / (2 * grads_power_2 + 25 | sum_activations[:, :, None, None] * grads_power_3 + eps) 26 | # Now bring back the ReLU from eq.7 in the paper, 27 | # And zero out aijs where the activations are 0 28 | aij = np.where(grads != 0, aij, 0) 29 | 30 | weights = np.maximum(grads, 0) * aij 31 | weights = np.sum(weights, axis=(2, 3)) 32 | return weights 33 | -------------------------------------------------------------------------------- /pytorch_grad_cam/guided_backprop.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.autograd import Function 4 | from pytorch_grad_cam.utils.find_layers import replace_all_layer_type_recursive 5 | 6 | 7 | class GuidedBackpropReLU(Function): 8 | @staticmethod 9 | def forward(self, input_img): 10 | positive_mask = (input_img > 0).type_as(input_img) 11 | output = torch.addcmul( 12 | torch.zeros( 13 | input_img.size()).type_as(input_img), 14 | input_img, 15 | positive_mask) 16 | self.save_for_backward(input_img, output) 17 | return output 18 | 19 | @staticmethod 20 | def backward(self, grad_output): 21 | input_img, output = self.saved_tensors 22 | grad_input = None 23 | 24 | positive_mask_1 = (input_img > 0).type_as(grad_output) 25 | positive_mask_2 = (grad_output > 0).type_as(grad_output) 26 | grad_input = torch.addcmul( 27 | torch.zeros( 28 | input_img.size()).type_as(input_img), 29 | torch.addcmul( 30 | torch.zeros( 31 | input_img.size()).type_as(input_img), 32 | grad_output, 33 | positive_mask_1), 34 | positive_mask_2) 35 | return grad_input 36 | 37 | 38 | class GuidedBackpropReLUasModule(torch.nn.Module): 39 | def __init__(self): 40 | super(GuidedBackpropReLUasModule, self).__init__() 41 | 42 | def forward(self, input_img): 43 | return GuidedBackpropReLU.apply(input_img) 44 | 45 | 46 | class GuidedBackpropReLUModel: 47 | def __init__(self, model, use_cuda): 48 | self.model = model 49 | self.model.eval() 50 | self.cuda = use_cuda 51 | if self.cuda: 52 | self.model = self.model.cuda() 53 | 54 | def forward(self, input_img): 55 | return self.model(input_img) 56 | 57 | def recursive_replace_relu_with_guidedrelu(self, module_top): 58 | 59 | for idx, module in module_top._modules.items(): 60 | self.recursive_replace_relu_with_guidedrelu(module) 61 | if module.__class__.__name__ == 'ReLU': 62 | module_top._modules[idx] = GuidedBackpropReLU.apply 63 | print("b") 64 | 65 | def recursive_replace_guidedrelu_with_relu(self, module_top): 66 | # noinspection PyBroadException 67 | try: 68 | for idx, module in module_top._modules.items(): 69 | self.recursive_replace_guidedrelu_with_relu(module) 70 | if module == GuidedBackpropReLU.apply: 71 | module_top._modules[idx] = torch.nn.ReLU() 72 | except BaseException: 73 | pass 74 | 75 | def __call__(self, input_img, target_category=None): 76 | replace_all_layer_type_recursive(self.model, 77 | torch.nn.ReLU, 78 | GuidedBackpropReLUasModule()) 79 | 80 | if self.cuda: 81 | input_img = input_img.cuda() 82 | 83 | input_img = input_img.requires_grad_(True) 84 | 85 | output = self.forward(input_img) 86 | 87 | if target_category is None: 88 | target_category = np.argmax(output.cpu().data.numpy()) 89 | 90 | loss = output[0, target_category] 91 | loss.backward(retain_graph=True) 92 | 93 | output = input_img.grad.cpu().data.numpy() 94 | output = output[0, :, :, :] 95 | output = output.transpose((1, 2, 0)) 96 | 97 | replace_all_layer_type_recursive(self.model, 98 | GuidedBackpropReLUasModule, 99 | torch.nn.ReLU()) 100 | 101 | return output 102 | -------------------------------------------------------------------------------- /pytorch_grad_cam/layer_cam.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from pytorch_grad_cam.base_cam import BaseCAM 3 | from pytorch_grad_cam.utils.svd_on_activations import get_2d_projection 4 | 5 | # https://ieeexplore.ieee.org/document/9462463 6 | 7 | 8 | class LayerCAM(BaseCAM): 9 | def __init__( 10 | self, 11 | model, 12 | target_layers, 13 | use_cuda=False, 14 | reshape_transform=None): 15 | super( 16 | LayerCAM, 17 | self).__init__( 18 | model, 19 | target_layers, 20 | use_cuda, 21 | reshape_transform) 22 | 23 | def get_cam_image(self, 24 | input_tensor, 25 | target_layer, 26 | target_category, 27 | activations, 28 | grads, 29 | eigen_smooth): 30 | spatial_weighted_activations = np.maximum(grads, 0) * activations 31 | 32 | if eigen_smooth: 33 | cam = get_2d_projection(spatial_weighted_activations) 34 | else: 35 | cam = spatial_weighted_activations.sum(axis=1) 36 | return cam 37 | -------------------------------------------------------------------------------- /pytorch_grad_cam/score_cam.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import tqdm 3 | from pytorch_grad_cam.base_cam import BaseCAM 4 | 5 | 6 | class ScoreCAM(BaseCAM): 7 | def __init__( 8 | self, 9 | model, 10 | target_layers, 11 | use_cuda=False, 12 | reshape_transform=None): 13 | super(ScoreCAM, self).__init__(model, target_layers, use_cuda, 14 | reshape_transform=reshape_transform) 15 | 16 | if len(target_layers) > 0: 17 | print("Warning: You are using ScoreCAM with target layers, " 18 | "however ScoreCAM will ignore them.") 19 | 20 | def get_cam_weights(self, 21 | input_tensor, 22 | target_layer, 23 | target_category, 24 | activations, 25 | grads): 26 | with torch.no_grad(): 27 | upsample = torch.nn.UpsamplingBilinear2d( 28 | size=input_tensor.shape[-2:]) 29 | activation_tensor = torch.from_numpy(activations) 30 | if self.cuda: 31 | activation_tensor = activation_tensor.cuda() 32 | 33 | upsampled = upsample(activation_tensor) 34 | 35 | maxs = upsampled.view(upsampled.size(0), 36 | upsampled.size(1), -1).max(dim=-1)[0] 37 | mins = upsampled.view(upsampled.size(0), 38 | upsampled.size(1), -1).min(dim=-1)[0] 39 | maxs, mins = maxs[:, :, None, None], mins[:, :, None, None] 40 | upsampled = (upsampled - mins) / (maxs - mins) 41 | 42 | input_tensors = input_tensor[:, None, 43 | :, :] * upsampled[:, :, None, :, :] 44 | 45 | if hasattr(self, "batch_size"): 46 | BATCH_SIZE = self.batch_size 47 | else: 48 | BATCH_SIZE = 16 49 | 50 | scores = [] 51 | for batch_index, tensor in enumerate(input_tensors): 52 | category = target_category[batch_index] 53 | for i in tqdm.tqdm(range(0, tensor.size(0), BATCH_SIZE)): 54 | batch = tensor[i: i + BATCH_SIZE, :] 55 | outputs = self.model(batch).cpu().numpy()[:, category] 56 | scores.extend(outputs) 57 | scores = torch.Tensor(scores) 58 | scores = scores.view(activations.shape[0], activations.shape[1]) 59 | 60 | weights = torch.nn.Softmax(dim=-1)(scores).numpy() 61 | return weights 62 | -------------------------------------------------------------------------------- /pytorch_grad_cam/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from pytorch_grad_cam.utils.image import deprocess_image 2 | from pytorch_grad_cam.utils.svd_on_activations import get_2d_projection 3 | -------------------------------------------------------------------------------- /pytorch_grad_cam/utils/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Whileherham/IMR-HSNet/57572272f493c3b16540a650ff0398d605e46812/pytorch_grad_cam/utils/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /pytorch_grad_cam/utils/__pycache__/find_layers.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Whileherham/IMR-HSNet/57572272f493c3b16540a650ff0398d605e46812/pytorch_grad_cam/utils/__pycache__/find_layers.cpython-38.pyc -------------------------------------------------------------------------------- /pytorch_grad_cam/utils/__pycache__/image.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Whileherham/IMR-HSNet/57572272f493c3b16540a650ff0398d605e46812/pytorch_grad_cam/utils/__pycache__/image.cpython-38.pyc -------------------------------------------------------------------------------- /pytorch_grad_cam/utils/__pycache__/svd_on_activations.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Whileherham/IMR-HSNet/57572272f493c3b16540a650ff0398d605e46812/pytorch_grad_cam/utils/__pycache__/svd_on_activations.cpython-38.pyc -------------------------------------------------------------------------------- /pytorch_grad_cam/utils/find_layers.py: -------------------------------------------------------------------------------- 1 | def replace_layer_recursive(model, old_layer, new_layer): 2 | for name, layer in model._modules.items(): 3 | if layer == old_layer: 4 | model._modules[name] = new_layer 5 | return True 6 | elif replace_layer_recursive(layer, old_layer, new_layer): 7 | return True 8 | return False 9 | 10 | 11 | def replace_all_layer_type_recursive(model, old_layer_type, new_layer): 12 | for name, layer in model._modules.items(): 13 | if isinstance(layer, old_layer_type): 14 | model._modules[name] = new_layer 15 | replace_all_layer_type_recursive(layer, old_layer_type, new_layer) 16 | 17 | 18 | def find_layer_types_recursive(model, layer_types): 19 | def predicate(layer): 20 | return type(layer) in layer_types 21 | return find_layer_predicate_recursive(model, predicate) 22 | 23 | 24 | def find_layer_predicate_recursive(model, predicate): 25 | result = [] 26 | for name, layer in model._modules.items(): 27 | if predicate(layer): 28 | result.append(layer) 29 | result.extend(find_layer_predicate_recursive(layer, predicate)) 30 | return result 31 | -------------------------------------------------------------------------------- /pytorch_grad_cam/utils/image.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import torch 4 | from torchvision.transforms import Compose, Normalize, ToTensor 5 | 6 | 7 | def preprocess_image(img: np.ndarray, mean=None, std=None) -> torch.Tensor: 8 | if mean is None: 9 | mean = [0.5, 0.5, 0.5] 10 | if std is None: 11 | std = [0.5, 0.5, 0.5] 12 | preprocessing = Compose([ 13 | ToTensor(), 14 | Normalize(mean=mean, std=std) 15 | ]) 16 | return preprocessing(img.copy()).unsqueeze(0) 17 | 18 | 19 | def deprocess_image(img): 20 | """ see https://github.com/jacobgil/keras-grad-cam/blob/master/grad-cam.py#L65 """ 21 | img = img - np.mean(img) 22 | img = img / (np.std(img) + 1e-5) 23 | img = img * 0.1 24 | img = img + 0.5 25 | img = np.clip(img, 0, 1) 26 | return np.uint8(img * 255) 27 | 28 | 29 | def show_cam_on_image(img: np.ndarray, 30 | mask: np.ndarray, 31 | use_rgb: bool = False, 32 | colormap: int = cv2.COLORMAP_JET) -> np.ndarray: 33 | """ This function overlays the cam mask on the image as an heatmap. 34 | By default the heatmap is in BGR format. 35 | 36 | :param img: The base image in RGB or BGR format. 37 | :param mask: The cam mask. 38 | :param use_rgb: Whether to use an RGB or BGR heatmap, this should be set to True if 'img' is in RGB format. 39 | :param colormap: The OpenCV colormap to be used. 40 | :returns: The default image with the cam overlay. 41 | """ 42 | heatmap = cv2.applyColorMap(np.uint8(255 * mask), colormap) 43 | if use_rgb: 44 | heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) 45 | heatmap = np.float32(heatmap) / 255 46 | 47 | if np.max(img) > 1: 48 | raise Exception( 49 | "The input image should np.float32 in the range [0, 1]") 50 | 51 | cam = heatmap + img 52 | cam = cam / np.max(cam) 53 | return np.uint8(255 * cam) 54 | -------------------------------------------------------------------------------- /pytorch_grad_cam/utils/svd_on_activations.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def get_2d_projection(activation_batch): 5 | # TBD: use pytorch batch svd implementation 6 | activation_batch[np.isnan(activation_batch)] = 0 7 | projections = [] 8 | for activations in activation_batch: 9 | reshaped_activations = (activations).reshape( 10 | activations.shape[0], -1).transpose() 11 | # Centering before the SVD seems to be important here, 12 | # Otherwise the image returned is negative 13 | reshaped_activations = reshaped_activations - \ 14 | reshaped_activations.mean(axis=0) 15 | U, S, VT = np.linalg.svd(reshaped_activations, full_matrices=True) 16 | projection = reshaped_activations @ VT[0, :] 17 | projection = projection.reshape(activations.shape[1:]) 18 | projections.append(projection) 19 | return np.float32(projections) 20 | -------------------------------------------------------------------------------- /pytorch_grad_cam/xgrad_cam.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from pytorch_grad_cam.base_cam import BaseCAM 3 | 4 | 5 | class XGradCAM(BaseCAM): 6 | def __init__( 7 | self, 8 | model, 9 | target_layers, 10 | use_cuda=False, 11 | reshape_transform=None): 12 | super( 13 | XGradCAM, 14 | self).__init__( 15 | model, 16 | target_layers, 17 | use_cuda, 18 | reshape_transform) 19 | 20 | def get_cam_weights(self, 21 | input_tensor, 22 | target_layer, 23 | target_category, 24 | activations, 25 | grads): 26 | sum_activations = np.sum(activations, axis=(2, 3)) 27 | eps = 1e-7 28 | weights = grads * activations / \ 29 | (sum_activations[:, :, None, None] + eps) 30 | weights = weights.sum(axis=(2, 3)) 31 | return weights 32 | -------------------------------------------------------------------------------- /simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | 6 | import ftfy 7 | import regex as re 8 | 9 | 10 | @lru_cache() 11 | def default_bpe(): 12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 13 | 14 | 15 | @lru_cache() 16 | def bytes_to_unicode(): 17 | """ 18 | Returns list of utf-8 byte and a corresponding list of unicode strings. 19 | The reversible bpe codes work on unicode strings. 20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 22 | This is a signficant percentage of your normal, say, 32K bpe vocab. 23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 24 | And avoids mapping to whitespace/control characters the bpe code barfs on. 25 | """ 26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 27 | cs = bs[:] 28 | n = 0 29 | for b in range(2**8): 30 | if b not in bs: 31 | bs.append(b) 32 | cs.append(2**8+n) 33 | n += 1 34 | cs = [chr(n) for n in cs] 35 | return dict(zip(bs, cs)) 36 | 37 | 38 | def get_pairs(word): 39 | """Return set of symbol pairs in a word. 40 | Word is represented as tuple of symbols (symbols being variable-length strings). 41 | """ 42 | pairs = set() 43 | prev_char = word[0] 44 | for char in word[1:]: 45 | pairs.add((prev_char, char)) 46 | prev_char = char 47 | return pairs 48 | 49 | 50 | def basic_clean(text): 51 | text = ftfy.fix_text(text) 52 | text = html.unescape(html.unescape(text)) 53 | return text.strip() 54 | 55 | 56 | def whitespace_clean(text): 57 | text = re.sub(r'\s+', ' ', text) 58 | text = text.strip() 59 | return text 60 | 61 | 62 | class SimpleTokenizer(object): 63 | def __init__(self, bpe_path: str = default_bpe()): 64 | self.byte_encoder = bytes_to_unicode() 65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 67 | merges = merges[1:49152-256-2+1] 68 | merges = [tuple(merge.split()) for merge in merges] 69 | vocab = list(bytes_to_unicode().values()) 70 | vocab = vocab + [v+'' for v in vocab] 71 | for merge in merges: 72 | vocab.append(''.join(merge)) 73 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 74 | self.encoder = dict(zip(vocab, range(len(vocab)))) 75 | self.decoder = {v: k for k, v in self.encoder.items()} 76 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 78 | self.pat = re.compile( 79 | r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", 80 | re.IGNORECASE) 81 | 82 | def bpe(self, token): 83 | if token in self.cache: 84 | return self.cache[token] 85 | word = tuple(token[:-1]) + ( token[-1] + '',) 86 | pairs = get_pairs(word) 87 | 88 | if not pairs: 89 | return token+'' 90 | 91 | while True: 92 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 93 | if bigram not in self.bpe_ranks: 94 | break 95 | first, second = bigram 96 | new_word = [] 97 | i = 0 98 | while i < len(word): 99 | # noinspection PyBroadException 100 | try: 101 | j = word.index(first, i) 102 | new_word.extend(word[i:j]) 103 | i = j 104 | except: 105 | new_word.extend(word[i:]) 106 | break 107 | 108 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 109 | new_word.append(first+second) 110 | i += 2 111 | else: 112 | new_word.append(word[i]) 113 | i += 1 114 | new_word = tuple(new_word) 115 | word = new_word 116 | if len(word) == 1: 117 | break 118 | else: 119 | pairs = get_pairs(word) 120 | word = ' '.join(word) 121 | self.cache[token] = word 122 | return word 123 | 124 | def encode(self, text): 125 | bpe_tokens = [] 126 | text = whitespace_clean(basic_clean(text)).lower() 127 | for token in re.findall(self.pat, text): 128 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 129 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 130 | return bpe_tokens 131 | 132 | def decode(self, tokens): 133 | text = ''.join([self.decoder[token] for token in tokens]) 134 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 135 | return text 136 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | r""" Hypercorrelation Squeeze testing code """ 2 | import argparse 3 | 4 | import torch.nn.functional as F 5 | import torch.nn as nn 6 | import torch 7 | 8 | from model.hsnet_imr import HypercorrSqueezeNetwork_imr 9 | from common.logger import Logger, AverageMeter 10 | from common.vis import Visualizer 11 | from common.evaluation import Evaluator 12 | from common import utils 13 | from data.dataset import FSSDataset 14 | 15 | 16 | 17 | def test(model, dataloader, nshot, stage): 18 | r""" Test HSNet """ 19 | 20 | # Freeze randomness during testing for reproducibility 21 | utils.fix_randseed(0) 22 | average_meter = AverageMeter(dataloader.dataset) 23 | 24 | for idx, batch in enumerate(dataloader): 25 | 26 | # 1. Hypercorrelation Squeeze Networks forward pass 27 | batch = utils.to_cuda(batch) 28 | pred_mask = model.module.predict_mask_nshot(batch, nshot=nshot, stage=stage) 29 | 30 | assert pred_mask.size() == batch['query_mask'].size() 31 | 32 | # 2. Evaluate prediction 33 | area_inter, area_union = Evaluator.classify_prediction(pred_mask.clone(), batch) 34 | average_meter.update(area_inter, area_union, batch['class_id'], loss=None) 35 | average_meter.write_process(idx, len(dataloader), epoch=-1, write_batch_idx=1) 36 | 37 | # Visualize predictions 38 | if Visualizer.visualize: 39 | Visualizer.visualize_prediction_batch(batch['support_imgs'], batch['support_masks'], 40 | batch['query_img'], batch['query_mask'], 41 | pred_mask, batch['class_id'], idx, 42 | area_inter[1].float() / area_union[1].float()) 43 | 44 | average_meter.write_result('Test', 0) 45 | miou, fb_iou = average_meter.compute_iou() 46 | return miou, fb_iou 47 | 48 | 49 | if __name__ == '__main__': 50 | 51 | # Arguments parsing 52 | parser = argparse.ArgumentParser(description='Hypercorrelation Squeeze Pytorch Implementation') 53 | parser.add_argument('--datapath', type=str, default='../Datasets_HSN') 54 | parser.add_argument('--benchmark', type=str, default='pascal', choices=['pascal', 'coco', 'fss']) 55 | parser.add_argument('--logpath', type=str, default='') 56 | parser.add_argument('--bsz', type=int, default=10) 57 | parser.add_argument('--nworker', type=int, default=8) 58 | parser.add_argument('--load', type=str, default=None) 59 | parser.add_argument('--fold', type=int, default=0, choices=[0, 1, 2, 3]) 60 | parser.add_argument('--nshot', type=int, default=1) 61 | parser.add_argument('--backbone', type=str, default='resnet50', choices=['vgg16', 'resnet50', 'resnet101']) 62 | parser.add_argument('--visualize', action='store_true') 63 | parser.add_argument('--use_original_imgsize', action='store_true') 64 | parser.add_argument('--stage', type=int, default=2) 65 | parser.add_argument('--traincampath', type=str, default='../Datasets_HSN/CAM_Train/') 66 | parser.add_argument('--valcampath', type=str, default='../Datasets_HSN/CAM_Val/') 67 | args = parser.parse_args() 68 | Logger.initialize(args, training=False) 69 | 70 | # Model initialization 71 | model = HypercorrSqueezeNetwork_imr(args.backbone, args.use_original_imgsize) 72 | model.eval() 73 | Logger.log_params(model) 74 | 75 | # Device setup 76 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 77 | Logger.info('# available GPUs: %d' % torch.cuda.device_count()) 78 | model = nn.DataParallel(model) 79 | model.to(device) 80 | 81 | # Load trained model 82 | if args.load == '': 83 | raise Exception('Pretrained model not specified.') 84 | model.load_state_dict(torch.load(args.load)) 85 | 86 | # Helper classes (for testing) initialization 87 | Evaluator.initialize() 88 | Visualizer.initialize(args.visualize) 89 | 90 | # Dataset initialization 91 | FSSDataset.initialize(img_size=400, datapath=args.datapath, use_original_imgsize=args.use_original_imgsize) 92 | dataloader_test = FSSDataset.build_dataloader(args.benchmark, args.bsz, args.nworker, args.fold, 'test', args.nshot, 93 | cam_train_path=args.traincampath, cam_val_path=args.valcampath) 94 | 95 | # Test HSNet 96 | with torch.no_grad(): 97 | test_miou, test_fb_iou = test(model, dataloader_test, args.nshot, args.stage) 98 | Logger.info('Fold %d mIoU: %5.2f \t FB-IoU: %5.2f' % (args.fold, test_miou.item(), test_fb_iou.item())) 99 | Logger.info('==================== Finished Testing ====================') 100 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | r""" Hypercorrelation Squeeze training (validation) code """ 2 | import argparse 3 | import pdb 4 | 5 | import torch.optim as optim 6 | import torch.nn as nn 7 | import torch 8 | 9 | from common.logger import Logger, AverageMeter 10 | from common.evaluation import Evaluator 11 | from common import utils 12 | from data.dataset import FSSDataset 13 | from model.hsnet_imr import HypercorrSqueezeNetwork_imr 14 | 15 | 16 | def train(epoch, model, dataloader, optimizer, training, stage): 17 | r""" Train HSNet """ 18 | 19 | # Force randomness during training / freeze randomness during testing 20 | utils.fix_randseed(None) if training else utils.fix_randseed(0) 21 | model.module.train_mode() if training else model.module.eval() 22 | average_meter = AverageMeter(dataloader.dataset) 23 | 24 | for idx, batch in enumerate(dataloader): 25 | 26 | # 1. Hypercorrelation Squeeze Networks forward pass 27 | batch = utils.to_cuda(batch) 28 | logit_mask_q, logit_mask_s, losses = model( 29 | query_img=batch['query_img'], support_img=batch['support_imgs'].squeeze(1), 30 | support_cam=batch['support_cams'].squeeze(1), query_cam=batch['query_cam'], stage=stage, 31 | query_mask=batch['query_mask'], support_mask=batch['support_masks'].squeeze(1)) 32 | pred_mask_q = logit_mask_q.argmax(dim=1) 33 | 34 | # 2. Compute loss & update model parameters 35 | loss = losses.mean() 36 | if training: 37 | optimizer.zero_grad() 38 | loss.backward() 39 | optimizer.step() 40 | 41 | # 3. Evaluate prediction 42 | area_inter, area_union = Evaluator.classify_prediction(pred_mask_q, batch) 43 | average_meter.update(area_inter, area_union, batch['class_id'], loss.detach().clone()) 44 | average_meter.write_process(idx, len(dataloader), epoch, write_batch_idx=50) 45 | 46 | # Write evaluation results 47 | average_meter.write_result('Training' if training else 'Validation', epoch) 48 | avg_loss = utils.mean(average_meter.loss_buf) 49 | miou, fb_iou = average_meter.compute_iou() 50 | 51 | return avg_loss, miou, fb_iou 52 | 53 | 54 | if __name__ == '__main__': 55 | # Arguments parsing 56 | parser = argparse.ArgumentParser(description='Hypercorrelation Squeeze Pytorch Implementation') 57 | parser.add_argument('--datapath', type=str, default='../Datasets_HSN') 58 | parser.add_argument('--benchmark', type=str, default='pascal', choices=['pascal', 'coco', 'fss']) 59 | parser.add_argument('--logpath', type=str, default='') 60 | parser.add_argument('--bsz', type=int, default=40) 61 | parser.add_argument('--lr', type=float, default=4e-4) 62 | parser.add_argument('--niter', type=int, default=400) 63 | parser.add_argument('--nworker', type=int, default=8) 64 | parser.add_argument('--fold', type=int, default=0, choices=[0, 1, 2, 3]) 65 | parser.add_argument('--stage', type=int, default=2) 66 | parser.add_argument('--backbone', type=str, default='resnet50', choices=['vgg16', 'resnet50', 'resnet101']) 67 | parser.add_argument('--traincampath', type=str, default='../Datasets_HSN/CAM_Train/') 68 | parser.add_argument('--valcampath', type=str, default='../Datasets_HSN/CAM_Val/') 69 | 70 | args = parser.parse_args() 71 | Logger.initialize(args, training=True) 72 | assert args.bsz % torch.cuda.device_count() == 0 73 | 74 | # Model initialization 75 | model = HypercorrSqueezeNetwork_imr(args.backbone, False) 76 | Logger.log_params(model) 77 | 78 | # Device setup 79 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 80 | Logger.info('# available GPUs: %d' % torch.cuda.device_count()) 81 | model = nn.DataParallel(model) 82 | model.to(device) 83 | 84 | # Helper classes (for training) initialization 85 | optimizer = optim.Adam([{"params": model.parameters(), "lr": args.lr}]) 86 | Evaluator.initialize() 87 | 88 | # Dataset initialization 89 | FSSDataset.initialize(img_size=400, datapath=args.datapath, use_original_imgsize=False) 90 | dataloader_trn = FSSDataset.build_dataloader(args.benchmark, args.bsz, args.nworker, args.fold, 'trn', 91 | cam_train_path=args.traincampath, cam_val_path=args.valcampath) 92 | dataloader_val = FSSDataset.build_dataloader(args.benchmark, args.bsz, args.nworker, args.fold, 'val', 93 | cam_train_path=args.traincampath, cam_val_path=args.valcampath) 94 | 95 | # Train HSNet 96 | best_val_miou = float('-inf') 97 | best_val_loss = float('inf') 98 | for epoch in range(args.niter): 99 | trn_loss, trn_miou, trn_fb_iou = train(epoch, model, dataloader_trn, optimizer, training=True, 100 | stage=args.stage) 101 | with torch.no_grad(): 102 | val_loss, val_miou, val_fb_iou = train(epoch, model, dataloader_val, optimizer, training=False, 103 | stage=args.stage) 104 | # Save the best model 105 | if val_miou > best_val_miou: 106 | best_val_miou = val_miou 107 | Logger.save_model_miou(model, epoch, val_miou) 108 | 109 | Logger.tbd_writer.add_scalars('data/loss', {'trn_loss': trn_loss, 'val_loss': val_loss}, epoch) 110 | Logger.tbd_writer.add_scalars('data/miou', {'trn_miou': trn_miou, 'val_miou': val_miou}, epoch) 111 | Logger.tbd_writer.add_scalars('data/fb_iou', {'trn_fb_iou': trn_fb_iou, 'val_fb_iou': val_fb_iou}, epoch) 112 | Logger.tbd_writer.flush() 113 | Logger.tbd_writer.close() 114 | Logger.info('==================== Finished Training ====================') 115 | --------------------------------------------------------------------------------