├── .gitignore
├── README.md
├── clip
├── __init__.py
├── adapters.py
├── bpe_simple_vocab_16e6.txt.gz
├── clip.py
├── model.py
└── simple_tokenizer.py
├── config
├── __init__.py
└── defaults.py
├── configs
├── coco.yml
└── nus.yml
├── datasets
├── MultiLabel
│ └── classification.py
├── __init__.py
├── bases.py
├── json
│ ├── test_17_filtered.json
│ ├── test_65_filtered.json
│ └── train_48_filtered.json
└── make_dataloader.py
├── loss
├── __init__.py
├── asymmetric_loss.py
├── mmc_loss.py
└── seesawloss.py
├── model
├── __init__.py
├── base.py
├── model.py
└── ot_solver.py
├── processor
├── __init__.py
└── processor.py
├── requirements.txt
├── solver
├── __init__.py
├── make_optimizer.py
└── make_scheduler.py
├── src
└── method.png
├── train.py
└── utils
├── __init__.py
├── logger.py
├── meter.py
├── metrics.py
└── model_utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # UV
98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | #uv.lock
102 |
103 | # poetry
104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105 | # This is especially recommended for binary packages to ensure reproducibility, and is more
106 | # commonly ignored for libraries.
107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108 | #poetry.lock
109 |
110 | # pdm
111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
112 | #pdm.lock
113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
114 | # in version control.
115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
116 | .pdm.toml
117 | .pdm-python
118 | .pdm-build/
119 |
120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
121 | __pypackages__/
122 |
123 | # Celery stuff
124 | celerybeat-schedule
125 | celerybeat.pid
126 |
127 | # SageMath parsed files
128 | *.sage.py
129 |
130 | # Environments
131 | .env
132 | .venv
133 | env/
134 | venv/
135 | ENV/
136 | env.bak/
137 | venv.bak/
138 |
139 | # Spyder project settings
140 | .spyderproject
141 | .spyproject
142 |
143 | # Rope project settings
144 | .ropeproject
145 |
146 | # mkdocs documentation
147 | /site
148 |
149 | # mypy
150 | .mypy_cache/
151 | .dmypy.json
152 | dmypy.json
153 |
154 | # Pyre type checker
155 | .pyre/
156 |
157 | # pytype static type analyzer
158 | .pytype/
159 |
160 | # Cython debug symbols
161 | cython_debug/
162 |
163 | # PyCharm
164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
166 | # and can be added to the global gitignore or merged into this file. For a more nuclear
167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
168 | #.idea/
169 |
170 | # PyPI configuration file
171 | .pypirc
172 |
173 | # Temp folders
174 | data/
175 | wandb/
176 | checkpoints/
177 | .vscode/
178 |
179 | # logging files
180 | runs/
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # RAM: Open-Vocabulary Multi-Label Recognition through Knowledge-Constrained Optimal Transport
2 |
3 | Official implementation of the [paper](https://arxiv.org/abs/2503.15337) in CVPR 2025:
4 |
5 | **Recover and Match: Open-Vocabulary Multi-Label Recognition through Knowledge-Constrained Optimal Transport**
6 |
7 | ## 📨 Introduction
8 |
9 | RAM is an efficient matching framework for OVMLR (Open-Vocabulary Multi-Label Recognition). To address the urgent problems in existing methods, RAM involves (1) LLA to recover regional semantics, and (2) KCOT to find precise region-to-label matching.
10 |
11 |
12 |
13 |
14 |
15 |
16 | ## 🔧 Installation
17 |
18 | Install the environment through conda and pip is recommended:
19 |
20 | ```shell
21 | conda create -n ram python=3.10
22 | conda activate ram
23 |
24 | # Install the dependencies
25 | pip install -r requirements.txt
26 | ```
27 |
28 |
29 | ## 🎯 Running the code
30 | - `model/model.py`: Implementation of RAM model
31 | - `model/ot_solver.py`: Implementation of Sinkhorn Algorithm
32 | - `clip/adapters.py`: Implementation of LLA (Local Adapter)
33 | - `loss/mmc_loss.py`: Implementation of MMC loss (Multi-Matching loss)
34 |
35 | Run the following code to start training:
36 | ```shell
37 | python train.py --config_file configs/coco.yml
38 | ```
39 | Use wandb to log the running:
40 | ```shell
41 | python train.py --config_file configs/coco.yml WANDB True
42 | ```
43 |
44 | ## 💬 Discussion
45 | The core contribution is the OT-based matching pipeline, which we found beneficial to the OVMLR task while remaining highly efficient.
46 | The matching framework can be easily extended to dense prediction tasks (e.g., **semantic segmentation**). Welcome to transfer our approach to the segmentation scenarios.
47 |
48 | If you find our work useful, please cite our paper:
49 |
50 | ```
51 | @article{tan2025recoverandmatch,
52 | title={Recover and Match: Open-Vocabulary Multi-Label Recognition through Knowledge-Constrained Optimal Transport},
53 | author={Hao Tan and Zichang Tan and Jun Li and Ajian Liu and Jun Wan and Zhen Lei},
54 | journal={arXiv preprint arXiv:2503.15337},
55 | year={2025}
56 | }
57 | ```
58 |
59 | ## Acknowledgements
60 |
61 | This repo benefits from [MaPLe](https://github.com/muzairkhattak/multimodal-prompt-learning), [CLIP-Surgery](https://github.com/xmed-lab/CLIP_Surgery) and [POT](https://github.com/PythonOT/POT). Thanks for their wonderful works.
62 |
--------------------------------------------------------------------------------
/clip/__init__.py:
--------------------------------------------------------------------------------
1 | from .clip import *
2 | from .model import CLIP, convert_weights
3 |
--------------------------------------------------------------------------------
/clip/adapters.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 |
5 |
6 |
7 | class h_sigmoid(nn.Module):
8 | def __init__(self, inplace=True):
9 | super(h_sigmoid, self).__init__()
10 | self.relu = nn.ReLU6(inplace=inplace)
11 |
12 | def forward(self, x):
13 | return self.relu(x + 3) / 6
14 |
15 |
16 | class h_swish(nn.Module):
17 | def __init__(self, inplace=True):
18 | super(h_swish, self).__init__()
19 | self.sigmoid = h_sigmoid(inplace=inplace)
20 |
21 | def forward(self, x):
22 | return x * self.sigmoid(x)
23 |
24 |
25 | class QuickGELU(nn.Module):
26 | def forward(self, x: torch.Tensor):
27 | return x * torch.sigmoid(1.702 * x)
28 |
29 |
30 |
31 | class PFCrossAttention(nn.Module):
32 | """
33 | Parameter-free Cross-Attention
34 | https://arxiv.org/abs/2209.14169
35 | """
36 | def __init__(self, dim, qk_scale=None, attn_drop=0., proj_drop=0.):
37 | super().__init__()
38 | self.scale = qk_scale or dim ** -0.5
39 |
40 | self.attn_drop = nn.Dropout(attn_drop)
41 | self.proj_drop = nn.Dropout(proj_drop)
42 |
43 | def forward(self, query, kv):
44 | k, v = kv, kv
45 |
46 | attn = (query @ k.transpose(1,2)) * self.scale
47 | attn = attn.softmax(dim=-1)
48 | attn = self.attn_drop(attn)
49 |
50 | query = attn @ v
51 | query = self.proj_drop(query)
52 |
53 | return query
54 |
55 |
56 |
57 | class LocalAdapter(nn.Module):
58 | def __init__(self, in_dim, out_dim, stride=1, hidden_dim=None, kernel_size=3, text_dim=512):
59 | super().__init__()
60 | hidden_dim = text_dim
61 | layers1 = [
62 | nn.Conv2d(in_dim,
63 | hidden_dim,
64 | kernel_size=1,
65 | stride=stride,
66 | padding=1//2,
67 | bias=False),
68 | nn.BatchNorm2d(hidden_dim),
69 | nn.GELU()
70 | ]
71 | self.conv_adapter_layers1 = nn.Sequential(*layers1)
72 | layers2 = [
73 | nn.Conv2d(in_dim,
74 | hidden_dim,
75 | kernel_size=3,
76 | stride=stride,
77 | padding=3//2,
78 | bias=False),
79 | nn.BatchNorm2d(hidden_dim),
80 | nn.GELU()
81 | ]
82 | self.conv_adapter_layers2 = nn.Sequential(*layers2)
83 |
84 | self.conv_adapter_layers = nn.Conv2d(2, 2, kernel_size=kernel_size, stride=stride, padding=kernel_size//2, bias=False)
85 | self.conv_adapter_final = nn.Conv2d(hidden_dim, out_dim, kernel_size=1, stride=stride, padding=1//2, bias=False)
86 |
87 | self.adapter_norm_q1 = nn.LayerNorm(text_dim)
88 | self.adapter_norm_q2 = nn.LayerNorm(text_dim)
89 | self.adapter_norm_kv = nn.LayerNorm(text_dim)
90 | self.scale = text_dim ** -0.5
91 | self.adapter_cattn = PFCrossAttention(text_dim)
92 |
93 | def forward(self, x, text_fea=None):
94 | # x: channel first
95 | x_cls, x = x[:1], x[1:]
96 | tok, B, dim = x.shape
97 | H = int(tok**0.5)
98 | x_loc = x.permute(1, 2, 0).reshape(B, dim, H, H)
99 |
100 | # Dual convolutional streams
101 | x_loc1 = self.conv_adapter_layers1(x_loc)
102 | x_loc2 = self.conv_adapter_layers2(x_loc)
103 | x_loc1 = x_loc1.reshape(B, -1, tok)
104 | x_loc2 = x_loc2.reshape(B, -1, tok)
105 |
106 | # Cross attention with text features
107 | x_loc1 = x_loc1 + self.adapter_cattn(
108 | self.adapter_norm_q1(x_loc1.permute(0, 2, 1)),
109 | self.adapter_norm_kv(text_fea.permute(1, 0, 2))
110 | ).permute(0, 2, 1)
111 | x_loc2 = x_loc2 + self.adapter_cattn(
112 | self.adapter_norm_q2(x_loc2.permute(0, 2, 1)),
113 | self.adapter_norm_kv(text_fea.permute(1, 0, 2))
114 | ).permute(0, 2, 1)
115 |
116 | # Reshape and concat
117 | x_loc1 = x_loc1.reshape(B, -1, H, H)
118 | x_loc2 = x_loc2.reshape(B, -1, H, H)
119 | x_loc = torch.cat([x_loc1, x_loc2], dim=1)
120 |
121 | # Max and Average pooling
122 | avg_x = torch.mean(x_loc, dim=1, keepdim=True)
123 | max_x, _ = torch.max(x_loc, dim=1, keepdim=True)
124 |
125 | # Aggregated convolution
126 | agg = torch.cat([avg_x, max_x], dim=1)
127 | y = self.conv_adapter_layers(agg)
128 | y = F.sigmoid(y)
129 |
130 | # Final multiplication and convolution
131 | x = x_loc1 * y[:, 0].unsqueeze(1) + x_loc2 * y[:, 1].unsqueeze(1)
132 | x = self.conv_adapter_final(x)
133 | x = x.reshape(B, -1, tok).permute(2, 0, 1)
134 | x = torch.cat([x_cls, x], dim=0)
135 |
136 | return x
137 |
138 |
--------------------------------------------------------------------------------
/clip/bpe_simple_vocab_16e6.txt.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EricTan7/RAM/477f5051234819a11ca62ec441ecb3ee33abda65/clip/bpe_simple_vocab_16e6.txt.gz
--------------------------------------------------------------------------------
/clip/clip.py:
--------------------------------------------------------------------------------
1 | import hashlib
2 | import os
3 | import urllib
4 | import warnings
5 | from typing import Union, List
6 | from pkg_resources import packaging
7 |
8 | import torch
9 | from PIL import Image
10 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
11 | from tqdm import tqdm
12 |
13 | from .model import build_model
14 | from .simple_tokenizer import SimpleTokenizer as _Tokenizer
15 |
16 | try:
17 | from torchvision.transforms import InterpolationMode
18 | BICUBIC = InterpolationMode.BICUBIC
19 | except ImportError:
20 | BICUBIC = Image.BICUBIC
21 |
22 |
23 | if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"):
24 | warnings.warn("PyTorch version 1.7.1 or higher is recommended")
25 |
26 |
27 | __all__ = ["available_models", "load", "tokenize"]
28 | _tokenizer = _Tokenizer()
29 |
30 | _MODELS = {
31 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
32 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
33 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
34 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
35 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
36 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
37 | "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
38 | "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt",
39 | }
40 |
41 |
42 | def _download(url: str, root: str = os.path.expanduser("/mnt/nas/TrueNas1/pretrain/clip")):
43 | os.makedirs(root, exist_ok=True)
44 | filename = os.path.basename(url)
45 |
46 | expected_sha256 = url.split("/")[-2]
47 | download_target = os.path.join(root, filename)
48 |
49 | if os.path.exists(download_target) and not os.path.isfile(download_target):
50 | raise RuntimeError(f"{download_target} exists and is not a regular file")
51 |
52 | if os.path.isfile(download_target):
53 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
54 | return download_target
55 | else:
56 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
57 |
58 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
59 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop:
60 | while True:
61 | buffer = source.read(8192)
62 | if not buffer:
63 | break
64 |
65 | output.write(buffer)
66 | loop.update(len(buffer))
67 |
68 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
69 | raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
70 |
71 | return download_target
72 |
73 |
74 | def _transform(n_px):
75 | return Compose([
76 | Resize(n_px, interpolation=BICUBIC),
77 | CenterCrop(n_px),
78 | lambda image: image.convert("RGB"),
79 | ToTensor(),
80 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
81 | ])
82 |
83 |
84 | def available_models() -> List[str]:
85 | """Returns the names of available CLIP models"""
86 | return list(_MODELS.keys())
87 |
88 |
89 | def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=False, download_root=None,
90 | input_size=None, design_details=None):
91 | """Load a CLIP model
92 |
93 | Parameters
94 | ----------
95 | name : str
96 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
97 |
98 | device : Union[str, torch.device]
99 | The device to put the loaded model
100 |
101 | jit : bool
102 | Whether to load the optimized JIT model or more hackable non-JIT model (default).
103 |
104 | Returns
105 | -------
106 | model : torch.nn.Module
107 | The CLIP model
108 |
109 | preprocess : Callable[[PIL.Image], torch.Tensor]
110 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
111 | """
112 | if name in _MODELS:
113 | model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
114 | elif os.path.isfile(name):
115 | model_path = name
116 | else:
117 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
118 |
119 | try:
120 | # loading JIT archive
121 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
122 | state_dict = None
123 | except RuntimeError:
124 | # loading saved state dict
125 | if jit:
126 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
127 | jit = False
128 | state_dict = torch.load(model_path, map_location="cpu")
129 |
130 | if not jit:
131 | model = build_model(state_dict or model.state_dict(), input_size, design_details).to(device)
132 | if str(device) == "cpu":
133 | model.float()
134 | return model, _transform(model.visual.input_resolution)
135 |
136 | # patch the device names
137 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
138 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
139 |
140 | def patch_device(module):
141 | try:
142 | graphs = [module.graph] if hasattr(module, "graph") else []
143 | except RuntimeError:
144 | graphs = []
145 |
146 | if hasattr(module, "forward1"):
147 | graphs.append(module.forward1.graph)
148 |
149 | for graph in graphs:
150 | for node in graph.findAllNodes("prim::Constant"):
151 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
152 | node.copyAttributes(device_node)
153 |
154 | model.apply(patch_device)
155 | patch_device(model.encode_image)
156 | patch_device(model.encode_text)
157 |
158 | # patch dtype to float32 on CPU
159 | if str(device) == "cpu":
160 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
161 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
162 | float_node = float_input.node()
163 |
164 | def patch_float(module):
165 | try:
166 | graphs = [module.graph] if hasattr(module, "graph") else []
167 | except RuntimeError:
168 | graphs = []
169 |
170 | if hasattr(module, "forward1"):
171 | graphs.append(module.forward1.graph)
172 |
173 | for graph in graphs:
174 | for node in graph.findAllNodes("aten::to"):
175 | inputs = list(node.inputs())
176 | for i in [1, 2]: # dtype can be the second or third argument to aten::to()
177 | if inputs[i].node()["value"] == 5:
178 | inputs[i].node().copyAttributes(float_node)
179 |
180 | model.apply(patch_float)
181 | patch_float(model.encode_image)
182 | patch_float(model.encode_text)
183 |
184 | model.float()
185 |
186 | return model, _transform(model.input_resolution.item())
187 |
188 |
189 | def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor:
190 | """
191 | Returns the tokenized representation of given input string(s)
192 |
193 | Parameters
194 | ----------
195 | texts : Union[str, List[str]]
196 | An input string or a list of input strings to tokenize
197 |
198 | context_length : int
199 | The context length to use; all CLIP models use 77 as the context length
200 |
201 | truncate: bool
202 | Whether to truncate the text in case its encoding is longer than the context length
203 |
204 | Returns
205 | -------
206 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
207 | """
208 | if isinstance(texts, str):
209 | texts = [texts]
210 |
211 | sot_token = _tokenizer.encoder["<|startoftext|>"]
212 | eot_token = _tokenizer.encoder["<|endoftext|>"]
213 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
214 | if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"):
215 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
216 | else:
217 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.int)
218 |
219 | for i, tokens in enumerate(all_tokens):
220 | if len(tokens) > context_length:
221 | if truncate:
222 | tokens = tokens[:context_length]
223 | tokens[-1] = eot_token
224 | else:
225 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
226 | result[i, :len(tokens)] = torch.tensor(tokens)
227 |
228 | return result
229 |
--------------------------------------------------------------------------------
/clip/model.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 | from typing import Tuple, Union
3 |
4 | import numpy as np
5 | import torch
6 | import torch.nn.functional as F
7 | from torch import nn
8 | import math
9 | from torch.utils.checkpoint import checkpoint
10 | from .adapters import LocalAdapter
11 |
12 |
13 |
14 | class LayerNorm(nn.LayerNorm):
15 | """Subclass torch's LayerNorm to handle fp16."""
16 |
17 | def forward(self, x: torch.Tensor):
18 | orig_type = x.dtype
19 | ret = super().forward(x.type(torch.float32))
20 | return ret.type(orig_type)
21 |
22 |
23 | class QuickGELU(nn.Module):
24 | def forward(self, x: torch.Tensor):
25 | return x * torch.sigmoid(1.702 * x)
26 |
27 |
28 |
29 | class ModifiedAttention(nn.Module):
30 | def __init__(self, out_dim, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., settings=''):
31 | super().__init__()
32 | self.num_heads = num_heads
33 | head_dim = dim // num_heads
34 | self.scale = qk_scale or head_dim ** -0.5
35 |
36 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
37 | self.attn_drop = nn.Dropout(attn_drop)
38 | self.out_proj = nn.Linear(out_dim, dim)
39 | self.proj_drop = nn.Dropout(proj_drop)
40 | self.settings = settings
41 |
42 | def forward(self, x, n_adapt=-1):
43 | B, N, C = x.shape
44 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
45 | q, k, v = qkv[0], qkv[1], qkv[2]
46 |
47 | # original self-attention
48 | attn = (q @ k.transpose(-2, -1)) * self.scale
49 | attn = attn.softmax(dim=-1)
50 | attn = self.attn_drop(attn)
51 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
52 | x = self.proj_drop(self.out_proj(x))
53 |
54 | v = v[:, :, :-n_adapt] if n_adapt > 0 else v
55 | attn_loc = (v @ v.transpose(-2, -1)) * self.scale
56 | attn_loc = (attn_loc).softmax(dim=-1)
57 | attn_loc = self.attn_drop(attn_loc)
58 |
59 | x_loc = (attn_loc @ v).transpose(1, 2).reshape(B, -1, C)
60 | x_loc = self.proj_drop(self.out_proj(x_loc))
61 |
62 | return [x_loc, x]
63 |
64 |
65 |
66 | class Bottleneck(nn.Module):
67 | expansion = 4
68 |
69 | def __init__(self, inplanes, planes, stride=1):
70 | super().__init__()
71 |
72 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
73 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
74 | self.bn1 = nn.BatchNorm2d(planes)
75 | self.relu1 = nn.ReLU(inplace=True)
76 |
77 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
78 | self.bn2 = nn.BatchNorm2d(planes)
79 | self.relu2 = nn.ReLU(inplace=True)
80 |
81 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
82 |
83 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
84 | self.bn3 = nn.BatchNorm2d(planes * self.expansion)
85 | self.relu3 = nn.ReLU(inplace=True)
86 |
87 | self.downsample = None
88 | self.stride = stride
89 |
90 | if stride > 1 or inplanes != planes * Bottleneck.expansion:
91 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
92 | self.downsample = nn.Sequential(OrderedDict([
93 | ("-1", nn.AvgPool2d(stride)),
94 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
95 | ("1", nn.BatchNorm2d(planes * self.expansion))
96 | ]))
97 |
98 | def forward(self, x: torch.Tensor):
99 | identity = x
100 |
101 | out = self.relu1(self.bn1(self.conv1(x)))
102 | out = self.relu2(self.bn2(self.conv2(out)))
103 | out = self.avgpool(out)
104 | out = self.bn3(self.conv3(out))
105 |
106 | if self.downsample is not None:
107 | identity = self.downsample(x)
108 |
109 | out += identity
110 | out = self.relu3(out)
111 | return out
112 |
113 |
114 |
115 | class AttentionPool2d(nn.Module):
116 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
117 | super().__init__()
118 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
119 | self.k_proj = nn.Linear(embed_dim, embed_dim)
120 | self.q_proj = nn.Linear(embed_dim, embed_dim)
121 | self.v_proj = nn.Linear(embed_dim, embed_dim)
122 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
123 | self.num_heads = num_heads
124 |
125 | def forward(self, x):
126 | x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC
127 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
128 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
129 | x, _ = F.multi_head_attention_forward(
130 | query=x[:1], key=x, value=x,
131 | embed_dim_to_check=x.shape[-1],
132 | num_heads=self.num_heads,
133 | q_proj_weight=self.q_proj.weight,
134 | k_proj_weight=self.k_proj.weight,
135 | v_proj_weight=self.v_proj.weight,
136 | in_proj_weight=None,
137 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
138 | bias_k=None,
139 | bias_v=None,
140 | add_zero_attn=False,
141 | dropout_p=0,
142 | out_proj_weight=self.c_proj.weight,
143 | out_proj_bias=self.c_proj.bias,
144 | use_separate_proj_weight=True,
145 | training=self.training,
146 | need_weights=False
147 | )
148 | return x.squeeze(0)
149 |
150 |
151 |
152 | class ModifiedResNet(nn.Module):
153 | """
154 | A ResNet class that is similar to torchvision's but contains the following changes:
155 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
156 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
157 | - The final pooling layer is a QKV attention instead of an average pool
158 | """
159 |
160 | def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
161 | super().__init__()
162 | self.output_dim = output_dim
163 | self.input_resolution = input_resolution
164 |
165 | # the 3-layer stem
166 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
167 | self.bn1 = nn.BatchNorm2d(width // 2)
168 | self.relu1 = nn.ReLU(inplace=True)
169 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
170 | self.bn2 = nn.BatchNorm2d(width // 2)
171 | self.relu2 = nn.ReLU(inplace=True)
172 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
173 | self.bn3 = nn.BatchNorm2d(width)
174 | self.relu3 = nn.ReLU(inplace=True)
175 | self.avgpool = nn.AvgPool2d(2)
176 |
177 | # residual layers
178 | self._inplanes = width # this is a *mutable* variable used during construction
179 | self.layer1 = self._make_layer(width, layers[0])
180 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
181 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
182 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
183 |
184 | embed_dim = width * 32 # the ResNet feature dimension
185 | self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
186 |
187 | def _make_layer(self, planes, blocks, stride=1):
188 | layers = [Bottleneck(self._inplanes, planes, stride)]
189 |
190 | self._inplanes = planes * Bottleneck.expansion
191 | for _ in range(1, blocks):
192 | layers.append(Bottleneck(self._inplanes, planes))
193 |
194 | return nn.Sequential(*layers)
195 |
196 | def forward(self, x):
197 | def stem(x):
198 | x = self.relu1(self.bn1(self.conv1(x)))
199 | x = self.relu2(self.bn2(self.conv2(x)))
200 | x = self.relu3(self.bn3(self.conv3(x)))
201 | x = self.avgpool(x)
202 | return x
203 |
204 | x = x.type(self.conv1.weight.dtype)
205 | x = stem(x)
206 | x = self.layer1(x)
207 | x = self.layer2(x)
208 | x = self.layer3(x)
209 | x = self.layer4(x)
210 | x = self.attnpool(x)
211 |
212 | return x
213 |
214 |
215 |
216 | class ResidualAttentionBlock(nn.Module):
217 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None, text_layer=False, i=0, design_details=None):
218 | super().__init__()
219 |
220 | self.attn = nn.MultiheadAttention(d_model, n_head)
221 | self.ln_1 = LayerNorm(d_model)
222 | self.mlp = nn.Sequential(OrderedDict([
223 | ("c_fc", nn.Linear(d_model, d_model * 4)),
224 | ("gelu", QuickGELU()),
225 | ("c_proj", nn.Linear(d_model * 4, d_model))
226 | ]))
227 | self.ln_2 = LayerNorm(d_model)
228 | self.attn_mask = attn_mask
229 | self.text_layer = text_layer
230 | adapt_layer = design_details["depth_text"] if self.text_layer else design_details["depth_vision"]
231 | self.adapt = True if i in adapt_layer else False
232 | self.clip_text_len = 77
233 |
234 | if self.text_layer and self.adapt:
235 | self.n_ctx_text = design_details.get("text_ctx", 0)
236 | ctx_vectors = torch.empty(self.n_ctx_text, d_model)
237 | nn.init.normal_(ctx_vectors, std=0.02)
238 | self.deep_embeds = nn.Parameter(ctx_vectors)
239 |
240 | if not self.text_layer and self.adapt:
241 | kernel_size = design_details["kernel_size"]
242 | self.conv_adapter = LocalAdapter(d_model, d_model, kernel_size=kernel_size)
243 | self.n_adapt = design_details.get("vision_adapt", 0)
244 | ctx_vectors = torch.empty(self.n_adapt, d_model)
245 | nn.init.normal_(ctx_vectors, std=0.02)
246 | self.adapt_embeds = nn.Parameter(ctx_vectors)
247 |
248 | def attention(self, x: torch.Tensor, n_adapt=-1):
249 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
250 | if isinstance(self.attn, ModifiedAttention):
251 | x = x.transpose(0, 1)
252 | x_loc, x = self.attn(x, n_adapt)
253 | return [x_loc.transpose(0, 1), x.transpose(0, 1)]
254 | else:
255 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
256 |
257 | def forward(self, x: torch.Tensor, text_fea=None):
258 | if isinstance(self.attn, ModifiedAttention):
259 | if isinstance(x, list):
260 | x_loc, x = x
261 | else:
262 | x_loc, x = x, x
263 | if self.adapt:
264 | prefix = x[:, :, :]
265 | adapt_embeds = self.adapt_embeds.expand(x.shape[1], -1, -1).permute(1, 0, 2).half()
266 | x = torch.cat([prefix, adapt_embeds], dim=0)
267 | x_loc_res_attn, x_res = self.attention(self.ln_1(x), n_adapt=self.n_adapt)
268 | x = x + x_res
269 | x = x + self.mlp(self.ln_2(x))
270 | x = x[:x.shape[0]-self.n_adapt]
271 | x_loc_res_conv = self.conv_adapter(x, text_fea=text_fea)
272 | x_loc_res = x_loc_res_attn + x_loc_res_conv
273 | else:
274 | x_loc_res_attn, x_res = self.attention(self.ln_1(x))
275 | x = x + x_res
276 | x = x + self.mlp(self.ln_2(x))
277 | x_loc_res = x_loc_res_attn
278 |
279 | x_loc = x_loc + x_loc_res
280 | return [x_loc, x]
281 |
282 | elif self.text_layer:
283 | if self.adapt:
284 | prefix = x[:1, :, :]
285 | suffix = x[1:, :, :]
286 | deep_embeds = self.deep_embeds.expand(x.shape[1], -1, -1).permute(1, 0, 2).half()
287 | x = torch.cat([prefix, deep_embeds, suffix], dim=0)
288 | x_pad = x[self.clip_text_len:]
289 | x = x[:self.clip_text_len]
290 | x = x + self.attention(self.ln_1(x))
291 | x = x + self.mlp(self.ln_2(x))
292 | prefix = x[:1, :, :]
293 | suffix = x[1 + self.n_ctx_text:, :, :]
294 | x = torch.cat([prefix, suffix, x_pad], dim=0)
295 | else:
296 | x = x + self.attention(self.ln_1(x))
297 | x = x + self.mlp(self.ln_2(x))
298 | return x
299 |
300 |
301 | class Transformer(nn.Module):
302 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None, prompts_needed=None,
303 | text_layer=False, design_details=None):
304 | super().__init__()
305 | self.width = width
306 | self.layers = layers
307 | self.text_layer = text_layer
308 |
309 | self.resblocks = nn.Sequential(*[
310 | ResidualAttentionBlock(width, heads, attn_mask, text_layer, i, design_details)
311 | for i in range(layers)
312 | ])
313 |
314 | def CLIP_forward(self, x, out_layers):
315 | out_tokens = []
316 | if out_layers is None:
317 | out_layers = [len(self.resblocks)-1]
318 | for idx, block in enumerate(self.resblocks):
319 | x = block(x)
320 | if idx in out_layers:
321 | out_tokens.append(x)
322 | return x, out_tokens
323 |
324 | def SAA_CLIP_forward(self, x, out_layers, text_fea=None):
325 | out_tokens = []
326 | if out_layers is None:
327 | out_layers = [len(self.resblocks)-1]
328 | for idx, block in enumerate(self.resblocks):
329 | if idx == 0:
330 | x_loc, x = block(x, text_fea=text_fea)
331 | else:
332 | x_loc, x = block([x_loc, x], text_fea=text_fea)
333 | if idx in out_layers:
334 | out_tokens.append(x_loc)
335 | return x, out_tokens
336 |
337 | def forward(self, x, out_layers=None, use_SAA=True, use_checkpoint=False, text_fea=None):
338 | if self.text_layer:
339 | if use_checkpoint:
340 | x = checkpoint(self.resblocks[:6], x)
341 | x = checkpoint(self.resblocks[6:], x)
342 | return x
343 | else:
344 | return self.resblocks(x)
345 |
346 | else:
347 | if use_SAA:
348 | return self.SAA_CLIP_forward(x, out_layers, text_fea)
349 | else:
350 | return self.CLIP_forward(x, out_layers)
351 |
352 |
353 |
354 | class VisionTransformer(nn.Module):
355 | def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int,
356 | output_dim: int, design_details):
357 | super().__init__()
358 | self.input_resolution = input_resolution
359 | self.output_dim = output_dim
360 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
361 |
362 | SAA_layer = design_details["SAA_layer"] # [-1] or [2,3,5,...]
363 | assert SAA_layer[-1] <= layers-1, "SAA depth {} should not exceed Max layers {}".format(SAA_layer, layers)
364 | self.SAA_layer = SAA_layer if SAA_layer[0]!=-1 else None
365 |
366 | scale = width ** -0.5
367 | self.class_embedding = nn.Parameter(scale * torch.randn(width))
368 | self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
369 | self.ln_pre = LayerNorm(width)
370 |
371 | self.transformer = Transformer(width, layers, heads, design_details=design_details)
372 |
373 | self.ln_post = LayerNorm(width)
374 | self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
375 |
376 | self.attn = None
377 | self.embed_dim = width
378 | self.num_heads = heads
379 |
380 | @torch.no_grad()
381 | def SAA_replace(self):
382 | if self.SAA_layer is not None:
383 | for i in self.SAA_layer:
384 | self.attn = ModifiedAttention(self.embed_dim, self.embed_dim, self.num_heads, True)
385 | self.attn.qkv.weight.data = self.transformer.resblocks[i].attn.in_proj_weight.clone()
386 | self.attn.qkv.bias.data = self.transformer.resblocks[i].attn.in_proj_bias.clone()
387 | self.attn.out_proj.weight.data = self.transformer.resblocks[i].attn.out_proj.weight.clone()
388 | self.attn.out_proj.bias.data = self.transformer.resblocks[i].attn.out_proj.bias.clone()
389 | self.transformer.resblocks[i].attn = self.attn
390 | print(f"SAA replace Done")
391 |
392 | def forward(self, x: torch.Tensor, out_layers=[11], use_SAA=True, text_fea=None):
393 | x = self.conv1(x)
394 | x = x.reshape(x.shape[0], x.shape[1], -1)
395 | x = x.permute(0, 2, 1)
396 | x = torch.cat(
397 | [self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device),
398 | x], dim=1)
399 | x = x + self.positional_embedding.to(x.dtype)
400 |
401 | # Normal code as before
402 | x = self.ln_pre(x)
403 |
404 | x = x.permute(1, 0, 2) # NLD -> LND
405 | x, mid_features = self.transformer(x, out_layers, use_SAA=use_SAA, text_fea=text_fea)
406 | x = self.ln_post(x.permute(1, 0, 2))
407 | mid_features_proj = []
408 | for fea in mid_features:
409 | fea = self.ln_post(fea.permute(1, 0, 2)) @ self.proj # LND -> NLD
410 | mid_features_proj.append(fea)
411 |
412 | return x[:, :1, :] @ self.proj, mid_features_proj
413 |
414 |
415 |
416 |
417 | class CLIP(nn.Module):
418 | def __init__(self,
419 | embed_dim: int,
420 | image_resolution: int,
421 | vision_layers: Union[Tuple[int, int, int, int], int],
422 | vision_width: int,
423 | vision_patch_size: int,
424 | context_length: int,
425 | vocab_size: int,
426 | transformer_width: int,
427 | transformer_heads: int,
428 | transformer_layers: int,
429 | design_details
430 | ):
431 | super().__init__()
432 |
433 | self.context_length = context_length
434 |
435 | if isinstance(vision_layers, (tuple, list)):
436 | vision_heads = vision_width * 32 // 64
437 | self.visual = ModifiedResNet(
438 | layers=vision_layers,
439 | output_dim=embed_dim,
440 | heads=vision_heads,
441 | input_resolution=image_resolution,
442 | width=vision_width,
443 | design_details=design_details
444 | )
445 | else:
446 |
447 | vision_heads = vision_width // 64
448 | self.visual = VisionTransformer(
449 | input_resolution=image_resolution,
450 | patch_size=vision_patch_size,
451 | width=vision_width,
452 | layers=vision_layers,
453 | heads=vision_heads,
454 | output_dim=embed_dim,
455 | design_details=design_details
456 | )
457 |
458 | self.transformer = Transformer(
459 | width=transformer_width,
460 | layers=transformer_layers,
461 | heads=transformer_heads,
462 | attn_mask=self.build_attention_mask(),
463 | text_layer=True,
464 | design_details=design_details
465 | )
466 |
467 | self.vocab_size = vocab_size
468 | self.token_embedding = nn.Embedding(vocab_size, transformer_width)
469 | self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
470 | self.ln_final = LayerNorm(transformer_width)
471 |
472 | self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
473 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
474 |
475 | self.initialize_parameters()
476 |
477 | def initialize_parameters(self):
478 | nn.init.normal_(self.token_embedding.weight, std=0.02)
479 | nn.init.normal_(self.positional_embedding, std=0.01)
480 |
481 | if isinstance(self.visual, ModifiedResNet):
482 | if self.visual.attnpool is not None:
483 | std = self.visual.attnpool.c_proj.in_features ** -0.5
484 | nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
485 | nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
486 | nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
487 | nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
488 |
489 | for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
490 | for name, param in resnet_block.named_parameters():
491 | if name.endswith("bn3.weight"):
492 | nn.init.zeros_(param)
493 |
494 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
495 | attn_std = self.transformer.width ** -0.5
496 | fc_std = (2 * self.transformer.width) ** -0.5
497 | for block in self.transformer.resblocks:
498 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
499 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
500 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
501 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
502 |
503 | if self.text_projection is not None:
504 | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
505 |
506 | def build_attention_mask(self):
507 | # lazily create causal attention mask, with full attention between the vision tokens
508 | # pytorch uses additive attention mask; fill with -inf
509 | mask = torch.empty(self.context_length, self.context_length)
510 | mask.fill_(float("-inf"))
511 | mask.triu_(1) # zero out the lower diagonal
512 | return mask
513 |
514 | @property
515 | def dtype(self):
516 | return self.visual.conv1.weight.dtype
517 |
518 | def encode_image(self, image, features_layers=None, ffn=False):
519 | return self.visual(image.type(self.dtype), features_layers, ffn)
520 |
521 | def encode_text(self, text, use_checkpoint=False):
522 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
523 |
524 | x = x + self.positional_embedding.type(self.dtype)
525 | x = x.permute(1, 0, 2) # NLD -> LND
526 | x = self.transformer(x, use_checkpoint=use_checkpoint)
527 | x = x.permute(1, 0, 2) # LND -> NLD
528 | x = self.ln_final(x).type(self.dtype)
529 |
530 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
531 |
532 | return x
533 |
534 | def forward(self, image, text):
535 | image_features = self.encode_image(image)[:, 0, :]
536 | text_features = self.encode_text(text)
537 |
538 | # normalized features
539 | image_features = image_features / image_features.norm(dim=-1, keepdim=True)
540 | text_features = text_features / text_features.norm(dim=-1, keepdim=True)
541 |
542 | # cosine similarity as logits
543 | logit_scale = self.logit_scale.exp()
544 | logits_per_image = logit_scale * image_features @ text_features.t()
545 | logits_per_text = logit_scale * text_features @ image_features.t()
546 |
547 | return logits_per_image, logits_per_text
548 |
549 |
550 |
551 | def convert_weights(model: nn.Module):
552 | """Convert applicable model parameters to fp16"""
553 |
554 | def _convert_weights_to_fp16(l):
555 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
556 | l.weight.data = l.weight.data.half()
557 | if l.bias is not None:
558 | l.bias.data = l.bias.data.half()
559 |
560 | if isinstance(l, nn.MultiheadAttention):
561 | for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
562 | tensor = getattr(l, attr)
563 | if tensor is not None:
564 | tensor.data = tensor.data.half()
565 |
566 | for name in ["text_projection", "proj"]:
567 | if hasattr(l, name):
568 | attr = getattr(l, name)
569 | if attr is not None:
570 | attr.data = attr.data.half()
571 |
572 | model.apply(_convert_weights_to_fp16)
573 |
574 |
575 | def build_model(state_dict: dict, input_size: list, design_details):
576 | vit = "visual.proj" in state_dict
577 |
578 | if vit:
579 | vision_width = state_dict["visual.conv1.weight"].shape[0]
580 | vision_layers = len(
581 | [k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
582 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
583 |
584 | # NOTE: resize pos_embed, so that input resolution can be customized
585 | if input_size[0] >= 224 and input_size[1] >= 224:
586 | num_x = input_size[1] // vision_patch_size
587 | num_y = input_size[0] // vision_patch_size
588 | num_patches = num_x * num_y
589 | if state_dict["visual.positional_embedding"].shape[0] != num_patches + 1:
590 | state_dict["visual.positional_embedding"] = resize_pos_embed(state_dict["visual.positional_embedding"],
591 | num_patches+1, num_y, num_x)
592 |
593 | grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
594 | image_resolution = vision_patch_size * grid_size
595 | else:
596 | counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in
597 | [1, 2, 3, 4]]
598 | vision_layers = tuple(counts)
599 | vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
600 |
601 |
602 | # assert input_size[0] == input_size[1]
603 | if input_size[0] >= 224 and input_size[1] >= 224:
604 | num_x = input_size[0] // 32
605 | num_fea = num_x ** 2
606 | if state_dict["visual.attnpool.positional_embedding"].shape[0] != num_fea + 1:
607 | state_dict["visual.attnpool.positional_embedding"] = resize_pos_embed(state_dict["visual.attnpool.positional_embedding"],
608 | num_fea+1, num_x, num_x)
609 |
610 | output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
611 | vision_patch_size = None
612 | assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
613 | image_resolution = output_width * 32
614 |
615 | embed_dim = state_dict["text_projection"].shape[1]
616 | context_length = state_dict["positional_embedding"].shape[0]
617 | vocab_size = state_dict["token_embedding.weight"].shape[0]
618 | transformer_width = state_dict["ln_final.weight"].shape[0]
619 | transformer_heads = transformer_width // 64
620 | transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
621 |
622 | model = CLIP(
623 | embed_dim,
624 | image_resolution, vision_layers, vision_width, vision_patch_size,
625 | context_length, vocab_size, transformer_width, transformer_heads, transformer_layers, design_details
626 | )
627 |
628 | for key in ["input_resolution", "context_length", "vocab_size"]:
629 | if key in state_dict:
630 | del state_dict[key]
631 |
632 | convert_weights(model)
633 | model.load_state_dict(state_dict, strict=False)
634 | return model.eval()
635 |
636 |
637 | def resize_pos_embed(posemb, num_tok, hight, width):
638 | # Rescale the grid of position embeddings when loading from state_dict. Adapted from
639 | # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
640 | # posemb_token, posemb_grid = posemb[:, :1], posemb[0, 1:]
641 | num_tok -= 1
642 | posemb_token, posemb_grid = posemb[0, :], posemb[1:, :]
643 |
644 | gs_old = int(math.sqrt(len(posemb_grid)))
645 | print('Resized position embedding from:{} to: {} with height:{} width: {}'.format(posemb.shape[0], num_tok, hight, width))
646 | posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
647 | posemb_grid = F.interpolate(posemb_grid, size=(hight, width), mode='bilinear')
648 | posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, hight * width, -1)
649 | posemb_grid = posemb_grid.squeeze(0)
650 | posemb_token = posemb_token.unsqueeze(0)
651 | posemb = torch.cat([posemb_token, posemb_grid], dim=0)
652 | return posemb
653 |
--------------------------------------------------------------------------------
/clip/simple_tokenizer.py:
--------------------------------------------------------------------------------
1 | import gzip
2 | import html
3 | import os
4 | from functools import lru_cache
5 |
6 | import ftfy
7 | import regex as re
8 |
9 |
10 | @lru_cache()
11 | def default_bpe():
12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
13 |
14 |
15 | @lru_cache()
16 | def bytes_to_unicode():
17 | """
18 | Returns list of utf-8 byte and a corresponding list of unicode strings.
19 | The reversible bpe codes work on unicode strings.
20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
22 | This is a signficant percentage of your normal, say, 32K bpe vocab.
23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
24 | And avoids mapping to whitespace/control characters the bpe code barfs on.
25 | """
26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
27 | cs = bs[:]
28 | n = 0
29 | for b in range(2**8):
30 | if b not in bs:
31 | bs.append(b)
32 | cs.append(2**8+n)
33 | n += 1
34 | cs = [chr(n) for n in cs]
35 | return dict(zip(bs, cs))
36 |
37 |
38 | def get_pairs(word):
39 | """Return set of symbol pairs in a word.
40 | Word is represented as tuple of symbols (symbols being variable-length strings).
41 | """
42 | pairs = set()
43 | prev_char = word[0]
44 | for char in word[1:]:
45 | pairs.add((prev_char, char))
46 | prev_char = char
47 | return pairs
48 |
49 |
50 | def basic_clean(text):
51 | text = ftfy.fix_text(text)
52 | text = html.unescape(html.unescape(text))
53 | return text.strip()
54 |
55 |
56 | def whitespace_clean(text):
57 | text = re.sub(r'\s+', ' ', text)
58 | text = text.strip()
59 | return text
60 |
61 |
62 | class SimpleTokenizer(object):
63 | def __init__(self, bpe_path: str = default_bpe()):
64 | self.byte_encoder = bytes_to_unicode()
65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
67 | merges = merges[1:49152-256-2+1]
68 | merges = [tuple(merge.split()) for merge in merges]
69 | vocab = list(bytes_to_unicode().values())
70 | vocab = vocab + [v+'' for v in vocab]
71 | for merge in merges:
72 | vocab.append(''.join(merge))
73 | vocab.extend(['<|startoftext|>', '<|endoftext|>'])
74 | self.encoder = dict(zip(vocab, range(len(vocab))))
75 | self.decoder = {v: k for k, v in self.encoder.items()}
76 | self.bpe_ranks = dict(zip(merges, range(len(merges))))
77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
79 |
80 | def bpe(self, token):
81 | if token in self.cache:
82 | return self.cache[token]
83 | word = tuple(token[:-1]) + ( token[-1] + '',)
84 | pairs = get_pairs(word)
85 |
86 | if not pairs:
87 | return token+''
88 |
89 | while True:
90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
91 | if bigram not in self.bpe_ranks:
92 | break
93 | first, second = bigram
94 | new_word = []
95 | i = 0
96 | while i < len(word):
97 | try:
98 | j = word.index(first, i)
99 | new_word.extend(word[i:j])
100 | i = j
101 | except:
102 | new_word.extend(word[i:])
103 | break
104 |
105 | if word[i] == first and i < len(word)-1 and word[i+1] == second:
106 | new_word.append(first+second)
107 | i += 2
108 | else:
109 | new_word.append(word[i])
110 | i += 1
111 | new_word = tuple(new_word)
112 | word = new_word
113 | if len(word) == 1:
114 | break
115 | else:
116 | pairs = get_pairs(word)
117 | word = ' '.join(word)
118 | self.cache[token] = word
119 | return word
120 |
121 | def encode(self, text):
122 | bpe_tokens = []
123 | text = whitespace_clean(basic_clean(text)).lower()
124 | for token in re.findall(self.pat, text):
125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
127 | return bpe_tokens
128 |
129 | def decode(self, tokens):
130 | text = ''.join([self.decoder[token] for token in tokens])
131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ')
132 | return text
133 |
--------------------------------------------------------------------------------
/config/__init__.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: sherlock
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 | from .defaults import _C as cfg
8 |
--------------------------------------------------------------------------------
/config/defaults.py:
--------------------------------------------------------------------------------
1 | from yacs.config import CfgNode as CN
2 |
3 | # -----------------------------------------------------------------------------
4 | # Convention about Training / Test specific parameters
5 | # -----------------------------------------------------------------------------
6 | # Whenever an argument can be either used for training or for testing, the
7 | # corresponding name will be post-fixed by a _TRAIN for a training parameter,
8 |
9 | # -----------------------------------------------------------------------------
10 | # Config definition
11 | # -----------------------------------------------------------------------------
12 |
13 | _C = CN()
14 | # -----------------------------------------------------------------------------
15 | # MODEL
16 | # -----------------------------------------------------------------------------
17 | _C.MODEL = CN()
18 | # Using cuda or cpu for training
19 | _C.MODEL.DEVICE = "cuda"
20 | # ID number of GPU
21 | _C.MODEL.DEVICE_ID = '0'
22 | # Name of backbone
23 | _C.MODEL.NAME = 'resnet50'
24 | # Last stride of backbone
25 | _C.MODEL.LAST_STRIDE = 1
26 | # Path to pretrained model of backbone
27 | _C.MODEL.PRETRAIN_PATH = ''
28 |
29 | # Use ImageNet pretrained model to initialize backbone or use self trained model to initialize the whole model
30 | # Options: 'imagenet' , 'self' , 'finetune'
31 | _C.MODEL.PRETRAIN_CHOICE = 'imagenet'
32 |
33 | # If train with BNNeck, options: 'bnneck' or 'no'
34 | _C.MODEL.NECK = 'bnneck'
35 | # If train loss include center loss, options: 'yes' or 'no'. Loss with center loss has different optimizer configuration
36 | _C.MODEL.IF_WITH_CENTER = 'no'
37 |
38 | _C.MODEL.ID_LOSS_TYPE = 'softmax'
39 | _C.MODEL.ID_LOSS_WEIGHT = 1.0
40 | _C.MODEL.TRIPLET_LOSS_WEIGHT = 1.0
41 |
42 | _C.MODEL.METRIC_LOSS_TYPE = 'triplet'
43 | # If train with multi-gpu ddp mode, options: 'True', 'False'
44 | _C.MODEL.DIST_TRAIN = False
45 | # If train with soft triplet loss, options: 'True', 'False'
46 | _C.MODEL.NO_MARGIN = False
47 | # If train with label smooth, options: 'on', 'off'
48 | _C.MODEL.IF_LABELSMOOTH = 'on'
49 | # If train with arcface loss, options: 'True', 'False'
50 | _C.MODEL.COS_LAYER = False
51 |
52 | # Transformer setting
53 | _C.MODEL.DROP_PATH = 0.1
54 | _C.MODEL.DROP_OUT = 0.0
55 | _C.MODEL.ATT_DROP_RATE = 0.0
56 | _C.MODEL.TRANSFORMER_TYPE = 'None'
57 | _C.MODEL.STRIDE_SIZE = [16, 16]
58 |
59 | # JPM Parameter
60 | _C.MODEL.JPM = False
61 | _C.MODEL.SHIFT_NUM = 5
62 | _C.MODEL.SHUFFLE_GROUP = 2
63 | _C.MODEL.DEVIDE_LENGTH = 4
64 | _C.MODEL.RE_ARRANGE = True
65 |
66 | # SIE Parameter
67 | _C.MODEL.SIE_COE = 3.0
68 | _C.MODEL.SIE_CAMERA = False
69 | _C.MODEL.SIE_VIEW = False
70 |
71 | # CLIP Backbone
72 | _C.MODEL.BACKBONE = 'ViT-B/16'
73 | # Use BN
74 | _C.MODEL.BN = False
75 | # Number of head
76 | _C.MODEL.NUM_HEAD = 8
77 | # Loss type
78 | _C.MODEL.LOSS_TYPE = 'BCE' # 'ASL'
79 | # ema model
80 | _C.MODEL.USE_EMA = False
81 | _C.MODEL.EMA_DECAY = 0.9997
82 | # load pretrain
83 | _C.MODEL.LOAD = False
84 | # text encoder
85 | _C.MODEL.TEXT_ENCODER = 'CLIP'
86 | _C.MODEL.DEPTH_TEXT = [-1]
87 | _C.MODEL.TEXT_CTX = 4 # middle layers, 一般不用
88 | _C.MODEL.PROMPT_CSC = False
89 |
90 | # transfer type
91 | _C.MODEL.TRANSFER_TYPE = "freeze_all"
92 |
93 | # SAA
94 | _C.MODEL.SAA_LAYER = [-1]
95 |
96 | # loc region pooling
97 | _C.MODEL.LOC_STRIDE_SIZE = 4
98 | _C.MODEL.LOC_KERNEL_SIZE = 4
99 |
100 | # Adapter
101 | _C.MODEL.DEPTH_VISION = [-1]
102 | _C.MODEL.VISION_ADAPT = 8
103 | _C.MODEL.KERNEL_SIZE = 3
104 |
105 | # temperature
106 | _C.MODEL.TEMPERATURE = 0.002
107 |
108 | # new prototype
109 | _C.MODEL.NUM_NEW_PROTOTYPE = 10
110 |
111 | # OT reg
112 | _C.MODEL.OT_REG = 0.1
113 | _C.MODEL.OT_REGSC = 0.05
114 |
115 | # Adapter ratio
116 | _C.MODEL.ADAPTER_RATIO = 0.2
117 |
118 |
119 | # -----------------------------------------------------------------------------
120 | # INPUT
121 | # -----------------------------------------------------------------------------
122 | _C.INPUT = CN()
123 | # Size of the image during training
124 | _C.INPUT.SIZE_TRAIN = [224, 224]
125 | # Size of the image during test
126 | _C.INPUT.SIZE_TEST = [224, 224]
127 | # Random probability for image horizontal flip
128 | _C.INPUT.PROB = 0.5
129 | # Random probability for random erasing
130 | _C.INPUT.RE_PROB = 0.5
131 | # Values to be used for image normalization
132 | # _C.INPUT.PIXEL_MEAN = [0.485, 0.456, 0.406]
133 | # # Values to be used for image normalization
134 | # _C.INPUT.PIXEL_STD = [0.229, 0.224, 0.225]
135 | _C.INPUT.PIXEL_MEAN = [0.48145466, 0.4578275, 0.40821073]
136 | # Values to be used for image normalization
137 | _C.INPUT.PIXEL_STD = [0.26862954, 0.26130258, 0.27577711]
138 | # Value of padding size
139 | _C.INPUT.PADDING = 10
140 |
141 | # Image augmentation
142 | _C.INPUT.HIDESEEK = False
143 | _C.INPUT.AUGMIX = False
144 |
145 | # Vis: with path dataset
146 | _C.INPUT.WITH_PATH = False
147 |
148 | # Text template
149 | _C.INPUT.TEMPLATE = 'vanilla' # 'ensemble'
150 | _C.INPUT.ENSEMBLE_TYPE = 'embedding' # 'logit' 'score'
151 | _C.INPUT.NUM_GROUPS = 1 # 'logit'
152 |
153 | # -----------------------------------------------------------------------------
154 | # Dataset
155 | # -----------------------------------------------------------------------------
156 | _C.DATASETS = CN()
157 | # List of the dataset names for training, as present in paths_catalog.py
158 | _C.DATASETS.NAMES = ('PETA')
159 | # Root directory where datasets should be used (and downloaded if not found)
160 | _C.DATASETS.ROOT_DIR = ('../data')
161 | # pedestrain attribute numbers
162 | _C.DATASETS.NUMBERS = 35
163 | _C.DATASETS.TEMPLATE_NAMES = ''
164 |
165 | # label proportion
166 | _C.DATASETS.PARTIAL = -1.0
167 |
168 | # -----------------------------------------------------------------------------
169 | # DataLoader
170 | # -----------------------------------------------------------------------------
171 | _C.DATALOADER = CN()
172 | # Number of data loading threads
173 | _C.DATALOADER.NUM_WORKERS = 8
174 | # Sampler for data loading
175 | _C.DATALOADER.SAMPLER = 'softmax'
176 | # Number of instance for one batch
177 | _C.DATALOADER.NUM_INSTANCE = 16
178 |
179 | # ---------------------------------------------------------------------------- #
180 | # Solver
181 | # ---------------------------------------------------------------------------- #
182 | _C.SOLVER = CN()
183 | # Save the model
184 | _C.SOLVER.SAVE_MODEL = False
185 | # Name of optimizer
186 | _C.SOLVER.OPTIMIZER_NAME = "Adam"
187 | # Number of max epoches
188 | _C.SOLVER.MAX_EPOCHS = 100
189 | # Number of max epoches FOR SCHEDULER
190 | _C.SOLVER.SCHEDULER_MAX_EPOCHS = 60
191 | # Number of max epoches FOR SCHEDULER
192 | _C.SOLVER.SCHEDULER_MAX_ITER = 1000000
193 | # Base learning rate
194 | _C.SOLVER.BASE_LR = 3e-4
195 | # Factor of learning bias
196 | _C.SOLVER.BIAS_LR_FACTOR = 1
197 | # Factor of learning bias
198 | _C.SOLVER.SEED = 42
199 | # Momentum
200 | _C.SOLVER.MOMENTUM = 0.9
201 | # Margin of triplet loss
202 | _C.SOLVER.MARGIN = 0.3
203 | # Learning rate of SGD to learn the centers of center loss
204 | _C.SOLVER.CENTER_LR = 0.5
205 | # Balanced weight of center loss
206 | _C.SOLVER.CENTER_LOSS_WEIGHT = 0.0005
207 | _C.SOLVER.LARGE_FC_LR = False
208 |
209 | # Settings of weight decay
210 | _C.SOLVER.WEIGHT_DECAY = 0.0001
211 | _C.SOLVER.WEIGHT_DECAY_BIAS = 0.0001
212 | _C.SOLVER.WEIGHT_DECAY_SGD = 0.0001
213 |
214 | # decay rate of learning rate
215 | _C.SOLVER.GAMMA = 0.1
216 | # decay step of learning rate
217 | _C.SOLVER.STEPS = (40, 70)
218 | # warm up factor
219 | _C.SOLVER.WARMUP_FACTOR = 0.01
220 | # warm up epochs
221 | _C.SOLVER.WARMUP_EPOCHS = 3
222 | # method of warm up, option: 'constant','linear'
223 | _C.SOLVER.WARMUP_METHOD = "linear"
224 |
225 | _C.SOLVER.COSINE_MARGIN = 0.5
226 | _C.SOLVER.COSINE_SCALE = 30
227 |
228 | # epoch number of saving checkpoints
229 | _C.SOLVER.CHECKPOINT_PERIOD = 10
230 | # iteration of display training log
231 | _C.SOLVER.LOG_PERIOD = 100
232 | # epoch number of validation
233 | _C.SOLVER.EVAL_PERIOD = 10
234 | # Number of images per batch
235 | # This is global, so if we have 8 GPUs and IMS_PER_BATCH = 128, each GPU
236 | # contain 16 images per batch
237 | _C.SOLVER.IMS_PER_BATCH = 64
238 |
239 | # Classification Threshold
240 | # Loss type for contrastive
241 | _C.SOLVER.THRESH = 0.5
242 |
243 | # Label smoothing
244 | _C.SOLVER.LABEL_SMOOTHING = False
245 |
246 | # LR sheduler iter (TGPT imple)
247 | _C.SOLVER.GAMMA = 0.1
248 | _C.SOLVER.LR_SCHEDULER = "cosine"
249 | _C.SOLVER.STEPSIZE = 1000
250 |
251 | # aslloss param
252 | _C.SOLVER.GAMMA_NEG = 2
253 | _C.SOLVER.GAMMA_POS = 0
254 | _C.SOLVER.CLIP = 0.
255 |
256 | # twloss param
257 | _C.SOLVER.TP = 4.
258 | _C.SOLVER.TN = 1.
259 |
260 | # save the middle output, for visualization
261 | _C.SOLVER.VERBOSE = False
262 |
263 | # iter training
264 | _C.SOLVER.MAX_ITER = 12800
265 | _C.SOLVER.WARMUP_ITER = 200
266 | _C.SOLVER.BASE_LR_SGD = 0.001
267 |
268 | # KD loss weight
269 | _C.SOLVER.KDLOSS_WEIGHT = 1.
270 |
271 | # Text batch
272 | _C.SOLVER.TEXT_BATCH_SIZE = 80
273 |
274 | # debug mode
275 | _C.SOLVER.DEBUG = False
276 |
277 | # zero-shot testing
278 | _C.SOLVER.ZS_TEST = False
279 |
280 | # sample text
281 | _C.SOLVER.SAMPLE_TEXT = False
282 |
283 |
284 | # ---------------------------------------------------------------------------- #
285 | # TEST
286 | # ---------------------------------------------------------------------------- #
287 |
288 | _C.TEST = CN()
289 | # Number of images per batch during test
290 | _C.TEST.IMS_PER_BATCH = 128
291 | # If test with re-ranking, options: 'True','False'
292 | _C.TEST.RE_RANKING = False
293 | # Path to trained model
294 | _C.TEST.WEIGHT = ""
295 | _C.TEST.WEIGHT_ITERS = 12800
296 | # Which feature of BNNeck to be used for test, before or after BNNneck, options: 'before' or 'after'
297 | _C.TEST.NECK_FEAT = 'after'
298 | # Whether feature is nomalized before test, if yes, it is equivalent to cosine distance
299 | _C.TEST.FEAT_NORM = 'yes'
300 |
301 | # Name for saving the distmat after testing.
302 | _C.TEST.DIST_MAT = "dist_mat.npy"
303 | # Whether calculate the eval score option: 'True', 'False'
304 | _C.TEST.EVAL = False
305 | _C.TEST.USE_FUSION = False
306 | _C.TEST.TTA = -1
307 | _C.TEST.TEN_CROP = False
308 | # ---------------------------------------------------------------------------- #
309 | # Misc options
310 | # ---------------------------------------------------------------------------- #
311 | # Path to checkpoint and saved log of trained model
312 | _C.OUTPUT_DIR = ""
313 |
314 |
315 | # ---------------------------------------------------------------------------- #
316 | # Train time configs
317 | # ---------------------------------------------------------------------------- #
318 | _C.LOCAL_RANK = 0
319 | _C.WANDB = False
320 | _C.WANDB_PROJ = "OVML-RAM"
321 |
322 |
323 |
324 |
325 |
326 |
327 |
--------------------------------------------------------------------------------
/configs/coco.yml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | IF_LABELSMOOTH: 'off'
3 | IF_WITH_CENTER: 'no'
4 | NO_MARGIN: True
5 | STRIDE_SIZE: [16, 16]
6 | BACKBONE: 'ViT-B/16'
7 | LOSS_TYPE: 'MMC'
8 | DEPTH_VISION: [9,10,11]
9 | DEPTH_TEXT: [9,10,11]
10 | PROMPT_CSC: False
11 | TRANSFER_TYPE: "Adapter"
12 | SAA_LAYER: [12, -1]
13 | USE_EMA: True
14 | KERNEL_SIZE: 3
15 | TEMPERATURE: 2e-3
16 | OT_REGSC: 0.05
17 |
18 | INPUT:
19 | SIZE_TRAIN: [224, 224]
20 | SIZE_TEST: [224, 224]
21 | PROB: 0.5
22 | RE_PROB: 0.5
23 | PADDING: 10
24 | PIXEL_MEAN: [0.5, 0.5, 0.5]
25 | PIXEL_STD: [0.5, 0.5, 0.5]
26 | ENSEMBLE_TYPE: "embedding"
27 | NUM_GROUPS: 1
28 |
29 | DATASETS:
30 | NAMES: ('COCO_ZS')
31 | ROOT_DIR: ('/mnt/nas/TrueNas1/COCO')
32 | NUMBERS: 80
33 |
34 | DATALOADER:
35 | SAMPLER: 'sigmoid'
36 | NUM_INSTANCE: 4
37 | NUM_WORKERS: 8
38 |
39 | SOLVER:
40 | OPTIMIZER_NAME: 'AdamW'
41 | BASE_LR: 5e-5
42 | IMS_PER_BATCH: 32
43 | LARGE_FC_LR: False
44 | LOG_PERIOD: 500
45 | WEIGHT_DECAY: 1e-4
46 | MAX_EPOCHS: 6
47 | EVAL_PERIOD: 3
48 | SAVE_MODEL: False
49 |
50 | TEST:
51 | EVAL: True
52 | IMS_PER_BATCH: 512
53 | RE_RANKING: False
54 | WEIGHT: ''
55 | WEIGHT_ITERS: 12800
56 | NECK_FEAT: 'before'
57 | FEAT_NORM: 'yes'
58 |
59 | OUTPUT_DIR: './logs/RAM/Natural/COCO'
60 |
--------------------------------------------------------------------------------
/configs/nus.yml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | IF_LABELSMOOTH: 'off'
3 | IF_WITH_CENTER: 'no'
4 | NO_MARGIN: True
5 | STRIDE_SIZE: [16, 16]
6 | BACKBONE: 'ViT-B/16'
7 | LOSS_TYPE: 'MMC'
8 | DEPTH_VISION: [9,10,11]
9 | DEPTH_TEXT: [6,7,8,9,10,11]
10 | PROMPT_CSC: False
11 | TRANSFER_TYPE: "Adapter"
12 | SAA_LAYER: [12, -1]
13 | USE_EMA: True
14 | KERNEL_SIZE: 3
15 | TEMPERATURE: 2e-4
16 | OT_REGSC: 0.05
17 |
18 | INPUT:
19 | SIZE_TRAIN: [224, 224]
20 | SIZE_TEST: [224, 224]
21 | PROB: 0.5
22 | RE_PROB: 0.5
23 | PADDING: 10
24 | PIXEL_MEAN: [0.5, 0.5, 0.5]
25 | PIXEL_STD: [0.5, 0.5, 0.5]
26 | ENSEMBLE_TYPE: "embedding"
27 | NUM_GROUPS: 1
28 |
29 | DATASETS:
30 | NAMES: ('NUS_ZS')
31 | ROOT_DIR: ('/mnt/nas/TrueNas1/NUS_WIDE')
32 | NUMBERS: 81
33 |
34 | DATALOADER:
35 | SAMPLER: 'sigmoid'
36 | NUM_INSTANCE: 4
37 | NUM_WORKERS: 8
38 |
39 | SOLVER:
40 | OPTIMIZER_NAME: 'AdamW'
41 | BASE_LR: 5e-5
42 | IMS_PER_BATCH: 32
43 | LARGE_FC_LR: False
44 | LOG_PERIOD: 1000
45 | WEIGHT_DECAY: 1e-4
46 | MAX_EPOCHS: 6
47 | EVAL_PERIOD: 3
48 | SAVE_MODEL: False
49 |
50 | TEST:
51 | EVAL: True
52 | IMS_PER_BATCH: 512
53 | RE_RANKING: False
54 | WEIGHT: ''
55 | WEIGHT_ITERS: 12800
56 | NECK_FEAT: 'before'
57 | FEAT_NORM: 'yes'
58 |
59 | OUTPUT_DIR: './logs/RAM/Natural/NUS'
60 |
--------------------------------------------------------------------------------
/datasets/MultiLabel/classification.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import random
4 | import json
5 |
6 | from datasets.bases import BaseImageDataset
7 |
8 |
9 |
10 | class ZSMultiLabelClassification(BaseImageDataset):
11 | def __init__(self, root='', verbose=True, **kwargs):
12 | super(ZSMultiLabelClassification, self).__init__()
13 | self.dataset_dir = root
14 | if "coco" in root.lower():
15 | dataset_name = "COCO"
16 | self.train_file = os.path.join(self.dataset_dir, "annotations", 'train_48_filtered.json')
17 | self.test_file = os.path.join(self.dataset_dir, "annotations", 'test_17_filtered.json')
18 | self.test_file_gzsl = os.path.join(self.dataset_dir, "annotations", 'test_65_filtered.json')
19 | elif "nus" in root.lower():
20 | dataset_name = "NUS"
21 | self.train_file = os.path.join(self.dataset_dir, "annotations", 'train_925_filtered.json')
22 | self.test_file = os.path.join(self.dataset_dir, "annotations", 'test_81_filtered.json')
23 | self.test_file_gzsl = os.path.join(self.dataset_dir, "annotations", 'test_1006_filtered.json')
24 | else:
25 | raise NotImplementedError
26 | self._check_before_run()
27 |
28 | train, class2idx, name_train = self._load_dataset(self.dataset_dir, self.train_file, shuffle=True)
29 | test, _, name_test = self._load_dataset(self.dataset_dir, self.test_file, shuffle=False)
30 | test_gzsl, _, _ = self._load_dataset(self.dataset_dir, self.test_file_gzsl, shuffle=False, names=name_train+name_test)
31 | self.train = train
32 | self.test = test
33 | self.test_gzsl = test_gzsl
34 | self.class2idx = class2idx
35 | if verbose:
36 | print(f"=> {dataset_name} ZSL Dataset:")
37 | self.print_dataset_statistics(train, test)
38 | print(f"=> {dataset_name} GZSL Dataset:")
39 | self.print_dataset_statistics(train, test_gzsl)
40 | self.classnames_seen = name_train
41 | self.classnames_unseen = name_test
42 | self.classnames = name_train+name_test
43 | self.num_cls_train = len(name_train)
44 | self.num_cls_test = len(name_test)
45 |
46 | def _check_before_run(self):
47 | """Check if all files are available before going deeper"""
48 | if not os.path.exists(self.dataset_dir):
49 | raise RuntimeError("'{}' is not available".format(self.dataset_dir))
50 | if not os.path.exists(self.train_file):
51 | raise RuntimeError("'{}' is not available".format(self.train_file))
52 | if not os.path.exists(self.test_file):
53 | raise RuntimeError("'{}' is not available".format(self.test_file))
54 |
55 | def _load_dataset(self, data_dir, annot_path, shuffle=True, names=None):
56 | out_data = []
57 | with open(annot_path) as f:
58 | annotation = json.load(f)
59 | classes = sorted(annotation['classes']) if names is None else names
60 | class_to_idx = {classes[i]: i for i in range(len(classes))}
61 | images_info = annotation['images']
62 | img_wo_objects = 0
63 | for img_info in images_info:
64 | labels_idx = list()
65 | rel_image_path, img_labels = img_info
66 | full_image_path = os.path.join(data_dir, rel_image_path)
67 | labels_idx = [class_to_idx[lbl] for lbl in img_labels if lbl in class_to_idx]
68 | labels_idx = list(set(labels_idx))
69 | # transform to one-hot
70 | onehot = np.zeros(len(classes), dtype=int)
71 | onehot[labels_idx] = 1
72 | assert full_image_path
73 | if not labels_idx:
74 | img_wo_objects += 1
75 | out_data.append((full_image_path, onehot))
76 | if img_wo_objects:
77 | print(f'WARNING: there are {img_wo_objects} images without labels and will be treated as negatives')
78 | if shuffle:
79 | random.shuffle(out_data)
80 | return out_data, class_to_idx, classes
81 |
82 |
83 |
84 |
--------------------------------------------------------------------------------
/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from datasets.make_dataloader import make_dataloader
--------------------------------------------------------------------------------
/datasets/bases.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import numpy as np
4 |
5 | from PIL import Image, ImageFile
6 | from imgaug import augmenters as iaa
7 | from torch.utils.data import Dataset
8 | ImageFile.LOAD_TRUNCATED_IMAGES = True
9 |
10 |
11 |
12 | class BaseDataset(object):
13 | """
14 | Base class of reid dataset
15 | """
16 | def get_imagedata_info(self, data):
17 | imgs = []
18 | labels = []
19 | for data_img, data_label in data:
20 | imgs += [data_img]
21 | labels += [data_label]
22 | num_imgs = len(imgs)
23 | num_labels = len(labels)
24 | return num_imgs, num_labels
25 |
26 | def print_dataset_statistics(self):
27 | raise NotImplementedError
28 |
29 |
30 |
31 | class BaseImageDataset(BaseDataset):
32 | """
33 | Base class of image reid dataset
34 | """
35 | def print_dataset_statistics(self, train, test):
36 | num_train_imgs, num_train_labels = self.get_imagedata_info(train)
37 | num_test_imgs, num_test_labels = self.get_imagedata_info(test)
38 |
39 | print("Dataset statistics:")
40 | print(" ----------------------------------------")
41 | print(" subset | # images | # labels")
42 | print(" ----------------------------------------")
43 | print(" train | {:8d} | {:9d}".format(num_train_imgs, num_train_labels))
44 | print(" test | {:8d} | {:9d}".format(num_test_imgs, num_test_labels))
45 | print(" ----------------------------------------")
46 |
47 |
48 |
49 | class ImageDataset(Dataset):
50 | def __init__(self, dataset, transform=None, shuffle = True, mirror=False, Aug = False):
51 | self.dataset = dataset
52 | self.transform = transform
53 | self.mirror = mirror
54 | self.Aug = Aug
55 | self.AugSeq = iaa.Sequential([
56 | iaa.Sometimes(0.2,
57 | iaa.Crop(percent=(0, 0.1)),
58 | ),
59 | # Small gaussian blur with random sigma between 0 and 0.5.
60 | # But we only blur about 50% of all images.
61 | iaa.Sometimes(0.2,
62 | iaa.GaussianBlur(sigma=(0, 0.2))
63 | ),
64 |
65 | # Strengthen or weaken the contrast in each image.
66 | iaa.Sometimes(0.2, iaa.ContrastNormalization((0.75, 1.25))),
67 | iaa.AdditiveGaussianNoise(loc=0, scale=(0.0, 0.015*255), per_channel=0.2),
68 | iaa.Multiply((0.9, 1.1), per_channel=0.2),
69 | # Apply affine transformations to each image.
70 | # Scale/zoom them, translate/move them, rotate them and shear them.
71 | iaa.Sometimes(0.2,
72 | iaa.Affine(
73 | scale={"x": (0.9, 1.1), "y": (0.9, 1.1)},
74 | translate_percent={"x": (-0.1, 0.1), "y": (-0.1, 0.1)},
75 | rotate=(-10, 10)
76 | )
77 | ),
78 | ], random_order=True) # apply augmenters in random order
79 |
80 | def __getitem__(self, item):
81 | img_path, atts_list = self.dataset[item]
82 | img = Image.open(img_path).convert("RGB")
83 |
84 | # flip
85 | flip = 0
86 | if self.mirror:
87 | flip = np.random.choice(2)
88 | if flip:
89 | img = img.transpose(Image.FLIP_LEFT_RIGHT)
90 |
91 | if self.Aug:
92 | cv_img = cv2.cvtColor(np.asarray(img), cv2.COLOR_RGB2BGR)
93 | cv_img = cv_img.reshape(1, cv_img.shape[0], cv_img.shape[1], cv_img.shape[2])
94 | cv_img = self.AugSeq.augment_images(cv_img)
95 | cv_img = cv_img.reshape(cv_img.shape[1], cv_img.shape[2], cv_img.shape[3])
96 | img = Image.fromarray(cv2.cvtColor(cv_img, cv2.COLOR_BGR2RGB))
97 |
98 | if self.transform is not None:
99 | img = self.transform(img)
100 | return img, atts_list
101 |
102 | def __len__(self):
103 | return len(self.dataset)
104 |
105 |
--------------------------------------------------------------------------------
/datasets/make_dataloader.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.distributed as dist
4 | import torchvision.transforms as T
5 | from torch.utils.data import DataLoader
6 | from timm.data.random_erasing import RandomErasing
7 |
8 | from .bases import ImageDataset
9 | from .MultiLabel.classification import ZSMultiLabelClassification
10 |
11 |
12 |
13 | def train_collate_fn(batch):
14 | imgs, label = zip(*batch)
15 | label = torch.tensor(label, dtype=torch.int64)
16 | return torch.stack(imgs, dim=0), label
17 |
18 |
19 | def val_collate_fn(batch):
20 | imgs, label = zip(*batch)
21 | label = torch.tensor(label, dtype=torch.int64)
22 | return torch.stack(imgs, dim=0), label
23 |
24 |
25 | def make_dataloader(cfg):
26 | num_workers = cfg.DATALOADER.NUM_WORKERS
27 | dataset = ZSMultiLabelClassification(root=cfg.DATASETS.ROOT_DIR)
28 |
29 | train_transforms = T.Compose([
30 | T.Resize(cfg.INPUT.SIZE_TRAIN, interpolation=3),
31 | T.RandomHorizontalFlip(p=cfg.INPUT.PROB),
32 | T.Pad(cfg.INPUT.PADDING),
33 | T.RandomCrop(cfg.INPUT.SIZE_TRAIN),
34 | T.ToTensor(),
35 | T.Normalize(mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD),
36 | RandomErasing(probability=cfg.INPUT.RE_PROB, mode='pixel', max_count=1, device='cpu'),
37 | ])
38 | train_set = ImageDataset(dataset.train, transform=train_transforms, mirror=True, Aug=True)
39 |
40 | val_transforms = T.Compose([
41 | T.Resize(cfg.INPUT.SIZE_TEST),
42 | T.ToTensor(),
43 | T.Normalize(mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD)
44 | ])
45 | val_set = ImageDataset(dataset.test, transform=val_transforms, mirror=False, Aug=False)
46 | val_set_gzsl = ImageDataset(dataset.test_gzsl, transform=val_transforms, mirror=False, Aug=False)
47 |
48 | if cfg.MODEL.DIST_TRAIN:
49 | print('DIST_TRAIN START')
50 | mini_batch_size = cfg.SOLVER.IMS_PER_BATCH // dist.get_world_size()
51 | print('===========================\n mini batch size:', mini_batch_size)
52 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_set)
53 | nw = min([os.cpu_count(), mini_batch_size if mini_batch_size > 1 else 0, 8])
54 | train_loader = torch.utils.data.DataLoader(
55 | train_set,
56 | batch_size=mini_batch_size,
57 | pin_memory=True,
58 | num_workers=nw,
59 | shuffle=False,
60 | sampler=train_sampler,
61 | collate_fn=train_collate_fn,
62 | drop_last=True
63 | )
64 | val_loader = DataLoader(
65 | val_set,
66 | batch_size=cfg.TEST.IMS_PER_BATCH,
67 | shuffle=False,
68 | num_workers=num_workers,
69 | collate_fn=val_collate_fn
70 | )
71 | val_loader_gzsl = DataLoader(
72 | val_set_gzsl,
73 | batch_size=cfg.TEST.IMS_PER_BATCH,
74 | shuffle=False,
75 | num_workers=num_workers,
76 | collate_fn=val_collate_fn
77 | )
78 | return train_loader, val_loader, val_loader_gzsl, train_sampler, dataset
79 | else:
80 | train_loader = DataLoader(
81 | train_set,
82 | batch_size=cfg.SOLVER.IMS_PER_BATCH,
83 | shuffle=False,
84 | num_workers=num_workers,
85 | collate_fn=train_collate_fn,
86 | drop_last=True,
87 | persistent_workers=True
88 | )
89 | val_loader = DataLoader(
90 | val_set,
91 | batch_size=cfg.TEST.IMS_PER_BATCH,
92 | shuffle=False,
93 | num_workers=num_workers,
94 | collate_fn=val_collate_fn,
95 | persistent_workers=True,
96 | )
97 | val_loader_gzsl = DataLoader(
98 | val_set_gzsl,
99 | batch_size=cfg.TEST.IMS_PER_BATCH,
100 | shuffle=False,
101 | num_workers=num_workers,
102 | collate_fn=val_collate_fn,
103 | persistent_workers=True,
104 | )
105 | return train_loader, val_loader, val_loader_gzsl, None, dataset
106 |
107 |
--------------------------------------------------------------------------------
/loss/__init__.py:
--------------------------------------------------------------------------------
1 | from .asymmetric_loss import AsymmetricLossOptimized, AsymmetricLossOptimized_partial
2 | from .mmc_loss import mmc_loss
3 |
--------------------------------------------------------------------------------
/loss/asymmetric_loss.py:
--------------------------------------------------------------------------------
1 | """
2 | Borrow from https://github.com/SlongLiu/query2labels, thanks
3 | """
4 | import torch
5 | import torch.nn as nn
6 |
7 |
8 |
9 | class AsymmetricLoss(nn.Module):
10 | def __init__(self, gamma_neg=4, gamma_pos=1, clip=0.05, eps=1e-8, disable_torch_grad_focal_loss=False):
11 | super(AsymmetricLoss, self).__init__()
12 |
13 | self.gamma_neg = gamma_neg
14 | self.gamma_pos = gamma_pos
15 | self.clip = clip
16 | self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss
17 | self.eps = eps
18 |
19 | def forward(self, x, y):
20 | """"
21 | Parameters
22 | ----------
23 | x: input logits
24 | y: targets (multi-label binarized vector)
25 | """
26 |
27 | # Calculating Probabilities
28 | x_sigmoid = torch.sigmoid(x)
29 | xs_pos = x_sigmoid
30 | xs_neg = 1 - x_sigmoid
31 |
32 | # Asymmetric Clipping
33 | if self.clip is not None and self.clip > 0:
34 | xs_neg = (xs_neg + self.clip).clamp(max=1)
35 |
36 | # Basic CE calculation
37 | los_pos = y * torch.log(xs_pos.clamp(min=self.eps, max=1-self.eps))
38 | los_neg = (1 - y) * torch.log(xs_neg.clamp(min=self.eps, max=1-self.eps))
39 | loss = los_pos + los_neg
40 |
41 | # Asymmetric Focusing
42 | if self.gamma_neg > 0 or self.gamma_pos > 0:
43 | if self.disable_torch_grad_focal_loss:
44 | torch._C.set_grad_enabled(False)
45 | pt0 = xs_pos * y
46 | pt1 = xs_neg * (1 - y) # pt = p if t > 0 else 1-p
47 | pt = pt0 + pt1
48 | one_sided_gamma = self.gamma_pos * y + self.gamma_neg * (1 - y)
49 | one_sided_w = torch.pow(1 - pt, one_sided_gamma)
50 | if self.disable_torch_grad_focal_loss:
51 | torch._C.set_grad_enabled(True)
52 | loss *= one_sided_w
53 |
54 | return -loss.sum()
55 |
56 |
57 |
58 | class AsymmetricLossOptimized(nn.Module):
59 | ''' Notice - optimized version, minimizes memory allocation and gpu uploading,
60 | favors inplace operations'''
61 | # Default values are those used in query2label training
62 | def __init__(self, gamma_neg=2, gamma_pos=0, clip=0., eps=1e-5, disable_torch_grad_focal_loss=True):
63 | super(AsymmetricLossOptimized, self).__init__()
64 |
65 | self.gamma_neg = gamma_neg
66 | self.gamma_pos = gamma_pos
67 | self.clip = clip
68 | self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss
69 | self.eps = eps
70 |
71 | self.targets = self.anti_targets = self.xs_pos = self.xs_neg = self.asymmetric_w = self.loss = None
72 |
73 | def forward(self, x, y):
74 | """"
75 | Parameters
76 | ----------
77 | x: input logits
78 | y: targets (multi-label binarized vector)
79 | """
80 |
81 | self.targets = y
82 | self.anti_targets = 1 - y
83 |
84 | # Calculating Probabilities
85 | self.xs_pos = torch.sigmoid(x)
86 | self.xs_neg = 1.0 - self.xs_pos
87 |
88 | # Asymmetric Clipping
89 | if self.clip is not None and self.clip > 0:
90 | self.xs_neg.add_(self.clip).clamp_(max=1)
91 |
92 | # Basic CE calculation
93 | self.loss = self.targets * torch.log(self.xs_pos.clamp(min=self.eps))
94 | self.loss.add_(self.anti_targets * torch.log(self.xs_neg.clamp(min=self.eps)))
95 |
96 | # Asymmetric Focusing
97 | if self.gamma_neg > 0 or self.gamma_pos > 0:
98 | if self.disable_torch_grad_focal_loss:
99 | with torch.no_grad():
100 | # if self.disable_torch_grad_focal_loss:
101 | # torch._C.set_grad_enabled(False)
102 | self.xs_pos = self.xs_pos * self.targets
103 | self.xs_neg = self.xs_neg * self.anti_targets
104 | self.asymmetric_w = torch.pow(1 - self.xs_pos - self.xs_neg,
105 | self.gamma_pos * self.targets + self.gamma_neg * self.anti_targets)
106 | # if self.disable_torch_grad_focal_loss:
107 | # torch._C.set_grad_enabled(True)
108 | self.loss *= self.asymmetric_w
109 | else:
110 | self.xs_pos = self.xs_pos * self.targets
111 | self.xs_neg = self.xs_neg * self.anti_targets
112 | self.asymmetric_w = torch.pow(1 - self.xs_pos - self.xs_neg,
113 | self.gamma_pos * self.targets + self.gamma_neg * self.anti_targets)
114 | self.loss *= self.asymmetric_w
115 | _loss = - self.loss.sum() / x.size(0)
116 | _loss = _loss / y.size(1) * 1000
117 |
118 | return _loss
119 |
120 |
121 |
122 |
123 | class AsymmetricLossOptimized_partial(nn.Module):
124 | """
125 | ASL loss used for partial label setting
126 | """
127 | def __init__(self, gamma_neg=2, gamma_pos=0, clip=0., eps=1e-5, disable_torch_grad_focal_loss=True):
128 | super(AsymmetricLossOptimized, self).__init__()
129 |
130 | self.gamma_neg = gamma_neg
131 | self.gamma_pos = gamma_pos
132 | self.clip = clip
133 | self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss
134 | self.eps = eps
135 |
136 | self.targets = self.anti_targets = self.xs_pos = self.xs_neg = self.asymmetric_w = self.loss = None
137 |
138 | def forward(self, x, y):
139 | """"
140 | Parameters
141 | ----------
142 | x: input logits
143 | y: targets (multi-label binarized vector)
144 | """
145 | y = y.reshape(-1)
146 | x = x.reshape(-1)
147 | x = x[y!=-1]
148 | y = y[y!=-1]
149 |
150 | self.targets = y
151 | self.anti_targets = 1 - y
152 |
153 | # Calculating Probabilities
154 | # self.xs_pos = torch.sigmoid(x)
155 | self.xs_pos = x
156 | self.xs_neg = 1.0 - self.xs_pos
157 |
158 | # Asymmetric Clipping
159 | if self.clip is not None and self.clip > 0:
160 | self.xs_neg.add_(self.clip).clamp_(max=1)
161 |
162 | # Basic CE calculation
163 | self.loss = self.targets * torch.log(self.xs_pos.clamp(min=self.eps))
164 | self.loss.add_(self.anti_targets * torch.log(self.xs_neg.clamp(min=self.eps)))
165 |
166 | # Asymmetric Focusing
167 | if self.gamma_neg > 0 or self.gamma_pos > 0:
168 | if self.disable_torch_grad_focal_loss:
169 | with torch.no_grad():
170 | # if self.disable_torch_grad_focal_loss:
171 | # torch._C.set_grad_enabled(False)
172 | self.xs_pos = self.xs_pos * self.targets
173 | self.xs_neg = self.xs_neg * self.anti_targets
174 | self.asymmetric_w = torch.pow(1 - self.xs_pos - self.xs_neg,
175 | self.gamma_pos * self.targets + self.gamma_neg * self.anti_targets)
176 | # if self.disable_torch_grad_focal_loss:
177 | # torch._C.set_grad_enabled(True)
178 | self.loss *= self.asymmetric_w
179 | else:
180 | self.xs_pos = self.xs_pos * self.targets
181 | self.xs_neg = self.xs_neg * self.anti_targets
182 | self.asymmetric_w = torch.pow(1 - self.xs_pos - self.xs_neg,
183 | self.gamma_pos * self.targets + self.gamma_neg * self.anti_targets)
184 | self.loss *= self.asymmetric_w
185 | # print(self.loss.shape)
186 |
187 | _loss = - self.loss.sum() / x.size(0) * 1000
188 | # _loss = _loss / y.size(1) * 1000
189 |
190 | return _loss
191 |
192 |
--------------------------------------------------------------------------------
/loss/mmc_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 |
5 |
6 | def mmc_loss(logits, mask, logits_mask=None):
7 | """
8 | MMC Loss (Multi-Matching Contrastive loss)
9 | Inspired by SupCon Loss: https://github.com/google-research/google-research/tree/master/supcon
10 | MMC extends it into (1) multi-modal setting and (2) batched contrastive process
11 | Args:
12 | logits: torch.Tensor[B, C], B - mini-batch size, C - number of classes
13 | mask: torch.Tensor[B, C], binary mask, 1 if the class is present in the image, 0 otherwise
14 | logits_mask: torch.Tensor[B, C], mask out self-matching logits, not applied in multi-modal setting
15 | Returns:
16 | loss_cl: torch.Tensor[1], mean cross-entropy loss over positive pairs
17 | """
18 | # flatten the batch dimension
19 | logits = logits.reshape(-1)
20 | mask = mask.reshape(-1)
21 |
22 | # for numerical stability
23 | logits_max = torch.max(logits)
24 | logits = logits - logits_max.detach()
25 | exp_mixed_logits = torch.exp(logits)
26 |
27 | # mask out self-matching logits
28 | if logits_mask is not None:
29 | logits_mask = logits_mask.reshape(-1)
30 | exp_mixed_logits = exp_mixed_logits * logits_mask
31 |
32 | # cross entropy + softmax
33 | log_prob = logits - torch.log(exp_mixed_logits.sum())
34 | num_pos_pairs = mask.sum()
35 |
36 | # sum over positive pairs, division is outside the log
37 | num_pos_pairs = torch.where(num_pos_pairs < 1e-6, 1, num_pos_pairs)
38 | mean_log_prob_pos = (mask * log_prob).sum() / num_pos_pairs
39 |
40 | # mean over batch samples
41 | loss = -mean_log_prob_pos
42 | loss = loss.mean()
43 |
44 | return loss
45 |
46 |
--------------------------------------------------------------------------------
/loss/seesawloss.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from typing import Union
6 |
7 |
8 |
9 | class focal_loss(nn.Module):
10 | def __init__(self, alpha=None, gamma=2.0, num_classes=2):
11 | super(focal_loss, self).__init__()
12 | self.alpha = alpha
13 | self.gamma = gamma
14 | self.num_classes = num_classes
15 |
16 | def forward(self, y_pred, y_labels):
17 | y_pred = torch.softmax(y_pred, dim=1)
18 | class_mask = F.one_hot(y_labels, num_classes=self.num_classes)
19 | pt = (y_pred * class_mask).sum(dim=1)
20 | if self.alpha is None:
21 | loss = -((1 - pt) ** self.gamma) * pt.log()
22 | loss = loss.mean()
23 | else:
24 | alpha = self.alpha[y_labels]
25 | loss = -alpha * ((1 - pt) ** self.gamma) * pt.log()
26 | loss = loss.sum() / alpha.sum()
27 | return loss
28 |
29 |
30 | class SeesawLossWithLogits(nn.Module):
31 | """
32 | This is unofficial implementation for Seesaw loss,
33 | which is proposed in the techinical report for LVIS workshop at ECCV 2020.
34 | For more detail, please refer https://arxiv.org/pdf/2008.10032.pdf.
35 | Args:
36 | class_counts: The list which has number of samples for each class.
37 | Should have same length as num_classes.
38 | p: Scale parameter which adjust the strength of panishment.
39 | Set to 0.8 as a default by following the original paper.
40 | """
41 |
42 | def __init__(self, class_counts: Union[list, np.array], p: float = 0.8):
43 | super().__init__()
44 |
45 | class_counts = torch.FloatTensor(class_counts)
46 | conditions = class_counts[:, None] > class_counts[None, :]
47 | trues = (class_counts[None, :] / class_counts[:, None]) ** p
48 | print(trues.dtype)
49 | falses = torch.ones(len(class_counts), len(class_counts))
50 | self.s = torch.where(conditions, trues, falses)
51 | self.num_labels = len(class_counts)
52 | self.eps = 1.0e-6
53 |
54 | def forward(self, logits, targets):
55 | targets = F.one_hot(targets, self.num_labels)
56 | # print(targets.shape)
57 | self.s = self.s.to(targets.device)
58 | max_element, _ = logits.max(axis=-1)
59 | logits = logits - max_element[:, None] # to prevent overflow
60 |
61 | numerator = torch.exp(logits)
62 | denominator = (
63 | (1 - targets)[:, None, :]
64 | * self.s[None, :, :]
65 | * torch.exp(logits)[:, None, :]).sum(axis=-1) \
66 | + torch.exp(logits)
67 |
68 | sigma = numerator / (denominator + self.eps)
69 | loss = (- targets * torch.log(sigma + self.eps)).sum(-1)
70 | return loss.mean()
71 |
72 |
73 | class DistibutionAgnosticSeesawLossWithLogits(nn.Module):
74 | """
75 | This is unofficial implementation for Seesaw loss,
76 | which is proposed in the techinical report for LVIS workshop at ECCV 2020.
77 | For more detail, please refer https://arxiv.org/pdf/2008.10032.pdf.
78 | Args:
79 | p: Parameter for Mitigation Factor,
80 | Set to 0.8 for default following the paper.
81 | q: Parameter for Compensation Factor
82 | Set to 2 for default following the paper.
83 | num_labels: Class nums
84 | """
85 |
86 | def __init__(self, p: float = 0.8, q: float = 2, num_labels=2, data_name='PETA'): #num_labels=2
87 | super().__init__()
88 | self.eps = 1.0e-6
89 | self.p = p
90 | self.q = q
91 | self.class_counts = None
92 | self.num_labels = num_labels
93 | self.data_name = data_name
94 |
95 | def forward(self, logits, targets):
96 | # Mitigation Factor
97 | if self.class_counts is None:
98 | self.class_counts = (targets.sum(axis=0) + 1).float() # to prevent devided by zero.
99 | else:
100 | self.class_counts += targets.sum(axis=0)
101 |
102 | m_conditions = self.class_counts[:, None] > self.class_counts[None, :]
103 | m_trues = (self.class_counts[None, :] / self.class_counts[:, None]) ** self.p
104 | m_falses = torch.ones(len(self.class_counts), len(self.class_counts)).to(targets.device)
105 | m = torch.where(m_conditions, m_trues, m_falses) # [num_labels, num_labels]
106 |
107 | # Compensation Factor
108 | probility = logits
109 |
110 | c_condition = probility / (probility * targets).sum(dim=-1)[:, None]
111 | c_condition = torch.stack([c_condition] * targets.shape[-1], dim=1)
112 | c_condition = c_condition * targets[:, :, None]
113 | false = torch.ones(c_condition.shape).to(targets.device)
114 | c = torch.where(c_condition>1, c_condition ** self.q, false)
115 |
116 | # Sij = Mij * Cij
117 | s = m[None, :, :] * c
118 | # s = c
119 |
120 | # softmax trick to prevent overflow (like logsumexp trick)
121 | max_element, _ = logits.max(axis=-1)
122 | logits = logits - max_element[:, None] # to prevent overflow
123 | numerator = torch.exp(logits)
124 | denominator = (
125 | (1 - targets)[:, None, :]
126 | * s[None, :, :]
127 | * torch.exp(logits)[:, None, :]).sum(axis=-1) \
128 | + torch.exp(logits)
129 |
130 | sigma = numerator / (denominator + self.eps)
131 |
132 | #seesaw loss
133 | loss = (- targets * torch.log(sigma + self.eps)).sum(-1)
134 | return loss.mean()
135 |
136 |
--------------------------------------------------------------------------------
/model/__init__.py:
--------------------------------------------------------------------------------
1 | from .model import RAM_Model
2 |
3 |
4 |
5 | def build_model(cfg, clip_model, dataset, clip_model_teacher=None):
6 | model = RAM_Model(
7 | cfg,
8 | clip_model,
9 | dataset.classnames_seen,
10 | dataset.classnames_unseen,
11 | clip_model_teacher
12 | )
13 | print('Build RAM Model Done')
14 | return model
15 |
--------------------------------------------------------------------------------
/model/base.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | import torch
4 | import torch.nn as nn
5 | from collections import OrderedDict
6 |
7 | from utils.model_utils import save_checkpoint, load_checkpoint, tolist_if_not
8 | from loss import AsymmetricLossOptimized, AsymmetricLossOptimized_partial, mmc_loss
9 |
10 |
11 | class BaseModel(nn.Module):
12 | """
13 | Basic Model
14 | Implementation of some common functions
15 | """
16 | def __init__(self):
17 | super().__init__()
18 | self._models = OrderedDict()
19 | self._optims = OrderedDict()
20 | self._scheds = OrderedDict()
21 |
22 | def get_model_names(self, names=None):
23 | names_real = list(self._models.keys())
24 | if names is not None:
25 | names = tolist_if_not(names)
26 | for name in names:
27 | assert name in names_real
28 | return names
29 | else:
30 | return names_real
31 |
32 | def register_model(self, name="model", model=None, optim=None, sched=None):
33 | assert name not in self._models, "Found duplicate model names"
34 |
35 | self._models[name] = model
36 | self._optims[name] = optim
37 | self._scheds[name] = sched
38 |
39 | def get_current_lr(self, names=None):
40 | names = self.get_model_names(names)
41 | name = names[0]
42 | return self._optims[name].param_groups[0]["lr"]
43 |
44 | def get_specific_lr(self, names=None):
45 | if names is None:
46 | names = self.get_model_names(names)
47 | name = names[0]
48 | else:
49 | name = names
50 | return self._optims[name].param_groups[0]["lr"]
51 |
52 | def update_lr(self, names=None):
53 | names = self.get_model_names(names)
54 |
55 | for name in names:
56 | if self._scheds[name] is not None:
57 | self._scheds[name].step()
58 |
59 | def set_model_mode(self, mode="train", names=None):
60 | names = self.get_model_names(names)
61 |
62 | for name in names:
63 | if mode == "train":
64 | self._models[name].train()
65 | elif mode in ["test", "eval"]:
66 | self._models[name].eval()
67 | else:
68 | raise KeyError
69 |
70 | def detect_anomaly(self, loss):
71 | if not torch.isfinite(loss).all():
72 | raise FloatingPointError("Loss is infinite or NaN!")
73 |
74 | def save_model(self, iters, directory, is_best=False):
75 | # save registered_module
76 | names = self.get_model_names()
77 |
78 | for name in names:
79 | model_dict = self._models[name].state_dict()
80 | save_dict = OrderedDict()
81 | for k, v in self._models[name].named_parameters():
82 | if v.requires_grad:
83 | save_dict[k] = model_dict[k]
84 |
85 | sdir = os.path.join(directory, name)
86 | save_checkpoint(
87 | {
88 | "state_dict": save_dict,
89 | "iters": iters,
90 | },
91 | sdir,
92 | is_best
93 | )
94 |
95 | print(f"Checkpoint of {name} saved to {sdir}")
96 |
97 | def load_model(self, directory, iters):
98 | model_file = f"model-iters{iters}.pth"
99 | names = self.get_model_names()
100 |
101 | for name in names:
102 | model_path = os.path.join(directory, name, model_file)
103 | if not os.path.exists(model_path):
104 | raise FileNotFoundError('Model not found at "{}"'.format(model_path))
105 | checkpoint = load_checkpoint(model_path)
106 | state_dict = checkpoint["state_dict"]
107 | iters = checkpoint["iters"]
108 |
109 | # Ignore fixed token vectors
110 | if "token_prefix" in state_dict:
111 | del state_dict["token_prefix"]
112 | if "token_suffix" in state_dict:
113 | del state_dict["token_suffix"]
114 |
115 | print("Loading weights to {} " 'from "{}"'.format(name, model_path))
116 | self._models[name].load_state_dict(state_dict, strict=False)
117 |
118 | def make_criterion(self, cfg):
119 | """
120 | Classification loss
121 | - Zero-shot setting: MMC loss
122 | - Partial-label setting: ASL partial loss
123 | """
124 | if cfg.MODEL.LOSS_TYPE == 'MMC':
125 | criterion = mmc_loss
126 | elif cfg.MODEL.LOSS_TYPE == 'ASL':
127 | criterion = AsymmetricLossOptimized(cfg.SOLVER.GAMMA_NEG, cfg.SOLVER.GAMMA_POS, cfg.SOLVER.CLIP)
128 | elif cfg.MODEL.LOSS_TYPE == 'ASL-partial':
129 | criterion = AsymmetricLossOptimized_partial(cfg.SOLVER.GAMMA_NEG, cfg.SOLVER.GAMMA_POS, cfg.SOLVER.CLIP)
130 | else:
131 | raise NotImplementedError
132 |
133 | return criterion
134 |
135 | def freeze(self, transfer_type):
136 | if hasattr(self, "clip_model_teacher") and self.clip_model_teacher is not None:
137 | for name, param in self.clip_model_teacher.named_parameters():
138 | param.requires_grad = False
139 |
140 | if transfer_type == "no_freeze":
141 | pass
142 |
143 | elif transfer_type == "freeze_all":
144 | for name, param in self.clip_model.named_parameters():
145 | param.requires_grad = False
146 |
147 | elif transfer_type == "freeze_text":
148 | for name, param in self.clip_model.named_parameters():
149 | if 'visual.' in name:
150 | continue
151 | else:
152 | param.requires_grad = False
153 |
154 | elif transfer_type == "Adapter":
155 | for name, param in self.clip_model.named_parameters():
156 | if "adapter" in name or "embeds" in name: # embeds
157 | param.requires_grad = True
158 | else:
159 | param.requires_grad = False
160 |
161 | elif "partial" in transfer_type:
162 | total_layer = len(self.clip_model.visual.transformer.resblocks)
163 | partial_layer = int(transfer_type.split("-")[-1])
164 | if partial_layer > total_layer:
165 | raise NotImplementedError
166 | for name, param in self.clip_model.named_parameters():
167 | find = False
168 | for l in range(total_layer-partial_layer, total_layer):
169 | if "visual.transformer.resblocks.{}".format(l) in name:
170 | param.requires_grad = True
171 | find = True
172 | break
173 | if not find:
174 | param.requires_grad = False
175 |
176 | else:
177 | raise NotImplementedError
178 |
179 | def load_param(self, trained_path):
180 | param_dict = torch.load(trained_path)
181 | for i in param_dict:
182 | self.state_dict()[i.replace('module.', '')].copy_(param_dict[i])
183 | print('Loading pretrained model from {}'.format(trained_path))
184 |
185 |
186 |
--------------------------------------------------------------------------------
/model/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import torch.distributed as dist
5 |
6 | import clip
7 | from clip import clip
8 | from .base import BaseModel
9 | from .ot_solver import Sinkhorn
10 |
11 |
12 | class RAM_Model(BaseModel):
13 | def __init__(
14 | self,
15 | cfg,
16 | clip_model,
17 | classnames_seen,
18 | classnames_unseen,
19 | clip_model_teacher
20 | ):
21 | super().__init__()
22 | self.cfg = cfg
23 | self.classnames_seen = classnames_seen
24 | self.classnames_unseen = classnames_unseen
25 | self.num_classes_seen = len(classnames_seen)
26 | self.num_classes_unseen = len(classnames_unseen)
27 | self.criterion = self.make_criterion(cfg)
28 | self.device = dist.get_rank() if cfg.MODEL.DIST_TRAIN else 'cuda'
29 |
30 | # create teacher model, preserve teacher text features
31 | self.clip_model = clip_model
32 | self.text_tokens = self.get_text_templates()
33 | self.text_fea_teacher = self.get_text_fea(clip_model_teacher, classnames_seen+classnames_unseen)
34 | self.clip_model_teacher = clip_model_teacher
35 |
36 | # freeze the model
37 | self.freeze(cfg.MODEL.TRANSFER_TYPE)
38 |
39 | # pre-trained logit scale
40 | self.logit_scale = self.clip_model.logit_scale.exp()
41 | # learnable temperature
42 | self.temperature_loc = nn.Parameter(torch.tensor(cfg.MODEL.TEMPERATURE))
43 | self.temperature = nn.Parameter(1./self.logit_scale)
44 |
45 | # KCOT parameters
46 | self.reg = cfg.MODEL.OT_REG
47 | self.reg_sc = cfg.MODEL.OT_REGSC
48 |
49 | def get_text_fea(self, clip_model, classnames):
50 | text_templates = "A photo of a {}."
51 | text_templates = [text_templates.format(classnames[i]) for i in range(len(classnames))]
52 | text_tok = clip.tokenize(text_templates)
53 | with torch.no_grad():
54 | text_fea = clip_model.encode_text(text_tok)
55 | return text_fea.unsqueeze(1).detach()
56 |
57 | def get_text_templates(self):
58 | templates = "A photo of a {}."
59 | texts = [templates.format(name) for name in self.classnames_seen+self.classnames_unseen]
60 | text_tokens = clip.tokenize(texts)
61 | return text_tokens.cuda()
62 |
63 | def build_weights(self, sim, dim=-1, temperature=0.1):
64 | with torch.no_grad():
65 | sim_max = sim.max(dim=dim)[0]
66 | weights = (sim_max / temperature).softmax(dim=-1)
67 | return weights
68 |
69 | def generate_teacher_distribution(self, img_teacher, zsl=False, gzsl=False):
70 | with torch.no_grad():
71 | _, img_loc = self.clip_model_teacher.visual(img_teacher)
72 | img_loc = img_loc[0][:, 1:]
73 | text_fea = self.text_fea_teacher.clone().cuda()
74 | if zsl:
75 | text_fea = text_fea[self.num_classes_seen:]
76 | elif gzsl:
77 | pass
78 | else:
79 | text_fea = text_fea[:self.num_classes_seen]
80 | B, tok, dim=img_loc.shape
81 | C, gp, dim = text_fea.shape
82 |
83 | text_fea = F.normalize(text_fea, dim=-1)
84 | img_loc = F.normalize(img_loc, dim=-1)
85 | text_fea = text_fea.unsqueeze(0).expand(B, -1, -1, -1).reshape(B, -1, dim)
86 |
87 | # generate weight
88 | logit_scale = self.clip_model_teacher.logit_scale.exp()
89 | logits_loc = logit_scale * img_loc @ text_fea.transpose(-2, -1)
90 | logits_loc = logits_loc.reshape(B, -1, C, gp)
91 | local_similarity = logits_loc.softmax(dim=2)
92 | prob = (local_similarity*20.).softmax(dim=1)
93 | prob = prob.mean(dim=-1)
94 | return prob
95 |
96 | def forward(self, img, target=None, zsl=False, gzsl=False):
97 | seen = True if not zsl and not gzsl else False
98 | if seen:
99 | text_tokens = self.text_tokens[:self.num_classes_seen].clone()
100 | elif zsl:
101 | text_tokens = self.text_tokens[self.num_classes_seen:self.num_classes_seen+self.num_classes_unseen].clone()
102 | else:
103 | text_tokens = self.text_tokens[:self.num_classes_seen+self.num_classes_unseen].clone()
104 | prompt_fea_loc = self.clip_model.encode_text(text_tokens)
105 | prompt_fea_loc = prompt_fea_loc.unsqueeze(1)
106 |
107 | img_glb, img_loc = self.clip_model.visual(img, text_fea=prompt_fea_loc)
108 | img_loc = img_loc[0][:, 1:]
109 |
110 | B, tok, dim = img_loc.shape
111 | C, gp, dim = prompt_fea_loc.shape
112 |
113 | prompt_fea_loc = prompt_fea_loc.permute(1, 0, 2)
114 |
115 | img_glb = F.normalize(img_glb, dim=-1)
116 | img_loc = F.normalize(img_loc, dim=-1)
117 | prompt_fea_loc = F.normalize(prompt_fea_loc, dim=-1)
118 |
119 | logits_glb = img_glb @ prompt_fea_loc.transpose(1, 2) / self.temperature
120 | score_glb = logits_glb.squeeze(1).softmax(dim=-1)
121 | if self.training:
122 | mask = target
123 | loss_glb = self.criterion(logits_glb, mask)
124 |
125 | # Cost matrix
126 | sim = img_loc @ prompt_fea_loc.transpose(1, 2)
127 | cost = (sim * self.logit_scale).softmax(dim=-1)
128 | cost = 1.0 - cost
129 |
130 | if self.training:
131 | # Teacher is only applied in training
132 | frozen_mask = self.generate_teacher_distribution(img, zsl, gzsl)
133 | gt_mask = target.unsqueeze(1).expand(-1, tok, -1)
134 | frozen_mask[gt_mask==0] = frozen_mask.min()
135 | cost_tr = -torch.log(frozen_mask) * self.reg_sc
136 | cost = cost + cost_tr
137 | reg = self.reg + self.reg_sc
138 | else:
139 | reg = self.reg
140 |
141 | u = self.build_weights(sim.detach(), dim=2, temperature=0.1)
142 | v = torch.zeros((B, C), dtype=sim.dtype, device=sim.device).fill_(1./C)
143 | with torch.no_grad():
144 | T = Sinkhorn(u, v, cost, reg=reg)
145 | if torch.isnan(T).any():
146 | raise ValueError("Found nan in OT matrix!")
147 |
148 | sim_op = T * sim
149 | sim_op = sim_op.sum(dim=1) / self.temperature_loc
150 | score_loc = sim_op.softmax(dim=-1)
151 | score = (score_glb + score_loc) / 2.
152 | if self.training:
153 | mask = target
154 | loss_loc = self.criterion(sim_op, mask)
155 | loss = loss_glb + loss_loc
156 | return {"score": score, "loss": loss}
157 | else:
158 | return {"score": score}
159 |
160 |
161 |
162 |
--------------------------------------------------------------------------------
/model/ot_solver.py:
--------------------------------------------------------------------------------
1 | """
2 | Sinkhorn Algorithm for different settings of Optimal Transport
3 | Implementation inspired from: https://github.com/PythonOT/POT, thanks
4 | """
5 |
6 |
7 | import torch
8 |
9 |
10 |
11 | def Sinkhorn(a, b, M, reg, max_iter=100, thresh=1e-3):
12 | """
13 | Sinkhorn Iteration
14 | Solving Entropic Optimal Transport (EOT)
15 | Args:
16 | a: torch.Tensor[B, N], B - batch size, N - number of points in the source distribution
17 | b: torch.Tensor[B, M], B - batch size, M - number of points in the target distribution
18 | M: torch.Tensor[B, N, M], cost matrix
19 | reg: float, regularization strength
20 | max_iter: int, maximum number of iterations
21 | thresh: float, convergence threshold
22 | Returns:
23 | T: torch.Tensor[B, N, M], transport plan
24 | """
25 | K = torch.exp(-M / reg)
26 | r = torch.ones_like(a)
27 | c = torch.ones_like(b)
28 | thresh = 1e-3
29 |
30 | for i in range(max_iter):
31 | r0 = r
32 | r = a / torch.matmul(K, c.unsqueeze(-1)).squeeze(-1)
33 | c = b / torch.matmul(K.permute(0, 2, 1).contiguous(), r.unsqueeze(-1)).squeeze(-1)
34 | err = (r - r0).abs().mean(dim=1)
35 | if torch.all(err < thresh):
36 | break
37 | T = torch.matmul(r.unsqueeze(-1), c.unsqueeze(-2)) * K
38 | return T
39 |
40 |
41 |
42 | def Sinkhorn_entropic_unbalanced(a, b, M, reg, reg_m, max_iter=100, thresh=1e-3):
43 | """
44 | Sinkhorn Iteration
45 | Solving Entropic Unbalanced Optimal Transport (EUOT)
46 | Args:
47 | a: torch.Tensor[B, N], B - batch size, N - number of points in the source distribution
48 | b: torch.Tensor[B, M], B - batch size, M - number of points in the target distribution
49 | M: torch.Tensor[B, N, M], cost matrix
50 | reg: float, entropy regularization strength
51 | reg_m: float, marginal regularization strength
52 | max_iter: int, maximum number of iterations
53 | thresh: float, convergence threshold
54 | Returns:
55 | T: torch.Tensor[B, N, M], transport plan
56 | """
57 | if isinstance(reg_m, float) or isinstance(reg_m, int):
58 | reg_m1, reg_m2 = reg_m, reg_m
59 | else:
60 | reg_m1, reg_m2 = reg_m[0], reg_m[1]
61 |
62 | u = torch.ones_like(a)
63 | v = torch.ones_like(b)
64 |
65 | # entropic reg
66 | K = torch.exp(-M / reg)
67 | # kl unbalanced
68 | fi_1 = reg_m1 / (reg_m1 + reg)
69 | fi_2 = reg_m2 / (reg_m2 + reg)
70 |
71 | thresh = 1e-3
72 | for i in range(max_iter):
73 | uprev = u
74 | vprev = v
75 |
76 | Kv = torch.matmul(K, v.unsqueeze(-1)).squeeze(-1)
77 | u = (a / Kv) ** fi_1
78 | Ktu = torch.matmul(K.permute(0, 2, 1).contiguous(), u.unsqueeze(-1)).squeeze(-1)
79 | v = (b / Ktu) ** fi_2
80 |
81 | max_u = torch.cat([torch.max(torch.abs(u), dim=1, keepdim=True)[0], torch.max(torch.abs(uprev), dim=1, keepdim=True)[0], torch.ones((u.shape[0], 1)).cuda()], dim=1)
82 | max_v = torch.cat([torch.max(torch.abs(v), dim=1, keepdim=True)[0], torch.max(torch.abs(vprev), dim=1, keepdim=True)[0], torch.ones((v.shape[0], 1)).cuda()], dim=1)
83 |
84 | err_u = torch.max(torch.abs(u - uprev), dim=1)[0] / torch.max(max_u, dim=1)[0]
85 | err_v = torch.max(torch.abs(v - vprev), dim=1)[0] / torch.max(max_v, dim=1)[0]
86 |
87 | err = 0.5 * (err_u.mean() + err_v.mean())
88 | if err.item() < thresh:
89 | break
90 |
91 | T = torch.matmul(u.unsqueeze(-1), v.unsqueeze(-2)) * K
92 | return T
93 |
94 |
95 |
96 | def Sinkhorn_unbalanced(a, b, M, reg_m, div='kl', reg=0, max_iter=100, thresh=1e-3):
97 | """
98 | Sinkhorn Iteration
99 | Solving Unbalanced Optimal Transport (UOT)
100 | Args:
101 | a: torch.Tensor[B, N], B - batch size, N - number of points in the source distribution
102 | b: torch.Tensor[B, M], B - batch size, M - number of points in the target distribution
103 | M: torch.Tensor[B, N, M], cost matrix
104 | reg_m: float, marginals regularization strength
105 | div: regularization method ("kl", "l2")
106 | max_iter: int, maximum number of iterations
107 | thresh: float, convergence threshold
108 | Returns:
109 | T: torch.Tensor[B, N, M], transport plan
110 | """
111 | if isinstance(reg_m, float) or isinstance(reg_m, int):
112 | reg_m1, reg_m2 = reg_m, reg_m
113 | else:
114 | reg_m1, reg_m2 = reg_m[0], reg_m[1]
115 |
116 | G = torch.matmul(a.unsqueeze(-1), b.unsqueeze(-2))
117 | c = torch.matmul(a.unsqueeze(-1), b.unsqueeze(-2))
118 | assert div in ["kl", "l2"]
119 |
120 | if div == 'kl':
121 | sum_r = reg + reg_m1 + reg_m2
122 | r1, r2, r = reg_m1 / sum_r, reg_m2 / sum_r, reg / sum_r
123 | K = torch.matmul(a.unsqueeze(-1)**r1, b.unsqueeze(-2)**r2) * (c**r) * torch.exp(-M / sum_r)
124 | elif div == 'l2':
125 | K = reg_m1 * a.unsqueeze(-1) + reg_m2 * b.unsqueeze(-2) + reg * c - M
126 | K = torch.max(K, torch.zeros_like(M))
127 |
128 | thresh = 1e-3
129 | for i in range(max_iter):
130 | Gprev = G
131 |
132 | if div == 'kl':
133 | Gd = torch.matmul(torch.sum(G, dim=-1, keepdim=True)**r1, torch.sum(G, dim=1, keepdim=True)**r2) + 1e-16
134 | G = K * G**(r1 + r2) / Gd
135 | elif div == 'l2':
136 | Gd = reg_m1 * torch.sum(G, dim=-1, keepdim=True) + \
137 | reg_m2 * torch.sum(G, dim=1, keepdim=True) + reg * G + 1e-16
138 | G = K * G / Gd
139 |
140 | err = torch.sqrt(torch.sum((G - Gprev) ** 2, dim=(1,2)).mean())
141 | if err < thresh:
142 | break
143 |
144 | return G
145 |
146 |
147 |
148 |
149 |
--------------------------------------------------------------------------------
/processor/__init__.py:
--------------------------------------------------------------------------------
1 | from .processor import do_train
--------------------------------------------------------------------------------
/processor/processor.py:
--------------------------------------------------------------------------------
1 |
2 | import logging
3 | import os
4 | import time
5 | import wandb
6 | import datetime
7 | import numpy as np
8 | from clip import convert_weights
9 | import torch
10 | import torch.nn as nn
11 | from torch.cuda import amp
12 | import torch.distributed as dist
13 | from torch.autograd import Variable
14 |
15 | from utils.model_utils import thread_flag
16 | from utils.model_utils import ModelEma
17 | from utils.meter import AverageMeter
18 | from utils.metrics import multilabel_evaluation
19 |
20 | import warnings
21 | warnings.filterwarnings("ignore")
22 |
23 |
24 |
25 |
26 | def do_train(
27 | cfg,
28 | model,
29 | train_loader,
30 | val_loader,
31 | val_loader_gzsl,
32 | optimizer,
33 | optimizer_sgd,
34 | scheduler,
35 | scheduler_sgd,
36 | output_dir,
37 | train_sampler=None,
38 | ):
39 | log_period = cfg.SOLVER.LOG_PERIOD
40 | checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD
41 | eval_period = cfg.SOLVER.EVAL_PERIOD
42 | device = "cuda"
43 | epochs = cfg.SOLVER.MAX_EPOCHS
44 | logger = logging.getLogger("RAM.train")
45 | logger.info('start training')
46 |
47 | if device:
48 | model.to(cfg.LOCAL_RANK)
49 | if torch.cuda.device_count() > 1 and cfg.MODEL.DIST_TRAIN:
50 | print('Using {} GPUs for training'.format(torch.cuda.device_count()))
51 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)
52 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[cfg.LOCAL_RANK], find_unused_parameters=True)
53 |
54 | loss_meter = AverageMeter()
55 | acc_meter = AverageMeter()
56 | batch_time = AverageMeter()
57 | scaler = amp.GradScaler()
58 |
59 | # Training
60 | label_smoothing = cfg.SOLVER.LABEL_SMOOTHING
61 | torch.cuda.empty_cache()
62 | thresh = cfg.SOLVER.THRESH
63 | ema_m = None
64 | tot_iters = len(train_loader) * epochs
65 |
66 | for epoch in range(1, epochs + 1):
67 | if cfg.MODEL.USE_EMA:
68 | if cfg.MODEL.DIST_TRAIN:
69 | ema_m = ModelEma(model.module, cfg.MODEL.EMA_DECAY, device=dist.get_rank())
70 | else:
71 | ema_m = ModelEma(model, cfg.MODEL.EMA_DECAY, device=device)
72 |
73 | if train_sampler is not None:
74 | train_sampler.set_epoch(epoch)
75 | torch.cuda.empty_cache()
76 |
77 | loss_meter.reset()
78 | acc_meter.reset()
79 |
80 | scheduler.step(epoch)
81 | scheduler_sgd.step(epoch)
82 |
83 | model.train()
84 | for n_iter, (img, label) in enumerate(train_loader):
85 |
86 | start = time.time()
87 | if cfg.SOLVER.LR_SCHEDULER == 'onecycle':
88 | scheduler.step()
89 | scheduler_sgd.step(epoch)
90 | correct = 0
91 | total = 0
92 | optimizer.zero_grad()
93 | optimizer_sgd.zero_grad()
94 | img = img.to(device)
95 |
96 | # construct GT matrix
97 | if label_smoothing:
98 | label_f = label.float()
99 | label_soft = torch.where(label_f == 1, torch.tensor(0.9), label_f)
100 | label_soft = torch.where(label_soft == 0, torch.tensor(0.1), label_soft)
101 | target = label_soft.to(device)
102 | else:
103 | target = label.to(device)
104 |
105 | with amp.autocast(enabled=True):
106 | output = model(img, target=target)
107 |
108 | score = output["score"]
109 | loss = output["loss"]
110 |
111 | # score, loss
112 | scaler.scale(loss).backward()
113 | gpu_mem = torch.cuda.max_memory_allocated()/(1024.0 * 1024.0)
114 | scaler.step(optimizer)
115 | scaler.step(optimizer_sgd)
116 | scaler.update()
117 | if ema_m is not None:
118 | if cfg.MODEL.DIST_TRAIN:
119 | ema_m.update(model.module)
120 | else:
121 | ema_m.update(model)
122 |
123 | targets = Variable(label)
124 | label = label.numpy()
125 | outputs_np = score.data.cpu().numpy()
126 | predicted = outputs_np > thresh
127 | correct += np.sum(predicted == label, axis=0)
128 | total += targets.size(0)
129 | acc = np.mean(correct / total)
130 |
131 | loss_meter.update(loss.item(), img.shape[0])
132 | acc_meter.update(acc, 1)
133 | torch.cuda.synchronize()
134 |
135 | batch_time.update(time.time() - start)
136 |
137 | # Logging
138 | if (n_iter + 1) % log_period == 0:
139 | if thread_flag(cfg.MODEL.DIST_TRAIN):
140 | now_iter = (epoch-1) * len(train_loader) + n_iter
141 | nb_remain = tot_iters - now_iter
142 | eta_seconds = batch_time.avg * nb_remain
143 | eta = str(datetime.timedelta(seconds=int(eta_seconds)))
144 | cur_lr = optimizer.param_groups[0]['lr']
145 |
146 | logger.info("Epoch[{}] Iteration[{}/{}] Loss: {:.3f}, Acc: {:.3f}, lr: {:.2e}, mem: {:.2f}MB, speed:{:.2f}[img/s], ETA: {}"
147 | .format(epoch, n_iter+1, len(train_loader), loss_meter.avg, acc, cur_lr, gpu_mem, train_loader.batch_size/batch_time.avg, eta))
148 |
149 | if cfg.WANDB:
150 | wandb.log({
151 | "epoch": epoch,
152 | "lr": cur_lr,
153 | "train loss": loss_meter.avg,
154 | "train acc": acc
155 | })
156 |
157 | output_path = os.path.join(output_dir, f'epoch{epoch}.pth')
158 | if cfg.SOLVER.SAVE_MODEL and epoch % checkpoint_period == 0:
159 | if thread_flag(cfg.MODEL.DIST_TRAIN):
160 | torch.save(model.state_dict(), output_path)
161 |
162 | # Testing
163 | if epoch % eval_period == 0:
164 | if thread_flag(cfg.MODEL.DIST_TRAIN):
165 | Result_k, Result_k5 = validate(cfg, val_loader, model, device, zsl=True, gzsl=False)
166 | Result_k_gzsl, Result_k5_gzsl = validate(cfg, val_loader_gzsl, model, device, zsl=False, gzsl=True)
167 | now_metric = (Result_k["OF1"] + Result_k_gzsl["OF1"]) / 2.
168 |
169 | if ema_m is not None:
170 | Result_k_ema, Result_k5_ema = validate(cfg, val_loader, ema_m.module, device, zsl=True, gzsl=False)
171 | Result_k_gzsl_ema, Result_k5_gzsl_ema = validate(cfg, val_loader_gzsl, ema_m.module, device, zsl=False, gzsl=True)
172 |
173 | logger.info("Validation Results - Epoch: {}, F1_avg: {:.3%}".format(epoch, now_metric))
174 | logger.info("ZSL:")
175 | logger.info("OP_3: {:.3%}, OR_3: {:.3%}, OF1_3: {:.3%}".format(Result_k['OP'], Result_k['OR'], Result_k['OF1']))
176 | logger.info("OP_5: {:.3%}, OR_5: {:.3%}, OF1_5: {:.3%}".format(Result_k5['OP'], Result_k5['OR'], Result_k5['OF1']))
177 | logger.info("GZSL:")
178 | logger.info("OP_3: {:.3%}, OR_3: {:.3%}, OF1_3: {:.3%}".format(Result_k_gzsl['OP'], Result_k_gzsl['OR'], Result_k_gzsl['OF1']))
179 | logger.info("OP_5: {:.3%}, OR_5: {:.3%}, OF1_5: {:.3%}".format(Result_k5_gzsl['OP'], Result_k5_gzsl['OR'], Result_k5_gzsl['OF1']))
180 | if ema_m is not None:
181 | logger.info("EMA Results:")
182 | logger.info("ZSL:")
183 | logger.info("OP_3: {:.3%}, OR_3: {:.3%}, OF1_3: {:.3%}".format(Result_k_ema['OP'], Result_k_ema['OR'], Result_k_ema['OF1']))
184 | logger.info("OP_5: {:.3%}, OR_5: {:.3%}, OF1_5: {:.3%}".format(Result_k5_ema['OP'], Result_k5_ema['OR'], Result_k5_ema['OF1']))
185 | logger.info("GZSL:")
186 | logger.info("OP_3: {:.3%}, OR_3: {:.3%}, OF1_3: {:.3%}".format(Result_k_gzsl_ema['OP'], Result_k_gzsl_ema['OR'], Result_k_gzsl_ema['OF1']))
187 | logger.info("OP_5: {:.3%}, OR_5: {:.3%}, OF1_5: {:.3%}".format(Result_k5_gzsl_ema['OP'], Result_k5_gzsl_ema['OR'], Result_k5_gzsl_ema['OF1']))
188 |
189 | # 3. log wandb
190 | if cfg.WANDB:
191 | wandb.log({
192 | "F1_avg ZSL-GZSL": now_metric,
193 | "OP_3": Result_k['OP'], "OR_3": Result_k['OR'], "OF1_3": Result_k['OF1'],
194 | "OP_5": Result_k5['OP'], "OR_5": Result_k5['OR'], "OF1_5": Result_k5['OF1'],
195 | "OP_3 GZSL": Result_k_gzsl['OP'], "OR_3 GZSL": Result_k_gzsl['OR'], "OF1_3 GZSL": Result_k_gzsl['OF1'],
196 | "OP_5 GZSL": Result_k5_gzsl['OP'], "OR_5 GZSL": Result_k5_gzsl['OR'], "OF1_5 GZSL": Result_k5_gzsl['OF1'],
197 | })
198 | if ema_m is not None:
199 | wandb.log({
200 | "OP_3 ema": Result_k_ema['OP'], "OR_3 ema": Result_k_ema['OR'], "OF1_3 ema": Result_k_ema['OF1'],
201 | "OP_5 ema": Result_k5_ema['OP'], "OR_5 ema": Result_k5_ema['OR'], "OF1_5 ema": Result_k5_ema['OF1'],
202 | "OP_3 GZSL ema": Result_k_gzsl_ema['OP'], "OR_3 GZSL ema": Result_k_gzsl_ema['OR'], "OF1_3 GZSL ema": Result_k_gzsl_ema['OF1'],
203 | "OP_5 GZSL ema": Result_k5_gzsl_ema['OP'], "OR_5 GZSL ema": Result_k5_gzsl_ema['OR'], "OF1_5 GZSL ema": Result_k5_gzsl_ema['OF1'],
204 | })
205 |
206 | torch.cuda.empty_cache()
207 |
208 |
209 |
210 | def validate(cfg, val_loader, model, device, zsl=True, gzsl=False):
211 | model.eval()
212 | total = 0
213 | batch_idx = 0
214 | batch_time = AverageMeter()
215 |
216 | for n_iter, (img, label) in enumerate(val_loader):
217 | with torch.no_grad():
218 | img = img.to(device)
219 | target = label.to(device)
220 | start = time.time()
221 | output = model(img, target=target, zsl=zsl, gzsl=gzsl)
222 | score = output["score"]
223 | label = label.numpy()
224 | outputs_np = score.data.cpu().numpy()
225 | gpu_mem = torch.cuda.max_memory_allocated()/(1024.0 * 1024.0)
226 |
227 | batch_time.update(time.time() - start)
228 |
229 | if total == 0:
230 | g_labels = label
231 | p_score = outputs_np
232 | else:
233 | g_labels = np.row_stack((g_labels, label))
234 | p_score = np.row_stack((p_score, outputs_np))
235 |
236 | total += label.shape[0]
237 | batch_idx += 1
238 |
239 | if (n_iter + 1) % cfg.SOLVER.LOG_PERIOD == 0:
240 | print("mem:{:.2f}, test speed:{:.2f}".format(gpu_mem, val_loader.batch_size/batch_time.avg))
241 |
242 | Result_k = multilabel_evaluation(p_score, g_labels, k=3)
243 | Result_k5 = multilabel_evaluation(p_score, g_labels, k=5)
244 | torch.cuda.empty_cache()
245 |
246 | return Result_k, Result_k5
247 |
248 |
249 |
250 | def parse_batch(batch):
251 | input, label = batch
252 | return input, label
253 |
254 |
255 |
256 |
257 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | ftfy==6.3.1
2 | imgaug==0.4.0
3 | ipdb==0.13.13
4 | numpy==1.25.2
5 | opencv_contrib_python==4.8.0.76
6 | opencv_python==4.8.0.76
7 | Pillow==9.4.0
8 | Pillow==11.1.0
9 | regex==2023.8.8
10 | setuptools==68.0.0
11 | timm==0.5.4
12 | torch==1.13.1
13 | torchvision==0.14.1
14 | tqdm==4.66.1
15 | yacs==0.1.8
16 |
--------------------------------------------------------------------------------
/solver/__init__.py:
--------------------------------------------------------------------------------
1 | from .make_optimizer import make_optimizer
2 | from .make_scheduler import make_scheduler
--------------------------------------------------------------------------------
/solver/make_optimizer.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 |
5 | def make_optimizer(cfg, model, lr_mult=0.1):
6 | clip_params, sgd_params, other_params = [], [], []
7 | for pname, p in model.named_parameters():
8 | if not p.requires_grad:
9 | continue
10 | elif 'embeds' in pname:
11 | sgd_params.append(p)
12 | elif pname.startswith('clip'):
13 | clip_params.append(p)
14 | else:
15 | other_params.append(p)
16 |
17 | # Optimizer1
18 | param_groups = [
19 | {'params': clip_params, 'lr': cfg.SOLVER.BASE_LR * lr_mult, 'weight_decay': cfg.SOLVER.WEIGHT_DECAY},
20 | {'params': other_params, 'lr': cfg.SOLVER.BASE_LR, 'weight_decay': cfg.SOLVER.WEIGHT_DECAY},
21 | ]
22 | if cfg.SOLVER.OPTIMIZER_NAME == 'SGD':
23 | optimizer = torch.optim.SGD(param_groups, momentum=cfg.SOLVER.MOMENTUM)
24 | elif cfg.SOLVER.OPTIMIZER_NAME == 'AdamW':
25 | optimizer = torch.optim.AdamW(param_groups)
26 | else:
27 | optimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)(param_groups, momentum=cfg.SOLVER.MOMENTUM)
28 |
29 | # Optimizer2
30 | param_groups_sgd = [{'params': sgd_params, 'lr': cfg.SOLVER.BASE_LR_SGD, 'weight_decay': cfg.SOLVER.WEIGHT_DECAY_SGD}]
31 | optimizer_sgd = torch.optim.SGD(param_groups_sgd, momentum=cfg.SOLVER.MOMENTUM)
32 |
33 | return optimizer, optimizer_sgd
34 |
35 |
36 |
--------------------------------------------------------------------------------
/solver/make_scheduler.py:
--------------------------------------------------------------------------------
1 | """ Scheduler Factory
2 | Hacked together by / Copyright 2020 Ross Wightman
3 | """
4 | import math
5 | import torch
6 | from torch.optim.lr_scheduler import OneCycleLR
7 | from typing import Dict, Any
8 |
9 |
10 |
11 | class Scheduler:
12 | """ Parameter Scheduler Base Class
13 | A scheduler base class that can be used to schedule any optimizer parameter groups.
14 |
15 | Unlike the builtin PyTorch schedulers, this is intended to be consistently called
16 | * At the END of each epoch, before incrementing the epoch count, to calculate next epoch's value
17 | * At the END of each optimizer update, after incrementing the update count, to calculate next update's value
18 |
19 | The schedulers built on this should try to remain as stateless as possible (for simplicity).
20 |
21 | This family of schedulers is attempting to avoid the confusion of the meaning of 'last_epoch'
22 | and -1 values for special behaviour. All epoch and update counts must be tracked in the training
23 | code and explicitly passed in to the schedulers on the corresponding step or step_update call.
24 |
25 | Based on ideas from:
26 | * https://github.com/pytorch/fairseq/tree/master/fairseq/optim/lr_scheduler
27 | * https://github.com/allenai/allennlp/tree/master/allennlp/training/learning_rate_schedulers
28 | """
29 |
30 | def __init__(self,
31 | optimizer: torch.optim.Optimizer,
32 | param_group_field: str,
33 | noise_range_t=None,
34 | noise_type='normal',
35 | noise_pct=0.67,
36 | noise_std=1.0,
37 | noise_seed=None,
38 | initialize: bool = True) -> None:
39 | self.optimizer = optimizer
40 | self.param_group_field = param_group_field
41 | self._initial_param_group_field = f"initial_{param_group_field}"
42 | if initialize:
43 | for i, group in enumerate(self.optimizer.param_groups):
44 | if param_group_field not in group:
45 | raise KeyError(f"{param_group_field} missing from param_groups[{i}]")
46 | group.setdefault(self._initial_param_group_field, group[param_group_field])
47 | else:
48 | for i, group in enumerate(self.optimizer.param_groups):
49 | if self._initial_param_group_field not in group:
50 | raise KeyError(f"{self._initial_param_group_field} missing from param_groups[{i}]")
51 | self.base_values = [group[self._initial_param_group_field] for group in self.optimizer.param_groups]
52 | self.metric = None # any point to having this for all?
53 | self.noise_range_t = noise_range_t
54 | self.noise_pct = noise_pct
55 | self.noise_type = noise_type
56 | self.noise_std = noise_std
57 | self.noise_seed = noise_seed if noise_seed is not None else 42
58 | self.update_groups(self.base_values)
59 |
60 | def state_dict(self) -> Dict[str, Any]:
61 | return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}
62 |
63 | def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
64 | self.__dict__.update(state_dict)
65 |
66 | def get_epoch_values(self, epoch: int):
67 | return None
68 |
69 | def get_update_values(self, num_updates: int):
70 | return None
71 |
72 | def step(self, epoch: int, metric: float = None) -> None:
73 | self.metric = metric
74 | values = self.get_epoch_values(epoch)
75 | if values is not None:
76 | values = self._add_noise(values, epoch)
77 | self.update_groups(values)
78 |
79 | def step_update(self, num_updates: int, metric: float = None):
80 | self.metric = metric
81 | values = self.get_update_values(num_updates)
82 | if values is not None:
83 | values = self._add_noise(values, num_updates)
84 | self.update_groups(values)
85 |
86 | def update_groups(self, values):
87 | if not isinstance(values, (list, tuple)):
88 | values = [values] * len(self.optimizer.param_groups)
89 | for param_group, value in zip(self.optimizer.param_groups, values):
90 | param_group[self.param_group_field] = value
91 |
92 | def _add_noise(self, lrs, t):
93 | if self.noise_range_t is not None:
94 | if isinstance(self.noise_range_t, (list, tuple)):
95 | apply_noise = self.noise_range_t[0] <= t < self.noise_range_t[1]
96 | else:
97 | apply_noise = t >= self.noise_range_t
98 | if apply_noise:
99 | g = torch.Generator()
100 | g.manual_seed(self.noise_seed + t)
101 | if self.noise_type == 'normal':
102 | while True:
103 | # resample if noise out of percent limit, brute force but shouldn't spin much
104 | noise = torch.randn(1, generator=g).item()
105 | if abs(noise) < self.noise_pct:
106 | break
107 | else:
108 | noise = 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct
109 | lrs = [v + v * noise for v in lrs]
110 | return lrs
111 |
112 |
113 | class CosineLRScheduler(Scheduler):
114 | """
115 | Cosine decay with restarts.
116 | This is described in the paper https://arxiv.org/abs/1608.03983.
117 |
118 | Inspiration from
119 | https://github.com/allenai/allennlp/blob/master/allennlp/training/learning_rate_schedulers/cosine.py
120 | """
121 |
122 | def __init__(self,
123 | optimizer: torch.optim.Optimizer,
124 | t_initial: int,
125 | t_mul: float = 1.,
126 | lr_min: float = 0.,
127 | decay_rate: float = 1.,
128 | warmup_t=0,
129 | warmup_lr_init=0,
130 | warmup_prefix=False,
131 | cycle_limit=0,
132 | t_in_epochs=True,
133 | noise_range_t=None,
134 | noise_pct=0.67,
135 | noise_std=1.0,
136 | noise_seed=42,
137 | initialize=True) -> None:
138 | super().__init__(
139 | optimizer, param_group_field="lr",
140 | noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed,
141 | initialize=initialize)
142 |
143 | assert t_initial > 0
144 | assert lr_min >= 0
145 | self.t_initial = t_initial
146 | self.t_mul = t_mul
147 | self.lr_min = lr_min
148 | self.decay_rate = decay_rate
149 | self.cycle_limit = cycle_limit
150 | self.warmup_t = warmup_t
151 | self.warmup_lr_init = warmup_lr_init
152 | self.warmup_prefix = warmup_prefix
153 | self.t_in_epochs = t_in_epochs
154 | if self.warmup_t:
155 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
156 | super().update_groups(self.warmup_lr_init)
157 | else:
158 | self.warmup_steps = [1 for _ in self.base_values]
159 |
160 | def _get_lr(self, t):
161 | if t < self.warmup_t:
162 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
163 | else:
164 | if self.warmup_prefix:
165 | t = t - self.warmup_t
166 |
167 | if self.t_mul != 1:
168 | i = math.floor(math.log(1 - t / self.t_initial * (1 - self.t_mul), self.t_mul))
169 | t_i = self.t_mul ** i * self.t_initial
170 | t_curr = t - (1 - self.t_mul ** i) / (1 - self.t_mul) * self.t_initial
171 | else:
172 | i = t // self.t_initial
173 | t_i = self.t_initial
174 | t_curr = t - (self.t_initial * i)
175 |
176 | gamma = self.decay_rate ** i
177 | lr_min = self.lr_min * gamma
178 | lr_max_values = [v * gamma for v in self.base_values]
179 |
180 | if self.cycle_limit == 0 or (self.cycle_limit > 0 and i < self.cycle_limit):
181 | lrs = [
182 | lr_min + 0.5 * (lr_max - lr_min) * (1 + math.cos(math.pi * t_curr / t_i)) for lr_max in lr_max_values
183 | ]
184 | else:
185 | lrs = [self.lr_min for _ in self.base_values]
186 |
187 | return lrs
188 |
189 | def get_epoch_values(self, epoch: int):
190 | if self.t_in_epochs:
191 | return self._get_lr(epoch)
192 | else:
193 | return None
194 |
195 | def get_update_values(self, num_updates: int):
196 | if not self.t_in_epochs:
197 | return self._get_lr(num_updates)
198 | else:
199 | return None
200 |
201 | def get_cycle_length(self, cycles=0):
202 | if not cycles:
203 | cycles = self.cycle_limit
204 | cycles = max(1, cycles)
205 | if self.t_mul == 1.0:
206 | return self.t_initial * cycles
207 | else:
208 | return int(math.floor(-self.t_initial * (self.t_mul ** cycles - 1) / (1 - self.t_mul)))
209 |
210 |
211 | def make_scheduler(cfg, optimizer, dataloader_len):
212 | lr_scheduler = cfg.SOLVER.LR_SCHEDULER
213 | lr_mult = cfg.SOLVER.IMS_PER_BATCH / 256
214 | num_epochs = cfg.SOLVER.SCHEDULER_MAX_EPOCHS
215 | lr_min = 0.002 * cfg.SOLVER.BASE_LR
216 | warmup_lr_init = 0.01 * cfg.SOLVER.BASE_LR
217 |
218 | warmup_t = cfg.SOLVER.WARMUP_EPOCHS
219 | noise_range = None
220 |
221 | if lr_scheduler == 'onecycle':
222 | lr_scheduler = OneCycleLR(
223 | optimizer,
224 | max_lr=[cfg.SOLVER.BASE_LR_CLIP * lr_mult, cfg.SOLVER.BASE_LR * lr_mult],
225 | steps_per_epoch=dataloader_len,
226 | epochs=cfg.SOLVER.MAX_EPOCHS,
227 | pct_start=0.2
228 | )
229 | else:
230 | lr_scheduler = CosineLRScheduler(
231 | optimizer,
232 | t_initial=num_epochs,
233 | lr_min=lr_min,
234 | t_mul= 1.,
235 | decay_rate=0.1,
236 | warmup_lr_init=warmup_lr_init,
237 | warmup_t=warmup_t,
238 | cycle_limit=1,
239 | t_in_epochs=True,
240 | noise_range_t=noise_range,
241 | noise_pct= 0.67,
242 | noise_std= 1.,
243 | noise_seed=42,
244 | )
245 |
246 | return lr_scheduler
247 |
248 |
249 |
--------------------------------------------------------------------------------
/src/method.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EricTan7/RAM/477f5051234819a11ca62ec441ecb3ee33abda65/src/method.png
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | from utils.logger import setup_logger
3 | from datasets import make_dataloader
4 | from model import build_model
5 | from solver import make_optimizer, make_scheduler
6 | from processor import do_train
7 | from utils.model_utils import set_seed, load_clip_to_cpu, thread_flag
8 | import torch
9 | import torch.distributed as dist
10 | import argparse
11 | from config import cfg
12 | import warnings
13 | warnings.filterwarnings("ignore")
14 | import time
15 | import wandb
16 |
17 |
18 |
19 | def main(cfg):
20 | set_seed(cfg.SOLVER.SEED)
21 |
22 | if cfg.MODEL.DIST_TRAIN:
23 | torch.cuda.set_device(cfg.LOCAL_RANK)
24 | dist.init_process_group(backend='nccl', init_method='env://')
25 | dist.barrier()
26 |
27 | # 1. Logging
28 | if cfg.WANDB:
29 | run = wandb.init(project=cfg.WANDB_PROJ, config=cfg)
30 | run.name = f'{cfg.DATASETS.NAMES}-{cfg.SOLVER.OPTIMIZER_NAME}-lr{cfg.SOLVER.BASE_LR}'
31 |
32 | output_dir = os.path.join(cfg.OUTPUT_DIR, 'RAM-' + time.strftime('%Y-%m-%d-%H-%M-%S'))
33 | if thread_flag(cfg.MODEL.DIST_TRAIN):
34 | if output_dir and not os.path.exists(output_dir):
35 | os.makedirs(output_dir)
36 | logger = setup_logger("RAM", output_dir, if_train=True)
37 | logger.info("Saving model in the path :{}".format(cfg.OUTPUT_DIR))
38 | logger.info("Running with config:\n{}".format(cfg))
39 |
40 | # 2. Data
41 | train_loader, val_loader, val_loader_gzsl, train_sampler, dataset = make_dataloader(cfg)
42 |
43 | # 3. Model
44 | clip_model = load_clip_to_cpu(cfg)
45 | clip_model_teacher = load_clip_to_cpu(cfg, zero_shot=True).eval()
46 | model = build_model(cfg, clip_model, dataset, clip_model_teacher)
47 |
48 |
49 | if cfg.MODEL.LOAD:
50 | state_dict = torch.load(cfg.TEST.WEIGHT)
51 | model.load_state_dict(state_dict)
52 | print(f"Load weights from: {cfg.TEST.WEIGHT}")
53 |
54 | # 4. Optimizer
55 | optimizer, optimizer_sgd = make_optimizer(cfg, model)
56 | scheduler, scheduler_sgd = make_scheduler(cfg, optimizer, len(train_loader)), make_scheduler(cfg, optimizer_sgd, len(train_loader))
57 |
58 | # 5. Start training
59 | do_train(
60 | cfg,
61 | model,
62 | train_loader,
63 | val_loader,
64 | val_loader_gzsl,
65 | optimizer,
66 | optimizer_sgd,
67 | scheduler,
68 | scheduler_sgd,
69 | output_dir,
70 | train_sampler,
71 | )
72 |
73 |
74 | if __name__ == '__main__':
75 | parser = argparse.ArgumentParser(description="RAM Training")
76 | parser.add_argument(
77 | "--config_file",
78 | help="path to config file", type=str, default="configs/coco.yml"
79 | )
80 | parser.add_argument(
81 | "--local_rank",
82 | type=int, default=0
83 | )
84 | parser.add_argument(
85 | "opts",
86 | help="Modify config options from command-line", nargs=argparse.REMAINDER, default=None
87 | )
88 | args = parser.parse_args()
89 |
90 | if args.config_file != "":
91 | cfg.merge_from_file(args.config_file)
92 | cfg.merge_from_list(args.opts)
93 | cfg.LOCAL_RANK = args.local_rank
94 |
95 | main(cfg)
96 |
97 |
98 |
99 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EricTan7/RAM/477f5051234819a11ca62ec441ecb3ee33abda65/utils/__init__.py
--------------------------------------------------------------------------------
/utils/logger.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import sys
4 | import os
5 |
6 |
7 | def setup_logger(name, save_dir, if_train):
8 | logger = logging.getLogger(name)
9 | logger.setLevel(logging.DEBUG)
10 |
11 | ch = logging.StreamHandler(stream=sys.stdout)
12 | ch.setLevel(logging.DEBUG)
13 | formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s")
14 | ch.setFormatter(formatter)
15 | logger.addHandler(ch)
16 |
17 | if save_dir:
18 | if not os.path.exists(save_dir):
19 | os.makedirs(save_dir)
20 | if if_train:
21 | fh = logging.FileHandler(os.path.join(save_dir, "train_log.txt"), mode='a')
22 | else:
23 | fh = logging.FileHandler(os.path.join(save_dir, "test_log.txt"), mode='a')
24 | fh.setLevel(logging.DEBUG)
25 | fh.setFormatter(formatter)
26 | logger.addHandler(fh)
27 |
28 | return logger
--------------------------------------------------------------------------------
/utils/meter.py:
--------------------------------------------------------------------------------
1 | class AverageMeter(object):
2 | """Computes and stores the average and current value"""
3 |
4 | def __init__(self):
5 | self.val = 0
6 | self.avg = 0
7 | self.sum = 0
8 | self.count = 0
9 |
10 | def reset(self):
11 | self.val = 0
12 | self.avg = 0
13 | self.sum = 0
14 | self.count = 0
15 |
16 | def update(self, val, n=1):
17 | self.val = val
18 | self.sum += val * n
19 | self.count += n
20 | self.avg = self.sum / self.count
--------------------------------------------------------------------------------
/utils/metrics.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.parallel
3 | import torch.optim
4 | import torch.utils.data.distributed
5 | import numpy as np
6 |
7 |
8 |
9 |
10 | def cal_metrics(p_labels, g_labels):
11 | tp, fp, fn, tn = 0, 0, 0, 0
12 | target = g_labels
13 |
14 | pred = p_labels
15 |
16 | tp += ((pred + target)==2).sum(axis=0)
17 | fp += ((pred - target)==1).sum(axis=0)
18 | fn += ((pred - target)==-1).sum(axis=0)
19 | tn += ((pred + target)==0).sum(axis=0)
20 | p_c = [float(tp[i] / (tp[i] + fp[i])) if tp[i] > 0 else 0.0 for i in range(len(tp))]
21 | r_c = [float(tp[i] / (tp[i] + fn[i])) if tp[i] > 0 else 0.0 for i in range(len(tp))]
22 |
23 | mean_p_c = sum(p_c) / len(p_c)
24 | mean_r_c = sum(r_c) / len(r_c)
25 | if mean_p_c==0 and mean_r_c ==0:
26 | mean_f_c = 0.
27 | else:
28 | mean_f_c = 2 * mean_p_c * mean_r_c / (mean_p_c + mean_r_c)
29 |
30 | p_o = tp.sum() / (tp + fp).sum()
31 | r_o = tp.sum() / (tp + fn).sum()
32 | if p_o==0 and r_o ==0:
33 | f_o = 0.
34 | else:
35 | f_o = 2 * p_o * r_o / (p_o + r_o)
36 |
37 | Result = {}
38 | Result['CP'] = mean_p_c
39 | Result['CR'] = mean_r_c
40 | Result['CF1'] = mean_f_c
41 | Result['OP'] = p_o
42 | Result['OR'] = r_o
43 | Result['OF1'] = f_o
44 |
45 | return Result
46 |
47 |
48 |
49 | def multilabel_evaluation(scores, targets, k=1):
50 | scores = torch.tensor(scores)
51 | targets[targets == -1] = 0
52 | n, c = scores.size()
53 | pred = np.zeros((n, c))
54 | index = scores.topk(k, 1, True, True)[1].numpy()
55 | for i in range(n):
56 | for ind in index[i]:
57 | pred[i, ind] = 1
58 | return cal_metrics(pred, targets)
59 |
--------------------------------------------------------------------------------
/utils/model_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import random
4 | import torch.distributed as dist
5 | from clip import clip
6 | from functools import partial
7 | from collections import OrderedDict
8 | from copy import deepcopy
9 | import os
10 | import errno
11 | import pickle
12 |
13 |
14 |
15 | class ModelEma(torch.nn.Module):
16 | def __init__(self, model, decay=0.9997, device=None):
17 | super(ModelEma, self).__init__()
18 | # make a copy of the model for accumulating moving average of weights
19 | self.module = deepcopy(model)
20 | self.module.eval()
21 |
22 | self.decay = decay
23 | self.device = device
24 | if self.device is not None:
25 | self.module.to(device=device)
26 |
27 | def _update(self, model, update_fn):
28 | with torch.no_grad():
29 | for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()):
30 | if self.device is not None:
31 | model_v = model_v.to(device=self.device)
32 | ema_v.copy_(update_fn(ema_v, model_v))
33 | self.module.temperature = deepcopy(model.temperature)
34 | if hasattr(model, "temperature_glb"):
35 | self.module.temperature_glb = deepcopy(model.temperature_glb)
36 |
37 | def update(self, model):
38 | self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m)
39 |
40 | def set(self, model):
41 | self._update(model, update_fn=lambda e, m: m)
42 |
43 |
44 |
45 |
46 | def set_seed(seed):
47 | torch.manual_seed(seed)
48 | torch.cuda.manual_seed(seed)
49 | torch.cuda.manual_seed_all(seed)
50 | np.random.seed(seed)
51 | random.seed(seed)
52 | torch.backends.cudnn.deterministic = True
53 | torch.backends.cudnn.benchmark = False
54 |
55 |
56 | def thread_flag(dist_train):
57 | if not dist_train:
58 | return True
59 | else:
60 | return dist.get_rank() == 0
61 |
62 |
63 | def getModelSize(model):
64 | param_size = 0
65 | param_sum = 0
66 | grad_param_size = 0
67 | grad_param_sum = 0
68 | for param in model.parameters():
69 | param_size += param.nelement() * param.element_size()
70 | param_sum += param.nelement()
71 | if param.requires_grad == True:
72 | grad_param_size += param.nelement() * param.element_size()
73 | grad_param_sum += param.nelement()
74 | print('total number of params:{:.3f}M'.format(param_sum / 1000 / 1000))
75 | print('trainable number of params:{:.3f}M ({:.5%})'.format(grad_param_sum / 1000 / 1000, grad_param_sum/param_sum))
76 |
77 | return (param_size, param_sum, grad_param_size)
78 |
79 |
80 |
81 | def convert_params_to_value(params):
82 | if params[0] == -1:
83 | return [-1] # not using
84 | elif params[-1] == -1:
85 | return list(range(params[0])) # continuous N layers
86 | else:
87 | return params
88 |
89 |
90 | def load_clip_to_cpu(cfg, zero_shot=False):
91 | backbone_name = cfg.MODEL.BACKBONE
92 | url = clip._MODELS[backbone_name]
93 | model_path = clip._download(url)
94 |
95 | try:
96 | # loading JIT archive
97 | model = torch.jit.load(model_path, map_location="cpu").eval()
98 | state_dict = None
99 |
100 | except RuntimeError:
101 | state_dict = torch.load(model_path, map_location="cpu")
102 |
103 | if zero_shot:
104 | saa_layer = [12, -1] if "ViT-B" in backbone_name else [24, -1]
105 | saa_layer = convert_params_to_value(saa_layer)
106 | design_details = {
107 | "depth_vision": [-1],
108 | "depth_text": [-1],
109 | "SAA_layer": saa_layer
110 | }
111 | print("Build zero-shot CLIP Model")
112 | else:
113 | depth_vision = convert_params_to_value(cfg.MODEL.DEPTH_VISION)
114 | depth_text = convert_params_to_value(cfg.MODEL.DEPTH_TEXT)
115 | saa_layer = convert_params_to_value(cfg.MODEL.SAA_LAYER)
116 | design_details = {
117 | "depth_vision": depth_vision,
118 | "vision_adapt": cfg.MODEL.VISION_ADAPT,
119 | "depth_text": depth_text,
120 | "text_ctx": cfg.MODEL.TEXT_CTX,
121 | "SAA_layer": saa_layer,
122 | "kernel_size": cfg.MODEL.KERNEL_SIZE
123 | }
124 | print("Build CLIP Model")
125 |
126 | model = clip.build_model(state_dict or model.state_dict(), cfg.INPUT.SIZE_TRAIN, design_details)
127 | model.visual.SAA_replace()
128 |
129 | return model.float()
130 |
131 |
132 |
133 | def mkdir_if_missing(dirname):
134 | """Create dirname if it is missing."""
135 | if not os.path.exists(dirname):
136 | try:
137 | os.makedirs(dirname)
138 | except OSError as e:
139 | if e.errno != errno.EEXIST:
140 | raise
141 |
142 |
143 | def tolist_if_not(x):
144 | """Convert to a list."""
145 | if not isinstance(x, list):
146 | x = [x]
147 | return x
148 |
149 |
150 | def save_checkpoint(
151 | state,
152 | save_dir,
153 | is_best=False,
154 | remove_module_from_keys=True
155 | ):
156 | r"""Save checkpoint.
157 |
158 | Args:
159 | state (dict): dictionary.
160 | save_dir (str): directory to save checkpoint.
161 | is_best (bool, optional): if True, this checkpoint will be copied and named
162 | ``model-best.pth.tar``. Default is False.
163 | remove_module_from_keys (bool, optional): whether to remove "module."
164 | from layer names. Default is True.
165 | model_name (str, optional): model name to save.
166 | """
167 | mkdir_if_missing(save_dir)
168 |
169 | if remove_module_from_keys:
170 | # remove 'module.' in state_dict's keys
171 | state_dict = state["state_dict"]
172 | new_state_dict = OrderedDict()
173 | for k, v in state_dict.items():
174 | if k.startswith("module."):
175 | k = k[7:]
176 | new_state_dict[k] = v
177 | state["state_dict"] = new_state_dict
178 |
179 | # save model
180 | iters = state["iters"]
181 | if is_best:
182 | model_name = "model-best.pth"
183 | else:
184 | model_name = f"model-iters{iters}.pth"
185 | fpath = os.path.join(save_dir, model_name)
186 |
187 | torch.save(state, fpath)
188 |
189 |
190 | def load_checkpoint(fpath):
191 | r"""Load checkpoint.
192 |
193 | ``UnicodeDecodeError`` can be well handled, which means
194 | python2-saved files can be read from python3.
195 |
196 | Args:
197 | fpath (str): path to checkpoint.
198 |
199 | Returns:
200 | dict
201 |
202 | Examples::
203 | fpath = 'log/my_model/model.pth.tar-10'
204 | checkpoint = load_checkpoint(fpath)
205 | """
206 | if fpath is None:
207 | raise ValueError("File path is None")
208 |
209 | if not os.path.exists(fpath):
210 | raise FileNotFoundError('File is not found at "{}"'.format(fpath))
211 |
212 | map_location = None if torch.cuda.is_available() else "cpu"
213 |
214 | try:
215 | checkpoint = torch.load(fpath, map_location=map_location)
216 |
217 | except UnicodeDecodeError:
218 | pickle.load = partial(pickle.load, encoding="latin1")
219 | pickle.Unpickler = partial(pickle.Unpickler, encoding="latin1")
220 | checkpoint = torch.load(
221 | fpath, pickle_module=pickle, map_location=map_location
222 | )
223 |
224 | except Exception:
225 | print('Unable to load checkpoint from "{}"'.format(fpath))
226 | raise
227 |
228 | return checkpoint
229 |
230 |
231 |
232 | def load_pretrained_weights(model, weight_path):
233 | r"""Load pretrianed weights to model.
234 |
235 | Features::
236 | - Incompatible layers (unmatched in name or size) will be ignored.
237 | - Can automatically deal with keys containing "module.".
238 |
239 | Args:
240 | model (nn.Module): network model.
241 | weight_path (str): path to pretrained weights.
242 |
243 | Examples::
244 | # >>> weight_path = 'log/my_model/model-best.pth.tar'
245 | # >>> load_pretrained_weights(model, weight_path)
246 | """
247 | checkpoint = load_checkpoint(weight_path)
248 | if "state_dict" in checkpoint:
249 | state_dict = checkpoint["state_dict"]
250 | else:
251 | state_dict = checkpoint
252 |
253 | model_dict = model.state_dict()
254 | new_state_dict = OrderedDict()
255 | matched_layers, discarded_layers = [], []
256 |
257 | for k, v in state_dict.items():
258 | if k.startswith("module."):
259 | k = k[7:] # discard module.
260 |
261 | if k in model_dict and model_dict[k].size() == v.size():
262 | new_state_dict[k] = v
263 | matched_layers.append(k)
264 | else:
265 | discarded_layers.append(k)
266 |
267 | model_dict.update(new_state_dict)
268 | model.load_state_dict(model_dict)
269 |
270 | if len(matched_layers) == 0:
271 | print(
272 | f"Cannot load {weight_path} (check the key names manually)"
273 | )
274 | else:
275 | print(f"Successfully loaded pretrained weights from {weight_path}")
276 | if len(discarded_layers) > 0:
277 | print(
278 | f"Layers discarded due to unmatched keys or size: {discarded_layers}"
279 | )
280 |
--------------------------------------------------------------------------------