├── .gitignore ├── README.md ├── clip ├── __init__.py ├── adapters.py ├── bpe_simple_vocab_16e6.txt.gz ├── clip.py ├── model.py └── simple_tokenizer.py ├── config ├── __init__.py └── defaults.py ├── configs ├── coco.yml └── nus.yml ├── datasets ├── MultiLabel │ └── classification.py ├── __init__.py ├── bases.py ├── json │ ├── test_17_filtered.json │ ├── test_65_filtered.json │ └── train_48_filtered.json └── make_dataloader.py ├── loss ├── __init__.py ├── asymmetric_loss.py ├── mmc_loss.py └── seesawloss.py ├── model ├── __init__.py ├── base.py ├── model.py └── ot_solver.py ├── processor ├── __init__.py └── processor.py ├── requirements.txt ├── solver ├── __init__.py ├── make_optimizer.py └── make_scheduler.py ├── src └── method.png ├── train.py └── utils ├── __init__.py ├── logger.py ├── meter.py ├── metrics.py └── model_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # UV 98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | #uv.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # pdm 111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 112 | #pdm.lock 113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 114 | # in version control. 115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 116 | .pdm.toml 117 | .pdm-python 118 | .pdm-build/ 119 | 120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 121 | __pypackages__/ 122 | 123 | # Celery stuff 124 | celerybeat-schedule 125 | celerybeat.pid 126 | 127 | # SageMath parsed files 128 | *.sage.py 129 | 130 | # Environments 131 | .env 132 | .venv 133 | env/ 134 | venv/ 135 | ENV/ 136 | env.bak/ 137 | venv.bak/ 138 | 139 | # Spyder project settings 140 | .spyderproject 141 | .spyproject 142 | 143 | # Rope project settings 144 | .ropeproject 145 | 146 | # mkdocs documentation 147 | /site 148 | 149 | # mypy 150 | .mypy_cache/ 151 | .dmypy.json 152 | dmypy.json 153 | 154 | # Pyre type checker 155 | .pyre/ 156 | 157 | # pytype static type analyzer 158 | .pytype/ 159 | 160 | # Cython debug symbols 161 | cython_debug/ 162 | 163 | # PyCharm 164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 166 | # and can be added to the global gitignore or merged into this file. For a more nuclear 167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 168 | #.idea/ 169 | 170 | # PyPI configuration file 171 | .pypirc 172 | 173 | # Temp folders 174 | data/ 175 | wandb/ 176 | checkpoints/ 177 | .vscode/ 178 | 179 | # logging files 180 | runs/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # RAM: Open-Vocabulary Multi-Label Recognition through Knowledge-Constrained Optimal Transport 2 | 3 | Official implementation of the [paper](https://arxiv.org/abs/2503.15337) in CVPR 2025: 4 | 5 | **Recover and Match: Open-Vocabulary Multi-Label Recognition through Knowledge-Constrained Optimal Transport** 6 | 7 | ## 📨 Introduction 8 | 9 | RAM is an efficient matching framework for OVMLR (Open-Vocabulary Multi-Label Recognition). To address the urgent problems in existing methods, RAM involves (1) LLA to recover regional semantics, and (2) KCOT to find precise region-to-label matching. 10 | 11 |

12 | RAM Framework 13 |

14 | 15 | 16 | ## 🔧 Installation 17 | 18 | Install the environment through conda and pip is recommended: 19 | 20 | ```shell 21 | conda create -n ram python=3.10 22 | conda activate ram 23 | 24 | # Install the dependencies 25 | pip install -r requirements.txt 26 | ``` 27 | 28 | 29 | ## 🎯 Running the code 30 | - `model/model.py`: Implementation of RAM model 31 | - `model/ot_solver.py`: Implementation of Sinkhorn Algorithm 32 | - `clip/adapters.py`: Implementation of LLA (Local Adapter) 33 | - `loss/mmc_loss.py`: Implementation of MMC loss (Multi-Matching loss) 34 | 35 | Run the following code to start training: 36 | ```shell 37 | python train.py --config_file configs/coco.yml 38 | ``` 39 | Use wandb to log the running: 40 | ```shell 41 | python train.py --config_file configs/coco.yml WANDB True 42 | ``` 43 | 44 | ## 💬 Discussion 45 | The core contribution is the OT-based matching pipeline, which we found beneficial to the OVMLR task while remaining highly efficient. 46 | The matching framework can be easily extended to dense prediction tasks (e.g., **semantic segmentation**). Welcome to transfer our approach to the segmentation scenarios. 47 | 48 | If you find our work useful, please cite our paper: 49 | 50 | ``` 51 | @article{tan2025recoverandmatch, 52 | title={Recover and Match: Open-Vocabulary Multi-Label Recognition through Knowledge-Constrained Optimal Transport}, 53 | author={Hao Tan and Zichang Tan and Jun Li and Ajian Liu and Jun Wan and Zhen Lei}, 54 | journal={arXiv preprint arXiv:2503.15337}, 55 | year={2025} 56 | } 57 | ``` 58 | 59 | ## Acknowledgements 60 | 61 | This repo benefits from [MaPLe](https://github.com/muzairkhattak/multimodal-prompt-learning), [CLIP-Surgery](https://github.com/xmed-lab/CLIP_Surgery) and [POT](https://github.com/PythonOT/POT). Thanks for their wonderful works. 62 | -------------------------------------------------------------------------------- /clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .clip import * 2 | from .model import CLIP, convert_weights 3 | -------------------------------------------------------------------------------- /clip/adapters.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | 6 | 7 | class h_sigmoid(nn.Module): 8 | def __init__(self, inplace=True): 9 | super(h_sigmoid, self).__init__() 10 | self.relu = nn.ReLU6(inplace=inplace) 11 | 12 | def forward(self, x): 13 | return self.relu(x + 3) / 6 14 | 15 | 16 | class h_swish(nn.Module): 17 | def __init__(self, inplace=True): 18 | super(h_swish, self).__init__() 19 | self.sigmoid = h_sigmoid(inplace=inplace) 20 | 21 | def forward(self, x): 22 | return x * self.sigmoid(x) 23 | 24 | 25 | class QuickGELU(nn.Module): 26 | def forward(self, x: torch.Tensor): 27 | return x * torch.sigmoid(1.702 * x) 28 | 29 | 30 | 31 | class PFCrossAttention(nn.Module): 32 | """ 33 | Parameter-free Cross-Attention 34 | https://arxiv.org/abs/2209.14169 35 | """ 36 | def __init__(self, dim, qk_scale=None, attn_drop=0., proj_drop=0.): 37 | super().__init__() 38 | self.scale = qk_scale or dim ** -0.5 39 | 40 | self.attn_drop = nn.Dropout(attn_drop) 41 | self.proj_drop = nn.Dropout(proj_drop) 42 | 43 | def forward(self, query, kv): 44 | k, v = kv, kv 45 | 46 | attn = (query @ k.transpose(1,2)) * self.scale 47 | attn = attn.softmax(dim=-1) 48 | attn = self.attn_drop(attn) 49 | 50 | query = attn @ v 51 | query = self.proj_drop(query) 52 | 53 | return query 54 | 55 | 56 | 57 | class LocalAdapter(nn.Module): 58 | def __init__(self, in_dim, out_dim, stride=1, hidden_dim=None, kernel_size=3, text_dim=512): 59 | super().__init__() 60 | hidden_dim = text_dim 61 | layers1 = [ 62 | nn.Conv2d(in_dim, 63 | hidden_dim, 64 | kernel_size=1, 65 | stride=stride, 66 | padding=1//2, 67 | bias=False), 68 | nn.BatchNorm2d(hidden_dim), 69 | nn.GELU() 70 | ] 71 | self.conv_adapter_layers1 = nn.Sequential(*layers1) 72 | layers2 = [ 73 | nn.Conv2d(in_dim, 74 | hidden_dim, 75 | kernel_size=3, 76 | stride=stride, 77 | padding=3//2, 78 | bias=False), 79 | nn.BatchNorm2d(hidden_dim), 80 | nn.GELU() 81 | ] 82 | self.conv_adapter_layers2 = nn.Sequential(*layers2) 83 | 84 | self.conv_adapter_layers = nn.Conv2d(2, 2, kernel_size=kernel_size, stride=stride, padding=kernel_size//2, bias=False) 85 | self.conv_adapter_final = nn.Conv2d(hidden_dim, out_dim, kernel_size=1, stride=stride, padding=1//2, bias=False) 86 | 87 | self.adapter_norm_q1 = nn.LayerNorm(text_dim) 88 | self.adapter_norm_q2 = nn.LayerNorm(text_dim) 89 | self.adapter_norm_kv = nn.LayerNorm(text_dim) 90 | self.scale = text_dim ** -0.5 91 | self.adapter_cattn = PFCrossAttention(text_dim) 92 | 93 | def forward(self, x, text_fea=None): 94 | # x: channel first 95 | x_cls, x = x[:1], x[1:] 96 | tok, B, dim = x.shape 97 | H = int(tok**0.5) 98 | x_loc = x.permute(1, 2, 0).reshape(B, dim, H, H) 99 | 100 | # Dual convolutional streams 101 | x_loc1 = self.conv_adapter_layers1(x_loc) 102 | x_loc2 = self.conv_adapter_layers2(x_loc) 103 | x_loc1 = x_loc1.reshape(B, -1, tok) 104 | x_loc2 = x_loc2.reshape(B, -1, tok) 105 | 106 | # Cross attention with text features 107 | x_loc1 = x_loc1 + self.adapter_cattn( 108 | self.adapter_norm_q1(x_loc1.permute(0, 2, 1)), 109 | self.adapter_norm_kv(text_fea.permute(1, 0, 2)) 110 | ).permute(0, 2, 1) 111 | x_loc2 = x_loc2 + self.adapter_cattn( 112 | self.adapter_norm_q2(x_loc2.permute(0, 2, 1)), 113 | self.adapter_norm_kv(text_fea.permute(1, 0, 2)) 114 | ).permute(0, 2, 1) 115 | 116 | # Reshape and concat 117 | x_loc1 = x_loc1.reshape(B, -1, H, H) 118 | x_loc2 = x_loc2.reshape(B, -1, H, H) 119 | x_loc = torch.cat([x_loc1, x_loc2], dim=1) 120 | 121 | # Max and Average pooling 122 | avg_x = torch.mean(x_loc, dim=1, keepdim=True) 123 | max_x, _ = torch.max(x_loc, dim=1, keepdim=True) 124 | 125 | # Aggregated convolution 126 | agg = torch.cat([avg_x, max_x], dim=1) 127 | y = self.conv_adapter_layers(agg) 128 | y = F.sigmoid(y) 129 | 130 | # Final multiplication and convolution 131 | x = x_loc1 * y[:, 0].unsqueeze(1) + x_loc2 * y[:, 1].unsqueeze(1) 132 | x = self.conv_adapter_final(x) 133 | x = x.reshape(B, -1, tok).permute(2, 0, 1) 134 | x = torch.cat([x_cls, x], dim=0) 135 | 136 | return x 137 | 138 | -------------------------------------------------------------------------------- /clip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EricTan7/RAM/477f5051234819a11ca62ec441ecb3ee33abda65/clip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /clip/clip.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import urllib 4 | import warnings 5 | from typing import Union, List 6 | from pkg_resources import packaging 7 | 8 | import torch 9 | from PIL import Image 10 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize 11 | from tqdm import tqdm 12 | 13 | from .model import build_model 14 | from .simple_tokenizer import SimpleTokenizer as _Tokenizer 15 | 16 | try: 17 | from torchvision.transforms import InterpolationMode 18 | BICUBIC = InterpolationMode.BICUBIC 19 | except ImportError: 20 | BICUBIC = Image.BICUBIC 21 | 22 | 23 | if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"): 24 | warnings.warn("PyTorch version 1.7.1 or higher is recommended") 25 | 26 | 27 | __all__ = ["available_models", "load", "tokenize"] 28 | _tokenizer = _Tokenizer() 29 | 30 | _MODELS = { 31 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 32 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 33 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", 34 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", 35 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 36 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", 37 | "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", 38 | "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt", 39 | } 40 | 41 | 42 | def _download(url: str, root: str = os.path.expanduser("/mnt/nas/TrueNas1/pretrain/clip")): 43 | os.makedirs(root, exist_ok=True) 44 | filename = os.path.basename(url) 45 | 46 | expected_sha256 = url.split("/")[-2] 47 | download_target = os.path.join(root, filename) 48 | 49 | if os.path.exists(download_target) and not os.path.isfile(download_target): 50 | raise RuntimeError(f"{download_target} exists and is not a regular file") 51 | 52 | if os.path.isfile(download_target): 53 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: 54 | return download_target 55 | else: 56 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") 57 | 58 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 59 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop: 60 | while True: 61 | buffer = source.read(8192) 62 | if not buffer: 63 | break 64 | 65 | output.write(buffer) 66 | loop.update(len(buffer)) 67 | 68 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: 69 | raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") 70 | 71 | return download_target 72 | 73 | 74 | def _transform(n_px): 75 | return Compose([ 76 | Resize(n_px, interpolation=BICUBIC), 77 | CenterCrop(n_px), 78 | lambda image: image.convert("RGB"), 79 | ToTensor(), 80 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 81 | ]) 82 | 83 | 84 | def available_models() -> List[str]: 85 | """Returns the names of available CLIP models""" 86 | return list(_MODELS.keys()) 87 | 88 | 89 | def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=False, download_root=None, 90 | input_size=None, design_details=None): 91 | """Load a CLIP model 92 | 93 | Parameters 94 | ---------- 95 | name : str 96 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 97 | 98 | device : Union[str, torch.device] 99 | The device to put the loaded model 100 | 101 | jit : bool 102 | Whether to load the optimized JIT model or more hackable non-JIT model (default). 103 | 104 | Returns 105 | ------- 106 | model : torch.nn.Module 107 | The CLIP model 108 | 109 | preprocess : Callable[[PIL.Image], torch.Tensor] 110 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 111 | """ 112 | if name in _MODELS: 113 | model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip")) 114 | elif os.path.isfile(name): 115 | model_path = name 116 | else: 117 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}") 118 | 119 | try: 120 | # loading JIT archive 121 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() 122 | state_dict = None 123 | except RuntimeError: 124 | # loading saved state dict 125 | if jit: 126 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") 127 | jit = False 128 | state_dict = torch.load(model_path, map_location="cpu") 129 | 130 | if not jit: 131 | model = build_model(state_dict or model.state_dict(), input_size, design_details).to(device) 132 | if str(device) == "cpu": 133 | model.float() 134 | return model, _transform(model.visual.input_resolution) 135 | 136 | # patch the device names 137 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) 138 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] 139 | 140 | def patch_device(module): 141 | try: 142 | graphs = [module.graph] if hasattr(module, "graph") else [] 143 | except RuntimeError: 144 | graphs = [] 145 | 146 | if hasattr(module, "forward1"): 147 | graphs.append(module.forward1.graph) 148 | 149 | for graph in graphs: 150 | for node in graph.findAllNodes("prim::Constant"): 151 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): 152 | node.copyAttributes(device_node) 153 | 154 | model.apply(patch_device) 155 | patch_device(model.encode_image) 156 | patch_device(model.encode_text) 157 | 158 | # patch dtype to float32 on CPU 159 | if str(device) == "cpu": 160 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) 161 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 162 | float_node = float_input.node() 163 | 164 | def patch_float(module): 165 | try: 166 | graphs = [module.graph] if hasattr(module, "graph") else [] 167 | except RuntimeError: 168 | graphs = [] 169 | 170 | if hasattr(module, "forward1"): 171 | graphs.append(module.forward1.graph) 172 | 173 | for graph in graphs: 174 | for node in graph.findAllNodes("aten::to"): 175 | inputs = list(node.inputs()) 176 | for i in [1, 2]: # dtype can be the second or third argument to aten::to() 177 | if inputs[i].node()["value"] == 5: 178 | inputs[i].node().copyAttributes(float_node) 179 | 180 | model.apply(patch_float) 181 | patch_float(model.encode_image) 182 | patch_float(model.encode_text) 183 | 184 | model.float() 185 | 186 | return model, _transform(model.input_resolution.item()) 187 | 188 | 189 | def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor: 190 | """ 191 | Returns the tokenized representation of given input string(s) 192 | 193 | Parameters 194 | ---------- 195 | texts : Union[str, List[str]] 196 | An input string or a list of input strings to tokenize 197 | 198 | context_length : int 199 | The context length to use; all CLIP models use 77 as the context length 200 | 201 | truncate: bool 202 | Whether to truncate the text in case its encoding is longer than the context length 203 | 204 | Returns 205 | ------- 206 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] 207 | """ 208 | if isinstance(texts, str): 209 | texts = [texts] 210 | 211 | sot_token = _tokenizer.encoder["<|startoftext|>"] 212 | eot_token = _tokenizer.encoder["<|endoftext|>"] 213 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 214 | if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"): 215 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 216 | else: 217 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.int) 218 | 219 | for i, tokens in enumerate(all_tokens): 220 | if len(tokens) > context_length: 221 | if truncate: 222 | tokens = tokens[:context_length] 223 | tokens[-1] = eot_token 224 | else: 225 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") 226 | result[i, :len(tokens)] = torch.tensor(tokens) 227 | 228 | return result 229 | -------------------------------------------------------------------------------- /clip/model.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from typing import Tuple, Union 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from torch import nn 8 | import math 9 | from torch.utils.checkpoint import checkpoint 10 | from .adapters import LocalAdapter 11 | 12 | 13 | 14 | class LayerNorm(nn.LayerNorm): 15 | """Subclass torch's LayerNorm to handle fp16.""" 16 | 17 | def forward(self, x: torch.Tensor): 18 | orig_type = x.dtype 19 | ret = super().forward(x.type(torch.float32)) 20 | return ret.type(orig_type) 21 | 22 | 23 | class QuickGELU(nn.Module): 24 | def forward(self, x: torch.Tensor): 25 | return x * torch.sigmoid(1.702 * x) 26 | 27 | 28 | 29 | class ModifiedAttention(nn.Module): 30 | def __init__(self, out_dim, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., settings=''): 31 | super().__init__() 32 | self.num_heads = num_heads 33 | head_dim = dim // num_heads 34 | self.scale = qk_scale or head_dim ** -0.5 35 | 36 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 37 | self.attn_drop = nn.Dropout(attn_drop) 38 | self.out_proj = nn.Linear(out_dim, dim) 39 | self.proj_drop = nn.Dropout(proj_drop) 40 | self.settings = settings 41 | 42 | def forward(self, x, n_adapt=-1): 43 | B, N, C = x.shape 44 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 45 | q, k, v = qkv[0], qkv[1], qkv[2] 46 | 47 | # original self-attention 48 | attn = (q @ k.transpose(-2, -1)) * self.scale 49 | attn = attn.softmax(dim=-1) 50 | attn = self.attn_drop(attn) 51 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 52 | x = self.proj_drop(self.out_proj(x)) 53 | 54 | v = v[:, :, :-n_adapt] if n_adapt > 0 else v 55 | attn_loc = (v @ v.transpose(-2, -1)) * self.scale 56 | attn_loc = (attn_loc).softmax(dim=-1) 57 | attn_loc = self.attn_drop(attn_loc) 58 | 59 | x_loc = (attn_loc @ v).transpose(1, 2).reshape(B, -1, C) 60 | x_loc = self.proj_drop(self.out_proj(x_loc)) 61 | 62 | return [x_loc, x] 63 | 64 | 65 | 66 | class Bottleneck(nn.Module): 67 | expansion = 4 68 | 69 | def __init__(self, inplanes, planes, stride=1): 70 | super().__init__() 71 | 72 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 73 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) 74 | self.bn1 = nn.BatchNorm2d(planes) 75 | self.relu1 = nn.ReLU(inplace=True) 76 | 77 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) 78 | self.bn2 = nn.BatchNorm2d(planes) 79 | self.relu2 = nn.ReLU(inplace=True) 80 | 81 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() 82 | 83 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) 84 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 85 | self.relu3 = nn.ReLU(inplace=True) 86 | 87 | self.downsample = None 88 | self.stride = stride 89 | 90 | if stride > 1 or inplanes != planes * Bottleneck.expansion: 91 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 92 | self.downsample = nn.Sequential(OrderedDict([ 93 | ("-1", nn.AvgPool2d(stride)), 94 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), 95 | ("1", nn.BatchNorm2d(planes * self.expansion)) 96 | ])) 97 | 98 | def forward(self, x: torch.Tensor): 99 | identity = x 100 | 101 | out = self.relu1(self.bn1(self.conv1(x))) 102 | out = self.relu2(self.bn2(self.conv2(out))) 103 | out = self.avgpool(out) 104 | out = self.bn3(self.conv3(out)) 105 | 106 | if self.downsample is not None: 107 | identity = self.downsample(x) 108 | 109 | out += identity 110 | out = self.relu3(out) 111 | return out 112 | 113 | 114 | 115 | class AttentionPool2d(nn.Module): 116 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): 117 | super().__init__() 118 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) 119 | self.k_proj = nn.Linear(embed_dim, embed_dim) 120 | self.q_proj = nn.Linear(embed_dim, embed_dim) 121 | self.v_proj = nn.Linear(embed_dim, embed_dim) 122 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) 123 | self.num_heads = num_heads 124 | 125 | def forward(self, x): 126 | x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC 127 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC 128 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC 129 | x, _ = F.multi_head_attention_forward( 130 | query=x[:1], key=x, value=x, 131 | embed_dim_to_check=x.shape[-1], 132 | num_heads=self.num_heads, 133 | q_proj_weight=self.q_proj.weight, 134 | k_proj_weight=self.k_proj.weight, 135 | v_proj_weight=self.v_proj.weight, 136 | in_proj_weight=None, 137 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), 138 | bias_k=None, 139 | bias_v=None, 140 | add_zero_attn=False, 141 | dropout_p=0, 142 | out_proj_weight=self.c_proj.weight, 143 | out_proj_bias=self.c_proj.bias, 144 | use_separate_proj_weight=True, 145 | training=self.training, 146 | need_weights=False 147 | ) 148 | return x.squeeze(0) 149 | 150 | 151 | 152 | class ModifiedResNet(nn.Module): 153 | """ 154 | A ResNet class that is similar to torchvision's but contains the following changes: 155 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. 156 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 157 | - The final pooling layer is a QKV attention instead of an average pool 158 | """ 159 | 160 | def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): 161 | super().__init__() 162 | self.output_dim = output_dim 163 | self.input_resolution = input_resolution 164 | 165 | # the 3-layer stem 166 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) 167 | self.bn1 = nn.BatchNorm2d(width // 2) 168 | self.relu1 = nn.ReLU(inplace=True) 169 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) 170 | self.bn2 = nn.BatchNorm2d(width // 2) 171 | self.relu2 = nn.ReLU(inplace=True) 172 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) 173 | self.bn3 = nn.BatchNorm2d(width) 174 | self.relu3 = nn.ReLU(inplace=True) 175 | self.avgpool = nn.AvgPool2d(2) 176 | 177 | # residual layers 178 | self._inplanes = width # this is a *mutable* variable used during construction 179 | self.layer1 = self._make_layer(width, layers[0]) 180 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2) 181 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2) 182 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2) 183 | 184 | embed_dim = width * 32 # the ResNet feature dimension 185 | self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim) 186 | 187 | def _make_layer(self, planes, blocks, stride=1): 188 | layers = [Bottleneck(self._inplanes, planes, stride)] 189 | 190 | self._inplanes = planes * Bottleneck.expansion 191 | for _ in range(1, blocks): 192 | layers.append(Bottleneck(self._inplanes, planes)) 193 | 194 | return nn.Sequential(*layers) 195 | 196 | def forward(self, x): 197 | def stem(x): 198 | x = self.relu1(self.bn1(self.conv1(x))) 199 | x = self.relu2(self.bn2(self.conv2(x))) 200 | x = self.relu3(self.bn3(self.conv3(x))) 201 | x = self.avgpool(x) 202 | return x 203 | 204 | x = x.type(self.conv1.weight.dtype) 205 | x = stem(x) 206 | x = self.layer1(x) 207 | x = self.layer2(x) 208 | x = self.layer3(x) 209 | x = self.layer4(x) 210 | x = self.attnpool(x) 211 | 212 | return x 213 | 214 | 215 | 216 | class ResidualAttentionBlock(nn.Module): 217 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None, text_layer=False, i=0, design_details=None): 218 | super().__init__() 219 | 220 | self.attn = nn.MultiheadAttention(d_model, n_head) 221 | self.ln_1 = LayerNorm(d_model) 222 | self.mlp = nn.Sequential(OrderedDict([ 223 | ("c_fc", nn.Linear(d_model, d_model * 4)), 224 | ("gelu", QuickGELU()), 225 | ("c_proj", nn.Linear(d_model * 4, d_model)) 226 | ])) 227 | self.ln_2 = LayerNorm(d_model) 228 | self.attn_mask = attn_mask 229 | self.text_layer = text_layer 230 | adapt_layer = design_details["depth_text"] if self.text_layer else design_details["depth_vision"] 231 | self.adapt = True if i in adapt_layer else False 232 | self.clip_text_len = 77 233 | 234 | if self.text_layer and self.adapt: 235 | self.n_ctx_text = design_details.get("text_ctx", 0) 236 | ctx_vectors = torch.empty(self.n_ctx_text, d_model) 237 | nn.init.normal_(ctx_vectors, std=0.02) 238 | self.deep_embeds = nn.Parameter(ctx_vectors) 239 | 240 | if not self.text_layer and self.adapt: 241 | kernel_size = design_details["kernel_size"] 242 | self.conv_adapter = LocalAdapter(d_model, d_model, kernel_size=kernel_size) 243 | self.n_adapt = design_details.get("vision_adapt", 0) 244 | ctx_vectors = torch.empty(self.n_adapt, d_model) 245 | nn.init.normal_(ctx_vectors, std=0.02) 246 | self.adapt_embeds = nn.Parameter(ctx_vectors) 247 | 248 | def attention(self, x: torch.Tensor, n_adapt=-1): 249 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None 250 | if isinstance(self.attn, ModifiedAttention): 251 | x = x.transpose(0, 1) 252 | x_loc, x = self.attn(x, n_adapt) 253 | return [x_loc.transpose(0, 1), x.transpose(0, 1)] 254 | else: 255 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] 256 | 257 | def forward(self, x: torch.Tensor, text_fea=None): 258 | if isinstance(self.attn, ModifiedAttention): 259 | if isinstance(x, list): 260 | x_loc, x = x 261 | else: 262 | x_loc, x = x, x 263 | if self.adapt: 264 | prefix = x[:, :, :] 265 | adapt_embeds = self.adapt_embeds.expand(x.shape[1], -1, -1).permute(1, 0, 2).half() 266 | x = torch.cat([prefix, adapt_embeds], dim=0) 267 | x_loc_res_attn, x_res = self.attention(self.ln_1(x), n_adapt=self.n_adapt) 268 | x = x + x_res 269 | x = x + self.mlp(self.ln_2(x)) 270 | x = x[:x.shape[0]-self.n_adapt] 271 | x_loc_res_conv = self.conv_adapter(x, text_fea=text_fea) 272 | x_loc_res = x_loc_res_attn + x_loc_res_conv 273 | else: 274 | x_loc_res_attn, x_res = self.attention(self.ln_1(x)) 275 | x = x + x_res 276 | x = x + self.mlp(self.ln_2(x)) 277 | x_loc_res = x_loc_res_attn 278 | 279 | x_loc = x_loc + x_loc_res 280 | return [x_loc, x] 281 | 282 | elif self.text_layer: 283 | if self.adapt: 284 | prefix = x[:1, :, :] 285 | suffix = x[1:, :, :] 286 | deep_embeds = self.deep_embeds.expand(x.shape[1], -1, -1).permute(1, 0, 2).half() 287 | x = torch.cat([prefix, deep_embeds, suffix], dim=0) 288 | x_pad = x[self.clip_text_len:] 289 | x = x[:self.clip_text_len] 290 | x = x + self.attention(self.ln_1(x)) 291 | x = x + self.mlp(self.ln_2(x)) 292 | prefix = x[:1, :, :] 293 | suffix = x[1 + self.n_ctx_text:, :, :] 294 | x = torch.cat([prefix, suffix, x_pad], dim=0) 295 | else: 296 | x = x + self.attention(self.ln_1(x)) 297 | x = x + self.mlp(self.ln_2(x)) 298 | return x 299 | 300 | 301 | class Transformer(nn.Module): 302 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None, prompts_needed=None, 303 | text_layer=False, design_details=None): 304 | super().__init__() 305 | self.width = width 306 | self.layers = layers 307 | self.text_layer = text_layer 308 | 309 | self.resblocks = nn.Sequential(*[ 310 | ResidualAttentionBlock(width, heads, attn_mask, text_layer, i, design_details) 311 | for i in range(layers) 312 | ]) 313 | 314 | def CLIP_forward(self, x, out_layers): 315 | out_tokens = [] 316 | if out_layers is None: 317 | out_layers = [len(self.resblocks)-1] 318 | for idx, block in enumerate(self.resblocks): 319 | x = block(x) 320 | if idx in out_layers: 321 | out_tokens.append(x) 322 | return x, out_tokens 323 | 324 | def SAA_CLIP_forward(self, x, out_layers, text_fea=None): 325 | out_tokens = [] 326 | if out_layers is None: 327 | out_layers = [len(self.resblocks)-1] 328 | for idx, block in enumerate(self.resblocks): 329 | if idx == 0: 330 | x_loc, x = block(x, text_fea=text_fea) 331 | else: 332 | x_loc, x = block([x_loc, x], text_fea=text_fea) 333 | if idx in out_layers: 334 | out_tokens.append(x_loc) 335 | return x, out_tokens 336 | 337 | def forward(self, x, out_layers=None, use_SAA=True, use_checkpoint=False, text_fea=None): 338 | if self.text_layer: 339 | if use_checkpoint: 340 | x = checkpoint(self.resblocks[:6], x) 341 | x = checkpoint(self.resblocks[6:], x) 342 | return x 343 | else: 344 | return self.resblocks(x) 345 | 346 | else: 347 | if use_SAA: 348 | return self.SAA_CLIP_forward(x, out_layers, text_fea) 349 | else: 350 | return self.CLIP_forward(x, out_layers) 351 | 352 | 353 | 354 | class VisionTransformer(nn.Module): 355 | def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, 356 | output_dim: int, design_details): 357 | super().__init__() 358 | self.input_resolution = input_resolution 359 | self.output_dim = output_dim 360 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) 361 | 362 | SAA_layer = design_details["SAA_layer"] # [-1] or [2,3,5,...] 363 | assert SAA_layer[-1] <= layers-1, "SAA depth {} should not exceed Max layers {}".format(SAA_layer, layers) 364 | self.SAA_layer = SAA_layer if SAA_layer[0]!=-1 else None 365 | 366 | scale = width ** -0.5 367 | self.class_embedding = nn.Parameter(scale * torch.randn(width)) 368 | self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) 369 | self.ln_pre = LayerNorm(width) 370 | 371 | self.transformer = Transformer(width, layers, heads, design_details=design_details) 372 | 373 | self.ln_post = LayerNorm(width) 374 | self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) 375 | 376 | self.attn = None 377 | self.embed_dim = width 378 | self.num_heads = heads 379 | 380 | @torch.no_grad() 381 | def SAA_replace(self): 382 | if self.SAA_layer is not None: 383 | for i in self.SAA_layer: 384 | self.attn = ModifiedAttention(self.embed_dim, self.embed_dim, self.num_heads, True) 385 | self.attn.qkv.weight.data = self.transformer.resblocks[i].attn.in_proj_weight.clone() 386 | self.attn.qkv.bias.data = self.transformer.resblocks[i].attn.in_proj_bias.clone() 387 | self.attn.out_proj.weight.data = self.transformer.resblocks[i].attn.out_proj.weight.clone() 388 | self.attn.out_proj.bias.data = self.transformer.resblocks[i].attn.out_proj.bias.clone() 389 | self.transformer.resblocks[i].attn = self.attn 390 | print(f"SAA replace Done") 391 | 392 | def forward(self, x: torch.Tensor, out_layers=[11], use_SAA=True, text_fea=None): 393 | x = self.conv1(x) 394 | x = x.reshape(x.shape[0], x.shape[1], -1) 395 | x = x.permute(0, 2, 1) 396 | x = torch.cat( 397 | [self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), 398 | x], dim=1) 399 | x = x + self.positional_embedding.to(x.dtype) 400 | 401 | # Normal code as before 402 | x = self.ln_pre(x) 403 | 404 | x = x.permute(1, 0, 2) # NLD -> LND 405 | x, mid_features = self.transformer(x, out_layers, use_SAA=use_SAA, text_fea=text_fea) 406 | x = self.ln_post(x.permute(1, 0, 2)) 407 | mid_features_proj = [] 408 | for fea in mid_features: 409 | fea = self.ln_post(fea.permute(1, 0, 2)) @ self.proj # LND -> NLD 410 | mid_features_proj.append(fea) 411 | 412 | return x[:, :1, :] @ self.proj, mid_features_proj 413 | 414 | 415 | 416 | 417 | class CLIP(nn.Module): 418 | def __init__(self, 419 | embed_dim: int, 420 | image_resolution: int, 421 | vision_layers: Union[Tuple[int, int, int, int], int], 422 | vision_width: int, 423 | vision_patch_size: int, 424 | context_length: int, 425 | vocab_size: int, 426 | transformer_width: int, 427 | transformer_heads: int, 428 | transformer_layers: int, 429 | design_details 430 | ): 431 | super().__init__() 432 | 433 | self.context_length = context_length 434 | 435 | if isinstance(vision_layers, (tuple, list)): 436 | vision_heads = vision_width * 32 // 64 437 | self.visual = ModifiedResNet( 438 | layers=vision_layers, 439 | output_dim=embed_dim, 440 | heads=vision_heads, 441 | input_resolution=image_resolution, 442 | width=vision_width, 443 | design_details=design_details 444 | ) 445 | else: 446 | 447 | vision_heads = vision_width // 64 448 | self.visual = VisionTransformer( 449 | input_resolution=image_resolution, 450 | patch_size=vision_patch_size, 451 | width=vision_width, 452 | layers=vision_layers, 453 | heads=vision_heads, 454 | output_dim=embed_dim, 455 | design_details=design_details 456 | ) 457 | 458 | self.transformer = Transformer( 459 | width=transformer_width, 460 | layers=transformer_layers, 461 | heads=transformer_heads, 462 | attn_mask=self.build_attention_mask(), 463 | text_layer=True, 464 | design_details=design_details 465 | ) 466 | 467 | self.vocab_size = vocab_size 468 | self.token_embedding = nn.Embedding(vocab_size, transformer_width) 469 | self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) 470 | self.ln_final = LayerNorm(transformer_width) 471 | 472 | self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) 473 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 474 | 475 | self.initialize_parameters() 476 | 477 | def initialize_parameters(self): 478 | nn.init.normal_(self.token_embedding.weight, std=0.02) 479 | nn.init.normal_(self.positional_embedding, std=0.01) 480 | 481 | if isinstance(self.visual, ModifiedResNet): 482 | if self.visual.attnpool is not None: 483 | std = self.visual.attnpool.c_proj.in_features ** -0.5 484 | nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) 485 | nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) 486 | nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) 487 | nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) 488 | 489 | for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]: 490 | for name, param in resnet_block.named_parameters(): 491 | if name.endswith("bn3.weight"): 492 | nn.init.zeros_(param) 493 | 494 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) 495 | attn_std = self.transformer.width ** -0.5 496 | fc_std = (2 * self.transformer.width) ** -0.5 497 | for block in self.transformer.resblocks: 498 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std) 499 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std) 500 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) 501 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) 502 | 503 | if self.text_projection is not None: 504 | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) 505 | 506 | def build_attention_mask(self): 507 | # lazily create causal attention mask, with full attention between the vision tokens 508 | # pytorch uses additive attention mask; fill with -inf 509 | mask = torch.empty(self.context_length, self.context_length) 510 | mask.fill_(float("-inf")) 511 | mask.triu_(1) # zero out the lower diagonal 512 | return mask 513 | 514 | @property 515 | def dtype(self): 516 | return self.visual.conv1.weight.dtype 517 | 518 | def encode_image(self, image, features_layers=None, ffn=False): 519 | return self.visual(image.type(self.dtype), features_layers, ffn) 520 | 521 | def encode_text(self, text, use_checkpoint=False): 522 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] 523 | 524 | x = x + self.positional_embedding.type(self.dtype) 525 | x = x.permute(1, 0, 2) # NLD -> LND 526 | x = self.transformer(x, use_checkpoint=use_checkpoint) 527 | x = x.permute(1, 0, 2) # LND -> NLD 528 | x = self.ln_final(x).type(self.dtype) 529 | 530 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection 531 | 532 | return x 533 | 534 | def forward(self, image, text): 535 | image_features = self.encode_image(image)[:, 0, :] 536 | text_features = self.encode_text(text) 537 | 538 | # normalized features 539 | image_features = image_features / image_features.norm(dim=-1, keepdim=True) 540 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 541 | 542 | # cosine similarity as logits 543 | logit_scale = self.logit_scale.exp() 544 | logits_per_image = logit_scale * image_features @ text_features.t() 545 | logits_per_text = logit_scale * text_features @ image_features.t() 546 | 547 | return logits_per_image, logits_per_text 548 | 549 | 550 | 551 | def convert_weights(model: nn.Module): 552 | """Convert applicable model parameters to fp16""" 553 | 554 | def _convert_weights_to_fp16(l): 555 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): 556 | l.weight.data = l.weight.data.half() 557 | if l.bias is not None: 558 | l.bias.data = l.bias.data.half() 559 | 560 | if isinstance(l, nn.MultiheadAttention): 561 | for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: 562 | tensor = getattr(l, attr) 563 | if tensor is not None: 564 | tensor.data = tensor.data.half() 565 | 566 | for name in ["text_projection", "proj"]: 567 | if hasattr(l, name): 568 | attr = getattr(l, name) 569 | if attr is not None: 570 | attr.data = attr.data.half() 571 | 572 | model.apply(_convert_weights_to_fp16) 573 | 574 | 575 | def build_model(state_dict: dict, input_size: list, design_details): 576 | vit = "visual.proj" in state_dict 577 | 578 | if vit: 579 | vision_width = state_dict["visual.conv1.weight"].shape[0] 580 | vision_layers = len( 581 | [k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) 582 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] 583 | 584 | # NOTE: resize pos_embed, so that input resolution can be customized 585 | if input_size[0] >= 224 and input_size[1] >= 224: 586 | num_x = input_size[1] // vision_patch_size 587 | num_y = input_size[0] // vision_patch_size 588 | num_patches = num_x * num_y 589 | if state_dict["visual.positional_embedding"].shape[0] != num_patches + 1: 590 | state_dict["visual.positional_embedding"] = resize_pos_embed(state_dict["visual.positional_embedding"], 591 | num_patches+1, num_y, num_x) 592 | 593 | grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) 594 | image_resolution = vision_patch_size * grid_size 595 | else: 596 | counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in 597 | [1, 2, 3, 4]] 598 | vision_layers = tuple(counts) 599 | vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] 600 | 601 | 602 | # assert input_size[0] == input_size[1] 603 | if input_size[0] >= 224 and input_size[1] >= 224: 604 | num_x = input_size[0] // 32 605 | num_fea = num_x ** 2 606 | if state_dict["visual.attnpool.positional_embedding"].shape[0] != num_fea + 1: 607 | state_dict["visual.attnpool.positional_embedding"] = resize_pos_embed(state_dict["visual.attnpool.positional_embedding"], 608 | num_fea+1, num_x, num_x) 609 | 610 | output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) 611 | vision_patch_size = None 612 | assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] 613 | image_resolution = output_width * 32 614 | 615 | embed_dim = state_dict["text_projection"].shape[1] 616 | context_length = state_dict["positional_embedding"].shape[0] 617 | vocab_size = state_dict["token_embedding.weight"].shape[0] 618 | transformer_width = state_dict["ln_final.weight"].shape[0] 619 | transformer_heads = transformer_width // 64 620 | transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks"))) 621 | 622 | model = CLIP( 623 | embed_dim, 624 | image_resolution, vision_layers, vision_width, vision_patch_size, 625 | context_length, vocab_size, transformer_width, transformer_heads, transformer_layers, design_details 626 | ) 627 | 628 | for key in ["input_resolution", "context_length", "vocab_size"]: 629 | if key in state_dict: 630 | del state_dict[key] 631 | 632 | convert_weights(model) 633 | model.load_state_dict(state_dict, strict=False) 634 | return model.eval() 635 | 636 | 637 | def resize_pos_embed(posemb, num_tok, hight, width): 638 | # Rescale the grid of position embeddings when loading from state_dict. Adapted from 639 | # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224 640 | # posemb_token, posemb_grid = posemb[:, :1], posemb[0, 1:] 641 | num_tok -= 1 642 | posemb_token, posemb_grid = posemb[0, :], posemb[1:, :] 643 | 644 | gs_old = int(math.sqrt(len(posemb_grid))) 645 | print('Resized position embedding from:{} to: {} with height:{} width: {}'.format(posemb.shape[0], num_tok, hight, width)) 646 | posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) 647 | posemb_grid = F.interpolate(posemb_grid, size=(hight, width), mode='bilinear') 648 | posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, hight * width, -1) 649 | posemb_grid = posemb_grid.squeeze(0) 650 | posemb_token = posemb_token.unsqueeze(0) 651 | posemb = torch.cat([posemb_token, posemb_grid], dim=0) 652 | return posemb 653 | -------------------------------------------------------------------------------- /clip/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | 6 | import ftfy 7 | import regex as re 8 | 9 | 10 | @lru_cache() 11 | def default_bpe(): 12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 13 | 14 | 15 | @lru_cache() 16 | def bytes_to_unicode(): 17 | """ 18 | Returns list of utf-8 byte and a corresponding list of unicode strings. 19 | The reversible bpe codes work on unicode strings. 20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 22 | This is a signficant percentage of your normal, say, 32K bpe vocab. 23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 24 | And avoids mapping to whitespace/control characters the bpe code barfs on. 25 | """ 26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 27 | cs = bs[:] 28 | n = 0 29 | for b in range(2**8): 30 | if b not in bs: 31 | bs.append(b) 32 | cs.append(2**8+n) 33 | n += 1 34 | cs = [chr(n) for n in cs] 35 | return dict(zip(bs, cs)) 36 | 37 | 38 | def get_pairs(word): 39 | """Return set of symbol pairs in a word. 40 | Word is represented as tuple of symbols (symbols being variable-length strings). 41 | """ 42 | pairs = set() 43 | prev_char = word[0] 44 | for char in word[1:]: 45 | pairs.add((prev_char, char)) 46 | prev_char = char 47 | return pairs 48 | 49 | 50 | def basic_clean(text): 51 | text = ftfy.fix_text(text) 52 | text = html.unescape(html.unescape(text)) 53 | return text.strip() 54 | 55 | 56 | def whitespace_clean(text): 57 | text = re.sub(r'\s+', ' ', text) 58 | text = text.strip() 59 | return text 60 | 61 | 62 | class SimpleTokenizer(object): 63 | def __init__(self, bpe_path: str = default_bpe()): 64 | self.byte_encoder = bytes_to_unicode() 65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 67 | merges = merges[1:49152-256-2+1] 68 | merges = [tuple(merge.split()) for merge in merges] 69 | vocab = list(bytes_to_unicode().values()) 70 | vocab = vocab + [v+'' for v in vocab] 71 | for merge in merges: 72 | vocab.append(''.join(merge)) 73 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 74 | self.encoder = dict(zip(vocab, range(len(vocab)))) 75 | self.decoder = {v: k for k, v in self.encoder.items()} 76 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 79 | 80 | def bpe(self, token): 81 | if token in self.cache: 82 | return self.cache[token] 83 | word = tuple(token[:-1]) + ( token[-1] + '',) 84 | pairs = get_pairs(word) 85 | 86 | if not pairs: 87 | return token+'' 88 | 89 | while True: 90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 91 | if bigram not in self.bpe_ranks: 92 | break 93 | first, second = bigram 94 | new_word = [] 95 | i = 0 96 | while i < len(word): 97 | try: 98 | j = word.index(first, i) 99 | new_word.extend(word[i:j]) 100 | i = j 101 | except: 102 | new_word.extend(word[i:]) 103 | break 104 | 105 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 106 | new_word.append(first+second) 107 | i += 2 108 | else: 109 | new_word.append(word[i]) 110 | i += 1 111 | new_word = tuple(new_word) 112 | word = new_word 113 | if len(word) == 1: 114 | break 115 | else: 116 | pairs = get_pairs(word) 117 | word = ' '.join(word) 118 | self.cache[token] = word 119 | return word 120 | 121 | def encode(self, text): 122 | bpe_tokens = [] 123 | text = whitespace_clean(basic_clean(text)).lower() 124 | for token in re.findall(self.pat, text): 125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 127 | return bpe_tokens 128 | 129 | def decode(self, tokens): 130 | text = ''.join([self.decoder[token] for token in tokens]) 131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 132 | return text 133 | -------------------------------------------------------------------------------- /config/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | from .defaults import _C as cfg 8 | -------------------------------------------------------------------------------- /config/defaults.py: -------------------------------------------------------------------------------- 1 | from yacs.config import CfgNode as CN 2 | 3 | # ----------------------------------------------------------------------------- 4 | # Convention about Training / Test specific parameters 5 | # ----------------------------------------------------------------------------- 6 | # Whenever an argument can be either used for training or for testing, the 7 | # corresponding name will be post-fixed by a _TRAIN for a training parameter, 8 | 9 | # ----------------------------------------------------------------------------- 10 | # Config definition 11 | # ----------------------------------------------------------------------------- 12 | 13 | _C = CN() 14 | # ----------------------------------------------------------------------------- 15 | # MODEL 16 | # ----------------------------------------------------------------------------- 17 | _C.MODEL = CN() 18 | # Using cuda or cpu for training 19 | _C.MODEL.DEVICE = "cuda" 20 | # ID number of GPU 21 | _C.MODEL.DEVICE_ID = '0' 22 | # Name of backbone 23 | _C.MODEL.NAME = 'resnet50' 24 | # Last stride of backbone 25 | _C.MODEL.LAST_STRIDE = 1 26 | # Path to pretrained model of backbone 27 | _C.MODEL.PRETRAIN_PATH = '' 28 | 29 | # Use ImageNet pretrained model to initialize backbone or use self trained model to initialize the whole model 30 | # Options: 'imagenet' , 'self' , 'finetune' 31 | _C.MODEL.PRETRAIN_CHOICE = 'imagenet' 32 | 33 | # If train with BNNeck, options: 'bnneck' or 'no' 34 | _C.MODEL.NECK = 'bnneck' 35 | # If train loss include center loss, options: 'yes' or 'no'. Loss with center loss has different optimizer configuration 36 | _C.MODEL.IF_WITH_CENTER = 'no' 37 | 38 | _C.MODEL.ID_LOSS_TYPE = 'softmax' 39 | _C.MODEL.ID_LOSS_WEIGHT = 1.0 40 | _C.MODEL.TRIPLET_LOSS_WEIGHT = 1.0 41 | 42 | _C.MODEL.METRIC_LOSS_TYPE = 'triplet' 43 | # If train with multi-gpu ddp mode, options: 'True', 'False' 44 | _C.MODEL.DIST_TRAIN = False 45 | # If train with soft triplet loss, options: 'True', 'False' 46 | _C.MODEL.NO_MARGIN = False 47 | # If train with label smooth, options: 'on', 'off' 48 | _C.MODEL.IF_LABELSMOOTH = 'on' 49 | # If train with arcface loss, options: 'True', 'False' 50 | _C.MODEL.COS_LAYER = False 51 | 52 | # Transformer setting 53 | _C.MODEL.DROP_PATH = 0.1 54 | _C.MODEL.DROP_OUT = 0.0 55 | _C.MODEL.ATT_DROP_RATE = 0.0 56 | _C.MODEL.TRANSFORMER_TYPE = 'None' 57 | _C.MODEL.STRIDE_SIZE = [16, 16] 58 | 59 | # JPM Parameter 60 | _C.MODEL.JPM = False 61 | _C.MODEL.SHIFT_NUM = 5 62 | _C.MODEL.SHUFFLE_GROUP = 2 63 | _C.MODEL.DEVIDE_LENGTH = 4 64 | _C.MODEL.RE_ARRANGE = True 65 | 66 | # SIE Parameter 67 | _C.MODEL.SIE_COE = 3.0 68 | _C.MODEL.SIE_CAMERA = False 69 | _C.MODEL.SIE_VIEW = False 70 | 71 | # CLIP Backbone 72 | _C.MODEL.BACKBONE = 'ViT-B/16' 73 | # Use BN 74 | _C.MODEL.BN = False 75 | # Number of head 76 | _C.MODEL.NUM_HEAD = 8 77 | # Loss type 78 | _C.MODEL.LOSS_TYPE = 'BCE' # 'ASL' 79 | # ema model 80 | _C.MODEL.USE_EMA = False 81 | _C.MODEL.EMA_DECAY = 0.9997 82 | # load pretrain 83 | _C.MODEL.LOAD = False 84 | # text encoder 85 | _C.MODEL.TEXT_ENCODER = 'CLIP' 86 | _C.MODEL.DEPTH_TEXT = [-1] 87 | _C.MODEL.TEXT_CTX = 4 # middle layers, 一般不用 88 | _C.MODEL.PROMPT_CSC = False 89 | 90 | # transfer type 91 | _C.MODEL.TRANSFER_TYPE = "freeze_all" 92 | 93 | # SAA 94 | _C.MODEL.SAA_LAYER = [-1] 95 | 96 | # loc region pooling 97 | _C.MODEL.LOC_STRIDE_SIZE = 4 98 | _C.MODEL.LOC_KERNEL_SIZE = 4 99 | 100 | # Adapter 101 | _C.MODEL.DEPTH_VISION = [-1] 102 | _C.MODEL.VISION_ADAPT = 8 103 | _C.MODEL.KERNEL_SIZE = 3 104 | 105 | # temperature 106 | _C.MODEL.TEMPERATURE = 0.002 107 | 108 | # new prototype 109 | _C.MODEL.NUM_NEW_PROTOTYPE = 10 110 | 111 | # OT reg 112 | _C.MODEL.OT_REG = 0.1 113 | _C.MODEL.OT_REGSC = 0.05 114 | 115 | # Adapter ratio 116 | _C.MODEL.ADAPTER_RATIO = 0.2 117 | 118 | 119 | # ----------------------------------------------------------------------------- 120 | # INPUT 121 | # ----------------------------------------------------------------------------- 122 | _C.INPUT = CN() 123 | # Size of the image during training 124 | _C.INPUT.SIZE_TRAIN = [224, 224] 125 | # Size of the image during test 126 | _C.INPUT.SIZE_TEST = [224, 224] 127 | # Random probability for image horizontal flip 128 | _C.INPUT.PROB = 0.5 129 | # Random probability for random erasing 130 | _C.INPUT.RE_PROB = 0.5 131 | # Values to be used for image normalization 132 | # _C.INPUT.PIXEL_MEAN = [0.485, 0.456, 0.406] 133 | # # Values to be used for image normalization 134 | # _C.INPUT.PIXEL_STD = [0.229, 0.224, 0.225] 135 | _C.INPUT.PIXEL_MEAN = [0.48145466, 0.4578275, 0.40821073] 136 | # Values to be used for image normalization 137 | _C.INPUT.PIXEL_STD = [0.26862954, 0.26130258, 0.27577711] 138 | # Value of padding size 139 | _C.INPUT.PADDING = 10 140 | 141 | # Image augmentation 142 | _C.INPUT.HIDESEEK = False 143 | _C.INPUT.AUGMIX = False 144 | 145 | # Vis: with path dataset 146 | _C.INPUT.WITH_PATH = False 147 | 148 | # Text template 149 | _C.INPUT.TEMPLATE = 'vanilla' # 'ensemble' 150 | _C.INPUT.ENSEMBLE_TYPE = 'embedding' # 'logit' 'score' 151 | _C.INPUT.NUM_GROUPS = 1 # 'logit' 152 | 153 | # ----------------------------------------------------------------------------- 154 | # Dataset 155 | # ----------------------------------------------------------------------------- 156 | _C.DATASETS = CN() 157 | # List of the dataset names for training, as present in paths_catalog.py 158 | _C.DATASETS.NAMES = ('PETA') 159 | # Root directory where datasets should be used (and downloaded if not found) 160 | _C.DATASETS.ROOT_DIR = ('../data') 161 | # pedestrain attribute numbers 162 | _C.DATASETS.NUMBERS = 35 163 | _C.DATASETS.TEMPLATE_NAMES = '' 164 | 165 | # label proportion 166 | _C.DATASETS.PARTIAL = -1.0 167 | 168 | # ----------------------------------------------------------------------------- 169 | # DataLoader 170 | # ----------------------------------------------------------------------------- 171 | _C.DATALOADER = CN() 172 | # Number of data loading threads 173 | _C.DATALOADER.NUM_WORKERS = 8 174 | # Sampler for data loading 175 | _C.DATALOADER.SAMPLER = 'softmax' 176 | # Number of instance for one batch 177 | _C.DATALOADER.NUM_INSTANCE = 16 178 | 179 | # ---------------------------------------------------------------------------- # 180 | # Solver 181 | # ---------------------------------------------------------------------------- # 182 | _C.SOLVER = CN() 183 | # Save the model 184 | _C.SOLVER.SAVE_MODEL = False 185 | # Name of optimizer 186 | _C.SOLVER.OPTIMIZER_NAME = "Adam" 187 | # Number of max epoches 188 | _C.SOLVER.MAX_EPOCHS = 100 189 | # Number of max epoches FOR SCHEDULER 190 | _C.SOLVER.SCHEDULER_MAX_EPOCHS = 60 191 | # Number of max epoches FOR SCHEDULER 192 | _C.SOLVER.SCHEDULER_MAX_ITER = 1000000 193 | # Base learning rate 194 | _C.SOLVER.BASE_LR = 3e-4 195 | # Factor of learning bias 196 | _C.SOLVER.BIAS_LR_FACTOR = 1 197 | # Factor of learning bias 198 | _C.SOLVER.SEED = 42 199 | # Momentum 200 | _C.SOLVER.MOMENTUM = 0.9 201 | # Margin of triplet loss 202 | _C.SOLVER.MARGIN = 0.3 203 | # Learning rate of SGD to learn the centers of center loss 204 | _C.SOLVER.CENTER_LR = 0.5 205 | # Balanced weight of center loss 206 | _C.SOLVER.CENTER_LOSS_WEIGHT = 0.0005 207 | _C.SOLVER.LARGE_FC_LR = False 208 | 209 | # Settings of weight decay 210 | _C.SOLVER.WEIGHT_DECAY = 0.0001 211 | _C.SOLVER.WEIGHT_DECAY_BIAS = 0.0001 212 | _C.SOLVER.WEIGHT_DECAY_SGD = 0.0001 213 | 214 | # decay rate of learning rate 215 | _C.SOLVER.GAMMA = 0.1 216 | # decay step of learning rate 217 | _C.SOLVER.STEPS = (40, 70) 218 | # warm up factor 219 | _C.SOLVER.WARMUP_FACTOR = 0.01 220 | # warm up epochs 221 | _C.SOLVER.WARMUP_EPOCHS = 3 222 | # method of warm up, option: 'constant','linear' 223 | _C.SOLVER.WARMUP_METHOD = "linear" 224 | 225 | _C.SOLVER.COSINE_MARGIN = 0.5 226 | _C.SOLVER.COSINE_SCALE = 30 227 | 228 | # epoch number of saving checkpoints 229 | _C.SOLVER.CHECKPOINT_PERIOD = 10 230 | # iteration of display training log 231 | _C.SOLVER.LOG_PERIOD = 100 232 | # epoch number of validation 233 | _C.SOLVER.EVAL_PERIOD = 10 234 | # Number of images per batch 235 | # This is global, so if we have 8 GPUs and IMS_PER_BATCH = 128, each GPU 236 | # contain 16 images per batch 237 | _C.SOLVER.IMS_PER_BATCH = 64 238 | 239 | # Classification Threshold 240 | # Loss type for contrastive 241 | _C.SOLVER.THRESH = 0.5 242 | 243 | # Label smoothing 244 | _C.SOLVER.LABEL_SMOOTHING = False 245 | 246 | # LR sheduler iter (TGPT imple) 247 | _C.SOLVER.GAMMA = 0.1 248 | _C.SOLVER.LR_SCHEDULER = "cosine" 249 | _C.SOLVER.STEPSIZE = 1000 250 | 251 | # aslloss param 252 | _C.SOLVER.GAMMA_NEG = 2 253 | _C.SOLVER.GAMMA_POS = 0 254 | _C.SOLVER.CLIP = 0. 255 | 256 | # twloss param 257 | _C.SOLVER.TP = 4. 258 | _C.SOLVER.TN = 1. 259 | 260 | # save the middle output, for visualization 261 | _C.SOLVER.VERBOSE = False 262 | 263 | # iter training 264 | _C.SOLVER.MAX_ITER = 12800 265 | _C.SOLVER.WARMUP_ITER = 200 266 | _C.SOLVER.BASE_LR_SGD = 0.001 267 | 268 | # KD loss weight 269 | _C.SOLVER.KDLOSS_WEIGHT = 1. 270 | 271 | # Text batch 272 | _C.SOLVER.TEXT_BATCH_SIZE = 80 273 | 274 | # debug mode 275 | _C.SOLVER.DEBUG = False 276 | 277 | # zero-shot testing 278 | _C.SOLVER.ZS_TEST = False 279 | 280 | # sample text 281 | _C.SOLVER.SAMPLE_TEXT = False 282 | 283 | 284 | # ---------------------------------------------------------------------------- # 285 | # TEST 286 | # ---------------------------------------------------------------------------- # 287 | 288 | _C.TEST = CN() 289 | # Number of images per batch during test 290 | _C.TEST.IMS_PER_BATCH = 128 291 | # If test with re-ranking, options: 'True','False' 292 | _C.TEST.RE_RANKING = False 293 | # Path to trained model 294 | _C.TEST.WEIGHT = "" 295 | _C.TEST.WEIGHT_ITERS = 12800 296 | # Which feature of BNNeck to be used for test, before or after BNNneck, options: 'before' or 'after' 297 | _C.TEST.NECK_FEAT = 'after' 298 | # Whether feature is nomalized before test, if yes, it is equivalent to cosine distance 299 | _C.TEST.FEAT_NORM = 'yes' 300 | 301 | # Name for saving the distmat after testing. 302 | _C.TEST.DIST_MAT = "dist_mat.npy" 303 | # Whether calculate the eval score option: 'True', 'False' 304 | _C.TEST.EVAL = False 305 | _C.TEST.USE_FUSION = False 306 | _C.TEST.TTA = -1 307 | _C.TEST.TEN_CROP = False 308 | # ---------------------------------------------------------------------------- # 309 | # Misc options 310 | # ---------------------------------------------------------------------------- # 311 | # Path to checkpoint and saved log of trained model 312 | _C.OUTPUT_DIR = "" 313 | 314 | 315 | # ---------------------------------------------------------------------------- # 316 | # Train time configs 317 | # ---------------------------------------------------------------------------- # 318 | _C.LOCAL_RANK = 0 319 | _C.WANDB = False 320 | _C.WANDB_PROJ = "OVML-RAM" 321 | 322 | 323 | 324 | 325 | 326 | 327 | -------------------------------------------------------------------------------- /configs/coco.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | IF_LABELSMOOTH: 'off' 3 | IF_WITH_CENTER: 'no' 4 | NO_MARGIN: True 5 | STRIDE_SIZE: [16, 16] 6 | BACKBONE: 'ViT-B/16' 7 | LOSS_TYPE: 'MMC' 8 | DEPTH_VISION: [9,10,11] 9 | DEPTH_TEXT: [9,10,11] 10 | PROMPT_CSC: False 11 | TRANSFER_TYPE: "Adapter" 12 | SAA_LAYER: [12, -1] 13 | USE_EMA: True 14 | KERNEL_SIZE: 3 15 | TEMPERATURE: 2e-3 16 | OT_REGSC: 0.05 17 | 18 | INPUT: 19 | SIZE_TRAIN: [224, 224] 20 | SIZE_TEST: [224, 224] 21 | PROB: 0.5 22 | RE_PROB: 0.5 23 | PADDING: 10 24 | PIXEL_MEAN: [0.5, 0.5, 0.5] 25 | PIXEL_STD: [0.5, 0.5, 0.5] 26 | ENSEMBLE_TYPE: "embedding" 27 | NUM_GROUPS: 1 28 | 29 | DATASETS: 30 | NAMES: ('COCO_ZS') 31 | ROOT_DIR: ('/mnt/nas/TrueNas1/COCO') 32 | NUMBERS: 80 33 | 34 | DATALOADER: 35 | SAMPLER: 'sigmoid' 36 | NUM_INSTANCE: 4 37 | NUM_WORKERS: 8 38 | 39 | SOLVER: 40 | OPTIMIZER_NAME: 'AdamW' 41 | BASE_LR: 5e-5 42 | IMS_PER_BATCH: 32 43 | LARGE_FC_LR: False 44 | LOG_PERIOD: 500 45 | WEIGHT_DECAY: 1e-4 46 | MAX_EPOCHS: 6 47 | EVAL_PERIOD: 3 48 | SAVE_MODEL: False 49 | 50 | TEST: 51 | EVAL: True 52 | IMS_PER_BATCH: 512 53 | RE_RANKING: False 54 | WEIGHT: '' 55 | WEIGHT_ITERS: 12800 56 | NECK_FEAT: 'before' 57 | FEAT_NORM: 'yes' 58 | 59 | OUTPUT_DIR: './logs/RAM/Natural/COCO' 60 | -------------------------------------------------------------------------------- /configs/nus.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | IF_LABELSMOOTH: 'off' 3 | IF_WITH_CENTER: 'no' 4 | NO_MARGIN: True 5 | STRIDE_SIZE: [16, 16] 6 | BACKBONE: 'ViT-B/16' 7 | LOSS_TYPE: 'MMC' 8 | DEPTH_VISION: [9,10,11] 9 | DEPTH_TEXT: [6,7,8,9,10,11] 10 | PROMPT_CSC: False 11 | TRANSFER_TYPE: "Adapter" 12 | SAA_LAYER: [12, -1] 13 | USE_EMA: True 14 | KERNEL_SIZE: 3 15 | TEMPERATURE: 2e-4 16 | OT_REGSC: 0.05 17 | 18 | INPUT: 19 | SIZE_TRAIN: [224, 224] 20 | SIZE_TEST: [224, 224] 21 | PROB: 0.5 22 | RE_PROB: 0.5 23 | PADDING: 10 24 | PIXEL_MEAN: [0.5, 0.5, 0.5] 25 | PIXEL_STD: [0.5, 0.5, 0.5] 26 | ENSEMBLE_TYPE: "embedding" 27 | NUM_GROUPS: 1 28 | 29 | DATASETS: 30 | NAMES: ('NUS_ZS') 31 | ROOT_DIR: ('/mnt/nas/TrueNas1/NUS_WIDE') 32 | NUMBERS: 81 33 | 34 | DATALOADER: 35 | SAMPLER: 'sigmoid' 36 | NUM_INSTANCE: 4 37 | NUM_WORKERS: 8 38 | 39 | SOLVER: 40 | OPTIMIZER_NAME: 'AdamW' 41 | BASE_LR: 5e-5 42 | IMS_PER_BATCH: 32 43 | LARGE_FC_LR: False 44 | LOG_PERIOD: 1000 45 | WEIGHT_DECAY: 1e-4 46 | MAX_EPOCHS: 6 47 | EVAL_PERIOD: 3 48 | SAVE_MODEL: False 49 | 50 | TEST: 51 | EVAL: True 52 | IMS_PER_BATCH: 512 53 | RE_RANKING: False 54 | WEIGHT: '' 55 | WEIGHT_ITERS: 12800 56 | NECK_FEAT: 'before' 57 | FEAT_NORM: 'yes' 58 | 59 | OUTPUT_DIR: './logs/RAM/Natural/NUS' 60 | -------------------------------------------------------------------------------- /datasets/MultiLabel/classification.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import random 4 | import json 5 | 6 | from datasets.bases import BaseImageDataset 7 | 8 | 9 | 10 | class ZSMultiLabelClassification(BaseImageDataset): 11 | def __init__(self, root='', verbose=True, **kwargs): 12 | super(ZSMultiLabelClassification, self).__init__() 13 | self.dataset_dir = root 14 | if "coco" in root.lower(): 15 | dataset_name = "COCO" 16 | self.train_file = os.path.join(self.dataset_dir, "annotations", 'train_48_filtered.json') 17 | self.test_file = os.path.join(self.dataset_dir, "annotations", 'test_17_filtered.json') 18 | self.test_file_gzsl = os.path.join(self.dataset_dir, "annotations", 'test_65_filtered.json') 19 | elif "nus" in root.lower(): 20 | dataset_name = "NUS" 21 | self.train_file = os.path.join(self.dataset_dir, "annotations", 'train_925_filtered.json') 22 | self.test_file = os.path.join(self.dataset_dir, "annotations", 'test_81_filtered.json') 23 | self.test_file_gzsl = os.path.join(self.dataset_dir, "annotations", 'test_1006_filtered.json') 24 | else: 25 | raise NotImplementedError 26 | self._check_before_run() 27 | 28 | train, class2idx, name_train = self._load_dataset(self.dataset_dir, self.train_file, shuffle=True) 29 | test, _, name_test = self._load_dataset(self.dataset_dir, self.test_file, shuffle=False) 30 | test_gzsl, _, _ = self._load_dataset(self.dataset_dir, self.test_file_gzsl, shuffle=False, names=name_train+name_test) 31 | self.train = train 32 | self.test = test 33 | self.test_gzsl = test_gzsl 34 | self.class2idx = class2idx 35 | if verbose: 36 | print(f"=> {dataset_name} ZSL Dataset:") 37 | self.print_dataset_statistics(train, test) 38 | print(f"=> {dataset_name} GZSL Dataset:") 39 | self.print_dataset_statistics(train, test_gzsl) 40 | self.classnames_seen = name_train 41 | self.classnames_unseen = name_test 42 | self.classnames = name_train+name_test 43 | self.num_cls_train = len(name_train) 44 | self.num_cls_test = len(name_test) 45 | 46 | def _check_before_run(self): 47 | """Check if all files are available before going deeper""" 48 | if not os.path.exists(self.dataset_dir): 49 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 50 | if not os.path.exists(self.train_file): 51 | raise RuntimeError("'{}' is not available".format(self.train_file)) 52 | if not os.path.exists(self.test_file): 53 | raise RuntimeError("'{}' is not available".format(self.test_file)) 54 | 55 | def _load_dataset(self, data_dir, annot_path, shuffle=True, names=None): 56 | out_data = [] 57 | with open(annot_path) as f: 58 | annotation = json.load(f) 59 | classes = sorted(annotation['classes']) if names is None else names 60 | class_to_idx = {classes[i]: i for i in range(len(classes))} 61 | images_info = annotation['images'] 62 | img_wo_objects = 0 63 | for img_info in images_info: 64 | labels_idx = list() 65 | rel_image_path, img_labels = img_info 66 | full_image_path = os.path.join(data_dir, rel_image_path) 67 | labels_idx = [class_to_idx[lbl] for lbl in img_labels if lbl in class_to_idx] 68 | labels_idx = list(set(labels_idx)) 69 | # transform to one-hot 70 | onehot = np.zeros(len(classes), dtype=int) 71 | onehot[labels_idx] = 1 72 | assert full_image_path 73 | if not labels_idx: 74 | img_wo_objects += 1 75 | out_data.append((full_image_path, onehot)) 76 | if img_wo_objects: 77 | print(f'WARNING: there are {img_wo_objects} images without labels and will be treated as negatives') 78 | if shuffle: 79 | random.shuffle(out_data) 80 | return out_data, class_to_idx, classes 81 | 82 | 83 | 84 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from datasets.make_dataloader import make_dataloader -------------------------------------------------------------------------------- /datasets/bases.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | 5 | from PIL import Image, ImageFile 6 | from imgaug import augmenters as iaa 7 | from torch.utils.data import Dataset 8 | ImageFile.LOAD_TRUNCATED_IMAGES = True 9 | 10 | 11 | 12 | class BaseDataset(object): 13 | """ 14 | Base class of reid dataset 15 | """ 16 | def get_imagedata_info(self, data): 17 | imgs = [] 18 | labels = [] 19 | for data_img, data_label in data: 20 | imgs += [data_img] 21 | labels += [data_label] 22 | num_imgs = len(imgs) 23 | num_labels = len(labels) 24 | return num_imgs, num_labels 25 | 26 | def print_dataset_statistics(self): 27 | raise NotImplementedError 28 | 29 | 30 | 31 | class BaseImageDataset(BaseDataset): 32 | """ 33 | Base class of image reid dataset 34 | """ 35 | def print_dataset_statistics(self, train, test): 36 | num_train_imgs, num_train_labels = self.get_imagedata_info(train) 37 | num_test_imgs, num_test_labels = self.get_imagedata_info(test) 38 | 39 | print("Dataset statistics:") 40 | print(" ----------------------------------------") 41 | print(" subset | # images | # labels") 42 | print(" ----------------------------------------") 43 | print(" train | {:8d} | {:9d}".format(num_train_imgs, num_train_labels)) 44 | print(" test | {:8d} | {:9d}".format(num_test_imgs, num_test_labels)) 45 | print(" ----------------------------------------") 46 | 47 | 48 | 49 | class ImageDataset(Dataset): 50 | def __init__(self, dataset, transform=None, shuffle = True, mirror=False, Aug = False): 51 | self.dataset = dataset 52 | self.transform = transform 53 | self.mirror = mirror 54 | self.Aug = Aug 55 | self.AugSeq = iaa.Sequential([ 56 | iaa.Sometimes(0.2, 57 | iaa.Crop(percent=(0, 0.1)), 58 | ), 59 | # Small gaussian blur with random sigma between 0 and 0.5. 60 | # But we only blur about 50% of all images. 61 | iaa.Sometimes(0.2, 62 | iaa.GaussianBlur(sigma=(0, 0.2)) 63 | ), 64 | 65 | # Strengthen or weaken the contrast in each image. 66 | iaa.Sometimes(0.2, iaa.ContrastNormalization((0.75, 1.25))), 67 | iaa.AdditiveGaussianNoise(loc=0, scale=(0.0, 0.015*255), per_channel=0.2), 68 | iaa.Multiply((0.9, 1.1), per_channel=0.2), 69 | # Apply affine transformations to each image. 70 | # Scale/zoom them, translate/move them, rotate them and shear them. 71 | iaa.Sometimes(0.2, 72 | iaa.Affine( 73 | scale={"x": (0.9, 1.1), "y": (0.9, 1.1)}, 74 | translate_percent={"x": (-0.1, 0.1), "y": (-0.1, 0.1)}, 75 | rotate=(-10, 10) 76 | ) 77 | ), 78 | ], random_order=True) # apply augmenters in random order 79 | 80 | def __getitem__(self, item): 81 | img_path, atts_list = self.dataset[item] 82 | img = Image.open(img_path).convert("RGB") 83 | 84 | # flip 85 | flip = 0 86 | if self.mirror: 87 | flip = np.random.choice(2) 88 | if flip: 89 | img = img.transpose(Image.FLIP_LEFT_RIGHT) 90 | 91 | if self.Aug: 92 | cv_img = cv2.cvtColor(np.asarray(img), cv2.COLOR_RGB2BGR) 93 | cv_img = cv_img.reshape(1, cv_img.shape[0], cv_img.shape[1], cv_img.shape[2]) 94 | cv_img = self.AugSeq.augment_images(cv_img) 95 | cv_img = cv_img.reshape(cv_img.shape[1], cv_img.shape[2], cv_img.shape[3]) 96 | img = Image.fromarray(cv2.cvtColor(cv_img, cv2.COLOR_BGR2RGB)) 97 | 98 | if self.transform is not None: 99 | img = self.transform(img) 100 | return img, atts_list 101 | 102 | def __len__(self): 103 | return len(self.dataset) 104 | 105 | -------------------------------------------------------------------------------- /datasets/make_dataloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.distributed as dist 4 | import torchvision.transforms as T 5 | from torch.utils.data import DataLoader 6 | from timm.data.random_erasing import RandomErasing 7 | 8 | from .bases import ImageDataset 9 | from .MultiLabel.classification import ZSMultiLabelClassification 10 | 11 | 12 | 13 | def train_collate_fn(batch): 14 | imgs, label = zip(*batch) 15 | label = torch.tensor(label, dtype=torch.int64) 16 | return torch.stack(imgs, dim=0), label 17 | 18 | 19 | def val_collate_fn(batch): 20 | imgs, label = zip(*batch) 21 | label = torch.tensor(label, dtype=torch.int64) 22 | return torch.stack(imgs, dim=0), label 23 | 24 | 25 | def make_dataloader(cfg): 26 | num_workers = cfg.DATALOADER.NUM_WORKERS 27 | dataset = ZSMultiLabelClassification(root=cfg.DATASETS.ROOT_DIR) 28 | 29 | train_transforms = T.Compose([ 30 | T.Resize(cfg.INPUT.SIZE_TRAIN, interpolation=3), 31 | T.RandomHorizontalFlip(p=cfg.INPUT.PROB), 32 | T.Pad(cfg.INPUT.PADDING), 33 | T.RandomCrop(cfg.INPUT.SIZE_TRAIN), 34 | T.ToTensor(), 35 | T.Normalize(mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD), 36 | RandomErasing(probability=cfg.INPUT.RE_PROB, mode='pixel', max_count=1, device='cpu'), 37 | ]) 38 | train_set = ImageDataset(dataset.train, transform=train_transforms, mirror=True, Aug=True) 39 | 40 | val_transforms = T.Compose([ 41 | T.Resize(cfg.INPUT.SIZE_TEST), 42 | T.ToTensor(), 43 | T.Normalize(mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD) 44 | ]) 45 | val_set = ImageDataset(dataset.test, transform=val_transforms, mirror=False, Aug=False) 46 | val_set_gzsl = ImageDataset(dataset.test_gzsl, transform=val_transforms, mirror=False, Aug=False) 47 | 48 | if cfg.MODEL.DIST_TRAIN: 49 | print('DIST_TRAIN START') 50 | mini_batch_size = cfg.SOLVER.IMS_PER_BATCH // dist.get_world_size() 51 | print('===========================\n mini batch size:', mini_batch_size) 52 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_set) 53 | nw = min([os.cpu_count(), mini_batch_size if mini_batch_size > 1 else 0, 8]) 54 | train_loader = torch.utils.data.DataLoader( 55 | train_set, 56 | batch_size=mini_batch_size, 57 | pin_memory=True, 58 | num_workers=nw, 59 | shuffle=False, 60 | sampler=train_sampler, 61 | collate_fn=train_collate_fn, 62 | drop_last=True 63 | ) 64 | val_loader = DataLoader( 65 | val_set, 66 | batch_size=cfg.TEST.IMS_PER_BATCH, 67 | shuffle=False, 68 | num_workers=num_workers, 69 | collate_fn=val_collate_fn 70 | ) 71 | val_loader_gzsl = DataLoader( 72 | val_set_gzsl, 73 | batch_size=cfg.TEST.IMS_PER_BATCH, 74 | shuffle=False, 75 | num_workers=num_workers, 76 | collate_fn=val_collate_fn 77 | ) 78 | return train_loader, val_loader, val_loader_gzsl, train_sampler, dataset 79 | else: 80 | train_loader = DataLoader( 81 | train_set, 82 | batch_size=cfg.SOLVER.IMS_PER_BATCH, 83 | shuffle=False, 84 | num_workers=num_workers, 85 | collate_fn=train_collate_fn, 86 | drop_last=True, 87 | persistent_workers=True 88 | ) 89 | val_loader = DataLoader( 90 | val_set, 91 | batch_size=cfg.TEST.IMS_PER_BATCH, 92 | shuffle=False, 93 | num_workers=num_workers, 94 | collate_fn=val_collate_fn, 95 | persistent_workers=True, 96 | ) 97 | val_loader_gzsl = DataLoader( 98 | val_set_gzsl, 99 | batch_size=cfg.TEST.IMS_PER_BATCH, 100 | shuffle=False, 101 | num_workers=num_workers, 102 | collate_fn=val_collate_fn, 103 | persistent_workers=True, 104 | ) 105 | return train_loader, val_loader, val_loader_gzsl, None, dataset 106 | 107 | -------------------------------------------------------------------------------- /loss/__init__.py: -------------------------------------------------------------------------------- 1 | from .asymmetric_loss import AsymmetricLossOptimized, AsymmetricLossOptimized_partial 2 | from .mmc_loss import mmc_loss 3 | -------------------------------------------------------------------------------- /loss/asymmetric_loss.py: -------------------------------------------------------------------------------- 1 | """ 2 | Borrow from https://github.com/SlongLiu/query2labels, thanks 3 | """ 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | 9 | class AsymmetricLoss(nn.Module): 10 | def __init__(self, gamma_neg=4, gamma_pos=1, clip=0.05, eps=1e-8, disable_torch_grad_focal_loss=False): 11 | super(AsymmetricLoss, self).__init__() 12 | 13 | self.gamma_neg = gamma_neg 14 | self.gamma_pos = gamma_pos 15 | self.clip = clip 16 | self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss 17 | self.eps = eps 18 | 19 | def forward(self, x, y): 20 | """" 21 | Parameters 22 | ---------- 23 | x: input logits 24 | y: targets (multi-label binarized vector) 25 | """ 26 | 27 | # Calculating Probabilities 28 | x_sigmoid = torch.sigmoid(x) 29 | xs_pos = x_sigmoid 30 | xs_neg = 1 - x_sigmoid 31 | 32 | # Asymmetric Clipping 33 | if self.clip is not None and self.clip > 0: 34 | xs_neg = (xs_neg + self.clip).clamp(max=1) 35 | 36 | # Basic CE calculation 37 | los_pos = y * torch.log(xs_pos.clamp(min=self.eps, max=1-self.eps)) 38 | los_neg = (1 - y) * torch.log(xs_neg.clamp(min=self.eps, max=1-self.eps)) 39 | loss = los_pos + los_neg 40 | 41 | # Asymmetric Focusing 42 | if self.gamma_neg > 0 or self.gamma_pos > 0: 43 | if self.disable_torch_grad_focal_loss: 44 | torch._C.set_grad_enabled(False) 45 | pt0 = xs_pos * y 46 | pt1 = xs_neg * (1 - y) # pt = p if t > 0 else 1-p 47 | pt = pt0 + pt1 48 | one_sided_gamma = self.gamma_pos * y + self.gamma_neg * (1 - y) 49 | one_sided_w = torch.pow(1 - pt, one_sided_gamma) 50 | if self.disable_torch_grad_focal_loss: 51 | torch._C.set_grad_enabled(True) 52 | loss *= one_sided_w 53 | 54 | return -loss.sum() 55 | 56 | 57 | 58 | class AsymmetricLossOptimized(nn.Module): 59 | ''' Notice - optimized version, minimizes memory allocation and gpu uploading, 60 | favors inplace operations''' 61 | # Default values are those used in query2label training 62 | def __init__(self, gamma_neg=2, gamma_pos=0, clip=0., eps=1e-5, disable_torch_grad_focal_loss=True): 63 | super(AsymmetricLossOptimized, self).__init__() 64 | 65 | self.gamma_neg = gamma_neg 66 | self.gamma_pos = gamma_pos 67 | self.clip = clip 68 | self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss 69 | self.eps = eps 70 | 71 | self.targets = self.anti_targets = self.xs_pos = self.xs_neg = self.asymmetric_w = self.loss = None 72 | 73 | def forward(self, x, y): 74 | """" 75 | Parameters 76 | ---------- 77 | x: input logits 78 | y: targets (multi-label binarized vector) 79 | """ 80 | 81 | self.targets = y 82 | self.anti_targets = 1 - y 83 | 84 | # Calculating Probabilities 85 | self.xs_pos = torch.sigmoid(x) 86 | self.xs_neg = 1.0 - self.xs_pos 87 | 88 | # Asymmetric Clipping 89 | if self.clip is not None and self.clip > 0: 90 | self.xs_neg.add_(self.clip).clamp_(max=1) 91 | 92 | # Basic CE calculation 93 | self.loss = self.targets * torch.log(self.xs_pos.clamp(min=self.eps)) 94 | self.loss.add_(self.anti_targets * torch.log(self.xs_neg.clamp(min=self.eps))) 95 | 96 | # Asymmetric Focusing 97 | if self.gamma_neg > 0 or self.gamma_pos > 0: 98 | if self.disable_torch_grad_focal_loss: 99 | with torch.no_grad(): 100 | # if self.disable_torch_grad_focal_loss: 101 | # torch._C.set_grad_enabled(False) 102 | self.xs_pos = self.xs_pos * self.targets 103 | self.xs_neg = self.xs_neg * self.anti_targets 104 | self.asymmetric_w = torch.pow(1 - self.xs_pos - self.xs_neg, 105 | self.gamma_pos * self.targets + self.gamma_neg * self.anti_targets) 106 | # if self.disable_torch_grad_focal_loss: 107 | # torch._C.set_grad_enabled(True) 108 | self.loss *= self.asymmetric_w 109 | else: 110 | self.xs_pos = self.xs_pos * self.targets 111 | self.xs_neg = self.xs_neg * self.anti_targets 112 | self.asymmetric_w = torch.pow(1 - self.xs_pos - self.xs_neg, 113 | self.gamma_pos * self.targets + self.gamma_neg * self.anti_targets) 114 | self.loss *= self.asymmetric_w 115 | _loss = - self.loss.sum() / x.size(0) 116 | _loss = _loss / y.size(1) * 1000 117 | 118 | return _loss 119 | 120 | 121 | 122 | 123 | class AsymmetricLossOptimized_partial(nn.Module): 124 | """ 125 | ASL loss used for partial label setting 126 | """ 127 | def __init__(self, gamma_neg=2, gamma_pos=0, clip=0., eps=1e-5, disable_torch_grad_focal_loss=True): 128 | super(AsymmetricLossOptimized, self).__init__() 129 | 130 | self.gamma_neg = gamma_neg 131 | self.gamma_pos = gamma_pos 132 | self.clip = clip 133 | self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss 134 | self.eps = eps 135 | 136 | self.targets = self.anti_targets = self.xs_pos = self.xs_neg = self.asymmetric_w = self.loss = None 137 | 138 | def forward(self, x, y): 139 | """" 140 | Parameters 141 | ---------- 142 | x: input logits 143 | y: targets (multi-label binarized vector) 144 | """ 145 | y = y.reshape(-1) 146 | x = x.reshape(-1) 147 | x = x[y!=-1] 148 | y = y[y!=-1] 149 | 150 | self.targets = y 151 | self.anti_targets = 1 - y 152 | 153 | # Calculating Probabilities 154 | # self.xs_pos = torch.sigmoid(x) 155 | self.xs_pos = x 156 | self.xs_neg = 1.0 - self.xs_pos 157 | 158 | # Asymmetric Clipping 159 | if self.clip is not None and self.clip > 0: 160 | self.xs_neg.add_(self.clip).clamp_(max=1) 161 | 162 | # Basic CE calculation 163 | self.loss = self.targets * torch.log(self.xs_pos.clamp(min=self.eps)) 164 | self.loss.add_(self.anti_targets * torch.log(self.xs_neg.clamp(min=self.eps))) 165 | 166 | # Asymmetric Focusing 167 | if self.gamma_neg > 0 or self.gamma_pos > 0: 168 | if self.disable_torch_grad_focal_loss: 169 | with torch.no_grad(): 170 | # if self.disable_torch_grad_focal_loss: 171 | # torch._C.set_grad_enabled(False) 172 | self.xs_pos = self.xs_pos * self.targets 173 | self.xs_neg = self.xs_neg * self.anti_targets 174 | self.asymmetric_w = torch.pow(1 - self.xs_pos - self.xs_neg, 175 | self.gamma_pos * self.targets + self.gamma_neg * self.anti_targets) 176 | # if self.disable_torch_grad_focal_loss: 177 | # torch._C.set_grad_enabled(True) 178 | self.loss *= self.asymmetric_w 179 | else: 180 | self.xs_pos = self.xs_pos * self.targets 181 | self.xs_neg = self.xs_neg * self.anti_targets 182 | self.asymmetric_w = torch.pow(1 - self.xs_pos - self.xs_neg, 183 | self.gamma_pos * self.targets + self.gamma_neg * self.anti_targets) 184 | self.loss *= self.asymmetric_w 185 | # print(self.loss.shape) 186 | 187 | _loss = - self.loss.sum() / x.size(0) * 1000 188 | # _loss = _loss / y.size(1) * 1000 189 | 190 | return _loss 191 | 192 | -------------------------------------------------------------------------------- /loss/mmc_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | 6 | def mmc_loss(logits, mask, logits_mask=None): 7 | """ 8 | MMC Loss (Multi-Matching Contrastive loss) 9 | Inspired by SupCon Loss: https://github.com/google-research/google-research/tree/master/supcon 10 | MMC extends it into (1) multi-modal setting and (2) batched contrastive process 11 | Args: 12 | logits: torch.Tensor[B, C], B - mini-batch size, C - number of classes 13 | mask: torch.Tensor[B, C], binary mask, 1 if the class is present in the image, 0 otherwise 14 | logits_mask: torch.Tensor[B, C], mask out self-matching logits, not applied in multi-modal setting 15 | Returns: 16 | loss_cl: torch.Tensor[1], mean cross-entropy loss over positive pairs 17 | """ 18 | # flatten the batch dimension 19 | logits = logits.reshape(-1) 20 | mask = mask.reshape(-1) 21 | 22 | # for numerical stability 23 | logits_max = torch.max(logits) 24 | logits = logits - logits_max.detach() 25 | exp_mixed_logits = torch.exp(logits) 26 | 27 | # mask out self-matching logits 28 | if logits_mask is not None: 29 | logits_mask = logits_mask.reshape(-1) 30 | exp_mixed_logits = exp_mixed_logits * logits_mask 31 | 32 | # cross entropy + softmax 33 | log_prob = logits - torch.log(exp_mixed_logits.sum()) 34 | num_pos_pairs = mask.sum() 35 | 36 | # sum over positive pairs, division is outside the log 37 | num_pos_pairs = torch.where(num_pos_pairs < 1e-6, 1, num_pos_pairs) 38 | mean_log_prob_pos = (mask * log_prob).sum() / num_pos_pairs 39 | 40 | # mean over batch samples 41 | loss = -mean_log_prob_pos 42 | loss = loss.mean() 43 | 44 | return loss 45 | 46 | -------------------------------------------------------------------------------- /loss/seesawloss.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from typing import Union 6 | 7 | 8 | 9 | class focal_loss(nn.Module): 10 | def __init__(self, alpha=None, gamma=2.0, num_classes=2): 11 | super(focal_loss, self).__init__() 12 | self.alpha = alpha 13 | self.gamma = gamma 14 | self.num_classes = num_classes 15 | 16 | def forward(self, y_pred, y_labels): 17 | y_pred = torch.softmax(y_pred, dim=1) 18 | class_mask = F.one_hot(y_labels, num_classes=self.num_classes) 19 | pt = (y_pred * class_mask).sum(dim=1) 20 | if self.alpha is None: 21 | loss = -((1 - pt) ** self.gamma) * pt.log() 22 | loss = loss.mean() 23 | else: 24 | alpha = self.alpha[y_labels] 25 | loss = -alpha * ((1 - pt) ** self.gamma) * pt.log() 26 | loss = loss.sum() / alpha.sum() 27 | return loss 28 | 29 | 30 | class SeesawLossWithLogits(nn.Module): 31 | """ 32 | This is unofficial implementation for Seesaw loss, 33 | which is proposed in the techinical report for LVIS workshop at ECCV 2020. 34 | For more detail, please refer https://arxiv.org/pdf/2008.10032.pdf. 35 | Args: 36 | class_counts: The list which has number of samples for each class. 37 | Should have same length as num_classes. 38 | p: Scale parameter which adjust the strength of panishment. 39 | Set to 0.8 as a default by following the original paper. 40 | """ 41 | 42 | def __init__(self, class_counts: Union[list, np.array], p: float = 0.8): 43 | super().__init__() 44 | 45 | class_counts = torch.FloatTensor(class_counts) 46 | conditions = class_counts[:, None] > class_counts[None, :] 47 | trues = (class_counts[None, :] / class_counts[:, None]) ** p 48 | print(trues.dtype) 49 | falses = torch.ones(len(class_counts), len(class_counts)) 50 | self.s = torch.where(conditions, trues, falses) 51 | self.num_labels = len(class_counts) 52 | self.eps = 1.0e-6 53 | 54 | def forward(self, logits, targets): 55 | targets = F.one_hot(targets, self.num_labels) 56 | # print(targets.shape) 57 | self.s = self.s.to(targets.device) 58 | max_element, _ = logits.max(axis=-1) 59 | logits = logits - max_element[:, None] # to prevent overflow 60 | 61 | numerator = torch.exp(logits) 62 | denominator = ( 63 | (1 - targets)[:, None, :] 64 | * self.s[None, :, :] 65 | * torch.exp(logits)[:, None, :]).sum(axis=-1) \ 66 | + torch.exp(logits) 67 | 68 | sigma = numerator / (denominator + self.eps) 69 | loss = (- targets * torch.log(sigma + self.eps)).sum(-1) 70 | return loss.mean() 71 | 72 | 73 | class DistibutionAgnosticSeesawLossWithLogits(nn.Module): 74 | """ 75 | This is unofficial implementation for Seesaw loss, 76 | which is proposed in the techinical report for LVIS workshop at ECCV 2020. 77 | For more detail, please refer https://arxiv.org/pdf/2008.10032.pdf. 78 | Args: 79 | p: Parameter for Mitigation Factor, 80 | Set to 0.8 for default following the paper. 81 | q: Parameter for Compensation Factor 82 | Set to 2 for default following the paper. 83 | num_labels: Class nums 84 | """ 85 | 86 | def __init__(self, p: float = 0.8, q: float = 2, num_labels=2, data_name='PETA'): #num_labels=2 87 | super().__init__() 88 | self.eps = 1.0e-6 89 | self.p = p 90 | self.q = q 91 | self.class_counts = None 92 | self.num_labels = num_labels 93 | self.data_name = data_name 94 | 95 | def forward(self, logits, targets): 96 | # Mitigation Factor 97 | if self.class_counts is None: 98 | self.class_counts = (targets.sum(axis=0) + 1).float() # to prevent devided by zero. 99 | else: 100 | self.class_counts += targets.sum(axis=0) 101 | 102 | m_conditions = self.class_counts[:, None] > self.class_counts[None, :] 103 | m_trues = (self.class_counts[None, :] / self.class_counts[:, None]) ** self.p 104 | m_falses = torch.ones(len(self.class_counts), len(self.class_counts)).to(targets.device) 105 | m = torch.where(m_conditions, m_trues, m_falses) # [num_labels, num_labels] 106 | 107 | # Compensation Factor 108 | probility = logits 109 | 110 | c_condition = probility / (probility * targets).sum(dim=-1)[:, None] 111 | c_condition = torch.stack([c_condition] * targets.shape[-1], dim=1) 112 | c_condition = c_condition * targets[:, :, None] 113 | false = torch.ones(c_condition.shape).to(targets.device) 114 | c = torch.where(c_condition>1, c_condition ** self.q, false) 115 | 116 | # Sij = Mij * Cij 117 | s = m[None, :, :] * c 118 | # s = c 119 | 120 | # softmax trick to prevent overflow (like logsumexp trick) 121 | max_element, _ = logits.max(axis=-1) 122 | logits = logits - max_element[:, None] # to prevent overflow 123 | numerator = torch.exp(logits) 124 | denominator = ( 125 | (1 - targets)[:, None, :] 126 | * s[None, :, :] 127 | * torch.exp(logits)[:, None, :]).sum(axis=-1) \ 128 | + torch.exp(logits) 129 | 130 | sigma = numerator / (denominator + self.eps) 131 | 132 | #seesaw loss 133 | loss = (- targets * torch.log(sigma + self.eps)).sum(-1) 134 | return loss.mean() 135 | 136 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import RAM_Model 2 | 3 | 4 | 5 | def build_model(cfg, clip_model, dataset, clip_model_teacher=None): 6 | model = RAM_Model( 7 | cfg, 8 | clip_model, 9 | dataset.classnames_seen, 10 | dataset.classnames_unseen, 11 | clip_model_teacher 12 | ) 13 | print('Build RAM Model Done') 14 | return model 15 | -------------------------------------------------------------------------------- /model/base.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import torch 4 | import torch.nn as nn 5 | from collections import OrderedDict 6 | 7 | from utils.model_utils import save_checkpoint, load_checkpoint, tolist_if_not 8 | from loss import AsymmetricLossOptimized, AsymmetricLossOptimized_partial, mmc_loss 9 | 10 | 11 | class BaseModel(nn.Module): 12 | """ 13 | Basic Model 14 | Implementation of some common functions 15 | """ 16 | def __init__(self): 17 | super().__init__() 18 | self._models = OrderedDict() 19 | self._optims = OrderedDict() 20 | self._scheds = OrderedDict() 21 | 22 | def get_model_names(self, names=None): 23 | names_real = list(self._models.keys()) 24 | if names is not None: 25 | names = tolist_if_not(names) 26 | for name in names: 27 | assert name in names_real 28 | return names 29 | else: 30 | return names_real 31 | 32 | def register_model(self, name="model", model=None, optim=None, sched=None): 33 | assert name not in self._models, "Found duplicate model names" 34 | 35 | self._models[name] = model 36 | self._optims[name] = optim 37 | self._scheds[name] = sched 38 | 39 | def get_current_lr(self, names=None): 40 | names = self.get_model_names(names) 41 | name = names[0] 42 | return self._optims[name].param_groups[0]["lr"] 43 | 44 | def get_specific_lr(self, names=None): 45 | if names is None: 46 | names = self.get_model_names(names) 47 | name = names[0] 48 | else: 49 | name = names 50 | return self._optims[name].param_groups[0]["lr"] 51 | 52 | def update_lr(self, names=None): 53 | names = self.get_model_names(names) 54 | 55 | for name in names: 56 | if self._scheds[name] is not None: 57 | self._scheds[name].step() 58 | 59 | def set_model_mode(self, mode="train", names=None): 60 | names = self.get_model_names(names) 61 | 62 | for name in names: 63 | if mode == "train": 64 | self._models[name].train() 65 | elif mode in ["test", "eval"]: 66 | self._models[name].eval() 67 | else: 68 | raise KeyError 69 | 70 | def detect_anomaly(self, loss): 71 | if not torch.isfinite(loss).all(): 72 | raise FloatingPointError("Loss is infinite or NaN!") 73 | 74 | def save_model(self, iters, directory, is_best=False): 75 | # save registered_module 76 | names = self.get_model_names() 77 | 78 | for name in names: 79 | model_dict = self._models[name].state_dict() 80 | save_dict = OrderedDict() 81 | for k, v in self._models[name].named_parameters(): 82 | if v.requires_grad: 83 | save_dict[k] = model_dict[k] 84 | 85 | sdir = os.path.join(directory, name) 86 | save_checkpoint( 87 | { 88 | "state_dict": save_dict, 89 | "iters": iters, 90 | }, 91 | sdir, 92 | is_best 93 | ) 94 | 95 | print(f"Checkpoint of {name} saved to {sdir}") 96 | 97 | def load_model(self, directory, iters): 98 | model_file = f"model-iters{iters}.pth" 99 | names = self.get_model_names() 100 | 101 | for name in names: 102 | model_path = os.path.join(directory, name, model_file) 103 | if not os.path.exists(model_path): 104 | raise FileNotFoundError('Model not found at "{}"'.format(model_path)) 105 | checkpoint = load_checkpoint(model_path) 106 | state_dict = checkpoint["state_dict"] 107 | iters = checkpoint["iters"] 108 | 109 | # Ignore fixed token vectors 110 | if "token_prefix" in state_dict: 111 | del state_dict["token_prefix"] 112 | if "token_suffix" in state_dict: 113 | del state_dict["token_suffix"] 114 | 115 | print("Loading weights to {} " 'from "{}"'.format(name, model_path)) 116 | self._models[name].load_state_dict(state_dict, strict=False) 117 | 118 | def make_criterion(self, cfg): 119 | """ 120 | Classification loss 121 | - Zero-shot setting: MMC loss 122 | - Partial-label setting: ASL partial loss 123 | """ 124 | if cfg.MODEL.LOSS_TYPE == 'MMC': 125 | criterion = mmc_loss 126 | elif cfg.MODEL.LOSS_TYPE == 'ASL': 127 | criterion = AsymmetricLossOptimized(cfg.SOLVER.GAMMA_NEG, cfg.SOLVER.GAMMA_POS, cfg.SOLVER.CLIP) 128 | elif cfg.MODEL.LOSS_TYPE == 'ASL-partial': 129 | criterion = AsymmetricLossOptimized_partial(cfg.SOLVER.GAMMA_NEG, cfg.SOLVER.GAMMA_POS, cfg.SOLVER.CLIP) 130 | else: 131 | raise NotImplementedError 132 | 133 | return criterion 134 | 135 | def freeze(self, transfer_type): 136 | if hasattr(self, "clip_model_teacher") and self.clip_model_teacher is not None: 137 | for name, param in self.clip_model_teacher.named_parameters(): 138 | param.requires_grad = False 139 | 140 | if transfer_type == "no_freeze": 141 | pass 142 | 143 | elif transfer_type == "freeze_all": 144 | for name, param in self.clip_model.named_parameters(): 145 | param.requires_grad = False 146 | 147 | elif transfer_type == "freeze_text": 148 | for name, param in self.clip_model.named_parameters(): 149 | if 'visual.' in name: 150 | continue 151 | else: 152 | param.requires_grad = False 153 | 154 | elif transfer_type == "Adapter": 155 | for name, param in self.clip_model.named_parameters(): 156 | if "adapter" in name or "embeds" in name: # embeds 157 | param.requires_grad = True 158 | else: 159 | param.requires_grad = False 160 | 161 | elif "partial" in transfer_type: 162 | total_layer = len(self.clip_model.visual.transformer.resblocks) 163 | partial_layer = int(transfer_type.split("-")[-1]) 164 | if partial_layer > total_layer: 165 | raise NotImplementedError 166 | for name, param in self.clip_model.named_parameters(): 167 | find = False 168 | for l in range(total_layer-partial_layer, total_layer): 169 | if "visual.transformer.resblocks.{}".format(l) in name: 170 | param.requires_grad = True 171 | find = True 172 | break 173 | if not find: 174 | param.requires_grad = False 175 | 176 | else: 177 | raise NotImplementedError 178 | 179 | def load_param(self, trained_path): 180 | param_dict = torch.load(trained_path) 181 | for i in param_dict: 182 | self.state_dict()[i.replace('module.', '')].copy_(param_dict[i]) 183 | print('Loading pretrained model from {}'.format(trained_path)) 184 | 185 | 186 | -------------------------------------------------------------------------------- /model/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.distributed as dist 5 | 6 | import clip 7 | from clip import clip 8 | from .base import BaseModel 9 | from .ot_solver import Sinkhorn 10 | 11 | 12 | class RAM_Model(BaseModel): 13 | def __init__( 14 | self, 15 | cfg, 16 | clip_model, 17 | classnames_seen, 18 | classnames_unseen, 19 | clip_model_teacher 20 | ): 21 | super().__init__() 22 | self.cfg = cfg 23 | self.classnames_seen = classnames_seen 24 | self.classnames_unseen = classnames_unseen 25 | self.num_classes_seen = len(classnames_seen) 26 | self.num_classes_unseen = len(classnames_unseen) 27 | self.criterion = self.make_criterion(cfg) 28 | self.device = dist.get_rank() if cfg.MODEL.DIST_TRAIN else 'cuda' 29 | 30 | # create teacher model, preserve teacher text features 31 | self.clip_model = clip_model 32 | self.text_tokens = self.get_text_templates() 33 | self.text_fea_teacher = self.get_text_fea(clip_model_teacher, classnames_seen+classnames_unseen) 34 | self.clip_model_teacher = clip_model_teacher 35 | 36 | # freeze the model 37 | self.freeze(cfg.MODEL.TRANSFER_TYPE) 38 | 39 | # pre-trained logit scale 40 | self.logit_scale = self.clip_model.logit_scale.exp() 41 | # learnable temperature 42 | self.temperature_loc = nn.Parameter(torch.tensor(cfg.MODEL.TEMPERATURE)) 43 | self.temperature = nn.Parameter(1./self.logit_scale) 44 | 45 | # KCOT parameters 46 | self.reg = cfg.MODEL.OT_REG 47 | self.reg_sc = cfg.MODEL.OT_REGSC 48 | 49 | def get_text_fea(self, clip_model, classnames): 50 | text_templates = "A photo of a {}." 51 | text_templates = [text_templates.format(classnames[i]) for i in range(len(classnames))] 52 | text_tok = clip.tokenize(text_templates) 53 | with torch.no_grad(): 54 | text_fea = clip_model.encode_text(text_tok) 55 | return text_fea.unsqueeze(1).detach() 56 | 57 | def get_text_templates(self): 58 | templates = "A photo of a {}." 59 | texts = [templates.format(name) for name in self.classnames_seen+self.classnames_unseen] 60 | text_tokens = clip.tokenize(texts) 61 | return text_tokens.cuda() 62 | 63 | def build_weights(self, sim, dim=-1, temperature=0.1): 64 | with torch.no_grad(): 65 | sim_max = sim.max(dim=dim)[0] 66 | weights = (sim_max / temperature).softmax(dim=-1) 67 | return weights 68 | 69 | def generate_teacher_distribution(self, img_teacher, zsl=False, gzsl=False): 70 | with torch.no_grad(): 71 | _, img_loc = self.clip_model_teacher.visual(img_teacher) 72 | img_loc = img_loc[0][:, 1:] 73 | text_fea = self.text_fea_teacher.clone().cuda() 74 | if zsl: 75 | text_fea = text_fea[self.num_classes_seen:] 76 | elif gzsl: 77 | pass 78 | else: 79 | text_fea = text_fea[:self.num_classes_seen] 80 | B, tok, dim=img_loc.shape 81 | C, gp, dim = text_fea.shape 82 | 83 | text_fea = F.normalize(text_fea, dim=-1) 84 | img_loc = F.normalize(img_loc, dim=-1) 85 | text_fea = text_fea.unsqueeze(0).expand(B, -1, -1, -1).reshape(B, -1, dim) 86 | 87 | # generate weight 88 | logit_scale = self.clip_model_teacher.logit_scale.exp() 89 | logits_loc = logit_scale * img_loc @ text_fea.transpose(-2, -1) 90 | logits_loc = logits_loc.reshape(B, -1, C, gp) 91 | local_similarity = logits_loc.softmax(dim=2) 92 | prob = (local_similarity*20.).softmax(dim=1) 93 | prob = prob.mean(dim=-1) 94 | return prob 95 | 96 | def forward(self, img, target=None, zsl=False, gzsl=False): 97 | seen = True if not zsl and not gzsl else False 98 | if seen: 99 | text_tokens = self.text_tokens[:self.num_classes_seen].clone() 100 | elif zsl: 101 | text_tokens = self.text_tokens[self.num_classes_seen:self.num_classes_seen+self.num_classes_unseen].clone() 102 | else: 103 | text_tokens = self.text_tokens[:self.num_classes_seen+self.num_classes_unseen].clone() 104 | prompt_fea_loc = self.clip_model.encode_text(text_tokens) 105 | prompt_fea_loc = prompt_fea_loc.unsqueeze(1) 106 | 107 | img_glb, img_loc = self.clip_model.visual(img, text_fea=prompt_fea_loc) 108 | img_loc = img_loc[0][:, 1:] 109 | 110 | B, tok, dim = img_loc.shape 111 | C, gp, dim = prompt_fea_loc.shape 112 | 113 | prompt_fea_loc = prompt_fea_loc.permute(1, 0, 2) 114 | 115 | img_glb = F.normalize(img_glb, dim=-1) 116 | img_loc = F.normalize(img_loc, dim=-1) 117 | prompt_fea_loc = F.normalize(prompt_fea_loc, dim=-1) 118 | 119 | logits_glb = img_glb @ prompt_fea_loc.transpose(1, 2) / self.temperature 120 | score_glb = logits_glb.squeeze(1).softmax(dim=-1) 121 | if self.training: 122 | mask = target 123 | loss_glb = self.criterion(logits_glb, mask) 124 | 125 | # Cost matrix 126 | sim = img_loc @ prompt_fea_loc.transpose(1, 2) 127 | cost = (sim * self.logit_scale).softmax(dim=-1) 128 | cost = 1.0 - cost 129 | 130 | if self.training: 131 | # Teacher is only applied in training 132 | frozen_mask = self.generate_teacher_distribution(img, zsl, gzsl) 133 | gt_mask = target.unsqueeze(1).expand(-1, tok, -1) 134 | frozen_mask[gt_mask==0] = frozen_mask.min() 135 | cost_tr = -torch.log(frozen_mask) * self.reg_sc 136 | cost = cost + cost_tr 137 | reg = self.reg + self.reg_sc 138 | else: 139 | reg = self.reg 140 | 141 | u = self.build_weights(sim.detach(), dim=2, temperature=0.1) 142 | v = torch.zeros((B, C), dtype=sim.dtype, device=sim.device).fill_(1./C) 143 | with torch.no_grad(): 144 | T = Sinkhorn(u, v, cost, reg=reg) 145 | if torch.isnan(T).any(): 146 | raise ValueError("Found nan in OT matrix!") 147 | 148 | sim_op = T * sim 149 | sim_op = sim_op.sum(dim=1) / self.temperature_loc 150 | score_loc = sim_op.softmax(dim=-1) 151 | score = (score_glb + score_loc) / 2. 152 | if self.training: 153 | mask = target 154 | loss_loc = self.criterion(sim_op, mask) 155 | loss = loss_glb + loss_loc 156 | return {"score": score, "loss": loss} 157 | else: 158 | return {"score": score} 159 | 160 | 161 | 162 | -------------------------------------------------------------------------------- /model/ot_solver.py: -------------------------------------------------------------------------------- 1 | """ 2 | Sinkhorn Algorithm for different settings of Optimal Transport 3 | Implementation inspired from: https://github.com/PythonOT/POT, thanks 4 | """ 5 | 6 | 7 | import torch 8 | 9 | 10 | 11 | def Sinkhorn(a, b, M, reg, max_iter=100, thresh=1e-3): 12 | """ 13 | Sinkhorn Iteration 14 | Solving Entropic Optimal Transport (EOT) 15 | Args: 16 | a: torch.Tensor[B, N], B - batch size, N - number of points in the source distribution 17 | b: torch.Tensor[B, M], B - batch size, M - number of points in the target distribution 18 | M: torch.Tensor[B, N, M], cost matrix 19 | reg: float, regularization strength 20 | max_iter: int, maximum number of iterations 21 | thresh: float, convergence threshold 22 | Returns: 23 | T: torch.Tensor[B, N, M], transport plan 24 | """ 25 | K = torch.exp(-M / reg) 26 | r = torch.ones_like(a) 27 | c = torch.ones_like(b) 28 | thresh = 1e-3 29 | 30 | for i in range(max_iter): 31 | r0 = r 32 | r = a / torch.matmul(K, c.unsqueeze(-1)).squeeze(-1) 33 | c = b / torch.matmul(K.permute(0, 2, 1).contiguous(), r.unsqueeze(-1)).squeeze(-1) 34 | err = (r - r0).abs().mean(dim=1) 35 | if torch.all(err < thresh): 36 | break 37 | T = torch.matmul(r.unsqueeze(-1), c.unsqueeze(-2)) * K 38 | return T 39 | 40 | 41 | 42 | def Sinkhorn_entropic_unbalanced(a, b, M, reg, reg_m, max_iter=100, thresh=1e-3): 43 | """ 44 | Sinkhorn Iteration 45 | Solving Entropic Unbalanced Optimal Transport (EUOT) 46 | Args: 47 | a: torch.Tensor[B, N], B - batch size, N - number of points in the source distribution 48 | b: torch.Tensor[B, M], B - batch size, M - number of points in the target distribution 49 | M: torch.Tensor[B, N, M], cost matrix 50 | reg: float, entropy regularization strength 51 | reg_m: float, marginal regularization strength 52 | max_iter: int, maximum number of iterations 53 | thresh: float, convergence threshold 54 | Returns: 55 | T: torch.Tensor[B, N, M], transport plan 56 | """ 57 | if isinstance(reg_m, float) or isinstance(reg_m, int): 58 | reg_m1, reg_m2 = reg_m, reg_m 59 | else: 60 | reg_m1, reg_m2 = reg_m[0], reg_m[1] 61 | 62 | u = torch.ones_like(a) 63 | v = torch.ones_like(b) 64 | 65 | # entropic reg 66 | K = torch.exp(-M / reg) 67 | # kl unbalanced 68 | fi_1 = reg_m1 / (reg_m1 + reg) 69 | fi_2 = reg_m2 / (reg_m2 + reg) 70 | 71 | thresh = 1e-3 72 | for i in range(max_iter): 73 | uprev = u 74 | vprev = v 75 | 76 | Kv = torch.matmul(K, v.unsqueeze(-1)).squeeze(-1) 77 | u = (a / Kv) ** fi_1 78 | Ktu = torch.matmul(K.permute(0, 2, 1).contiguous(), u.unsqueeze(-1)).squeeze(-1) 79 | v = (b / Ktu) ** fi_2 80 | 81 | max_u = torch.cat([torch.max(torch.abs(u), dim=1, keepdim=True)[0], torch.max(torch.abs(uprev), dim=1, keepdim=True)[0], torch.ones((u.shape[0], 1)).cuda()], dim=1) 82 | max_v = torch.cat([torch.max(torch.abs(v), dim=1, keepdim=True)[0], torch.max(torch.abs(vprev), dim=1, keepdim=True)[0], torch.ones((v.shape[0], 1)).cuda()], dim=1) 83 | 84 | err_u = torch.max(torch.abs(u - uprev), dim=1)[0] / torch.max(max_u, dim=1)[0] 85 | err_v = torch.max(torch.abs(v - vprev), dim=1)[0] / torch.max(max_v, dim=1)[0] 86 | 87 | err = 0.5 * (err_u.mean() + err_v.mean()) 88 | if err.item() < thresh: 89 | break 90 | 91 | T = torch.matmul(u.unsqueeze(-1), v.unsqueeze(-2)) * K 92 | return T 93 | 94 | 95 | 96 | def Sinkhorn_unbalanced(a, b, M, reg_m, div='kl', reg=0, max_iter=100, thresh=1e-3): 97 | """ 98 | Sinkhorn Iteration 99 | Solving Unbalanced Optimal Transport (UOT) 100 | Args: 101 | a: torch.Tensor[B, N], B - batch size, N - number of points in the source distribution 102 | b: torch.Tensor[B, M], B - batch size, M - number of points in the target distribution 103 | M: torch.Tensor[B, N, M], cost matrix 104 | reg_m: float, marginals regularization strength 105 | div: regularization method ("kl", "l2") 106 | max_iter: int, maximum number of iterations 107 | thresh: float, convergence threshold 108 | Returns: 109 | T: torch.Tensor[B, N, M], transport plan 110 | """ 111 | if isinstance(reg_m, float) or isinstance(reg_m, int): 112 | reg_m1, reg_m2 = reg_m, reg_m 113 | else: 114 | reg_m1, reg_m2 = reg_m[0], reg_m[1] 115 | 116 | G = torch.matmul(a.unsqueeze(-1), b.unsqueeze(-2)) 117 | c = torch.matmul(a.unsqueeze(-1), b.unsqueeze(-2)) 118 | assert div in ["kl", "l2"] 119 | 120 | if div == 'kl': 121 | sum_r = reg + reg_m1 + reg_m2 122 | r1, r2, r = reg_m1 / sum_r, reg_m2 / sum_r, reg / sum_r 123 | K = torch.matmul(a.unsqueeze(-1)**r1, b.unsqueeze(-2)**r2) * (c**r) * torch.exp(-M / sum_r) 124 | elif div == 'l2': 125 | K = reg_m1 * a.unsqueeze(-1) + reg_m2 * b.unsqueeze(-2) + reg * c - M 126 | K = torch.max(K, torch.zeros_like(M)) 127 | 128 | thresh = 1e-3 129 | for i in range(max_iter): 130 | Gprev = G 131 | 132 | if div == 'kl': 133 | Gd = torch.matmul(torch.sum(G, dim=-1, keepdim=True)**r1, torch.sum(G, dim=1, keepdim=True)**r2) + 1e-16 134 | G = K * G**(r1 + r2) / Gd 135 | elif div == 'l2': 136 | Gd = reg_m1 * torch.sum(G, dim=-1, keepdim=True) + \ 137 | reg_m2 * torch.sum(G, dim=1, keepdim=True) + reg * G + 1e-16 138 | G = K * G / Gd 139 | 140 | err = torch.sqrt(torch.sum((G - Gprev) ** 2, dim=(1,2)).mean()) 141 | if err < thresh: 142 | break 143 | 144 | return G 145 | 146 | 147 | 148 | 149 | -------------------------------------------------------------------------------- /processor/__init__.py: -------------------------------------------------------------------------------- 1 | from .processor import do_train -------------------------------------------------------------------------------- /processor/processor.py: -------------------------------------------------------------------------------- 1 | 2 | import logging 3 | import os 4 | import time 5 | import wandb 6 | import datetime 7 | import numpy as np 8 | from clip import convert_weights 9 | import torch 10 | import torch.nn as nn 11 | from torch.cuda import amp 12 | import torch.distributed as dist 13 | from torch.autograd import Variable 14 | 15 | from utils.model_utils import thread_flag 16 | from utils.model_utils import ModelEma 17 | from utils.meter import AverageMeter 18 | from utils.metrics import multilabel_evaluation 19 | 20 | import warnings 21 | warnings.filterwarnings("ignore") 22 | 23 | 24 | 25 | 26 | def do_train( 27 | cfg, 28 | model, 29 | train_loader, 30 | val_loader, 31 | val_loader_gzsl, 32 | optimizer, 33 | optimizer_sgd, 34 | scheduler, 35 | scheduler_sgd, 36 | output_dir, 37 | train_sampler=None, 38 | ): 39 | log_period = cfg.SOLVER.LOG_PERIOD 40 | checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD 41 | eval_period = cfg.SOLVER.EVAL_PERIOD 42 | device = "cuda" 43 | epochs = cfg.SOLVER.MAX_EPOCHS 44 | logger = logging.getLogger("RAM.train") 45 | logger.info('start training') 46 | 47 | if device: 48 | model.to(cfg.LOCAL_RANK) 49 | if torch.cuda.device_count() > 1 and cfg.MODEL.DIST_TRAIN: 50 | print('Using {} GPUs for training'.format(torch.cuda.device_count())) 51 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device) 52 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[cfg.LOCAL_RANK], find_unused_parameters=True) 53 | 54 | loss_meter = AverageMeter() 55 | acc_meter = AverageMeter() 56 | batch_time = AverageMeter() 57 | scaler = amp.GradScaler() 58 | 59 | # Training 60 | label_smoothing = cfg.SOLVER.LABEL_SMOOTHING 61 | torch.cuda.empty_cache() 62 | thresh = cfg.SOLVER.THRESH 63 | ema_m = None 64 | tot_iters = len(train_loader) * epochs 65 | 66 | for epoch in range(1, epochs + 1): 67 | if cfg.MODEL.USE_EMA: 68 | if cfg.MODEL.DIST_TRAIN: 69 | ema_m = ModelEma(model.module, cfg.MODEL.EMA_DECAY, device=dist.get_rank()) 70 | else: 71 | ema_m = ModelEma(model, cfg.MODEL.EMA_DECAY, device=device) 72 | 73 | if train_sampler is not None: 74 | train_sampler.set_epoch(epoch) 75 | torch.cuda.empty_cache() 76 | 77 | loss_meter.reset() 78 | acc_meter.reset() 79 | 80 | scheduler.step(epoch) 81 | scheduler_sgd.step(epoch) 82 | 83 | model.train() 84 | for n_iter, (img, label) in enumerate(train_loader): 85 | 86 | start = time.time() 87 | if cfg.SOLVER.LR_SCHEDULER == 'onecycle': 88 | scheduler.step() 89 | scheduler_sgd.step(epoch) 90 | correct = 0 91 | total = 0 92 | optimizer.zero_grad() 93 | optimizer_sgd.zero_grad() 94 | img = img.to(device) 95 | 96 | # construct GT matrix 97 | if label_smoothing: 98 | label_f = label.float() 99 | label_soft = torch.where(label_f == 1, torch.tensor(0.9), label_f) 100 | label_soft = torch.where(label_soft == 0, torch.tensor(0.1), label_soft) 101 | target = label_soft.to(device) 102 | else: 103 | target = label.to(device) 104 | 105 | with amp.autocast(enabled=True): 106 | output = model(img, target=target) 107 | 108 | score = output["score"] 109 | loss = output["loss"] 110 | 111 | # score, loss 112 | scaler.scale(loss).backward() 113 | gpu_mem = torch.cuda.max_memory_allocated()/(1024.0 * 1024.0) 114 | scaler.step(optimizer) 115 | scaler.step(optimizer_sgd) 116 | scaler.update() 117 | if ema_m is not None: 118 | if cfg.MODEL.DIST_TRAIN: 119 | ema_m.update(model.module) 120 | else: 121 | ema_m.update(model) 122 | 123 | targets = Variable(label) 124 | label = label.numpy() 125 | outputs_np = score.data.cpu().numpy() 126 | predicted = outputs_np > thresh 127 | correct += np.sum(predicted == label, axis=0) 128 | total += targets.size(0) 129 | acc = np.mean(correct / total) 130 | 131 | loss_meter.update(loss.item(), img.shape[0]) 132 | acc_meter.update(acc, 1) 133 | torch.cuda.synchronize() 134 | 135 | batch_time.update(time.time() - start) 136 | 137 | # Logging 138 | if (n_iter + 1) % log_period == 0: 139 | if thread_flag(cfg.MODEL.DIST_TRAIN): 140 | now_iter = (epoch-1) * len(train_loader) + n_iter 141 | nb_remain = tot_iters - now_iter 142 | eta_seconds = batch_time.avg * nb_remain 143 | eta = str(datetime.timedelta(seconds=int(eta_seconds))) 144 | cur_lr = optimizer.param_groups[0]['lr'] 145 | 146 | logger.info("Epoch[{}] Iteration[{}/{}] Loss: {:.3f}, Acc: {:.3f}, lr: {:.2e}, mem: {:.2f}MB, speed:{:.2f}[img/s], ETA: {}" 147 | .format(epoch, n_iter+1, len(train_loader), loss_meter.avg, acc, cur_lr, gpu_mem, train_loader.batch_size/batch_time.avg, eta)) 148 | 149 | if cfg.WANDB: 150 | wandb.log({ 151 | "epoch": epoch, 152 | "lr": cur_lr, 153 | "train loss": loss_meter.avg, 154 | "train acc": acc 155 | }) 156 | 157 | output_path = os.path.join(output_dir, f'epoch{epoch}.pth') 158 | if cfg.SOLVER.SAVE_MODEL and epoch % checkpoint_period == 0: 159 | if thread_flag(cfg.MODEL.DIST_TRAIN): 160 | torch.save(model.state_dict(), output_path) 161 | 162 | # Testing 163 | if epoch % eval_period == 0: 164 | if thread_flag(cfg.MODEL.DIST_TRAIN): 165 | Result_k, Result_k5 = validate(cfg, val_loader, model, device, zsl=True, gzsl=False) 166 | Result_k_gzsl, Result_k5_gzsl = validate(cfg, val_loader_gzsl, model, device, zsl=False, gzsl=True) 167 | now_metric = (Result_k["OF1"] + Result_k_gzsl["OF1"]) / 2. 168 | 169 | if ema_m is not None: 170 | Result_k_ema, Result_k5_ema = validate(cfg, val_loader, ema_m.module, device, zsl=True, gzsl=False) 171 | Result_k_gzsl_ema, Result_k5_gzsl_ema = validate(cfg, val_loader_gzsl, ema_m.module, device, zsl=False, gzsl=True) 172 | 173 | logger.info("Validation Results - Epoch: {}, F1_avg: {:.3%}".format(epoch, now_metric)) 174 | logger.info("ZSL:") 175 | logger.info("OP_3: {:.3%}, OR_3: {:.3%}, OF1_3: {:.3%}".format(Result_k['OP'], Result_k['OR'], Result_k['OF1'])) 176 | logger.info("OP_5: {:.3%}, OR_5: {:.3%}, OF1_5: {:.3%}".format(Result_k5['OP'], Result_k5['OR'], Result_k5['OF1'])) 177 | logger.info("GZSL:") 178 | logger.info("OP_3: {:.3%}, OR_3: {:.3%}, OF1_3: {:.3%}".format(Result_k_gzsl['OP'], Result_k_gzsl['OR'], Result_k_gzsl['OF1'])) 179 | logger.info("OP_5: {:.3%}, OR_5: {:.3%}, OF1_5: {:.3%}".format(Result_k5_gzsl['OP'], Result_k5_gzsl['OR'], Result_k5_gzsl['OF1'])) 180 | if ema_m is not None: 181 | logger.info("EMA Results:") 182 | logger.info("ZSL:") 183 | logger.info("OP_3: {:.3%}, OR_3: {:.3%}, OF1_3: {:.3%}".format(Result_k_ema['OP'], Result_k_ema['OR'], Result_k_ema['OF1'])) 184 | logger.info("OP_5: {:.3%}, OR_5: {:.3%}, OF1_5: {:.3%}".format(Result_k5_ema['OP'], Result_k5_ema['OR'], Result_k5_ema['OF1'])) 185 | logger.info("GZSL:") 186 | logger.info("OP_3: {:.3%}, OR_3: {:.3%}, OF1_3: {:.3%}".format(Result_k_gzsl_ema['OP'], Result_k_gzsl_ema['OR'], Result_k_gzsl_ema['OF1'])) 187 | logger.info("OP_5: {:.3%}, OR_5: {:.3%}, OF1_5: {:.3%}".format(Result_k5_gzsl_ema['OP'], Result_k5_gzsl_ema['OR'], Result_k5_gzsl_ema['OF1'])) 188 | 189 | # 3. log wandb 190 | if cfg.WANDB: 191 | wandb.log({ 192 | "F1_avg ZSL-GZSL": now_metric, 193 | "OP_3": Result_k['OP'], "OR_3": Result_k['OR'], "OF1_3": Result_k['OF1'], 194 | "OP_5": Result_k5['OP'], "OR_5": Result_k5['OR'], "OF1_5": Result_k5['OF1'], 195 | "OP_3 GZSL": Result_k_gzsl['OP'], "OR_3 GZSL": Result_k_gzsl['OR'], "OF1_3 GZSL": Result_k_gzsl['OF1'], 196 | "OP_5 GZSL": Result_k5_gzsl['OP'], "OR_5 GZSL": Result_k5_gzsl['OR'], "OF1_5 GZSL": Result_k5_gzsl['OF1'], 197 | }) 198 | if ema_m is not None: 199 | wandb.log({ 200 | "OP_3 ema": Result_k_ema['OP'], "OR_3 ema": Result_k_ema['OR'], "OF1_3 ema": Result_k_ema['OF1'], 201 | "OP_5 ema": Result_k5_ema['OP'], "OR_5 ema": Result_k5_ema['OR'], "OF1_5 ema": Result_k5_ema['OF1'], 202 | "OP_3 GZSL ema": Result_k_gzsl_ema['OP'], "OR_3 GZSL ema": Result_k_gzsl_ema['OR'], "OF1_3 GZSL ema": Result_k_gzsl_ema['OF1'], 203 | "OP_5 GZSL ema": Result_k5_gzsl_ema['OP'], "OR_5 GZSL ema": Result_k5_gzsl_ema['OR'], "OF1_5 GZSL ema": Result_k5_gzsl_ema['OF1'], 204 | }) 205 | 206 | torch.cuda.empty_cache() 207 | 208 | 209 | 210 | def validate(cfg, val_loader, model, device, zsl=True, gzsl=False): 211 | model.eval() 212 | total = 0 213 | batch_idx = 0 214 | batch_time = AverageMeter() 215 | 216 | for n_iter, (img, label) in enumerate(val_loader): 217 | with torch.no_grad(): 218 | img = img.to(device) 219 | target = label.to(device) 220 | start = time.time() 221 | output = model(img, target=target, zsl=zsl, gzsl=gzsl) 222 | score = output["score"] 223 | label = label.numpy() 224 | outputs_np = score.data.cpu().numpy() 225 | gpu_mem = torch.cuda.max_memory_allocated()/(1024.0 * 1024.0) 226 | 227 | batch_time.update(time.time() - start) 228 | 229 | if total == 0: 230 | g_labels = label 231 | p_score = outputs_np 232 | else: 233 | g_labels = np.row_stack((g_labels, label)) 234 | p_score = np.row_stack((p_score, outputs_np)) 235 | 236 | total += label.shape[0] 237 | batch_idx += 1 238 | 239 | if (n_iter + 1) % cfg.SOLVER.LOG_PERIOD == 0: 240 | print("mem:{:.2f}, test speed:{:.2f}".format(gpu_mem, val_loader.batch_size/batch_time.avg)) 241 | 242 | Result_k = multilabel_evaluation(p_score, g_labels, k=3) 243 | Result_k5 = multilabel_evaluation(p_score, g_labels, k=5) 244 | torch.cuda.empty_cache() 245 | 246 | return Result_k, Result_k5 247 | 248 | 249 | 250 | def parse_batch(batch): 251 | input, label = batch 252 | return input, label 253 | 254 | 255 | 256 | 257 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | ftfy==6.3.1 2 | imgaug==0.4.0 3 | ipdb==0.13.13 4 | numpy==1.25.2 5 | opencv_contrib_python==4.8.0.76 6 | opencv_python==4.8.0.76 7 | Pillow==9.4.0 8 | Pillow==11.1.0 9 | regex==2023.8.8 10 | setuptools==68.0.0 11 | timm==0.5.4 12 | torch==1.13.1 13 | torchvision==0.14.1 14 | tqdm==4.66.1 15 | yacs==0.1.8 16 | -------------------------------------------------------------------------------- /solver/__init__.py: -------------------------------------------------------------------------------- 1 | from .make_optimizer import make_optimizer 2 | from .make_scheduler import make_scheduler -------------------------------------------------------------------------------- /solver/make_optimizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | 5 | def make_optimizer(cfg, model, lr_mult=0.1): 6 | clip_params, sgd_params, other_params = [], [], [] 7 | for pname, p in model.named_parameters(): 8 | if not p.requires_grad: 9 | continue 10 | elif 'embeds' in pname: 11 | sgd_params.append(p) 12 | elif pname.startswith('clip'): 13 | clip_params.append(p) 14 | else: 15 | other_params.append(p) 16 | 17 | # Optimizer1 18 | param_groups = [ 19 | {'params': clip_params, 'lr': cfg.SOLVER.BASE_LR * lr_mult, 'weight_decay': cfg.SOLVER.WEIGHT_DECAY}, 20 | {'params': other_params, 'lr': cfg.SOLVER.BASE_LR, 'weight_decay': cfg.SOLVER.WEIGHT_DECAY}, 21 | ] 22 | if cfg.SOLVER.OPTIMIZER_NAME == 'SGD': 23 | optimizer = torch.optim.SGD(param_groups, momentum=cfg.SOLVER.MOMENTUM) 24 | elif cfg.SOLVER.OPTIMIZER_NAME == 'AdamW': 25 | optimizer = torch.optim.AdamW(param_groups) 26 | else: 27 | optimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)(param_groups, momentum=cfg.SOLVER.MOMENTUM) 28 | 29 | # Optimizer2 30 | param_groups_sgd = [{'params': sgd_params, 'lr': cfg.SOLVER.BASE_LR_SGD, 'weight_decay': cfg.SOLVER.WEIGHT_DECAY_SGD}] 31 | optimizer_sgd = torch.optim.SGD(param_groups_sgd, momentum=cfg.SOLVER.MOMENTUM) 32 | 33 | return optimizer, optimizer_sgd 34 | 35 | 36 | -------------------------------------------------------------------------------- /solver/make_scheduler.py: -------------------------------------------------------------------------------- 1 | """ Scheduler Factory 2 | Hacked together by / Copyright 2020 Ross Wightman 3 | """ 4 | import math 5 | import torch 6 | from torch.optim.lr_scheduler import OneCycleLR 7 | from typing import Dict, Any 8 | 9 | 10 | 11 | class Scheduler: 12 | """ Parameter Scheduler Base Class 13 | A scheduler base class that can be used to schedule any optimizer parameter groups. 14 | 15 | Unlike the builtin PyTorch schedulers, this is intended to be consistently called 16 | * At the END of each epoch, before incrementing the epoch count, to calculate next epoch's value 17 | * At the END of each optimizer update, after incrementing the update count, to calculate next update's value 18 | 19 | The schedulers built on this should try to remain as stateless as possible (for simplicity). 20 | 21 | This family of schedulers is attempting to avoid the confusion of the meaning of 'last_epoch' 22 | and -1 values for special behaviour. All epoch and update counts must be tracked in the training 23 | code and explicitly passed in to the schedulers on the corresponding step or step_update call. 24 | 25 | Based on ideas from: 26 | * https://github.com/pytorch/fairseq/tree/master/fairseq/optim/lr_scheduler 27 | * https://github.com/allenai/allennlp/tree/master/allennlp/training/learning_rate_schedulers 28 | """ 29 | 30 | def __init__(self, 31 | optimizer: torch.optim.Optimizer, 32 | param_group_field: str, 33 | noise_range_t=None, 34 | noise_type='normal', 35 | noise_pct=0.67, 36 | noise_std=1.0, 37 | noise_seed=None, 38 | initialize: bool = True) -> None: 39 | self.optimizer = optimizer 40 | self.param_group_field = param_group_field 41 | self._initial_param_group_field = f"initial_{param_group_field}" 42 | if initialize: 43 | for i, group in enumerate(self.optimizer.param_groups): 44 | if param_group_field not in group: 45 | raise KeyError(f"{param_group_field} missing from param_groups[{i}]") 46 | group.setdefault(self._initial_param_group_field, group[param_group_field]) 47 | else: 48 | for i, group in enumerate(self.optimizer.param_groups): 49 | if self._initial_param_group_field not in group: 50 | raise KeyError(f"{self._initial_param_group_field} missing from param_groups[{i}]") 51 | self.base_values = [group[self._initial_param_group_field] for group in self.optimizer.param_groups] 52 | self.metric = None # any point to having this for all? 53 | self.noise_range_t = noise_range_t 54 | self.noise_pct = noise_pct 55 | self.noise_type = noise_type 56 | self.noise_std = noise_std 57 | self.noise_seed = noise_seed if noise_seed is not None else 42 58 | self.update_groups(self.base_values) 59 | 60 | def state_dict(self) -> Dict[str, Any]: 61 | return {key: value for key, value in self.__dict__.items() if key != 'optimizer'} 62 | 63 | def load_state_dict(self, state_dict: Dict[str, Any]) -> None: 64 | self.__dict__.update(state_dict) 65 | 66 | def get_epoch_values(self, epoch: int): 67 | return None 68 | 69 | def get_update_values(self, num_updates: int): 70 | return None 71 | 72 | def step(self, epoch: int, metric: float = None) -> None: 73 | self.metric = metric 74 | values = self.get_epoch_values(epoch) 75 | if values is not None: 76 | values = self._add_noise(values, epoch) 77 | self.update_groups(values) 78 | 79 | def step_update(self, num_updates: int, metric: float = None): 80 | self.metric = metric 81 | values = self.get_update_values(num_updates) 82 | if values is not None: 83 | values = self._add_noise(values, num_updates) 84 | self.update_groups(values) 85 | 86 | def update_groups(self, values): 87 | if not isinstance(values, (list, tuple)): 88 | values = [values] * len(self.optimizer.param_groups) 89 | for param_group, value in zip(self.optimizer.param_groups, values): 90 | param_group[self.param_group_field] = value 91 | 92 | def _add_noise(self, lrs, t): 93 | if self.noise_range_t is not None: 94 | if isinstance(self.noise_range_t, (list, tuple)): 95 | apply_noise = self.noise_range_t[0] <= t < self.noise_range_t[1] 96 | else: 97 | apply_noise = t >= self.noise_range_t 98 | if apply_noise: 99 | g = torch.Generator() 100 | g.manual_seed(self.noise_seed + t) 101 | if self.noise_type == 'normal': 102 | while True: 103 | # resample if noise out of percent limit, brute force but shouldn't spin much 104 | noise = torch.randn(1, generator=g).item() 105 | if abs(noise) < self.noise_pct: 106 | break 107 | else: 108 | noise = 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct 109 | lrs = [v + v * noise for v in lrs] 110 | return lrs 111 | 112 | 113 | class CosineLRScheduler(Scheduler): 114 | """ 115 | Cosine decay with restarts. 116 | This is described in the paper https://arxiv.org/abs/1608.03983. 117 | 118 | Inspiration from 119 | https://github.com/allenai/allennlp/blob/master/allennlp/training/learning_rate_schedulers/cosine.py 120 | """ 121 | 122 | def __init__(self, 123 | optimizer: torch.optim.Optimizer, 124 | t_initial: int, 125 | t_mul: float = 1., 126 | lr_min: float = 0., 127 | decay_rate: float = 1., 128 | warmup_t=0, 129 | warmup_lr_init=0, 130 | warmup_prefix=False, 131 | cycle_limit=0, 132 | t_in_epochs=True, 133 | noise_range_t=None, 134 | noise_pct=0.67, 135 | noise_std=1.0, 136 | noise_seed=42, 137 | initialize=True) -> None: 138 | super().__init__( 139 | optimizer, param_group_field="lr", 140 | noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed, 141 | initialize=initialize) 142 | 143 | assert t_initial > 0 144 | assert lr_min >= 0 145 | self.t_initial = t_initial 146 | self.t_mul = t_mul 147 | self.lr_min = lr_min 148 | self.decay_rate = decay_rate 149 | self.cycle_limit = cycle_limit 150 | self.warmup_t = warmup_t 151 | self.warmup_lr_init = warmup_lr_init 152 | self.warmup_prefix = warmup_prefix 153 | self.t_in_epochs = t_in_epochs 154 | if self.warmup_t: 155 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] 156 | super().update_groups(self.warmup_lr_init) 157 | else: 158 | self.warmup_steps = [1 for _ in self.base_values] 159 | 160 | def _get_lr(self, t): 161 | if t < self.warmup_t: 162 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] 163 | else: 164 | if self.warmup_prefix: 165 | t = t - self.warmup_t 166 | 167 | if self.t_mul != 1: 168 | i = math.floor(math.log(1 - t / self.t_initial * (1 - self.t_mul), self.t_mul)) 169 | t_i = self.t_mul ** i * self.t_initial 170 | t_curr = t - (1 - self.t_mul ** i) / (1 - self.t_mul) * self.t_initial 171 | else: 172 | i = t // self.t_initial 173 | t_i = self.t_initial 174 | t_curr = t - (self.t_initial * i) 175 | 176 | gamma = self.decay_rate ** i 177 | lr_min = self.lr_min * gamma 178 | lr_max_values = [v * gamma for v in self.base_values] 179 | 180 | if self.cycle_limit == 0 or (self.cycle_limit > 0 and i < self.cycle_limit): 181 | lrs = [ 182 | lr_min + 0.5 * (lr_max - lr_min) * (1 + math.cos(math.pi * t_curr / t_i)) for lr_max in lr_max_values 183 | ] 184 | else: 185 | lrs = [self.lr_min for _ in self.base_values] 186 | 187 | return lrs 188 | 189 | def get_epoch_values(self, epoch: int): 190 | if self.t_in_epochs: 191 | return self._get_lr(epoch) 192 | else: 193 | return None 194 | 195 | def get_update_values(self, num_updates: int): 196 | if not self.t_in_epochs: 197 | return self._get_lr(num_updates) 198 | else: 199 | return None 200 | 201 | def get_cycle_length(self, cycles=0): 202 | if not cycles: 203 | cycles = self.cycle_limit 204 | cycles = max(1, cycles) 205 | if self.t_mul == 1.0: 206 | return self.t_initial * cycles 207 | else: 208 | return int(math.floor(-self.t_initial * (self.t_mul ** cycles - 1) / (1 - self.t_mul))) 209 | 210 | 211 | def make_scheduler(cfg, optimizer, dataloader_len): 212 | lr_scheduler = cfg.SOLVER.LR_SCHEDULER 213 | lr_mult = cfg.SOLVER.IMS_PER_BATCH / 256 214 | num_epochs = cfg.SOLVER.SCHEDULER_MAX_EPOCHS 215 | lr_min = 0.002 * cfg.SOLVER.BASE_LR 216 | warmup_lr_init = 0.01 * cfg.SOLVER.BASE_LR 217 | 218 | warmup_t = cfg.SOLVER.WARMUP_EPOCHS 219 | noise_range = None 220 | 221 | if lr_scheduler == 'onecycle': 222 | lr_scheduler = OneCycleLR( 223 | optimizer, 224 | max_lr=[cfg.SOLVER.BASE_LR_CLIP * lr_mult, cfg.SOLVER.BASE_LR * lr_mult], 225 | steps_per_epoch=dataloader_len, 226 | epochs=cfg.SOLVER.MAX_EPOCHS, 227 | pct_start=0.2 228 | ) 229 | else: 230 | lr_scheduler = CosineLRScheduler( 231 | optimizer, 232 | t_initial=num_epochs, 233 | lr_min=lr_min, 234 | t_mul= 1., 235 | decay_rate=0.1, 236 | warmup_lr_init=warmup_lr_init, 237 | warmup_t=warmup_t, 238 | cycle_limit=1, 239 | t_in_epochs=True, 240 | noise_range_t=noise_range, 241 | noise_pct= 0.67, 242 | noise_std= 1., 243 | noise_seed=42, 244 | ) 245 | 246 | return lr_scheduler 247 | 248 | 249 | -------------------------------------------------------------------------------- /src/method.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EricTan7/RAM/477f5051234819a11ca62ec441ecb3ee33abda65/src/method.png -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | from utils.logger import setup_logger 3 | from datasets import make_dataloader 4 | from model import build_model 5 | from solver import make_optimizer, make_scheduler 6 | from processor import do_train 7 | from utils.model_utils import set_seed, load_clip_to_cpu, thread_flag 8 | import torch 9 | import torch.distributed as dist 10 | import argparse 11 | from config import cfg 12 | import warnings 13 | warnings.filterwarnings("ignore") 14 | import time 15 | import wandb 16 | 17 | 18 | 19 | def main(cfg): 20 | set_seed(cfg.SOLVER.SEED) 21 | 22 | if cfg.MODEL.DIST_TRAIN: 23 | torch.cuda.set_device(cfg.LOCAL_RANK) 24 | dist.init_process_group(backend='nccl', init_method='env://') 25 | dist.barrier() 26 | 27 | # 1. Logging 28 | if cfg.WANDB: 29 | run = wandb.init(project=cfg.WANDB_PROJ, config=cfg) 30 | run.name = f'{cfg.DATASETS.NAMES}-{cfg.SOLVER.OPTIMIZER_NAME}-lr{cfg.SOLVER.BASE_LR}' 31 | 32 | output_dir = os.path.join(cfg.OUTPUT_DIR, 'RAM-' + time.strftime('%Y-%m-%d-%H-%M-%S')) 33 | if thread_flag(cfg.MODEL.DIST_TRAIN): 34 | if output_dir and not os.path.exists(output_dir): 35 | os.makedirs(output_dir) 36 | logger = setup_logger("RAM", output_dir, if_train=True) 37 | logger.info("Saving model in the path :{}".format(cfg.OUTPUT_DIR)) 38 | logger.info("Running with config:\n{}".format(cfg)) 39 | 40 | # 2. Data 41 | train_loader, val_loader, val_loader_gzsl, train_sampler, dataset = make_dataloader(cfg) 42 | 43 | # 3. Model 44 | clip_model = load_clip_to_cpu(cfg) 45 | clip_model_teacher = load_clip_to_cpu(cfg, zero_shot=True).eval() 46 | model = build_model(cfg, clip_model, dataset, clip_model_teacher) 47 | 48 | 49 | if cfg.MODEL.LOAD: 50 | state_dict = torch.load(cfg.TEST.WEIGHT) 51 | model.load_state_dict(state_dict) 52 | print(f"Load weights from: {cfg.TEST.WEIGHT}") 53 | 54 | # 4. Optimizer 55 | optimizer, optimizer_sgd = make_optimizer(cfg, model) 56 | scheduler, scheduler_sgd = make_scheduler(cfg, optimizer, len(train_loader)), make_scheduler(cfg, optimizer_sgd, len(train_loader)) 57 | 58 | # 5. Start training 59 | do_train( 60 | cfg, 61 | model, 62 | train_loader, 63 | val_loader, 64 | val_loader_gzsl, 65 | optimizer, 66 | optimizer_sgd, 67 | scheduler, 68 | scheduler_sgd, 69 | output_dir, 70 | train_sampler, 71 | ) 72 | 73 | 74 | if __name__ == '__main__': 75 | parser = argparse.ArgumentParser(description="RAM Training") 76 | parser.add_argument( 77 | "--config_file", 78 | help="path to config file", type=str, default="configs/coco.yml" 79 | ) 80 | parser.add_argument( 81 | "--local_rank", 82 | type=int, default=0 83 | ) 84 | parser.add_argument( 85 | "opts", 86 | help="Modify config options from command-line", nargs=argparse.REMAINDER, default=None 87 | ) 88 | args = parser.parse_args() 89 | 90 | if args.config_file != "": 91 | cfg.merge_from_file(args.config_file) 92 | cfg.merge_from_list(args.opts) 93 | cfg.LOCAL_RANK = args.local_rank 94 | 95 | main(cfg) 96 | 97 | 98 | 99 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EricTan7/RAM/477f5051234819a11ca62ec441ecb3ee33abda65/utils/__init__.py -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import sys 4 | import os 5 | 6 | 7 | def setup_logger(name, save_dir, if_train): 8 | logger = logging.getLogger(name) 9 | logger.setLevel(logging.DEBUG) 10 | 11 | ch = logging.StreamHandler(stream=sys.stdout) 12 | ch.setLevel(logging.DEBUG) 13 | formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s") 14 | ch.setFormatter(formatter) 15 | logger.addHandler(ch) 16 | 17 | if save_dir: 18 | if not os.path.exists(save_dir): 19 | os.makedirs(save_dir) 20 | if if_train: 21 | fh = logging.FileHandler(os.path.join(save_dir, "train_log.txt"), mode='a') 22 | else: 23 | fh = logging.FileHandler(os.path.join(save_dir, "test_log.txt"), mode='a') 24 | fh.setLevel(logging.DEBUG) 25 | fh.setFormatter(formatter) 26 | logger.addHandler(fh) 27 | 28 | return logger -------------------------------------------------------------------------------- /utils/meter.py: -------------------------------------------------------------------------------- 1 | class AverageMeter(object): 2 | """Computes and stores the average and current value""" 3 | 4 | def __init__(self): 5 | self.val = 0 6 | self.avg = 0 7 | self.sum = 0 8 | self.count = 0 9 | 10 | def reset(self): 11 | self.val = 0 12 | self.avg = 0 13 | self.sum = 0 14 | self.count = 0 15 | 16 | def update(self, val, n=1): 17 | self.val = val 18 | self.sum += val * n 19 | self.count += n 20 | self.avg = self.sum / self.count -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.parallel 3 | import torch.optim 4 | import torch.utils.data.distributed 5 | import numpy as np 6 | 7 | 8 | 9 | 10 | def cal_metrics(p_labels, g_labels): 11 | tp, fp, fn, tn = 0, 0, 0, 0 12 | target = g_labels 13 | 14 | pred = p_labels 15 | 16 | tp += ((pred + target)==2).sum(axis=0) 17 | fp += ((pred - target)==1).sum(axis=0) 18 | fn += ((pred - target)==-1).sum(axis=0) 19 | tn += ((pred + target)==0).sum(axis=0) 20 | p_c = [float(tp[i] / (tp[i] + fp[i])) if tp[i] > 0 else 0.0 for i in range(len(tp))] 21 | r_c = [float(tp[i] / (tp[i] + fn[i])) if tp[i] > 0 else 0.0 for i in range(len(tp))] 22 | 23 | mean_p_c = sum(p_c) / len(p_c) 24 | mean_r_c = sum(r_c) / len(r_c) 25 | if mean_p_c==0 and mean_r_c ==0: 26 | mean_f_c = 0. 27 | else: 28 | mean_f_c = 2 * mean_p_c * mean_r_c / (mean_p_c + mean_r_c) 29 | 30 | p_o = tp.sum() / (tp + fp).sum() 31 | r_o = tp.sum() / (tp + fn).sum() 32 | if p_o==0 and r_o ==0: 33 | f_o = 0. 34 | else: 35 | f_o = 2 * p_o * r_o / (p_o + r_o) 36 | 37 | Result = {} 38 | Result['CP'] = mean_p_c 39 | Result['CR'] = mean_r_c 40 | Result['CF1'] = mean_f_c 41 | Result['OP'] = p_o 42 | Result['OR'] = r_o 43 | Result['OF1'] = f_o 44 | 45 | return Result 46 | 47 | 48 | 49 | def multilabel_evaluation(scores, targets, k=1): 50 | scores = torch.tensor(scores) 51 | targets[targets == -1] = 0 52 | n, c = scores.size() 53 | pred = np.zeros((n, c)) 54 | index = scores.topk(k, 1, True, True)[1].numpy() 55 | for i in range(n): 56 | for ind in index[i]: 57 | pred[i, ind] = 1 58 | return cal_metrics(pred, targets) 59 | -------------------------------------------------------------------------------- /utils/model_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import random 4 | import torch.distributed as dist 5 | from clip import clip 6 | from functools import partial 7 | from collections import OrderedDict 8 | from copy import deepcopy 9 | import os 10 | import errno 11 | import pickle 12 | 13 | 14 | 15 | class ModelEma(torch.nn.Module): 16 | def __init__(self, model, decay=0.9997, device=None): 17 | super(ModelEma, self).__init__() 18 | # make a copy of the model for accumulating moving average of weights 19 | self.module = deepcopy(model) 20 | self.module.eval() 21 | 22 | self.decay = decay 23 | self.device = device 24 | if self.device is not None: 25 | self.module.to(device=device) 26 | 27 | def _update(self, model, update_fn): 28 | with torch.no_grad(): 29 | for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()): 30 | if self.device is not None: 31 | model_v = model_v.to(device=self.device) 32 | ema_v.copy_(update_fn(ema_v, model_v)) 33 | self.module.temperature = deepcopy(model.temperature) 34 | if hasattr(model, "temperature_glb"): 35 | self.module.temperature_glb = deepcopy(model.temperature_glb) 36 | 37 | def update(self, model): 38 | self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m) 39 | 40 | def set(self, model): 41 | self._update(model, update_fn=lambda e, m: m) 42 | 43 | 44 | 45 | 46 | def set_seed(seed): 47 | torch.manual_seed(seed) 48 | torch.cuda.manual_seed(seed) 49 | torch.cuda.manual_seed_all(seed) 50 | np.random.seed(seed) 51 | random.seed(seed) 52 | torch.backends.cudnn.deterministic = True 53 | torch.backends.cudnn.benchmark = False 54 | 55 | 56 | def thread_flag(dist_train): 57 | if not dist_train: 58 | return True 59 | else: 60 | return dist.get_rank() == 0 61 | 62 | 63 | def getModelSize(model): 64 | param_size = 0 65 | param_sum = 0 66 | grad_param_size = 0 67 | grad_param_sum = 0 68 | for param in model.parameters(): 69 | param_size += param.nelement() * param.element_size() 70 | param_sum += param.nelement() 71 | if param.requires_grad == True: 72 | grad_param_size += param.nelement() * param.element_size() 73 | grad_param_sum += param.nelement() 74 | print('total number of params:{:.3f}M'.format(param_sum / 1000 / 1000)) 75 | print('trainable number of params:{:.3f}M ({:.5%})'.format(grad_param_sum / 1000 / 1000, grad_param_sum/param_sum)) 76 | 77 | return (param_size, param_sum, grad_param_size) 78 | 79 | 80 | 81 | def convert_params_to_value(params): 82 | if params[0] == -1: 83 | return [-1] # not using 84 | elif params[-1] == -1: 85 | return list(range(params[0])) # continuous N layers 86 | else: 87 | return params 88 | 89 | 90 | def load_clip_to_cpu(cfg, zero_shot=False): 91 | backbone_name = cfg.MODEL.BACKBONE 92 | url = clip._MODELS[backbone_name] 93 | model_path = clip._download(url) 94 | 95 | try: 96 | # loading JIT archive 97 | model = torch.jit.load(model_path, map_location="cpu").eval() 98 | state_dict = None 99 | 100 | except RuntimeError: 101 | state_dict = torch.load(model_path, map_location="cpu") 102 | 103 | if zero_shot: 104 | saa_layer = [12, -1] if "ViT-B" in backbone_name else [24, -1] 105 | saa_layer = convert_params_to_value(saa_layer) 106 | design_details = { 107 | "depth_vision": [-1], 108 | "depth_text": [-1], 109 | "SAA_layer": saa_layer 110 | } 111 | print("Build zero-shot CLIP Model") 112 | else: 113 | depth_vision = convert_params_to_value(cfg.MODEL.DEPTH_VISION) 114 | depth_text = convert_params_to_value(cfg.MODEL.DEPTH_TEXT) 115 | saa_layer = convert_params_to_value(cfg.MODEL.SAA_LAYER) 116 | design_details = { 117 | "depth_vision": depth_vision, 118 | "vision_adapt": cfg.MODEL.VISION_ADAPT, 119 | "depth_text": depth_text, 120 | "text_ctx": cfg.MODEL.TEXT_CTX, 121 | "SAA_layer": saa_layer, 122 | "kernel_size": cfg.MODEL.KERNEL_SIZE 123 | } 124 | print("Build CLIP Model") 125 | 126 | model = clip.build_model(state_dict or model.state_dict(), cfg.INPUT.SIZE_TRAIN, design_details) 127 | model.visual.SAA_replace() 128 | 129 | return model.float() 130 | 131 | 132 | 133 | def mkdir_if_missing(dirname): 134 | """Create dirname if it is missing.""" 135 | if not os.path.exists(dirname): 136 | try: 137 | os.makedirs(dirname) 138 | except OSError as e: 139 | if e.errno != errno.EEXIST: 140 | raise 141 | 142 | 143 | def tolist_if_not(x): 144 | """Convert to a list.""" 145 | if not isinstance(x, list): 146 | x = [x] 147 | return x 148 | 149 | 150 | def save_checkpoint( 151 | state, 152 | save_dir, 153 | is_best=False, 154 | remove_module_from_keys=True 155 | ): 156 | r"""Save checkpoint. 157 | 158 | Args: 159 | state (dict): dictionary. 160 | save_dir (str): directory to save checkpoint. 161 | is_best (bool, optional): if True, this checkpoint will be copied and named 162 | ``model-best.pth.tar``. Default is False. 163 | remove_module_from_keys (bool, optional): whether to remove "module." 164 | from layer names. Default is True. 165 | model_name (str, optional): model name to save. 166 | """ 167 | mkdir_if_missing(save_dir) 168 | 169 | if remove_module_from_keys: 170 | # remove 'module.' in state_dict's keys 171 | state_dict = state["state_dict"] 172 | new_state_dict = OrderedDict() 173 | for k, v in state_dict.items(): 174 | if k.startswith("module."): 175 | k = k[7:] 176 | new_state_dict[k] = v 177 | state["state_dict"] = new_state_dict 178 | 179 | # save model 180 | iters = state["iters"] 181 | if is_best: 182 | model_name = "model-best.pth" 183 | else: 184 | model_name = f"model-iters{iters}.pth" 185 | fpath = os.path.join(save_dir, model_name) 186 | 187 | torch.save(state, fpath) 188 | 189 | 190 | def load_checkpoint(fpath): 191 | r"""Load checkpoint. 192 | 193 | ``UnicodeDecodeError`` can be well handled, which means 194 | python2-saved files can be read from python3. 195 | 196 | Args: 197 | fpath (str): path to checkpoint. 198 | 199 | Returns: 200 | dict 201 | 202 | Examples:: 203 | fpath = 'log/my_model/model.pth.tar-10' 204 | checkpoint = load_checkpoint(fpath) 205 | """ 206 | if fpath is None: 207 | raise ValueError("File path is None") 208 | 209 | if not os.path.exists(fpath): 210 | raise FileNotFoundError('File is not found at "{}"'.format(fpath)) 211 | 212 | map_location = None if torch.cuda.is_available() else "cpu" 213 | 214 | try: 215 | checkpoint = torch.load(fpath, map_location=map_location) 216 | 217 | except UnicodeDecodeError: 218 | pickle.load = partial(pickle.load, encoding="latin1") 219 | pickle.Unpickler = partial(pickle.Unpickler, encoding="latin1") 220 | checkpoint = torch.load( 221 | fpath, pickle_module=pickle, map_location=map_location 222 | ) 223 | 224 | except Exception: 225 | print('Unable to load checkpoint from "{}"'.format(fpath)) 226 | raise 227 | 228 | return checkpoint 229 | 230 | 231 | 232 | def load_pretrained_weights(model, weight_path): 233 | r"""Load pretrianed weights to model. 234 | 235 | Features:: 236 | - Incompatible layers (unmatched in name or size) will be ignored. 237 | - Can automatically deal with keys containing "module.". 238 | 239 | Args: 240 | model (nn.Module): network model. 241 | weight_path (str): path to pretrained weights. 242 | 243 | Examples:: 244 | # >>> weight_path = 'log/my_model/model-best.pth.tar' 245 | # >>> load_pretrained_weights(model, weight_path) 246 | """ 247 | checkpoint = load_checkpoint(weight_path) 248 | if "state_dict" in checkpoint: 249 | state_dict = checkpoint["state_dict"] 250 | else: 251 | state_dict = checkpoint 252 | 253 | model_dict = model.state_dict() 254 | new_state_dict = OrderedDict() 255 | matched_layers, discarded_layers = [], [] 256 | 257 | for k, v in state_dict.items(): 258 | if k.startswith("module."): 259 | k = k[7:] # discard module. 260 | 261 | if k in model_dict and model_dict[k].size() == v.size(): 262 | new_state_dict[k] = v 263 | matched_layers.append(k) 264 | else: 265 | discarded_layers.append(k) 266 | 267 | model_dict.update(new_state_dict) 268 | model.load_state_dict(model_dict) 269 | 270 | if len(matched_layers) == 0: 271 | print( 272 | f"Cannot load {weight_path} (check the key names manually)" 273 | ) 274 | else: 275 | print(f"Successfully loaded pretrained weights from {weight_path}") 276 | if len(discarded_layers) > 0: 277 | print( 278 | f"Layers discarded due to unmatched keys or size: {discarded_layers}" 279 | ) 280 | --------------------------------------------------------------------------------