├── .DS_Store ├── .gitignore ├── .gitmodules ├── README.md ├── clip ├── __init__.py ├── bpe_simple_vocab_16e6.txt.gz ├── clip.py ├── model.py └── simple_tokenizer.py ├── maskcut ├── crf.py ├── dino.py ├── muskcut.py └── ovcamo_cato.txt ├── pytorch_grad_cam ├── __init__.py ├── ablation_cam.py ├── ablation_cam_multilayer.py ├── ablation_layer.py ├── activations_and_gradients.py ├── base_cam.py ├── eigen_cam.py ├── eigen_grad_cam.py ├── fullgrad_cam.py ├── grad_cam.py ├── grad_cam_plusplus.py ├── guided_backprop.py ├── layer_cam.py ├── score_cam.py ├── utils │ ├── __init__.py │ ├── find_layers.py │ ├── image.py │ ├── model_targets.py │ ├── reshape_transforms.py │ └── svd_on_activations.py └── xgrad_cam.py ├── requirements.txt └── tpnet ├── __init__.py ├── config ├── __init__.py └── cutler_config.py ├── data ├── __init__.py ├── build.py ├── dataset_mapper.py ├── datasets │ ├── __init__.py │ ├── builtin.py │ ├── builtin_meta.py │ ├── coco.py │ └── register_cis.py ├── detection_utils.py └── transforms │ ├── __init__.py │ ├── augmentation_impl.py │ └── transform.py ├── engine ├── __init__.py ├── defaults.py └── train_loop.py ├── evaluation ├── __init__.py └── coco_evaluation.py ├── model_zoo └── configs │ ├── Base-RCNN-FPN.yaml │ ├── COCO-Semisupervised │ ├── cascade_mask_rcnn_R_50_FPN_100perc.yaml │ ├── cascade_mask_rcnn_R_50_FPN_10perc.yaml │ ├── cascade_mask_rcnn_R_50_FPN_1perc.yaml │ ├── cascade_mask_rcnn_R_50_FPN_20perc.yaml │ ├── cascade_mask_rcnn_R_50_FPN_2perc.yaml │ ├── cascade_mask_rcnn_R_50_FPN_30perc.yaml │ ├── cascade_mask_rcnn_R_50_FPN_40perc.yaml │ ├── cascade_mask_rcnn_R_50_FPN_50perc.yaml │ ├── cascade_mask_rcnn_R_50_FPN_5perc.yaml │ ├── cascade_mask_rcnn_R_50_FPN_60perc.yaml │ └── cascade_mask_rcnn_R_50_FPN_80perc.yaml │ └── CutLER-ImageNet │ ├── cascade_mask_GT.yaml │ ├── cascade_mask_rcnn_R_50_FPN.yaml │ ├── cascade_mask_rcnn_R_50_FPN_demo.yaml │ ├── cascade_mask_rcnn_R_50_FPN_self_train.yaml │ ├── mask_rcnn_R_50_FPN.yaml │ └── test.yaml ├── modeling ├── __init__.py ├── meta_arch │ ├── __init__.py │ ├── build.py │ └── rcnn.py └── roi_heads │ ├── __init__.py │ ├── custom_cascade_rcnn.py │ ├── fast_rcnn.py │ └── roi_heads.py ├── solver ├── __init__.py └── build.py ├── structures ├── __init__.py └── boxes.py ├── tools ├── eval.sh ├── get_self_training_ann.py ├── run_with_submitit.sh ├── run_with_submitit_ssl.sh ├── single-node_run.sh └── train-1node.sh └── train_net.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zertow/TPNet/e2977653d35ac91db8c9c4201e1fb3c62e221be1/.DS_Store -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | __MACOSX/ 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | pip-wheel-metadata/ 25 | share/python-wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | *.py,cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | db.sqlite3-journal 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # pretrained models 82 | videocutler/pretrain 83 | *.pth 84 | 85 | # demo results 86 | demos/ 87 | 88 | # IPython 89 | profile_default/ 90 | ipython_config.py 91 | 92 | # pyenv 93 | .python-version 94 | 95 | # pipenv 96 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 97 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 98 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 99 | # install all needed dependencies. 100 | #Pipfile.lock 101 | 102 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 103 | __pypackages__/ 104 | 105 | # Celery stuff 106 | celerybeat-schedule 107 | celerybeat.pid 108 | 109 | # SageMath parsed files 110 | *.sage.py 111 | 112 | # Environments 113 | .env 114 | .venv 115 | env/ 116 | venv/ 117 | ENV/ 118 | env.bak/ 119 | venv.bak/ 120 | 121 | # Spyder project settings 122 | .spyderproject 123 | .spyproject 124 | 125 | # Rope project settings 126 | .ropeproject 127 | 128 | # mkdocs documentation 129 | /site -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "third_party/TokenCut"] 2 | path = third_party/TokenCut 3 | url = https://github.com/YangtaoWANG95/TokenCut.git 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TPNet 2 | 3 | > Text-prompt Camouflaged Instance Segmentation with Graduated Camouflage Learning, 4 | > *ACM MM 2024* 5 | 6 | ## Abstract 7 | Camouflaged instance segmentation (CIS) aims to seamlessly detect and segment objects blending with their surroundings. While existing CIS methods rely heavily on fully-supervised training with massive precisely annotated data, consuming considerable annotation efforts yet struggling to segment highly camouflaged objects accurately. Despite their visual similarity to the background, camouflaged objects differ semantically. Since text associated with images offers explicit semantic cues to underscore this difference, in this paper we propose a novel approach: the first \textbf{T}ext-\textbf{P}rompt based weakly-supervised camouflaged instance segmentation method named TPNet, leveraging semantic distinctions for effective segmentation. Specifically, TPNet operates in two stages: initiating with the generation of pseudo masks followed by a self-training process. In the pseudo mask generation stage, we innovatively align text prompts with images using a pre-training language-image model to obtain region proposals containing camouflaged instances and specific text prompt. Additionally, a Semantic-Spatial Iterative Fusion module is ingeniously designed to assimilate spatial information with semantic insights, iteratively refining pseudo mask. In the following stage, we employ Graduated Camouflage Learning, a straightforward self-training optimization strategy that evaluates camouflage levels to sequence training from simple to complex images, facilitating for an effective learning gradient. Through the collaboration of the dual phases, our method offers a comprehensive experiment on two common benchmark and demonstrates a significant advancement, delivering a novel solution that bridges the gap between weak-supervised and high camouflaged instance segmentation. 8 | 9 | 10 | ## Usage 11 | ### Install 12 | ```bash 13 | conda create --name tpnet python=3.8 -y 14 | conda activate tpnet 15 | conda install pytorch==1.8.1 torchvision==0.9.1 torchaudio==0.8.1 -c pytorch 16 | pip install git+https://github.com/lucasb-eyer/pydensecrf.git 17 | 18 | # under your working directory 19 | git clone git@github.com:facebookresearch/detectron2.git 20 | cd detectron2 21 | pip install -e . 22 | pip install git+https://github.com/cocodataset/panopticapi.git 23 | pip install git+https://github.com/mcordts/cityscapesScripts.git 24 | pip install -r requirements.txt 25 | 26 | ``` 27 | - [DETReg](https://github.com/amirbar/DETReg/): Follow the instructions at DETReg for installation. 28 | - [OSFormer](https://github.com/PJLallen/OSFormer): Follow the instructions at OSFormer to obtain the dataset and generate the dataset format. 29 | - [CLIP](https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt): Download CLIP pre-trained ViT-B/16. 30 | ### Directory 31 | The directory should be like this: 32 | 33 | ```` 34 | -- TPNet 35 | -- data (train dataset and test dataset) 36 | -- model (saved model) 37 | -- pre (pretrained model) 38 | -- data (train dataset and test dataset) 39 | |-- CIS 40 | | |-- Train_Image_CAM 41 | | |-- COD10K 42 | | |-- NC4K 43 | ... 44 | 45 | ```` 46 | 47 | ### Train 48 | #### stage1 49 | ```bash 50 | cd maskcut 51 | python maskcut.py --vit-arch base --patch-size 8 --tau 0.15 --fixed_size 480 --N 3 --dataset-path path/to/Train/data --out-dir output/path 52 | ``` 53 | * We adopt pre-trained DINO and CLIP as pretrain model. 54 | 55 | #### stage2 56 | ```bash 57 | cd tpnet 58 | python train_net.py --num-gpus 4 --config-file /media/data2/HZT/CutLER/cutler/model_zoo/configs/CutLER-ImageNet/cascade_mask_rcnn_R_50_FPN.yaml 59 | ``` 60 | 61 | ## Acknowledgement 62 | We borrowed the code from [CutLER](https://github.com/facebookresearch/CutLER/tree/main) and [pytorch_grad_camg](https://github.com/jacobgil/pytorch-grad-cam/tree/61e9babae8600351b02b6e90864e4807f44f2d4a/). Thanks for their wonderful works. 63 | ``` 64 | ## Citation 65 | @inproceedings{xia2024text, 66 | title={Text-prompt Camouflaged Instance Segmentation with Graduated Camouflage Learning}, 67 | author={Xia, Changqun and Qiao, Shengye and Li, Jia and others}, 68 | booktitle={ACM Multimedia 2024}, 69 | year={2024} 70 | } 71 | ``` 72 | -------------------------------------------------------------------------------- /clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .clip import * 2 | -------------------------------------------------------------------------------- /clip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zertow/TPNet/e2977653d35ac91db8c9c4201e1fb3c62e221be1/clip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /clip/clip.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import urllib 4 | import warnings 5 | from typing import Any, Union, List 6 | from pkg_resources import packaging 7 | 8 | import torch 9 | from PIL import Image 10 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize 11 | from tqdm import tqdm 12 | 13 | from .model import build_model 14 | from .simple_tokenizer import SimpleTokenizer as _Tokenizer 15 | from collections import OrderedDict 16 | 17 | try: 18 | from torchvision.transforms import InterpolationMode 19 | BICUBIC = InterpolationMode.BICUBIC 20 | except ImportError: 21 | BICUBIC = Image.BICUBIC 22 | 23 | 24 | if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"): 25 | warnings.warn("PyTorch version 1.7.1 or higher is recommended") 26 | 27 | 28 | __all__ = ["available_models", "load", "tokenize"] 29 | _tokenizer = _Tokenizer() 30 | 31 | _MODELS = { 32 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 33 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 34 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", 35 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", 36 | "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt", 37 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 38 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", 39 | "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", 40 | "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt", 41 | } 42 | 43 | 44 | def _download(url: str, root: str): 45 | os.makedirs(root, exist_ok=True) 46 | filename = os.path.basename(url) 47 | 48 | expected_sha256 = url.split("/")[-2] 49 | download_target = os.path.join(root, filename) 50 | 51 | if os.path.exists(download_target) and not os.path.isfile(download_target): 52 | raise RuntimeError(f"{download_target} exists and is not a regular file") 53 | 54 | if os.path.isfile(download_target): 55 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: 56 | return download_target 57 | else: 58 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") 59 | 60 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 61 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop: 62 | while True: 63 | buffer = source.read(8192) 64 | if not buffer: 65 | break 66 | 67 | output.write(buffer) 68 | loop.update(len(buffer)) 69 | 70 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: 71 | raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") 72 | 73 | return download_target 74 | 75 | 76 | def _convert_image_to_rgb(image): 77 | return image.convert("RGB") 78 | 79 | 80 | def _transform(n_px): 81 | return Compose([ 82 | Resize(n_px, interpolation=BICUBIC), 83 | CenterCrop(n_px), 84 | _convert_image_to_rgb, 85 | ToTensor(), 86 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 87 | ]) 88 | 89 | 90 | def available_models() -> List[str]: 91 | """Returns the names of available CLIP models""" 92 | return list(_MODELS.keys()) 93 | 94 | 95 | def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None): 96 | """Load a CLIP model 97 | 98 | Parameters 99 | ---------- 100 | name : str 101 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 102 | 103 | device : Union[str, torch.device] 104 | The device to put the loaded model 105 | 106 | jit : bool 107 | Whether to load the optimized JIT model or more hackable non-JIT model (default). 108 | 109 | download_root: str 110 | path to download the model files; by default, it uses "~/.cache/clip" 111 | 112 | Returns 113 | ------- 114 | model : torch.nn.Module 115 | The CLIP model 116 | 117 | preprocess : Callable[[PIL.Image], torch.Tensor] 118 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 119 | """ 120 | if name in _MODELS: 121 | model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip")) 122 | elif os.path.isfile(name): 123 | model_path = name 124 | else: 125 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}") 126 | 127 | with open(model_path, 'rb') as opened_file: 128 | try: 129 | # loading JIT archive 130 | model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval() 131 | state_dict = None 132 | except RuntimeError: 133 | # loading saved state dict 134 | if jit: 135 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") 136 | jit = False 137 | if 'RN50' in model_path: 138 | state_dict = torch.load(opened_file, map_location="cpu") 139 | else: 140 | state_dict0 = torch.load(model_path, map_location="cpu") 141 | state_dict = OrderedDict() 142 | for k in state_dict0.keys(): 143 | state_dict[k.replace('module.', '')] = state_dict0[k] 144 | 145 | 146 | if not jit: 147 | model = build_model(state_dict or model.state_dict()).to(device) 148 | if str(device) == "cpu": 149 | model.float() 150 | return model, _transform(model.visual.input_resolution) 151 | 152 | # patch the device names 153 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) 154 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] 155 | 156 | def patch_device(module): 157 | try: 158 | graphs = [module.graph] if hasattr(module, "graph") else [] 159 | except RuntimeError: 160 | graphs = [] 161 | 162 | if hasattr(module, "forward1"): 163 | graphs.append(module.forward1.graph) 164 | 165 | for graph in graphs: 166 | for node in graph.findAllNodes("prim::Constant"): 167 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): 168 | node.copyAttributes(device_node) 169 | 170 | model.apply(patch_device) 171 | patch_device(model.encode_image) 172 | patch_device(model.encode_text) 173 | 174 | # patch dtype to float32 on CPU 175 | if str(device) == "cpu": 176 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) 177 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 178 | float_node = float_input.node() 179 | 180 | def patch_float(module): 181 | try: 182 | graphs = [module.graph] if hasattr(module, "graph") else [] 183 | except RuntimeError: 184 | graphs = [] 185 | 186 | if hasattr(module, "forward1"): 187 | graphs.append(module.forward1.graph) 188 | 189 | for graph in graphs: 190 | for node in graph.findAllNodes("aten::to"): 191 | inputs = list(node.inputs()) 192 | for i in [1, 2]: # dtype can be the second or third argument to aten::to() 193 | if inputs[i].node()["value"] == 5: 194 | inputs[i].node().copyAttributes(float_node) 195 | 196 | model.apply(patch_float) 197 | patch_float(model.encode_image) 198 | patch_float(model.encode_text) 199 | 200 | model.float() 201 | 202 | return model, _transform(model.input_resolution.item()) 203 | 204 | 205 | def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]: 206 | """ 207 | Returns the tokenized representation of given input string(s) 208 | 209 | Parameters 210 | ---------- 211 | texts : Union[str, List[str]] 212 | An input string or a list of input strings to tokenize 213 | 214 | context_length : int 215 | The context length to use; all CLIP models use 77 as the context length 216 | 217 | truncate: bool 218 | Whether to truncate the text in case its encoding is longer than the context length 219 | 220 | Returns 221 | ------- 222 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]. 223 | We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long. 224 | """ 225 | if isinstance(texts, str): 226 | texts = [texts] 227 | 228 | sot_token = _tokenizer.encoder["<|startoftext|>"] 229 | eot_token = _tokenizer.encoder["<|endoftext|>"] 230 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 231 | if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"): 232 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 233 | else: 234 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.int) 235 | 236 | for i, tokens in enumerate(all_tokens): 237 | if len(tokens) > context_length: 238 | if truncate: 239 | tokens = tokens[:context_length] 240 | tokens[-1] = eot_token 241 | else: 242 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") 243 | result[i, :len(tokens)] = torch.tensor(tokens) 244 | 245 | return result 246 | -------------------------------------------------------------------------------- /clip/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | 6 | import ftfy 7 | import regex as re 8 | 9 | 10 | @lru_cache() 11 | def default_bpe(): 12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 13 | 14 | 15 | @lru_cache() 16 | def bytes_to_unicode(): 17 | """ 18 | Returns list of utf-8 byte and a corresponding list of unicode strings. 19 | The reversible bpe codes work on unicode strings. 20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 22 | This is a signficant percentage of your normal, say, 32K bpe vocab. 23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 24 | And avoids mapping to whitespace/control characters the bpe code barfs on. 25 | """ 26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 27 | cs = bs[:] 28 | n = 0 29 | for b in range(2**8): 30 | if b not in bs: 31 | bs.append(b) 32 | cs.append(2**8+n) 33 | n += 1 34 | cs = [chr(n) for n in cs] 35 | return dict(zip(bs, cs)) 36 | 37 | 38 | def get_pairs(word): 39 | """Return set of symbol pairs in a word. 40 | Word is represented as tuple of symbols (symbols being variable-length strings). 41 | """ 42 | pairs = set() 43 | prev_char = word[0] 44 | for char in word[1:]: 45 | pairs.add((prev_char, char)) 46 | prev_char = char 47 | return pairs 48 | 49 | 50 | def basic_clean(text): 51 | text = ftfy.fix_text(text) 52 | text = html.unescape(html.unescape(text)) 53 | return text.strip() 54 | 55 | 56 | def whitespace_clean(text): 57 | text = re.sub(r'\s+', ' ', text) 58 | text = text.strip() 59 | return text 60 | 61 | 62 | class SimpleTokenizer(object): 63 | def __init__(self, bpe_path: str = default_bpe()): 64 | self.byte_encoder = bytes_to_unicode() 65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 67 | merges = merges[1:49152-256-2+1] 68 | merges = [tuple(merge.split()) for merge in merges] 69 | vocab = list(bytes_to_unicode().values()) 70 | vocab = vocab + [v+'' for v in vocab] 71 | for merge in merges: 72 | vocab.append(''.join(merge)) 73 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 74 | self.encoder = dict(zip(vocab, range(len(vocab)))) 75 | self.decoder = {v: k for k, v in self.encoder.items()} 76 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 79 | 80 | def bpe(self, token): 81 | if token in self.cache: 82 | return self.cache[token] 83 | word = tuple(token[:-1]) + ( token[-1] + '',) 84 | pairs = get_pairs(word) 85 | 86 | if not pairs: 87 | return token+'' 88 | 89 | while True: 90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 91 | if bigram not in self.bpe_ranks: 92 | break 93 | first, second = bigram 94 | new_word = [] 95 | i = 0 96 | while i < len(word): 97 | try: 98 | j = word.index(first, i) 99 | new_word.extend(word[i:j]) 100 | i = j 101 | except: 102 | new_word.extend(word[i:]) 103 | break 104 | 105 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 106 | new_word.append(first+second) 107 | i += 2 108 | else: 109 | new_word.append(word[i]) 110 | i += 1 111 | new_word = tuple(new_word) 112 | word = new_word 113 | if len(word) == 1: 114 | break 115 | else: 116 | pairs = get_pairs(word) 117 | word = ' '.join(word) 118 | self.cache[token] = word 119 | return word 120 | 121 | def encode(self, text): 122 | bpe_tokens = [] 123 | text = whitespace_clean(basic_clean(text)).lower() 124 | for token in re.findall(self.pat, text): 125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 127 | return bpe_tokens 128 | 129 | def decode(self, tokens): 130 | text = ''.join([self.decoder[token] for token in tokens]) 131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 132 | return text 133 | -------------------------------------------------------------------------------- /maskcut/crf.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # modfied by Xudong Wang based on https://github.com/lucasb-eyer/pydensecrf/blob/master/pydensecrf/tests/test_dcrf.py and third_party/TokenCut 3 | 4 | import numpy as np 5 | import pydensecrf.densecrf as dcrf 6 | import pydensecrf.utils as utils 7 | import torch 8 | import torch.nn.functional as F 9 | import torchvision.transforms.functional as VF 10 | 11 | MAX_ITER = 10 12 | POS_W = 7 13 | POS_XY_STD = 3 14 | Bi_W = 10 15 | Bi_XY_STD = 50 16 | Bi_RGB_STD = 5 17 | 18 | def densecrf(image, mask): 19 | h, w = mask.shape 20 | mask = mask.reshape(1, h, w) 21 | fg = mask.astype(float) 22 | bg = 1 - fg 23 | output_logits = torch.from_numpy(np.concatenate((bg,fg), axis=0)) 24 | 25 | H, W = image.shape[:2] 26 | image = np.ascontiguousarray(image) 27 | 28 | output_logits = F.interpolate(output_logits.unsqueeze(0), size=(H, W), mode="bilinear").squeeze() 29 | output_probs = F.softmax(output_logits, dim=0).cpu().numpy() 30 | 31 | c = output_probs.shape[0] 32 | h = output_probs.shape[1] 33 | w = output_probs.shape[2] 34 | 35 | U = utils.unary_from_softmax(output_probs) 36 | U = np.ascontiguousarray(U) 37 | 38 | d = dcrf.DenseCRF2D(w, h, c) 39 | d.setUnaryEnergy(U) 40 | d.addPairwiseGaussian(sxy=POS_XY_STD, compat=POS_W) 41 | d.addPairwiseBilateral(sxy=Bi_XY_STD, srgb=Bi_RGB_STD, rgbim=image, compat=Bi_W) 42 | 43 | Q = d.inference(MAX_ITER) 44 | Q = np.array(Q).reshape((c, h, w)) 45 | MAP = np.argmax(Q, axis=0).reshape((h,w)).astype(np.float32) 46 | return MAP 47 | -------------------------------------------------------------------------------- /maskcut/dino.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | """ 4 | Copied from Dino repo. https://github.com/facebookresearch/dino 5 | Mostly copy-paste from timm library. 6 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py 7 | """ 8 | import math 9 | from functools import partial 10 | 11 | import torch 12 | import torch.nn as nn 13 | 14 | def _no_grad_trunc_normal_(tensor, mean, std, a, b): 15 | # Cut & paste from PyTorch official master until it's in a few official releases - RW 16 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf 17 | def norm_cdf(x): 18 | # Computes standard normal cumulative distribution function 19 | return (1. + math.erf(x / math.sqrt(2.))) / 2. 20 | 21 | if (mean < a - 2 * std) or (mean > b + 2 * std): 22 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " 23 | "The distribution of values may be incorrect.", 24 | stacklevel=2) 25 | 26 | with torch.no_grad(): 27 | # Values are generated by using a truncated uniform distribution and 28 | # then using the inverse CDF for the normal distribution. 29 | # Get upper and lower cdf values 30 | l = norm_cdf((a - mean) / std) 31 | u = norm_cdf((b - mean) / std) 32 | 33 | # Uniformly fill tensor with values from [l, u], then translate to 34 | # [2l-1, 2u-1]. 35 | tensor.uniform_(2 * l - 1, 2 * u - 1) 36 | 37 | # Use inverse cdf transform for normal distribution to get truncated 38 | # standard normal 39 | tensor.erfinv_() 40 | 41 | # Transform to proper mean, std 42 | tensor.mul_(std * math.sqrt(2.)) 43 | tensor.add_(mean) 44 | 45 | # Clamp to ensure it's in the proper range 46 | tensor.clamp_(min=a, max=b) 47 | return tensor 48 | 49 | 50 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): 51 | # type: (Tensor, float, float, float, float) -> Tensor 52 | return _no_grad_trunc_normal_(tensor, mean, std, a, b) 53 | 54 | 55 | def drop_path(x, drop_prob: float = 0., training: bool = False): 56 | if drop_prob == 0. or not training: 57 | return x 58 | keep_prob = 1 - drop_prob 59 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 60 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) 61 | random_tensor.floor_() # binarize 62 | output = x.div(keep_prob) * random_tensor 63 | return output 64 | 65 | 66 | class DropPath(nn.Module): 67 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 68 | """ 69 | def __init__(self, drop_prob=None): 70 | super(DropPath, self).__init__() 71 | self.drop_prob = drop_prob 72 | 73 | def forward(self, x): 74 | return drop_path(x, self.drop_prob, self.training) 75 | 76 | 77 | class Mlp(nn.Module): 78 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 79 | super().__init__() 80 | out_features = out_features or in_features 81 | hidden_features = hidden_features or in_features 82 | self.fc1 = nn.Linear(in_features, hidden_features) 83 | self.act = act_layer() 84 | self.fc2 = nn.Linear(hidden_features, out_features) 85 | self.drop = nn.Dropout(drop) 86 | 87 | def forward(self, x): 88 | x = self.fc1(x) 89 | x = self.act(x) 90 | x = self.drop(x) 91 | x = self.fc2(x) 92 | x = self.drop(x) 93 | return x 94 | 95 | 96 | class Attention(nn.Module): 97 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 98 | super().__init__() 99 | self.num_heads = num_heads 100 | head_dim = dim // num_heads 101 | self.scale = qk_scale or head_dim ** -0.5 102 | 103 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 104 | self.attn_drop = nn.Dropout(attn_drop) 105 | self.proj = nn.Linear(dim, dim) 106 | self.proj_drop = nn.Dropout(proj_drop) 107 | 108 | def forward(self, x): 109 | B, N, C = x.shape 110 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 111 | q, k, v = qkv[0], qkv[1], qkv[2] 112 | 113 | attn = (q @ k.transpose(-2, -1)) * self.scale 114 | attn = attn.softmax(dim=-1) 115 | attn = self.attn_drop(attn) 116 | 117 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 118 | x = self.proj(x) 119 | x = self.proj_drop(x) 120 | return x, attn 121 | 122 | 123 | class Block(nn.Module): 124 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 125 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 126 | super().__init__() 127 | self.norm1 = norm_layer(dim) 128 | self.attn = Attention( 129 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 130 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 131 | self.norm2 = norm_layer(dim) 132 | mlp_hidden_dim = int(dim * mlp_ratio) 133 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 134 | 135 | def forward(self, x, return_attention=False): 136 | y, attn = self.attn(self.norm1(x)) 137 | if return_attention: 138 | return attn 139 | x = x + self.drop_path(y) 140 | x = x + self.drop_path(self.mlp(self.norm2(x))) 141 | return x 142 | 143 | 144 | class PatchEmbed(nn.Module): 145 | """ Image to Patch Embedding 146 | """ 147 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): 148 | super().__init__() 149 | num_patches = (img_size // patch_size) * (img_size // patch_size) 150 | self.img_size = img_size 151 | self.patch_size = patch_size 152 | self.num_patches = num_patches 153 | 154 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 155 | 156 | def forward(self, x): 157 | B, C, H, W = x.shape 158 | x = self.proj(x).flatten(2).transpose(1, 2) 159 | return x 160 | 161 | 162 | class VisionTransformer(nn.Module): 163 | """ Vision Transformer """ 164 | def __init__(self, img_size=[224], patch_size=16, in_chans=3, num_classes=0, embed_dim=768, depth=12, 165 | num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., 166 | drop_path_rate=0., norm_layer=nn.LayerNorm, **kwargs): 167 | super().__init__() 168 | self.num_features = self.embed_dim = embed_dim 169 | 170 | self.patch_embed = PatchEmbed( 171 | img_size=img_size[0], patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 172 | num_patches = self.patch_embed.num_patches 173 | 174 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 175 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) 176 | self.pos_drop = nn.Dropout(p=drop_rate) 177 | 178 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 179 | self.blocks = nn.ModuleList([ 180 | Block( 181 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 182 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) 183 | for i in range(depth)]) 184 | self.norm = norm_layer(embed_dim) 185 | 186 | # Classifier head 187 | self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() 188 | 189 | trunc_normal_(self.pos_embed, std=.02) 190 | trunc_normal_(self.cls_token, std=.02) 191 | self.apply(self._init_weights) 192 | 193 | def _init_weights(self, m): 194 | if isinstance(m, nn.Linear): 195 | trunc_normal_(m.weight, std=.02) 196 | if isinstance(m, nn.Linear) and m.bias is not None: 197 | nn.init.constant_(m.bias, 0) 198 | elif isinstance(m, nn.LayerNorm): 199 | nn.init.constant_(m.bias, 0) 200 | nn.init.constant_(m.weight, 1.0) 201 | 202 | def interpolate_pos_encoding(self, x, w, h): 203 | npatch = x.shape[1] - 1 204 | N = self.pos_embed.shape[1] - 1 205 | if npatch == N and w == h: 206 | return self.pos_embed 207 | class_pos_embed = self.pos_embed[:, 0] 208 | patch_pos_embed = self.pos_embed[:, 1:] 209 | dim = x.shape[-1] 210 | w0 = w // self.patch_embed.patch_size 211 | h0 = h // self.patch_embed.patch_size 212 | # we add a small number to avoid floating point error in the interpolation 213 | # see discussion at https://github.com/facebookresearch/dino/issues/8 214 | w0, h0 = w0 + 0.1, h0 + 0.1 215 | patch_pos_embed = nn.functional.interpolate( 216 | patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), 217 | scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), 218 | mode='bicubic', 219 | ) 220 | assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1] 221 | patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) 222 | return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) 223 | 224 | def prepare_tokens(self, x): 225 | B, nc, w, h = x.shape 226 | x = self.patch_embed(x) # patch linear embedding 227 | 228 | # add the [CLS] token to the embed patch tokens 229 | cls_tokens = self.cls_token.expand(B, -1, -1) 230 | x = torch.cat((cls_tokens, x), dim=1) 231 | 232 | # add positional encoding to each token 233 | x = x + self.interpolate_pos_encoding(x, w, h) 234 | 235 | return self.pos_drop(x) 236 | 237 | def forward(self, x): 238 | x = self.prepare_tokens(x) 239 | for blk in self.blocks: 240 | x = blk(x) 241 | x = self.norm(x) 242 | return x[:, 0] 243 | 244 | def get_last_selfattention(self, x): 245 | x = self.prepare_tokens(x) 246 | for i, blk in enumerate(self.blocks): 247 | if i < len(self.blocks) - 1: 248 | x = blk(x) 249 | else: 250 | # return attention of the last block 251 | return blk(x, return_attention=True) 252 | 253 | def get_intermediate_layers(self, x, n=1): 254 | x = self.prepare_tokens(x) 255 | # we return the output tokens from the `n` last blocks 256 | output = [] 257 | for i, blk in enumerate(self.blocks): 258 | x = blk(x) 259 | if len(self.blocks) - i <= n: 260 | output.append(self.norm(x)) 261 | return output 262 | 263 | 264 | def vit_small(patch_size=16, **kwargs): 265 | model = VisionTransformer( 266 | patch_size=patch_size, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, 267 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 268 | return model 269 | 270 | 271 | def vit_base(patch_size=16, **kwargs): 272 | model = VisionTransformer( 273 | patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, 274 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 275 | return model 276 | 277 | class ViTFeat(nn.Module): 278 | """ Vision Transformer """ 279 | def __init__(self, pretrained_pth, feat_dim, vit_arch = 'base', vit_feat = 'k', patch_size=16): 280 | super().__init__() 281 | if vit_arch == 'base' : 282 | self.model = vit_base(patch_size=patch_size, num_classes=0) 283 | 284 | else : 285 | self.model = vit_small(patch_size=patch_size, num_classes=0) 286 | 287 | self.feat_dim = feat_dim 288 | self.vit_feat = vit_feat 289 | self.patch_size = patch_size 290 | 291 | # state_dict = torch.load(pretrained_pth, map_location="cpu") 292 | state_dict = torch.hub.load_state_dict_from_url(pretrained_pth) 293 | self.model.load_state_dict(state_dict, strict=True) 294 | print('Loading weight from {}'.format(pretrained_pth)) 295 | 296 | 297 | def forward(self, img) : 298 | feat_out = {} 299 | def hook_fn_forward_qkv(module, input, output): 300 | feat_out["qkv"] = output 301 | 302 | # self.model._modules["blocks"][-1]._modules["attn"]._modules["qkv"].register_forward_hook(hook_fn_forward_qkv) 303 | self.model._modules["blocks"][-1]._modules["attn"]._modules["qkv"].register_forward_hook(hook_fn_forward_qkv) 304 | 305 | 306 | # Forward pass in the model 307 | with torch.no_grad() : 308 | h, w = img.shape[2], img.shape[3] 309 | feat_h, feat_w = h // self.patch_size, w // self.patch_size 310 | attentions = self.model.get_last_selfattention(img) 311 | # attentions = self.model.get_intermediate_layers(img,2) 312 | bs, nb_head, nb_token = attentions.shape[0], attentions.shape[1], attentions.shape[2] 313 | qkv = ( 314 | feat_out["qkv"] 315 | .reshape(bs, nb_token, 3, nb_head, -1) 316 | .permute(2, 0, 3, 1, 4) 317 | ) 318 | q, k, v = qkv[0], qkv[1], qkv[2] 319 | 320 | k = k.transpose(1, 2).reshape(bs, nb_token, -1) 321 | q = q.transpose(1, 2).reshape(bs, nb_token, -1) 322 | v = v.transpose(1, 2).reshape(bs, nb_token, -1) 323 | 324 | # Modality selection 325 | if self.vit_feat == "k": 326 | feats = k[:, 1:].transpose(1, 2).reshape(bs, self.feat_dim, feat_h * feat_w) 327 | elif self.vit_feat == "q": 328 | feats = q[:, 1:].transpose(1, 2).reshape(bs, self.feat_dim, feat_h * feat_w) 329 | elif self.vit_feat == "v": 330 | feats = v[:, 1:].transpose(1, 2).reshape(bs, self.feat_dim, feat_h * feat_w) 331 | elif self.vit_feat == "kqv": 332 | k = k[:, 1:].transpose(1, 2).reshape(bs, self.feat_dim, feat_h * feat_w) 333 | q = q[:, 1:].transpose(1, 2).reshape(bs, self.feat_dim, feat_h * feat_w) 334 | v = v[:, 1:].transpose(1, 2).reshape(bs, self.feat_dim, feat_h * feat_w) 335 | feats = torch.cat([k, q, v], dim=1) 336 | return feats 337 | 338 | 339 | if __name__ == "__main__": 340 | vit_arch = 'base' 341 | vit_feat = 'k' 342 | 343 | model = ViTFeat(vit_arch, vit_feat) 344 | img = torch.cuda.FloatTensor(4, 3, 224, 224) 345 | model.cuda() 346 | # Forward pass in the model 347 | feat = model(img) 348 | print (feat.shape) 349 | -------------------------------------------------------------------------------- /maskcut/ovcamo_cato.txt: -------------------------------------------------------------------------------- 1 | A camouflaged phote of the Deer 2 | A camouflaged phote of the Sciuridae 3 | A camouflaged phote of the Pipefish 4 | A camouflaged phote of the Monkey 5 | A camouflaged phote of the Beetle 6 | A camouflaged phote of the Bat 7 | A camouflaged phote of the Duck 8 | A camouflaged phote of the Leopard 9 | A camouflaged phote of the Giraffe 10 | A camouflaged phote of the Wolf 11 | A camouflaged phote of the Tiger 12 | A camouflaged phote of the Cat 13 | A camouflaged phote of the Human 14 | A camouflaged phote of the Frogmouth 15 | A camouflaged phote of the ClownFish 16 | A camouflaged phote of the Toad 17 | A camouflaged phote of the Cicada 18 | A camouflaged phote of the Other 19 | A camouflaged phote of the Heron 20 | A camouflaged phote of the Caterpillar 21 | A camouflaged phote of the Reccoon 22 | A camouflaged phote of the Butterfly 23 | A camouflaged phote of the Bittern 24 | A camouflaged phote of the BatFish 25 | A camouflaged phote of the FrogFish 26 | A camouflaged phote of the Frog 27 | A camouflaged phote of the Slug 28 | A camouflaged phote of the Chameleon 29 | A camouflaged phote of the Spider 30 | A camouflaged phote of the Cheetah 31 | A camouflaged phote of the Bird 32 | A camouflaged phote of the Snake 33 | A camouflaged phote of the Lizard 34 | A camouflaged phote of the Fish 35 | A camouflaged phote of the Bug 36 | A camouflaged phote of the Ant 37 | A camouflaged phote of the SeaHorse 38 | A camouflaged phote of the Crab 39 | A camouflaged phote of the Dog 40 | A camouflaged phote of the Stingaree 41 | A camouflaged phote of the Shrimp 42 | A camouflaged phote of the Rabbit 43 | A camouflaged phote of the Crocodile 44 | A camouflaged phote of the Moth 45 | A camouflaged phote of the ScorpionFish 46 | A camouflaged phote of the Pagurian 47 | A camouflaged phote of the Grouse 48 | A camouflaged phote of the Owl 49 | A camouflaged phote of the Katydid 50 | A camouflaged phote of the Gecko 51 | A camouflaged phote of the Flounder 52 | A camouflaged phote of the Owlfly 53 | A camouflaged phote of the Sheep 54 | A camouflaged phote of the Mantis 55 | A camouflaged phote of the Dragonfly 56 | A camouflaged phote of the Centipede 57 | A camouflaged phote of the Kangaroo 58 | A camouflaged phote of the CrocodileFish 59 | A camouflaged phote of the Lion 60 | A camouflaged phote of the Bee 61 | A camouflaged phote of the StarFish 62 | A camouflaged phote of the Grasshopper 63 | A camouflaged phote of the Mockingbird 64 | A camouflaged phote of the Turtle 65 | A camouflaged phote of the StickInsect 66 | A camouflaged phote of the Octopus 67 | A camouflaged phote of the Worm 68 | A camouflaged phote of the LeafySeaDragon -------------------------------------------------------------------------------- /pytorch_grad_cam/__init__.py: -------------------------------------------------------------------------------- 1 | from pytorch_grad_cam.grad_cam import GradCAM 2 | from pytorch_grad_cam.ablation_layer import AblationLayer, AblationLayerVit, AblationLayerFasterRCNN 3 | from pytorch_grad_cam.ablation_cam import AblationCAM 4 | from pytorch_grad_cam.xgrad_cam import XGradCAM 5 | from pytorch_grad_cam.grad_cam_plusplus import GradCAMPlusPlus 6 | from pytorch_grad_cam.score_cam import ScoreCAM 7 | from pytorch_grad_cam.layer_cam import LayerCAM 8 | from pytorch_grad_cam.eigen_cam import EigenCAM 9 | from pytorch_grad_cam.eigen_grad_cam import EigenGradCAM 10 | from pytorch_grad_cam.fullgrad_cam import FullGrad 11 | from pytorch_grad_cam.guided_backprop import GuidedBackpropReLUModel 12 | from pytorch_grad_cam.activations_and_gradients import ActivationsAndGradients 13 | import pytorch_grad_cam.utils.model_targets 14 | import pytorch_grad_cam.utils.reshape_transforms -------------------------------------------------------------------------------- /pytorch_grad_cam/ablation_cam.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import tqdm 4 | from typing import Callable, List 5 | from pytorch_grad_cam.base_cam import BaseCAM 6 | from pytorch_grad_cam.utils.find_layers import replace_layer_recursive 7 | from pytorch_grad_cam.ablation_layer import AblationLayer 8 | 9 | 10 | """ Implementation of AblationCAM 11 | https://openaccess.thecvf.com/content_WACV_2020/papers/Desai_Ablation-CAM_Visual_Explanations_for_Deep_Convolutional_Network_via_Gradient-free_Localization_WACV_2020_paper.pdf 12 | 13 | Ablate individual activations, and then measure the drop in the target score. 14 | 15 | In the current implementation, the target layer activations is cached, so it won't be re-computed. 16 | However layers before it, if any, will not be cached. 17 | This means that if the target layer is a large block, for example model.featuers (in vgg), there will 18 | be a large save in run time. 19 | 20 | Since we have to go over many channels and ablate them, and every channel ablation requires a forward pass, 21 | it would be nice if we could avoid doing that for channels that won't contribute anwyay, making it much faster. 22 | The parameter ratio_channels_to_ablate controls how many channels should be ablated, using an experimental method 23 | (to be improved). The default 1.0 value means that all channels will be ablated. 24 | """ 25 | 26 | 27 | class AblationCAM(BaseCAM): 28 | def __init__(self, 29 | model: torch.nn.Module, 30 | target_layers: List[torch.nn.Module], 31 | use_cuda: bool = False, 32 | reshape_transform: Callable = None, 33 | ablation_layer: torch.nn.Module = AblationLayer(), 34 | batch_size: int = 32, 35 | ratio_channels_to_ablate: float = 1.0) -> None: 36 | 37 | super(AblationCAM, self).__init__(model, 38 | target_layers, 39 | use_cuda, 40 | reshape_transform, 41 | uses_gradients=False) 42 | self.batch_size = batch_size 43 | self.ablation_layer = ablation_layer 44 | self.ratio_channels_to_ablate = ratio_channels_to_ablate 45 | 46 | def save_activation(self, module, input, output) -> None: 47 | """ Helper function to save the raw activations from the target layer """ 48 | self.activations = output 49 | 50 | def assemble_ablation_scores(self, 51 | new_scores: list, 52 | original_score: float , 53 | ablated_channels: np.ndarray, 54 | number_of_channels: int) -> np.ndarray: 55 | """ Take the value from the channels that were ablated, 56 | and just set the original score for the channels that were skipped """ 57 | 58 | index = 0 59 | result = [] 60 | sorted_indices = np.argsort(ablated_channels) 61 | ablated_channels = ablated_channels[sorted_indices] 62 | new_scores = np.float32(new_scores)[sorted_indices] 63 | 64 | for i in range(number_of_channels): 65 | if index < len(ablated_channels) and ablated_channels[index] == i: 66 | weight = new_scores[index] 67 | index = index + 1 68 | else: 69 | weight = original_score 70 | result.append(weight) 71 | 72 | return result 73 | 74 | def get_cam_weights(self, 75 | input_tensor: torch.Tensor, 76 | target_layer: torch.nn.Module, 77 | targets: List[Callable], 78 | activations: torch.Tensor, 79 | grads: torch.Tensor) -> np.ndarray: 80 | 81 | # Do a forward pass, compute the target scores, and cache the activations 82 | handle = target_layer.register_forward_hook(self.save_activation) 83 | with torch.no_grad(): 84 | outputs = self.model(input_tensor) 85 | handle.remove() 86 | original_scores = np.float32([target(output).cpu().item() for target, output in zip(targets, outputs)]) 87 | 88 | # Replace the layer with the ablation layer. 89 | # When we finish, we will replace it back, so the original model is unchanged. 90 | ablation_layer = self.ablation_layer 91 | replace_layer_recursive(self.model, target_layer, ablation_layer) 92 | 93 | number_of_channels = activations.shape[1] 94 | weights = [] 95 | # This is a "gradient free" method, so we don't need gradients here. 96 | with torch.no_grad(): 97 | # Loop over each of the batch images and ablate activations for it. 98 | for batch_index, (target, tensor) in enumerate(zip(targets, input_tensor)): 99 | new_scores = [] 100 | batch_tensor = tensor.repeat(self.batch_size, 1, 1, 1) 101 | 102 | # Check which channels should be ablated. Normally this will be all channels, 103 | # But we can also try to speed this up by using a low ratio_channels_to_ablate. 104 | channels_to_ablate = ablation_layer.activations_to_be_ablated(activations[batch_index, :], 105 | self.ratio_channels_to_ablate) 106 | number_channels_to_ablate = len(channels_to_ablate) 107 | 108 | for i in tqdm.tqdm(range(0, number_channels_to_ablate, self.batch_size)): 109 | if i + self.batch_size > number_channels_to_ablate: 110 | batch_tensor = batch_tensor[:(number_channels_to_ablate - i)] 111 | 112 | # Change the state of the ablation layer so it ablates the next channels. 113 | # TBD: Move this into the ablation layer forward pass. 114 | ablation_layer.set_next_batch(input_batch_index=batch_index, 115 | activations=self.activations, 116 | num_channels_to_ablate=batch_tensor.size(0)) 117 | score = [target(o).cpu().item() for o in self.model(batch_tensor)] 118 | new_scores.extend(score) 119 | ablation_layer.indices = ablation_layer.indices[batch_tensor.size(0):] 120 | 121 | new_scores = self.assemble_ablation_scores(new_scores, 122 | original_scores[batch_index], 123 | channels_to_ablate, 124 | number_of_channels) 125 | weights.extend(new_scores) 126 | 127 | weights = np.float32(weights) 128 | weights = weights.reshape(activations.shape[:2]) 129 | original_scores = original_scores[:, None] 130 | weights = (original_scores - weights) / original_scores 131 | 132 | # Replace the model back to the original state 133 | replace_layer_recursive(self.model, ablation_layer, target_layer) 134 | return weights 135 | -------------------------------------------------------------------------------- /pytorch_grad_cam/ablation_cam_multilayer.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import torch 4 | import tqdm 5 | from pytorch_grad_cam.base_cam import BaseCAM 6 | 7 | 8 | class AblationLayer(torch.nn.Module): 9 | def __init__(self, layer, reshape_transform, indices): 10 | super(AblationLayer, self).__init__() 11 | 12 | self.layer = layer 13 | self.reshape_transform = reshape_transform 14 | # The channels to zero out: 15 | self.indices = indices 16 | 17 | def forward(self, x): 18 | self.__call__(x) 19 | 20 | def __call__(self, x): 21 | output = self.layer(x) 22 | 23 | # Hack to work with ViT, 24 | # Since the activation channels are last and not first like in CNNs 25 | # Probably should remove it? 26 | if self.reshape_transform is not None: 27 | output = output.transpose(1, 2) 28 | 29 | for i in range(output.size(0)): 30 | 31 | # Commonly the minimum activation will be 0, 32 | # And then it makes sense to zero it out. 33 | # However depending on the architecture, 34 | # If the values can be negative, we use very negative values 35 | # to perform the ablation, deviating from the paper. 36 | if torch.min(output) == 0: 37 | output[i, self.indices[i], :] = 0 38 | else: 39 | ABLATION_VALUE = 1e5 40 | output[i, self.indices[i], :] = torch.min( 41 | output) - ABLATION_VALUE 42 | 43 | if self.reshape_transform is not None: 44 | output = output.transpose(2, 1) 45 | 46 | return output 47 | 48 | 49 | def replace_layer_recursive(model, old_layer, new_layer): 50 | for name, layer in model._modules.items(): 51 | if layer == old_layer: 52 | model._modules[name] = new_layer 53 | return True 54 | elif replace_layer_recursive(layer, old_layer, new_layer): 55 | return True 56 | return False 57 | 58 | 59 | class AblationCAM(BaseCAM): 60 | def __init__(self, model, target_layers, use_cuda=False, 61 | reshape_transform=None): 62 | super(AblationCAM, self).__init__(model, target_layers, use_cuda, 63 | reshape_transform) 64 | 65 | if len(target_layers) > 1: 66 | print( 67 | "Warning. You are usign Ablation CAM with more than 1 layers. " 68 | "This is supported only if all layers have the same output shape") 69 | 70 | def set_ablation_layers(self): 71 | self.ablation_layers = [] 72 | for target_layer in self.target_layers: 73 | ablation_layer = AblationLayer(target_layer, 74 | self.reshape_transform, indices=[]) 75 | self.ablation_layers.append(ablation_layer) 76 | replace_layer_recursive(self.model, target_layer, ablation_layer) 77 | 78 | def unset_ablation_layers(self): 79 | # replace the model back to the original state 80 | for ablation_layer, target_layer in zip( 81 | self.ablation_layers, self.target_layers): 82 | replace_layer_recursive(self.model, ablation_layer, target_layer) 83 | 84 | def set_ablation_layer_batch_indices(self, indices): 85 | for ablation_layer in self.ablation_layers: 86 | ablation_layer.indices = indices 87 | 88 | def trim_ablation_layer_batch_indices(self, keep): 89 | for ablation_layer in self.ablation_layers: 90 | ablation_layer.indices = ablation_layer.indices[:keep] 91 | 92 | def get_cam_weights(self, 93 | input_tensor, 94 | target_category, 95 | activations, 96 | grads): 97 | with torch.no_grad(): 98 | outputs = self.model(input_tensor).cpu().numpy() 99 | original_scores = [] 100 | for i in range(input_tensor.size(0)): 101 | original_scores.append(outputs[i, target_category[i]]) 102 | original_scores = np.float32(original_scores) 103 | 104 | self.set_ablation_layers() 105 | 106 | if hasattr(self, "batch_size"): 107 | BATCH_SIZE = self.batch_size 108 | else: 109 | BATCH_SIZE = 32 110 | 111 | number_of_channels = activations.shape[1] 112 | weights = [] 113 | 114 | with torch.no_grad(): 115 | # Iterate over the input batch 116 | for tensor, category in zip(input_tensor, target_category): 117 | batch_tensor = tensor.repeat(BATCH_SIZE, 1, 1, 1) 118 | for i in tqdm.tqdm(range(0, number_of_channels, BATCH_SIZE)): 119 | self.set_ablation_layer_batch_indices( 120 | list(range(i, i + BATCH_SIZE))) 121 | 122 | if i + BATCH_SIZE > number_of_channels: 123 | keep = number_of_channels - i 124 | batch_tensor = batch_tensor[:keep] 125 | self.trim_ablation_layer_batch_indices(self, keep) 126 | score = self.model(batch_tensor)[:, category].cpu().numpy() 127 | weights.extend(score) 128 | 129 | weights = np.float32(weights) 130 | weights = weights.reshape(activations.shape[:2]) 131 | original_scores = original_scores[:, None] 132 | weights = (original_scores - weights) / original_scores 133 | 134 | # replace the model back to the original state 135 | self.unset_ablation_layers() 136 | return weights 137 | -------------------------------------------------------------------------------- /pytorch_grad_cam/ablation_layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from collections import OrderedDict 3 | import numpy as np 4 | from pytorch_grad_cam.utils.svd_on_activations import get_2d_projection 5 | 6 | 7 | class AblationLayer(torch.nn.Module): 8 | def __init__(self): 9 | super(AblationLayer, self).__init__() 10 | 11 | def objectiveness_mask_from_svd(self, activations, threshold=0.01): 12 | """ Experimental method to get a binary mask to compare if the activation is worth ablating. 13 | The idea is to apply the EigenCAM method by doing PCA on the activations. 14 | Then we create a binary mask by comparing to a low threshold. 15 | Areas that are masked out, are probably not interesting anyway. 16 | """ 17 | 18 | projection = get_2d_projection(activations[None, :])[0, :] 19 | projection = np.abs(projection) 20 | projection = projection - projection.min() 21 | projection = projection / projection.max() 22 | projection = projection > threshold 23 | return projection 24 | 25 | def activations_to_be_ablated(self, activations, ratio_channels_to_ablate=1.0): 26 | """ Experimental method to get a binary mask to compare if the activation is worth ablating. 27 | Create a binary CAM mask with objectiveness_mask_from_svd. 28 | Score each Activation channel, by seeing how much of its values are inside the mask. 29 | Then keep the top channels. 30 | 31 | """ 32 | if ratio_channels_to_ablate == 1.0: 33 | self.indices = np.int32(range(activations.shape[0])) 34 | return self.indices 35 | 36 | projection = self.objectiveness_mask_from_svd(activations) 37 | 38 | scores = [] 39 | for channel in activations: 40 | normalized = np.abs(channel) 41 | normalized = normalized - normalized.min() 42 | normalized = normalized / np.max(normalized) 43 | score = (projection*normalized).sum() / normalized.sum() 44 | scores.append(score) 45 | scores = np.float32(scores) 46 | 47 | indices = list(np.argsort(scores)) 48 | high_score_indices = indices[::-1][: int(len(indices) * ratio_channels_to_ablate)] 49 | low_score_indices = indices[: int(len(indices) * ratio_channels_to_ablate)] 50 | self.indices = np.int32(high_score_indices + low_score_indices) 51 | return self.indices 52 | 53 | def set_next_batch(self, input_batch_index, activations, num_channels_to_ablate): 54 | """ This creates the next batch of activations from the layer. 55 | Just take corresponding batch member from activations, and repeat it num_channels_to_ablate times. 56 | """ 57 | self.activations = activations[input_batch_index, :, :, :].clone().unsqueeze(0).repeat(num_channels_to_ablate, 1, 1, 1) 58 | 59 | def __call__(self, x): 60 | output = self.activations 61 | for i in range(output.size(0)): 62 | # Commonly the minimum activation will be 0, 63 | # And then it makes sense to zero it out. 64 | # However depending on the architecture, 65 | # If the values can be negative, we use very negative values 66 | # to perform the ablation, deviating from the paper. 67 | if torch.min(output) == 0: 68 | output[i, self.indices[i], :] = 0 69 | else: 70 | ABLATION_VALUE = 1e7 71 | output[i, self.indices[i], :] = torch.min( 72 | output) - ABLATION_VALUE 73 | 74 | return output 75 | 76 | 77 | class AblationLayerVit(AblationLayer): 78 | def __init__(self): 79 | super(AblationLayerVit, self).__init__() 80 | 81 | def __call__(self, x): 82 | output = self.activations 83 | output = output.transpose(1, 2) 84 | for i in range(output.size(0)): 85 | 86 | # Commonly the minimum activation will be 0, 87 | # And then it makes sense to zero it out. 88 | # However depending on the architecture, 89 | # If the values can be negative, we use very negative values 90 | # to perform the ablation, deviating from the paper. 91 | if torch.min(output) == 0: 92 | output[i, self.indices[i], :] = 0 93 | else: 94 | ABLATION_VALUE = 1e7 95 | output[i, self.indices[i], :] = torch.min( 96 | output) - ABLATION_VALUE 97 | 98 | output = output.transpose(2, 1) 99 | 100 | return output 101 | 102 | def set_next_batch(self, input_batch_index, activations, num_channels_to_ablate): 103 | """ This creates the next batch of activations from the layer. 104 | Just take corresponding batch member from activations, and repeat it num_channels_to_ablate times. 105 | """ 106 | self.activations = activations[input_batch_index, :, :].clone().unsqueeze(0).repeat(num_channels_to_ablate, 1, 1) 107 | 108 | 109 | 110 | class AblationLayerFasterRCNN(AblationLayer): 111 | def __init__(self): 112 | super(AblationLayerFasterRCNN, self).__init__() 113 | 114 | def set_next_batch(self, input_batch_index, activations, num_channels_to_ablate): 115 | """ Extract the next batch member from activations, 116 | and repeat it num_channels_to_ablate times. 117 | """ 118 | self.activations = OrderedDict() 119 | for key, value in activations.items(): 120 | fpn_activation = value[input_batch_index, :, :, :].clone().unsqueeze(0) 121 | self.activations[key] = fpn_activation.repeat(num_channels_to_ablate, 1, 1, 1) 122 | 123 | def __call__(self, x): 124 | result = self.activations 125 | layers = {0: '0', 1: '1', 2: '2', 3: '3', 4: 'pool'} 126 | num_channels_to_ablate = result['pool'].size(0) 127 | for i in range(num_channels_to_ablate): 128 | pyramid_layer = int(self.indices[i]/256) 129 | index_in_pyramid_layer = int(self.indices[i] % 256) 130 | result[layers[pyramid_layer]][i, index_in_pyramid_layer, :, :] = -1000 131 | return result 132 | -------------------------------------------------------------------------------- /pytorch_grad_cam/activations_and_gradients.py: -------------------------------------------------------------------------------- 1 | class ActivationsAndGradients: 2 | """ Class for extracting activations and 3 | registering gradients from targetted intermediate layers """ 4 | 5 | def __init__(self, model, target_layers, reshape_transform): 6 | self.model = model 7 | self.gradients = [] 8 | self.activations = [] 9 | self.reshape_transform = reshape_transform 10 | self.handles = [] 11 | for target_layer in target_layers: 12 | self.handles.append( 13 | target_layer.register_forward_hook(self.save_activation)) 14 | # Because of https://github.com/pytorch/pytorch/issues/61519, 15 | # we don't use backward hook to record gradients. 16 | self.handles.append( 17 | target_layer.register_forward_hook(self.save_gradient)) 18 | 19 | def save_activation(self, module, input, output): 20 | activation = output 21 | 22 | if self.reshape_transform is not None: 23 | activation = self.reshape_transform(activation, self.height, self.width) 24 | self.activations.append(activation.cpu().detach()) 25 | 26 | def save_gradient(self, module, input, output): 27 | if not hasattr(output, "requires_grad") or not output.requires_grad: 28 | # You can only register hooks on tensor requires grad. 29 | return 30 | 31 | # Gradients are computed in reverse order 32 | def _store_grad(grad): 33 | if self.reshape_transform is not None: 34 | grad = self.reshape_transform(grad, self.height, self.width) 35 | self.gradients = [grad.cpu().detach()] + self.gradients 36 | 37 | output.register_hook(_store_grad) 38 | 39 | def __call__(self, x, H, W): 40 | self.height = H // 16 41 | self.width = W // 16 42 | self.gradients = [] 43 | self.activations = [] 44 | if isinstance(x, list): 45 | return self.model.forward_last_layer(x[0], x[1]) 46 | else: 47 | return self.model(x) 48 | 49 | def release(self): 50 | for handle in self.handles: 51 | handle.remove() 52 | -------------------------------------------------------------------------------- /pytorch_grad_cam/base_cam.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import ttach as tta 4 | from typing import Callable, List, Tuple 5 | from pytorch_grad_cam.activations_and_gradients import ActivationsAndGradients 6 | from pytorch_grad_cam.utils.svd_on_activations import get_2d_projection 7 | from pytorch_grad_cam.utils.image import scale_cam_image 8 | from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget 9 | 10 | 11 | class BaseCAM: 12 | def __init__(self, 13 | model: torch.nn.Module, 14 | target_layers: List[torch.nn.Module], 15 | use_cuda: bool = False, 16 | reshape_transform: Callable = None, 17 | compute_input_gradient: bool = False, 18 | uses_gradients: bool = True) -> None: 19 | self.model = model.eval() 20 | self.target_layers = target_layers 21 | self.cuda = use_cuda 22 | if self.cuda: 23 | self.model = model.cuda() 24 | self.reshape_transform = reshape_transform 25 | self.compute_input_gradient = compute_input_gradient 26 | self.uses_gradients = uses_gradients 27 | self.activations_and_grads = ActivationsAndGradients( 28 | self.model, target_layers, reshape_transform) 29 | 30 | """ Get a vector of weights for every channel in the target layer. 31 | Methods that return weights channels, 32 | will typically need to only implement this function. """ 33 | 34 | def get_cam_weights(self, 35 | input_tensor: torch.Tensor, 36 | target_layers: List[torch.nn.Module], 37 | targets: List[torch.nn.Module], 38 | activations: torch.Tensor, 39 | grads: torch.Tensor) -> np.ndarray: 40 | raise Exception("Not Implemented") 41 | 42 | def get_cam_image(self, 43 | input_tensor: torch.Tensor, 44 | target_layer: torch.nn.Module, 45 | targets: List[torch.nn.Module], 46 | activations: torch.Tensor, 47 | grads: torch.Tensor, 48 | eigen_smooth: bool = False) -> np.ndarray: 49 | 50 | weights = self.get_cam_weights(input_tensor, 51 | target_layer, 52 | targets, 53 | activations, 54 | grads) 55 | weighted_activations = weights[:, :, None, None] * activations 56 | if eigen_smooth: 57 | cam = get_2d_projection(weighted_activations) 58 | else: 59 | cam = weighted_activations.sum(axis=1) 60 | return cam 61 | 62 | def forward(self, 63 | input_tensor: torch.Tensor, 64 | targets: List[torch.nn.Module], 65 | target_size, 66 | eigen_smooth: bool = False) -> np.ndarray: 67 | 68 | if self.cuda: 69 | input_tensor = input_tensor.cuda() 70 | 71 | if self.compute_input_gradient: 72 | input_tensor = torch.autograd.Variable(input_tensor, 73 | requires_grad=True) 74 | W,H = self.get_target_width_height(input_tensor) 75 | outputs = self.activations_and_grads(input_tensor,H,W) 76 | if targets is None: 77 | if isinstance(input_tensor, list): 78 | target_categories = np.argmax(outputs[0].cpu().data.numpy(), axis=-1) 79 | else: 80 | target_categories = np.argmax(outputs.cpu().data.numpy(), axis=-1) 81 | targets = [ClassifierOutputTarget(category) for category in target_categories] 82 | 83 | if self.uses_gradients: 84 | self.model.zero_grad() 85 | if isinstance(input_tensor, list): 86 | loss = sum([target(output[0]) for target, output in zip(targets, outputs)]) 87 | else: 88 | loss = sum([target(output) for target, output in zip(targets, outputs)]) 89 | loss.backward(retain_graph=True) 90 | cam_per_layer = self.compute_cam_per_layer(input_tensor, 91 | targets, 92 | target_size, 93 | eigen_smooth) 94 | if isinstance(input_tensor, list): 95 | return self.aggregate_multi_layers(cam_per_layer), outputs[0], outputs[1] 96 | else: 97 | return self.aggregate_multi_layers(cam_per_layer), outputs 98 | 99 | def get_target_width_height(self, 100 | input_tensor: torch.Tensor) -> Tuple[int, int]: 101 | if isinstance(input_tensor, list): 102 | width, height = input_tensor[-1], input_tensor[-2] 103 | return width, height 104 | 105 | def compute_cam_per_layer( 106 | self, 107 | input_tensor: torch.Tensor, 108 | targets: List[torch.nn.Module], 109 | target_size, 110 | eigen_smooth: bool) -> np.ndarray: 111 | activations_list = [a.cpu().data.numpy() 112 | for a in self.activations_and_grads.activations] 113 | grads_list = [g.cpu().data.numpy() 114 | for g in self.activations_and_grads.gradients] 115 | 116 | cam_per_target_layer = [] 117 | # Loop over the saliency image from every layer 118 | for i in range(len(self.target_layers)): 119 | target_layer = self.target_layers[i] 120 | layer_activations = None 121 | layer_grads = None 122 | if i < len(activations_list): 123 | layer_activations = activations_list[i] 124 | if i < len(grads_list): 125 | layer_grads = grads_list[i] 126 | 127 | cam = self.get_cam_image(input_tensor, 128 | target_layer, 129 | targets, 130 | layer_activations, 131 | layer_grads, 132 | eigen_smooth) 133 | cam = np.maximum(cam, 0).astype(np.float32)#float16->32 134 | scaled = scale_cam_image(cam, target_size) 135 | cam_per_target_layer.append(scaled[:, None, :]) 136 | 137 | return cam_per_target_layer 138 | 139 | def aggregate_multi_layers(self, cam_per_target_layer: np.ndarray) -> np.ndarray: 140 | cam_per_target_layer = np.concatenate(cam_per_target_layer, axis=1) 141 | cam_per_target_layer = np.maximum(cam_per_target_layer, 0) 142 | result = np.mean(cam_per_target_layer, axis=1) 143 | return scale_cam_image(result) 144 | 145 | def forward_augmentation_smoothing(self, 146 | input_tensor: torch.Tensor, 147 | targets: List[torch.nn.Module], 148 | eigen_smooth: bool = False) -> np.ndarray: 149 | transforms = tta.Compose( 150 | [ 151 | tta.HorizontalFlip(), 152 | tta.Multiply(factors=[0.9, 1, 1.1]), 153 | ] 154 | ) 155 | cams = [] 156 | for transform in transforms: 157 | augmented_tensor = transform.augment_image(input_tensor) 158 | cam = self.forward(augmented_tensor, 159 | targets, 160 | eigen_smooth) 161 | 162 | # The ttach library expects a tensor of size BxCxHxW 163 | cam = cam[:, None, :, :] 164 | cam = torch.from_numpy(cam) 165 | cam = transform.deaugment_mask(cam) 166 | 167 | # Back to numpy float32, HxW 168 | cam = cam.numpy() 169 | cam = cam[:, 0, :, :] 170 | cams.append(cam) 171 | 172 | cam = np.mean(np.float32(cams), axis=0) 173 | return cam 174 | 175 | def __call__(self, 176 | input_tensor: torch.Tensor, 177 | targets: List[torch.nn.Module] = None, 178 | target_size=None, 179 | aug_smooth: bool = False, 180 | eigen_smooth: bool = False) -> np.ndarray: 181 | 182 | # Smooth the CAM result with test time augmentation 183 | if aug_smooth is True: 184 | return self.forward_augmentation_smoothing( 185 | input_tensor, targets, eigen_smooth) 186 | 187 | return self.forward(input_tensor, 188 | targets, target_size,eigen_smooth) 189 | 190 | def __del__(self): 191 | self.activations_and_grads.release() 192 | 193 | def __enter__(self): 194 | return self 195 | 196 | def __exit__(self, exc_type, exc_value, exc_tb): 197 | self.activations_and_grads.release() 198 | if isinstance(exc_value, IndexError): 199 | # Handle IndexError here... 200 | print( 201 | f"An exception occurred in CAM with block: {exc_type}. Message: {exc_value}") 202 | return True 203 | -------------------------------------------------------------------------------- /pytorch_grad_cam/eigen_cam.py: -------------------------------------------------------------------------------- 1 | from pytorch_grad_cam.base_cam import BaseCAM 2 | from pytorch_grad_cam.utils.svd_on_activations import get_2d_projection 3 | 4 | # https://arxiv.org/abs/2008.00299 5 | 6 | 7 | class EigenCAM(BaseCAM): 8 | def __init__(self, model, target_layers, use_cuda=False, 9 | reshape_transform=None): 10 | super(EigenCAM, self).__init__(model, 11 | target_layers, 12 | use_cuda, 13 | reshape_transform, 14 | uses_gradients=False) 15 | 16 | def get_cam_image(self, 17 | input_tensor, 18 | target_layer, 19 | target_category, 20 | activations, 21 | grads, 22 | eigen_smooth): 23 | return get_2d_projection(activations) 24 | -------------------------------------------------------------------------------- /pytorch_grad_cam/eigen_grad_cam.py: -------------------------------------------------------------------------------- 1 | from pytorch_grad_cam.base_cam import BaseCAM 2 | from pytorch_grad_cam.utils.svd_on_activations import get_2d_projection 3 | 4 | # Like Eigen CAM: https://arxiv.org/abs/2008.00299 5 | # But multiply the activations x gradients 6 | 7 | 8 | class EigenGradCAM(BaseCAM): 9 | def __init__(self, model, target_layers, use_cuda=False, 10 | reshape_transform=None): 11 | super(EigenGradCAM, self).__init__(model, target_layers, use_cuda, 12 | reshape_transform) 13 | 14 | def get_cam_image(self, 15 | input_tensor, 16 | target_layer, 17 | target_category, 18 | activations, 19 | grads, 20 | eigen_smooth): 21 | return get_2d_projection(grads * activations) 22 | -------------------------------------------------------------------------------- /pytorch_grad_cam/fullgrad_cam.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from pytorch_grad_cam.base_cam import BaseCAM 4 | from pytorch_grad_cam.utils.find_layers import find_layer_predicate_recursive 5 | from pytorch_grad_cam.utils.svd_on_activations import get_2d_projection 6 | from pytorch_grad_cam.utils.image import scale_accross_batch_and_channels, scale_cam_image 7 | 8 | # https://arxiv.org/abs/1905.00780 9 | 10 | 11 | class FullGrad(BaseCAM): 12 | def __init__(self, model, target_layers, use_cuda=False, 13 | reshape_transform=None): 14 | if len(target_layers) > 0: 15 | print( 16 | "Warning: target_layers is ignored in FullGrad. All bias layers will be used instead") 17 | 18 | def layer_with_2D_bias(layer): 19 | bias_target_layers = [torch.nn.Conv2d, torch.nn.BatchNorm2d] 20 | if type(layer) in bias_target_layers and layer.bias is not None: 21 | return True 22 | return False 23 | target_layers = find_layer_predicate_recursive( 24 | model, layer_with_2D_bias) 25 | super( 26 | FullGrad, 27 | self).__init__( 28 | model, 29 | target_layers, 30 | use_cuda, 31 | reshape_transform, 32 | compute_input_gradient=True) 33 | self.bias_data = [self.get_bias_data( 34 | layer).cpu().numpy() for layer in target_layers] 35 | 36 | def get_bias_data(self, layer): 37 | # Borrowed from official paper impl: 38 | # https://github.com/idiap/fullgrad-saliency/blob/master/saliency/tensor_extractor.py#L47 39 | if isinstance(layer, torch.nn.BatchNorm2d): 40 | bias = - (layer.running_mean * layer.weight 41 | / torch.sqrt(layer.running_var + layer.eps)) + layer.bias 42 | return bias.data 43 | else: 44 | return layer.bias.data 45 | 46 | def compute_cam_per_layer( 47 | self, 48 | input_tensor, 49 | target_category, 50 | eigen_smooth): 51 | input_grad = input_tensor.grad.data.cpu().numpy() 52 | grads_list = [g.cpu().data.numpy() for g in 53 | self.activations_and_grads.gradients] 54 | cam_per_target_layer = [] 55 | target_size = self.get_target_width_height(input_tensor) 56 | 57 | gradient_multiplied_input = input_grad * input_tensor.data.cpu().numpy() 58 | gradient_multiplied_input = np.abs(gradient_multiplied_input) 59 | gradient_multiplied_input = scale_accross_batch_and_channels( 60 | gradient_multiplied_input, 61 | target_size) 62 | cam_per_target_layer.append(gradient_multiplied_input) 63 | 64 | # Loop over the saliency image from every layer 65 | assert(len(self.bias_data) == len(grads_list)) 66 | for bias, grads in zip(self.bias_data, grads_list): 67 | bias = bias[None, :, None, None] 68 | # In the paper they take the absolute value, 69 | # but possibily taking only the positive gradients will work 70 | # better. 71 | bias_grad = np.abs(bias * grads) 72 | result = scale_accross_batch_and_channels( 73 | bias_grad, target_size) 74 | result = np.sum(result, axis=1) 75 | cam_per_target_layer.append(result[:, None, :]) 76 | cam_per_target_layer = np.concatenate(cam_per_target_layer, axis=1) 77 | if eigen_smooth: 78 | # Resize to a smaller image, since this method typically has a very large number of channels, 79 | # and then consumes a lot of memory 80 | cam_per_target_layer = scale_accross_batch_and_channels( 81 | cam_per_target_layer, (target_size[0] // 8, target_size[1] // 8)) 82 | cam_per_target_layer = get_2d_projection(cam_per_target_layer) 83 | cam_per_target_layer = cam_per_target_layer[:, None, :, :] 84 | cam_per_target_layer = scale_accross_batch_and_channels( 85 | cam_per_target_layer, 86 | target_size) 87 | else: 88 | cam_per_target_layer = np.sum( 89 | cam_per_target_layer, axis=1)[:, None, :] 90 | 91 | return cam_per_target_layer 92 | 93 | def aggregate_multi_layers(self, cam_per_target_layer): 94 | result = np.sum(cam_per_target_layer, axis=1) 95 | return scale_cam_image(result) 96 | -------------------------------------------------------------------------------- /pytorch_grad_cam/grad_cam.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from pytorch_grad_cam.base_cam import BaseCAM 3 | 4 | 5 | class GradCAM(BaseCAM): 6 | def __init__(self, model, target_layers, use_cuda=False, 7 | reshape_transform=None): 8 | super( 9 | GradCAM, 10 | self).__init__( 11 | model, 12 | target_layers, 13 | use_cuda, 14 | reshape_transform) 15 | 16 | def get_cam_weights(self, 17 | input_tensor, 18 | target_layer, 19 | target_category, 20 | activations, 21 | grads): 22 | 23 | return np.mean(grads, axis=(2, 3)) 24 | -------------------------------------------------------------------------------- /pytorch_grad_cam/grad_cam_plusplus.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from pytorch_grad_cam.base_cam import BaseCAM 3 | 4 | # https://arxiv.org/abs/1710.11063 5 | 6 | 7 | class GradCAMPlusPlus(BaseCAM): 8 | def __init__(self, model, target_layers, use_cuda=False, 9 | reshape_transform=None): 10 | super(GradCAMPlusPlus, self).__init__(model, target_layers, use_cuda, 11 | reshape_transform) 12 | 13 | def get_cam_weights(self, 14 | input_tensor, 15 | target_layers, 16 | target_category, 17 | activations, 18 | grads): 19 | grads_power_2 = grads**2 20 | grads_power_3 = grads_power_2 * grads 21 | # Equation 19 in https://arxiv.org/abs/1710.11063 22 | sum_activations = np.sum(activations, axis=(2, 3)) 23 | eps = 0.000001 24 | aij = grads_power_2 / (2 * grads_power_2 + 25 | sum_activations[:, :, None, None] * grads_power_3 + eps) 26 | # Now bring back the ReLU from eq.7 in the paper, 27 | # And zero out aijs where the activations are 0 28 | aij = np.where(grads != 0, aij, 0) 29 | 30 | weights = np.maximum(grads, 0) * aij 31 | weights = np.sum(weights, axis=(2, 3)) 32 | return weights 33 | -------------------------------------------------------------------------------- /pytorch_grad_cam/guided_backprop.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.autograd import Function 4 | from pytorch_grad_cam.utils.find_layers import replace_all_layer_type_recursive 5 | 6 | 7 | class GuidedBackpropReLU(Function): 8 | @staticmethod 9 | def forward(self, input_img): 10 | positive_mask = (input_img > 0).type_as(input_img) 11 | output = torch.addcmul( 12 | torch.zeros( 13 | input_img.size()).type_as(input_img), 14 | input_img, 15 | positive_mask) 16 | self.save_for_backward(input_img, output) 17 | return output 18 | 19 | @staticmethod 20 | def backward(self, grad_output): 21 | input_img, output = self.saved_tensors 22 | grad_input = None 23 | 24 | positive_mask_1 = (input_img > 0).type_as(grad_output) 25 | positive_mask_2 = (grad_output > 0).type_as(grad_output) 26 | grad_input = torch.addcmul( 27 | torch.zeros( 28 | input_img.size()).type_as(input_img), 29 | torch.addcmul( 30 | torch.zeros( 31 | input_img.size()).type_as(input_img), 32 | grad_output, 33 | positive_mask_1), 34 | positive_mask_2) 35 | return grad_input 36 | 37 | 38 | class GuidedBackpropReLUasModule(torch.nn.Module): 39 | def __init__(self): 40 | super(GuidedBackpropReLUasModule, self).__init__() 41 | 42 | def forward(self, input_img): 43 | return GuidedBackpropReLU.apply(input_img) 44 | 45 | 46 | class GuidedBackpropReLUModel: 47 | def __init__(self, model, use_cuda): 48 | self.model = model 49 | self.model.eval() 50 | self.cuda = use_cuda 51 | if self.cuda: 52 | self.model = self.model.cuda() 53 | 54 | def forward(self, input_img): 55 | return self.model(input_img) 56 | 57 | def recursive_replace_relu_with_guidedrelu(self, module_top): 58 | 59 | for idx, module in module_top._modules.items(): 60 | self.recursive_replace_relu_with_guidedrelu(module) 61 | if module.__class__.__name__ == 'ReLU': 62 | module_top._modules[idx] = GuidedBackpropReLU.apply 63 | print("b") 64 | 65 | def recursive_replace_guidedrelu_with_relu(self, module_top): 66 | try: 67 | for idx, module in module_top._modules.items(): 68 | self.recursive_replace_guidedrelu_with_relu(module) 69 | if module == GuidedBackpropReLU.apply: 70 | module_top._modules[idx] = torch.nn.ReLU() 71 | except BaseException: 72 | pass 73 | 74 | def __call__(self, input_img, target_category=None): 75 | replace_all_layer_type_recursive(self.model, 76 | torch.nn.ReLU, 77 | GuidedBackpropReLUasModule()) 78 | 79 | if self.cuda: 80 | input_img = input_img.cuda() 81 | 82 | input_img = input_img.requires_grad_(True) 83 | 84 | output = self.forward(input_img) 85 | 86 | if target_category is None: 87 | target_category = np.argmax(output.cpu().data.numpy()) 88 | 89 | loss = output[0, target_category] 90 | loss.backward(retain_graph=True) 91 | 92 | output = input_img.grad.cpu().data.numpy() 93 | output = output[0, :, :, :] 94 | output = output.transpose((1, 2, 0)) 95 | 96 | replace_all_layer_type_recursive(self.model, 97 | GuidedBackpropReLUasModule, 98 | torch.nn.ReLU()) 99 | 100 | return output 101 | -------------------------------------------------------------------------------- /pytorch_grad_cam/layer_cam.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from pytorch_grad_cam.base_cam import BaseCAM 3 | from pytorch_grad_cam.utils.svd_on_activations import get_2d_projection 4 | 5 | # https://ieeexplore.ieee.org/document/9462463 6 | 7 | 8 | class LayerCAM(BaseCAM): 9 | def __init__( 10 | self, 11 | model, 12 | target_layers, 13 | use_cuda=False, 14 | reshape_transform=None): 15 | super( 16 | LayerCAM, 17 | self).__init__( 18 | model, 19 | target_layers, 20 | use_cuda, 21 | reshape_transform) 22 | 23 | def get_cam_image(self, 24 | input_tensor, 25 | target_layer, 26 | target_category, 27 | activations, 28 | grads, 29 | eigen_smooth): 30 | spatial_weighted_activations = np.maximum(grads, 0) * activations 31 | 32 | if eigen_smooth: 33 | cam = get_2d_projection(spatial_weighted_activations) 34 | else: 35 | cam = spatial_weighted_activations.sum(axis=1) 36 | return cam 37 | -------------------------------------------------------------------------------- /pytorch_grad_cam/score_cam.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import tqdm 3 | from pytorch_grad_cam.base_cam import BaseCAM 4 | 5 | 6 | class ScoreCAM(BaseCAM): 7 | def __init__( 8 | self, 9 | model, 10 | target_layers, 11 | use_cuda=False, 12 | reshape_transform=None): 13 | super(ScoreCAM, self).__init__(model, 14 | target_layers, 15 | use_cuda, 16 | reshape_transform=reshape_transform, 17 | uses_gradients=False) 18 | 19 | if len(target_layers) > 0: 20 | print("Warning: You are using ScoreCAM with target layers, " 21 | "however ScoreCAM will ignore them.") 22 | 23 | def get_cam_weights(self, 24 | input_tensor, 25 | target_layer, 26 | targets, 27 | activations, 28 | grads): 29 | with torch.no_grad(): 30 | upsample = torch.nn.UpsamplingBilinear2d( 31 | size=input_tensor.shape[-2:]) 32 | activation_tensor = torch.from_numpy(activations) 33 | if self.cuda: 34 | activation_tensor = activation_tensor.cuda() 35 | 36 | upsampled = upsample(activation_tensor) 37 | 38 | maxs = upsampled.view(upsampled.size(0), 39 | upsampled.size(1), -1).max(dim=-1)[0] 40 | mins = upsampled.view(upsampled.size(0), 41 | upsampled.size(1), -1).min(dim=-1)[0] 42 | 43 | maxs, mins = maxs[:, :, None, None], mins[:, :, None, None] 44 | upsampled = (upsampled - mins) / (maxs - mins) 45 | 46 | input_tensors = input_tensor[:, None, 47 | :, :] * upsampled[:, :, None, :, :] 48 | 49 | if hasattr(self, "batch_size"): 50 | BATCH_SIZE = self.batch_size 51 | else: 52 | BATCH_SIZE = 16 53 | 54 | scores = [] 55 | for target, tensor in zip(targets, input_tensors): 56 | for i in tqdm.tqdm(range(0, tensor.size(0), BATCH_SIZE)): 57 | batch = tensor[i: i + BATCH_SIZE, :] 58 | outputs = [target(o).cpu().item() for o in self.model(batch)] 59 | scores.extend(outputs) 60 | scores = torch.Tensor(scores) 61 | scores = scores.view(activations.shape[0], activations.shape[1]) 62 | weights = torch.nn.Softmax(dim=-1)(scores).numpy() 63 | return weights 64 | -------------------------------------------------------------------------------- /pytorch_grad_cam/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from pytorch_grad_cam.utils.image import deprocess_image 2 | from pytorch_grad_cam.utils.svd_on_activations import get_2d_projection 3 | from pytorch_grad_cam.utils import model_targets 4 | from pytorch_grad_cam.utils import reshape_transforms -------------------------------------------------------------------------------- /pytorch_grad_cam/utils/find_layers.py: -------------------------------------------------------------------------------- 1 | def replace_layer_recursive(model, old_layer, new_layer): 2 | for name, layer in model._modules.items(): 3 | if layer == old_layer: 4 | model._modules[name] = new_layer 5 | return True 6 | elif replace_layer_recursive(layer, old_layer, new_layer): 7 | return True 8 | return False 9 | 10 | 11 | def replace_all_layer_type_recursive(model, old_layer_type, new_layer): 12 | for name, layer in model._modules.items(): 13 | if isinstance(layer, old_layer_type): 14 | model._modules[name] = new_layer 15 | replace_all_layer_type_recursive(layer, old_layer_type, new_layer) 16 | 17 | 18 | def find_layer_types_recursive(model, layer_types): 19 | def predicate(layer): 20 | return type(layer) in layer_types 21 | return find_layer_predicate_recursive(model, predicate) 22 | 23 | 24 | def find_layer_predicate_recursive(model, predicate): 25 | result = [] 26 | for name, layer in model._modules.items(): 27 | if predicate(layer): 28 | result.append(layer) 29 | result.extend(find_layer_predicate_recursive(layer, predicate)) 30 | return result -------------------------------------------------------------------------------- /pytorch_grad_cam/utils/image.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import torch 4 | from torchvision.transforms import Compose, Normalize, ToTensor 5 | 6 | 7 | def preprocess_image(img: np.ndarray, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) -> torch.Tensor: 8 | preprocessing = Compose([ 9 | ToTensor(), 10 | Normalize(mean=mean, std=std) 11 | ]) 12 | return preprocessing(img.copy()).unsqueeze(0) 13 | 14 | 15 | def deprocess_image(img): 16 | """ see https://github.com/jacobgil/keras-grad-cam/blob/master/grad-cam.py#L65 """ 17 | img = img - np.mean(img) 18 | img = img / (np.std(img) + 1e-5) 19 | img = img * 0.1 20 | img = img + 0.5 21 | img = np.clip(img, 0, 1) 22 | return np.uint8(img * 255) 23 | 24 | 25 | def show_cam_on_image(img: np.ndarray, 26 | mask: np.ndarray, 27 | use_rgb: bool = False, 28 | colormap: int = cv2.COLORMAP_JET) -> np.ndarray: 29 | """ This function overlays the cam mask on the image as an heatmap. 30 | By default the heatmap is in BGR format. 31 | 32 | :param img: The base image in RGB or BGR format. 33 | :param mask: The cam mask. 34 | :param use_rgb: Whether to use an RGB or BGR heatmap, this should be set to True if 'img' is in RGB format. 35 | :param colormap: The OpenCV colormap to be used. 36 | :returns: The default image with the cam overlay. 37 | """ 38 | heatmap = cv2.applyColorMap(np.uint8(255 * mask), colormap) 39 | if use_rgb: 40 | heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) 41 | heatmap = np.float32(heatmap) / 255 42 | 43 | if np.max(img) > 1: 44 | raise Exception( 45 | "The input image should np.float32 in the range [0, 1]") 46 | 47 | cam = heatmap + img 48 | cam = cam / np.max(cam) 49 | return np.uint8(255 * cam) 50 | 51 | def scale_cam_image(cam, target_size=None): 52 | result = [] 53 | for img in cam: 54 | img = img - np.min(img) 55 | img = img / (1e-7 + np.max(img)) 56 | if target_size is not None: 57 | img = cv2.resize(img, target_size) 58 | result.append(img) 59 | result = np.float32(result) 60 | 61 | return result 62 | 63 | def scale_accross_batch_and_channels(tensor, target_size): 64 | batch_size, channel_size = tensor.shape[:2] 65 | reshaped_tensor = tensor.reshape( 66 | batch_size * channel_size, *tensor.shape[2:]) 67 | result = scale_cam_image(reshaped_tensor, target_size) 68 | result = result.reshape( 69 | batch_size, 70 | channel_size, 71 | target_size[1], 72 | target_size[0]) 73 | return result 74 | -------------------------------------------------------------------------------- /pytorch_grad_cam/utils/model_targets.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torchvision 4 | 5 | class ClassifierOutputTarget: 6 | def __init__(self, category): 7 | self.category = category 8 | def __call__(self, model_output): 9 | if len(model_output.shape) == 1: 10 | return model_output[self.category] 11 | return model_output[:, self.category] 12 | 13 | class SemanticSegmentationTarget: 14 | """ Gets a binary spatial mask and a category, 15 | And return the sum of the category scores, 16 | of the pixels in the mask. """ 17 | def __init__(self, category, mask): 18 | self.category = category 19 | self.mask = torch.from_numpy(mask) 20 | if torch.cuda.is_available(): 21 | self.mask = self.mask.cuda() 22 | 23 | def __call__(self, model_output): 24 | return (model_output[self.category, :, : ] * self.mask).sum() 25 | 26 | 27 | class FasterRCNNBoxScoreTarget: 28 | """ For every original detected bounding box specified in "bounding boxes", 29 | assign a score on how the current bounding boxes match it, 30 | 1. In IOU 31 | 2. In the classification score. 32 | If there is not a large enough overlap, or the category changed, 33 | assign a score of 0. 34 | 35 | The total score is the sum of all the box scores. 36 | """ 37 | 38 | def __init__(self, labels, bounding_boxes, iou_threshold=0.5): 39 | self.labels = labels 40 | self.bounding_boxes = bounding_boxes 41 | self.iou_threshold = iou_threshold 42 | 43 | def __call__(self, model_outputs): 44 | output = torch.Tensor([0]) 45 | if torch.cuda.is_available(): 46 | output = output.cuda() 47 | 48 | if len(model_outputs["boxes"]) == 0: 49 | return output 50 | 51 | for box, label in zip(self.bounding_boxes, self.labels): 52 | box = torch.Tensor(box[None, :]) 53 | if torch.cuda.is_available(): 54 | box = box.cuda() 55 | 56 | ious = torchvision.ops.box_iou(box, model_outputs["boxes"]) 57 | index = ious.argmax() 58 | if ious[0, index] > self.iou_threshold and model_outputs["labels"][index] == label: 59 | score = ious[0, index] + model_outputs["scores"][index] 60 | output = output + score 61 | return output -------------------------------------------------------------------------------- /pytorch_grad_cam/utils/reshape_transforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def fasterrcnn_reshape_transform(x): 4 | target_size = x['pool'].size()[-2 : ] 5 | activations = [] 6 | for key, value in x.items(): 7 | activations.append(torch.nn.functional.interpolate(torch.abs(value), target_size, mode='bilinear')) 8 | activations = torch.cat(activations, axis=1) 9 | return activations 10 | 11 | def swinT_reshape_transform(tensor, height=7, width=7): 12 | result = tensor.reshape(tensor.size(0), 13 | height, width, tensor.size(2)) 14 | 15 | # Bring the channels to the first dimension, 16 | # like in CNNs. 17 | result = result.transpose(2, 3).transpose(1, 2) 18 | return result 19 | 20 | def vit_reshape_transform(tensor, height=14, width=14): 21 | result = tensor[:, 1:, :].reshape(tensor.size(0), 22 | height, width, tensor.size(2)) 23 | 24 | # Bring the channels to the first dimension, 25 | # like in CNNs. 26 | result = result.transpose(2, 3).transpose(1, 2) 27 | return result 28 | -------------------------------------------------------------------------------- /pytorch_grad_cam/utils/svd_on_activations.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def get_2d_projection(activation_batch): 5 | # TBD: use pytorch batch svd implementation 6 | activation_batch[np.isnan(activation_batch)] = 0 7 | projections = [] 8 | for activations in activation_batch: 9 | reshaped_activations = (activations).reshape( 10 | activations.shape[0], -1).transpose() 11 | # Centering before the SVD seems to be important here, 12 | # Otherwise the image returned is negative 13 | reshaped_activations = reshaped_activations - \ 14 | reshaped_activations.mean(axis=0) 15 | U, S, VT = np.linalg.svd(reshaped_activations, full_matrices=True) 16 | projection = reshaped_activations @ VT[0, :] 17 | projection = projection.reshape(activations.shape[1:]) 18 | projections.append(projection) 19 | return np.float32(projections) 20 | -------------------------------------------------------------------------------- /pytorch_grad_cam/xgrad_cam.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from pytorch_grad_cam.base_cam import BaseCAM 3 | 4 | 5 | class XGradCAM(BaseCAM): 6 | def __init__( 7 | self, 8 | model, 9 | target_layers, 10 | use_cuda=False, 11 | reshape_transform=None): 12 | super( 13 | XGradCAM, 14 | self).__init__( 15 | model, 16 | target_layers, 17 | use_cuda, 18 | reshape_transform) 19 | 20 | def get_cam_weights(self, 21 | input_tensor, 22 | target_layer, 23 | target_category, 24 | activations, 25 | grads): 26 | sum_activations = np.sum(activations, axis=(2, 3)) 27 | eps = 1e-7 28 | weights = grads * activations / \ 29 | (sum_activations[:, :, None, None] + eps) 30 | weights = weights.sum(axis=(2, 3)) 31 | return weights 32 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | submitit 2 | pycocotools 3 | torch==1.8.1 4 | torchvision==0.9.1 5 | faiss-gpu==1.7.2 6 | opencv-python==4.6.0.66 7 | scikit-image==0.19.2 8 | scikit-learn==1.1.1 9 | shapely==1.8.2 10 | timm==0.5.4 11 | pyyaml==6.0 12 | colored 13 | fvcore==0.1.5.post20220512 14 | gdown==4.5.4 -------------------------------------------------------------------------------- /tpnet/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | import config 4 | import engine 5 | import modeling 6 | import structures 7 | import tools 8 | import demo 9 | 10 | # dataset loading 11 | from . import data # register all new datasets 12 | from data import datasets # register all new datasets 13 | from solver import * 14 | 15 | # from .data import register_all_imagenet -------------------------------------------------------------------------------- /tpnet/config/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | from .cutler_config import add_cutler_config -------------------------------------------------------------------------------- /tpnet/config/cutler_config.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | from detectron2.config import CfgNode as CN 4 | 5 | def add_cutler_config(cfg): 6 | cfg.DATALOADER.COPY_PASTE = False 7 | cfg.DATALOADER.COPY_PASTE_RATE = 0.0 8 | cfg.DATALOADER.COPY_PASTE_MIN_RATIO = 0.5 9 | cfg.DATALOADER.COPY_PASTE_MAX_RATIO = 1.0 10 | cfg.DATALOADER.COPY_PASTE_RANDOM_NUM = True 11 | cfg.DATALOADER.VISUALIZE_COPY_PASTE = False 12 | 13 | cfg.MODEL.ROI_HEADS.USE_DROPLOSS = False 14 | cfg.MODEL.ROI_HEADS.DROPLOSS_IOU_THRESH = 0.0 15 | 16 | cfg.SOLVER.BASE_LR_MULTIPLIER = 1 17 | cfg.SOLVER.BASE_LR_MULTIPLIER_NAMES = [] 18 | 19 | cfg.TEST.NO_SEGM = False -------------------------------------------------------------------------------- /tpnet/data/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | from . import datasets # ensure the builtin datasets are registered 4 | from .detection_utils import * # isort:skip 5 | from .build import ( 6 | build_batch_data_loader, 7 | build_detection_train_loader, 8 | build_detection_test_loader, 9 | get_detection_dataset_dicts, 10 | load_proposals_into_dataset, 11 | print_instances_class_histogram, 12 | ) 13 | from detectron2.data.common import * 14 | 15 | __all__ = [k for k in globals().keys() if not k.startswith("_")] -------------------------------------------------------------------------------- /tpnet/data/dataset_mapper.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # Modified by XuDong Wang from https://github.com/facebookresearch/detectron2/blob/main/detectron2/data/dataset_mapper.py 3 | 4 | import copy 5 | import logging 6 | import numpy as np 7 | from typing import List, Optional, Union 8 | import torch 9 | 10 | from detectron2.config import configurable 11 | 12 | import data.detection_utils as utils 13 | import data.transforms as T 14 | 15 | """ 16 | This file contains the default mapping that's applied to "dataset dicts". 17 | """ 18 | 19 | __all__ = ["DatasetMapper"] 20 | 21 | 22 | class DatasetMapper: 23 | """ 24 | A callable which takes a dataset dict in Detectron2 Dataset format, 25 | and map it into a format used by the model. 26 | 27 | This is the default callable to be used to map your dataset dict into training data. 28 | You may need to follow it to implement your own one for customized logic, 29 | such as a different way to read or transform images. 30 | See :doc:`/tutorials/data_loading` for details. 31 | 32 | The callable currently does the following: 33 | 34 | 1. Read the image from "file_name" 35 | 2. Applies cropping/geometric transforms to the image and annotations 36 | 3. Prepare data and annotations to Tensor and :class:`Instances` 37 | """ 38 | 39 | @configurable 40 | def __init__( 41 | self, 42 | is_train: bool, 43 | *, 44 | augmentations: List[Union[T.Augmentation, T.Transform]], 45 | image_format: str, 46 | use_instance_mask: bool = False, 47 | use_keypoint: bool = False, 48 | instance_mask_format: str = "polygon", 49 | keypoint_hflip_indices: Optional[np.ndarray] = None, 50 | precomputed_proposal_topk: Optional[int] = None, 51 | recompute_boxes: bool = False, 52 | ): 53 | """ 54 | NOTE: this interface is experimental. 55 | 56 | Args: 57 | is_train: whether it's used in training or inference 58 | augmentations: a list of augmentations or deterministic transforms to apply 59 | image_format: an image format supported by :func:`detection_utils.read_image`. 60 | use_instance_mask: whether to process instance segmentation annotations, if available 61 | use_keypoint: whether to process keypoint annotations if available 62 | instance_mask_format: one of "polygon" or "bitmask". Process instance segmentation 63 | masks into this format. 64 | keypoint_hflip_indices: see :func:`detection_utils.create_keypoint_hflip_indices` 65 | precomputed_proposal_topk: if given, will load pre-computed 66 | proposals from dataset_dict and keep the top k proposals for each image. 67 | recompute_boxes: whether to overwrite bounding box annotations 68 | by computing tight bounding boxes from instance mask annotations. 69 | """ 70 | if recompute_boxes: 71 | assert use_instance_mask, "recompute_boxes requires instance masks" 72 | # fmt: off 73 | self.is_train = is_train 74 | self.augmentations = T.AugmentationList(augmentations) 75 | self.image_format = image_format 76 | self.use_instance_mask = use_instance_mask 77 | self.instance_mask_format = instance_mask_format 78 | self.use_keypoint = use_keypoint 79 | self.keypoint_hflip_indices = keypoint_hflip_indices 80 | self.proposal_topk = precomputed_proposal_topk 81 | self.recompute_boxes = recompute_boxes 82 | # fmt: on 83 | logger = logging.getLogger(__name__) 84 | mode = "training" if is_train else "inference" 85 | logger.info(f"[DatasetMapper] Augmentations used in {mode}: {augmentations}") 86 | 87 | @classmethod 88 | def from_config(cls, cfg, is_train: bool = True): 89 | augs = utils.build_augmentation(cfg, is_train) 90 | if cfg.INPUT.CROP.ENABLED and is_train: 91 | augs.insert(0, T.RandomCrop(cfg.INPUT.CROP.TYPE, cfg.INPUT.CROP.SIZE)) 92 | recompute_boxes = cfg.MODEL.MASK_ON 93 | else: 94 | recompute_boxes = False 95 | 96 | ret = { 97 | "is_train": is_train, 98 | "augmentations": augs, 99 | "image_format": cfg.INPUT.FORMAT, 100 | "use_instance_mask": cfg.MODEL.MASK_ON, 101 | "instance_mask_format": cfg.INPUT.MASK_FORMAT, 102 | "use_keypoint": cfg.MODEL.KEYPOINT_ON, 103 | "recompute_boxes": recompute_boxes, 104 | } 105 | 106 | if cfg.MODEL.KEYPOINT_ON: 107 | ret["keypoint_hflip_indices"] = utils.create_keypoint_hflip_indices(cfg.DATASETS.TRAIN) 108 | 109 | if cfg.MODEL.LOAD_PROPOSALS: 110 | ret["precomputed_proposal_topk"] = ( 111 | cfg.DATASETS.PRECOMPUTED_PROPOSAL_TOPK_TRAIN 112 | if is_train 113 | else cfg.DATASETS.PRECOMPUTED_PROPOSAL_TOPK_TEST 114 | ) 115 | return ret 116 | 117 | def _transform_annotations(self, dataset_dict, transforms, image_shape): 118 | # USER: Modify this if you want to keep them for some reason. 119 | for anno in dataset_dict["annotations"]: 120 | if not self.use_instance_mask: 121 | anno.pop("segmentation", None) 122 | if not self.use_keypoint: 123 | anno.pop("keypoints", None) 124 | 125 | # USER: Implement additional transformations if you have other types of data 126 | annos = [ 127 | utils.transform_instance_annotations( 128 | obj, transforms, image_shape, keypoint_hflip_indices=self.keypoint_hflip_indices 129 | ) 130 | for obj in dataset_dict.pop("annotations") 131 | if obj.get("iscrowd", 0) == 0 132 | ] 133 | instances = utils.annotations_to_instances( 134 | annos, image_shape, mask_format=self.instance_mask_format 135 | ) 136 | 137 | # After transforms such as cropping are applied, the bounding box may no longer 138 | # tightly bound the object. As an example, imagine a triangle object 139 | # [(0,0), (2,0), (0,2)] cropped by a box [(1,0),(2,2)] (XYXY format). The tight 140 | # bounding box of the cropped triangle should be [(1,0),(2,1)], which is not equal to 141 | # the intersection of original bounding box and the cropping box. 142 | if self.recompute_boxes: 143 | instances.gt_boxes = instances.gt_masks.get_bounding_boxes() 144 | dataset_dict["instances"] = utils.filter_empty_instances(instances) 145 | 146 | def __call__(self, dataset_dict): 147 | """ 148 | Args: 149 | dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format. 150 | 151 | Returns: 152 | dict: a format that builtin models in detectron2 accept 153 | """ 154 | dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below 155 | # USER: Write your own image loading if it's not from a file 156 | image = utils.read_image(dataset_dict["file_name"], format=self.image_format) 157 | utils.check_image_size(dataset_dict, image) 158 | 159 | # USER: Remove if you don't do semantic/panoptic segmentation. 160 | if "sem_seg_file_name" in dataset_dict: 161 | sem_seg_gt = utils.read_image(dataset_dict.pop("sem_seg_file_name"), "L").squeeze(2) 162 | else: 163 | sem_seg_gt = None 164 | 165 | aug_input = T.AugInput(image, sem_seg=sem_seg_gt) 166 | transforms = self.augmentations(aug_input) 167 | image, sem_seg_gt = aug_input.image, aug_input.sem_seg 168 | 169 | image_shape = image.shape[:2] # h, w 170 | # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory, 171 | # but not efficient on large generic data structures due to the use of pickle & mp.Queue. 172 | # Therefore it's important to use torch.Tensor. 173 | dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1))) 174 | if sem_seg_gt is not None: 175 | dataset_dict["sem_seg"] = torch.as_tensor(sem_seg_gt.astype("long")) 176 | 177 | # USER: Remove if you don't use pre-computed proposals. 178 | # Most users would not need this feature. 179 | if self.proposal_topk is not None: 180 | utils.transform_proposals( 181 | dataset_dict, image_shape, transforms, proposal_topk=self.proposal_topk 182 | ) 183 | 184 | if not self.is_train: 185 | # USER: Modify this if you want to keep them for some reason. 186 | dataset_dict.pop("annotations", None) 187 | dataset_dict.pop("sem_seg_file_name", None) 188 | return dataset_dict 189 | 190 | if "annotations" in dataset_dict: 191 | self._transform_annotations(dataset_dict, transforms, image_shape) 192 | 193 | return dataset_dict 194 | -------------------------------------------------------------------------------- /tpnet/data/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | from .coco import load_coco_json, load_sem_seg, register_coco_instances, convert_to_coco_json 3 | from .builtin import ( 4 | register_all_imagenet, 5 | register_all_uvo, 6 | register_all_coco_ca, 7 | register_all_coco_semi, 8 | register_all_lvis, 9 | register_all_voc, 10 | register_all_cross_domain, 11 | register_all_kitti, 12 | register_all_objects365, 13 | register_all_openimages, 14 | ) 15 | from . import register_cis 16 | 17 | __all__ = [k for k in globals().keys() if not k.startswith("_")] -------------------------------------------------------------------------------- /tpnet/data/datasets/builtin.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # Modified by XuDong Wang from https://github.com/facebookresearch/detectron2/blob/main/detectron2/data/datasets/builtin.py 3 | 4 | """ 5 | This file registers pre-defined datasets at hard-coded paths, and their metadata. 6 | 7 | We hard-code metadata for common datasets. This will enable: 8 | 1. Consistency check when loading the datasets 9 | 2. Use models on these standard datasets directly and run demos, 10 | without having to download the dataset annotations 11 | 12 | We hard-code some paths to the dataset that's assumed to 13 | exist in "./datasets/". 14 | 15 | Users SHOULD NOT use this file to create new dataset / metadata for new dataset. 16 | To add new dataset, refer to the tutorial "docs/DATASETS.md". 17 | """ 18 | 19 | import os 20 | 21 | from .builtin_meta import _get_builtin_metadata 22 | from .coco import register_coco_instances 23 | 24 | # ==== Predefined datasets and splits for COCO ========== 25 | 26 | _PREDEFINED_SPLITS_COCO_SEMI = {} 27 | _PREDEFINED_SPLITS_COCO_SEMI["coco_semi"] = { 28 | # we use seed 42 to be consistent with previous works on SSL detection and segmentation 29 | "coco_semi_1perc": ("coco/train2017", "coco/annotations/1perc_instances_train2017.json"), 30 | "coco_semi_2perc": ("coco/train2017", "coco/annotations/2perc_instances_train2017.json"), 31 | "coco_semi_5perc": ("coco/train2017", "coco/annotations/5perc_instances_train2017.json"), 32 | "coco_semi_10perc": ("coco/train2017", "coco/annotations/10perc_instances_train2017.json"), 33 | "coco_semi_20perc": ("coco/train2017", "coco/annotations/20perc_instances_train2017.json"), 34 | "coco_semi_30perc": ("coco/train2017", "coco/annotations/30perc_instances_train2017.json"), 35 | "coco_semi_40perc": ("coco/train2017", "coco/annotations/40perc_instances_train2017.json"), 36 | "coco_semi_50perc": ("coco/train2017", "coco/annotations/50perc_instances_train2017.json"), 37 | "coco_semi_60perc": ("coco/train2017", "coco/annotations/60perc_instances_train2017.json"), 38 | "coco_semi_80perc": ("coco/train2017", "coco/annotations/80perc_instances_train2017.json"), 39 | } 40 | 41 | _PREDEFINED_SPLITS_COCO_CA = {} 42 | _PREDEFINED_SPLITS_COCO_CA["coco_cls_agnostic"] = { 43 | "cls_agnostic_coco": ("coco/val2017", "coco/annotations/coco_cls_agnostic_instances_val2017.json"), 44 | "cls_agnostic_coco20k": ("coco/train2014", "coco/annotations/coco20k_trainval_gt.json"), 45 | } 46 | 47 | _PREDEFINED_SPLITS_IMAGENET = {} 48 | _PREDEFINED_SPLITS_IMAGENET["imagenet"] = { 49 | # maskcut annotations 50 | "imagenet_train": ("imagenet/train", "imagenet/annotations/imagenet_train_fixsize480_tau0.15_N3.json"), 51 | # self-training round 1 52 | "imagenet_train_r1": ("imagenet/train", "imagenet/annotations/cutler_imagenet1k_train_r1.json"), 53 | # self-training round 2 54 | "imagenet_train_r2": ("imagenet/train", "imagenet/annotations/cutler_imagenet1k_train_r2.json"), 55 | # self-training round 3 56 | "imagenet_train_r3": ("imagenet/train", "imagenet/annotations/cutler_imagenet1k_train_r3.json"), 57 | } 58 | 59 | _PREDEFINED_SPLITS_VOC = {} 60 | _PREDEFINED_SPLITS_VOC["voc"] = { 61 | 'cls_agnostic_voc': ("voc/", "voc/annotations/trainvaltest_2007_cls_agnostic.json"), 62 | } 63 | 64 | _PREDEFINED_SPLITS_CROSSDOMAIN = {} 65 | _PREDEFINED_SPLITS_CROSSDOMAIN["cross_domain"] = { 66 | 'cls_agnostic_clipart': ("clipart/", "clipart/annotations/traintest_cls_agnostic.json"), 67 | 'cls_agnostic_watercolor': ("watercolor/", "watercolor/annotations/traintest_cls_agnostic.json"), 68 | 'cls_agnostic_comic': ("comic/", "comic/annotations/traintest_cls_agnostic.json"), 69 | } 70 | 71 | _PREDEFINED_SPLITS_KITTI = {} 72 | _PREDEFINED_SPLITS_KITTI["kitti"] = { 73 | 'cls_agnostic_kitti': ("kitti/", "kitti/annotations/trainval_cls_agnostic.json"), 74 | } 75 | 76 | _PREDEFINED_SPLITS_LVIS = {} 77 | _PREDEFINED_SPLITS_LVIS["lvis"] = { 78 | "cls_agnostic_lvis": ("coco/", "coco/annotations/lvis1.0_cocofied_val_cls_agnostic.json"), 79 | } 80 | 81 | _PREDEFINED_SPLITS_OBJECTS365 = {} 82 | _PREDEFINED_SPLITS_OBJECTS365["objects365"] = { 83 | 'cls_agnostic_objects365': ("objects365/val", "objects365/annotations/zhiyuan_objv2_val_cls_agnostic.json"), 84 | } 85 | 86 | _PREDEFINED_SPLITS_OpenImages = {} 87 | _PREDEFINED_SPLITS_OpenImages["openimages"] = { 88 | 'cls_agnostic_openimages': ("openImages/validation", "openImages/annotations/openimages_val_cls_agnostic.json"), 89 | } 90 | 91 | _PREDEFINED_SPLITS_UVO = {} 92 | _PREDEFINED_SPLITS_UVO["uvo"] = { 93 | "cls_agnostic_uvo": ("uvo/all_UVO_frames", "uvo/annotations/val_sparse_cleaned_cls_agnostic.json"), 94 | } 95 | 96 | def register_all_imagenet(root): 97 | for dataset_name, splits_per_dataset in _PREDEFINED_SPLITS_IMAGENET.items(): 98 | for key, (image_root, json_file) in splits_per_dataset.items(): 99 | # Assume pre-defined datasets live in `./datasets`. 100 | register_coco_instances( 101 | key, 102 | _get_builtin_metadata(dataset_name), 103 | os.path.join(root, json_file) if "://" not in json_file else json_file, 104 | os.path.join(root, image_root), 105 | ) 106 | 107 | def register_all_voc(root): 108 | for dataset_name, splits_per_dataset in _PREDEFINED_SPLITS_VOC.items(): 109 | for key, (image_root, json_file) in splits_per_dataset.items(): 110 | # Assume pre-defined datasets live in `./datasets`. 111 | register_coco_instances( 112 | key, 113 | _get_builtin_metadata(dataset_name), 114 | os.path.join(root, json_file) if "://" not in json_file else json_file, 115 | os.path.join(root, image_root), 116 | ) 117 | 118 | def register_all_cross_domain(root): 119 | for dataset_name, splits_per_dataset in _PREDEFINED_SPLITS_CROSSDOMAIN.items(): 120 | for key, (image_root, json_file) in splits_per_dataset.items(): 121 | # Assume pre-defined datasets live in `./datasets`. 122 | register_coco_instances( 123 | key, 124 | _get_builtin_metadata(dataset_name), 125 | os.path.join(root, json_file) if "://" not in json_file else json_file, 126 | os.path.join(root, image_root), 127 | ) 128 | 129 | def register_all_kitti(root): 130 | for dataset_name, splits_per_dataset in _PREDEFINED_SPLITS_KITTI.items(): 131 | for key, (image_root, json_file) in splits_per_dataset.items(): 132 | # Assume pre-defined datasets live in `./datasets`. 133 | register_coco_instances( 134 | key, 135 | _get_builtin_metadata(dataset_name), 136 | os.path.join(root, json_file) if "://" not in json_file else json_file, 137 | os.path.join(root, image_root), 138 | ) 139 | 140 | def register_all_objects365(root): 141 | for dataset_name, splits_per_dataset in _PREDEFINED_SPLITS_OBJECTS365.items(): 142 | for key, (image_root, json_file) in splits_per_dataset.items(): 143 | # Assume pre-defined datasets live in `./datasets`. 144 | register_coco_instances( 145 | key, 146 | _get_builtin_metadata(dataset_name), 147 | os.path.join(root, json_file) if "://" not in json_file else json_file, 148 | os.path.join(root, image_root), 149 | ) 150 | 151 | def register_all_openimages(root): 152 | for dataset_name, splits_per_dataset in _PREDEFINED_SPLITS_OpenImages.items(): 153 | for key, (image_root, json_file) in splits_per_dataset.items(): 154 | # Assume pre-defined datasets live in `./datasets`. 155 | register_coco_instances( 156 | key, 157 | _get_builtin_metadata(dataset_name), 158 | os.path.join(root, json_file) if "://" not in json_file else json_file, 159 | os.path.join(root, image_root), 160 | ) 161 | 162 | def register_all_lvis(root): 163 | for dataset_name, splits_per_dataset in _PREDEFINED_SPLITS_LVIS.items(): 164 | for key, (image_root, json_file) in splits_per_dataset.items(): 165 | # Assume pre-defined datasets live in `./datasets`. 166 | register_coco_instances( 167 | key, 168 | _get_builtin_metadata(dataset_name), 169 | os.path.join(root, json_file) if "://" not in json_file else json_file, 170 | os.path.join(root, image_root), 171 | ) 172 | 173 | def register_all_uvo(root): 174 | for dataset_name, splits_per_dataset in _PREDEFINED_SPLITS_UVO.items(): 175 | for key, (image_root, json_file) in splits_per_dataset.items(): 176 | # Assume pre-defined datasets live in `./datasets`. 177 | register_coco_instances( 178 | key, 179 | _get_builtin_metadata(dataset_name), 180 | os.path.join(root, json_file) if "://" not in json_file else json_file, 181 | os.path.join(root, image_root), 182 | ) 183 | 184 | def register_all_coco_semi(root): 185 | for dataset_name, splits_per_dataset in _PREDEFINED_SPLITS_COCO_SEMI.items(): 186 | for key, (image_root, json_file) in splits_per_dataset.items(): 187 | # Assume pre-defined datasets live in `./datasets`. 188 | register_coco_instances( 189 | key, 190 | _get_builtin_metadata(dataset_name), 191 | os.path.join(root, json_file) if "://" not in json_file else json_file, 192 | os.path.join(root, image_root), 193 | ) 194 | 195 | def register_all_coco_ca(root): 196 | for dataset_name, splits_per_dataset in _PREDEFINED_SPLITS_COCO_CA.items(): 197 | for key, (image_root, json_file) in splits_per_dataset.items(): 198 | # Assume pre-defined datasets live in `./datasets`. 199 | register_coco_instances( 200 | key, 201 | _get_builtin_metadata(dataset_name), 202 | os.path.join(root, json_file) if "://" not in json_file else json_file, 203 | os.path.join(root, image_root), 204 | ) 205 | 206 | _root = os.path.expanduser(os.getenv("DETECTRON2_DATASETS", "datasets")) 207 | register_all_coco_semi(_root) 208 | register_all_coco_ca(_root) 209 | register_all_imagenet(_root) 210 | register_all_uvo(_root) 211 | register_all_voc(_root) 212 | register_all_cross_domain(_root) 213 | register_all_kitti(_root) 214 | register_all_openimages(_root) 215 | register_all_objects365(_root) 216 | register_all_lvis(_root) -------------------------------------------------------------------------------- /tpnet/data/datasets/register_cis.py: -------------------------------------------------------------------------------- 1 | import os 2 | from detectron2.data.datasets.coco import load_coco_json 3 | from detectron2.data import MetadataCatalog, DatasetCatalog 4 | 5 | DATASET_ROOT = '/media/data2/HZT/data/CIS/COD10K' 6 | ANN_ROOT = os.path.join(DATASET_ROOT, 'annotations') 7 | TRAIN_PATH = os.path.join(DATASET_ROOT, 'Train_Image_CAM') 8 | TEST_PATH = os.path.join(DATASET_ROOT, 'Test_Image_CAM') 9 | TRAIN_JSON = os.path.join(ANN_ROOT, 'train_instance.json') 10 | TEST_JSON = os.path.join(ANN_ROOT, 'class_5_annotations.json') 11 | 12 | 13 | NC4K_ROOT = '/media/data2/HZT/data/CIS/NC4K' 14 | NC4K_PATH = os.path.join(NC4K_ROOT, 'test/image') 15 | NC4K_JSON = os.path.join(NC4K_ROOT, 'nc4k_test.json') 16 | 17 | CLASS_NAMES = ["foreground"] 18 | 19 | PREDEFINED_SPLITS_DATASET = { 20 | "my_data_train_coco_cod_style": (TRAIN_PATH, TRAIN_JSON), 21 | "my_data_test_coco_cod_style": (TEST_PATH, TEST_JSON), 22 | "my_data_test_coco_nc4k_style": (NC4K_PATH, NC4K_JSON), 23 | } 24 | 25 | 26 | def register_dataset(): 27 | """ 28 | purpose: register all splits of dataset with PREDEFINED_SPLITS_DATASET 29 | """ 30 | for key, (image_root, json_file) in PREDEFINED_SPLITS_DATASET.items(): 31 | register_dataset_instances(name=key, 32 | json_file=json_file, 33 | image_root=image_root) 34 | 35 | 36 | def register_dataset_instances(name, json_file, image_root): 37 | """ 38 | purpose: register dataset to DatasetCatalog, 39 | register metadata to MetadataCatalog and set attribute 40 | """ 41 | DatasetCatalog.register(name, lambda: load_coco_json(json_file, image_root, name)) 42 | MetadataCatalog.get(name).set(json_file=json_file, 43 | image_root=image_root, 44 | evaluator_type="coco") 45 | 46 | #_root = os.getenv("DETECTRON2_DATASETS", "datasets") 47 | register_dataset() -------------------------------------------------------------------------------- /tpnet/data/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # Modified by XuDong Wang from https://github.com/facebookresearch/detectron2/blob/main/detectron2/data/transforms/__init__.py 3 | 4 | from fvcore.transforms.transform import * 5 | from .transform import * 6 | from detectron2.data.transforms.augmentation import * 7 | from .augmentation_impl import * 8 | 9 | __all__ = [k for k in globals().keys() if not k.startswith("_")] 10 | 11 | 12 | from detectron2.utils.env import fixup_module_metadata 13 | 14 | fixup_module_metadata(__name__, globals(), __all__) 15 | del fixup_module_metadata -------------------------------------------------------------------------------- /tpnet/data/transforms/transform.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # Modified by XuDong Wang from https://github.com/facebookresearch/detectron2/blob/main/detectron2/data/transforms/transform.py 3 | 4 | """ 5 | See "Data Augmentation" tutorial for an overview of the system: 6 | https://detectron2.readthedocs.io/tutorials/augmentation.html 7 | """ 8 | 9 | import numpy as np 10 | import torch 11 | import torch.nn.functional as F 12 | from fvcore.transforms.transform import ( 13 | CropTransform, 14 | HFlipTransform, 15 | NoOpTransform, 16 | Transform, 17 | TransformList, 18 | ) 19 | from PIL import Image 20 | 21 | try: 22 | import cv2 # noqa 23 | except ImportError: 24 | # OpenCV is an optional dependency at the moment 25 | pass 26 | 27 | __all__ = [ 28 | "ExtentTransform", 29 | "ResizeTransform", 30 | "RotationTransform", 31 | "ColorTransform", 32 | "PILColorTransform", 33 | ] 34 | 35 | 36 | class ExtentTransform(Transform): 37 | """ 38 | Extracts a subregion from the source image and scales it to the output size. 39 | 40 | The fill color is used to map pixels from the source rect that fall outside 41 | the source image. 42 | 43 | See: https://pillow.readthedocs.io/en/latest/PIL.html#PIL.ImageTransform.ExtentTransform 44 | """ 45 | 46 | def __init__(self, src_rect, output_size, interp=Image.LINEAR, fill=0): 47 | """ 48 | Args: 49 | src_rect (x0, y0, x1, y1): src coordinates 50 | output_size (h, w): dst image size 51 | interp: PIL interpolation methods 52 | fill: Fill color used when src_rect extends outside image 53 | """ 54 | super().__init__() 55 | self._set_attributes(locals()) 56 | 57 | def apply_image(self, img, interp=None): 58 | h, w = self.output_size 59 | if len(img.shape) > 2 and img.shape[2] == 1: 60 | pil_image = Image.fromarray(img[:, :, 0], mode="L") 61 | else: 62 | pil_image = Image.fromarray(img) 63 | pil_image = pil_image.transform( 64 | size=(w, h), 65 | method=Image.EXTENT, 66 | data=self.src_rect, 67 | resample=interp if interp else self.interp, 68 | fill=self.fill, 69 | ) 70 | ret = np.asarray(pil_image) 71 | if len(img.shape) > 2 and img.shape[2] == 1: 72 | ret = np.expand_dims(ret, -1) 73 | return ret 74 | 75 | def apply_coords(self, coords): 76 | # Transform image center from source coordinates into output coordinates 77 | # and then map the new origin to the corner of the output image. 78 | h, w = self.output_size 79 | x0, y0, x1, y1 = self.src_rect 80 | new_coords = coords.astype(np.float32) 81 | new_coords[:, 0] -= 0.5 * (x0 + x1) 82 | new_coords[:, 1] -= 0.5 * (y0 + y1) 83 | new_coords[:, 0] *= w / (x1 - x0) 84 | new_coords[:, 1] *= h / (y1 - y0) 85 | new_coords[:, 0] += 0.5 * w 86 | new_coords[:, 1] += 0.5 * h 87 | return new_coords 88 | 89 | def apply_segmentation(self, segmentation): 90 | segmentation = self.apply_image(segmentation, interp=Image.NEAREST) 91 | return segmentation 92 | 93 | 94 | class ResizeTransform(Transform): 95 | """ 96 | Resize the image to a target size. 97 | """ 98 | 99 | def __init__(self, h, w, new_h, new_w, interp=None): 100 | """ 101 | Args: 102 | h, w (int): original image size 103 | new_h, new_w (int): new image size 104 | interp: PIL interpolation methods, defaults to bilinear. 105 | """ 106 | # TODO decide on PIL vs opencv 107 | super().__init__() 108 | if interp is None: 109 | interp = Image.BILINEAR 110 | self._set_attributes(locals()) 111 | 112 | def apply_image(self, img, interp=None): 113 | try: 114 | img.shape[:2] == (self.h, self.w) 115 | except: 116 | (self.h, self.w) = (self.w, self.h) 117 | assert img.shape[:2] == (self.h, self.w) 118 | assert len(img.shape) <= 4 119 | interp_method = interp if interp is not None else self.interp 120 | 121 | if img.dtype == np.uint8: 122 | if len(img.shape) > 2 and img.shape[2] == 1: 123 | pil_image = Image.fromarray(img[:, :, 0], mode="L") 124 | else: 125 | pil_image = Image.fromarray(img) 126 | pil_image = pil_image.resize((self.new_w, self.new_h), interp_method) 127 | ret = np.asarray(pil_image) 128 | if len(img.shape) > 2 and img.shape[2] == 1: 129 | ret = np.expand_dims(ret, -1) 130 | else: 131 | # PIL only supports uint8 132 | if any(x < 0 for x in img.strides): 133 | img = np.ascontiguousarray(img) 134 | img = torch.from_numpy(img) 135 | shape = list(img.shape) 136 | shape_4d = shape[:2] + [1] * (4 - len(shape)) + shape[2:] 137 | img = img.view(shape_4d).permute(2, 3, 0, 1) # hw(c) -> nchw 138 | _PIL_RESIZE_TO_INTERPOLATE_MODE = { 139 | Image.NEAREST: "nearest", 140 | Image.BILINEAR: "bilinear", 141 | Image.BICUBIC: "bicubic", 142 | } 143 | mode = _PIL_RESIZE_TO_INTERPOLATE_MODE[interp_method] 144 | align_corners = None if mode == "nearest" else False 145 | img = F.interpolate( 146 | img, (self.new_h, self.new_w), mode=mode, align_corners=align_corners 147 | ) 148 | shape[:2] = (self.new_h, self.new_w) 149 | ret = img.permute(2, 3, 0, 1).view(shape).numpy() # nchw -> hw(c) 150 | 151 | return ret 152 | 153 | def apply_coords(self, coords): 154 | coords[:, 0] = coords[:, 0] * (self.new_w * 1.0 / self.w) 155 | coords[:, 1] = coords[:, 1] * (self.new_h * 1.0 / self.h) 156 | return coords 157 | 158 | def apply_segmentation(self, segmentation): 159 | segmentation = self.apply_image(segmentation, interp=Image.NEAREST) 160 | return segmentation 161 | 162 | def inverse(self): 163 | return ResizeTransform(self.new_h, self.new_w, self.h, self.w, self.interp) 164 | 165 | 166 | class RotationTransform(Transform): 167 | """ 168 | This method returns a copy of this image, rotated the given 169 | number of degrees counter clockwise around its center. 170 | """ 171 | 172 | def __init__(self, h, w, angle, expand=True, center=None, interp=None): 173 | """ 174 | Args: 175 | h, w (int): original image size 176 | angle (float): degrees for rotation 177 | expand (bool): choose if the image should be resized to fit the whole 178 | rotated image (default), or simply cropped 179 | center (tuple (width, height)): coordinates of the rotation center 180 | if left to None, the center will be fit to the center of each image 181 | center has no effect if expand=True because it only affects shifting 182 | interp: cv2 interpolation method, default cv2.INTER_LINEAR 183 | """ 184 | super().__init__() 185 | image_center = np.array((w / 2, h / 2)) 186 | if center is None: 187 | center = image_center 188 | if interp is None: 189 | interp = cv2.INTER_LINEAR 190 | abs_cos, abs_sin = (abs(np.cos(np.deg2rad(angle))), abs(np.sin(np.deg2rad(angle)))) 191 | if expand: 192 | # find the new width and height bounds 193 | bound_w, bound_h = np.rint( 194 | [h * abs_sin + w * abs_cos, h * abs_cos + w * abs_sin] 195 | ).astype(int) 196 | else: 197 | bound_w, bound_h = w, h 198 | 199 | self._set_attributes(locals()) 200 | self.rm_coords = self.create_rotation_matrix() 201 | # Needed because of this problem https://github.com/opencv/opencv/issues/11784 202 | self.rm_image = self.create_rotation_matrix(offset=-0.5) 203 | 204 | def apply_image(self, img, interp=None): 205 | """ 206 | img should be a numpy array, formatted as Height * Width * Nchannels 207 | """ 208 | if len(img) == 0 or self.angle % 360 == 0: 209 | return img 210 | assert img.shape[:2] == (self.h, self.w) 211 | interp = interp if interp is not None else self.interp 212 | return cv2.warpAffine(img, self.rm_image, (self.bound_w, self.bound_h), flags=interp) 213 | 214 | def apply_coords(self, coords): 215 | """ 216 | coords should be a N * 2 array-like, containing N couples of (x, y) points 217 | """ 218 | coords = np.asarray(coords, dtype=float) 219 | if len(coords) == 0 or self.angle % 360 == 0: 220 | return coords 221 | return cv2.transform(coords[:, np.newaxis, :], self.rm_coords)[:, 0, :] 222 | 223 | def apply_segmentation(self, segmentation): 224 | segmentation = self.apply_image(segmentation, interp=cv2.INTER_NEAREST) 225 | return segmentation 226 | 227 | def create_rotation_matrix(self, offset=0): 228 | center = (self.center[0] + offset, self.center[1] + offset) 229 | rm = cv2.getRotationMatrix2D(tuple(center), self.angle, 1) 230 | if self.expand: 231 | # Find the coordinates of the center of rotation in the new image 232 | # The only point for which we know the future coordinates is the center of the image 233 | rot_im_center = cv2.transform(self.image_center[None, None, :] + offset, rm)[0, 0, :] 234 | new_center = np.array([self.bound_w / 2, self.bound_h / 2]) + offset - rot_im_center 235 | # shift the rotation center to the new coordinates 236 | rm[:, 2] += new_center 237 | return rm 238 | 239 | def inverse(self): 240 | """ 241 | The inverse is to rotate it back with expand, and crop to get the original shape. 242 | """ 243 | if not self.expand: # Not possible to inverse if a part of the image is lost 244 | raise NotImplementedError() 245 | rotation = RotationTransform( 246 | self.bound_h, self.bound_w, -self.angle, True, None, self.interp 247 | ) 248 | crop = CropTransform( 249 | (rotation.bound_w - self.w) // 2, (rotation.bound_h - self.h) // 2, self.w, self.h 250 | ) 251 | return TransformList([rotation, crop]) 252 | 253 | 254 | class ColorTransform(Transform): 255 | """ 256 | Generic wrapper for any photometric transforms. 257 | These transformations should only affect the color space and 258 | not the coordinate space of the image (e.g. annotation 259 | coordinates such as bounding boxes should not be changed) 260 | """ 261 | 262 | def __init__(self, op): 263 | """ 264 | Args: 265 | op (Callable): operation to be applied to the image, 266 | which takes in an ndarray and returns an ndarray. 267 | """ 268 | if not callable(op): 269 | raise ValueError("op parameter should be callable") 270 | super().__init__() 271 | self._set_attributes(locals()) 272 | 273 | def apply_image(self, img): 274 | return self.op(img) 275 | 276 | def apply_coords(self, coords): 277 | return coords 278 | 279 | def inverse(self): 280 | return NoOpTransform() 281 | 282 | def apply_segmentation(self, segmentation): 283 | return segmentation 284 | 285 | 286 | class PILColorTransform(ColorTransform): 287 | """ 288 | Generic wrapper for PIL Photometric image transforms, 289 | which affect the color space and not the coordinate 290 | space of the image 291 | """ 292 | 293 | def __init__(self, op): 294 | """ 295 | Args: 296 | op (Callable): operation to be applied to the image, 297 | which takes in a PIL Image and returns a transformed 298 | PIL Image. 299 | For reference on possible operations see: 300 | - https://pillow.readthedocs.io/en/stable/ 301 | """ 302 | if not callable(op): 303 | raise ValueError("op parameter should be callable") 304 | super().__init__(op) 305 | 306 | def apply_image(self, img): 307 | img = Image.fromarray(img) 308 | return np.asarray(super().apply_image(img)) 309 | 310 | 311 | def HFlip_rotated_box(transform, rotated_boxes): 312 | """ 313 | Apply the horizontal flip transform on rotated boxes. 314 | 315 | Args: 316 | rotated_boxes (ndarray): Nx5 floating point array of 317 | (x_center, y_center, width, height, angle_degrees) format 318 | in absolute coordinates. 319 | """ 320 | # Transform x_center 321 | rotated_boxes[:, 0] = transform.width - rotated_boxes[:, 0] 322 | # Transform angle 323 | rotated_boxes[:, 4] = -rotated_boxes[:, 4] 324 | return rotated_boxes 325 | 326 | 327 | def Resize_rotated_box(transform, rotated_boxes): 328 | """ 329 | Apply the resizing transform on rotated boxes. For details of how these (approximation) 330 | formulas are derived, please refer to :meth:`RotatedBoxes.scale`. 331 | 332 | Args: 333 | rotated_boxes (ndarray): Nx5 floating point array of 334 | (x_center, y_center, width, height, angle_degrees) format 335 | in absolute coordinates. 336 | """ 337 | scale_factor_x = transform.new_w * 1.0 / transform.w 338 | scale_factor_y = transform.new_h * 1.0 / transform.h 339 | rotated_boxes[:, 0] *= scale_factor_x 340 | rotated_boxes[:, 1] *= scale_factor_y 341 | theta = rotated_boxes[:, 4] * np.pi / 180.0 342 | c = np.cos(theta) 343 | s = np.sin(theta) 344 | rotated_boxes[:, 2] *= np.sqrt(np.square(scale_factor_x * c) + np.square(scale_factor_y * s)) 345 | rotated_boxes[:, 3] *= np.sqrt(np.square(scale_factor_x * s) + np.square(scale_factor_y * c)) 346 | rotated_boxes[:, 4] = np.arctan2(scale_factor_x * s, scale_factor_y * c) * 180 / np.pi 347 | 348 | return rotated_boxes 349 | 350 | 351 | HFlipTransform.register_type("rotated_box", HFlip_rotated_box) 352 | ResizeTransform.register_type("rotated_box", Resize_rotated_box) 353 | 354 | # not necessary any more with latest fvcore 355 | NoOpTransform.register_type("rotated_box", lambda t, x: x) 356 | -------------------------------------------------------------------------------- /tpnet/engine/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | from .train_loop import * 4 | 5 | __all__ = [k for k in globals().keys() if not k.startswith("_")] 6 | 7 | from .defaults import * -------------------------------------------------------------------------------- /tpnet/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | from .coco_evaluation import COCOEvaluator -------------------------------------------------------------------------------- /tpnet/model_zoo/configs/Base-RCNN-FPN.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | META_ARCHITECTURE: "GeneralizedRCNN" 3 | BACKBONE: 4 | NAME: "build_resnet_fpn_backbone" 5 | RESNETS: 6 | OUT_FEATURES: ["res2", "res3", "res4", "res5"] 7 | FPN: 8 | IN_FEATURES: ["res2", "res3", "res4", "res5"] 9 | ANCHOR_GENERATOR: 10 | SIZES: [[32], [64], [128], [256], [512]] # One size for each in feature map 11 | ASPECT_RATIOS: [[0.5, 1.0, 2.0]] # Three aspect ratios (same for all in feature maps) 12 | RPN: 13 | IN_FEATURES: ["p2", "p3", "p4", "p5", "p6"] 14 | PRE_NMS_TOPK_TRAIN: 2000 # Per FPN level 15 | PRE_NMS_TOPK_TEST: 1000 # Per FPN level 16 | # Detectron1 uses 2000 proposals per-batch, 17 | # (See "modeling/rpn/rpn_outputs.py" for details of this legacy issue) 18 | # which is approximately 1000 proposals per-image since the default batch size for FPN is 2. 19 | POST_NMS_TOPK_TRAIN: 1000 20 | POST_NMS_TOPK_TEST: 1000 21 | ROI_HEADS: 22 | NAME: "StandardROIHeads" 23 | IN_FEATURES: ["p2", "p3", "p4", "p5"] 24 | ROI_BOX_HEAD: 25 | NAME: "FastRCNNConvFCHead" 26 | NUM_FC: 2 27 | POOLER_RESOLUTION: 7 28 | ROI_MASK_HEAD: 29 | NAME: "MaskRCNNConvUpsampleHead" 30 | NUM_CONV: 4 31 | POOLER_RESOLUTION: 14 32 | DATASETS: 33 | TRAIN: ("coco_2017_train",) 34 | TEST: ("coco_2017_val",) 35 | SOLVER: 36 | IMS_PER_BATCH: 16 37 | BASE_LR: 0.02 38 | STEPS: (60000, 80000) 39 | MAX_ITER: 90000 40 | INPUT: 41 | MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800) 42 | VERSION: 2 43 | -------------------------------------------------------------------------------- /tpnet/model_zoo/configs/COCO-Semisupervised/cascade_mask_rcnn_R_50_FPN_100perc.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../Base-RCNN-FPN.yaml" 2 | MODEL: 3 | PIXEL_MEAN: [123.675, 116.280, 103.530] 4 | PIXEL_STD: [58.395, 57.120, 57.375] 5 | WEIGHTS: "http://dl.fbaipublicfiles.com/cutler/checkpoints/cutler_cascade_final.pth" 6 | MASK_ON: True 7 | BACKBONE: 8 | FREEZE_AT: 0 9 | RESNETS: 10 | DEPTH: 50 11 | NORM: "SyncBN" 12 | STRIDE_IN_1X1: False 13 | FPN: 14 | NORM: "SyncBN" 15 | ROI_BOX_HEAD: 16 | CLS_AGNOSTIC_BBOX_REG: True 17 | ROI_HEADS: 18 | NAME: CustomCascadeROIHeads 19 | RPN: 20 | POST_NMS_TOPK_TRAIN: 2000 21 | DATASETS: 22 | TRAIN: ("coco_2017_train",) 23 | TEST: ("coco_2017_val",) 24 | SOLVER: 25 | IMS_PER_BATCH: 16 26 | BASE_LR: 0.02 27 | STEPS: (60000, 80000) 28 | MAX_ITER: 90000 29 | BASE_LR_MULTIPLIER: 2 30 | BASE_LR_MULTIPLIER_NAMES: ['roi_heads.mask_head.predictor', 'roi_heads.box_predictor.0.cls_score', 'roi_heads.box_predictor.0.bbox_pred', 'roi_heads.box_predictor.1.cls_score', 'roi_heads.box_predictor.1.bbox_pred', 'roi_heads.box_predictor.2.cls_score', 'roi_heads.box_predictor.2.bbox_pred'] 31 | INPUT: 32 | MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800) 33 | MAX_SIZE_TRAIN: 1333 34 | MASK_FORMAT: "bitmask" 35 | FORMAT: "RGB" 36 | TEST: 37 | PRECISE_BN: 38 | ENABLED: True 39 | EVAL_PERIOD: 5000 40 | OUTPUT_DIR: "output/100perc" -------------------------------------------------------------------------------- /tpnet/model_zoo/configs/COCO-Semisupervised/cascade_mask_rcnn_R_50_FPN_10perc.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../Base-RCNN-FPN.yaml" 2 | MODEL: 3 | PIXEL_MEAN: [123.675, 116.280, 103.530] 4 | PIXEL_STD: [58.395, 57.120, 57.375] 5 | WEIGHTS: "http://dl.fbaipublicfiles.com/cutler/checkpoints/cutler_cascade_final.pth" 6 | MASK_ON: True 7 | BACKBONE: 8 | FREEZE_AT: 0 9 | RESNETS: 10 | DEPTH: 50 11 | NORM: "SyncBN" 12 | STRIDE_IN_1X1: False 13 | FPN: 14 | NORM: "SyncBN" 15 | ROI_BOX_HEAD: 16 | CLS_AGNOSTIC_BBOX_REG: True 17 | ROI_HEADS: 18 | NAME: CustomCascadeROIHeads 19 | RPN: 20 | POST_NMS_TOPK_TRAIN: 2000 21 | DATASETS: 22 | TRAIN: ("coco_semi_10perc",) 23 | TEST: ("coco_2017_val",) 24 | SOLVER: 25 | IMS_PER_BATCH: 16 26 | BASE_LR: 0.04 27 | STEPS: (6000, 8000) 28 | MAX_ITER: 9000 29 | BASE_LR_MULTIPLIER: 4 30 | BASE_LR_MULTIPLIER_NAMES: ['roi_heads.mask_head.predictor', 'roi_heads.box_predictor.0.cls_score', 'roi_heads.box_predictor.0.bbox_pred', 'roi_heads.box_predictor.1.cls_score', 'roi_heads.box_predictor.1.bbox_pred', 'roi_heads.box_predictor.2.cls_score', 'roi_heads.box_predictor.2.bbox_pred'] 31 | INPUT: 32 | MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800) 33 | MAX_SIZE_TRAIN: 1333 34 | MASK_FORMAT: "bitmask" 35 | FORMAT: "RGB" 36 | TEST: 37 | PRECISE_BN: 38 | ENABLED: True 39 | EVAL_PERIOD: 5000 40 | OUTPUT_DIR: "output/10perc" -------------------------------------------------------------------------------- /tpnet/model_zoo/configs/COCO-Semisupervised/cascade_mask_rcnn_R_50_FPN_1perc.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../Base-RCNN-FPN.yaml" 2 | MODEL: 3 | PIXEL_MEAN: [123.675, 116.280, 103.530] 4 | PIXEL_STD: [58.395, 57.120, 57.375] 5 | WEIGHTS: "http://dl.fbaipublicfiles.com/cutler/checkpoints/cutler_cascade_final.pth" 6 | MASK_ON: True 7 | BACKBONE: 8 | FREEZE_AT: 0 9 | RESNETS: 10 | DEPTH: 50 11 | NORM: "SyncBN" 12 | STRIDE_IN_1X1: False 13 | FPN: 14 | NORM: "SyncBN" 15 | ROI_BOX_HEAD: 16 | CLS_AGNOSTIC_BBOX_REG: True 17 | ROI_HEADS: 18 | NAME: CustomCascadeROIHeads 19 | RPN: 20 | POST_NMS_TOPK_TRAIN: 2000 21 | DATASETS: 22 | TRAIN: ("coco_semi_1perc",) 23 | TEST: ("coco_2017_val",) 24 | SOLVER: 25 | IMS_PER_BATCH: 16 26 | BASE_LR: 0.04 27 | STEPS: (2400, 3200) 28 | MAX_ITER: 3600 29 | WARMUP_FACTOR: 0.001 30 | WARMUP_ITERS: 1000 31 | BASE_LR_MULTIPLIER: 4 32 | BASE_LR_MULTIPLIER_NAMES: ['roi_heads.mask_head.predictor', 'roi_heads.box_predictor.0.cls_score', 'roi_heads.box_predictor.0.bbox_pred', 'roi_heads.box_predictor.1.cls_score', 'roi_heads.box_predictor.1.bbox_pred', 'roi_heads.box_predictor.2.cls_score', 'roi_heads.box_predictor.2.bbox_pred'] 33 | INPUT: 34 | MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800) 35 | MAX_SIZE_TRAIN: 1333 36 | MASK_FORMAT: "bitmask" 37 | FORMAT: "RGB" 38 | TEST: 39 | PRECISE_BN: 40 | ENABLED: True 41 | EVAL_PERIOD: 5000 42 | OUTPUT_DIR: "output/1perc" -------------------------------------------------------------------------------- /tpnet/model_zoo/configs/COCO-Semisupervised/cascade_mask_rcnn_R_50_FPN_20perc.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../Base-RCNN-FPN.yaml" 2 | MODEL: 3 | PIXEL_MEAN: [123.675, 116.280, 103.530] 4 | PIXEL_STD: [58.395, 57.120, 57.375] 5 | WEIGHTS: "http://dl.fbaipublicfiles.com/cutler/checkpoints/cutler_cascade_final.pth" 6 | MASK_ON: True 7 | BACKBONE: 8 | FREEZE_AT: 0 9 | RESNETS: 10 | DEPTH: 50 11 | NORM: "SyncBN" 12 | STRIDE_IN_1X1: False 13 | FPN: 14 | NORM: "SyncBN" 15 | ROI_BOX_HEAD: 16 | CLS_AGNOSTIC_BBOX_REG: True 17 | ROI_HEADS: 18 | NAME: CustomCascadeROIHeads 19 | RPN: 20 | POST_NMS_TOPK_TRAIN: 2000 21 | DATASETS: 22 | TRAIN: ("coco_semi_20perc",) 23 | TEST: ("coco_2017_val",) 24 | SOLVER: 25 | IMS_PER_BATCH: 16 26 | BASE_LR: 0.04 27 | STEPS: (12000, 16000) 28 | MAX_ITER: 18000 29 | BASE_LR_MULTIPLIER: 4 30 | BASE_LR_MULTIPLIER_NAMES: ['roi_heads.mask_head.predictor', 'roi_heads.box_predictor.0.cls_score', 'roi_heads.box_predictor.0.bbox_pred', 'roi_heads.box_predictor.1.cls_score', 'roi_heads.box_predictor.1.bbox_pred', 'roi_heads.box_predictor.2.cls_score', 'roi_heads.box_predictor.2.bbox_pred'] 31 | INPUT: 32 | MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800) 33 | MAX_SIZE_TRAIN: 1333 34 | MASK_FORMAT: "bitmask" 35 | FORMAT: "RGB" 36 | TEST: 37 | PRECISE_BN: 38 | ENABLED: True 39 | EVAL_PERIOD: 5000 40 | OUTPUT_DIR: "output/20perc" -------------------------------------------------------------------------------- /tpnet/model_zoo/configs/COCO-Semisupervised/cascade_mask_rcnn_R_50_FPN_2perc.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../Base-RCNN-FPN.yaml" 2 | MODEL: 3 | PIXEL_MEAN: [123.675, 116.280, 103.530] 4 | PIXEL_STD: [58.395, 57.120, 57.375] 5 | WEIGHTS: "http://dl.fbaipublicfiles.com/cutler/checkpoints/cutler_cascade_final.pth" 6 | MASK_ON: True 7 | BACKBONE: 8 | FREEZE_AT: 0 9 | RESNETS: 10 | DEPTH: 50 11 | NORM: "SyncBN" 12 | STRIDE_IN_1X1: False 13 | FPN: 14 | NORM: "SyncBN" 15 | ROI_BOX_HEAD: 16 | CLS_AGNOSTIC_BBOX_REG: True 17 | ROI_HEADS: 18 | NAME: CustomCascadeROIHeads 19 | RPN: 20 | POST_NMS_TOPK_TRAIN: 2000 21 | DATASETS: 22 | TRAIN: ("coco_semi_2perc",) 23 | TEST: ("coco_2017_val",) 24 | SOLVER: 25 | IMS_PER_BATCH: 16 26 | BASE_LR: 0.04 27 | STEPS: (2400, 3200) 28 | MAX_ITER: 3600 29 | WARMUP_FACTOR: 0.001 30 | WARMUP_ITERS: 1000 31 | BASE_LR_MULTIPLIER: 4 32 | BASE_LR_MULTIPLIER_NAMES: ['roi_heads.mask_head.predictor', 'roi_heads.box_predictor.0.cls_score', 'roi_heads.box_predictor.0.bbox_pred', 'roi_heads.box_predictor.1.cls_score', 'roi_heads.box_predictor.1.bbox_pred', 'roi_heads.box_predictor.2.cls_score', 'roi_heads.box_predictor.2.bbox_pred'] 33 | INPUT: 34 | MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800) 35 | MAX_SIZE_TRAIN: 1333 36 | MASK_FORMAT: "bitmask" 37 | FORMAT: "RGB" 38 | TEST: 39 | PRECISE_BN: 40 | ENABLED: True 41 | EVAL_PERIOD: 5000 42 | OUTPUT_DIR: "output/2perc" -------------------------------------------------------------------------------- /tpnet/model_zoo/configs/COCO-Semisupervised/cascade_mask_rcnn_R_50_FPN_30perc.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../Base-RCNN-FPN.yaml" 2 | MODEL: 3 | PIXEL_MEAN: [123.675, 116.280, 103.530] 4 | PIXEL_STD: [58.395, 57.120, 57.375] 5 | WEIGHTS: "http://dl.fbaipublicfiles.com/cutler/checkpoints/cutler_cascade_final.pth" 6 | MASK_ON: True 7 | BACKBONE: 8 | FREEZE_AT: 0 9 | RESNETS: 10 | DEPTH: 50 11 | NORM: "SyncBN" 12 | STRIDE_IN_1X1: False 13 | FPN: 14 | NORM: "SyncBN" 15 | ROI_BOX_HEAD: 16 | CLS_AGNOSTIC_BBOX_REG: True 17 | ROI_HEADS: 18 | NAME: CustomCascadeROIHeads 19 | RPN: 20 | POST_NMS_TOPK_TRAIN: 2000 21 | DATASETS: 22 | TRAIN: ("coco_semi_30perc",) 23 | TEST: ("coco_2017_val",) 24 | SOLVER: 25 | IMS_PER_BATCH: 16 26 | BASE_LR: 0.04 27 | STEPS: (18000, 24000) 28 | MAX_ITER: 27000 29 | BASE_LR_MULTIPLIER: 4 30 | BASE_LR_MULTIPLIER_NAMES: ['roi_heads.mask_head.predictor', 'roi_heads.box_predictor.0.cls_score', 'roi_heads.box_predictor.0.bbox_pred', 'roi_heads.box_predictor.1.cls_score', 'roi_heads.box_predictor.1.bbox_pred', 'roi_heads.box_predictor.2.cls_score', 'roi_heads.box_predictor.2.bbox_pred'] 31 | INPUT: 32 | MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800) 33 | MAX_SIZE_TRAIN: 1333 34 | MASK_FORMAT: "bitmask" 35 | FORMAT: "RGB" 36 | TEST: 37 | PRECISE_BN: 38 | ENABLED: True 39 | EVAL_PERIOD: 5000 40 | OUTPUT_DIR: "output/30perc" -------------------------------------------------------------------------------- /tpnet/model_zoo/configs/COCO-Semisupervised/cascade_mask_rcnn_R_50_FPN_40perc.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../Base-RCNN-FPN.yaml" 2 | MODEL: 3 | PIXEL_MEAN: [123.675, 116.280, 103.530] 4 | PIXEL_STD: [58.395, 57.120, 57.375] 5 | WEIGHTS: "http://dl.fbaipublicfiles.com/cutler/checkpoints/cutler_cascade_final.pth" 6 | MASK_ON: True 7 | BACKBONE: 8 | FREEZE_AT: 0 9 | RESNETS: 10 | DEPTH: 50 11 | NORM: "SyncBN" 12 | STRIDE_IN_1X1: False 13 | FPN: 14 | NORM: "SyncBN" 15 | ROI_BOX_HEAD: 16 | CLS_AGNOSTIC_BBOX_REG: True 17 | ROI_HEADS: 18 | NAME: CustomCascadeROIHeads 19 | RPN: 20 | POST_NMS_TOPK_TRAIN: 2000 21 | DATASETS: 22 | TRAIN: ("coco_semi_40perc",) 23 | TEST: ("coco_2017_val",) 24 | SOLVER: 25 | IMS_PER_BATCH: 16 26 | BASE_LR: 0.04 27 | STEPS: (24000, 32000) 28 | MAX_ITER: 36000 29 | BASE_LR_MULTIPLIER: 4 30 | BASE_LR_MULTIPLIER_NAMES: ['roi_heads.mask_head.predictor', 'roi_heads.box_predictor.0.cls_score', 'roi_heads.box_predictor.0.bbox_pred', 'roi_heads.box_predictor.1.cls_score', 'roi_heads.box_predictor.1.bbox_pred', 'roi_heads.box_predictor.2.cls_score', 'roi_heads.box_predictor.2.bbox_pred'] 31 | INPUT: 32 | MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800) 33 | MAX_SIZE_TRAIN: 1333 34 | MASK_FORMAT: "bitmask" 35 | FORMAT: "RGB" 36 | TEST: 37 | PRECISE_BN: 38 | ENABLED: True 39 | EVAL_PERIOD: 5000 40 | OUTPUT_DIR: "output/40perc" -------------------------------------------------------------------------------- /tpnet/model_zoo/configs/COCO-Semisupervised/cascade_mask_rcnn_R_50_FPN_50perc.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../Base-RCNN-FPN.yaml" 2 | MODEL: 3 | PIXEL_MEAN: [123.675, 116.280, 103.530] 4 | PIXEL_STD: [58.395, 57.120, 57.375] 5 | WEIGHTS: "http://dl.fbaipublicfiles.com/cutler/checkpoints/cutler_cascade_final.pth" 6 | MASK_ON: True 7 | BACKBONE: 8 | FREEZE_AT: 0 9 | RESNETS: 10 | DEPTH: 50 11 | NORM: "SyncBN" 12 | STRIDE_IN_1X1: False 13 | FPN: 14 | NORM: "SyncBN" 15 | ROI_BOX_HEAD: 16 | CLS_AGNOSTIC_BBOX_REG: True 17 | ROI_HEADS: 18 | NAME: CustomCascadeROIHeads 19 | RPN: 20 | POST_NMS_TOPK_TRAIN: 2000 21 | DATASETS: 22 | TRAIN: ("coco_semi_50perc",) 23 | TEST: ("coco_2017_val",) 24 | SOLVER: 25 | IMS_PER_BATCH: 16 26 | BASE_LR: 0.02 27 | STEPS: (30000, 40000) 28 | MAX_ITER: 45000 29 | BASE_LR_MULTIPLIER: 2 30 | BASE_LR_MULTIPLIER_NAMES: ['roi_heads.mask_head.predictor', 'roi_heads.box_predictor.0.cls_score', 'roi_heads.box_predictor.0.bbox_pred', 'roi_heads.box_predictor.1.cls_score', 'roi_heads.box_predictor.1.bbox_pred', 'roi_heads.box_predictor.2.cls_score', 'roi_heads.box_predictor.2.bbox_pred'] 31 | INPUT: 32 | MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800) 33 | MAX_SIZE_TRAIN: 1333 34 | MASK_FORMAT: "bitmask" 35 | FORMAT: "RGB" 36 | TEST: 37 | PRECISE_BN: 38 | ENABLED: True 39 | EVAL_PERIOD: 5000 40 | OUTPUT_DIR: "output/50perc" -------------------------------------------------------------------------------- /tpnet/model_zoo/configs/COCO-Semisupervised/cascade_mask_rcnn_R_50_FPN_5perc.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../Base-RCNN-FPN.yaml" 2 | MODEL: 3 | PIXEL_MEAN: [123.675, 116.280, 103.530] 4 | PIXEL_STD: [58.395, 57.120, 57.375] 5 | WEIGHTS: "http://dl.fbaipublicfiles.com/cutler/checkpoints/cutler_cascade_final.pth" 6 | MASK_ON: True 7 | BACKBONE: 8 | FREEZE_AT: 0 9 | RESNETS: 10 | DEPTH: 50 11 | NORM: "SyncBN" 12 | STRIDE_IN_1X1: False 13 | FPN: 14 | NORM: "SyncBN" 15 | ROI_BOX_HEAD: 16 | CLS_AGNOSTIC_BBOX_REG: True 17 | ROI_HEADS: 18 | NAME: CustomCascadeROIHeads 19 | RPN: 20 | POST_NMS_TOPK_TRAIN: 2000 21 | DATASETS: 22 | TRAIN: ("coco_semi_5perc",) 23 | TEST: ("coco_2017_val",) 24 | SOLVER: 25 | IMS_PER_BATCH: 16 26 | BASE_LR: 0.04 27 | STEPS: (3000, 4000) 28 | MAX_ITER: 4500 29 | WARMUP_FACTOR: 0.001 30 | WARMUP_ITERS: 1000 31 | BASE_LR_MULTIPLIER: 4 32 | BASE_LR_MULTIPLIER_NAMES: ['roi_heads.mask_head.predictor', 'roi_heads.box_predictor.0.cls_score', 'roi_heads.box_predictor.0.bbox_pred', 'roi_heads.box_predictor.1.cls_score', 'roi_heads.box_predictor.1.bbox_pred', 'roi_heads.box_predictor.2.cls_score', 'roi_heads.box_predictor.2.bbox_pred'] 33 | INPUT: 34 | MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800) 35 | MAX_SIZE_TRAIN: 1333 36 | MASK_FORMAT: "bitmask" 37 | FORMAT: "RGB" 38 | TEST: 39 | PRECISE_BN: 40 | ENABLED: True 41 | EVAL_PERIOD: 5000 42 | OUTPUT_DIR: "output/5perc" -------------------------------------------------------------------------------- /tpnet/model_zoo/configs/COCO-Semisupervised/cascade_mask_rcnn_R_50_FPN_60perc.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../Base-RCNN-FPN.yaml" 2 | MODEL: 3 | PIXEL_MEAN: [123.675, 116.280, 103.530] 4 | PIXEL_STD: [58.395, 57.120, 57.375] 5 | WEIGHTS: "http://dl.fbaipublicfiles.com/cutler/checkpoints/cutler_cascade_final.pth" 6 | MASK_ON: True 7 | BACKBONE: 8 | FREEZE_AT: 0 9 | RESNETS: 10 | DEPTH: 50 11 | NORM: "SyncBN" 12 | STRIDE_IN_1X1: False 13 | FPN: 14 | NORM: "SyncBN" 15 | ROI_BOX_HEAD: 16 | CLS_AGNOSTIC_BBOX_REG: True 17 | ROI_HEADS: 18 | NAME: CustomCascadeROIHeads 19 | RPN: 20 | POST_NMS_TOPK_TRAIN: 2000 21 | DATASETS: 22 | TRAIN: ("coco_semi_60perc",) 23 | TEST: ("coco_2017_val",) 24 | SOLVER: 25 | IMS_PER_BATCH: 16 26 | BASE_LR: 0.02 27 | STEPS: (36000, 48000) 28 | MAX_ITER: 54000 29 | BASE_LR_MULTIPLIER: 2 30 | BASE_LR_MULTIPLIER_NAMES: ['roi_heads.mask_head.predictor', 'roi_heads.box_predictor.0.cls_score', 'roi_heads.box_predictor.0.bbox_pred', 'roi_heads.box_predictor.1.cls_score', 'roi_heads.box_predictor.1.bbox_pred', 'roi_heads.box_predictor.2.cls_score', 'roi_heads.box_predictor.2.bbox_pred'] 31 | INPUT: 32 | MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800) 33 | MAX_SIZE_TRAIN: 1333 34 | MASK_FORMAT: "bitmask" 35 | FORMAT: "RGB" 36 | TEST: 37 | PRECISE_BN: 38 | ENABLED: True 39 | EVAL_PERIOD: 5000 40 | OUTPUT_DIR: "output/60perc" -------------------------------------------------------------------------------- /tpnet/model_zoo/configs/COCO-Semisupervised/cascade_mask_rcnn_R_50_FPN_80perc.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../Base-RCNN-FPN.yaml" 2 | MODEL: 3 | PIXEL_MEAN: [123.675, 116.280, 103.530] 4 | PIXEL_STD: [58.395, 57.120, 57.375] 5 | WEIGHTS: "http://dl.fbaipublicfiles.com/cutler/checkpoints/cutler_cascade_final.pth" 6 | MASK_ON: True 7 | BACKBONE: 8 | FREEZE_AT: 0 9 | RESNETS: 10 | DEPTH: 50 11 | NORM: "SyncBN" 12 | STRIDE_IN_1X1: False 13 | FPN: 14 | NORM: "SyncBN" 15 | ROI_BOX_HEAD: 16 | CLS_AGNOSTIC_BBOX_REG: True 17 | ROI_HEADS: 18 | NAME: CustomCascadeROIHeads 19 | RPN: 20 | POST_NMS_TOPK_TRAIN: 2000 21 | DATASETS: 22 | TRAIN: ("coco_semi_80perc",) 23 | TEST: ("coco_2017_val",) 24 | SOLVER: 25 | IMS_PER_BATCH: 16 26 | BASE_LR: 0.02 27 | STEPS: (48000, 64000) 28 | MAX_ITER: 72000 29 | BASE_LR_MULTIPLIER: 2 30 | BASE_LR_MULTIPLIER_NAMES: ['roi_heads.mask_head.predictor', 'roi_heads.box_predictor.0.cls_score', 'roi_heads.box_predictor.0.bbox_pred', 'roi_heads.box_predictor.1.cls_score', 'roi_heads.box_predictor.1.bbox_pred', 'roi_heads.box_predictor.2.cls_score', 'roi_heads.box_predictor.2.bbox_pred'] 31 | INPUT: 32 | MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800) 33 | MAX_SIZE_TRAIN: 1333 34 | MASK_FORMAT: "bitmask" 35 | FORMAT: "RGB" 36 | TEST: 37 | PRECISE_BN: 38 | ENABLED: True 39 | EVAL_PERIOD: 5000 40 | OUTPUT_DIR: "output/80perc" -------------------------------------------------------------------------------- /tpnet/model_zoo/configs/CutLER-ImageNet/cascade_mask_GT.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../Base-RCNN-FPN.yaml" 2 | DATALOADER: 3 | COPY_PASTE: True 4 | COPY_PASTE_RATE: 1.0 5 | VISUALIZE_COPY_PASTE: False 6 | COPY_PASTE_RANDOM_NUM: True 7 | COPY_PASTE_MIN_RATIO: 0.3 8 | COPY_PASTE_MAX_RATIO: 1.0 9 | NUM_WORKERS: 0 10 | MODEL: 11 | PIXEL_MEAN: [123.675, 116.280, 103.530] 12 | PIXEL_STD: [58.395, 57.120, 57.375] 13 | WEIGHTS: 'http://dl.fbaipublicfiles.com/cutler/checkpoints/dino_RN50_pretrain_d2_format.pkl' 14 | MASK_ON: True 15 | BACKBONE: 16 | FREEZE_AT: 0 17 | RESNETS: 18 | DEPTH: 50 19 | NORM: "SyncBN" 20 | STRIDE_IN_1X1: False 21 | FPN: 22 | NORM: "SyncBN" 23 | ROI_BOX_HEAD: 24 | CLS_AGNOSTIC_BBOX_REG: True 25 | ROI_HEADS: 26 | NAME: CustomCascadeROIHeads 27 | NUM_CLASSES: 1 28 | SCORE_THRESH_TEST: 0.0 29 | POSITIVE_FRACTION: 0.25 30 | USE_DROPLOSS: True 31 | DROPLOSS_IOU_THRESH: 0.01 32 | RPN: 33 | POST_NMS_TOPK_TRAIN: 4000 34 | NMS_THRESH: 0.65 35 | DATASETS: 36 | TRAIN: ("my_data_train_coco_cod_style",) 37 | TEST: ("my_data_test_coco_cod_style", "my_data_test_coco_nc4k_style") 38 | SOLVER: 39 | IMS_PER_BATCH: 8 40 | BASE_LR: 0.005 41 | WEIGHT_DECAY: 0.00005 42 | STEPS: (20000,) 43 | MAX_ITER: 25000 44 | GAMMA: 0.02 45 | CLIP_GRADIENTS: 46 | CLIP_TYPE: norm 47 | CLIP_VALUE: 1.0 48 | ENABLED: true 49 | NORM_TYPE: 2.0 50 | AMP: 51 | ENABLED: True 52 | INPUT: 53 | MIN_SIZE_TRAIN: (240, 320, 480, 640, 672, 704, 736, 768, 800, 1024) 54 | MAX_SIZE_TRAIN: 1333 55 | MASK_FORMAT: "bitmask" 56 | FORMAT: "RGB" 57 | TEST: 58 | PRECISE_BN: 59 | ENABLED: True 60 | NUM_ITER: 200 61 | DETECTIONS_PER_IMAGE: 100 62 | OUTPUT_DIR: "output/" -------------------------------------------------------------------------------- /tpnet/model_zoo/configs/CutLER-ImageNet/cascade_mask_rcnn_R_50_FPN.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../Base-RCNN-FPN.yaml" 2 | #DATALOADER: 3 | # COPY_PASTE: True 4 | # COPY_PASTE_RATE: 1.0 5 | # VISUALIZE_COPY_PASTE: False 6 | # COPY_PASTE_RANDOM_NUM: True 7 | # COPY_PASTE_MIN_RATIO: 0.3 8 | # COPY_PASTE_MAX_RATIO: 1.0 9 | # NUM_WORKERS: 0 10 | MODEL: 11 | PIXEL_MEAN: [123.675, 116.280, 103.530] 12 | PIXEL_STD: [58.395, 57.120, 57.375] 13 | WEIGHTS: 'http://dl.fbaipublicfiles.com/cutler/checkpoints/dino_RN50_pretrain_d2_format.pkl' 14 | MASK_ON: True 15 | BACKBONE: 16 | FREEZE_AT: 0 17 | RESNETS: 18 | DEPTH: 50 19 | NORM: "SyncBN" 20 | STRIDE_IN_1X1: False 21 | FPN: 22 | NORM: "SyncBN" 23 | ROI_BOX_HEAD: 24 | CLS_AGNOSTIC_BBOX_REG: True 25 | ROI_HEADS: 26 | NAME: CustomCascadeROIHeads 27 | NUM_CLASSES: 1 28 | SCORE_THRESH_TEST: 0.0 29 | POSITIVE_FRACTION: 0.25 30 | USE_DROPLOSS: True 31 | DROPLOSS_IOU_THRESH: 0.01 32 | RPN: 33 | POST_NMS_TOPK_TRAIN: 4000 34 | NMS_THRESH: 0.65 35 | DATASETS: 36 | TRAIN: ("my_data_train_coco_cod_style",) 37 | TEST: ("my_data_test_coco_cod_style", "my_data_test_coco_nc4k_style") 38 | SOLVER: 39 | IMS_PER_BATCH: 12 40 | BASE_LR: 0.005 41 | WEIGHT_DECAY: 0.00005 42 | STEPS: (20000,) 43 | MAX_ITER: 25000 44 | GAMMA: 0.2 45 | CLIP_GRADIENTS: 46 | CLIP_TYPE: norm 47 | CLIP_VALUE: 1.0 48 | ENABLED: true 49 | NORM_TYPE: 2.0 50 | AMP: 51 | ENABLED: True 52 | INPUT: 53 | MIN_SIZE_TRAIN: (240, 320, 480, 640, 672, 704, 736, 768, 800, 1024) 54 | MAX_SIZE_TRAIN: 1333 55 | MASK_FORMAT: "bitmask" 56 | FORMAT: "RGB" 57 | TEST: 58 | EVAL_PERIOD: 2000 59 | PRECISE_BN: 60 | ENABLED: True 61 | NUM_ITER: 200 62 | DETECTIONS_PER_IMAGE: 100 63 | OUTPUT_DIR: "output/tpnet" -------------------------------------------------------------------------------- /tpnet/model_zoo/configs/CutLER-ImageNet/cascade_mask_rcnn_R_50_FPN_demo.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../Base-RCNN-FPN.yaml" 2 | DATALOADER: 3 | COPY_PASTE: True 4 | COPY_PASTE_RATE: 1.0 5 | VISUALIZE_COPY_PASTE: False 6 | COPY_PASTE_RANDOM_NUM: True 7 | COPY_PASTE_MIN_RATIO: 0.3 8 | COPY_PASTE_MAX_RATIO: 1.0 9 | NUM_WORKERS: 0 10 | MODEL: 11 | PIXEL_MEAN: [123.675, 116.280, 103.530] 12 | PIXEL_STD: [58.395, 57.120, 57.375] 13 | WEIGHTS: 'http://dl.fbaipublicfiles.com/cutler/checkpoints/dino_RN50_pretrain_d2_format.pkl' 14 | MASK_ON: True 15 | BACKBONE: 16 | FREEZE_AT: 0 17 | RESNETS: 18 | DEPTH: 50 19 | NORM: "SyncBN" 20 | STRIDE_IN_1X1: False 21 | FPN: 22 | NORM: "SyncBN" 23 | ROI_BOX_HEAD: 24 | CLS_AGNOSTIC_BBOX_REG: True 25 | ROI_HEADS: 26 | NAME: CustomCascadeROIHeads 27 | NUM_CLASSES: 1 28 | SCORE_THRESH_TEST: 0.0 29 | POSITIVE_FRACTION: 0.25 30 | USE_DROPLOSS: True 31 | DROPLOSS_IOU_THRESH: 0.01 32 | RPN: 33 | POST_NMS_TOPK_TRAIN: 4000 34 | NMS_THRESH: 0.65 35 | DATASETS: 36 | TRAIN: ("imagenet_train",) 37 | TEST: ("imagenet_train",) 38 | SOLVER: 39 | IMS_PER_BATCH: 16 40 | BASE_LR: 0.01 41 | WEIGHT_DECAY: 0.00005 42 | STEPS: (80000,) 43 | MAX_ITER: 160000 44 | GAMMA: 0.02 45 | CLIP_GRADIENTS: 46 | CLIP_TYPE: norm 47 | CLIP_VALUE: 1.0 48 | ENABLED: true 49 | NORM_TYPE: 2.0 50 | AMP: 51 | ENABLED: True 52 | INPUT: 53 | MIN_SIZE_TRAIN: (240, 320, 480, 640, 672, 704, 736, 768, 800, 1024) 54 | MAX_SIZE_TRAIN: 1333 55 | MASK_FORMAT: "bitmask" 56 | FORMAT: "RGB" 57 | TEST: 58 | PRECISE_BN: 59 | ENABLED: True 60 | NUM_ITER: 200 61 | DETECTIONS_PER_IMAGE: 100 62 | OUTPUT_DIR: "output/" 63 | -------------------------------------------------------------------------------- /tpnet/model_zoo/configs/CutLER-ImageNet/cascade_mask_rcnn_R_50_FPN_self_train.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../Base-RCNN-FPN.yaml" 2 | DATALOADER: 3 | COPY_PASTE: True 4 | COPY_PASTE_RATE: 1.0 5 | VISUALIZE_COPY_PASTE: False 6 | COPY_PASTE_RANDOM_NUM: True 7 | COPY_PASTE_MIN_RATIO: 0.5 8 | COPY_PASTE_MAX_RATIO: 1.0 9 | NUM_WORKERS: 2 10 | MODEL: 11 | PIXEL_MEAN: [123.675, 116.280, 103.530] 12 | PIXEL_STD: [58.395, 57.120, 57.375] 13 | WEIGHTS: 'http://dl.fbaipublicfiles.com/cutler/checkpoints/cutler_cascade_r1.pth' # round 1 14 | # WEIGHTS: 'http://dl.fbaipublicfiles.com/cutler/checkpoints/cutler_cascade_r2.pth' # round 2 15 | MASK_ON: True 16 | BACKBONE: 17 | FREEZE_AT: 0 18 | RESNETS: 19 | DEPTH: 50 20 | NORM: "SyncBN" 21 | STRIDE_IN_1X1: False 22 | FPN: 23 | NORM: "SyncBN" 24 | ROI_BOX_HEAD: 25 | CLS_AGNOSTIC_BBOX_REG: True 26 | ROI_HEADS: 27 | NAME: CustomCascadeROIHeads 28 | NUM_CLASSES: 1 29 | SCORE_THRESH_TEST: 0.0 30 | POSITIVE_FRACTION: 0.25 31 | USE_DROPLOSS: False 32 | DROPLOSS_IOU_THRESH: 0.01 33 | DATASETS: 34 | TRAIN: ("my_data_train_coco_cod_style_r1",) 35 | TEST: ("my_data_test_coco_cod_style", "my_data_test_coco_nc4k_style") 36 | SOLVER: 37 | IMS_PER_BATCH: 8 38 | BASE_LR: 0.0026 39 | STEPS: (5999,) 40 | MAX_ITER: 6000 41 | GAMMA: 1.0 42 | CLIP_GRADIENTS: 43 | CLIP_TYPE: norm 44 | CLIP_VALUE: 1.0 45 | ENABLED: true 46 | NORM_TYPE: 2.0 47 | AMP: 48 | ENABLED: True 49 | INPUT: 50 | MIN_SIZE_TRAIN: (240, 320, 480, 640, 672, 704, 736, 768, 800, 1024) 51 | MAX_SIZE_TRAIN: 1333 52 | MASK_FORMAT: "bitmask" 53 | FORMAT: "RGB" 54 | TEST: 55 | PRECISE_BN: 56 | ENABLED: True 57 | NUM_ITER: 200 58 | DETECTIONS_PER_IMAGE: 100 59 | OUTPUT_DIR: "output/self-train-r1/" # round 1 60 | # OUTPUT_DIR: "output/self-train-r2/" # round 2 -------------------------------------------------------------------------------- /tpnet/model_zoo/configs/CutLER-ImageNet/mask_rcnn_R_50_FPN.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../Base-RCNN-FPN.yaml" 2 | DATALOADER: 3 | COPY_PASTE: True 4 | COPY_PASTE_RATE: 1.0 5 | VISUALIZE_COPY_PASTE: False 6 | COPY_PASTE_RANDOM_NUM: True 7 | COPY_PASTE_MIN_RATIO: 0.3 8 | COPY_PASTE_MAX_RATIO: 1.0 9 | MODEL: 10 | PIXEL_MEAN: [123.675, 116.280, 103.530] 11 | PIXEL_STD: [58.395, 57.120, 57.375] 12 | WEIGHTS: 'http://dl.fbaipublicfiles.com/cutler/checkpoints/dino_RN50_pretrain_d2_format.pkl' 13 | MASK_ON: True 14 | BACKBONE: 15 | FREEZE_AT: 0 16 | RESNETS: 17 | DEPTH: 50 18 | NORM: "SyncBN" 19 | STRIDE_IN_1X1: False 20 | FPN: 21 | NORM: "SyncBN" 22 | ROI_HEADS: 23 | NAME: "CustomStandardROIHeads" 24 | NUM_CLASSES: 1 25 | SCORE_THRESH_TEST: 0.0 26 | USE_DROPLOSS: True 27 | DROPLOSS_IOU_THRESH: 0.01 28 | RPN: 29 | POST_NMS_TOPK_TRAIN: 4000 30 | NMS_THRESH: 0.65 31 | DATASETS: 32 | 33 | TRAIN: ("my_data_train_coco_cod_style",) 34 | TEST: ("my_data_test_coco_nc4k_style0","my_data_test_coco_nc4k_style1","my_data_test_coco_nc4k_style2") 35 | SOLVER: 36 | IMS_PER_BATCH: 16 37 | BASE_LR: 0.01 38 | WEIGHT_DECAY: 0.00005 39 | STEPS: (80000,) 40 | MAX_ITER: 160000 41 | CLIP_GRADIENTS: 42 | CLIP_TYPE: norm 43 | CLIP_VALUE: 1.0 44 | ENABLED: true 45 | NORM_TYPE: 2.0 46 | INPUT: 47 | MIN_SIZE_TRAIN: (240, 320, 480, 640, 672, 704, 736, 768, 800, 1024) 48 | MAX_SIZE_TRAIN: 1333 49 | MASK_FORMAT: "bitmask" 50 | FORMAT: "RGB" 51 | TEST: 52 | PRECISE_BN: 53 | ENABLED: True 54 | OUTPUT_DIR: "output/" -------------------------------------------------------------------------------- /tpnet/model_zoo/configs/CutLER-ImageNet/test.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../Base-RCNN-FPN.yaml" 2 | DATALOADER: 3 | COPY_PASTE: True 4 | COPY_PASTE_RATE: 1.0 5 | VISUALIZE_COPY_PASTE: False 6 | COPY_PASTE_RANDOM_NUM: True 7 | COPY_PASTE_MIN_RATIO: 0.3 8 | COPY_PASTE_MAX_RATIO: 1.0 9 | NUM_WORKERS: 0 10 | MODEL: 11 | PIXEL_MEAN: [123.675, 116.280, 103.530] 12 | PIXEL_STD: [58.395, 57.120, 57.375] 13 | WEIGHTS: 'http://dl.fbaipublicfiles.com/cutler/checkpoints/dino_RN50_pretrain_d2_format.pkl' 14 | MASK_ON: True 15 | BACKBONE: 16 | FREEZE_AT: 0 17 | RESNETS: 18 | DEPTH: 50 19 | NORM: "SyncBN" 20 | STRIDE_IN_1X1: False 21 | FPN: 22 | NORM: "SyncBN" 23 | ROI_BOX_HEAD: 24 | CLS_AGNOSTIC_BBOX_REG: True 25 | ROI_HEADS: 26 | NAME: CustomCascadeROIHeads 27 | NUM_CLASSES: 1 28 | SCORE_THRESH_TEST: 0.0 29 | POSITIVE_FRACTION: 0.25 30 | USE_DROPLOSS: True 31 | DROPLOSS_IOU_THRESH: 0.01 32 | RPN: 33 | POST_NMS_TOPK_TRAIN: 4000 34 | NMS_THRESH: 0.65 35 | DATASETS: 36 | TEST: ("my_data_train_coco_cod_style",) 37 | SOLVER: 38 | IMS_PER_BATCH: 8 39 | BASE_LR: 0.005 40 | WEIGHT_DECAY: 0.00005 41 | STEPS: (8000,) 42 | MAX_ITER: 12000 43 | GAMMA: 0.02 44 | CLIP_GRADIENTS: 45 | CLIP_TYPE: norm 46 | CLIP_VALUE: 1.0 47 | ENABLED: true 48 | NORM_TYPE: 2.0 49 | AMP: 50 | ENABLED: True 51 | INPUT: 52 | MIN_SIZE_TRAIN: (240, 320, 480, 640, 672, 704, 736, 768, 800, 1024) 53 | MAX_SIZE_TRAIN: 1333 54 | MASK_FORMAT: "bitmask" 55 | FORMAT: "RGB" 56 | TEST: 57 | PRECISE_BN: 58 | ENABLED: True 59 | NUM_ITER: 200 60 | DETECTIONS_PER_IMAGE: 100 61 | OUTPUT_DIR: "output/" -------------------------------------------------------------------------------- /tpnet/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | from .roi_heads import ( 4 | ROI_HEADS_REGISTRY, 5 | ROIHeads, 6 | CustomStandardROIHeads, 7 | FastRCNNOutputLayers, 8 | build_roi_heads, 9 | ) 10 | from .roi_heads.custom_cascade_rcnn import CustomCascadeROIHeads 11 | from .roi_heads.fast_rcnn import FastRCNNOutputLayers 12 | from .meta_arch.rcnn import GeneralizedRCNN, ProposalNetwork 13 | from .meta_arch.build import build_model 14 | 15 | _EXCLUDE = {"ShapeSpec"} 16 | __all__ = [k for k in globals().keys() if k not in _EXCLUDE and not k.startswith("_")] -------------------------------------------------------------------------------- /tpnet/modeling/meta_arch/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # Modified by XuDong Wang from https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/meta_arch/__init__.py 4 | 5 | from .build import META_ARCH_REGISTRY, build_model # isort:skip 6 | 7 | __all__ = list(globals().keys()) 8 | -------------------------------------------------------------------------------- /tpnet/modeling/meta_arch/build.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # Modified by XuDong Wang from https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/meta_arch/build.py 3 | 4 | import torch 5 | 6 | from detectron2.utils.logger import _log_api_usage 7 | from detectron2.utils.registry import Registry 8 | 9 | META_ARCH_REGISTRY = Registry("META_ARCH") # noqa F401 isort:skip 10 | META_ARCH_REGISTRY.__doc__ = """ 11 | Registry for meta-architectures, i.e. the whole model. 12 | 13 | The registered object will be called with `obj(cfg)` 14 | and expected to return a `nn.Module` object. 15 | """ 16 | 17 | 18 | def build_model(cfg): 19 | """ 20 | Build the whole model architecture, defined by ``cfg.MODEL.META_ARCHITECTURE``. 21 | Note that it does not load any weights from ``cfg``. 22 | """ 23 | meta_arch = cfg.MODEL.META_ARCHITECTURE 24 | model = META_ARCH_REGISTRY.get(meta_arch)(cfg) 25 | model.to(torch.device(cfg.MODEL.DEVICE)) 26 | _log_api_usage("modeling.meta_arch." + meta_arch) 27 | return model 28 | -------------------------------------------------------------------------------- /tpnet/modeling/meta_arch/rcnn.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # Modified by XuDong Wang from https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/meta_arch/rcnn.py 3 | 4 | import logging 5 | import numpy as np 6 | from typing import Dict, List, Optional, Tuple 7 | import torch 8 | from torch import nn 9 | 10 | from detectron2.config import configurable 11 | from detectron2.data.detection_utils import convert_image_to_rgb 12 | from detectron2.layers import move_device_like 13 | from detectron2.structures import ImageList, Instances 14 | from detectron2.utils.events import get_event_storage 15 | from detectron2.utils.logger import log_first_n 16 | 17 | from detectron2.modeling.backbone import Backbone, build_backbone 18 | from detectron2.modeling.postprocessing import detector_postprocess 19 | from detectron2.modeling.proposal_generator import build_proposal_generator 20 | from ..roi_heads import build_roi_heads 21 | from .build import META_ARCH_REGISTRY 22 | 23 | __all__ = ["GeneralizedRCNN", "ProposalNetwork"] 24 | 25 | 26 | @META_ARCH_REGISTRY.register() 27 | class GeneralizedRCNN(nn.Module): 28 | """ 29 | Generalized R-CNN. Any models that contains the following three components: 30 | 1. Per-image feature extraction (aka backbone) 31 | 2. Region proposal generation 32 | 3. Per-region feature extraction and prediction 33 | """ 34 | 35 | @configurable 36 | def __init__( 37 | self, 38 | *, 39 | backbone: Backbone, 40 | proposal_generator: nn.Module, 41 | roi_heads: nn.Module, 42 | pixel_mean: Tuple[float], 43 | pixel_std: Tuple[float], 44 | input_format: Optional[str] = None, 45 | vis_period: int = 0, 46 | ): 47 | """ 48 | Args: 49 | backbone: a backbone module, must follow detectron2's backbone interface 50 | proposal_generator: a module that generates proposals using backbone features 51 | roi_heads: a ROI head that performs per-region computation 52 | pixel_mean, pixel_std: list or tuple with #channels element, representing 53 | the per-channel mean and std to be used to normalize the input image 54 | input_format: describe the meaning of channels of input. Needed by visualization 55 | vis_period: the period to run visualization. Set to 0 to disable. 56 | """ 57 | super().__init__() 58 | self.backbone = backbone 59 | self.proposal_generator = proposal_generator 60 | self.roi_heads = roi_heads 61 | 62 | self.input_format = input_format 63 | self.vis_period = vis_period 64 | if vis_period > 0: 65 | assert input_format is not None, "input_format is required for visualization!" 66 | 67 | self.register_buffer("pixel_mean", torch.tensor(pixel_mean).view(-1, 1, 1), False) 68 | self.register_buffer("pixel_std", torch.tensor(pixel_std).view(-1, 1, 1), False) 69 | assert ( 70 | self.pixel_mean.shape == self.pixel_std.shape 71 | ), f"{self.pixel_mean} and {self.pixel_std} have different shapes!" 72 | 73 | @classmethod 74 | def from_config(cls, cfg): 75 | backbone = build_backbone(cfg) 76 | return { 77 | "backbone": backbone, 78 | "proposal_generator": build_proposal_generator(cfg, backbone.output_shape()), 79 | "roi_heads": build_roi_heads(cfg, backbone.output_shape()), 80 | "input_format": cfg.INPUT.FORMAT, 81 | "vis_period": cfg.VIS_PERIOD, 82 | "pixel_mean": cfg.MODEL.PIXEL_MEAN, 83 | "pixel_std": cfg.MODEL.PIXEL_STD, 84 | } 85 | 86 | @property 87 | def device(self): 88 | return self.pixel_mean.device 89 | 90 | def _move_to_current_device(self, x): 91 | return move_device_like(x, self.pixel_mean) 92 | 93 | def visualize_training(self, batched_inputs, proposals): 94 | """ 95 | A function used to visualize images and proposals. It shows ground truth 96 | bounding boxes on the original image and up to 20 top-scoring predicted 97 | object proposals on the original image. Users can implement different 98 | visualization functions for different models. 99 | 100 | Args: 101 | batched_inputs (list): a list that contains input to the model. 102 | proposals (list): a list that contains predicted proposals. Both 103 | batched_inputs and proposals should have the same length. 104 | """ 105 | from detectron2.utils.visualizer import Visualizer 106 | 107 | storage = get_event_storage() 108 | max_vis_prop = 20 109 | 110 | for input, prop in zip(batched_inputs, proposals): 111 | img = input["image"] 112 | img = convert_image_to_rgb(img.permute(1, 2, 0), self.input_format) 113 | v_gt = Visualizer(img, None) 114 | v_gt = v_gt.overlay_instances(boxes=input["instances"].gt_boxes) 115 | anno_img = v_gt.get_image() 116 | box_size = min(len(prop.proposal_boxes), max_vis_prop) 117 | v_pred = Visualizer(img, None) 118 | v_pred = v_pred.overlay_instances( 119 | boxes=prop.proposal_boxes[0:box_size].tensor.cpu().numpy() 120 | ) 121 | prop_img = v_pred.get_image() 122 | vis_img = np.concatenate((anno_img, prop_img), axis=1) 123 | vis_img = vis_img.transpose(2, 0, 1) 124 | vis_name = "Left: GT bounding boxes; Right: Predicted proposals" 125 | storage.put_image(vis_name, vis_img) 126 | break # only visualize one image in a batch 127 | 128 | def forward(self, batched_inputs: List[Dict[str, torch.Tensor]]): 129 | """ 130 | Args: 131 | batched_inputs: a list, batched outputs of :class:`DatasetMapper` . 132 | Each item in the list contains the inputs for one image. 133 | For now, each item in the list is a dict that contains: 134 | 135 | * image: Tensor, image in (C, H, W) format. 136 | * instances (optional): groundtruth :class:`Instances` 137 | * proposals (optional): :class:`Instances`, precomputed proposals. 138 | 139 | Other information that's included in the original dicts, such as: 140 | 141 | * "height", "width" (int): the output resolution of the model, used in inference. 142 | See :meth:`postprocess` for details. 143 | 144 | Returns: 145 | list[dict]: 146 | Each dict is the output for one input image. 147 | The dict contains one key "instances" whose value is a :class:`Instances`. 148 | The :class:`Instances` object has the following keys: 149 | "pred_boxes", "pred_classes", "scores", "pred_masks", "pred_keypoints" 150 | """ 151 | if not self.training: 152 | return self.inference(batched_inputs) 153 | 154 | images = self.preprocess_image(batched_inputs) 155 | if "instances" in batched_inputs[0]: 156 | gt_instances = [x["instances"].to(self.device) for x in batched_inputs] 157 | else: 158 | gt_instances = None 159 | 160 | features = self.backbone(images.tensor) 161 | 162 | if self.proposal_generator is not None: 163 | proposals, proposal_losses = self.proposal_generator(images, features, gt_instances) 164 | else: 165 | assert "proposals" in batched_inputs[0] 166 | proposals = [x["proposals"].to(self.device) for x in batched_inputs] 167 | proposal_losses = {} 168 | 169 | _, detector_losses = self.roi_heads(images, features, proposals, gt_instances) 170 | if self.vis_period > 0: 171 | storage = get_event_storage() 172 | if storage.iter % self.vis_period == 0: 173 | self.visualize_training(batched_inputs, proposals) 174 | 175 | losses = {} 176 | losses.update(detector_losses) 177 | losses.update(proposal_losses) 178 | return losses 179 | 180 | def inference( 181 | self, 182 | batched_inputs: List[Dict[str, torch.Tensor]], 183 | detected_instances: Optional[List[Instances]] = None, 184 | do_postprocess: bool = True, 185 | ): 186 | """ 187 | Run inference on the given inputs. 188 | 189 | Args: 190 | batched_inputs (list[dict]): same as in :meth:`forward` 191 | detected_instances (None or list[Instances]): if not None, it 192 | contains an `Instances` object per image. The `Instances` 193 | object contains "pred_boxes" and "pred_classes" which are 194 | known boxes in the image. 195 | The inference will then skip the detection of bounding boxes, 196 | and only predict other per-ROI outputs. 197 | do_postprocess (bool): whether to apply post-processing on the outputs. 198 | 199 | Returns: 200 | When do_postprocess=True, same as in :meth:`forward`. 201 | Otherwise, a list[Instances] containing raw network outputs. 202 | """ 203 | assert not self.training 204 | 205 | images = self.preprocess_image(batched_inputs) 206 | features = self.backbone(images.tensor) 207 | 208 | if detected_instances is None: 209 | if self.proposal_generator is not None: 210 | proposals, _ = self.proposal_generator(images, features, None) 211 | else: 212 | assert "proposals" in batched_inputs[0] 213 | proposals = [x["proposals"].to(self.device) for x in batched_inputs] 214 | 215 | results, _ = self.roi_heads(images, features, proposals, None) 216 | else: 217 | detected_instances = [x.to(self.device) for x in detected_instances] 218 | results = self.roi_heads.forward_with_given_boxes(features, detected_instances) 219 | 220 | if do_postprocess: 221 | assert not torch.jit.is_scripting(), "Scripting is not supported for postprocess." 222 | return GeneralizedRCNN._postprocess(results, batched_inputs, images.image_sizes) 223 | else: 224 | return results 225 | 226 | def preprocess_image(self, batched_inputs: List[Dict[str, torch.Tensor]]): 227 | """ 228 | Normalize, pad and batch the input images. 229 | """ 230 | #print(batched_inputs) 231 | images = [self._move_to_current_device(x["image"]) for x in batched_inputs] 232 | #print(batched_inputs) 233 | #print(type(batched_inputs)) 234 | #images = [self._move_to_current_device(batched_inputs["image"]) ] 235 | images = [(x - self.pixel_mean) / self.pixel_std for x in images] 236 | images = ImageList.from_tensors( 237 | images, 238 | self.backbone.size_divisibility, 239 | padding_constraints=self.backbone.padding_constraints, 240 | ) 241 | return images 242 | 243 | @staticmethod 244 | def _postprocess(instances, batched_inputs: List[Dict[str, torch.Tensor]], image_sizes): 245 | """ 246 | Rescale the output instances to the target size. 247 | """ 248 | # note: private function; subject to changes 249 | processed_results = [] 250 | #batched_inputs = (batched_inputs,) 251 | for results_per_image, input_per_image, image_size in zip( 252 | instances, batched_inputs, image_sizes 253 | ): 254 | height = input_per_image.get("height", image_size[0]) 255 | width = input_per_image.get("width", image_size[1]) 256 | r = detector_postprocess(results_per_image, height, width) 257 | processed_results.append({"instances": r}) 258 | return processed_results 259 | 260 | 261 | @META_ARCH_REGISTRY.register() 262 | class ProposalNetwork(nn.Module): 263 | """ 264 | A meta architecture that only predicts object proposals. 265 | """ 266 | 267 | @configurable 268 | def __init__( 269 | self, 270 | *, 271 | backbone: Backbone, 272 | proposal_generator: nn.Module, 273 | pixel_mean: Tuple[float], 274 | pixel_std: Tuple[float], 275 | ): 276 | """ 277 | Args: 278 | backbone: a backbone module, must follow detectron2's backbone interface 279 | proposal_generator: a module that generates proposals using backbone features 280 | pixel_mean, pixel_std: list or tuple with #channels element, representing 281 | the per-channel mean and std to be used to normalize the input image 282 | """ 283 | super().__init__() 284 | self.backbone = backbone 285 | self.proposal_generator = proposal_generator 286 | self.register_buffer("pixel_mean", torch.tensor(pixel_mean).view(-1, 1, 1), False) 287 | self.register_buffer("pixel_std", torch.tensor(pixel_std).view(-1, 1, 1), False) 288 | 289 | @classmethod 290 | def from_config(cls, cfg): 291 | backbone = build_backbone(cfg) 292 | return { 293 | "backbone": backbone, 294 | "proposal_generator": build_proposal_generator(cfg, backbone.output_shape()), 295 | "pixel_mean": cfg.MODEL.PIXEL_MEAN, 296 | "pixel_std": cfg.MODEL.PIXEL_STD, 297 | } 298 | 299 | @property 300 | def device(self): 301 | return self.pixel_mean.device 302 | 303 | def _move_to_current_device(self, x): 304 | return move_device_like(x, self.pixel_mean) 305 | 306 | def forward(self, batched_inputs): 307 | """ 308 | Args: 309 | Same as in :class:`GeneralizedRCNN.forward` 310 | 311 | Returns: 312 | list[dict]: 313 | Each dict is the output for one input image. 314 | The dict contains one key "proposals" whose value is a 315 | :class:`Instances` with keys "proposal_boxes" and "objectness_logits". 316 | """ 317 | images = [self._move_to_current_device(x["image"]) for x in batched_inputs] 318 | images = [(x - self.pixel_mean) / self.pixel_std for x in images] 319 | images = ImageList.from_tensors( 320 | images, 321 | self.backbone.size_divisibility, 322 | padding_constraints=self.backbone.padding_constraints, 323 | ) 324 | features = self.backbone(images.tensor) 325 | 326 | if "instances" in batched_inputs[0]: 327 | gt_instances = [x["instances"].to(self.device) for x in batched_inputs] 328 | elif "targets" in batched_inputs[0]: 329 | log_first_n( 330 | logging.WARN, "'targets' in the model inputs is now renamed to 'instances'!", n=10 331 | ) 332 | gt_instances = [x["targets"].to(self.device) for x in batched_inputs] 333 | else: 334 | gt_instances = None 335 | proposals, proposal_losses = self.proposal_generator(images, features, gt_instances) 336 | # In training, the proposals are not useful at all but we generate them anyway. 337 | # This makes RPN-only models about 5% slower. 338 | if self.training: 339 | return proposal_losses 340 | 341 | processed_results = [] 342 | for results_per_image, input_per_image, image_size in zip( 343 | proposals, batched_inputs, images.image_sizes 344 | ): 345 | height = input_per_image.get("height", image_size[0]) 346 | width = input_per_image.get("width", image_size[1]) 347 | r = detector_postprocess(results_per_image, height, width) 348 | processed_results.append({"proposals": r}) 349 | return processed_results 350 | -------------------------------------------------------------------------------- /tpnet/modeling/roi_heads/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | from .roi_heads import ( 4 | ROI_HEADS_REGISTRY, 5 | ROIHeads, 6 | Res5ROIHeads, 7 | CustomStandardROIHeads, 8 | build_roi_heads, 9 | select_foreground_proposals, 10 | ) 11 | from .custom_cascade_rcnn import CustomCascadeROIHeads 12 | from .fast_rcnn import FastRCNNOutputLayers 13 | 14 | from . import custom_cascade_rcnn # isort:skip 15 | 16 | __all__ = list(globals().keys()) 17 | -------------------------------------------------------------------------------- /tpnet/solver/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | from .build import build_lr_scheduler, build_optimizer, get_default_optimizer_params 4 | 5 | __all__ = [k for k in globals().keys() if not k.startswith("_")] 6 | -------------------------------------------------------------------------------- /tpnet/solver/build.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # Modified by XuDong Wang from https://github.com/facebookresearch/detectron2/blob/main/detectron2/solver/build.py 3 | 4 | import copy 5 | import itertools 6 | import logging 7 | from collections import defaultdict 8 | from enum import Enum 9 | from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Type, Union 10 | import torch 11 | from fvcore.common.param_scheduler import CosineParamScheduler, MultiStepParamScheduler 12 | 13 | from detectron2.config import CfgNode 14 | 15 | from detectron2.solver.lr_scheduler import LRMultiplier, WarmupParamScheduler 16 | 17 | _GradientClipperInput = Union[torch.Tensor, Iterable[torch.Tensor]] 18 | _GradientClipper = Callable[[_GradientClipperInput], None] 19 | 20 | 21 | class GradientClipType(Enum): 22 | VALUE = "value" 23 | NORM = "norm" 24 | 25 | 26 | def _create_gradient_clipper(cfg: CfgNode) -> _GradientClipper: 27 | """ 28 | Creates gradient clipping closure to clip by value or by norm, 29 | according to the provided config. 30 | """ 31 | cfg = copy.deepcopy(cfg) 32 | 33 | def clip_grad_norm(p: _GradientClipperInput): 34 | torch.nn.utils.clip_grad_norm_(p, cfg.CLIP_VALUE, cfg.NORM_TYPE) 35 | 36 | def clip_grad_value(p: _GradientClipperInput): 37 | torch.nn.utils.clip_grad_value_(p, cfg.CLIP_VALUE) 38 | 39 | _GRADIENT_CLIP_TYPE_TO_CLIPPER = { 40 | GradientClipType.VALUE: clip_grad_value, 41 | GradientClipType.NORM: clip_grad_norm, 42 | } 43 | return _GRADIENT_CLIP_TYPE_TO_CLIPPER[GradientClipType(cfg.CLIP_TYPE)] 44 | 45 | 46 | def _generate_optimizer_class_with_gradient_clipping( 47 | optimizer: Type[torch.optim.Optimizer], 48 | *, 49 | per_param_clipper: Optional[_GradientClipper] = None, 50 | global_clipper: Optional[_GradientClipper] = None, 51 | ) -> Type[torch.optim.Optimizer]: 52 | """ 53 | Dynamically creates a new type that inherits the type of a given instance 54 | and overrides the `step` method to add gradient clipping 55 | """ 56 | assert ( 57 | per_param_clipper is None or global_clipper is None 58 | ), "Not allowed to use both per-parameter clipping and global clipping" 59 | 60 | def optimizer_wgc_step(self, closure=None): 61 | if per_param_clipper is not None: 62 | for group in self.param_groups: 63 | for p in group["params"]: 64 | per_param_clipper(p) 65 | else: 66 | # global clipper for future use with detr 67 | # (https://github.com/facebookresearch/detr/pull/287) 68 | all_params = itertools.chain(*[g["params"] for g in self.param_groups]) 69 | global_clipper(all_params) 70 | super(type(self), self).step(closure) 71 | 72 | OptimizerWithGradientClip = type( 73 | optimizer.__name__ + "WithGradientClip", 74 | (optimizer,), 75 | {"step": optimizer_wgc_step}, 76 | ) 77 | return OptimizerWithGradientClip 78 | 79 | 80 | def maybe_add_gradient_clipping( 81 | cfg: CfgNode, optimizer: Type[torch.optim.Optimizer] 82 | ) -> Type[torch.optim.Optimizer]: 83 | """ 84 | If gradient clipping is enabled through config options, wraps the existing 85 | optimizer type to become a new dynamically created class OptimizerWithGradientClip 86 | that inherits the given optimizer and overrides the `step` method to 87 | include gradient clipping. 88 | 89 | Args: 90 | cfg: CfgNode, configuration options 91 | optimizer: type. A subclass of torch.optim.Optimizer 92 | 93 | Return: 94 | type: either the input `optimizer` (if gradient clipping is disabled), or 95 | a subclass of it with gradient clipping included in the `step` method. 96 | """ 97 | if not cfg.SOLVER.CLIP_GRADIENTS.ENABLED: 98 | return optimizer 99 | if isinstance(optimizer, torch.optim.Optimizer): 100 | optimizer_type = type(optimizer) 101 | else: 102 | assert issubclass(optimizer, torch.optim.Optimizer), optimizer 103 | optimizer_type = optimizer 104 | 105 | grad_clipper = _create_gradient_clipper(cfg.SOLVER.CLIP_GRADIENTS) 106 | OptimizerWithGradientClip = _generate_optimizer_class_with_gradient_clipping( 107 | optimizer_type, per_param_clipper=grad_clipper 108 | ) 109 | if isinstance(optimizer, torch.optim.Optimizer): 110 | optimizer.__class__ = OptimizerWithGradientClip # a bit hacky, not recommended 111 | return optimizer 112 | else: 113 | return OptimizerWithGradientClip 114 | 115 | 116 | def build_optimizer(cfg: CfgNode, model: torch.nn.Module) -> torch.optim.Optimizer: 117 | """ 118 | Build an optimizer from config. 119 | """ 120 | params = get_default_optimizer_params( 121 | model, 122 | base_lr=cfg.SOLVER.BASE_LR, 123 | base_lr_multiplier=cfg.SOLVER.BASE_LR_MULTIPLIER, 124 | base_lr_multiplier_names=cfg.SOLVER.BASE_LR_MULTIPLIER_NAMES, 125 | weight_decay_norm=cfg.SOLVER.WEIGHT_DECAY_NORM, 126 | bias_lr_factor=cfg.SOLVER.BIAS_LR_FACTOR, 127 | weight_decay_bias=cfg.SOLVER.WEIGHT_DECAY_BIAS, 128 | ) 129 | return maybe_add_gradient_clipping(cfg, torch.optim.SGD)( 130 | params, 131 | lr=cfg.SOLVER.BASE_LR, 132 | momentum=cfg.SOLVER.MOMENTUM, 133 | nesterov=cfg.SOLVER.NESTEROV, 134 | weight_decay=cfg.SOLVER.WEIGHT_DECAY, 135 | ) 136 | 137 | 138 | def get_default_optimizer_params( 139 | model: torch.nn.Module, 140 | base_lr: Optional[float] = None, 141 | base_lr_multiplier: Optional[float] = 1.0, 142 | base_lr_multiplier_names: Optional[List[str]] = [], 143 | weight_decay: Optional[float] = None, 144 | weight_decay_norm: Optional[float] = None, 145 | bias_lr_factor: Optional[float] = 1.0, 146 | weight_decay_bias: Optional[float] = None, 147 | lr_factor_func: Optional[Callable] = None, 148 | overrides: Optional[Dict[str, Dict[str, float]]] = None, 149 | ) -> List[Dict[str, Any]]: 150 | """ 151 | Get default param list for optimizer, with support for a few types of 152 | overrides. If no overrides needed, this is equivalent to `model.parameters()`. 153 | 154 | Args: 155 | base_lr: lr for every group by default. Can be omitted to use the one in optimizer. 156 | weight_decay: weight decay for every group by default. Can be omitted to use the one 157 | in optimizer. 158 | weight_decay_norm: override weight decay for params in normalization layers 159 | bias_lr_factor: multiplier of lr for bias parameters. 160 | weight_decay_bias: override weight decay for bias parameters. 161 | lr_factor_func: function to calculate lr decay rate by mapping the parameter names to 162 | corresponding lr decay rate. Note that setting this option requires 163 | also setting ``base_lr``. 164 | overrides: if not `None`, provides values for optimizer hyperparameters 165 | (LR, weight decay) for module parameters with a given name; e.g. 166 | ``{"embedding": {"lr": 0.01, "weight_decay": 0.1}}`` will set the LR and 167 | weight decay values for all module parameters named `embedding`. 168 | 169 | For common detection models, ``weight_decay_norm`` is the only option 170 | needed to be set. ``bias_lr_factor,weight_decay_bias`` are legacy settings 171 | from Detectron1 that are not found useful. 172 | 173 | Example: 174 | :: 175 | torch.optim.SGD(get_default_optimizer_params(model, weight_decay_norm=0), 176 | lr=0.01, weight_decay=1e-4, momentum=0.9) 177 | """ 178 | if overrides is None: 179 | overrides = {} 180 | defaults = {} 181 | if base_lr is not None: 182 | defaults["lr"] = base_lr 183 | if weight_decay is not None: 184 | defaults["weight_decay"] = weight_decay 185 | bias_overrides = {} 186 | if bias_lr_factor is not None and bias_lr_factor != 1.0: 187 | # NOTE: unlike Detectron v1, we now by default make bias hyperparameters 188 | # exactly the same as regular weights. 189 | if base_lr is None: 190 | raise ValueError("bias_lr_factor requires base_lr") 191 | bias_overrides["lr"] = base_lr * bias_lr_factor 192 | if weight_decay_bias is not None: 193 | bias_overrides["weight_decay"] = weight_decay_bias 194 | if len(bias_overrides): 195 | if "bias" in overrides: 196 | raise ValueError("Conflicting overrides for 'bias'") 197 | overrides["bias"] = bias_overrides 198 | if lr_factor_func is not None: 199 | if base_lr is None: 200 | raise ValueError("lr_factor_func requires base_lr") 201 | norm_module_types = ( 202 | torch.nn.BatchNorm1d, 203 | torch.nn.BatchNorm2d, 204 | torch.nn.BatchNorm3d, 205 | torch.nn.SyncBatchNorm, 206 | # NaiveSyncBatchNorm inherits from BatchNorm2d 207 | torch.nn.GroupNorm, 208 | torch.nn.InstanceNorm1d, 209 | torch.nn.InstanceNorm2d, 210 | torch.nn.InstanceNorm3d, 211 | torch.nn.LayerNorm, 212 | torch.nn.LocalResponseNorm, 213 | ) 214 | params: List[Dict[str, Any]] = [] 215 | memo: Set[torch.nn.parameter.Parameter] = set() 216 | for module_name, module in model.named_modules(): 217 | for module_param_name, value in module.named_parameters(recurse=False): 218 | if not value.requires_grad: 219 | continue 220 | # Avoid duplicating parameters 221 | if value in memo: 222 | continue 223 | memo.add(value) 224 | 225 | hyperparams = copy.copy(defaults) 226 | if isinstance(module, norm_module_types) and weight_decay_norm is not None: 227 | hyperparams["weight_decay"] = weight_decay_norm 228 | if lr_factor_func is not None: 229 | hyperparams["lr"] *= lr_factor_func(f"{module_name}.{module_param_name}") 230 | hyperparams.update(overrides.get(module_param_name, {})) 231 | if module_name in base_lr_multiplier_names: 232 | hyperparams["lr"] *= base_lr_multiplier 233 | # print(" Checked: ", module_name, hyperparams["lr"]) 234 | 235 | params.append({"params": [value], **hyperparams}) 236 | return reduce_param_groups(params) 237 | 238 | 239 | def _expand_param_groups(params: List[Dict[str, Any]]) -> List[Dict[str, Any]]: 240 | # Transform parameter groups into per-parameter structure. 241 | # Later items in `params` can overwrite parameters set in previous items. 242 | ret = defaultdict(dict) 243 | for item in params: 244 | assert "params" in item 245 | cur_params = {x: y for x, y in item.items() if x != "params"} 246 | for param in item["params"]: 247 | ret[param].update({"params": [param], **cur_params}) 248 | return list(ret.values()) 249 | 250 | 251 | def reduce_param_groups(params: List[Dict[str, Any]]) -> List[Dict[str, Any]]: 252 | # Reorganize the parameter groups and merge duplicated groups. 253 | # The number of parameter groups needs to be as small as possible in order 254 | # to efficiently use the PyTorch multi-tensor optimizer. Therefore instead 255 | # of using a parameter_group per single parameter, we reorganize the 256 | # parameter groups and merge duplicated groups. This approach speeds 257 | # up multi-tensor optimizer significantly. 258 | params = _expand_param_groups(params) 259 | groups = defaultdict(list) # re-group all parameter groups by their hyperparams 260 | for item in params: 261 | cur_params = tuple((x, y) for x, y in item.items() if x != "params") 262 | groups[cur_params].extend(item["params"]) 263 | ret = [] 264 | for param_keys, param_values in groups.items(): 265 | cur = {kv[0]: kv[1] for kv in param_keys} 266 | cur["params"] = param_values 267 | ret.append(cur) 268 | return ret 269 | 270 | 271 | def build_lr_scheduler( 272 | cfg: CfgNode, optimizer: torch.optim.Optimizer 273 | ) -> torch.optim.lr_scheduler._LRScheduler: 274 | """ 275 | Build a LR scheduler from config. 276 | """ 277 | name = cfg.SOLVER.LR_SCHEDULER_NAME 278 | 279 | if name == "WarmupMultiStepLR": 280 | steps = [x for x in cfg.SOLVER.STEPS if x <= cfg.SOLVER.MAX_ITER] 281 | if len(steps) != len(cfg.SOLVER.STEPS): 282 | logger = logging.getLogger(__name__) 283 | logger.warning( 284 | "SOLVER.STEPS contains values larger than SOLVER.MAX_ITER. " 285 | "These values will be ignored." 286 | ) 287 | sched = MultiStepParamScheduler( 288 | values=[cfg.SOLVER.GAMMA**k for k in range(len(steps) + 1)], 289 | milestones=steps, 290 | num_updates=cfg.SOLVER.MAX_ITER, 291 | ) 292 | elif name == "WarmupCosineLR": 293 | end_value = cfg.SOLVER.BASE_LR_END / cfg.SOLVER.BASE_LR 294 | assert end_value >= 0.0 and end_value <= 1.0, end_value 295 | sched = CosineParamScheduler(1, end_value) 296 | else: 297 | raise ValueError("Unknown LR scheduler: {}".format(name)) 298 | 299 | sched = WarmupParamScheduler( 300 | sched, 301 | cfg.SOLVER.WARMUP_FACTOR, 302 | min(cfg.SOLVER.WARMUP_ITERS / cfg.SOLVER.MAX_ITER, 1.0), 303 | cfg.SOLVER.WARMUP_METHOD, 304 | ) 305 | return LRMultiplier(optimizer, multiplier=sched, max_iter=cfg.SOLVER.MAX_ITER) 306 | -------------------------------------------------------------------------------- /tpnet/structures/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | from .boxes import pairwise_iou_max_scores 4 | 5 | __all__ = [k for k in globals().keys() if not k.startswith("_")] 6 | 7 | 8 | from detectron2.utils.env import fixup_module_metadata 9 | 10 | fixup_module_metadata(__name__, globals(), __all__) 11 | del fixup_module_metadata 12 | -------------------------------------------------------------------------------- /tpnet/structures/boxes.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # Modified by XuDong Wang from https://github.com/facebookresearch/detectron2/blob/main/detectron2/structures/boxes.py 3 | 4 | import torch 5 | 6 | def pairwise_iou_max_scores(boxes1: torch.Tensor, boxes2: torch.Tensor) -> torch.Tensor: 7 | """ 8 | Given two lists of boxes of size N and M, compute the IoU 9 | (intersection over union) between **all** N x M pairs of boxes. 10 | The box order must be (xmin, ymin, xmax, ymax). 11 | 12 | Args: 13 | boxes1,boxes2 (Boxes): two `Boxes`. Contains N & M boxes, respectively. 14 | 15 | Returns: 16 | Tensor: IoU, sized [N,M]. 17 | """ 18 | area1 = (boxes1[:, 2] - boxes1[:, 0]) * (boxes1[:, 3] - boxes1[:, 1]) # [N] 19 | area2 = (boxes2[:, 2] - boxes2[:, 0]) * (boxes2[:, 3] - boxes2[:, 1]) # [M] 20 | 21 | width_height = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) - torch.max( 22 | boxes1[:, None, :2], boxes2[:, :2] 23 | ) # [N,M,2] 24 | 25 | width_height.clamp_(min=0) # [N,M,2] 26 | inter = width_height.prod(dim=2) # [N,M] 27 | 28 | # handle empty boxes 29 | iou = torch.where( 30 | inter > 0, 31 | inter / (area1[:, None] + area2 - inter), 32 | torch.zeros(1, dtype=inter.dtype, device=inter.device), 33 | ) 34 | iou_max, _ = torch.max(iou, dim=1) 35 | return iou_max -------------------------------------------------------------------------------- /tpnet/tools/eval.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | # link to the dataset folder, model weights and the config file. 4 | export DETECTRON2_DATASETS=/path/to/DETECTRON2_DATASETS/ 5 | model_weights="http://dl.fbaipublicfiles.com/cutler/checkpoints/cutler_cascade_final.pth" 6 | config_file="model_zoo/configs/CutLER-ImageNet/cascade_mask_rcnn_R_50_FPN.yaml" 7 | num_gpus=2 8 | 9 | echo "========== start evaluating the model on all 11 datasets ==========" 10 | 11 | test_dataset='cls_agnostic_clipart' 12 | echo "========== evaluating ${test_dataset} ==========" 13 | python train_net.py --num-gpus ${num_gpus} \ 14 | --config-file ${config_file} \ 15 | --test-dataset ${test_dataset} --no-segm \ 16 | --eval-only MODEL.WEIGHTS ${model_weights} 17 | 18 | test_dataset='cls_agnostic_watercolor' 19 | echo "========== evaluating ${test_dataset} ==========" 20 | python train_net.py --num-gpus ${num_gpus} \ 21 | --config-file ${config_file} \ 22 | --test-dataset ${test_dataset} --no-segm \ 23 | --eval-only MODEL.WEIGHTS ${model_weights} 24 | 25 | test_dataset='cls_agnostic_comic' 26 | echo "========== evaluating ${test_dataset} ==========" 27 | python train_net.py --num-gpus ${num_gpus} \ 28 | --config-file ${config_file} \ 29 | --test-dataset ${test_dataset} --no-segm \ 30 | --eval-only MODEL.WEIGHTS ${model_weights} 31 | 32 | test_dataset='cls_agnostic_voc' 33 | echo "========== evaluating ${test_dataset} ==========" 34 | python train_net.py --num-gpus ${num_gpus} \ 35 | --config-file ${config_file} \ 36 | --test-dataset ${test_dataset} --no-segm \ 37 | --eval-only MODEL.WEIGHTS ${model_weights} 38 | 39 | test_dataset='cls_agnostic_objects365' 40 | echo "========== evaluating ${test_dataset} ==========" 41 | python train_net.py --num-gpus ${num_gpus} \ 42 | --config-file ${config_file} \ 43 | --test-dataset ${test_dataset} --no-segm \ 44 | --eval-only MODEL.WEIGHTS ${model_weights} 45 | 46 | test_dataset='cls_agnostic_openimages' 47 | echo "========== evaluating ${test_dataset} ==========" 48 | python train_net.py --num-gpus ${num_gpus} \ 49 | --config-file ${config_file} \ 50 | --test-dataset ${test_dataset} --no-segm \ 51 | --eval-only MODEL.WEIGHTS ${model_weights} 52 | 53 | test_dataset='cls_agnostic_kitti' 54 | echo "========== evaluating ${test_dataset} ==========" 55 | python train_net.py --num-gpus ${num_gpus} \ 56 | --config-file ${config_file} \ 57 | --test-dataset ${test_dataset} --no-segm \ 58 | --eval-only MODEL.WEIGHTS ${model_weights} 59 | 60 | test_dataset='cls_agnostic_coco' 61 | echo "========== evaluating ${test_dataset} ==========" 62 | python train_net.py --num-gpus ${num_gpus} \ 63 | --config-file ${config_file} \ 64 | --test-dataset ${test_dataset} \ 65 | --eval-only MODEL.WEIGHTS ${model_weights} 66 | 67 | test_dataset='cls_agnostic_coco20k' 68 | echo "========== evaluating ${test_dataset} ==========" 69 | python train_net.py --num-gpus ${num_gpus} \ 70 | --config-file ${config_file} \ 71 | --test-dataset ${test_dataset} \ 72 | --eval-only MODEL.WEIGHTS ${model_weights} 73 | 74 | test_dataset='cls_agnostic_lvis' 75 | echo "========== evaluating ${test_dataset} ==========" 76 | # LVIS should set TEST.DETECTIONS_PER_IMAGE=300 77 | python train_net.py --num-gpus ${num_gpus} \ 78 | --config-file ${config_file} \ 79 | --test-dataset ${test_dataset} \ 80 | --eval-only MODEL.WEIGHTS ${model_weights} TEST.DETECTIONS_PER_IMAGE 300 81 | 82 | test_dataset='cls_agnostic_uvo' 83 | echo "========== evaluating ${test_dataset} ==========" 84 | python train_net.py --num-gpus ${num_gpus} \ 85 | --config-file ${config_file} \ 86 | --test-dataset ${test_dataset} \ 87 | --eval-only MODEL.WEIGHTS ${model_weights} 88 | 89 | echo "========== evaluation is completed ==========" -------------------------------------------------------------------------------- /tpnet/tools/get_self_training_ann.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | 4 | import os 5 | import json 6 | import tqdm 7 | import torch 8 | import datetime 9 | import argparse 10 | import pycocotools.mask as cocomask 11 | from detectron2.utils.file_io import PathManager 12 | 13 | INFO = { 14 | "description": "ImageNet-1K: Self-train", 15 | "url": "", 16 | "version": "1.0", 17 | "year": 2022, 18 | "contributor": "Xudong Wang", 19 | "date_created": datetime.datetime.utcnow().isoformat(' ') 20 | } 21 | 22 | LICENSES = [ 23 | { 24 | "id": 1, 25 | "name": "Apache License", 26 | "url": "https://github.com/facebookresearch/CutLER/blob/main/LICENSE" 27 | } 28 | ] 29 | 30 | CATEGORIES = [ 31 | { 32 | 'id': 1, 33 | 'name': 'fg', 34 | 'supercategory': 'fg', 35 | }, 36 | ] 37 | 38 | new_dict_filtered = { 39 | "info": INFO, 40 | "licenses": LICENSES, 41 | "categories": CATEGORIES, 42 | "images": [], 43 | "annotations": [] 44 | } 45 | 46 | category_info = { 47 | "is_crowd": 0, 48 | "id": 1 49 | } 50 | 51 | 52 | def segmToRLE(segm, h, w): 53 | if isinstance(segm, list): 54 | # polygon -- a single object might consist of multiple parts 55 | # we merge all parts into one mask rle code 56 | rles = cocomask.frPyObjects(segm, h, w) 57 | rle = cocomask.merge(rles) 58 | elif isinstance(segm["counts"], list): 59 | # uncompressed RLE 60 | rle = cocomask.frPyObjects(segm, h, w) 61 | else: 62 | # rle 63 | rle = segm 64 | return rle 65 | 66 | def rle2mask(rle, height, width): 67 | if "counts" in rle and isinstance(rle["counts"], list): 68 | # if compact RLE, ignore this conversion 69 | # Magic RLE format handling painfully discovered by looking at the 70 | # COCO API showAnns function. 71 | rle = cocomask.frPyObjects(rle, height, width) 72 | mask = cocomask.decode(rle) 73 | return mask 74 | 75 | def cocosegm2mask(segm, h, w): 76 | rle = segmToRLE(segm, h, w) 77 | mask = rle2mask(rle, h, w) 78 | return mask 79 | 80 | def BatchIoU(masks1, masks2): 81 | n1, n2 = masks1.size()[0], masks2.size()[0] 82 | masks1, masks2 = (masks1>0.5).to(torch.bool), (masks2>0.5).to(torch.bool) 83 | masks1_ = masks1[:,None,:,:,].expand(-1, n2, -1, -1) 84 | masks2_ = masks2[None,:,:,:,].expand(n1, -1, -1, -1) 85 | 86 | intersection = torch.sum(masks1_ * (masks1_ == masks2_), dim=[-1, -2]) 87 | union = torch.sum(masks1_ + masks2_, dim=[-1, -2]) 88 | ious = intersection.to(torch.float) / union 89 | return ious 90 | 91 | if __name__ == "__main__": 92 | # load model arguments 93 | parser = argparse.ArgumentParser(description='Generate json files for the self-training') 94 | parser.add_argument('--new-pred', type=str, 95 | default='output/inference/coco_instances_results.json', 96 | help='Path to model predictions') 97 | parser.add_argument('--prev-ann', type=str, 98 | default='DETECTRON2_DATASETS/imagenet/annotations/cutler_imagenet1k_train.json', 99 | help='Path to annotations in the previous round') 100 | parser.add_argument('--save-path', type=str, 101 | default='DETECTRON2_DATASETS/imagenet/annotations/cutler_imagenet1k_train_r1.json', 102 | help='Path to save the generated annotation file') 103 | # parser.add_argument('--n-rounds', type=int, default=1, 104 | # help='N-th round of self-training') 105 | parser.add_argument('--threshold', type=float, default=0.7, 106 | help='Confidence score thresholds') 107 | args = parser.parse_args() 108 | 109 | # load model predictions 110 | new_pred = args.new_pred 111 | with PathManager.open(new_pred, "r") as f: 112 | predictions = json.load(f) 113 | 114 | # filter out low-confidence model predictions 115 | THRESHOLD = args.threshold 116 | pred_image_to_anns = {} 117 | for id, ann in enumerate(predictions): 118 | confidence_score = ann['score'] 119 | if confidence_score >= THRESHOLD: 120 | if ann['image_id'] in pred_image_to_anns: 121 | pred_image_to_anns[ann['image_id']].append(ann) 122 | else: 123 | pred_image_to_anns[ann['image_id']] = [ann] 124 | 125 | # load psedu-masks used by the previous round 126 | pseudo_ann_dict = json.load(open(args.prev_ann)) 127 | pseudo_image_list = pseudo_ann_dict['images'] 128 | pseudo_annotations = pseudo_ann_dict['annotations'] 129 | 130 | pseudo_image_to_anns = {} 131 | for id, ann in enumerate(pseudo_annotations): 132 | if ann['image_id'] in pseudo_image_to_anns: 133 | pseudo_image_to_anns[ann['image_id']].append(ann) 134 | else: 135 | pseudo_image_to_anns[ann['image_id']] = [ann] 136 | 137 | # merge model predictions and the json file used by the previous round. 138 | merged_anns = [] 139 | num_preds, num_pseudo = 0, 0 140 | for k, anns_pseudo in tqdm.tqdm(pseudo_image_to_anns.items()): 141 | masks = [] 142 | for ann in anns_pseudo: 143 | segm = ann['segmentation'] 144 | mask = cocosegm2mask(segm, segm['size'][0], segm['size'][1]) 145 | masks.append(torch.from_numpy(mask)) 146 | pseudo_masks = torch.stack(masks, dim=0).cuda() 147 | del masks 148 | num_pseudo += len(anns_pseudo) 149 | try: 150 | anns_pred = pred_image_to_anns[k] 151 | except: 152 | merged_anns += anns_pseudo 153 | continue 154 | masks = [] 155 | for ann in anns_pred: 156 | segm = ann['segmentation'] 157 | mask = cocosegm2mask(segm, segm['size'][0], segm['size'][1]) 158 | masks.append(torch.from_numpy(mask)) 159 | pred_masks = torch.stack(masks, dim=0).cuda() 160 | num_preds += len(anns_pred) 161 | try: 162 | ious = BatchIoU(pseudo_masks, pred_masks) 163 | iou_max, _ = ious.max(dim=1) 164 | selected_index = (iou_max < 0.5).nonzero() 165 | selected_pseudo = [anns_pseudo[i] for i in selected_index] 166 | merged_anns += anns_pred + selected_pseudo 167 | # if num_preds % 200000 == 0: 168 | # print(len(merged_anns), num_preds, num_pseudo) 169 | except: 170 | merged_anns += anns_pseudo 171 | 172 | for key in pred_image_to_anns: 173 | if key in pseudo_image_to_anns: 174 | continue 175 | else: 176 | merged_anns += pred_image_to_anns[key] 177 | 178 | # re-generate annotation id 179 | ann_id = 1 180 | for ann in merged_anns: 181 | ann['id'] = ann_id 182 | ann['area'] = ann['bbox'][-1] * ann['bbox'][-2] 183 | ann['iscrowd'] = 0 184 | ann['width'] = ann['segmentation']['size'][0] 185 | ann['height'] = ann['segmentation']['size'][1] 186 | ann_id += 1 187 | 188 | new_dict_filtered['images'] = pseudo_image_list 189 | new_dict_filtered['annotations'] = merged_anns 190 | 191 | # save annotation file 192 | # save_path = os.path.join(args.save_path, "cutler_imagenet1k_train_r{}.json".format(args.n_rounds)) 193 | json.dump(new_dict_filtered, open(args.save_path, 'w')) 194 | print("Done: {} images; {} anns.".format(len(new_dict_filtered['images']), len(new_dict_filtered['annotations']))) -------------------------------------------------------------------------------- /tpnet/tools/run_with_submitit.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | sbatch tools/train-1node.sh \ 3 | --config-file model_zoo/configs/CutLER-ImageNet/cascade_mask_rcnn_R_50_FPN.yaml \ 4 | OUTPUT_DIR /path/to/output -------------------------------------------------------------------------------- /tpnet/tools/run_with_submitit_ssl.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | sbatch tools/train-1node.sh \ 3 | --config-file /private/home/xudongw/cutler-code-release/CutLER/cutler/model_zoo/configs/COCO-Semisupervised/cascade_mask_rcnn_R_50_FPN_50perc.yaml \ -------------------------------------------------------------------------------- /tpnet/tools/single-node_run.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | #!/bin/bash 3 | export DETECTRON2_DATASETS=/path/to/DETECTRON2_DATASETS/ 4 | MASTER_NODE=$(scontrol show hostname "$SLURM_NODELIST" | head -n1) 5 | DIST_URL="tcp://$MASTER_NODE:12399" 6 | SOCKET_NAME=$(ip r | grep default | awk '{print $5}') 7 | export GLOO_SOCKET_IFNAME=$SOCKET_NAME 8 | 9 | python -u train_net.py --num-gpus 8 --num-machines 1 --machine-rank "$SLURM_NODEID" --dist-url "$DIST_URL" "$@" -------------------------------------------------------------------------------- /tpnet/tools/train-1node.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | #!/bin/bash 3 | #SBATCH -p devlab 4 | #SBATCH --nodes=1 5 | #SBATCH --gres=gpu:8 6 | #SBATCH --gpus-per-node=8 7 | #SBATCH --cpus-per-task=80 8 | #SBATCH --mem=512G 9 | #SBATCH --time 2000 10 | #SBATCH -o "submitit/slurm-%j.out" 11 | 12 | srun tools/single-node_run.sh $@ -------------------------------------------------------------------------------- /tpnet/train_net.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # Modified by XuDong Wang from https://github.com/facebookresearch/detectron2/blob/main/tools/train_net.py 4 | 5 | """ 6 | A main training script. 7 | 8 | This scripts reads a given config file and runs the training or evaluation. 9 | It is an entry point that is made to train standard models in detectron2. 10 | 11 | In order to let one script support training of many models, 12 | this script contains logic that are specific to these built-in models and therefore 13 | may not be suitable for your own project. 14 | For example, your research project perhaps only needs a single "evaluator". 15 | 16 | Therefore, we recommend you to use detectron2 as an library and take 17 | this file as an example of how to use the library. 18 | You may want to write your own script with your datasets and other customizations. 19 | """ 20 | 21 | import logging 22 | import os 23 | from collections import OrderedDict 24 | import torch 25 | import detectron2.utils.comm as comm 26 | import detectron2.utils.analysis as ana 27 | from detectron2.checkpoint import DetectionCheckpointer 28 | from detectron2.config import get_cfg 29 | from config import add_cutler_config 30 | from detectron2.data import MetadataCatalog 31 | from engine import DefaultTrainer, default_argument_parser, default_setup 32 | from detectron2.engine import hooks, launch 33 | from detectron2.evaluation import ( 34 | CityscapesInstanceEvaluator, 35 | CityscapesSemSegEvaluator, 36 | # COCOEvaluator, 37 | COCOPanopticEvaluator, 38 | DatasetEvaluators, 39 | LVISEvaluator, 40 | PascalVOCDetectionEvaluator, 41 | SemSegEvaluator, 42 | verify_results, 43 | ) 44 | from evaluation import COCOEvaluator 45 | from detectron2.modeling import GeneralizedRCNNWithTTA 46 | import data # register new datasets 47 | import modeling.roi_heads 48 | 49 | def build_evaluator(cfg, dataset_name, output_folder=None): 50 | """ 51 | Create evaluator(s) for a given dataset. 52 | This uses the special metadata "evaluator_type" associated with each builtin dataset. 53 | For your own dataset, you can simply create an evaluator manually in your 54 | script and do not have to worry about the hacky if-else logic here. 55 | """ 56 | if output_folder is None: 57 | output_folder = os.path.join(cfg.OUTPUT_DIR, "inference") 58 | evaluator_list = [] 59 | evaluator_type = MetadataCatalog.get(dataset_name).evaluator_type 60 | if evaluator_type in ["sem_seg", "coco_panoptic_seg"]: 61 | evaluator_list.append( 62 | SemSegEvaluator( 63 | dataset_name, 64 | distributed=True, 65 | output_dir=output_folder, 66 | ) 67 | ) 68 | if evaluator_type in ["coco", "coco_panoptic_seg"]: 69 | evaluator_list.append(COCOEvaluator(dataset_name, output_dir=output_folder, no_segm=cfg.TEST.NO_SEGM)) 70 | if evaluator_type == "coco_panoptic_seg": 71 | evaluator_list.append(COCOPanopticEvaluator(dataset_name, output_folder)) 72 | if evaluator_type == "cityscapes_instance": 73 | return CityscapesInstanceEvaluator(dataset_name) 74 | if evaluator_type == "cityscapes_sem_seg": 75 | return CityscapesSemSegEvaluator(dataset_name) 76 | elif evaluator_type == "pascal_voc": 77 | return PascalVOCDetectionEvaluator(dataset_name) 78 | elif evaluator_type == "lvis": 79 | return LVISEvaluator(dataset_name, output_dir=output_folder) 80 | if len(evaluator_list) == 0: 81 | raise NotImplementedError( 82 | "no Evaluator for the dataset {} with the type {}".format(dataset_name, evaluator_type) 83 | ) 84 | elif len(evaluator_list) == 1: 85 | return evaluator_list[0] 86 | return DatasetEvaluators(evaluator_list) 87 | 88 | class Trainer(DefaultTrainer): 89 | """ 90 | We use the "DefaultTrainer" which contains pre-defined default logic for 91 | standard training workflow. They may not work for you, especially if you 92 | are working on a new research project. In that case you can write your 93 | own training loop. You can use "tools/plain_train_net.py" as an example. 94 | """ 95 | 96 | @classmethod 97 | def build_evaluator(cls, cfg, dataset_name, output_folder=None): 98 | return build_evaluator(cfg, dataset_name, output_folder) 99 | 100 | @classmethod 101 | def test_with_TTA(cls, cfg, model): 102 | logger = logging.getLogger("detectron2.trainer") 103 | # In the end of training, run an evaluation with TTA 104 | # Only support some R-CNN models. 105 | logger.info("Running inference with test-time augmentation ...") 106 | model = GeneralizedRCNNWithTTA(cfg, model) 107 | evaluators = [ 108 | cls.build_evaluator( 109 | cfg, name, output_folder=os.path.join(cfg.OUTPUT_DIR, "inference_TTA") 110 | ) 111 | for name in cfg.DATASETS.TEST 112 | ] 113 | res = cls.test(cfg, model, evaluators) 114 | res = OrderedDict({k + "_TTA": v for k, v in res.items()}) 115 | return res 116 | 117 | 118 | def setup(args): 119 | """ 120 | Create configs and perform basic setups. 121 | """ 122 | cfg = get_cfg() 123 | add_cutler_config(cfg) 124 | cfg.merge_from_file(args.config_file) 125 | cfg.merge_from_list(args.opts) 126 | # FIXME: brute force changes to test datasets and evaluation tasks 127 | if args.test_dataset != "": cfg.DATASETS.TEST = ((args.test_dataset),) 128 | if args.train_dataset != "": cfg.DATASETS.TRAIN = ((args.train_dataset),) 129 | cfg.TEST.NO_SEGM = args.no_segm 130 | cfg.freeze() 131 | default_setup(cfg, args) 132 | return cfg 133 | 134 | 135 | def main(args): 136 | cfg = setup(args) 137 | 138 | if args.eval_only: 139 | model = Trainer.build_model(cfg) 140 | #print(type(model)) 141 | #tensor = torch.rand(3, 1024, 1024) 142 | #x = ({"image":tensor},) 143 | # 144 | model.eval() 145 | #ret = ana.FlopCountAnalysis(model, x) 146 | ##ret = ana.flop_count_operators(model,x) 147 | # 148 | #print("FLOPs: ", ret.total()) 149 | #print(ret) 150 | #1/0 151 | DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load( 152 | cfg.MODEL.WEIGHTS, resume=args.resume 153 | ) 154 | res = Trainer.test(cfg, model) 155 | if cfg.TEST.AUG.ENABLED: 156 | res.update(Trainer.test_with_TTA(cfg, model)) 157 | if comm.is_main_process(): 158 | verify_results(cfg, res) 159 | return res 160 | 161 | """ 162 | If you'd like to do anything fancier than the standard training logic, 163 | consider writing your own training loop (see plain_train_net.py) or 164 | subclassing the trainer. 165 | """ 166 | trainer = Trainer(cfg) 167 | trainer.resume_or_load(resume=args.resume) 168 | if cfg.TEST.AUG.ENABLED: 169 | trainer.register_hooks( 170 | [hooks.EvalHook(0, lambda: trainer.test_with_TTA(cfg, trainer.model))] 171 | ) 172 | return trainer.train() 173 | 174 | 175 | if __name__ == "__main__": 176 | args = default_argument_parser().parse_args() 177 | # print(args) 178 | # args.opts = postprocess_args(args.opts) 179 | # rint = random.randint(0, 10000) 180 | # args.dist_url = args.dist_url.replace('12399', str(12399 + rint)) 181 | print("Command Line Args:", args) 182 | launch( 183 | main, 184 | args.num_gpus, 185 | num_machines=args.num_machines, 186 | machine_rank=args.machine_rank, 187 | dist_url=args.dist_url, 188 | args=(args,), 189 | ) 190 | --------------------------------------------------------------------------------