├── ModifiedCLIP ├── __init__.py ├── bpe_simple_vocab_16e6.txt.gz ├── clip.py ├── model.py └── simple_tokenizer.py ├── README.md ├── datasets ├── __init__.py ├── datasets_gen │ ├── __init__.py │ ├── hico.py │ ├── hico_eval.py │ ├── hico_eval_triplet.py │ ├── vcoco.py │ └── vcoco_eval.py ├── datasets_generate_feature │ ├── __init__.py │ ├── hico.py │ ├── hico_eval.py │ ├── hico_eval_triplet.py │ ├── hico_text_label.py │ ├── swig.py │ ├── swig_evaluator.py │ ├── swig_v1_categories.py │ ├── transforms.py │ ├── vcoco.py │ ├── vcoco_eval.py │ └── vcoco_text_label.py ├── hico_text_label.py ├── static_hico.py ├── transforms.py └── vcoco_text_label.py ├── engine.py ├── main.py ├── models ├── __init__.py ├── backbone.py ├── generate_image_feature │ ├── gen.py │ └── generate_verb.py ├── matcher.py ├── models_gen │ ├── gen.py │ └── gen_vlkt.py ├── models_hoiclip │ ├── gen.py │ └── hoiclip.py ├── position_encoding.py └── visualization_hoiclip │ ├── et_gen.py │ ├── gen.py │ ├── gen_vlkt.py │ └── oir_gen.py ├── paper_images └── intro.png ├── requirements.txt ├── scripts ├── generate_verb.sh ├── train_hico.sh ├── train_hico_frac.sh ├── train_hico_nrf_uc.sh ├── train_hico_rf_uc.sh ├── train_hico_uo.sh ├── train_hico_uv.sh ├── train_vcoco.sh └── visualization_hico.sh ├── tmp ├── vcoco_verb.pth ├── verb.pth └── vis_file_names.json ├── tools ├── convert_parameters.py ├── convert_vcoco_annotations.py └── covert_annot_for_official_eval.py └── util ├── box_ops.py ├── logger.py ├── misc.py ├── scheduler.py └── topk.py /ModifiedCLIP/__init__.py: -------------------------------------------------------------------------------- 1 | from .clip import * 2 | -------------------------------------------------------------------------------- /ModifiedCLIP/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Artanic30/HOICLIP/ee4db062097410abdd20fa96d40d26aaca1f19da/ModifiedCLIP/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /ModifiedCLIP/clip.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import urllib 4 | import warnings 5 | from typing import Any, Union, List 6 | from pkg_resources import packaging 7 | 8 | import torch 9 | from PIL import Image 10 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize 11 | from tqdm import tqdm 12 | 13 | from .model import build_model 14 | from .simple_tokenizer import SimpleTokenizer as _Tokenizer 15 | 16 | try: 17 | from torchvision.transforms import InterpolationMode 18 | BICUBIC = InterpolationMode.BICUBIC 19 | except ImportError: 20 | BICUBIC = Image.BICUBIC 21 | 22 | 23 | if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"): 24 | warnings.warn("PyTorch version 1.7.1 or higher is recommended") 25 | 26 | 27 | __all__ = ["available_models", "load", "tokenize"] 28 | _tokenizer = _Tokenizer() 29 | 30 | _MODELS = { 31 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 32 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 33 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", 34 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", 35 | "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt", 36 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 37 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", 38 | "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", 39 | "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt", 40 | } 41 | 42 | 43 | def _download(url: str, root: str): 44 | os.makedirs(root, exist_ok=True) 45 | filename = os.path.basename(url) 46 | 47 | expected_sha256 = url.split("/")[-2] 48 | download_target = os.path.join(root, filename) 49 | 50 | if os.path.exists(download_target) and not os.path.isfile(download_target): 51 | raise RuntimeError(f"{download_target} exists and is not a regular file") 52 | 53 | if os.path.isfile(download_target): 54 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: 55 | return download_target 56 | else: 57 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") 58 | 59 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 60 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop: 61 | while True: 62 | buffer = source.read(8192) 63 | if not buffer: 64 | break 65 | 66 | output.write(buffer) 67 | loop.update(len(buffer)) 68 | 69 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: 70 | raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") 71 | 72 | return download_target 73 | 74 | 75 | def _convert_image_to_rgb(image): 76 | return image.convert("RGB") 77 | 78 | 79 | def _transform(n_px): 80 | return Compose([ 81 | Resize(n_px, interpolation=BICUBIC), 82 | CenterCrop(n_px), 83 | _convert_image_to_rgb, 84 | ToTensor(), 85 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 86 | ]) 87 | 88 | 89 | def available_models() -> List[str]: 90 | """Returns the names of available CLIP models""" 91 | return list(_MODELS.keys()) 92 | 93 | 94 | def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None): 95 | """Load a CLIP model 96 | 97 | Parameters 98 | ---------- 99 | name : str 100 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 101 | 102 | device : Union[str, torch.device] 103 | The device to put the loaded model 104 | 105 | jit : bool 106 | Whether to load the optimized JIT model or more hackable non-JIT model (default). 107 | 108 | download_root: str 109 | path to download the model files; by default, it uses "~/.cache/clip" 110 | 111 | Returns 112 | ------- 113 | model : torch.nn.Module 114 | The CLIP model 115 | 116 | preprocess : Callable[[PIL.Image], torch.Tensor] 117 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 118 | """ 119 | if name in _MODELS: 120 | model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip")) 121 | elif os.path.isfile(name): 122 | model_path = name 123 | else: 124 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}") 125 | 126 | with open(model_path, 'rb') as opened_file: 127 | try: 128 | # loading JIT archive 129 | model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval() 130 | state_dict = None 131 | except RuntimeError: 132 | # loading saved state dict 133 | if jit: 134 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") 135 | jit = False 136 | state_dict = torch.load(opened_file, map_location="cpu") 137 | 138 | if not jit: 139 | model = build_model(state_dict or model.state_dict()).to(device) 140 | if str(device) == "cpu": 141 | model.float() 142 | return model, _transform(model.visual.input_resolution) 143 | 144 | # patch the device names 145 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) 146 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] 147 | 148 | def patch_device(module): 149 | try: 150 | graphs = [module.graph] if hasattr(module, "graph") else [] 151 | except RuntimeError: 152 | graphs = [] 153 | 154 | if hasattr(module, "forward1"): 155 | graphs.append(module.forward1.graph) 156 | 157 | for graph in graphs: 158 | for node in graph.findAllNodes("prim::Constant"): 159 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): 160 | node.copyAttributes(device_node) 161 | 162 | model.apply(patch_device) 163 | patch_device(model.encode_image) 164 | patch_device(model.encode_text) 165 | 166 | # patch dtype to float32 on CPU 167 | if str(device) == "cpu": 168 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) 169 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 170 | float_node = float_input.node() 171 | 172 | def patch_float(module): 173 | try: 174 | graphs = [module.graph] if hasattr(module, "graph") else [] 175 | except RuntimeError: 176 | graphs = [] 177 | 178 | if hasattr(module, "forward1"): 179 | graphs.append(module.forward1.graph) 180 | 181 | for graph in graphs: 182 | for node in graph.findAllNodes("aten::to"): 183 | inputs = list(node.inputs()) 184 | for i in [1, 2]: # dtype can be the second or third argument to aten::to() 185 | if inputs[i].node()["value"] == 5: 186 | inputs[i].node().copyAttributes(float_node) 187 | 188 | model.apply(patch_float) 189 | patch_float(model.encode_image) 190 | patch_float(model.encode_text) 191 | 192 | model.float() 193 | 194 | return model, _transform(model.input_resolution.item()) 195 | 196 | 197 | def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]: 198 | """ 199 | Returns the tokenized representation of given input string(s) 200 | 201 | Parameters 202 | ---------- 203 | texts : Union[str, List[str]] 204 | An input string or a list of input strings to tokenize 205 | 206 | context_length : int 207 | The context length to use; all CLIP models use 77 as the context length 208 | 209 | truncate: bool 210 | Whether to truncate the text in case its encoding is longer than the context length 211 | 212 | Returns 213 | ------- 214 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]. 215 | We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long. 216 | """ 217 | if isinstance(texts, str): 218 | texts = [texts] 219 | 220 | sot_token = _tokenizer.encoder["<|startoftext|>"] 221 | eot_token = _tokenizer.encoder["<|endoftext|>"] 222 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 223 | if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"): 224 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 225 | else: 226 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.int) 227 | 228 | for i, tokens in enumerate(all_tokens): 229 | if len(tokens) > context_length: 230 | if truncate: 231 | tokens = tokens[:context_length] 232 | tokens[-1] = eot_token 233 | else: 234 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") 235 | result[i, :len(tokens)] = torch.tensor(tokens) 236 | 237 | return result 238 | -------------------------------------------------------------------------------- /ModifiedCLIP/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | 6 | import ftfy 7 | import regex as re 8 | 9 | 10 | @lru_cache() 11 | def default_bpe(): 12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 13 | 14 | 15 | @lru_cache() 16 | def bytes_to_unicode(): 17 | """ 18 | Returns list of utf-8 byte and a corresponding list of unicode strings. 19 | The reversible bpe codes work on unicode strings. 20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 22 | This is a signficant percentage of your normal, say, 32K bpe vocab. 23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 24 | And avoids mapping to whitespace/control characters the bpe code barfs on. 25 | """ 26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 27 | cs = bs[:] 28 | n = 0 29 | for b in range(2**8): 30 | if b not in bs: 31 | bs.append(b) 32 | cs.append(2**8+n) 33 | n += 1 34 | cs = [chr(n) for n in cs] 35 | return dict(zip(bs, cs)) 36 | 37 | 38 | def get_pairs(word): 39 | """Return set of symbol pairs in a word. 40 | Word is represented as tuple of symbols (symbols being variable-length strings). 41 | """ 42 | pairs = set() 43 | prev_char = word[0] 44 | for char in word[1:]: 45 | pairs.add((prev_char, char)) 46 | prev_char = char 47 | return pairs 48 | 49 | 50 | def basic_clean(text): 51 | text = ftfy.fix_text(text) 52 | text = html.unescape(html.unescape(text)) 53 | return text.strip() 54 | 55 | 56 | def whitespace_clean(text): 57 | text = re.sub(r'\s+', ' ', text) 58 | text = text.strip() 59 | return text 60 | 61 | 62 | class SimpleTokenizer(object): 63 | def __init__(self, bpe_path: str = default_bpe()): 64 | self.byte_encoder = bytes_to_unicode() 65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 67 | merges = merges[1:49152-256-2+1] 68 | merges = [tuple(merge.split()) for merge in merges] 69 | vocab = list(bytes_to_unicode().values()) 70 | vocab = vocab + [v+'' for v in vocab] 71 | for merge in merges: 72 | vocab.append(''.join(merge)) 73 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 74 | self.encoder = dict(zip(vocab, range(len(vocab)))) 75 | self.decoder = {v: k for k, v in self.encoder.items()} 76 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 79 | 80 | def bpe(self, token): 81 | if token in self.cache: 82 | return self.cache[token] 83 | word = tuple(token[:-1]) + ( token[-1] + '',) 84 | pairs = get_pairs(word) 85 | 86 | if not pairs: 87 | return token+'' 88 | 89 | while True: 90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 91 | if bigram not in self.bpe_ranks: 92 | break 93 | first, second = bigram 94 | new_word = [] 95 | i = 0 96 | while i < len(word): 97 | try: 98 | j = word.index(first, i) 99 | new_word.extend(word[i:j]) 100 | i = j 101 | except: 102 | new_word.extend(word[i:]) 103 | break 104 | 105 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 106 | new_word.append(first+second) 107 | i += 2 108 | else: 109 | new_word.append(word[i]) 110 | i += 1 111 | new_word = tuple(new_word) 112 | word = new_word 113 | if len(word) == 1: 114 | break 115 | else: 116 | pairs = get_pairs(word) 117 | word = ' '.join(word) 118 | self.cache[token] = word 119 | return word 120 | 121 | def encode(self, text): 122 | bpe_tokens = [] 123 | text = whitespace_clean(basic_clean(text)).lower() 124 | for token in re.findall(self.pat, text): 125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 127 | return bpe_tokens 128 | 129 | def decode(self, tokens): 130 | text = ''.join([self.decoder[token] for token in tokens]) 131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 132 | return text 133 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # HOICLIP: Efficient-Knowledge-Transfer-for-HOI-Detection-with-Visual-Linguistic-Model 2 | 3 | Code for our CVPR 2023 4 | paper "[HOICLIP: Efficient-Knowledge-Transfer-for-HOI-Detection-with-Visual-Linguistic-Model](https://arxiv.org/abs/2303.15786)" 5 | . 6 | 7 | Contributed by Shan Ning*, Longtian Qiu*, Yongfei Liu, Xuming He. 8 | 9 | ![](paper_images/intro.png) 10 | 11 | ## Installation 12 | 13 | Install the dependencies. 14 | 15 | ``` 16 | pip install -r requirements.txt 17 | ``` 18 | 19 | ## Data preparation 20 | 21 | ### HICO-DET 22 | 23 | HICO-DET dataset can be downloaded [here](https://drive.google.com/open?id=1QZcJmGVlF9f4h-XLWe9Gkmnmj2z1gSnk). After 24 | finishing downloading, unpack the tarball (`hico_20160224_det.tar.gz`) to the `data` directory. 25 | 26 | Instead of using the original annotations files, we use the annotation files provided by the PPDM authors. The 27 | annotation files can be downloaded from [here](https://drive.google.com/open?id=1dUByzVzM6z1Oq4gENa1-t0FLhr0UtDaS). The 28 | downloaded annotation files have to be placed as follows. 29 | For fractional data setting, we provide the 30 | annotations [here](https://drive.google.com/file/d/13O_uUv_17-Db9ghDqo4z2s3MZlfZJtgi/view?usp=sharing). After 31 | decompress, the files should be placed under `data/hico_20160224_det/annotations`. 32 | 33 | ``` 34 | data 35 | └─ hico_20160224_det 36 | |─ annotations 37 | | |─ trainval_hico.json 38 | | |─ test_hico.json 39 | | |─ corre_hico.json 40 | | |─ trainval_hico_5%.json 41 | | |─ trainval_hico_15%.json 42 | | |─ trainval_hico_25%.json 43 | | └─ trainval_hico_50%.json 44 | : 45 | ``` 46 | 47 | ### V-COCO 48 | 49 | First clone the repository of V-COCO from [here](https://github.com/s-gupta/v-coco), and then follow the instruction to 50 | generate the file `instances_vcoco_all_2014.json`. Next, download the prior file `prior.pickle` 51 | from [here](https://drive.google.com/drive/folders/10uuzvMUCVVv95-xAZg5KS94QXm7QXZW4). Place the files and make 52 | directories as follows. 53 | 54 | ``` 55 | GEN-VLKT 56 | |─ data 57 | │ └─ v-coco 58 | | |─ data 59 | | | |─ instances_vcoco_all_2014.json 60 | | | : 61 | | |─ prior.pickle 62 | | |─ images 63 | | | |─ train2014 64 | | | | |─ COCO_train2014_000000000009.jpg 65 | | | | : 66 | | | └─ val2014 67 | | | |─ COCO_val2014_000000000042.jpg 68 | | | : 69 | | |─ annotations 70 | : : 71 | ``` 72 | 73 | For our implementation, the annotation file have to be converted to the HOIA format. The conversion can be conducted as 74 | follows. 75 | 76 | ``` 77 | PYTHONPATH=data/v-coco \ 78 | python convert_vcoco_annotations.py \ 79 | --load_path data/v-coco/data \ 80 | --prior_path data/v-coco/prior.pickle \ 81 | --save_path data/v-coco/annotations 82 | ``` 83 | 84 | Note that only Python2 can be used for this conversion because `vsrl_utils.py` in the v-coco repository shows a error 85 | with Python3. 86 | 87 | V-COCO annotations with the HOIA format, `corre_vcoco.npy`, `test_vcoco.json`, and `trainval_vcoco.json` will be 88 | generated to `annotations` directory. 89 | 90 | ## Pre-trained model 91 | 92 | Download the pretrained model of DETR detector for [ResNet50](https://dl.fbaipublicfiles.com/detr/detr-r50-e632da11.pth) 93 | , and put it to the `params` directory. 94 | 95 | ``` 96 | python ./tools/convert_parameters.py \ 97 | --load_path params/detr-r50-e632da11.pth \ 98 | --save_path params/detr-r50-pre-2branch-hico.pth \ 99 | --num_queries 64 100 | 101 | python ./tools/convert_parameters.py \ 102 | --load_path params/detr-r50-e632da11.pth \ 103 | --save_path params/detr-r50-pre-2branch-vcoco.pth \ 104 | --dataset vcoco \ 105 | --num_queries 64 106 | ``` 107 | 108 | ## Training 109 | 110 | After the preparation, you can start training with the following commands. 111 | 112 | ### HICO-DET 113 | 114 | ``` 115 | # default setting 116 | sh ./scripts/train_hico.sh 117 | ``` 118 | 119 | ### V-COCO 120 | 121 | ``` 122 | sh ./scripts/train_vcoco.sh 123 | ``` 124 | 125 | ### Zero-shot 126 | 127 | ``` 128 | # rare first unseen combination setting 129 | sh ./scripts/train_hico_rf_uc.sh 130 | # non rare first unseen combination setting 131 | sh ./scripts/train_hico_nrf_uc.sh 132 | # unseen object setting 133 | sh ./scripts/train_hico_uo.sh 134 | # unseen verb setting 135 | sh ./scripts/train_hico_uv.sh 136 | ``` 137 | 138 | ### Fractional data 139 | 140 | ``` 141 | # 50% fractional data 142 | sh ./scripts/train_hico_frac.sh 143 | ``` 144 | 145 | ### Generate verb representation for Visual Semantic Arithmetic 146 | 147 | ``` 148 | sh ./scripts/generate_verb.sh 149 | ``` 150 | 151 | We provide the generated verb representation in `./tmp/verb.pth` for hico and `./tmp/vcoco_verb.pth` for vcoco. 152 | 153 | ## Evaluation 154 | 155 | ### HICO-DET 156 | 157 | You can conduct the evaluation with trained parameters for HICO-DET as follows. 158 | 159 | ``` 160 | python -m torch.distributed.launch \ 161 | --nproc_per_node=2 \ 162 | --use_env \ 163 | main.py \ 164 | --pretrained [path to your checkpoint] \ 165 | --dataset_file hico \ 166 | --hoi_path data/hico_20160224_det \ 167 | --num_obj_classes 80 \ 168 | --num_verb_classes 117 \ 169 | --backbone resnet50 \ 170 | --num_queries 64 \ 171 | --dec_layers 3 \ 172 | --eval \ 173 | --zero_shot_type default \ 174 | --with_clip_label \ 175 | --with_obj_clip_label \ 176 | --use_nms_filter 177 | ``` 178 | 179 | For the official evaluation (reported in paper), you need to covert the prediction file to an official prediction format 180 | following [this file](./tools/covert_annot_for_official_eval.py), and then 181 | follow [PPDM](https://github.com/YueLiao/PPDM) evaluation steps. 182 | 183 | [//]: # (### V-COCO) 184 | 185 | [//]: # () 186 | [//]: # (Firstly, you need the add the following main function to the vsrl_eval.py in data/v-coco.) 187 | 188 | [//]: # () 189 | [//]: # (```) 190 | 191 | [//]: # (if __name__ == '__main__':) 192 | 193 | [//]: # ( import sys) 194 | 195 | [//]: # () 196 | [//]: # ( vsrl_annot_file = 'data/vcoco/vcoco_test.json') 197 | 198 | [//]: # ( coco_file = 'data/instances_vcoco_all_2014.json') 199 | 200 | [//]: # ( split_file = 'data/splits/vcoco_test.ids') 201 | 202 | [//]: # () 203 | [//]: # ( vcocoeval = VCOCOeval(vsrl_annot_file, coco_file, split_file)) 204 | 205 | [//]: # () 206 | [//]: # ( det_file = sys.argv[1]) 207 | 208 | [//]: # ( vcocoeval._do_eval(det_file, ovr_thresh=0.5)) 209 | 210 | [//]: # (```) 211 | 212 | [//]: # () 213 | [//]: # (Next, for the official evaluation of V-COCO, a pickle file of detection results have to be generated. You can generate) 214 | 215 | [//]: # (the file with the following command. and then evaluate it as follows.) 216 | 217 | [//]: # () 218 | [//]: # (```) 219 | 220 | [//]: # (python generate_vcoco_official.py \) 221 | 222 | [//]: # ( --param_path pretrained/VCOCO_GEN_VLKT_S.pth \) 223 | 224 | [//]: # ( --save_path vcoco.pickle \) 225 | 226 | [//]: # ( --hoi_path data/v-coco \) 227 | 228 | [//]: # ( --num_queries 64 \) 229 | 230 | [//]: # ( --dec_layers 3 \) 231 | 232 | [//]: # ( --use_nms_filter \) 233 | 234 | [//]: # ( --with_clip_label \) 235 | 236 | [//]: # ( --with_obj_clip_label) 237 | 238 | [//]: # () 239 | [//]: # (cd data/v-coco) 240 | 241 | [//]: # (python vsrl_eval.py vcoco.pickle) 242 | 243 | [//]: # () 244 | [//]: # (```) 245 | 246 | ### Zero-shot 247 | 248 | ``` 249 | python -m torch.distributed.launch \ 250 | --nproc_per_node=8 \ 251 | --use_env \ 252 | main.py \ 253 | --pretrained [path to your checkpoint] \ 254 | --dataset_file hico \ 255 | --hoi_path data/hico_20160224_det \ 256 | --num_obj_classes 80 \ 257 | --num_verb_classes 117 \ 258 | --backbone resnet50 \ 259 | --num_queries 64 \ 260 | --dec_layers 3 \ 261 | --eval \ 262 | --with_clip_label \ 263 | --with_obj_clip_label \ 264 | --use_nms_filter \ 265 | --zero_shot_type rare_first \ 266 | --del_unseen 267 | ``` 268 | 269 | ### Training Free Enhancement 270 | The `Training Free Enhancement` is used when args.training_free_enhancement_path is not empty. 271 | The results are placed in args.output_dir/args.training_free_enhancement_path. 272 | You may refer to codes in `engine.py:202`. 273 | By default, we set the topk to [10, 20, 30, 40, 50]. 274 | 275 | ## Visualization 276 | 277 | Script for visualization is in `scripts/visualization_hico.sh` 278 | You may need to adjust the file paths with TODO comment in `visualization_hoiclip/gen_vlkt.py` and currently the code 279 | visualize fail cases in some zero-shot setting. For detail information, you may refer to the comments. 280 | 281 | ## Regular HOI Detection Results 282 | 283 | ### HICO-DET 284 | 285 | | | Full (D) |Rare (D)|Non-rare (D)|Full(KO)|Rare (KO)|Non-rare (KO)|Download| Conifg | 286 | |:--------|:--------:| :---: | :---: | :---: |:-------:|:-----------:| :---: |:---------------------------------:| 287 | | HOICLIP | 34.69 | 31.12 |35.74 | 37.61| 34.47 | 38.54 | [model](https://drive.google.com/file/d/1q3JuEzICoppij3Wce9QfwZ1k9a4HZ9or/view?usp=drive_link) | [config](./scripts/train_hico.sh) | 288 | 289 | D: Default, KO: Known object. The best result is achieved with training free enhancement (topk=10). 290 | 291 | ### HICO-DET Fractional Setting 292 | 293 | | | Fractional |Full| Rare | Non-rare | Config | 294 | | :--- |:----------:| :---: |:----:|:---------:|:----------------------------------------------:| 295 | | HOICLIP| 5% |22.64 |21.94 | 24.28 | [config](./scripts/train_hico_frac.sh) | 296 | | HOICLIP| 15% |27.07 | 24.59 | 29.38 | [config](./scripts/train_hico_frac.sh) | 297 | | HOICLIP| 25% |28.44 |25.47| 30.52 | [config](./scripts/train_hico_frac.sh) | 298 | | HOICLIP| 50% |30.88|26.05 | 32.97 | [config](./scripts/train_hico_frac.sh) | 299 | 300 | You may need to change the `--frac [portion]%` in the scripts. 301 | 302 | ### V-COCO 303 | 304 | | | Scenario 1 | Scenario 2 | Download | Config | 305 | | :--- | :---: | :---: | :---: |:----------------------------------:| 306 | |HOICLIP| 63.50| 64.81 | [model](https://drive.google.com/file/d/1PAT2P3TaBCwG3AHuFcbe3iOk2__XOf_R/view?usp=drive_link) | [config](./scripts/train_vcoco.sh) | 307 | 308 | ## Zero-shot HOI Detection Results 309 | 310 | | |Type |Unseen| Seen| Full|Download| Conifg | 311 | | :--- | :---: | :---: | :---: | :---: | :---: |:------------------------------------------:| 312 | | HOICLIP|RF-UC |25.53 |34.85 |32.99| [model](https://drive.google.com/file/d/1E7QLhKgsC1qutGUinXIPmANRR3glYQ1h/view?usp=sharing)| [config](./scripts/train_hico_rf_uc.sh) | 313 | | HOICLIP|NF-UC |26.39| 28.10| 27.75| [model](https://drive.google.com/file/d/1W1zUEX3uDJN32UMI8seTDzmXZ5i9uBz7/view?usp=drive_link)| [config](./scripts/train_hico_nrf_uc.sh) | 314 | | HOICLIP|UO |16.20| 30.99| 28.53| [model](https://drive.google.com/file/d/1oOe8rOwGDugIhd5N3-dlwyf5SpxYFkHE/view?usp=drive_link)| [config](./scripts/train_hico_uo.sh) | 315 | | HOICLIP|UV|24.30| 32.19| 31.09| [model](https://drive.google.com/file/d/174J4x0LovEZBnZ_0yAObMsdl5sW9SZ84/view?usp=drive_link)| [config](./scripts/train_hico_uv.sh) | 316 | 317 | We also provide the checkpoints for uc0, uc1, uc2, uc3 settings in [Google Drive](https://drive.google.com/drive/folders/1NddLSPHbNZXlxmIQbcobh2O5KKWAITRo?usp=drive_link) 318 | ## Citation 319 | 320 | Please consider citing our paper if it helps your research. 321 | 322 | ``` 323 | @inproceedings{ning2023hoiclip, 324 | title={HOICLIP: Efficient Knowledge Transfer for HOI Detection with Vision-Language Models}, 325 | author={Ning, Shan and Qiu, Longtian and Liu, Yongfei and He, Xuming}, 326 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 327 | pages={23507--23517}, 328 | year={2023} 329 | } 330 | ``` 331 | 332 | ## Acknowledge 333 | 334 | Codes are built from [GEN-VLKT](https://github.com/YueLiao/gen-vlkt), [PPDM](https://github.com/YueLiao/PPDM) 335 | , [DETR](https://github.com/facebookresearch/detr), [QPIC](https://github.com/hitachi-rd-cv/qpic) 336 | and [CDN](https://github.com/YueLiao/CDN). We thank them for their contributions. 337 | 338 | # Release Schedule 339 | 340 | - [x] Update raw codes(2023/4/14) 341 | - [x] Update readme(2023/7/26) 342 | - [x] Data(2023/7/26) 343 | - [x] Scripts(2023/7/26) 344 | - [x] Performance table(2023/7/26) 345 | - [x] Others(2023/7/26) 346 | - [x] Release trained checkpoints(2023/7/26) 347 | - [x] Default settings(2023/7/26) 348 | - [x] Zero-shot settings(2023/7/26) 349 | - [x] Fractional settings(2023/7/26) 350 | - [x] Clean up codes(2023/7/26) 351 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .datasets_gen.hico import build as build_hico_gen 2 | from .datasets_gen.vcoco import build as build_vcoco_gen 3 | 4 | from .datasets_generate_feature.hico import build as build_hico_generate_verb 5 | from .datasets_generate_feature.vcoco import build as build_vcoco_generate_verb 6 | 7 | def build_dataset(image_set, args): 8 | if args.dataset_root == "GEN": 9 | if args.dataset_file == 'hico': 10 | return build_hico_gen(image_set, args) 11 | if args.dataset_file == 'vcoco': 12 | return build_vcoco_gen(image_set, args) 13 | elif args.dataset_root == "GENERATE_VERB": 14 | if args.dataset_file == 'hico': 15 | return build_hico_generate_verb(image_set, args) 16 | if args.dataset_file == 'vcoco': 17 | return build_vcoco_generate_verb(image_set, args) 18 | 19 | raise ValueError(f'dataset {args.dataset_file} not supported') 20 | -------------------------------------------------------------------------------- /datasets/datasets_gen/__init__.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data 2 | import torchvision 3 | 4 | from .hico import build as build_hico 5 | from .vcoco import build as build_vcoco 6 | 7 | def build_dataset(image_set, args): 8 | if args.dataset_file == 'hico': 9 | return build_hico(image_set, args) 10 | if args.dataset_file == 'vcoco': 11 | return build_vcoco(image_set, args) 12 | raise ValueError(f'dataset {args.dataset_file} not supported') 13 | -------------------------------------------------------------------------------- /datasets/datasets_gen/hico.py: -------------------------------------------------------------------------------- 1 | """ 2 | HICO detection dataset. 3 | """ 4 | from pathlib import Path 5 | 6 | from PIL import Image 7 | import json 8 | from collections import defaultdict 9 | import numpy as np 10 | 11 | import torch 12 | import torch.utils.data 13 | import clip 14 | 15 | import datasets.transforms as T 16 | from datasets.hico_text_label import hico_text_label, hico_unseen_index 17 | 18 | 19 | 20 | class HICODetection(torch.utils.data.Dataset): 21 | def __init__(self, img_set, img_folder, anno_file, clip_feats_folder, transforms, num_queries, args): 22 | self.img_set = img_set 23 | self.img_folder = img_folder 24 | self.clip_feates_folder = clip_feats_folder 25 | with open(anno_file, 'r') as f: 26 | self.annotations = json.load(f) 27 | self._transforms = transforms 28 | 29 | self.num_queries = num_queries 30 | 31 | self.unseen_index = hico_unseen_index.get(args.zero_shot_type, []) 32 | self._valid_obj_ids = (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 33 | 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 34 | 24, 25, 27, 28, 31, 32, 33, 34, 35, 36, 35 | 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 36 | 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 37 | 58, 59, 60, 61, 62, 63, 64, 65, 67, 70, 38 | 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 39 | 82, 84, 85, 86, 87, 88, 89, 90) 40 | self._valid_verb_ids = list(range(1, 118)) 41 | 42 | self.text_label_dict = hico_text_label 43 | self.text_label_ids = list(self.text_label_dict.keys()) 44 | if img_set == 'train' and len(self.unseen_index) != 0 and args.del_unseen: 45 | tmp = [] 46 | for idx, k in enumerate(self.text_label_ids): 47 | if idx in self.unseen_index: 48 | continue 49 | else: 50 | tmp.append(k) 51 | self.text_label_ids = tmp 52 | 53 | total_anno = 0 54 | skip_anno = 0 55 | hoi_count = [0 for i in range(600)] 56 | select_hoi = [0 for i in range(600)] 57 | 58 | if img_set == 'train': 59 | self.ids = [] 60 | for idx, img_anno in enumerate(self.annotations): 61 | new_img_anno = [] 62 | skip_pair = [] 63 | for hoi in img_anno['hoi_annotation']: 64 | hoi_count[hoi['hoi_category_id'] - 1] += 1 65 | if hoi['hoi_category_id'] - 1 in self.unseen_index: 66 | skip_pair.append((hoi['subject_id'], hoi['object_id'])) 67 | 68 | for hoi in img_anno['hoi_annotation']: 69 | if hoi['subject_id'] >= len(img_anno['annotations']) or hoi['object_id'] >= len( 70 | img_anno['annotations']): 71 | new_img_anno = [] 72 | break 73 | if (hoi['subject_id'], hoi['object_id']) not in skip_pair: 74 | new_img_anno.append(hoi) 75 | else: 76 | skip_anno += 1 77 | if len(new_img_anno) > 0: 78 | for pair_info in new_img_anno: 79 | select_hoi[pair_info['hoi_category_id'] - 1] += 1 80 | self.ids.append(idx) 81 | img_anno['hoi_annotation'] = new_img_anno 82 | total_anno += len(new_img_anno) 83 | 84 | # import cv2 85 | # import os 86 | # data_root = os.path.join(os.getcwd(), 'data/hico_20160224_det/images/train2015') 87 | # image = cv2.imread(os.path.join(data_root, img_anno['file_name'])) 88 | else: 89 | self.ids = list(range(len(self.annotations))) 90 | print("{} contains {} images and {} annotations".format(img_set, len(self.ids), total_anno)) 91 | 92 | # UC RF top20 seen instance number 93 | # [4051, 3208, 2385, 2338, 2250, 2206, 2019, 1994, 1904, 1686, 1681, 1625, 94 | # 1590, 1525, 1519, 1488, 1469, 1366, 1269, 1265] 95 | 96 | # Default bottom200 class instance number 97 | # [ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 98 | # 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 99 | # 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 100 | # 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 101 | # 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 102 | # 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 103 | # 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 104 | # 7, 7, 7, 8, 8, 8, 9, 9, 9, 9, 9, 9, 10, 10, 10, 10, 10, 10, 105 | # 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 12, 12, 12, 12, 12, 13, 13, 13, 106 | # 13, 13, 14, 14, 14, 14, 15, 15, 15, 15, 15, 15, 16, 16, 16, 16, 17, 17, 107 | # 17, 17, 17, 18, 18, 18, 18, 19, 19, 19, 20, 20, 20, 20, 21, 21, 21, 21, 108 | # 21, 22] 109 | 110 | device = "cuda" if torch.cuda.is_available() else "cpu" 111 | _, self.clip_preprocess = clip.load(args.clip_model, device) 112 | 113 | def __len__(self): 114 | return len(self.ids) 115 | 116 | def __getitem__(self, idx): 117 | img_anno = self.annotations[self.ids[idx]] 118 | 119 | img = Image.open(self.img_folder / img_anno['file_name']).convert('RGB') 120 | w, h = img.size 121 | 122 | if self.img_set == 'train' and len(img_anno['annotations']) > self.num_queries: 123 | img_anno['annotations'] = img_anno['annotations'][:self.num_queries] 124 | 125 | boxes = [obj['bbox'] for obj in img_anno['annotations']] 126 | # guard against no boxes via resizing 127 | boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4) 128 | 129 | if self.img_set == 'train': 130 | # Add index for confirming which boxes are kept after image transformation 131 | classes = [(i, self._valid_obj_ids.index(obj['category_id'])) for i, obj in 132 | enumerate(img_anno['annotations'])] 133 | else: 134 | classes = [self._valid_obj_ids.index(obj['category_id']) for obj in img_anno['annotations']] 135 | classes = torch.tensor(classes, dtype=torch.int64) 136 | 137 | target = {} 138 | target['orig_size'] = torch.as_tensor([int(h), int(w)]) 139 | target['size'] = torch.as_tensor([int(h), int(w)]) 140 | if self.img_set == 'train': 141 | boxes[:, 0::2].clamp_(min=0, max=w) 142 | boxes[:, 1::2].clamp_(min=0, max=h) 143 | keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0]) 144 | boxes = boxes[keep] 145 | classes = classes[keep] 146 | 147 | target['boxes'] = boxes 148 | target['labels'] = classes 149 | target['iscrowd'] = torch.tensor([0 for _ in range(boxes.shape[0])]) 150 | target['area'] = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) 151 | 152 | if self._transforms is not None: 153 | img_0, target_0 = self._transforms[0](img, target) 154 | img, target = self._transforms[1](img_0, target_0) 155 | clip_inputs = self.clip_preprocess(img_0) 156 | target['clip_inputs'] = clip_inputs 157 | kept_box_indices = [label[0] for label in target['labels']] 158 | 159 | target['labels'] = target['labels'][:, 1] 160 | 161 | obj_labels, verb_labels, sub_boxes, obj_boxes = [], [], [], [] 162 | sub_obj_pairs = [] 163 | hoi_labels = [] 164 | for hoi in img_anno['hoi_annotation']: 165 | # print('hoi: ', hoi) 166 | if hoi['subject_id'] not in kept_box_indices or hoi['object_id'] not in kept_box_indices: 167 | continue 168 | verb_obj_pair = (self._valid_verb_ids.index(hoi['category_id']), 169 | target['labels'][kept_box_indices.index(hoi['object_id'])]) 170 | if verb_obj_pair not in self.text_label_ids: 171 | continue 172 | 173 | sub_obj_pair = (hoi['subject_id'], hoi['object_id']) 174 | if sub_obj_pair in sub_obj_pairs: 175 | verb_labels[sub_obj_pairs.index(sub_obj_pair)][self._valid_verb_ids.index(hoi['category_id'])] = 1 176 | hoi_labels[sub_obj_pairs.index(sub_obj_pair)][self.text_label_ids.index(verb_obj_pair)] = 1 177 | else: 178 | sub_obj_pairs.append(sub_obj_pair) 179 | obj_labels.append(target['labels'][kept_box_indices.index(hoi['object_id'])]) 180 | verb_label = [0 for _ in range(len(self._valid_verb_ids))] 181 | hoi_label = [0] * len(self.text_label_ids) 182 | hoi_label[self.text_label_ids.index(verb_obj_pair)] = 1 183 | verb_label[self._valid_verb_ids.index(hoi['category_id'])] = 1 184 | sub_box = target['boxes'][kept_box_indices.index(hoi['subject_id'])] 185 | obj_box = target['boxes'][kept_box_indices.index(hoi['object_id'])] 186 | verb_labels.append(verb_label) 187 | hoi_labels.append(hoi_label) 188 | sub_boxes.append(sub_box) 189 | obj_boxes.append(obj_box) 190 | 191 | target['filename'] = img_anno['file_name'] 192 | # print('sub_obj_pairs: ', sub_obj_pairs) 193 | if len(sub_obj_pairs) == 0: 194 | target['obj_labels'] = torch.zeros((0,), dtype=torch.int64) 195 | target['verb_labels'] = torch.zeros((0, len(self._valid_verb_ids)), dtype=torch.float32) 196 | target['hoi_labels'] = torch.zeros((0, len(self.text_label_ids)), dtype=torch.float32) 197 | target['sub_boxes'] = torch.zeros((0, 4), dtype=torch.float32) 198 | target['obj_boxes'] = torch.zeros((0, 4), dtype=torch.float32) 199 | else: 200 | target['obj_labels'] = torch.stack(obj_labels) 201 | target['verb_labels'] = torch.as_tensor(verb_labels, dtype=torch.float32) 202 | target['hoi_labels'] = torch.as_tensor(hoi_labels, dtype=torch.float32) 203 | target['sub_boxes'] = torch.stack(sub_boxes) 204 | target['obj_boxes'] = torch.stack(obj_boxes) 205 | else: 206 | target['filename'] = img_anno['file_name'] 207 | target['boxes'] = boxes 208 | target['labels'] = classes 209 | target['id'] = idx 210 | 211 | if self._transforms is not None: 212 | img_0, _ = self._transforms[0](img, None) 213 | img, _ = self._transforms[1](img_0, None) 214 | clip_inputs = self.clip_preprocess(img_0) 215 | target['clip_inputs'] = clip_inputs 216 | 217 | hois = [] 218 | for hoi in img_anno['hoi_annotation']: 219 | hois.append((hoi['subject_id'], hoi['object_id'], self._valid_verb_ids.index(hoi['category_id']))) 220 | target['hois'] = torch.as_tensor(hois, dtype=torch.int64) 221 | 222 | return img, target 223 | 224 | def set_rare_hois(self, anno_file): 225 | with open(anno_file, 'r') as f: 226 | annotations = json.load(f) 227 | 228 | if len(self.unseen_index) == 0: 229 | # no unseen categoruy, use rare to evaluate 230 | counts = defaultdict(lambda: 0) 231 | for img_anno in annotations: 232 | hois = img_anno['hoi_annotation'] 233 | bboxes = img_anno['annotations'] 234 | for hoi in hois: 235 | triplet = (self._valid_obj_ids.index(bboxes[hoi['subject_id']]['category_id']), 236 | self._valid_obj_ids.index(bboxes[hoi['object_id']]['category_id']), 237 | self._valid_verb_ids.index(hoi['category_id'])) 238 | counts[triplet] += 1 239 | self.rare_triplets = [] 240 | self.non_rare_triplets = [] 241 | for triplet, count in counts.items(): 242 | if count < 10: 243 | self.rare_triplets.append(triplet) 244 | else: 245 | self.non_rare_triplets.append(triplet) 246 | print("rare:{}, non-rare:{}".format(len(self.rare_triplets), len(self.non_rare_triplets))) 247 | else: 248 | self.rare_triplets = [] 249 | self.non_rare_triplets = [] 250 | for img_anno in annotations: 251 | hois = img_anno['hoi_annotation'] 252 | bboxes = img_anno['annotations'] 253 | for hoi in hois: 254 | triplet = (self._valid_obj_ids.index(bboxes[hoi['subject_id']]['category_id']), 255 | self._valid_obj_ids.index(bboxes[hoi['object_id']]['category_id']), 256 | self._valid_verb_ids.index(hoi['category_id'])) 257 | if hoi['hoi_category_id'] - 1 in self.unseen_index: 258 | self.rare_triplets.append(triplet) 259 | else: 260 | self.non_rare_triplets.append(triplet) 261 | print("unseen:{}, seen:{}".format(len(self.rare_triplets), len(self.non_rare_triplets))) 262 | 263 | def load_correct_mat(self, path): 264 | self.correct_mat = np.load(path) 265 | 266 | 267 | # Add color jitter to coco transforms 268 | def make_hico_transforms(image_set): 269 | normalize = T.Compose([ 270 | T.ToTensor(), 271 | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 272 | ]) 273 | 274 | scales = [480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800] 275 | 276 | if image_set == 'train': 277 | return [T.Compose([ 278 | T.RandomHorizontalFlip(), 279 | T.ColorJitter(.4, .4, .4), 280 | T.RandomSelect( 281 | T.RandomResize(scales, max_size=1333), 282 | T.Compose([ 283 | T.RandomResize([400, 500, 600]), 284 | T.RandomSizeCrop(384, 600), 285 | T.RandomResize(scales, max_size=1333), 286 | ]))] 287 | ), 288 | normalize 289 | ] 290 | 291 | if image_set == 'val': 292 | return [T.Compose([ 293 | T.RandomResize([800], max_size=1333), 294 | ]), 295 | normalize 296 | ] 297 | 298 | raise ValueError(f'unknown {image_set}') 299 | 300 | 301 | def build(image_set, args): 302 | root = Path(args.hoi_path) 303 | assert root.exists(), f'provided HOI path {root} does not exist' 304 | if args.frac > 0: 305 | PATHS = { 306 | 'train': (root / 'images' / 'train2015', root / 'annotations' / f'trainval_hico_{args.frac}.json', 307 | root / 'clip_feats_pool' / 'train2015'), 308 | 'val': ( 309 | root / 'images' / 'test2015', root / 'annotations' / 'test_hico.json', 310 | root / 'clip_feats_pool' / 'test2015') 311 | } 312 | else: 313 | PATHS = { 314 | 'train': (root / 'images' / 'train2015', root / 'annotations' / 'trainval_hico.json', 315 | root / 'clip_feats_pool' / 'train2015'), 316 | 'val': ( 317 | root / 'images' / 'test2015', root / 'annotations' / 'test_hico.json', 318 | root / 'clip_feats_pool' / 'test2015') 319 | } 320 | CORRECT_MAT_PATH = root / 'annotations' / 'corre_hico.npy' 321 | 322 | img_folder, anno_file, clip_feats_folder = PATHS[image_set] 323 | dataset = HICODetection(image_set, img_folder, anno_file, clip_feats_folder, 324 | transforms=make_hico_transforms(image_set), 325 | num_queries=args.num_queries, args=args) 326 | if image_set == 'val': 327 | dataset.set_rare_hois(PATHS['train'][1]) 328 | dataset.load_correct_mat(CORRECT_MAT_PATH) 329 | return dataset 330 | -------------------------------------------------------------------------------- /datasets/datasets_gen/vcoco.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from PIL import Image 3 | import json 4 | import numpy as np 5 | 6 | import torch 7 | import torch.utils.data 8 | 9 | import datasets.transforms as T 10 | import clip 11 | from datasets.vcoco_text_label import * 12 | 13 | 14 | class VCOCO(torch.utils.data.Dataset): 15 | 16 | def __init__(self, img_set, img_folder, anno_file, transforms, num_queries, args): 17 | self.img_set = img_set 18 | self.img_folder = img_folder 19 | with open(anno_file, 'r') as f: 20 | self.annotations = json.load(f) 21 | self._transforms = transforms 22 | 23 | self.num_queries = num_queries 24 | 25 | self._valid_obj_ids = (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 26 | 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 27 | 24, 25, 27, 28, 31, 32, 33, 34, 35, 36, 28 | 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 29 | 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 30 | 58, 59, 60, 61, 62, 63, 64, 65, 67, 70, 31 | 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 32 | 82, 84, 85, 86, 87, 88, 89, 90) 33 | self._valid_verb_ids = range(29) 34 | 35 | device = "cuda" if torch.cuda.is_available() else "cpu" 36 | _, self.clip_preprocess = clip.load(args.clip_model, device) 37 | 38 | self.text_label_ids = list(vcoco_hoi_text_label.keys()) 39 | 40 | def __len__(self): 41 | return len(self.annotations) 42 | 43 | def __getitem__(self, idx): 44 | img_anno = self.annotations[idx] 45 | 46 | img = Image.open(self.img_folder / img_anno['file_name']).convert('RGB') 47 | w, h = img.size 48 | 49 | if self.img_set == 'train' and len(img_anno['annotations']) > self.num_queries: 50 | img_anno['annotations'] = img_anno['annotations'][:self.num_queries] 51 | 52 | boxes = [obj['bbox'] for obj in img_anno['annotations']] 53 | # guard against no boxes via resizing 54 | boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4) 55 | 56 | if self.img_set == 'train': 57 | # Add index for confirming which boxes are kept after image transformation 58 | classes = [(i, self._valid_obj_ids.index(obj['category_id'])) for i, obj in enumerate(img_anno['annotations'])] 59 | else: 60 | classes = [self._valid_obj_ids.index(obj['category_id']) for obj in img_anno['annotations']] 61 | classes = torch.tensor(classes, dtype=torch.int64) 62 | 63 | target = {} 64 | target['orig_size'] = torch.as_tensor([int(h), int(w)]) 65 | target['size'] = torch.as_tensor([int(h), int(w)]) 66 | if self.img_set == 'train': 67 | boxes[:, 0::2].clamp_(min=0, max=w) 68 | boxes[:, 1::2].clamp_(min=0, max=h) 69 | keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0]) 70 | boxes = boxes[keep] 71 | classes = classes[keep] 72 | 73 | target['boxes'] = boxes 74 | target['labels'] = classes 75 | target['iscrowd'] = torch.tensor([0 for _ in range(boxes.shape[0])]) 76 | target['area'] = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) 77 | 78 | if self._transforms is not None: 79 | img_0, target_0 = self._transforms[0](img, target) 80 | img, target = self._transforms[1](img_0, target_0) 81 | clip_inputs = self.clip_preprocess(img_0) 82 | target['clip_inputs'] = clip_inputs 83 | kept_box_indices = [label[0] for label in target['labels']] 84 | 85 | target['labels'] = target['labels'][:, 1] 86 | 87 | obj_labels, verb_labels, sub_boxes, obj_boxes = [], [], [], [] 88 | sub_obj_pairs = [] 89 | hoi_labels = [] 90 | for hoi in img_anno['hoi_annotation']: 91 | if hoi['subject_id'] not in kept_box_indices or \ 92 | (hoi['object_id'] != -1 and hoi['object_id'] not in kept_box_indices): 93 | continue 94 | 95 | #if hoi['subject_id'] not in kept_box_indices or hoi['object_id'] not in kept_box_indices: 96 | # continue 97 | 98 | if hoi['object_id'] == -1: 99 | verb_obj_pair = (self._valid_verb_ids.index(hoi['category_id']), 80) 100 | else: 101 | verb_obj_pair = (self._valid_verb_ids.index(hoi['category_id']), 102 | target['labels'][kept_box_indices.index(hoi['object_id'])]) 103 | 104 | if verb_obj_pair not in self.text_label_ids: 105 | continue 106 | 107 | sub_obj_pair = (hoi['subject_id'], hoi['object_id']) 108 | if sub_obj_pair in sub_obj_pairs: 109 | verb_labels[sub_obj_pairs.index(sub_obj_pair)][self._valid_verb_ids.index(hoi['category_id'])] = 1 110 | hoi_labels[sub_obj_pairs.index(sub_obj_pair)][self.text_label_ids.index(verb_obj_pair)] = 1 111 | else: 112 | sub_obj_pairs.append(sub_obj_pair) 113 | if hoi['object_id'] == -1: 114 | obj_labels.append(torch.tensor(len(self._valid_obj_ids))) 115 | else: 116 | obj_labels.append(target['labels'][kept_box_indices.index(hoi['object_id'])]) 117 | verb_label = [0 for _ in range(len(self._valid_verb_ids))] 118 | verb_label[self._valid_verb_ids.index(hoi['category_id'])] = 1 119 | hoi_label = [0] * len(self.text_label_ids) 120 | hoi_label[self.text_label_ids.index(verb_obj_pair)] = 1 121 | sub_box = target['boxes'][kept_box_indices.index(hoi['subject_id'])] 122 | if hoi['object_id'] == -1: 123 | obj_box = torch.zeros((4,), dtype=torch.float32) 124 | else: 125 | obj_box = target['boxes'][kept_box_indices.index(hoi['object_id'])] 126 | verb_labels.append(verb_label) 127 | hoi_labels.append(hoi_label) 128 | sub_boxes.append(sub_box) 129 | obj_boxes.append(obj_box) 130 | 131 | target['filename'] = img_anno['file_name'] 132 | if len(sub_obj_pairs) == 0: 133 | target['obj_labels'] = torch.zeros((0,), dtype=torch.int64) 134 | target['verb_labels'] = torch.zeros((0, len(self._valid_verb_ids)), dtype=torch.float32) 135 | #target['hoi_labels'] = torch.zeros((0, len(self._valid_verb_ids)), dtype=torch.float32) 136 | target['hoi_labels'] = torch.zeros((0, len(self.text_label_ids)), dtype=torch.float32) 137 | target['sub_boxes'] = torch.zeros((0, 4), dtype=torch.float32) 138 | target['obj_boxes'] = torch.zeros((0, 4), dtype=torch.float32) 139 | else: 140 | target['obj_labels'] = torch.stack(obj_labels) 141 | target['verb_labels'] = torch.as_tensor(verb_labels, dtype=torch.float32) 142 | #target['hoi_labels'] = torch.as_tensor(verb_labels, dtype=torch.float32) 143 | target['hoi_labels'] = torch.as_tensor(hoi_labels, dtype=torch.float32) 144 | target['sub_boxes'] = torch.stack(sub_boxes) 145 | target['obj_boxes'] = torch.stack(obj_boxes) 146 | else: 147 | target['filename'] = img_anno['file_name'] 148 | target['boxes'] = boxes 149 | target['labels'] = classes 150 | target['id'] = idx 151 | # target['img_id'] = int(img_anno['file_name'].rstrip('.jpg').split('_')[2]) 152 | target['img_id'] = int(img_anno['file_name'].rstrip('.jpg')) 153 | 154 | if self._transforms is not None: 155 | img_0, _ = self._transforms[0](img, None) 156 | img, _ = self._transforms[1](img_0, None) 157 | clip_inputs = self.clip_preprocess(img_0) 158 | target['clip_inputs'] = clip_inputs 159 | 160 | hois = [] 161 | for hoi in img_anno['hoi_annotation']: 162 | hois.append((hoi['subject_id'], hoi['object_id'], self._valid_verb_ids.index(hoi['category_id']))) 163 | target['hois'] = torch.as_tensor(hois, dtype=torch.int64) 164 | 165 | return img, target 166 | 167 | def load_correct_mat(self, path): 168 | self.correct_mat = np.load(path) 169 | 170 | 171 | # Add color jitter to coco transforms 172 | def make_vcoco_transforms(image_set): 173 | 174 | normalize = T.Compose([ 175 | T.ToTensor(), 176 | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 177 | ]) 178 | 179 | scales = [480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800] 180 | 181 | if image_set == 'train': 182 | return [T.Compose([ 183 | T.RandomHorizontalFlip(), 184 | T.ColorJitter(.4, .4, .4), 185 | T.RandomSelect( 186 | T.RandomResize(scales, max_size=1333), 187 | T.Compose([ 188 | T.RandomResize([400, 500, 600]), 189 | T.RandomSizeCrop(384, 600), 190 | T.RandomResize(scales, max_size=1333), 191 | ]))] 192 | ), 193 | normalize 194 | ] 195 | 196 | if image_set == 'val': 197 | return [T.Compose([ 198 | T.RandomResize([800], max_size=1333), 199 | ]), 200 | normalize 201 | ] 202 | 203 | raise ValueError(f'unknown {image_set}') 204 | 205 | 206 | def build(image_set, args): 207 | root = Path(args.hoi_path) 208 | assert root.exists(), f'provided HOI path {root} does not exist' 209 | PATHS = { 210 | 'train': (root / 'images' / 'train2014', root / 'annotations' / 'trainval_vcoco.json'), 211 | 'val': (root / 'images' / 'val2014', root / 'annotations' / 'test_vcoco.json') 212 | } 213 | CORRECT_MAT_PATH = root / 'annotations' / 'corre_vcoco.npy' 214 | 215 | img_folder, anno_file = PATHS[image_set] 216 | dataset = VCOCO(image_set, img_folder, anno_file, transforms=make_vcoco_transforms(image_set), 217 | num_queries=args.num_queries, args=args) 218 | if image_set == 'val': 219 | dataset.load_correct_mat(CORRECT_MAT_PATH) 220 | return dataset 221 | -------------------------------------------------------------------------------- /datasets/datasets_gen/vcoco_eval.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from collections import defaultdict 3 | from datasets.vcoco_text_label import * 4 | 5 | 6 | class VCOCOEvaluator(): 7 | 8 | def __init__(self, preds, gts, correct_mat, use_nms_filter=False): 9 | self.overlap_iou = 0.5 10 | self.max_hois = 100 11 | 12 | self.fp = defaultdict(list) 13 | self.tp = defaultdict(list) 14 | self.score = defaultdict(list) 15 | self.sum_gts = defaultdict(lambda: 0) 16 | 17 | self.verb_classes = ['hold_obj', 'stand', 'sit_instr', 'ride_instr', 'walk', 'look_obj', 'hit_instr', 'hit_obj', 18 | 'eat_obj', 'eat_instr', 'jump_instr', 'lay_instr', 'talk_on_phone_instr', 'carry_obj', 19 | 'throw_obj', 'catch_obj', 'cut_instr', 'cut_obj', 'run', 'work_on_computer_instr', 20 | 'ski_instr', 'surf_instr', 'skateboard_instr', 'smile', 'drink_instr', 'kick_obj', 21 | 'point_instr', 'read_obj', 'snowboard_instr'] 22 | self.thesis_map_indices = [0, 2, 3, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 19, 20, 21, 22, 24, 25, 27, 28] 23 | 24 | self.preds = [] 25 | self.hoi_obj_list = [] 26 | self.verb_hoi_dict = defaultdict(list) 27 | self.vcoco_triplet_labels = list(vcoco_hoi_text_label.keys()) 28 | for index, hoi_pair in enumerate(self.vcoco_triplet_labels): 29 | self.hoi_obj_list.append(hoi_pair[1]) 30 | self.verb_hoi_dict[hoi_pair[0]].append(index) 31 | 32 | self.score_mode = 1 33 | for img_preds in preds: 34 | img_preds = {k: v.to('cpu').numpy() for k, v in img_preds.items()} 35 | bboxes = [{'bbox': bbox, 'category_id': label} for bbox, label in zip(img_preds['boxes'], img_preds['labels'])] 36 | if self.score_mode == 0: 37 | obj_scores = img_preds['obj_scores'] 38 | hoi_scores = img_preds['hoi_scores'] * obj_scores[:, self.hoi_obj_list] 39 | elif self.score_mode == 1: 40 | obj_scores = img_preds['obj_scores'] * img_preds['obj_scores'] 41 | hoi_scores = img_preds['hoi_scores'] + obj_scores[:, self.hoi_obj_list] 42 | else: 43 | raise 44 | 45 | verb_scores = np.zeros((hoi_scores.shape[0], len(self.verb_hoi_dict)))# 64 x 29 46 | for i in range(hoi_scores.shape[0]): 47 | for k,v in self.verb_hoi_dict.items(): 48 | #verb_scores[i][k] = np.sum(hoi_scores[i, v]) 49 | verb_scores[i][k] = np.max(hoi_scores[i, v]) 50 | 51 | verb_labels = np.tile(np.arange(verb_scores.shape[1]), (verb_scores.shape[0], 1)) 52 | subject_ids = np.tile(img_preds['sub_ids'], (verb_scores.shape[1], 1)).T 53 | object_ids = np.tile(img_preds['obj_ids'], (verb_scores.shape[1], 1)).T 54 | 55 | verb_scores = verb_scores.ravel() 56 | verb_labels = verb_labels.ravel() 57 | subject_ids = subject_ids.ravel() 58 | object_ids = object_ids.ravel() 59 | 60 | if len(subject_ids) > 0: 61 | object_labels = np.array([bboxes[object_id]['category_id'] for object_id in object_ids]) 62 | correct_mat = np.concatenate((correct_mat, np.ones((correct_mat.shape[0], 1))), axis=1) 63 | masks = correct_mat[verb_labels, object_labels] 64 | verb_scores *= masks 65 | 66 | hois = [{'subject_id': subject_id, 'object_id': object_id, 'category_id': category_id, 'score': score} for 67 | subject_id, object_id, category_id, score in zip(subject_ids, object_ids, verb_labels, verb_scores)] 68 | hois.sort(key=lambda k: (k.get('score', 0)), reverse=True) 69 | hois = hois[:self.max_hois] 70 | else: 71 | hois = [] 72 | 73 | 74 | self.preds.append({ 75 | 'predictions': bboxes, 76 | 'hoi_prediction': hois 77 | }) 78 | 79 | self.gts = [] 80 | for img_gts in gts: 81 | img_gts = {k: v.to('cpu').numpy() for k, v in img_gts.items() if k != 'id' and k != 'img_id' and k != 'filename'} 82 | self.gts.append({ 83 | 'annotations': [{'bbox': bbox, 'category_id': label} for bbox, label in zip(img_gts['boxes'], img_gts['labels'])], 84 | 'hoi_annotation': [{'subject_id': hoi[0], 'object_id': hoi[1], 'category_id': hoi[2]} for hoi in img_gts['hois']] 85 | }) 86 | for hoi in self.gts[-1]['hoi_annotation']: 87 | self.sum_gts[hoi['category_id']] += 1 88 | 89 | def evaluate(self): 90 | for img_preds, img_gts in zip(self.preds, self.gts): 91 | pred_bboxes = img_preds['predictions'] 92 | gt_bboxes = img_gts['annotations'] 93 | pred_hois = img_preds['hoi_prediction'] 94 | gt_hois = img_gts['hoi_annotation'] 95 | if len(gt_bboxes) != 0: 96 | bbox_pairs, bbox_overlaps = self.compute_iou_mat(gt_bboxes, pred_bboxes) 97 | self.compute_fptp(pred_hois, gt_hois, bbox_pairs, pred_bboxes, bbox_overlaps) 98 | else: 99 | for pred_hoi in pred_hois: 100 | self.tp[pred_hoi['category_id']].append(0) 101 | self.fp[pred_hoi['category_id']].append(1) 102 | self.score[pred_hoi['category_id']].append(pred_hoi['score']) 103 | map = self.compute_map() 104 | return map 105 | 106 | def compute_map(self): 107 | print('------------------------------------------------------------') 108 | ap = defaultdict(lambda: 0) 109 | aps = {} 110 | for category_id in sorted(list(self.sum_gts.keys())): 111 | sum_gts = self.sum_gts[category_id] 112 | if sum_gts == 0: 113 | continue 114 | 115 | tp = np.array((self.tp[category_id])) 116 | fp = np.array((self.fp[category_id])) 117 | if len(tp) == 0: 118 | ap[category_id] = 0 119 | else: 120 | score = np.array(self.score[category_id]) 121 | sort_inds = np.argsort(-score) 122 | fp = fp[sort_inds] 123 | tp = tp[sort_inds] 124 | fp = np.cumsum(fp) 125 | tp = np.cumsum(tp) 126 | rec = tp / sum_gts 127 | prec = tp / (fp + tp) 128 | ap[category_id] = self.voc_ap(rec, prec) 129 | print('{:>23s}: #GTs = {:>04d}, AP = {:>.4f}'.format(self.verb_classes[category_id], sum_gts, ap[category_id])) 130 | aps['AP_{}'.format(self.verb_classes[category_id])] = ap[category_id] 131 | 132 | m_ap_all = np.mean(list(ap.values())) 133 | m_ap_thesis = np.mean([ap[category_id] for category_id in self.thesis_map_indices]) 134 | 135 | print('------------------------------------------------------------') 136 | print('mAP all: {:.4f} mAP thesis: {:.4f}'.format(m_ap_all, m_ap_thesis)) 137 | print('------------------------------------------------------------') 138 | 139 | aps.update({'mAP_all': m_ap_all, 'mAP_thesis': m_ap_thesis}) 140 | 141 | return aps 142 | 143 | def voc_ap(self, rec, prec): 144 | ap = 0. 145 | for t in np.arange(0., 1.1, 0.1): 146 | if np.sum(rec >= t) == 0: 147 | p = 0 148 | else: 149 | p = np.max(prec[rec >= t]) 150 | ap = ap + p / 11. 151 | return ap 152 | 153 | def compute_fptp(self, pred_hois, gt_hois, match_pairs, pred_bboxes, bbox_overlaps): 154 | pos_pred_ids = match_pairs.keys() 155 | vis_tag = np.zeros(len(gt_hois)) 156 | pred_hois.sort(key=lambda k: (k.get('score', 0)), reverse=True) 157 | if len(pred_hois) != 0: 158 | for pred_hoi in pred_hois: 159 | is_match = 0 160 | max_overlap = 0 161 | max_gt_hoi = 0 162 | for gt_hoi in gt_hois: 163 | if len(match_pairs) != 0 and pred_hoi['subject_id'] in pos_pred_ids and \ 164 | gt_hoi['object_id'] == -1: 165 | pred_sub_ids = match_pairs[pred_hoi['subject_id']] 166 | pred_sub_overlaps = bbox_overlaps[pred_hoi['subject_id']] 167 | pred_category_id = pred_hoi['category_id'] 168 | if gt_hoi['subject_id'] in pred_sub_ids and pred_category_id == gt_hoi['category_id']: 169 | is_match = 1 170 | min_overlap_gt = pred_sub_overlaps[pred_sub_ids.index(gt_hoi['subject_id'])] 171 | if min_overlap_gt > max_overlap: 172 | max_overlap = min_overlap_gt 173 | max_gt_hoi = gt_hoi 174 | elif len(match_pairs) != 0 and pred_hoi['subject_id'] in pos_pred_ids and \ 175 | pred_hoi['object_id'] in pos_pred_ids: 176 | pred_sub_ids = match_pairs[pred_hoi['subject_id']] 177 | pred_obj_ids = match_pairs[pred_hoi['object_id']] 178 | pred_sub_overlaps = bbox_overlaps[pred_hoi['subject_id']] 179 | pred_obj_overlaps = bbox_overlaps[pred_hoi['object_id']] 180 | pred_category_id = pred_hoi['category_id'] 181 | if gt_hoi['subject_id'] in pred_sub_ids and gt_hoi['object_id'] in pred_obj_ids and \ 182 | pred_category_id == gt_hoi['category_id']: 183 | is_match = 1 184 | min_overlap_gt = min(pred_sub_overlaps[pred_sub_ids.index(gt_hoi['subject_id'])], 185 | pred_obj_overlaps[pred_obj_ids.index(gt_hoi['object_id'])]) 186 | if min_overlap_gt > max_overlap: 187 | max_overlap = min_overlap_gt 188 | max_gt_hoi = gt_hoi 189 | if is_match == 1 and vis_tag[gt_hois.index(max_gt_hoi)] == 0: 190 | self.fp[pred_hoi['category_id']].append(0) 191 | self.tp[pred_hoi['category_id']].append(1) 192 | vis_tag[gt_hois.index(max_gt_hoi)] = 1 193 | else: 194 | self.fp[pred_hoi['category_id']].append(1) 195 | self.tp[pred_hoi['category_id']].append(0) 196 | self.score[pred_hoi['category_id']].append(pred_hoi['score']) 197 | 198 | def compute_iou_mat(self, bbox_list1, bbox_list2): 199 | iou_mat = np.zeros((len(bbox_list1), len(bbox_list2))) 200 | if len(bbox_list1) == 0 or len(bbox_list2) == 0: 201 | return {} 202 | for i, bbox1 in enumerate(bbox_list1): 203 | for j, bbox2 in enumerate(bbox_list2): 204 | iou_i = self.compute_IOU(bbox1, bbox2) 205 | iou_mat[i, j] = iou_i 206 | 207 | iou_mat_ov=iou_mat.copy() 208 | iou_mat[iou_mat>=self.overlap_iou] = 1 209 | iou_mat[iou_mat 0: 215 | for i, pred_id in enumerate(match_pairs[1]): 216 | if pred_id not in match_pairs_dict.keys(): 217 | match_pairs_dict[pred_id] = [] 218 | match_pair_overlaps[pred_id]=[] 219 | match_pairs_dict[pred_id].append(match_pairs[0][i]) 220 | match_pair_overlaps[pred_id].append(iou_mat_ov[match_pairs[0][i],pred_id]) 221 | return match_pairs_dict, match_pair_overlaps 222 | 223 | def compute_IOU(self, bbox1, bbox2): 224 | if isinstance(bbox1['category_id'], str): 225 | bbox1['category_id'] = int(bbox1['category_id'].replace('\n', '')) 226 | if isinstance(bbox2['category_id'], str): 227 | bbox2['category_id'] = int(bbox2['category_id'].replace('\n', '')) 228 | if bbox1['category_id'] == bbox2['category_id']: 229 | rec1 = bbox1['bbox'] 230 | rec2 = bbox2['bbox'] 231 | # computing area of each rectangles 232 | S_rec1 = (rec1[2] - rec1[0]+1) * (rec1[3] - rec1[1]+1) 233 | S_rec2 = (rec2[2] - rec2[0]+1) * (rec2[3] - rec2[1]+1) 234 | 235 | # computing the sum_area 236 | sum_area = S_rec1 + S_rec2 237 | 238 | # find the each edge of intersect rectangle 239 | left_line = max(rec1[1], rec2[1]) 240 | right_line = min(rec1[3], rec2[3]) 241 | top_line = max(rec1[0], rec2[0]) 242 | bottom_line = min(rec1[2], rec2[2]) 243 | # judge if there is an intersect 244 | if left_line >= right_line or top_line >= bottom_line: 245 | return 0 246 | else: 247 | intersect = (right_line - left_line+1) * (bottom_line - top_line+1) 248 | return intersect / (sum_area - intersect) 249 | else: 250 | return 0 251 | -------------------------------------------------------------------------------- /datasets/datasets_generate_feature/__init__.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data 2 | import torchvision 3 | 4 | from .hico import build as build_hico 5 | from .vcoco import build as build_vcoco 6 | 7 | def build_dataset(image_set, args): 8 | if args.dataset_file == 'hico': 9 | return build_hico(image_set, args) 10 | if args.dataset_file == 'vcoco': 11 | return build_vcoco(image_set, args) 12 | raise ValueError(f'dataset {args.dataset_file} not supported') 13 | -------------------------------------------------------------------------------- /datasets/datasets_generate_feature/swig.py: -------------------------------------------------------------------------------- 1 | """ 2 | SWiG-HOI dataset utils. 3 | """ 4 | import os 5 | import json 6 | import torch 7 | import torch.utils.data 8 | from torchvision.datasets import CocoDetection 9 | import datasets.transforms as T 10 | from PIL import Image 11 | from .swig_v1_categories import SWIG_INTERACTIONS, SWIG_ACTIONS, SWIG_CATEGORIES 12 | from util.sampler import repeat_factors_from_category_frequency, get_dataset_indices 13 | import clip 14 | import copy 15 | 16 | # NOTE: Replace the path to your file 17 | SWIG_ROOT = "./data/swig_hoi/images_512" 18 | SWIG_TRAIN_ANNO = "./data/swig_hoi/annotations/swig_trainval_1000.json" 19 | SWIG_VAL_ANNO = "./data/swig_hoi/annotations/swig_test_1000.json" 20 | SWIG_TEST_ANNO = "./data/swig_hoi/annotations/swig_test_1000.json" 21 | 22 | 23 | class SWiGHOIDetection(CocoDetection): 24 | def __init__(self, img_folder, ann_file, transforms, image_set, repeat_factor_sampling, args): 25 | self.root = img_folder 26 | self.transforms = transforms 27 | # Text description of human-object interactions 28 | dataset_texts, text_mapper = prepare_dataset_text(image_set) 29 | self.dataset_texts = dataset_texts 30 | self.text_mapper = text_mapper 31 | # Load dataset 32 | repeat_factor_sampling = repeat_factor_sampling and image_set == "train" 33 | reverse_text_mapper = {v: k for k, v in text_mapper.items()} 34 | self.dataset_dicts = load_swig_json(ann_file, img_folder, reverse_text_mapper, repeat_factor_sampling) 35 | device = "cuda" if torch.cuda.is_available() else "cpu" 36 | _, self.clip_preprocess = clip.load(args.clip_model, device) 37 | 38 | def __getitem__(self, idx: int): 39 | 40 | filename = self.dataset_dicts[idx]["file_name"] 41 | image = Image.open(filename).convert("RGB") 42 | img_raw = Image.open(filename).convert('RGB') 43 | 44 | w, h = image.size 45 | assert w == self.dataset_dicts[idx]["width"], "image shape is not consistent." 46 | assert h == self.dataset_dicts[idx]["height"], "image shape is not consistent." 47 | 48 | image_id = self.dataset_dicts[idx]["image_id"] 49 | annos = self.dataset_dicts[idx]["annotations"] 50 | 51 | boxes = torch.as_tensor(annos["boxes"], dtype=torch.float32).reshape(-1, 4) 52 | boxes[:, 0::2].clamp_(min=0, max=w) 53 | boxes[:, 1::2].clamp_(min=0, max=h) 54 | 55 | bbox_raw = copy.deepcopy(boxes) 56 | 57 | classes = torch.tensor(annos["classes"], dtype=torch.int64) 58 | aux_classes = torch.tensor(annos["aux_classes"], dtype=torch.int64) 59 | 60 | hoi = annos["hois"] 61 | hoi_label = torch.zeros(len(hoi), 14130) 62 | human_img = [] 63 | object_img = [] 64 | hoi_area_img = [] 65 | obj_cls = [] 66 | hoi_cls = [] 67 | for idx, i in enumerate(hoi): 68 | hoi_label[idx][i['hoi_id']] = 1 69 | 70 | h = bbox_raw[i['subject_id']] 71 | o = bbox_raw[i['object_id']] 72 | obj_cls.append(classes[i['object_id']][1]) 73 | hoi_cls.append(i['hoi_category_id']) 74 | 75 | hoi_bbox = torch.zeros_like(h) 76 | hoi_bbox[0] = torch.min(h[0], o[0]) 77 | hoi_bbox[2] = torch.max(h[2], o[2]) 78 | hoi_bbox[1] = torch.min(h[1], o[1]) 79 | hoi_bbox[3] = torch.max(h[3], o[3]) 80 | 81 | h_img = img_raw.crop(h.tolist()) 82 | o_img = img_raw.crop(o.tolist()) 83 | hoi_img = img_raw.crop(hoi_bbox.tolist()) 84 | human_img.append(self.clip_preprocess(h_img)) 85 | object_img.append(self.clip_preprocess(o_img)) 86 | hoi_area_img.append(self.clip_preprocess(hoi_img)) 87 | 88 | human_img = torch.stack(human_img) 89 | object_img = torch.stack(object_img) 90 | hoi_area_img = torch.stack(hoi_area_img) 91 | obj_cls = torch.tensor(obj_cls) 92 | hoi_cls = torch.tensor(hoi_cls) 93 | 94 | target = {"image_id": torch.tensor(image_id), "orig_size": torch.tensor([h, w]), "boxes": boxes, 95 | "labels": classes, "aux_classes": aux_classes, 'iscrowd': torch.zeros(len(boxes)), 96 | 'hoi_label': hoi_label, 'human_img': human_img, 'object_img': object_img, 97 | 'hoi_area_img': hoi_area_img, 'obj_cls': obj_cls, 'hoi_cls': hoi_cls} 98 | 99 | # if self.transforms is not None: 100 | # image, target = self.transforms(image, target) 101 | 102 | if self.transforms is not None: 103 | img_0, target_0 = self.transforms[0](image, target) 104 | image, target = self.transforms[1](img_0, target_0) 105 | clip_inputs = self.clip_preprocess(img_0) 106 | target['clip_inputs'] = clip_inputs 107 | 108 | return image, target 109 | 110 | def __len__(self): 111 | return len(self.dataset_dicts) 112 | 113 | 114 | def load_swig_json(json_file, image_root, text_mapper, repeat_factor_sampling=False): 115 | """ 116 | Load a json file with HOI's instances annotation. 117 | 118 | Args: 119 | json_file (str): full path to the json file in HOI instances annotation format. 120 | image_root (str or path-like): the directory where the images in this json file exists. 121 | text_mapper (dict): a dictionary to map text descriptions of HOIs to contiguous ids. 122 | repeat_factor_sampling (bool): resampling training data to increase the rate of tail 123 | categories to be observed by oversampling the images that contain them. 124 | Returns: 125 | list[dict]: a list of dicts in the following format. 126 | { 127 | 'file_name': path-like str to load image, 128 | 'height': 480, 129 | 'width': 640, 130 | 'image_id': 222, 131 | 'annotations': { 132 | 'boxes': list[list[int]], # n x 4, bounding box annotations 133 | 'classes': list[int], # n, object category annotation of the bounding boxes 134 | 'aux_classes': list[list], # n x 3, a list of auxiliary object annotations 135 | 'hois': [ 136 | { 137 | 'subject_id': 0, # person box id (corresponding to the list of boxes above) 138 | 'object_id': 1, # object box id (corresponding to the list of boxes above) 139 | 'action_id', 76, # person action category 140 | 'hoi_id', 459, # interaction category 141 | 'text': ('ride', 'skateboard') # text description of human action and object 142 | } 143 | ] 144 | } 145 | } 146 | """ 147 | HOI_MAPPER = {(x["action_id"], x["object_id"]): x["id"] for x in SWIG_INTERACTIONS} 148 | 149 | imgs_anns = json.load(open(json_file, "r")) 150 | 151 | dataset_dicts = [] 152 | images_without_valid_annotations = [] 153 | for anno_dict in imgs_anns: 154 | record = {} 155 | record["file_name"] = os.path.join(image_root, anno_dict["file_name"]) 156 | record["height"] = anno_dict["height"] 157 | record["width"] = anno_dict["width"] 158 | record["image_id"] = anno_dict["img_id"] 159 | 160 | if len(anno_dict["box_annotations"]) == 0 or len(anno_dict["hoi_annotations"]) == 0: 161 | images_without_valid_annotations.append(anno_dict) 162 | continue 163 | 164 | boxes = [obj["bbox"] for obj in anno_dict["box_annotations"]] 165 | classes = [obj["category_id"] for obj in anno_dict["box_annotations"]] 166 | aux_classes = [] 167 | for obj in anno_dict["box_annotations"]: 168 | aux_categories = obj["aux_category_id"] 169 | while len(aux_categories) < 3: 170 | aux_categories.append(-1) 171 | aux_classes.append(aux_categories) 172 | 173 | for hoi in anno_dict["hoi_annotations"]: 174 | target_id = hoi["object_id"] 175 | object_id = classes[target_id] 176 | action_id = hoi["action_id"] 177 | hoi["text"] = generate_text(action_id, object_id) 178 | continguous_id = HOI_MAPPER[(action_id, object_id)] 179 | hoi["hoi_id"] = text_mapper[continguous_id] 180 | 181 | targets = { 182 | "boxes": boxes, 183 | "classes": classes, 184 | "aux_classes": aux_classes, 185 | "hois": anno_dict["hoi_annotations"], 186 | } 187 | 188 | record["annotations"] = targets 189 | dataset_dicts.append(record) 190 | 191 | if repeat_factor_sampling: 192 | repeat_factors = repeat_factors_from_category_frequency(dataset_dicts, repeat_thresh=0.0001) 193 | dataset_indices = get_dataset_indices(repeat_factors) 194 | dataset_dicts = [dataset_dicts[i] for i in dataset_indices] 195 | 196 | return dataset_dicts 197 | 198 | 199 | def generate_text(action_id, object_id): 200 | act = SWIG_ACTIONS[action_id]["name"] 201 | obj = SWIG_CATEGORIES[object_id]["name"] 202 | act_def = SWIG_ACTIONS[action_id]["def"] 203 | obj_def = SWIG_CATEGORIES[object_id]["def"] 204 | obj_gloss = SWIG_CATEGORIES[object_id]["gloss"] 205 | obj_gloss = [obj] + [x for x in obj_gloss if x != obj] 206 | if len(obj_gloss) > 1: 207 | obj_gloss = " or ".join(obj_gloss) 208 | else: 209 | obj_gloss = obj_gloss[0] 210 | 211 | # s = [act, obj_gloss] 212 | s = [act, obj] 213 | return s 214 | 215 | 216 | ''' deprecated, text 217 | # def generate_text(action_id, object_id): 218 | # act = SWIG_ACTIONS[action_id]["name"] 219 | # obj = SWIG_CATEGORIES[object_id]["name"] 220 | # act_def = SWIG_ACTIONS[action_id]["def"] 221 | # obj_def = SWIG_CATEGORIES[object_id]["def"] 222 | # obj_gloss = SWIG_CATEGORIES[object_id]["gloss"] 223 | # obj_gloss = [obj] + [x for x in obj_gloss if x != obj] 224 | # if len(obj_gloss) > 1: 225 | # obj_gloss = " or ".join(obj_gloss) 226 | # else: 227 | # obj_gloss = obj_gloss[0] 228 | # # s = f"A photo of a person {act} with object {obj}. The object {obj} means {obj_def}." 229 | # # s = f"a photo of a person {act} with object {obj}" 230 | # # s = f"A photo of a person {act} with {obj}. The {act} means to {act_def}." 231 | # s = f"A photo of a person {act} with {obj_gloss}. The {act} means to {act_def}." 232 | # return s 233 | ''' 234 | 235 | 236 | def prepare_dataset_text(image_set): 237 | texts = [] 238 | text_mapper = {} 239 | for i, hoi in enumerate(SWIG_INTERACTIONS): 240 | if image_set != "train" and hoi["evaluation"] == 0: continue 241 | action_id = hoi["action_id"] 242 | object_id = hoi["object_id"] 243 | s = generate_text(action_id, object_id) 244 | text_mapper[len(texts)] = i 245 | texts.append(s) 246 | return texts, text_mapper 247 | 248 | 249 | def make_transforms(image_set): 250 | normalize = T.Compose([ 251 | T.ToTensor(), 252 | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 253 | ]) 254 | 255 | scales = [480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800] 256 | 257 | if image_set == 'train': 258 | return [T.Compose([ 259 | T.RandomHorizontalFlip(), 260 | T.ColorJitter(.4, .4, .4), 261 | T.RandomSelect( 262 | T.RandomResize(scales, max_size=1333), 263 | T.Compose([ 264 | T.RandomResize([400, 500, 600]), 265 | T.RandomSizeCrop(384, 600), 266 | T.RandomResize(scales, max_size=1333), 267 | ]))] 268 | ), 269 | normalize 270 | ] 271 | 272 | if image_set == 'val': 273 | return [T.Compose([ 274 | T.RandomResize([800], max_size=1333), 275 | ]), 276 | normalize 277 | ] 278 | 279 | raise ValueError(f'unknown {image_set}') 280 | 281 | 282 | ''' deprecated (Fixed image resolution + random cropping + centering) 283 | def make_transforms(image_set): 284 | 285 | normalize = T.Compose([ 286 | T.ToTensor(), 287 | T.Normalize([0.48145466, 0.4578275, 0.40821073], [0.26862954, 0.26130258, 0.27577711]), 288 | ]) 289 | 290 | if image_set == "train": 291 | return T.Compose([ 292 | T.RandomHorizontalFlip(), 293 | T.ColorJitter(brightness=[0.8, 1.2], contrast=[0.8, 1.2], saturation=[0.8, 1.2]), 294 | T.RandomSelect( 295 | T.ResizeAndCenterCrop(224), 296 | T.Compose([ 297 | T.RandomCrop_InteractionConstraint((0.8, 0.8), 0.9), 298 | T.ResizeAndCenterCrop(224) 299 | ]), 300 | ), 301 | normalize 302 | ]) 303 | if image_set == "val": 304 | return T.Compose([ 305 | T.ResizeAndCenterCrop(224), 306 | normalize 307 | ]) 308 | 309 | raise ValueError(f'unknown {image_set}') 310 | ''' 311 | 312 | 313 | def build(image_set, args): 314 | # NOTE: Replace the path to your file 315 | PATHS = { 316 | "train": (SWIG_ROOT, SWIG_TRAIN_ANNO), 317 | "val": (SWIG_ROOT, SWIG_VAL_ANNO), 318 | "dev": (SWIG_ROOT, SWIG_TEST_ANNO), 319 | } 320 | 321 | img_folder, ann_file = PATHS[image_set] 322 | dataset = SWiGHOIDetection( 323 | img_folder, 324 | ann_file, 325 | transforms=make_transforms(image_set), 326 | image_set=image_set, 327 | repeat_factor_sampling=args.repeat_factor_sampling, 328 | args=args 329 | ) 330 | 331 | return dataset 332 | -------------------------------------------------------------------------------- /datasets/datasets_generate_feature/swig_evaluator.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import collections 4 | import json 5 | import numpy as np 6 | from .swig_v1_categories import SWIG_INTERACTIONS 7 | 8 | 9 | class SWiGEvaluator(object): 10 | ''' Evaluator for SWIG-HOI dataset ''' 11 | def __init__(self, anno_file, output_dir): 12 | eval_hois = [x["id"] for x in SWIG_INTERACTIONS if x["evaluation"] == 1] 13 | size = max(eval_hois) + 1 14 | self.eval_hois = eval_hois 15 | 16 | self.gts = self.load_anno(anno_file) 17 | self.scores = {i: [] for i in range(size)} 18 | self.boxes = {i: [] for i in range(size)} 19 | self.keys = {i: [] for i in range(size)} 20 | self.swig_ap = np.zeros(size) 21 | self.swig_rec = np.zeros(size) 22 | self.output_dir = output_dir 23 | 24 | def update(self, predictions): 25 | # update predictions 26 | for img_id, preds in predictions.items(): 27 | for pred in preds: 28 | hoi_id = pred[0] 29 | score = pred[1] 30 | boxes = pred[2:] 31 | self.scores[hoi_id].append(score) 32 | self.boxes[hoi_id].append(boxes) 33 | self.keys[hoi_id].append(img_id) 34 | 35 | def accumulate(self): 36 | for hoi_id in self.eval_hois: 37 | gts_per_hoi = self.gts[hoi_id] 38 | ap, rec = calc_ap(self.scores[hoi_id], self.boxes[hoi_id], self.keys[hoi_id], gts_per_hoi) 39 | self.swig_ap[hoi_id], self.swig_rec[hoi_id] = ap, rec 40 | 41 | def summarize(self): 42 | eval_hois = np.asarray([x["id"] for x in SWIG_INTERACTIONS if x["evaluation"] == 1]) 43 | zero_hois = np.asarray([x["id"] for x in SWIG_INTERACTIONS if x["frequency"] == 0 and x["evaluation"] == 1]) 44 | rare_hois = np.asarray([x["id"] for x in SWIG_INTERACTIONS if x["frequency"] == 1 and x["evaluation"] == 1]) 45 | nonrare_hois = np.asarray([x["id"] for x in SWIG_INTERACTIONS if x["frequency"] == 2 and x["evaluation"] == 1]) 46 | 47 | full_mAP = np.mean(self.swig_ap[eval_hois]) 48 | zero_mAP = np.mean(self.swig_ap[zero_hois]) 49 | rare_mAP = np.mean(self.swig_ap[rare_hois]) 50 | nonrare_mAP = np.mean(self.swig_ap[nonrare_hois]) 51 | print("zero-shot mAP: {:.2f}".format(zero_mAP * 100.)) 52 | print("rare mAP: {:.2f}".format(rare_mAP * 100.)) 53 | print("nonrare mAP: {:.2f}".format(nonrare_mAP * 100.)) 54 | print("full mAP: {:.2f}".format(full_mAP * 100.)) 55 | 56 | def save_preds(self): 57 | with open(os.path.join(self.output_dir, "preds.pkl"), "wb") as f: 58 | pickle.dump({"scores": self.scores, "boxes": self.boxes, "keys": self.keys}, f) 59 | 60 | def save(self, output_dir=None): 61 | if output_dir is None: 62 | output_dir = self.output_dir 63 | with open(os.path.join(output_dir, "dets.pkl"), "wb") as f: 64 | pickle.dump({"gts": self.gts, "scores": self.scores, "boxes": self.boxes, "keys": self.keys}, f) 65 | 66 | def load_anno(self, anno_file): 67 | with open(anno_file, "r") as f: 68 | dataset_dicts = json.load(f) 69 | 70 | hoi_mapper = {(x["action_id"], x["object_id"]): x["id"] for x in SWIG_INTERACTIONS} 71 | 72 | size = max(self.eval_hois) + 1 73 | gts = {i: collections.defaultdict(list) for i in range(size)} 74 | for anno_dict in dataset_dicts: 75 | image_id = anno_dict["img_id"] 76 | box_annos = anno_dict.get("box_annotations", []) 77 | hoi_annos = anno_dict.get("hoi_annotations", []) 78 | for hoi in hoi_annos: 79 | person_box = box_annos[hoi["subject_id"]]["bbox"] 80 | object_box = box_annos[hoi["object_id"]]["bbox"] 81 | action_id = hoi["action_id"] 82 | object_id = box_annos[hoi["object_id"]]["category_id"] 83 | hoi_id = hoi_mapper[(action_id, object_id)] 84 | gts[hoi_id][image_id].append(person_box + object_box) 85 | 86 | for hoi_id in gts: 87 | for img_id in gts[hoi_id]: 88 | gts[hoi_id][img_id] = np.array(gts[hoi_id][img_id]) 89 | 90 | return gts 91 | 92 | 93 | def calc_ap(scores, boxes, keys, gt_boxes): 94 | if len(keys) == 0: 95 | return 0, 0 96 | 97 | if isinstance(boxes, list): 98 | scores, boxes, key = np.array(scores), np.array(boxes), np.array(keys) 99 | 100 | hit = [] 101 | idx = np.argsort(scores)[::-1] 102 | npos = 0 103 | used = {} 104 | 105 | for key in gt_boxes.keys(): 106 | npos += gt_boxes[key].shape[0] 107 | used[key] = set() 108 | 109 | for i in range(min(len(idx), 19999)): 110 | pair_id = idx[i] 111 | box = boxes[pair_id, :] 112 | key = keys[pair_id] 113 | if key in gt_boxes: 114 | maxi = 0.0 115 | k = -1 116 | for i in range(gt_boxes[key].shape[0]): 117 | tmp = calc_hit(box, gt_boxes[key][i, :]) 118 | if maxi < tmp: 119 | maxi = tmp 120 | k = i 121 | if k in used[key] or maxi < 0.5: 122 | hit.append(0) 123 | else: 124 | hit.append(1) 125 | used[key].add(k) 126 | else: 127 | hit.append(0) 128 | bottom = np.array(range(len(hit))) + 1 129 | hit = np.cumsum(hit) 130 | rec = hit / npos 131 | prec = hit / bottom 132 | ap = 0.0 133 | for i in range(11): 134 | mask = rec >= (i / 10.0) 135 | if np.sum(mask) > 0: 136 | ap += np.max(prec[mask]) / 11.0 137 | 138 | return ap, np.max(rec) 139 | 140 | 141 | def calc_hit(det, gtbox): 142 | gtbox = gtbox.astype(np.float64) 143 | hiou = iou(det[:4], gtbox[:4]) 144 | oiou = iou(det[4:], gtbox[4:]) 145 | return min(hiou, oiou) 146 | 147 | 148 | def iou(bb1, bb2, debug = False): 149 | x1 = bb1[2] - bb1[0] 150 | y1 = bb1[3] - bb1[1] 151 | if x1 < 0: 152 | x1 = 0 153 | if y1 < 0: 154 | y1 = 0 155 | 156 | x2 = bb2[2] - bb2[0] 157 | y2 = bb2[3] - bb2[1] 158 | if x2 < 0: 159 | x2 = 0 160 | if y2 < 0: 161 | y2 = 0 162 | 163 | xiou = min(bb1[2], bb2[2]) - max(bb1[0], bb2[0]) 164 | yiou = min(bb1[3], bb2[3]) - max(bb1[1], bb2[1]) 165 | if xiou < 0: 166 | xiou = 0 167 | if yiou < 0: 168 | yiou = 0 169 | 170 | if debug: 171 | print(x1, y1, x2, y2, xiou, yiou) 172 | print(x1 * y1, x2 * y2, xiou * yiou) 173 | if xiou * yiou <= 0: 174 | return 0 175 | else: 176 | return xiou * yiou / (x1 * y1 + x2 * y2 - xiou * yiou) 177 | 178 | 179 | ''' deprecated, evaluator 180 | eval_hois = [x["id"] for x in SWIG_INTERACTIONS if x["evaluation"] == 1] 181 | def swig_evaluation(predictions, gts): 182 | images, results = [], [] 183 | for img_key, ps in predictions.items(): 184 | images.extend([img_key] * len(ps)) 185 | results.extend(ps) 186 | 187 | size = max(eval_hois) + 1 188 | swig_ap, swig_rec = np.zeros(size), np.zeros(size) 189 | 190 | scores = [[] for _ in range(size)] 191 | boxes = [[] for _ in range(size)] 192 | keys = [[] for _ in range(size)] 193 | 194 | for img_id, det in zip(images, results): 195 | hoi_id, person_box, object_box, score = int(det[0]), det[1], det[2], det[-1] 196 | scores[hoi_id].append(score) 197 | boxes[hoi_id].append([float(x) for x in person_box] + [float(x) for x in object_box]) 198 | keys[hoi_id].append(img_id) 199 | 200 | for hoi_id in eval_hois: 201 | gts_per_hoi = gts[hoi_id] 202 | ap, rec = calc_ap(scores[hoi_id], boxes[hoi_id], keys[hoi_id], gts_per_hoi) 203 | swig_ap[hoi_id], swig_rec[hoi_id] = ap, rec 204 | 205 | return swig_ap, swig_rec 206 | 207 | 208 | def prepare_swig_gts(anno_file): 209 | """ 210 | Convert dataset to the format required by evaluator. 211 | """ 212 | with open(anno_file, "r") as f: 213 | dataset_dicts = json.load(f) 214 | 215 | filename_to_id_mapper = {x["file_name"]: i for i, x in enumerate(dataset_dicts)} 216 | hoi_mapper = {(x["action_id"], x["object_id"]): x["id"] for x in SWIG_INTERACTIONS} 217 | 218 | size = max(eval_hois) + 1 219 | gts = {i: collections.defaultdict(list) for i in range(size)} 220 | for anno_dict in dataset_dicts: 221 | image_id = filename_to_id_mapper[anno_dict["file_name"]] 222 | box_annos = anno_dict.get("box_annotations", []) 223 | hoi_annos = anno_dict.get("hoi_annotations", []) 224 | for hoi in hoi_annos: 225 | person_box = box_annos[hoi["subject_id"]]["bbox"] 226 | object_box = box_annos[hoi["object_id"]]["bbox"] 227 | action_id = hoi["action_id"] 228 | object_id = box_annos[hoi["object_id"]]["category_id"] 229 | hoi_id = hoi_mapper[(action_id, object_id)] 230 | gts[hoi_id][image_id].append(person_box + object_box) 231 | 232 | for hoi_id in gts: 233 | for img_id in gts[hoi_id]: 234 | gts[hoi_id][img_id] = np.array(gts[hoi_id][img_id]) 235 | 236 | return gts, filename_to_id_mapper 237 | ''' -------------------------------------------------------------------------------- /datasets/datasets_generate_feature/transforms.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Modified from DETR (https://github.com/facebookresearch/detr) 3 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 4 | # ------------------------------------------------------------------------ 5 | """ 6 | Transforms and data augmentation for both image + bbox. 7 | """ 8 | import random 9 | 10 | import PIL 11 | import torch 12 | import torchvision.transforms as T 13 | import torchvision.transforms.functional as F 14 | 15 | from util.box_ops import box_xyxy_to_cxcywh 16 | from util.misc import interpolate 17 | 18 | 19 | def crop(image, target, region): 20 | cropped_image = F.crop(image, *region) 21 | 22 | target = target.copy() 23 | i, j, h, w = region 24 | 25 | # should we do something wrt the original size? 26 | target["size"] = torch.tensor([h, w]) 27 | 28 | fields = ["labels", "area", "iscrowd"] 29 | 30 | if "boxes" in target: 31 | boxes = target["boxes"] 32 | max_size = torch.as_tensor([w, h], dtype=torch.float32) 33 | cropped_boxes = boxes - torch.as_tensor([j, i, j, i]) 34 | cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size) 35 | cropped_boxes = cropped_boxes.clamp(min=0) 36 | area = (cropped_boxes[:, 1, :] - cropped_boxes[:, 0, :]).prod(dim=1) 37 | target["boxes"] = cropped_boxes.reshape(-1, 4) 38 | target["area"] = area 39 | fields.append("boxes") 40 | 41 | if "masks" in target: 42 | # FIXME should we update the area here if there are no boxes? 43 | target['masks'] = target['masks'][:, i:i + h, j:j + w] 44 | fields.append("masks") 45 | 46 | # remove elements for which the boxes or masks that have zero area 47 | if "boxes" in target or "masks" in target: 48 | # favor boxes selection when defining which elements to keep 49 | # this is compatible with previous implementation 50 | if "boxes" in target: 51 | cropped_boxes = target['boxes'].reshape(-1, 2, 2) 52 | keep = torch.all(cropped_boxes[:, 1, :] > cropped_boxes[:, 0, :], dim=1) 53 | else: 54 | keep = target['masks'].flatten(1).any(1) 55 | 56 | for field in fields: 57 | target[field] = target[field][keep] 58 | 59 | return cropped_image, target 60 | 61 | 62 | def hflip(image, target): 63 | flipped_image = F.hflip(image) 64 | 65 | w, h = image.size 66 | 67 | target = target.copy() 68 | if "boxes" in target: 69 | boxes = target["boxes"] 70 | boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor([-1, 1, -1, 1]) + torch.as_tensor([w, 0, w, 0]) 71 | target["boxes"] = boxes 72 | 73 | if "masks" in target: 74 | target['masks'] = target['masks'].flip(-1) 75 | 76 | return flipped_image, target 77 | 78 | 79 | def resize(image, target, size, max_size=None): 80 | # size can be min_size (scalar) or (w, h) tuple 81 | 82 | def get_size_with_aspect_ratio(image_size, size, max_size=None): 83 | w, h = image_size 84 | if max_size is not None: 85 | min_original_size = float(min((w, h))) 86 | max_original_size = float(max((w, h))) 87 | if max_original_size / min_original_size * size > max_size: 88 | size = int(round(max_size * min_original_size / max_original_size)) 89 | 90 | if (w <= h and w == size) or (h <= w and h == size): 91 | return (h, w) 92 | 93 | if w < h: 94 | ow = size 95 | oh = int(size * h / w) 96 | else: 97 | oh = size 98 | ow = int(size * w / h) 99 | 100 | return (oh, ow) 101 | 102 | def get_size(image_size, size, max_size=None): 103 | if isinstance(size, (list, tuple)): 104 | return size[::-1] 105 | else: 106 | return get_size_with_aspect_ratio(image_size, size, max_size) 107 | 108 | size = get_size(image.size, size, max_size) 109 | rescaled_image = F.resize(image, size) 110 | 111 | if target is None: 112 | return rescaled_image, None 113 | 114 | ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(rescaled_image.size, image.size)) 115 | ratio_width, ratio_height = ratios 116 | 117 | target = target.copy() 118 | if "boxes" in target: 119 | boxes = target["boxes"] 120 | scaled_boxes = boxes * torch.as_tensor([ratio_width, ratio_height, ratio_width, ratio_height]) 121 | target["boxes"] = scaled_boxes 122 | 123 | if "area" in target: 124 | area = target["area"] 125 | scaled_area = area * (ratio_width * ratio_height) 126 | target["area"] = scaled_area 127 | 128 | h, w = size 129 | target["size"] = torch.tensor([h, w]) 130 | 131 | if "masks" in target: 132 | target['masks'] = interpolate( 133 | target['masks'][:, None].float(), size, mode="nearest")[:, 0] > 0.5 134 | 135 | return rescaled_image, target 136 | 137 | 138 | def pad(image, target, padding): 139 | # assumes that we only pad on the bottom right corners 140 | padded_image = F.pad(image, (0, 0, padding[0], padding[1])) 141 | if target is None: 142 | return padded_image, None 143 | target = target.copy() 144 | # should we do something wrt the original size? 145 | target["size"] = torch.tensor(padded_image[::-1]) 146 | if "masks" in target: 147 | target['masks'] = torch.nn.functional.pad(target['masks'], (0, padding[0], 0, padding[1])) 148 | return padded_image, target 149 | 150 | 151 | class RandomCrop(object): 152 | def __init__(self, size): 153 | self.size = size 154 | 155 | def __call__(self, img, target): 156 | region = T.RandomCrop.get_params(img, self.size) 157 | return crop(img, target, region) 158 | 159 | 160 | class RandomSizeCrop(object): 161 | def __init__(self, min_size: int, max_size: int): 162 | self.min_size = min_size 163 | self.max_size = max_size 164 | 165 | def __call__(self, img: PIL.Image.Image, target: dict): 166 | w = random.randint(self.min_size, min(img.width, self.max_size)) 167 | h = random.randint(self.min_size, min(img.height, self.max_size)) 168 | region = T.RandomCrop.get_params(img, [h, w]) 169 | return crop(img, target, region) 170 | 171 | 172 | class CenterCrop(object): 173 | def __init__(self, size): 174 | self.size = size 175 | 176 | def __call__(self, img, target): 177 | image_width, image_height = img.size 178 | crop_height, crop_width = self.size 179 | crop_top = int(round((image_height - crop_height) / 2.)) 180 | crop_left = int(round((image_width - crop_width) / 2.)) 181 | return crop(img, target, (crop_top, crop_left, crop_height, crop_width)) 182 | 183 | 184 | class RandomHorizontalFlip(object): 185 | def __init__(self, p=0.5): 186 | self.p = p 187 | 188 | def __call__(self, img, target): 189 | if random.random() < self.p: 190 | return hflip(img, target) 191 | return img, target 192 | 193 | 194 | class RandomResize(object): 195 | def __init__(self, sizes, max_size=None): 196 | assert isinstance(sizes, (list, tuple)) 197 | self.sizes = sizes 198 | self.max_size = max_size 199 | 200 | def __call__(self, img, target=None): 201 | size = random.choice(self.sizes) 202 | return resize(img, target, size, self.max_size) 203 | 204 | 205 | class RandomPad(object): 206 | def __init__(self, max_pad): 207 | self.max_pad = max_pad 208 | 209 | def __call__(self, img, target): 210 | pad_x = random.randint(0, self.max_pad) 211 | pad_y = random.randint(0, self.max_pad) 212 | return pad(img, target, (pad_x, pad_y)) 213 | 214 | 215 | class RandomSelect(object): 216 | """ 217 | Randomly selects between transforms1 and transforms2, 218 | with probability p for transforms1 and (1 - p) for transforms2 219 | """ 220 | def __init__(self, transforms1, transforms2, p=0.5): 221 | self.transforms1 = transforms1 222 | self.transforms2 = transforms2 223 | self.p = p 224 | 225 | def __call__(self, img, target): 226 | if random.random() < self.p: 227 | return self.transforms1(img, target) 228 | return self.transforms2(img, target) 229 | 230 | 231 | class ToTensor(object): 232 | def __call__(self, img, target): 233 | return F.to_tensor(img), target 234 | 235 | 236 | class RandomErasing(object): 237 | 238 | def __init__(self, *args, **kwargs): 239 | self.eraser = T.RandomErasing(*args, **kwargs) 240 | 241 | def __call__(self, img, target): 242 | return self.eraser(img), target 243 | 244 | 245 | class Normalize(object): 246 | def __init__(self, mean, std): 247 | self.mean = mean 248 | self.std = std 249 | 250 | def __call__(self, image, target=None): 251 | image = F.normalize(image, mean=self.mean, std=self.std) 252 | if target is None: 253 | return image, None 254 | target = target.copy() 255 | h, w = image.shape[-2:] 256 | if "boxes" in target: 257 | boxes = target["boxes"] 258 | boxes = box_xyxy_to_cxcywh(boxes) 259 | boxes = boxes / torch.tensor([w, h, w, h], dtype=torch.float32) 260 | target["boxes"] = boxes 261 | return image, target 262 | 263 | 264 | class Compose(object): 265 | def __init__(self, transforms): 266 | self.transforms = transforms 267 | 268 | def __call__(self, image, target): 269 | for t in self.transforms: 270 | image, target = t(image, target) 271 | return image, target 272 | 273 | def __repr__(self): 274 | format_string = self.__class__.__name__ + "(" 275 | for t in self.transforms: 276 | format_string += "\n" 277 | format_string += " {0}".format(t) 278 | format_string += "\n)" 279 | return format_string 280 | 281 | class ColorJitter(object): 282 | def __init__(self, brightness=0, contrast=0, saturatio=0, hue=0): 283 | self.color_jitter = T.ColorJitter(brightness, contrast, saturatio, hue) 284 | 285 | def __call__(self, img, target): 286 | return self.color_jitter(img), target 287 | -------------------------------------------------------------------------------- /datasets/datasets_generate_feature/vcoco.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from PIL import Image 3 | import json 4 | import numpy as np 5 | 6 | import torch 7 | import torch.utils.data 8 | import torchvision 9 | 10 | import datasets.transforms as T 11 | import clip 12 | from .vcoco_text_label import * 13 | import copy 14 | 15 | 16 | class VCOCO(torch.utils.data.Dataset): 17 | 18 | def __init__(self, img_set, img_folder, anno_file, transforms, num_queries, args): 19 | self.img_set = img_set 20 | self.img_folder = img_folder 21 | with open(anno_file, 'r') as f: 22 | self.annotations = json.load(f) 23 | self._transforms = transforms 24 | 25 | self.num_queries = num_queries 26 | 27 | self._valid_obj_ids = (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 28 | 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 29 | 24, 25, 27, 28, 31, 32, 33, 34, 35, 36, 30 | 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 31 | 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 32 | 58, 59, 60, 61, 62, 63, 64, 65, 67, 70, 33 | 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 34 | 82, 84, 85, 86, 87, 88, 89, 90) 35 | self._valid_verb_ids = range(29) 36 | 37 | device = "cuda" if torch.cuda.is_available() else "cpu" 38 | _, self.clip_preprocess = clip.load(args.clip_model, device) 39 | 40 | self.text_label_ids = list(vcoco_hoi_text_label.keys()) 41 | 42 | def __len__(self): 43 | return len(self.annotations) 44 | 45 | def __getitem__(self, idx): 46 | img_anno = self.annotations[idx] 47 | 48 | img = Image.open(self.img_folder / img_anno['file_name']).convert('RGB') 49 | img_raw = Image.open(self.img_folder / img_anno['file_name']).convert('RGB') 50 | w, h = img.size 51 | 52 | if self.img_set == 'train' and len(img_anno['annotations']) > self.num_queries: 53 | img_anno['annotations'] = img_anno['annotations'][:self.num_queries] 54 | 55 | boxes = [obj['bbox'] for obj in img_anno['annotations']] 56 | # guard against no boxes via resizing 57 | boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4) 58 | 59 | if self.img_set == 'train': 60 | # Add index for confirming which boxes are kept after image transformation 61 | classes = [(i, self._valid_obj_ids.index(obj['category_id'])) for i, obj in enumerate(img_anno['annotations'])] 62 | else: 63 | classes = [self._valid_obj_ids.index(obj['category_id']) for obj in img_anno['annotations']] 64 | classes = torch.tensor(classes, dtype=torch.int64) 65 | 66 | target = {} 67 | target['orig_size'] = torch.as_tensor([int(h), int(w)]) 68 | target['size'] = torch.as_tensor([int(h), int(w)]) 69 | if self.img_set == 'train': 70 | boxes[:, 0::2].clamp_(min=0, max=w) 71 | boxes[:, 1::2].clamp_(min=0, max=h) 72 | keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0]) 73 | boxes = boxes[keep] 74 | classes = classes[keep] 75 | bbox_raw = copy.deepcopy(boxes) 76 | 77 | target['boxes'] = boxes 78 | target['labels'] = classes 79 | target['iscrowd'] = torch.tensor([0 for _ in range(boxes.shape[0])]) 80 | target['area'] = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) 81 | 82 | if self._transforms is not None: 83 | img_0, target_0 = self._transforms[0](img, target) 84 | img, target = self._transforms[1](img_0, target_0) 85 | clip_inputs = self.clip_preprocess(img_0) 86 | target['clip_inputs'] = clip_inputs 87 | kept_box_indices = [label[0] for label in target['labels']] 88 | 89 | target['labels'] = target['labels'][:, 1] 90 | 91 | obj_labels, verb_labels, sub_boxes, obj_boxes = [], [], [], [] 92 | sub_obj_pairs = [] 93 | hoi_labels = [] 94 | 95 | human_img = [] 96 | object_img = [] 97 | hoi_area_img = [] 98 | obj_cls = [] 99 | hoi_cls = [] 100 | for hoi in img_anno['hoi_annotation']: 101 | if hoi['subject_id'] not in kept_box_indices or \ 102 | (hoi['object_id'] != -1 and hoi['object_id'] not in kept_box_indices): 103 | continue 104 | 105 | #if hoi['subject_id'] not in kept_box_indices or hoi['object_id'] not in kept_box_indices: 106 | # continue 107 | 108 | if hoi['object_id'] == -1: 109 | verb_obj_pair = (self._valid_verb_ids.index(hoi['category_id']), 80) 110 | else: 111 | verb_obj_pair = (self._valid_verb_ids.index(hoi['category_id']), 112 | target['labels'][kept_box_indices.index(hoi['object_id'])]) 113 | 114 | if verb_obj_pair not in self.text_label_ids: 115 | continue 116 | 117 | h = bbox_raw[hoi['subject_id']] 118 | o = bbox_raw[hoi['object_id']] 119 | obj_cls.append(verb_obj_pair[1]) 120 | hoi_cls.append(list(vcoco_hoi_text_label).index(verb_obj_pair)) 121 | 122 | if list(vcoco_hoi_text_label).index(verb_obj_pair) > 260: 123 | print(list(vcoco_hoi_text_label).index(verb_obj_pair)) 124 | if list(vcoco_hoi_text_label).index(verb_obj_pair) < 1: 125 | print(list(vcoco_hoi_text_label).index(verb_obj_pair)) 126 | hoi_bbox = torch.zeros_like(h) 127 | hoi_bbox[0] = torch.min(h[0], o[0]) 128 | hoi_bbox[2] = torch.max(h[2], o[2]) 129 | hoi_bbox[1] = torch.min(h[1], o[1]) 130 | hoi_bbox[3] = torch.max(h[3], o[3]) 131 | 132 | h_img = img_raw.crop(h.tolist()) 133 | o_img = img_raw.crop(o.tolist()) 134 | hoi_img = img_raw.crop(hoi_bbox.tolist()) 135 | 136 | human_img.append(self.clip_preprocess(h_img)) 137 | object_img.append(self.clip_preprocess(o_img)) 138 | hoi_area_img.append(self.clip_preprocess(hoi_img)) 139 | 140 | sub_obj_pair = (hoi['subject_id'], hoi['object_id']) 141 | if sub_obj_pair in sub_obj_pairs: 142 | verb_labels[sub_obj_pairs.index(sub_obj_pair)][self._valid_verb_ids.index(hoi['category_id'])] = 1 143 | hoi_labels[sub_obj_pairs.index(sub_obj_pair)][self.text_label_ids.index(verb_obj_pair)] = 1 144 | else: 145 | sub_obj_pairs.append(sub_obj_pair) 146 | if hoi['object_id'] == -1: 147 | obj_labels.append(torch.tensor(len(self._valid_obj_ids))) 148 | else: 149 | obj_labels.append(target['labels'][kept_box_indices.index(hoi['object_id'])]) 150 | verb_label = [0 for _ in range(len(self._valid_verb_ids))] 151 | verb_label[self._valid_verb_ids.index(hoi['category_id'])] = 1 152 | hoi_label = [0] * len(self.text_label_ids) 153 | hoi_label[self.text_label_ids.index(verb_obj_pair)] = 1 154 | sub_box = target['boxes'][kept_box_indices.index(hoi['subject_id'])] 155 | if hoi['object_id'] == -1: 156 | obj_box = torch.zeros((4,), dtype=torch.float32) 157 | else: 158 | obj_box = target['boxes'][kept_box_indices.index(hoi['object_id'])] 159 | verb_labels.append(verb_label) 160 | hoi_labels.append(hoi_label) 161 | sub_boxes.append(sub_box) 162 | obj_boxes.append(obj_box) 163 | 164 | target['filename'] = img_anno['file_name'] 165 | if len(sub_obj_pairs) == 0: 166 | target['obj_labels'] = torch.zeros((0,), dtype=torch.int64) 167 | target['verb_labels'] = torch.zeros((0, len(self._valid_verb_ids)), dtype=torch.float32) 168 | #target['hoi_labels'] = torch.zeros((0, len(self._valid_verb_ids)), dtype=torch.float32) 169 | target['hoi_labels'] = torch.zeros((0, len(self.text_label_ids)), dtype=torch.float32) 170 | target['sub_boxes'] = torch.zeros((0, 4), dtype=torch.float32) 171 | target['obj_boxes'] = torch.zeros((0, 4), dtype=torch.float32) 172 | else: 173 | target['obj_labels'] = torch.stack(obj_labels) 174 | target['verb_labels'] = torch.as_tensor(verb_labels, dtype=torch.float32) 175 | #target['hoi_labels'] = torch.as_tensor(verb_labels, dtype=torch.float32) 176 | target['hoi_labels'] = torch.as_tensor(hoi_labels, dtype=torch.float32) 177 | target['sub_boxes'] = torch.stack(sub_boxes) 178 | target['obj_boxes'] = torch.stack(obj_boxes) 179 | 180 | human_img = torch.stack(human_img) 181 | object_img = torch.stack(object_img) 182 | hoi_area_img = torch.stack(hoi_area_img) 183 | obj_cls = torch.tensor(obj_cls) 184 | hoi_cls = torch.tensor(hoi_cls) 185 | 186 | target['human_img'] = human_img 187 | target['object_img'] = object_img 188 | target['hoi_area_img'] = hoi_area_img 189 | target['obj_cls'] = obj_cls 190 | target['hoi_cls'] = hoi_cls 191 | else: 192 | target['filename'] = img_anno['file_name'] 193 | target['boxes'] = boxes 194 | target['labels'] = classes 195 | target['id'] = idx 196 | target['img_id'] = int(img_anno['file_name'].rstrip('.jpg').split('_')[2]) 197 | 198 | if self._transforms is not None: 199 | img, _ = self._transforms(img, None) 200 | 201 | hois = [] 202 | for hoi in img_anno['hoi_annotation']: 203 | hois.append((hoi['subject_id'], hoi['object_id'], self._valid_verb_ids.index(hoi['category_id']))) 204 | target['hois'] = torch.as_tensor(hois, dtype=torch.int64) 205 | 206 | return img, target 207 | 208 | def load_correct_mat(self, path): 209 | self.correct_mat = np.load(path) 210 | 211 | 212 | # Add color jitter to coco transforms 213 | def make_vcoco_transforms(image_set): 214 | 215 | normalize = T.Compose([ 216 | T.ToTensor(), 217 | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 218 | ]) 219 | 220 | scales = [480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800] 221 | 222 | if image_set == 'train': 223 | return [T.Compose([ 224 | T.RandomHorizontalFlip(), 225 | T.ColorJitter(.4, .4, .4), 226 | T.RandomSelect( 227 | T.RandomResize(scales, max_size=1333), 228 | T.Compose([ 229 | T.RandomResize([400, 500, 600]), 230 | T.RandomSizeCrop(384, 600), 231 | T.RandomResize(scales, max_size=1333), 232 | ]))] 233 | ), 234 | normalize 235 | ] 236 | 237 | if image_set == 'val': 238 | return T.Compose([ 239 | T.RandomResize([800], max_size=1333), 240 | normalize, 241 | ]) 242 | 243 | raise ValueError(f'unknown {image_set}') 244 | 245 | 246 | def build(image_set, args): 247 | root = Path(args.hoi_path) 248 | assert root.exists(), f'provided HOI path {root} does not exist' 249 | PATHS = { 250 | 'train': (root / 'images' / 'train2014', root / 'annotations' / 'trainval_vcoco.json'), 251 | 'val': (root / 'images' / 'val2014', root / 'annotations' / 'test_vcoco.json') 252 | } 253 | CORRECT_MAT_PATH = root / 'annotations' / 'corre_vcoco.npy' 254 | 255 | img_folder, anno_file = PATHS[image_set] 256 | dataset = VCOCO(image_set, img_folder, anno_file, transforms=make_vcoco_transforms(image_set), 257 | num_queries=args.num_queries, args=args) 258 | if image_set == 'val': 259 | dataset.load_correct_mat(CORRECT_MAT_PATH) 260 | return dataset 261 | -------------------------------------------------------------------------------- /datasets/datasets_generate_feature/vcoco_eval.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from collections import defaultdict 3 | import os, cv2, json 4 | from .vcoco_text_label import * 5 | from util.topk import top_k 6 | 7 | class VCOCOEvaluator(): 8 | 9 | def __init__(self, preds, gts, correct_mat, use_nms_filter=False): 10 | self.overlap_iou = 0.5 11 | self.max_hois = 100 12 | 13 | self.fp = defaultdict(list) 14 | self.tp = defaultdict(list) 15 | self.score = defaultdict(list) 16 | self.sum_gts = defaultdict(lambda: 0) 17 | 18 | self.verb_classes = ['hold_obj', 'stand', 'sit_instr', 'ride_instr', 'walk', 'look_obj', 'hit_instr', 'hit_obj', 19 | 'eat_obj', 'eat_instr', 'jump_instr', 'lay_instr', 'talk_on_phone_instr', 'carry_obj', 20 | 'throw_obj', 'catch_obj', 'cut_instr', 'cut_obj', 'run', 'work_on_computer_instr', 21 | 'ski_instr', 'surf_instr', 'skateboard_instr', 'smile', 'drink_instr', 'kick_obj', 22 | 'point_instr', 'read_obj', 'snowboard_instr'] 23 | self.thesis_map_indices = [0, 2, 3, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 19, 20, 21, 22, 24, 25, 27, 28] 24 | 25 | self.preds = [] 26 | self.hoi_obj_list = [] 27 | self.verb_hoi_dict = defaultdict(list) 28 | self.vcoco_triplet_labels = list(vcoco_hoi_text_label.keys()) 29 | for index, hoi_pair in enumerate(self.vcoco_triplet_labels): 30 | self.hoi_obj_list.append(hoi_pair[1]) 31 | self.verb_hoi_dict[hoi_pair[0]].append(index) 32 | 33 | self.score_mode = 1 34 | for img_preds in preds: 35 | img_preds = {k: v.to('cpu').numpy() for k, v in img_preds.items()} 36 | bboxes = [{'bbox': bbox, 'category_id': label} for bbox, label in zip(img_preds['boxes'], img_preds['labels'])] 37 | if self.score_mode == 0: 38 | obj_scores = img_preds['obj_scores'] 39 | hoi_scores = img_preds['hoi_scores'] * obj_scores[:, self.hoi_obj_list] 40 | elif self.score_mode == 1: 41 | obj_scores = img_preds['obj_scores'] * img_preds['obj_scores'] 42 | hoi_scores = img_preds['hoi_scores'] + obj_scores[:, self.hoi_obj_list] 43 | else: 44 | raise 45 | 46 | verb_scores = np.zeros((hoi_scores.shape[0], len(self.verb_hoi_dict)))# 64 x 29 47 | for i in range(hoi_scores.shape[0]): 48 | for k,v in self.verb_hoi_dict.items(): 49 | #verb_scores[i][k] = np.sum(hoi_scores[i, v]) 50 | verb_scores[i][k] = np.max(hoi_scores[i, v]) 51 | 52 | verb_labels = np.tile(np.arange(verb_scores.shape[1]), (verb_scores.shape[0], 1)) 53 | subject_ids = np.tile(img_preds['sub_ids'], (verb_scores.shape[1], 1)).T 54 | object_ids = np.tile(img_preds['obj_ids'], (verb_scores.shape[1], 1)).T 55 | 56 | verb_scores = verb_scores.ravel() 57 | verb_labels = verb_labels.ravel() 58 | subject_ids = subject_ids.ravel() 59 | object_ids = object_ids.ravel() 60 | 61 | if len(subject_ids) > 0: 62 | object_labels = np.array([bboxes[object_id]['category_id'] for object_id in object_ids]) 63 | correct_mat = np.concatenate((correct_mat, np.ones((correct_mat.shape[0], 1))), axis=1) 64 | masks = correct_mat[verb_labels, object_labels] 65 | verb_scores *= masks 66 | 67 | hois = [{'subject_id': subject_id, 'object_id': object_id, 'category_id': category_id, 'score': score} for 68 | subject_id, object_id, category_id, score in zip(subject_ids, object_ids, verb_labels, verb_scores)] 69 | hois.sort(key=lambda k: (k.get('score', 0)), reverse=True) 70 | hois = hois[:self.max_hois] 71 | else: 72 | hois = [] 73 | 74 | 75 | self.preds.append({ 76 | 'predictions': bboxes, 77 | 'hoi_prediction': hois 78 | }) 79 | 80 | self.gts = [] 81 | for img_gts in gts: 82 | img_gts = {k: v.to('cpu').numpy() for k, v in img_gts.items() if k != 'id' and k != 'img_id' and k != 'filename'} 83 | self.gts.append({ 84 | 'annotations': [{'bbox': bbox, 'category_id': label} for bbox, label in zip(img_gts['boxes'], img_gts['labels'])], 85 | 'hoi_annotation': [{'subject_id': hoi[0], 'object_id': hoi[1], 'category_id': hoi[2]} for hoi in img_gts['hois']] 86 | }) 87 | for hoi in self.gts[-1]['hoi_annotation']: 88 | self.sum_gts[hoi['category_id']] += 1 89 | 90 | def evaluate(self): 91 | for img_preds, img_gts in zip(self.preds, self.gts): 92 | pred_bboxes = img_preds['predictions'] 93 | gt_bboxes = img_gts['annotations'] 94 | pred_hois = img_preds['hoi_prediction'] 95 | gt_hois = img_gts['hoi_annotation'] 96 | if len(gt_bboxes) != 0: 97 | bbox_pairs, bbox_overlaps = self.compute_iou_mat(gt_bboxes, pred_bboxes) 98 | self.compute_fptp(pred_hois, gt_hois, bbox_pairs, pred_bboxes, bbox_overlaps) 99 | else: 100 | for pred_hoi in pred_hois: 101 | self.tp[pred_hoi['category_id']].append(0) 102 | self.fp[pred_hoi['category_id']].append(1) 103 | self.score[pred_hoi['category_id']].append(pred_hoi['score']) 104 | map = self.compute_map() 105 | return map 106 | 107 | def compute_map(self): 108 | print('------------------------------------------------------------') 109 | ap = defaultdict(lambda: 0) 110 | aps = {} 111 | for category_id in sorted(list(self.sum_gts.keys())): 112 | sum_gts = self.sum_gts[category_id] 113 | if sum_gts == 0: 114 | continue 115 | 116 | tp = np.array((self.tp[category_id])) 117 | fp = np.array((self.fp[category_id])) 118 | if len(tp) == 0: 119 | ap[category_id] = 0 120 | else: 121 | score = np.array(self.score[category_id]) 122 | sort_inds = np.argsort(-score) 123 | fp = fp[sort_inds] 124 | tp = tp[sort_inds] 125 | fp = np.cumsum(fp) 126 | tp = np.cumsum(tp) 127 | rec = tp / sum_gts 128 | prec = tp / (fp + tp) 129 | ap[category_id] = self.voc_ap(rec, prec) 130 | print('{:>23s}: #GTs = {:>04d}, AP = {:>.4f}'.format(self.verb_classes[category_id], sum_gts, ap[category_id])) 131 | aps['AP_{}'.format(self.verb_classes[category_id])] = ap[category_id] 132 | 133 | m_ap_all = np.mean(list(ap.values())) 134 | m_ap_thesis = np.mean([ap[category_id] for category_id in self.thesis_map_indices]) 135 | 136 | print('------------------------------------------------------------') 137 | print('mAP all: {:.4f} mAP thesis: {:.4f}'.format(m_ap_all, m_ap_thesis)) 138 | print('------------------------------------------------------------') 139 | 140 | aps.update({'mAP_all': m_ap_all, 'mAP_thesis': m_ap_thesis}) 141 | 142 | return aps 143 | 144 | def voc_ap(self, rec, prec): 145 | ap = 0. 146 | for t in np.arange(0., 1.1, 0.1): 147 | if np.sum(rec >= t) == 0: 148 | p = 0 149 | else: 150 | p = np.max(prec[rec >= t]) 151 | ap = ap + p / 11. 152 | return ap 153 | 154 | def compute_fptp(self, pred_hois, gt_hois, match_pairs, pred_bboxes, bbox_overlaps): 155 | pos_pred_ids = match_pairs.keys() 156 | vis_tag = np.zeros(len(gt_hois)) 157 | pred_hois.sort(key=lambda k: (k.get('score', 0)), reverse=True) 158 | if len(pred_hois) != 0: 159 | for pred_hoi in pred_hois: 160 | is_match = 0 161 | max_overlap = 0 162 | max_gt_hoi = 0 163 | for gt_hoi in gt_hois: 164 | if len(match_pairs) != 0 and pred_hoi['subject_id'] in pos_pred_ids and \ 165 | gt_hoi['object_id'] == -1: 166 | pred_sub_ids = match_pairs[pred_hoi['subject_id']] 167 | pred_sub_overlaps = bbox_overlaps[pred_hoi['subject_id']] 168 | pred_category_id = pred_hoi['category_id'] 169 | if gt_hoi['subject_id'] in pred_sub_ids and pred_category_id == gt_hoi['category_id']: 170 | is_match = 1 171 | min_overlap_gt = pred_sub_overlaps[pred_sub_ids.index(gt_hoi['subject_id'])] 172 | if min_overlap_gt > max_overlap: 173 | max_overlap = min_overlap_gt 174 | max_gt_hoi = gt_hoi 175 | elif len(match_pairs) != 0 and pred_hoi['subject_id'] in pos_pred_ids and \ 176 | pred_hoi['object_id'] in pos_pred_ids: 177 | pred_sub_ids = match_pairs[pred_hoi['subject_id']] 178 | pred_obj_ids = match_pairs[pred_hoi['object_id']] 179 | pred_sub_overlaps = bbox_overlaps[pred_hoi['subject_id']] 180 | pred_obj_overlaps = bbox_overlaps[pred_hoi['object_id']] 181 | pred_category_id = pred_hoi['category_id'] 182 | if gt_hoi['subject_id'] in pred_sub_ids and gt_hoi['object_id'] in pred_obj_ids and \ 183 | pred_category_id == gt_hoi['category_id']: 184 | is_match = 1 185 | min_overlap_gt = min(pred_sub_overlaps[pred_sub_ids.index(gt_hoi['subject_id'])], 186 | pred_obj_overlaps[pred_obj_ids.index(gt_hoi['object_id'])]) 187 | if min_overlap_gt > max_overlap: 188 | max_overlap = min_overlap_gt 189 | max_gt_hoi = gt_hoi 190 | if is_match == 1 and vis_tag[gt_hois.index(max_gt_hoi)] == 0: 191 | self.fp[pred_hoi['category_id']].append(0) 192 | self.tp[pred_hoi['category_id']].append(1) 193 | vis_tag[gt_hois.index(max_gt_hoi)] = 1 194 | else: 195 | self.fp[pred_hoi['category_id']].append(1) 196 | self.tp[pred_hoi['category_id']].append(0) 197 | self.score[pred_hoi['category_id']].append(pred_hoi['score']) 198 | 199 | def compute_iou_mat(self, bbox_list1, bbox_list2): 200 | iou_mat = np.zeros((len(bbox_list1), len(bbox_list2))) 201 | if len(bbox_list1) == 0 or len(bbox_list2) == 0: 202 | return {} 203 | for i, bbox1 in enumerate(bbox_list1): 204 | for j, bbox2 in enumerate(bbox_list2): 205 | iou_i = self.compute_IOU(bbox1, bbox2) 206 | iou_mat[i, j] = iou_i 207 | 208 | iou_mat_ov=iou_mat.copy() 209 | iou_mat[iou_mat>=self.overlap_iou] = 1 210 | iou_mat[iou_mat 0: 216 | for i, pred_id in enumerate(match_pairs[1]): 217 | if pred_id not in match_pairs_dict.keys(): 218 | match_pairs_dict[pred_id] = [] 219 | match_pair_overlaps[pred_id]=[] 220 | match_pairs_dict[pred_id].append(match_pairs[0][i]) 221 | match_pair_overlaps[pred_id].append(iou_mat_ov[match_pairs[0][i],pred_id]) 222 | return match_pairs_dict, match_pair_overlaps 223 | 224 | def compute_IOU(self, bbox1, bbox2): 225 | if isinstance(bbox1['category_id'], str): 226 | bbox1['category_id'] = int(bbox1['category_id'].replace('\n', '')) 227 | if isinstance(bbox2['category_id'], str): 228 | bbox2['category_id'] = int(bbox2['category_id'].replace('\n', '')) 229 | if bbox1['category_id'] == bbox2['category_id']: 230 | rec1 = bbox1['bbox'] 231 | rec2 = bbox2['bbox'] 232 | # computing area of each rectangles 233 | S_rec1 = (rec1[2] - rec1[0]+1) * (rec1[3] - rec1[1]+1) 234 | S_rec2 = (rec2[2] - rec2[0]+1) * (rec2[3] - rec2[1]+1) 235 | 236 | # computing the sum_area 237 | sum_area = S_rec1 + S_rec2 238 | 239 | # find the each edge of intersect rectangle 240 | left_line = max(rec1[1], rec2[1]) 241 | right_line = min(rec1[3], rec2[3]) 242 | top_line = max(rec1[0], rec2[0]) 243 | bottom_line = min(rec1[2], rec2[2]) 244 | # judge if there is an intersect 245 | if left_line >= right_line or top_line >= bottom_line: 246 | return 0 247 | else: 248 | intersect = (right_line - left_line+1) * (bottom_line - top_line+1) 249 | return intersect / (sum_area - intersect) 250 | else: 251 | return 0 252 | -------------------------------------------------------------------------------- /datasets/transforms.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Modified from DETR (https://github.com/facebookresearch/detr) 3 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 4 | # ------------------------------------------------------------------------ 5 | """ 6 | Transforms and data augmentation for both image + bbox. 7 | """ 8 | import random 9 | 10 | import PIL 11 | import torch 12 | import torchvision.transforms as T 13 | import torchvision.transforms.functional as F 14 | 15 | from util.box_ops import box_xyxy_to_cxcywh 16 | from util.misc import interpolate 17 | 18 | 19 | def crop(image, target, region): 20 | cropped_image = F.crop(image, *region) 21 | 22 | target = target.copy() 23 | i, j, h, w = region 24 | 25 | # should we do something wrt the original size? 26 | target["size"] = torch.tensor([h, w]) 27 | 28 | fields = ["labels", "area", "iscrowd"] 29 | 30 | if "boxes" in target: 31 | boxes = target["boxes"] 32 | max_size = torch.as_tensor([w, h], dtype=torch.float32) 33 | cropped_boxes = boxes - torch.as_tensor([j, i, j, i]) 34 | cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size) 35 | cropped_boxes = cropped_boxes.clamp(min=0) 36 | area = (cropped_boxes[:, 1, :] - cropped_boxes[:, 0, :]).prod(dim=1) 37 | target["boxes"] = cropped_boxes.reshape(-1, 4) 38 | target["area"] = area 39 | fields.append("boxes") 40 | 41 | if "masks" in target: 42 | # FIXME should we update the area here if there are no boxes? 43 | target['masks'] = target['masks'][:, i:i + h, j:j + w] 44 | fields.append("masks") 45 | 46 | # remove elements for which the boxes or masks that have zero area 47 | if "boxes" in target or "masks" in target: 48 | # favor boxes selection when defining which elements to keep 49 | # this is compatible with previous implementation 50 | if "boxes" in target: 51 | cropped_boxes = target['boxes'].reshape(-1, 2, 2) 52 | keep = torch.all(cropped_boxes[:, 1, :] > cropped_boxes[:, 0, :], dim=1) 53 | else: 54 | keep = target['masks'].flatten(1).any(1) 55 | 56 | for field in fields: 57 | target[field] = target[field][keep] 58 | 59 | return cropped_image, target 60 | 61 | 62 | def hflip(image, target): 63 | flipped_image = F.hflip(image) 64 | 65 | w, h = image.size 66 | 67 | target = target.copy() 68 | if "boxes" in target: 69 | boxes = target["boxes"] 70 | boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor([-1, 1, -1, 1]) + torch.as_tensor([w, 0, w, 0]) 71 | target["boxes"] = boxes 72 | 73 | if "masks" in target: 74 | target['masks'] = target['masks'].flip(-1) 75 | 76 | return flipped_image, target 77 | 78 | 79 | def resize(image, target, size, max_size=None): 80 | # size can be min_size (scalar) or (w, h) tuple 81 | 82 | def get_size_with_aspect_ratio(image_size, size, max_size=None): 83 | w, h = image_size 84 | if max_size is not None: 85 | min_original_size = float(min((w, h))) 86 | max_original_size = float(max((w, h))) 87 | if max_original_size / min_original_size * size > max_size: 88 | size = int(round(max_size * min_original_size / max_original_size)) 89 | 90 | if (w <= h and w == size) or (h <= w and h == size): 91 | return (h, w) 92 | 93 | if w < h: 94 | ow = size 95 | oh = int(size * h / w) 96 | else: 97 | oh = size 98 | ow = int(size * w / h) 99 | 100 | return (oh, ow) 101 | 102 | def get_size(image_size, size, max_size=None): 103 | if isinstance(size, (list, tuple)): 104 | return size[::-1] 105 | else: 106 | return get_size_with_aspect_ratio(image_size, size, max_size) 107 | 108 | size = get_size(image.size, size, max_size) 109 | rescaled_image = F.resize(image, size) 110 | 111 | if target is None: 112 | return rescaled_image, None 113 | 114 | ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(rescaled_image.size, image.size)) 115 | ratio_width, ratio_height = ratios 116 | 117 | target = target.copy() 118 | if "boxes" in target: 119 | boxes = target["boxes"] 120 | scaled_boxes = boxes * torch.as_tensor([ratio_width, ratio_height, ratio_width, ratio_height]) 121 | target["boxes"] = scaled_boxes 122 | 123 | if "area" in target: 124 | area = target["area"] 125 | scaled_area = area * (ratio_width * ratio_height) 126 | target["area"] = scaled_area 127 | 128 | h, w = size 129 | target["size"] = torch.tensor([h, w]) 130 | 131 | if "masks" in target: 132 | target['masks'] = interpolate( 133 | target['masks'][:, None].float(), size, mode="nearest")[:, 0] > 0.5 134 | 135 | return rescaled_image, target 136 | 137 | 138 | def pad(image, target, padding): 139 | # assumes that we only pad on the bottom right corners 140 | padded_image = F.pad(image, (0, 0, padding[0], padding[1])) 141 | if target is None: 142 | return padded_image, None 143 | target = target.copy() 144 | # should we do something wrt the original size? 145 | target["size"] = torch.tensor(padded_image[::-1]) 146 | if "masks" in target: 147 | target['masks'] = torch.nn.functional.pad(target['masks'], (0, padding[0], 0, padding[1])) 148 | return padded_image, target 149 | 150 | 151 | class RandomCrop(object): 152 | def __init__(self, size): 153 | self.size = size 154 | 155 | def __call__(self, img, target): 156 | region = T.RandomCrop.get_params(img, self.size) 157 | return crop(img, target, region) 158 | 159 | 160 | class RandomSizeCrop(object): 161 | def __init__(self, min_size: int, max_size: int): 162 | self.min_size = min_size 163 | self.max_size = max_size 164 | 165 | def __call__(self, img: PIL.Image.Image, target: dict): 166 | w = random.randint(self.min_size, min(img.width, self.max_size)) 167 | h = random.randint(self.min_size, min(img.height, self.max_size)) 168 | region = T.RandomCrop.get_params(img, [h, w]) 169 | return crop(img, target, region) 170 | 171 | 172 | class CenterCrop(object): 173 | def __init__(self, size): 174 | self.size = size 175 | 176 | def __call__(self, img, target): 177 | image_width, image_height = img.size 178 | crop_height, crop_width = self.size 179 | crop_top = int(round((image_height - crop_height) / 2.)) 180 | crop_left = int(round((image_width - crop_width) / 2.)) 181 | return crop(img, target, (crop_top, crop_left, crop_height, crop_width)) 182 | 183 | 184 | class RandomHorizontalFlip(object): 185 | def __init__(self, p=0.5): 186 | self.p = p 187 | 188 | def __call__(self, img, target): 189 | if random.random() < self.p: 190 | return hflip(img, target) 191 | return img, target 192 | 193 | 194 | class RandomResize(object): 195 | def __init__(self, sizes, max_size=None): 196 | assert isinstance(sizes, (list, tuple)) 197 | self.sizes = sizes 198 | self.max_size = max_size 199 | 200 | def __call__(self, img, target=None): 201 | size = random.choice(self.sizes) 202 | return resize(img, target, size, self.max_size) 203 | 204 | 205 | class RandomPad(object): 206 | def __init__(self, max_pad): 207 | self.max_pad = max_pad 208 | 209 | def __call__(self, img, target): 210 | pad_x = random.randint(0, self.max_pad) 211 | pad_y = random.randint(0, self.max_pad) 212 | return pad(img, target, (pad_x, pad_y)) 213 | 214 | 215 | class RandomSelect(object): 216 | """ 217 | Randomly selects between transforms1 and transforms2, 218 | with probability p for transforms1 and (1 - p) for transforms2 219 | """ 220 | def __init__(self, transforms1, transforms2, p=0.5): 221 | self.transforms1 = transforms1 222 | self.transforms2 = transforms2 223 | self.p = p 224 | 225 | def __call__(self, img, target): 226 | if random.random() < self.p: 227 | return self.transforms1(img, target) 228 | return self.transforms2(img, target) 229 | 230 | 231 | class ToTensor(object): 232 | def __call__(self, img, target): 233 | return F.to_tensor(img), target 234 | 235 | 236 | class RandomErasing(object): 237 | 238 | def __init__(self, *args, **kwargs): 239 | self.eraser = T.RandomErasing(*args, **kwargs) 240 | 241 | def __call__(self, img, target): 242 | return self.eraser(img), target 243 | 244 | 245 | class Normalize(object): 246 | def __init__(self, mean, std): 247 | self.mean = mean 248 | self.std = std 249 | 250 | def __call__(self, image, target=None): 251 | image = F.normalize(image, mean=self.mean, std=self.std) 252 | if target is None: 253 | return image, None 254 | target = target.copy() 255 | h, w = image.shape[-2:] 256 | if "boxes" in target: 257 | boxes = target["boxes"] 258 | boxes = box_xyxy_to_cxcywh(boxes) 259 | boxes = boxes / torch.tensor([w, h, w, h], dtype=torch.float32) 260 | target["boxes"] = boxes 261 | return image, target 262 | 263 | 264 | class Compose(object): 265 | def __init__(self, transforms): 266 | self.transforms = transforms 267 | 268 | def __call__(self, image, target): 269 | for t in self.transforms: 270 | image, target = t(image, target) 271 | return image, target 272 | 273 | def __repr__(self): 274 | format_string = self.__class__.__name__ + "(" 275 | for t in self.transforms: 276 | format_string += "\n" 277 | format_string += " {0}".format(t) 278 | format_string += "\n)" 279 | return format_string 280 | 281 | class ColorJitter(object): 282 | def __init__(self, brightness=0, contrast=0, saturatio=0, hue=0): 283 | self.color_jitter = T.ColorJitter(brightness, contrast, saturatio, hue) 284 | 285 | def __call__(self, img, target): 286 | return self.color_jitter(img), target 287 | -------------------------------------------------------------------------------- /engine.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import random 4 | import sys 5 | from typing import Iterable 6 | import numpy as np 7 | import copy 8 | import itertools 9 | 10 | import torch 11 | 12 | import util.misc as utils 13 | from datasets.datasets_gen.hico_eval_triplet import HICOEvaluator as HICOEvaluator_gen 14 | from datasets.datasets_gen.vcoco_eval import VCOCOEvaluator as VCOCOEvaluator_gen 15 | import json 16 | import torch.nn.functional as F 17 | from tqdm import tqdm 18 | import datetime 19 | import time 20 | 21 | 22 | def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, 23 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 24 | device: torch.device, epoch: int, max_norm: float = 0, lr_scheduler=None, 25 | gradient_accumulation_steps=1, enable_amp=False, no_training=False, args=None): 26 | model.train() 27 | criterion.train() 28 | metric_logger = utils.MetricLogger(delimiter=" ") 29 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 30 | if hasattr(criterion, 'loss_labels') and False: 31 | metric_logger.add_meter('class_error', utils.SmoothedValue(window_size=1, fmt='{value:.2f}')) 32 | elif hasattr(criterion, 'loss_hoi_labels'): 33 | metric_logger.add_meter('hoi_class_error', utils.SmoothedValue(window_size=1, fmt='{value:.2f}')) 34 | else: 35 | metric_logger.add_meter('obj_class_error', utils.SmoothedValue(window_size=1, fmt='{value:.2f}')) 36 | header = 'Epoch: [{}]'.format(epoch) 37 | print_freq = 100 38 | 39 | if enable_amp: 40 | print('\nEnable half precision training\n') 41 | 42 | # scaler = GradScaler() 43 | # debug 44 | debug_count = 0 45 | step = 0 46 | for samples, targets in metric_logger.log_every(data_loader, print_freq, header): 47 | if no_training: 48 | samples = samples.to(device) 49 | targets = [{k: v.to(device) for k, v in t.items() if k != 'filename' and k != 'raw_img'} for t in targets] 50 | clip_img = torch.stack([v['clip_inputs'] for v in targets]) 51 | # with autocast(): 52 | obj_feature, hoi_feature, verb_feature = model(samples, clip_input=clip_img, targets=targets) 53 | 54 | metric_logger.update(loss=0) 55 | if hasattr(criterion, 'loss_labels'): 56 | metric_logger.update(class_error=0) 57 | elif hasattr(criterion, 'loss_hoi_labels'): 58 | metric_logger.update(hoi_class_error=0) 59 | else: 60 | metric_logger.update(obj_class_error=0) 61 | metric_logger.update(lr=0) 62 | continue 63 | 64 | samples = samples.to(device) 65 | file_names = [{'filename': i['filename']} for i in targets] 66 | targets = [{k: v.to(device) for k, v in t.items() if k != 'filename' and k != 'raw_img'} for t in targets] 67 | for t, f in zip(targets, file_names): 68 | t.update(f) 69 | clip_img = torch.stack([v['clip_inputs'] for v in targets]) 70 | 71 | outputs = model(samples, clip_input=clip_img, targets=targets) 72 | loss_dict = criterion(outputs, targets) 73 | 74 | weight_dict = criterion.weight_dict 75 | losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict) 76 | 77 | # reduce losses over all GPUs for logging purposes 78 | loss_dict_reduced = utils.reduce_dict(loss_dict) 79 | loss_dict_reduced_unscaled = {f'{k}_unscaled': v 80 | for k, v in loss_dict_reduced.items()} 81 | loss_dict_reduced_scaled = {k: v * weight_dict[k] 82 | for k, v in loss_dict_reduced.items() if k in weight_dict} 83 | losses_reduced_scaled = sum(loss_dict_reduced_scaled.values()) 84 | 85 | loss_value = losses_reduced_scaled.item() 86 | # print(loss_value) 87 | # sys.exit() 88 | 89 | if not math.isfinite(loss_value): 90 | print("Loss is {}, stopping training".format(loss_value)) 91 | print(loss_dict_reduced) 92 | sys.exit(1) 93 | 94 | delay_unscale = (step + 1) % gradient_accumulation_steps != 0 95 | losses = losses / gradient_accumulation_steps 96 | if enable_amp: 97 | raise NotImplementedError 98 | # with amp.scale_loss(losses, optimizer, delay_unscale=delay_unscale) as scaled_loss: 99 | # scaled_loss.backward() 100 | else: 101 | losses.backward() 102 | 103 | if (step + 1) % gradient_accumulation_steps == 0: 104 | if max_norm > 0: 105 | if enable_amp: 106 | pass 107 | # torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), max_norm) 108 | else: 109 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) 110 | optimizer.step() 111 | optimizer.zero_grad() 112 | 113 | if lr_scheduler: 114 | lr_scheduler.iter_step() 115 | 116 | step += 1 117 | 118 | metric_logger.update(loss=loss_value, **loss_dict_reduced_scaled, **loss_dict_reduced_unscaled) 119 | if hasattr(criterion, 'loss_labels') and False: 120 | metric_logger.update(class_error=loss_dict_reduced['class_error']) 121 | elif hasattr(criterion, 'loss_hoi_labels'): 122 | if 'hoi_class_error' in loss_dict_reduced: 123 | metric_logger.update(hoi_class_error=loss_dict_reduced['hoi_class_error']) 124 | else: 125 | metric_logger.update(hoi_class_error=-1) 126 | else: 127 | metric_logger.update(obj_class_error=loss_dict_reduced['obj_class_error']) 128 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 129 | 130 | # trick for generate verb 131 | if no_training: 132 | from datasets.static_hico import HOI_IDX_TO_ACT_IDX, HOI_IDX_TO_OBJ_IDX 133 | hoi_feature = hoi_feature / hoi_feature.norm(dim=1, keepdim=True) 134 | obj_feature = obj_feature / obj_feature.norm(dim=1, keepdim=True) 135 | 136 | y_verb = [HOI_IDX_TO_ACT_IDX[i] for i in range(600)] 137 | y_obj = [HOI_IDX_TO_OBJ_IDX[i] for i in range(600)] 138 | 139 | # composite image feature verb + text feature object 140 | obj_human = [] 141 | for i in range(600): 142 | obj_human.append(obj_feature[y_obj[i]]) 143 | obj_human = torch.stack(obj_human) 144 | verb_human = hoi_feature - obj_human 145 | 146 | verb_feature = torch.zeros(117, 512) 147 | for idx, v in zip(y_verb, verb_human): 148 | verb_feature[idx] += v 149 | 150 | for i in range(117): 151 | verb_feature[i] /= y_verb.count(i) 152 | 153 | v_feature = verb_feature / verb_feature.norm(dim=-1, keepdim=True) 154 | torch.save(v_feature, f'./verb_{args.dataset_file}.pth') 155 | exit() 156 | 157 | # gather the stats from all processes 158 | metric_logger.synchronize_between_processes() 159 | print("Averaged stats:", metric_logger) 160 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 161 | 162 | 163 | @torch.no_grad() 164 | def evaluate_hoi(dataset_file, model, postprocessors, data_loader, 165 | subject_category_id, device, args): 166 | model.eval() 167 | 168 | metric_logger = utils.MetricLogger(delimiter=" ") 169 | header = 'Test:' 170 | 171 | preds = [] 172 | gts = [] 173 | counter = 0 174 | 175 | for samples, targets in metric_logger.log_every(data_loader, 10, header): 176 | samples = samples.to(device) 177 | clip_img = torch.stack([v['clip_inputs'] for v in targets]).to(device) 178 | 179 | outputs = model(samples, is_training=False, clip_input=clip_img, targets=targets) 180 | orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0) 181 | results = postprocessors['hoi'](outputs, orig_target_sizes) 182 | 183 | preds.extend(list(itertools.chain.from_iterable(utils.all_gather(results)))) 184 | # For avoiding a runtime error, the copy is used 185 | gts.extend(list(itertools.chain.from_iterable(utils.all_gather(copy.deepcopy(targets))))) 186 | 187 | counter += 1 188 | if counter >= 20 and args.no_training: 189 | break 190 | # gather the stats from all processes 191 | metric_logger.synchronize_between_processes() 192 | 193 | img_ids = [img_gts['id'] for img_gts in gts] 194 | _, indices = np.unique(img_ids, return_index=True) 195 | preds = [img_preds for i, img_preds in enumerate(preds) if i in indices] 196 | gts = [img_gts for i, img_gts in enumerate(gts) if i in indices] 197 | 198 | """ 199 | For zero-shot enhancement 200 | args.training_free_enhancement_path is the path to store performance for different hyper-parameter 201 | """ 202 | root = os.path.join(args.output_dir, args.training_free_enhancement_path) 203 | if args.training_free_enhancement_path: 204 | 205 | with open(os.path.join(root, 'log.txt'), 'a') as f: 206 | log = f'\n=========The great hyperparameter tuning begins============\n' 207 | print(log) 208 | f.write(log) 209 | 210 | test_pred = copy.deepcopy(preds) 211 | 212 | # testing 213 | if dataset_file == 'hico': 214 | evaluator = HICOEvaluator_gen(test_pred, gts, data_loader.dataset.rare_triplets, 215 | data_loader.dataset.non_rare_triplets, 216 | data_loader.dataset.correct_mat, args=args) 217 | else: 218 | evaluator = VCOCOEvaluator_gen(preds, gts, data_loader.dataset.correct_mat, 219 | use_nms_filter=args.use_nms_filter) 220 | stats = evaluator.evaluate() 221 | 222 | text_hoi_feature = model.transformer.hoi_cls 223 | spatial_feature = torch.cat([i['clip_visual'].unsqueeze(0) for i in preds]) 224 | spatial_feature /= spatial_feature.norm(dim=-1, keepdim=True) 225 | spatial_cls = spatial_feature[:, 0, :] # M, c 226 | 227 | cls_scores = spatial_cls @ text_hoi_feature 228 | with open(os.path.join(root, 'log.txt'), 'a') as f: 229 | log = f'\n=========Baseline Performance============\n{stats}\n============================\n' 230 | print(log) 231 | f.write(log) 232 | 233 | best_performance_1 = 0 234 | for a in [1]: 235 | for co in [1.0]: 236 | for topk in [10, 20, 30, 40, 50]: 237 | print(f'current at topk: {topk} as: {a}') 238 | test_pred = copy.deepcopy(preds) 239 | clip_hoi_score = cls_scores 240 | # clip_hoi_score /= (1 + alpha + beta) 241 | clip_hoi_score_ori = clip_hoi_score.clone() 242 | 243 | ignore_idx = clip_hoi_score.sort(descending=True).indices[:, topk:] 244 | for idx, igx in enumerate(ignore_idx): 245 | clip_hoi_score[idx][igx] *= 0 246 | clip_hoi_score = clip_hoi_score.unsqueeze(1) 247 | 248 | # update logits 249 | for i in range(len(test_pred)): 250 | test_pred[i]['hoi_scores'] += clip_hoi_score[i].sigmoid() * co 251 | # testing 252 | if dataset_file == 'hico': 253 | evaluator = HICOEvaluator_gen(test_pred, gts, data_loader.dataset.rare_triplets, 254 | data_loader.dataset.non_rare_triplets, 255 | data_loader.dataset.correct_mat, args=args) 256 | 257 | else: 258 | evaluator = VCOCOEvaluator_gen(test_pred, gts, data_loader.dataset.correct_mat, 259 | use_nms_filter=args.use_nms_filter) 260 | stats = evaluator.evaluate() 261 | if dataset_file == 'hico': 262 | re_map = stats['mAP'] 263 | elif dataset_file == 'vcoco': 264 | re_map = stats['mAP_all'] 265 | elif dataset_file == 'hoia': 266 | re_map = stats['mAP'] 267 | else: 268 | raise NotImplementedError 269 | 270 | if best_performance_1 < re_map: 271 | best_performance_1 = re_map 272 | 273 | with open(os.path.join(root, 'log.txt'), 'a') as f: 274 | log = f'sigmoid after topk: {topk} as: {a} co: {co}' \ 275 | f'\n performance: {stats}\n' 276 | print(log) 277 | f.write(log) 278 | 279 | if dataset_file == 'hico': 280 | if args.dataset_root == 'GEN': 281 | evaluator = HICOEvaluator_gen(preds, gts, data_loader.dataset.rare_triplets, 282 | data_loader.dataset.non_rare_triplets, data_loader.dataset.correct_mat, 283 | args=args) 284 | elif dataset_file == 'vcoco': 285 | if args.dataset_root == 'GEN': 286 | evaluator = VCOCOEvaluator_gen(preds, gts, data_loader.dataset.correct_mat, 287 | use_nms_filter=args.use_nms_filter) 288 | else: 289 | raise NotImplementedError 290 | start_time = time.time() 291 | stats = evaluator.evaluate() 292 | total_time = time.time() - start_time 293 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 294 | print('Total time computing mAP: {}'.format(total_time_str)) 295 | 296 | return stats 297 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .models_gen.gen_vlkt import build as build_gen 2 | from .models_hoiclip.hoiclip import build as build_models_hoiclip 3 | from .visualization_hoiclip.gen_vlkt import build as visualization 4 | from .generate_image_feature.generate_verb import build as generate_verb 5 | 6 | 7 | def build_model(args): 8 | if args.model_name == "HOICLIP": 9 | return build_models_hoiclip(args) 10 | elif args.model_name == "GEN": 11 | return build_gen(args) 12 | elif args.model_name == "VISUALIZATION": 13 | return visualization(args) 14 | elif args.model_name == "GENERATE_VERB": 15 | return generate_verb(args) 16 | 17 | raise ValueError(f'Model {args.model_name} not supported') 18 | -------------------------------------------------------------------------------- /models/backbone.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Modified from DETR (https://github.com/facebookresearch/detr) 3 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 4 | # ------------------------------------------------------------------------ 5 | """ 6 | Backbone modules. 7 | """ 8 | from collections import OrderedDict 9 | 10 | import torch 11 | import torch.nn.functional as F 12 | import torchvision 13 | from torch import nn 14 | from torchvision.models._utils import IntermediateLayerGetter 15 | from typing import Dict, List 16 | 17 | from util.misc import NestedTensor, is_main_process 18 | 19 | from .position_encoding import build_position_encoding 20 | 21 | 22 | class FrozenBatchNorm2d(torch.nn.Module): 23 | """ 24 | BatchNorm2d where the batch statistics and the affine parameters are fixed. 25 | 26 | Copy-paste from torchvision.misc.ops with added eps before rqsrt, 27 | without which any other models than torchvision.models.resnet[18,34,50,101] 28 | produce nans. 29 | """ 30 | 31 | def __init__(self, n): 32 | super(FrozenBatchNorm2d, self).__init__() 33 | self.register_buffer("weight", torch.ones(n)) 34 | self.register_buffer("bias", torch.zeros(n)) 35 | self.register_buffer("running_mean", torch.zeros(n)) 36 | self.register_buffer("running_var", torch.ones(n)) 37 | 38 | def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, 39 | missing_keys, unexpected_keys, error_msgs): 40 | num_batches_tracked_key = prefix + 'num_batches_tracked' 41 | if num_batches_tracked_key in state_dict: 42 | del state_dict[num_batches_tracked_key] 43 | 44 | super(FrozenBatchNorm2d, self)._load_from_state_dict( 45 | state_dict, prefix, local_metadata, strict, 46 | missing_keys, unexpected_keys, error_msgs) 47 | 48 | def forward(self, x): 49 | # move reshapes to the beginning 50 | # to make it fuser-friendly 51 | w = self.weight.reshape(1, -1, 1, 1) 52 | b = self.bias.reshape(1, -1, 1, 1) 53 | rv = self.running_var.reshape(1, -1, 1, 1) 54 | rm = self.running_mean.reshape(1, -1, 1, 1) 55 | eps = 1e-5 56 | scale = w * (rv + eps).rsqrt() 57 | bias = b - rm * scale 58 | return x * scale + bias 59 | 60 | 61 | class BackboneBase(nn.Module): 62 | 63 | def __init__(self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool): 64 | super().__init__() 65 | for name, parameter in backbone.named_parameters(): 66 | if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name: 67 | parameter.requires_grad_(False) 68 | if return_interm_layers: 69 | return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"} 70 | else: 71 | return_layers = {'layer4': "0"} 72 | self.body = IntermediateLayerGetter(backbone, return_layers=return_layers) 73 | self.num_channels = num_channels 74 | 75 | def forward(self, tensor_list: NestedTensor): 76 | xs = self.body(tensor_list.tensors) 77 | out: Dict[str, NestedTensor] = {} 78 | for name, x in xs.items(): 79 | m = tensor_list.mask 80 | assert m is not None 81 | mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0] 82 | out[name] = NestedTensor(x, mask) 83 | return out 84 | 85 | 86 | class Backbone(BackboneBase): 87 | """ResNet backbone with frozen BatchNorm.""" 88 | def __init__(self, name: str, 89 | train_backbone: bool, 90 | return_interm_layers: bool, 91 | dilation: bool): 92 | backbone = getattr(torchvision.models, name)( 93 | replace_stride_with_dilation=[False, False, dilation], 94 | pretrained=False, norm_layer=FrozenBatchNorm2d) 95 | num_channels = 512 if name in ('resnet18', 'resnet34') else 2048 96 | super().__init__(backbone, train_backbone, num_channels, return_interm_layers) 97 | 98 | 99 | class Joiner(nn.Sequential): 100 | def __init__(self, backbone, position_embedding): 101 | super().__init__(backbone, position_embedding) 102 | 103 | def forward(self, tensor_list: NestedTensor): 104 | xs = self[0](tensor_list) 105 | out: List[NestedTensor] = [] 106 | pos = [] 107 | for name, x in xs.items(): 108 | out.append(x) 109 | # position encoding 110 | pos.append(self[1](x).to(x.tensors.dtype)) 111 | 112 | return out, pos 113 | 114 | 115 | def build_backbone(args): 116 | position_embedding = build_position_encoding(args) 117 | train_backbone = args.lr_backbone > 0 118 | return_interm_layers = args.masks 119 | backbone = Backbone(args.backbone, train_backbone, return_interm_layers, args.dilation) 120 | model = Joiner(backbone, position_embedding) 121 | model.num_channels = backbone.num_channels 122 | return model 123 | -------------------------------------------------------------------------------- /models/generate_image_feature/generate_verb.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | from util.box_ops import box_cxcywh_to_xyxy, generalized_box_iou 6 | from util.misc import (NestedTensor, nested_tensor_from_tensor_list, 7 | accuracy, get_world_size, 8 | is_dist_avail_and_initialized) 9 | import numpy as np 10 | import clip 11 | from datasets.hico_text_label import hico_text_label, hico_obj_text_label, hico_unseen_index 12 | from datasets.vcoco_text_label import vcoco_hoi_text_label, vcoco_obj_text_label 13 | from datasets.static_hico import HOI_IDX_TO_ACT_IDX 14 | 15 | from ..backbone import build_backbone 16 | from ..matcher import build_matcher 17 | from .gen import build_gen 18 | from PIL import Image 19 | import torchvision.transforms as T 20 | 21 | 22 | def _sigmoid(x): 23 | y = torch.clamp(x.sigmoid(), min=1e-4, max=1 - 1e-4) 24 | return y 25 | 26 | 27 | class GEN_VLKT(nn.Module): 28 | def __init__(self, backbone, transformer, num_queries, aux_loss=False, args=None): 29 | super().__init__() 30 | self.args = args 31 | self.clip_model, self.preprocess = clip.load(self.args.clip_model) 32 | 33 | self.obj_feature = nn.Parameter(torch.zeros(81, 512)) 34 | self.hoi_feature = nn.Parameter(torch.zeros(600, 512)) 35 | self.verb_feature = nn.Parameter(torch.zeros(117, 512)) 36 | 37 | 38 | def forward(self, samples: NestedTensor, is_training=True, clip_input=None, targets=None): 39 | for t in targets: 40 | if t['obj_boxes'].shape[0] == 0 or t['human_img'].shape[0] == 0 or t['hoi_area_img'].shape[0] == 0: 41 | continue 42 | # print(f"human_img: {t['human_img'].shape}") 43 | h_feature = self.clip_model.encode_image(t['human_img'])[0] 44 | o_feature = self.clip_model.encode_image(t['object_img'])[0] 45 | hoi_feature = self.clip_model.encode_image(t['hoi_area_img'])[0] 46 | verb_feature = hoi_feature.clone() * 2 - o_feature.clone() - h_feature.clone() 47 | 48 | # h_feature = h_feature / h_feature.norm(dim=1, keepdim=True) 49 | # o_feature = o_feature / o_feature.norm(dim=1, keepdim=True) 50 | # hoi_feature = hoi_feature / hoi_feature.norm(dim=1, keepdim=True) 51 | if h_feature.shape[0] != o_feature.shape[0] or h_feature.shape[0] != hoi_feature.shape[0]: 52 | raise ValueError 53 | 54 | obj_label = t['obj_cls'] 55 | hoi_label = t['hoi_cls'] - 1 56 | if obj_label.shape[0] != o_feature.shape[0] or hoi_label.shape[0] != hoi_feature.shape[0]: 57 | raise ValueError 58 | 59 | if obj_label.max() > 80 or hoi_label.max() >= 600: 60 | raise ValueError 61 | ver_label = torch.tensor(HOI_IDX_TO_ACT_IDX)[hoi_label] 62 | 63 | # print(f"obj_label: {obj_label}") 64 | # print(f"hoi_label: {hoi_label}") 65 | 66 | 67 | 68 | if torch.isnan(o_feature.sum()) or torch.isnan(hoi_feature.sum()) or torch.isnan(h_feature.sum()): 69 | print(obj_label) 70 | print(hoi_label) 71 | continue 72 | 73 | if 66 in hoi_label or 166 in hoi_label: 74 | print(hoi_label) 75 | print(hoi_feature) 76 | 77 | self.obj_feature.data[0] += h_feature.mean(dim=0) 78 | self.obj_feature.data[obj_label] += o_feature 79 | self.hoi_feature.data[hoi_label] += hoi_feature 80 | self.verb_feature.data[ver_label] += verb_feature 81 | 82 | return self.obj_feature.data, self.hoi_feature.data, self.verb_feature.data 83 | 84 | 85 | class PostProcessHOITriplet(nn.Module): 86 | 87 | def __init__(self, args): 88 | super().__init__() 89 | self.subject_category_id = args.subject_category_id 90 | 91 | @torch.no_grad() 92 | def forward(self, outputs, target_sizes): 93 | out_hoi_logits = outputs['pred_hoi_logits'] 94 | out_obj_logits = outputs['pred_obj_logits'] 95 | out_sub_boxes = outputs['pred_sub_boxes'] 96 | out_obj_boxes = outputs['pred_obj_boxes'] 97 | 98 | assert len(out_hoi_logits) == len(target_sizes) 99 | assert target_sizes.shape[1] == 2 100 | 101 | hoi_scores = out_hoi_logits.sigmoid() 102 | obj_scores = out_obj_logits.sigmoid() 103 | obj_labels = F.softmax(out_obj_logits, -1)[..., :-1].max(-1)[1] 104 | 105 | img_h, img_w = target_sizes.unbind(1) 106 | scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(hoi_scores.device) 107 | sub_boxes = box_cxcywh_to_xyxy(out_sub_boxes) 108 | sub_boxes = sub_boxes * scale_fct[:, None, :] 109 | obj_boxes = box_cxcywh_to_xyxy(out_obj_boxes) 110 | obj_boxes = obj_boxes * scale_fct[:, None, :] 111 | 112 | results = [] 113 | for index in range(len(hoi_scores)): 114 | hs, os, ol, sb, ob = hoi_scores[index], obj_scores[index], obj_labels[index], sub_boxes[index], obj_boxes[ 115 | index] 116 | sl = torch.full_like(ol, self.subject_category_id) 117 | l = torch.cat((sl, ol)) 118 | b = torch.cat((sb, ob)) 119 | results.append({'labels': l.to('cpu'), 'boxes': b.to('cpu')}) 120 | 121 | ids = torch.arange(b.shape[0]) 122 | 123 | results[-1].update({'hoi_scores': hs.to('cpu'), 'obj_scores': os.to('cpu'), 124 | 'sub_ids': ids[:ids.shape[0] // 2], 'obj_ids': ids[ids.shape[0] // 2:]}) 125 | 126 | return results 127 | 128 | 129 | def build(args): 130 | device = torch.device(args.device) 131 | 132 | backbone = build_backbone(args) 133 | 134 | gen = build_gen(args) 135 | 136 | model = GEN_VLKT( 137 | backbone, 138 | gen, 139 | num_queries=args.num_queries, 140 | aux_loss=args.aux_loss, 141 | args=args 142 | ) 143 | 144 | matcher = build_matcher(args) 145 | weight_dict = {} 146 | if args.with_clip_label: 147 | weight_dict['loss_hoi_labels'] = args.hoi_loss_coef 148 | weight_dict['loss_obj_ce'] = args.obj_loss_coef 149 | else: 150 | weight_dict['loss_hoi_labels'] = args.hoi_loss_coef 151 | weight_dict['loss_obj_ce'] = args.obj_loss_coef 152 | 153 | weight_dict['loss_sub_bbox'] = args.bbox_loss_coef 154 | weight_dict['loss_obj_bbox'] = args.bbox_loss_coef 155 | weight_dict['loss_sub_giou'] = args.giou_loss_coef 156 | weight_dict['loss_obj_giou'] = args.giou_loss_coef 157 | if args.with_mimic: 158 | weight_dict['loss_feat_mimic'] = args.mimic_loss_coef 159 | 160 | if args.with_rec_loss: 161 | weight_dict['loss_rec'] = args.rec_loss_coef 162 | 163 | 164 | return model, model, model 165 | -------------------------------------------------------------------------------- /models/matcher.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from scipy.optimize import linear_sum_assignment 3 | from torch import nn 4 | 5 | from util.box_ops import box_cxcywh_to_xyxy, generalized_box_iou 6 | 7 | class HungarianMatcherHOI(nn.Module): 8 | 9 | def __init__(self, cost_obj_class: float = 1, cost_verb_class: float = 1, cost_bbox: float = 1, 10 | cost_giou: float = 1, cost_hoi_class: float = 1): 11 | super().__init__() 12 | self.cost_obj_class = cost_obj_class 13 | self.cost_verb_class = cost_verb_class 14 | self.cost_hoi_class = cost_hoi_class 15 | self.cost_bbox = cost_bbox 16 | self.cost_giou = cost_giou 17 | assert cost_obj_class != 0 or cost_verb_class != 0 or cost_bbox != 0 or cost_giou != 0, 'all costs cant be 0' 18 | 19 | @torch.no_grad() 20 | def forward(self, outputs, targets): 21 | bs, num_queries = outputs['pred_sub_boxes'].shape[:2] 22 | if 'pred_hoi_logits' in outputs.keys(): 23 | out_hoi_prob = outputs['pred_hoi_logits'].flatten(0, 1).sigmoid() 24 | tgt_hoi_labels = torch.cat([v['hoi_labels'] for v in targets]) 25 | tgt_hoi_labels_permute = tgt_hoi_labels.permute(1, 0) 26 | cost_hoi_class = -(out_hoi_prob.matmul(tgt_hoi_labels_permute) / \ 27 | (tgt_hoi_labels_permute.sum(dim=0, keepdim=True) + 1e-4) + \ 28 | (1 - out_hoi_prob).matmul(1 - tgt_hoi_labels_permute) / \ 29 | ((1 - tgt_hoi_labels_permute).sum(dim=0, keepdim=True) + 1e-4)) / 2 30 | cost_hoi_class = self.cost_hoi_class * cost_hoi_class 31 | elif 'pred_verb_logits' in outputs.keys(): 32 | out_verb_prob = outputs['pred_verb_logits'].flatten(0, 1).sigmoid() 33 | tgt_verb_labels = torch.cat([v['verb_labels'] for v in targets]) 34 | tgt_verb_labels_permute = tgt_verb_labels.permute(1, 0) 35 | cost_verb_class = -(out_verb_prob.matmul(tgt_verb_labels_permute) / \ 36 | (tgt_verb_labels_permute.sum(dim=0, keepdim=True) + 1e-4) + \ 37 | (1 - out_verb_prob).matmul(1 - tgt_verb_labels_permute) / \ 38 | ((1 - tgt_verb_labels_permute).sum(dim=0, keepdim=True) + 1e-4)) / 2 39 | cost_hoi_class = self.cost_verb_class * cost_verb_class 40 | else: 41 | cost_hoi_class = 0 42 | 43 | 44 | tgt_obj_labels = torch.cat([v['obj_labels'] for v in targets]) 45 | out_obj_prob = outputs['pred_obj_logits'].flatten(0, 1).softmax(-1) 46 | cost_obj_class = -out_obj_prob[:, tgt_obj_labels] 47 | out_sub_bbox = outputs['pred_sub_boxes'].flatten(0, 1) 48 | out_obj_bbox = outputs['pred_obj_boxes'].flatten(0, 1) 49 | 50 | tgt_sub_boxes = torch.cat([v['sub_boxes'] for v in targets]) 51 | tgt_obj_boxes = torch.cat([v['obj_boxes'] for v in targets]) 52 | 53 | if out_sub_bbox.dtype == torch.float16: 54 | out_sub_bbox = out_sub_bbox.type(torch.float32) 55 | out_obj_bbox = out_obj_bbox.type(torch.float32) 56 | 57 | cost_sub_bbox = torch.cdist(out_sub_bbox, tgt_sub_boxes, p=1) 58 | cost_obj_bbox = torch.cdist(out_obj_bbox, tgt_obj_boxes, p=1) * (tgt_obj_boxes != 0).any(dim=1).unsqueeze(0) 59 | if cost_sub_bbox.shape[1] == 0: 60 | cost_bbox = cost_sub_bbox 61 | else: 62 | cost_bbox = torch.stack((cost_sub_bbox, cost_obj_bbox)).max(dim=0)[0] 63 | 64 | cost_sub_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_sub_bbox), box_cxcywh_to_xyxy(tgt_sub_boxes)) 65 | cost_obj_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_obj_bbox), box_cxcywh_to_xyxy(tgt_obj_boxes)) + \ 66 | cost_sub_giou * (tgt_obj_boxes == 0).all(dim=1).unsqueeze(0) 67 | if cost_sub_giou.shape[1] == 0: 68 | cost_giou = cost_sub_giou 69 | else: 70 | cost_giou = torch.stack((cost_sub_giou, cost_obj_giou)).max(dim=0)[0] 71 | 72 | C = self.cost_hoi_class * cost_hoi_class + self.cost_bbox * cost_bbox + \ 73 | self.cost_giou * cost_giou + self.cost_obj_class * cost_obj_class 74 | 75 | C = C.view(bs, num_queries, -1).cpu() 76 | 77 | sizes = [len(v['sub_boxes']) for v in targets] 78 | indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))] 79 | return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices] 80 | 81 | 82 | def build_matcher(args): 83 | return HungarianMatcherHOI(cost_obj_class=args.set_cost_obj_class, cost_verb_class=args.set_cost_verb_class, 84 | cost_bbox=args.set_cost_bbox, cost_giou=args.set_cost_giou, 85 | cost_hoi_class=args.set_cost_hoi) 86 | -------------------------------------------------------------------------------- /models/position_encoding.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Modified from DETR (https://github.com/facebookresearch/detr) 3 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 4 | # ------------------------------------------------------------------------ 5 | """ 6 | Various positional encodings for the transformer. 7 | """ 8 | import math 9 | import torch 10 | from torch import nn 11 | 12 | from util.misc import NestedTensor 13 | 14 | 15 | class PositionEmbeddingSine(nn.Module): 16 | """ 17 | This is a more standard version of the position embedding, very similar to the one 18 | used by the Attention is all you need paper, generalized to work on images. 19 | """ 20 | def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): 21 | super().__init__() 22 | self.num_pos_feats = num_pos_feats 23 | self.temperature = temperature 24 | self.normalize = normalize 25 | if scale is not None and normalize is False: 26 | raise ValueError("normalize should be True if scale is passed") 27 | if scale is None: 28 | scale = 2 * math.pi 29 | self.scale = scale 30 | 31 | def forward(self, tensor_list: NestedTensor): 32 | x = tensor_list.tensors 33 | mask = tensor_list.mask 34 | assert mask is not None 35 | not_mask = ~mask 36 | y_embed = not_mask.cumsum(1, dtype=torch.float32) 37 | x_embed = not_mask.cumsum(2, dtype=torch.float32) 38 | if self.normalize: 39 | eps = 1e-6 40 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale 41 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale 42 | 43 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) 44 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) 45 | 46 | pos_x = x_embed[:, :, :, None] / dim_t 47 | pos_y = y_embed[:, :, :, None] / dim_t 48 | pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) 49 | pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) 50 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) 51 | return pos 52 | 53 | 54 | class PositionEmbeddingLearned(nn.Module): 55 | """ 56 | Absolute pos embedding, learned. 57 | """ 58 | def __init__(self, num_pos_feats=256): 59 | super().__init__() 60 | self.row_embed = nn.Embedding(50, num_pos_feats) 61 | self.col_embed = nn.Embedding(50, num_pos_feats) 62 | self.reset_parameters() 63 | 64 | def reset_parameters(self): 65 | nn.init.uniform_(self.row_embed.weight) 66 | nn.init.uniform_(self.col_embed.weight) 67 | 68 | def forward(self, tensor_list: NestedTensor): 69 | x = tensor_list.tensors 70 | h, w = x.shape[-2:] 71 | i = torch.arange(w, device=x.device) 72 | j = torch.arange(h, device=x.device) 73 | x_emb = self.col_embed(i) 74 | y_emb = self.row_embed(j) 75 | pos = torch.cat([ 76 | x_emb.unsqueeze(0).repeat(h, 1, 1), 77 | y_emb.unsqueeze(1).repeat(1, w, 1), 78 | ], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1) 79 | return pos 80 | 81 | 82 | def build_position_encoding(args): 83 | N_steps = args.hidden_dim // 2 84 | if args.position_embedding in ('v2', 'sine'): 85 | # TODO find a better way of exposing other arguments 86 | position_embedding = PositionEmbeddingSine(N_steps, normalize=True) 87 | elif args.position_embedding in ('v3', 'learned'): 88 | position_embedding = PositionEmbeddingLearned(N_steps) 89 | else: 90 | raise ValueError(f"not supported {args.position_embedding}") 91 | 92 | return position_embedding 93 | -------------------------------------------------------------------------------- /paper_images/intro.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Artanic30/HOICLIP/ee4db062097410abdd20fa96d40d26aaca1f19da/paper_images/intro.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | cython 2 | pycocotools 3 | torch==1.7.1 4 | torchvision==0.8.2 5 | scipy==1.3.1 6 | opencv-python 7 | ftfy 8 | regex 9 | tqdm -------------------------------------------------------------------------------- /scripts/generate_verb.sh: -------------------------------------------------------------------------------- 1 | ulimit -n 4096 2 | set -x 3 | EXP_DIR=exps/hico/hoiclip 4 | 5 | for i in 1 2 3 4 5 6 7 8 9 8 7 6 5 4 3 2 1 2 3 4 5 6 7 8 9 6 | do 7 | swapon --show 8 | free -h 9 | export NCCL_P2P_LEVEL=NVL 10 | export OMP_NUM_THREADS=8 11 | python -m torch.distributed.launch \ 12 | --nproc_per_node=1 \ 13 | --master_port $[29403 + i] \ 14 | --use_env \ 15 | main.py \ 16 | --output_dir ${EXP_DIR} \ 17 | --dataset_file hico \ 18 | --hoi_path data/hico_20160224_det \ 19 | --num_obj_classes 80 \ 20 | --num_verb_classes 117 \ 21 | --backbone resnet50 \ 22 | --num_queries 64 \ 23 | --dec_layers 3 \ 24 | --epochs 1 \ 25 | --use_nms_filter \ 26 | --fix_clip \ 27 | --batch_size 16 \ 28 | --pretrained params/detr-r50-pre-2branch-hico.pth \ 29 | --with_clip_label \ 30 | --with_obj_clip_label \ 31 | --gradient_accumulation_steps 1 \ 32 | --num_workers 8 \ 33 | --opt_sched "multiStep" \ 34 | --dataset_root GENERATE_VERB \ 35 | --model_name GENERATE_VERB \ 36 | --no_training 37 | sleep 120 38 | done 39 | -------------------------------------------------------------------------------- /scripts/train_hico.sh: -------------------------------------------------------------------------------- 1 | ulimit -n 4096 2 | set -x 3 | EXP_DIR=exps/hico/hoiclip 4 | 5 | for i in 1 2 3 4 5 6 7 8 9 8 7 6 5 4 3 2 1 2 3 4 5 6 7 8 9 6 | do 7 | swapon --show 8 | free -h 9 | export NCCL_P2P_LEVEL=NVL 10 | export OMP_NUM_THREADS=8 11 | python -m torch.distributed.launch \ 12 | --nproc_per_node=2 \ 13 | --master_port $[29403 + i] \ 14 | --use_env \ 15 | main.py \ 16 | --output_dir ${EXP_DIR} \ 17 | --dataset_file hico \ 18 | --hoi_path data/hico_20160224_det \ 19 | --num_obj_classes 80 \ 20 | --num_verb_classes 117 \ 21 | --backbone resnet50 \ 22 | --num_queries 64 \ 23 | --dec_layers 3 \ 24 | --epochs 90 \ 25 | --lr_drop 60 \ 26 | --use_nms_filter \ 27 | --fix_clip \ 28 | --batch_size 8 \ 29 | --pretrained params/detr-r50-pre-2branch-hico.pth \ 30 | --with_clip_label \ 31 | --with_obj_clip_label \ 32 | --gradient_accumulation_steps 1 \ 33 | --num_workers 8 \ 34 | --opt_sched "multiStep" \ 35 | --dataset_root GEN \ 36 | --model_name HOICLIP \ 37 | --zero_shot_type default \ 38 | --resume ${EXP_DIR}/checkpoint_last.pth \ 39 | --verb_pth ./tmp/verb.pth \ 40 | --training_free_enhancement_path \ 41 | ./training_free_ehnahcement/ 42 | sleep 120 43 | done 44 | -------------------------------------------------------------------------------- /scripts/train_hico_frac.sh: -------------------------------------------------------------------------------- 1 | ulimit -n 4096 2 | set -x 3 | EXP_DIR=exps/hico/hoiclip 4 | 5 | for i in 1 2 3 4 5 6 7 8 9 8 7 6 5 4 3 2 1 2 3 4 5 6 7 8 9 6 | do 7 | swapon --show 8 | free -h 9 | export NCCL_P2P_LEVEL=NVL 10 | export OMP_NUM_THREADS=8 11 | python -m torch.distributed.launch \ 12 | --nproc_per_node=2 \ 13 | --master_port $[29403 + i] \ 14 | --use_env \ 15 | main.py \ 16 | --output_dir ${EXP_DIR} \ 17 | --dataset_file hico \ 18 | --hoi_path data/hico_20160224_det \ 19 | --num_obj_classes 80 \ 20 | --num_verb_classes 117 \ 21 | --backbone resnet50 \ 22 | --num_queries 64 \ 23 | --dec_layers 3 \ 24 | --epochs 90 \ 25 | --lr_drop 60 \ 26 | --use_nms_filter \ 27 | --fix_clip \ 28 | --batch_size 8 \ 29 | --pretrained params/detr-r50-pre-2branch-hico.pth \ 30 | --with_clip_label \ 31 | --with_obj_clip_label \ 32 | --gradient_accumulation_steps 1 \ 33 | --num_workers 8 \ 34 | --opt_sched "multiStep" \ 35 | --dataset_root GEN \ 36 | --model_name HOICLIP \ 37 | --zero_shot_type default \ 38 | --frac 50% \ 39 | --resume ${EXP_DIR}/checkpoint_last.pth \ 40 | --verb_pth ./tmp/verb.pth \ 41 | --training_free_enhancement_path \ 42 | ./training_free_ehnahcement/ 43 | sleep 120 44 | done 45 | -------------------------------------------------------------------------------- /scripts/train_hico_nrf_uc.sh: -------------------------------------------------------------------------------- 1 | ulimit -n 4096 2 | set -x 3 | EXP_DIR=exps/hico/hoiclip 4 | 5 | for i in 1 2 3 4 5 6 7 8 9 8 7 6 5 4 3 2 1 2 3 4 5 6 7 8 9 6 | do 7 | swapon --show 8 | free -h 9 | export NCCL_P2P_LEVEL=NVL 10 | export OMP_NUM_THREADS=8 11 | python -m torch.distributed.launch \ 12 | --nproc_per_node=2 \ 13 | --master_port $[29403 + i] \ 14 | --use_env \ 15 | main.py \ 16 | --output_dir ${EXP_DIR} \ 17 | --dataset_file hico \ 18 | --hoi_path data/hico_20160224_det \ 19 | --num_obj_classes 80 \ 20 | --num_verb_classes 117 \ 21 | --backbone resnet50 \ 22 | --num_queries 64 \ 23 | --dec_layers 3 \ 24 | --epochs 90 \ 25 | --lr_drop 60 \ 26 | --use_nms_filter \ 27 | --fix_clip \ 28 | --batch_size 8 \ 29 | --pretrained params/detr-r50-pre-2branch-hico.pth \ 30 | --with_clip_label \ 31 | --with_obj_clip_label \ 32 | --gradient_accumulation_steps 1 \ 33 | --num_workers 8 \ 34 | --opt_sched "multiStep" \ 35 | --dataset_root GEN \ 36 | --model_name HOICLIP \ 37 | --del_unseen \ 38 | --zero_shot_type non_rare_first \ 39 | --resume ${EXP_DIR}/checkpoint_last.pth \ 40 | --verb_pth ./tmp/verb.pth \ 41 | --training_free_enhancement_path \ 42 | ./training_free_ehnahcement/ 43 | sleep 120 44 | done 45 | -------------------------------------------------------------------------------- /scripts/train_hico_rf_uc.sh: -------------------------------------------------------------------------------- 1 | ulimit -n 4096 2 | set -x 3 | EXP_DIR=exps/hico/hoiclip 4 | 5 | for i in 1 2 3 4 5 6 7 8 9 8 7 6 5 4 3 2 1 2 3 4 5 6 7 8 9 6 | do 7 | swapon --show 8 | free -h 9 | export NCCL_P2P_LEVEL=NVL 10 | export OMP_NUM_THREADS=8 11 | python -m torch.distributed.launch \ 12 | --nproc_per_node=2 \ 13 | --master_port $[29403 + i] \ 14 | --use_env \ 15 | main.py \ 16 | --output_dir ${EXP_DIR} \ 17 | --dataset_file hico \ 18 | --hoi_path data/hico_20160224_det \ 19 | --num_obj_classes 80 \ 20 | --num_verb_classes 117 \ 21 | --backbone resnet50 \ 22 | --num_queries 64 \ 23 | --dec_layers 3 \ 24 | --epochs 90 \ 25 | --lr_drop 60 \ 26 | --use_nms_filter \ 27 | --fix_clip \ 28 | --batch_size 8 \ 29 | --pretrained params/detr-r50-pre-2branch-hico.pth \ 30 | --with_clip_label \ 31 | --with_obj_clip_label \ 32 | --gradient_accumulation_steps 1 \ 33 | --num_workers 8 \ 34 | --opt_sched "multiStep" \ 35 | --dataset_root GEN \ 36 | --model_name HOICLIP \ 37 | --del_unseen \ 38 | --zero_shot_type rare_first \ 39 | --resume ${EXP_DIR}/checkpoint_last.pth \ 40 | --verb_pth ./tmp/verb.pth \ 41 | --training_free_enhancement_path \ 42 | ./training_free_ehnahcement/ 43 | sleep 120 44 | done 45 | -------------------------------------------------------------------------------- /scripts/train_hico_uo.sh: -------------------------------------------------------------------------------- 1 | ulimit -n 4096 2 | set -x 3 | EXP_DIR=exps/hico/hoiclip 4 | 5 | for i in 1 2 3 4 5 6 7 8 9 8 7 6 5 4 3 2 1 2 3 4 5 6 7 8 9 6 | do 7 | swapon --show 8 | free -h 9 | export NCCL_P2P_LEVEL=NVL 10 | export OMP_NUM_THREADS=8 11 | python -m torch.distributed.launch \ 12 | --nproc_per_node=2 \ 13 | --master_port $[29403 + i] \ 14 | --use_env \ 15 | main.py \ 16 | --output_dir ${EXP_DIR} \ 17 | --dataset_file hico \ 18 | --hoi_path data/hico_20160224_det \ 19 | --num_obj_classes 80 \ 20 | --num_verb_classes 117 \ 21 | --backbone resnet50 \ 22 | --num_queries 64 \ 23 | --dec_layers 3 \ 24 | --epochs 90 \ 25 | --lr_drop 60 \ 26 | --use_nms_filter \ 27 | --fix_clip \ 28 | --batch_size 8 \ 29 | --pretrained params/detr-r50-pre-2branch-hico.pth \ 30 | --with_clip_label \ 31 | --with_obj_clip_label \ 32 | --gradient_accumulation_steps 1 \ 33 | --num_workers 8 \ 34 | --opt_sched "multiStep" \ 35 | --dataset_root GEN \ 36 | --model_name HOICLIP \ 37 | --del_unseen \ 38 | --zero_shot_type unseen_object \ 39 | --resume ${EXP_DIR}/checkpoint_last.pth \ 40 | --verb_pth ./tmp/verb.pth \ 41 | --training_free_enhancement_path \ 42 | ./training_free_ehnahcement/ 43 | sleep 120 44 | done 45 | -------------------------------------------------------------------------------- /scripts/train_hico_uv.sh: -------------------------------------------------------------------------------- 1 | ulimit -n 4096 2 | set -x 3 | EXP_DIR=exps/hico/hoiclip 4 | 5 | for i in 1 2 3 4 5 6 7 8 9 8 7 6 5 4 3 2 1 2 3 4 5 6 7 8 9 6 | do 7 | swapon --show 8 | free -h 9 | export NCCL_P2P_LEVEL=NVL 10 | export OMP_NUM_THREADS=8 11 | python -m torch.distributed.launch \ 12 | --nproc_per_node=2 \ 13 | --master_port $[29403 + i] \ 14 | --use_env \ 15 | main.py \ 16 | --output_dir ${EXP_DIR} \ 17 | --dataset_file hico \ 18 | --hoi_path data/hico_20160224_det \ 19 | --num_obj_classes 80 \ 20 | --num_verb_classes 117 \ 21 | --backbone resnet50 \ 22 | --num_queries 64 \ 23 | --dec_layers 3 \ 24 | --epochs 90 \ 25 | --lr_drop 60 \ 26 | --use_nms_filter \ 27 | --fix_clip \ 28 | --batch_size 8 \ 29 | --pretrained params/detr-r50-pre-2branch-hico.pth \ 30 | --with_clip_label \ 31 | --with_obj_clip_label \ 32 | --gradient_accumulation_steps 1 \ 33 | --num_workers 8 \ 34 | --opt_sched "multiStep" \ 35 | --dataset_root GEN \ 36 | --model_name HOICLIP \ 37 | --del_unseen \ 38 | --zero_shot_type unseen_verb \ 39 | --resume ${EXP_DIR}/checkpoint_last.pth \ 40 | --verb_pth ./tmp/verb.pth \ 41 | --training_free_enhancement_path \ 42 | ./training_free_ehnahcement/ 43 | sleep 120 44 | done 45 | -------------------------------------------------------------------------------- /scripts/train_vcoco.sh: -------------------------------------------------------------------------------- 1 | ulimit -n 4096 2 | set -x 3 | EXP_DIR=exps/hico/hoiclip 4 | 5 | for i in 1 2 3 4 5 6 7 8 9 8 7 6 5 4 3 2 1 2 3 4 5 6 7 8 9 6 | do 7 | swapon --show 8 | free -h 9 | export NCCL_P2P_LEVEL=NVL 10 | export OMP_NUM_THREADS=8 11 | python -m torch.distributed.launch \ 12 | --nproc_per_node=2 \ 13 | --master_port $[29403 + i] \ 14 | --use_env \ 15 | main.py \ 16 | --output_dir ${EXP_DIR} \ 17 | --dataset_file vcoco \ 18 | --hoi_path data/v-coco \ 19 | --num_obj_classes 81 \ 20 | --num_verb_classes 29 \ 21 | --backbone resnet50 \ 22 | --num_queries 64 \ 23 | --dec_layers 3 \ 24 | --epochs 90 \ 25 | --lr_drop 60 \ 26 | --use_nms_filter \ 27 | --fix_clip \ 28 | --batch_size 8 \ 29 | --pretrained params/detr-r50-pre-2branch-hico.pth \ 30 | --with_clip_label \ 31 | --with_obj_clip_label \ 32 | --gradient_accumulation_steps 1 \ 33 | --num_workers 8 \ 34 | --opt_sched "multiStep" \ 35 | --dataset_root GEN \ 36 | --model_name HOICLIP \ 37 | --zero_shot_type default \ 38 | --resume ${EXP_DIR}/checkpoint_last.pth \ 39 | --verb_pth ./tmp/verb.pth \ 40 | --verb_weight 0.1 \ 41 | --training_free_enhancement_path \ 42 | ./training_free_ehnahcement/ 43 | sleep 120 44 | done 45 | -------------------------------------------------------------------------------- /scripts/visualization_hico.sh: -------------------------------------------------------------------------------- 1 | ulimit -n 4096 2 | set -x 3 | EXP_DIR=exps/hico/hoiclip 4 | 5 | python main.py \ 6 | --output_dir ${EXP_DIR} \ 7 | --dataset_file hico \ 8 | --hoi_path data/hico_20160224_det \ 9 | --num_obj_classes 80 \ 10 | --num_verb_classes 117 \ 11 | --backbone resnet50 \ 12 | --num_queries 64 \ 13 | --dec_layers 3 \ 14 | --epochs 90 \ 15 | --lr_drop 60 \ 16 | --use_nms_filter \ 17 | --fix_clip \ 18 | --batch_size 8 \ 19 | --fs_num -1 \ 20 | --pretrained params/detr-r50-pre-2branch-hico.pth \ 21 | --with_clip_label \ 22 | --with_obj_clip_label \ 23 | --gradient_accumulation_steps 1 \ 24 | --num_workers 8 \ 25 | --opt_sched "multiStep" \ 26 | --dataset_root GEN \ 27 | --model_name VISUALIZATION \ 28 | --zero_shot_type default \ 29 | --resume ${EXP_DIR}/checkpoint_last.pth \ 30 | --verb_pth ./tmp/verb.pth \ 31 | --eval -------------------------------------------------------------------------------- /tmp/vcoco_verb.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Artanic30/HOICLIP/ee4db062097410abdd20fa96d40d26aaca1f19da/tmp/vcoco_verb.pth -------------------------------------------------------------------------------- /tmp/verb.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Artanic30/HOICLIP/ee4db062097410abdd20fa96d40d26aaca1f19da/tmp/verb.pth -------------------------------------------------------------------------------- /tmp/vis_file_names.json: -------------------------------------------------------------------------------- 1 | ["HICO_test2015_00000040.jpg", "HICO_test2015_00000133.jpg", "HICO_test2015_00000143.jpg", "HICO_test2015_00000283.jpg", "HICO_test2015_00000317.jpg", "HICO_test2015_00000325.jpg", "HICO_test2015_00000327.jpg", "HICO_test2015_00000358.jpg", "HICO_test2015_00000369.jpg", "HICO_test2015_00000394.jpg", "HICO_test2015_00000413.jpg", "HICO_test2015_00000424.jpg", "HICO_test2015_00000457.jpg", "HICO_test2015_00000469.jpg", "HICO_test2015_00000503.jpg", "HICO_test2015_00000559.jpg", "HICO_test2015_00000651.jpg", "HICO_test2015_00000671.jpg", "HICO_test2015_00000699.jpg", "HICO_test2015_00000719.jpg", "HICO_test2015_00000768.jpg", "HICO_test2015_00000783.jpg", "HICO_test2015_00000798.jpg", "HICO_test2015_00000881.jpg", "HICO_test2015_00000882.jpg", "HICO_test2015_00000896.jpg", "HICO_test2015_00000927.jpg", "HICO_test2015_00000987.jpg", "HICO_test2015_00001004.jpg", "HICO_test2015_00001016.jpg", "HICO_test2015_00001019.jpg", "HICO_test2015_00001057.jpg", "HICO_test2015_00001086.jpg", "HICO_test2015_00001089.jpg", "HICO_test2015_00001234.jpg", "HICO_test2015_00001251.jpg", "HICO_test2015_00001285.jpg", "HICO_test2015_00001304.jpg", "HICO_test2015_00001316.jpg", "HICO_test2015_00001323.jpg", "HICO_test2015_00001384.jpg", "HICO_test2015_00001406.jpg", "HICO_test2015_00001425.jpg", "HICO_test2015_00001440.jpg", "HICO_test2015_00001449.jpg", "HICO_test2015_00001483.jpg", "HICO_test2015_00001512.jpg", "HICO_test2015_00001574.jpg", "HICO_test2015_00001601.jpg", "HICO_test2015_00001630.jpg", "HICO_test2015_00001643.jpg", "HICO_test2015_00001714.jpg", "HICO_test2015_00001741.jpg", "HICO_test2015_00001767.jpg", "HICO_test2015_00001961.jpg", "HICO_test2015_00001976.jpg", "HICO_test2015_00001985.jpg", "HICO_test2015_00002042.jpg", "HICO_test2015_00002109.jpg", "HICO_test2015_00002131.jpg", "HICO_test2015_00002144.jpg", "HICO_test2015_00002156.jpg", "HICO_test2015_00002164.jpg", "HICO_test2015_00002173.jpg", "HICO_test2015_00002237.jpg", "HICO_test2015_00002254.jpg", "HICO_test2015_00002257.jpg", "HICO_test2015_00002283.jpg", "HICO_test2015_00002306.jpg", "HICO_test2015_00002313.jpg", "HICO_test2015_00002325.jpg", "HICO_test2015_00002326.jpg", "HICO_test2015_00002373.jpg", "HICO_test2015_00002399.jpg", "HICO_test2015_00002419.jpg", "HICO_test2015_00002462.jpg", "HICO_test2015_00002465.jpg", "HICO_test2015_00002479.jpg", "HICO_test2015_00002503.jpg", "HICO_test2015_00002546.jpg", "HICO_test2015_00002574.jpg", "HICO_test2015_00002578.jpg", "HICO_test2015_00002618.jpg", "HICO_test2015_00002634.jpg", "HICO_test2015_00002642.jpg", "HICO_test2015_00002653.jpg", "HICO_test2015_00002665.jpg", "HICO_test2015_00002728.jpg", "HICO_test2015_00002731.jpg", "HICO_test2015_00002772.jpg", "HICO_test2015_00002815.jpg", "HICO_test2015_00002844.jpg", "HICO_test2015_00002858.jpg", "HICO_test2015_00002909.jpg", "HICO_test2015_00002926.jpg", "HICO_test2015_00002936.jpg", "HICO_test2015_00002971.jpg", "HICO_test2015_00002983.jpg", "HICO_test2015_00002989.jpg", "HICO_test2015_00002994.jpg", "HICO_test2015_00003097.jpg", "HICO_test2015_00003148.jpg", "HICO_test2015_00003399.jpg", "HICO_test2015_00003530.jpg", "HICO_test2015_00003537.jpg", "HICO_test2015_00003560.jpg", "HICO_test2015_00003597.jpg", "HICO_test2015_00003645.jpg", "HICO_test2015_00003681.jpg", "HICO_test2015_00003692.jpg", "HICO_test2015_00003696.jpg", "HICO_test2015_00003816.jpg", "HICO_test2015_00003857.jpg", "HICO_test2015_00003883.jpg", "HICO_test2015_00003911.jpg", "HICO_test2015_00003932.jpg", "HICO_test2015_00003971.jpg", "HICO_test2015_00003991.jpg", "HICO_test2015_00004011.jpg", "HICO_test2015_00004072.jpg", "HICO_test2015_00004160.jpg", "HICO_test2015_00004184.jpg", "HICO_test2015_00004193.jpg", "HICO_test2015_00004239.jpg", "HICO_test2015_00004283.jpg", "HICO_test2015_00004295.jpg", "HICO_test2015_00004400.jpg", "HICO_test2015_00004414.jpg", "HICO_test2015_00004415.jpg", "HICO_test2015_00004422.jpg", "HICO_test2015_00004428.jpg", "HICO_test2015_00004476.jpg", "HICO_test2015_00004497.jpg", "HICO_test2015_00004605.jpg", "HICO_test2015_00004680.jpg", "HICO_test2015_00004693.jpg", "HICO_test2015_00004716.jpg", "HICO_test2015_00004723.jpg", "HICO_test2015_00004800.jpg", "HICO_test2015_00004812.jpg", "HICO_test2015_00004861.jpg", "HICO_test2015_00004876.jpg", "HICO_test2015_00004967.jpg", "HICO_test2015_00004985.jpg", "HICO_test2015_00005033.jpg", "HICO_test2015_00005102.jpg", "HICO_test2015_00005138.jpg", "HICO_test2015_00005145.jpg", "HICO_test2015_00005212.jpg", "HICO_test2015_00005222.jpg", "HICO_test2015_00005242.jpg", "HICO_test2015_00005273.jpg", "HICO_test2015_00005335.jpg", "HICO_test2015_00005360.jpg", "HICO_test2015_00005553.jpg", "HICO_test2015_00005566.jpg", "HICO_test2015_00005598.jpg", "HICO_test2015_00005641.jpg", "HICO_test2015_00005658.jpg", "HICO_test2015_00005665.jpg", "HICO_test2015_00005708.jpg", "HICO_test2015_00005711.jpg", "HICO_test2015_00005727.jpg", "HICO_test2015_00005731.jpg", "HICO_test2015_00005776.jpg", "HICO_test2015_00005811.jpg", "HICO_test2015_00005831.jpg", "HICO_test2015_00005872.jpg", "HICO_test2015_00005883.jpg", "HICO_test2015_00005906.jpg", "HICO_test2015_00005942.jpg", "HICO_test2015_00005947.jpg", "HICO_test2015_00005949.jpg", "HICO_test2015_00005997.jpg", "HICO_test2015_00006005.jpg", "HICO_test2015_00006022.jpg", "HICO_test2015_00006054.jpg", "HICO_test2015_00006095.jpg", "HICO_test2015_00006106.jpg", "HICO_test2015_00006121.jpg", "HICO_test2015_00006136.jpg", "HICO_test2015_00006163.jpg", "HICO_test2015_00006168.jpg", "HICO_test2015_00006201.jpg", "HICO_test2015_00006210.jpg", "HICO_test2015_00006243.jpg", "HICO_test2015_00006261.jpg", "HICO_test2015_00006263.jpg", "HICO_test2015_00006265.jpg", "HICO_test2015_00006287.jpg", "HICO_test2015_00006297.jpg", "HICO_test2015_00006306.jpg", "HICO_test2015_00006336.jpg", "HICO_test2015_00006362.jpg", "HICO_test2015_00006411.jpg", "HICO_test2015_00006436.jpg", "HICO_test2015_00006521.jpg", "HICO_test2015_00006538.jpg", "HICO_test2015_00006553.jpg", "HICO_test2015_00006578.jpg", "HICO_test2015_00006584.jpg", "HICO_test2015_00006613.jpg", "HICO_test2015_00006614.jpg", "HICO_test2015_00006668.jpg", "HICO_test2015_00006683.jpg", "HICO_test2015_00006710.jpg", "HICO_test2015_00006723.jpg", "HICO_test2015_00006733.jpg", "HICO_test2015_00006764.jpg", "HICO_test2015_00006787.jpg", "HICO_test2015_00006796.jpg", "HICO_test2015_00006805.jpg", "HICO_test2015_00006842.jpg", "HICO_test2015_00006969.jpg", "HICO_test2015_00006994.jpg", "HICO_test2015_00007005.jpg", "HICO_test2015_00007038.jpg", "HICO_test2015_00007091.jpg", "HICO_test2015_00007124.jpg", "HICO_test2015_00007136.jpg", "HICO_test2015_00007180.jpg", "HICO_test2015_00007186.jpg", "HICO_test2015_00007245.jpg", "HICO_test2015_00007267.jpg", "HICO_test2015_00007268.jpg", "HICO_test2015_00007288.jpg", "HICO_test2015_00007293.jpg", "HICO_test2015_00007295.jpg", "HICO_test2015_00007316.jpg", "HICO_test2015_00007325.jpg", "HICO_test2015_00007337.jpg", "HICO_test2015_00007356.jpg", "HICO_test2015_00007387.jpg", "HICO_test2015_00007399.jpg", "HICO_test2015_00007418.jpg", "HICO_test2015_00007436.jpg", "HICO_test2015_00007451.jpg", "HICO_test2015_00007454.jpg", "HICO_test2015_00007487.jpg", "HICO_test2015_00007541.jpg", "HICO_test2015_00007550.jpg", "HICO_test2015_00007560.jpg", "HICO_test2015_00007585.jpg", "HICO_test2015_00007588.jpg", "HICO_test2015_00007649.jpg", "HICO_test2015_00007650.jpg", "HICO_test2015_00007710.jpg", "HICO_test2015_00007745.jpg", "HICO_test2015_00007762.jpg", "HICO_test2015_00007772.jpg", "HICO_test2015_00007791.jpg", "HICO_test2015_00007796.jpg", "HICO_test2015_00007797.jpg", "HICO_test2015_00007813.jpg", "HICO_test2015_00007836.jpg", "HICO_test2015_00007865.jpg", "HICO_test2015_00007905.jpg", "HICO_test2015_00007943.jpg", "HICO_test2015_00007953.jpg", "HICO_test2015_00007969.jpg", "HICO_test2015_00008077.jpg", "HICO_test2015_00008122.jpg", "HICO_test2015_00008145.jpg", "HICO_test2015_00008178.jpg", "HICO_test2015_00008236.jpg", "HICO_test2015_00008273.jpg", "HICO_test2015_00008289.jpg", "HICO_test2015_00008373.jpg", "HICO_test2015_00008376.jpg", "HICO_test2015_00008380.jpg", "HICO_test2015_00008447.jpg", "HICO_test2015_00008454.jpg", "HICO_test2015_00008469.jpg", "HICO_test2015_00008505.jpg", "HICO_test2015_00008519.jpg", "HICO_test2015_00008526.jpg", "HICO_test2015_00008583.jpg", "HICO_test2015_00008615.jpg", "HICO_test2015_00008637.jpg", "HICO_test2015_00008690.jpg", "HICO_test2015_00008753.jpg", "HICO_test2015_00008755.jpg", "HICO_test2015_00008792.jpg", "HICO_test2015_00008874.jpg", "HICO_test2015_00008929.jpg", "HICO_test2015_00008941.jpg", "HICO_test2015_00008970.jpg", "HICO_test2015_00008999.jpg", "HICO_test2015_00009024.jpg", "HICO_test2015_00009043.jpg", "HICO_test2015_00009218.jpg", "HICO_test2015_00009219.jpg", "HICO_test2015_00009222.jpg", "HICO_test2015_00009225.jpg", "HICO_test2015_00009226.jpg", "HICO_test2015_00009227.jpg", "HICO_test2015_00009228.jpg", "HICO_test2015_00009237.jpg", "HICO_test2015_00009242.jpg", "HICO_test2015_00009244.jpg", "HICO_test2015_00009245.jpg", "HICO_test2015_00009248.jpg", "HICO_test2015_00009249.jpg", "HICO_test2015_00009250.jpg", "HICO_test2015_00009254.jpg", "HICO_test2015_00009418.jpg", "HICO_test2015_00009419.jpg", "HICO_test2015_00009438.jpg", "HICO_test2015_00009452.jpg", "HICO_test2015_00009498.jpg", "HICO_test2015_00009500.jpg", "HICO_test2015_00009501.jpg", "HICO_test2015_00009508.jpg", "HICO_test2015_00009510.jpg", "HICO_test2015_00009533.jpg", "HICO_test2015_00009535.jpg", "HICO_test2015_00009536.jpg", "HICO_test2015_00009537.jpg", "HICO_test2015_00009671.jpg"] -------------------------------------------------------------------------------- /tools/convert_parameters.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | from torch import nn 5 | 6 | 7 | def get_args(): 8 | parser = argparse.ArgumentParser() 9 | 10 | parser.add_argument( 11 | '--load_path', type=str, required=True, 12 | ) 13 | parser.add_argument( 14 | '--save_path', type=str, required=True, 15 | ) 16 | parser.add_argument( 17 | '--dataset', type=str, default='hico', 18 | ) 19 | parser.add_argument( 20 | '--num_queries', type=int, default=100, 21 | ) 22 | 23 | args = parser.parse_args() 24 | 25 | return args 26 | 27 | 28 | def main(args): 29 | ps = torch.load(args.load_path) 30 | 31 | obj_ids = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 32 | 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 33 | 24, 25, 27, 28, 31, 32, 33, 34, 35, 36, 34 | 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 35 | 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 36 | 58, 59, 60, 61, 62, 63, 64, 65, 67, 70, 37 | 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 38 | 82, 84, 85, 86, 87, 88, 89, 90] 39 | 40 | # For no pair 41 | obj_ids.append(91) 42 | 43 | for k in list(ps['model'].keys()): 44 | print(k) 45 | if len(k.split('.')) > 1 and k.split('.')[1] == 'decoder': 46 | ps['model'][k.replace('decoder', 'instance_decoder')] = ps['model'][k].clone() 47 | ps['model'][k.replace('decoder', 'interaction_decoder')] = ps['model'][k].clone() 48 | 49 | del ps['model'][k] 50 | 51 | ps['model']['hum_bbox_embed.layers.0.weight'] = ps['model']['bbox_embed.layers.0.weight'].clone() 52 | ps['model']['hum_bbox_embed.layers.0.bias'] = ps['model']['bbox_embed.layers.0.bias'].clone() 53 | ps['model']['hum_bbox_embed.layers.1.weight'] = ps['model']['bbox_embed.layers.1.weight'].clone() 54 | ps['model']['hum_bbox_embed.layers.1.bias'] = ps['model']['bbox_embed.layers.1.bias'].clone() 55 | ps['model']['hum_bbox_embed.layers.2.weight'] = ps['model']['bbox_embed.layers.2.weight'].clone() 56 | ps['model']['hum_bbox_embed.layers.2.bias'] = ps['model']['bbox_embed.layers.2.bias'].clone() 57 | 58 | ps['model']['obj_bbox_embed.layers.0.weight'] = ps['model']['bbox_embed.layers.0.weight'].clone() 59 | ps['model']['obj_bbox_embed.layers.0.bias'] = ps['model']['bbox_embed.layers.0.bias'].clone() 60 | ps['model']['obj_bbox_embed.layers.1.weight'] = ps['model']['bbox_embed.layers.1.weight'].clone() 61 | ps['model']['obj_bbox_embed.layers.1.bias'] = ps['model']['bbox_embed.layers.1.bias'].clone() 62 | ps['model']['obj_bbox_embed.layers.2.weight'] = ps['model']['bbox_embed.layers.2.weight'].clone() 63 | ps['model']['obj_bbox_embed.layers.2.bias'] = ps['model']['bbox_embed.layers.2.bias'].clone() 64 | 65 | ps['model']['obj_class_embed.weight'] = ps['model']['class_embed.weight'].clone()[obj_ids] 66 | ps['model']['obj_class_embed.bias'] = ps['model']['class_embed.bias'].clone()[obj_ids] 67 | 68 | ps['model']['query_embed.weight'] = ps['model']['query_embed.weight'].clone()[:args.num_queries] 69 | 70 | if args.dataset == 'vcoco': 71 | l = nn.Linear(ps['model']['obj_class_embed.weight'].shape[1], 1) 72 | l.to(ps['model']['obj_class_embed.weight'].device) 73 | ps['model']['obj_class_embed.weight'] = torch.cat(( 74 | ps['model']['obj_class_embed.weight'][:-1], l.weight, ps['model']['obj_class_embed.weight'][[-1]])) 75 | ps['model']['obj_class_embed.bias'] = torch.cat( 76 | (ps['model']['obj_class_embed.bias'][:-1], l.bias, ps['model']['obj_class_embed.bias'][[-1]])) 77 | 78 | torch.save(ps, args.save_path) 79 | 80 | 81 | if __name__ == '__main__': 82 | args = get_args() 83 | main(args) 84 | -------------------------------------------------------------------------------- /tools/convert_vcoco_annotations.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | from collections import defaultdict 4 | import json 5 | import pickle 6 | import os 7 | 8 | import vsrl_utils as vu 9 | 10 | 11 | def get_args(): 12 | parser = argparse.ArgumentParser() 13 | 14 | parser.add_argument( 15 | '--load_path', type=str, required=True, 16 | ) 17 | parser.add_argument( 18 | '--prior_path', type=str, required=True, 19 | ) 20 | parser.add_argument( 21 | '--save_path', type=str, required=True, 22 | ) 23 | 24 | args = parser.parse_args() 25 | 26 | return args 27 | 28 | 29 | def set_hoi(box_annotations, hoi_annotations, verb_classes): 30 | no_object_id = -1 31 | 32 | hoia_annotations = defaultdict(lambda: { 33 | 'annotations': [], 34 | 'hoi_annotation': [] 35 | }) 36 | 37 | for action_annotation in hoi_annotations: 38 | for label, img_id, role_ids in zip(action_annotation['label'][:, 0], 39 | action_annotation['image_id'][:, 0], 40 | action_annotation['role_object_id']): 41 | hoia_annotations[img_id]['file_name'] = box_annotations[img_id]['file_name'] 42 | hoia_annotations[img_id]['annotations'] = box_annotations[img_id]['annotations'] 43 | 44 | if label == 0: 45 | continue 46 | 47 | subject_id = box_annotations[img_id]['annotation_ids'].index(role_ids[0]) 48 | 49 | if len(role_ids) == 1: 50 | hoia_annotations[img_id]['hoi_annotation'].append( 51 | {'subject_id': subject_id, 'object_id': no_object_id, 52 | 'category_id': verb_classes.index(action_annotation['action_name'])}) 53 | continue 54 | 55 | for role_name, role_id in zip(action_annotation['role_name'][1:], role_ids[1:]): 56 | if role_id == 0: 57 | object_id = no_object_id 58 | else: 59 | object_id = box_annotations[img_id]['annotation_ids'].index(role_id) 60 | 61 | hoia_annotations[img_id]['hoi_annotation'].append( 62 | {'subject_id': subject_id, 'object_id': object_id, 63 | 'category_id': verb_classes.index('{}_{}'.format(action_annotation['action_name'], role_name))}) 64 | 65 | hoia_annotations = [v for v in hoia_annotations.values()] 66 | 67 | return hoia_annotations 68 | 69 | 70 | def main(args): 71 | vsgnet_verbs_classes = { 72 | 'carry_obj': 0, 73 | 'catch_obj': 1, 74 | 'cut_instr':2, 75 | 'cut_obj': 3, 76 | 'drink_instr': 4, 77 | 'eat_instr':5, 78 | 'eat_obj': 6, 79 | 'hit_instr':7, 80 | 'hit_obj': 8, 81 | 'hold_obj': 9, 82 | 'jump_instr': 10, 83 | 'kick_obj': 11, 84 | 'lay_instr': 12, 85 | 'look_obj': 13, 86 | 'point_instr': 14, 87 | 'read_obj': 15, 88 | 'ride_instr': 16, 89 | 'run': 17, 90 | 'sit_instr': 18, 91 | 'skateboard_instr': 19, 92 | 'ski_instr': 20, 93 | 'smile': 21, 94 | 'snowboard_instr': 22, 95 | 'stand': 23, 96 | 'surf_instr': 24, 97 | 'talk_on_phone_instr': 25, 98 | 'throw_obj': 26, 99 | 'walk': 27, 100 | 'work_on_computer_instr': 28 101 | } 102 | 103 | box_annotations = defaultdict(lambda: { 104 | 'annotations': [], 105 | 'annotation_ids': [] 106 | }) 107 | 108 | coco = vu.load_coco(args.load_path) 109 | 110 | img_ids = coco.getImgIds() 111 | img_infos = coco.loadImgs(img_ids) 112 | 113 | for img_info in img_infos: 114 | box_annotations[img_info['id']]['file_name'] = img_info['file_name'] 115 | 116 | annotation_ids = coco.getAnnIds(imgIds=img_ids) 117 | annotations = coco.loadAnns(annotation_ids) 118 | for annotation in annotations: 119 | img_id = annotation['image_id'] 120 | category_id = annotation['category_id'] 121 | box = np.array(annotation['bbox']) 122 | box[2:] += box[:2] 123 | 124 | box_annotations[img_id]['annotations'].append({'category_id': category_id, 'bbox': box.tolist()}) 125 | box_annotations[img_id]['annotation_ids'].append(annotation['id']) 126 | 127 | hoi_trainval = vu.load_vcoco('vcoco_trainval') 128 | hoi_test = vu.load_vcoco('vcoco_test') 129 | 130 | action_classes = [x['action_name'] for x in hoi_trainval] 131 | verb_classes = [] 132 | for action in hoi_trainval: 133 | if len(action['role_name']) == 1: 134 | verb_classes.append(action['action_name']) 135 | else: 136 | verb_classes += ['{}_{}'.format(action['action_name'], r) for r in action['role_name'][1:]] 137 | 138 | print('Verb class') 139 | for i, verb_class in enumerate(verb_classes): 140 | print('{:02d}: {}'.format(i, verb_class)) 141 | 142 | hoia_trainval_annotations = set_hoi(box_annotations, hoi_trainval, verb_classes) 143 | hoia_test_annotations = set_hoi(box_annotations, hoi_test, verb_classes) 144 | 145 | print('#Training images: {}, #Test images: {}'.format(len(hoia_trainval_annotations), len(hoia_test_annotations))) 146 | 147 | with open(os.path.join(args.save_path, 'trainval_vcoco.json'), 'w') as f: 148 | json.dump(hoia_trainval_annotations, f) 149 | 150 | with open(os.path.join(args.save_path, 'test_vcoco.json'), 'w') as f: 151 | json.dump(hoia_test_annotations, f) 152 | 153 | with open(args.prior_path, 'rb') as f: 154 | prior = pickle.load(f) 155 | 156 | prior = [prior[k] for k in sorted(prior.keys())] 157 | prior = np.concatenate(prior).T 158 | prior = prior[[vsgnet_verbs_classes[verb_class] for verb_class in verb_classes]] 159 | np.save(os.path.join(args.save_path, 'corre_vcoco.npy'), prior) 160 | 161 | 162 | if __name__ == '__main__': 163 | args = get_args() 164 | main(args) 165 | -------------------------------------------------------------------------------- /tools/covert_annot_for_official_eval.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | import scipy.io as sio 4 | import os 5 | 6 | 7 | def Format_Pred(pred_file): 8 | orig_file = json.load(open(pred_file, 'r')) 9 | if isinstance(orig_file, str): 10 | orig_file = eval(orig_file)['preds'] 11 | out_pred = {} 12 | for annot in orig_file: 13 | annot_bbox = annot['predictions'] 14 | annot_hoi = annot['hoi_prediction'] 15 | img_id = int((annot['filename'].split('.')[0]).split('_')[-1]) 16 | for hoi in annot_hoi: 17 | sub_bbox = annot_bbox[hoi['subject_id']] 18 | obj_bbox = annot_bbox[hoi['object_id']] 19 | hoi_cls = int(hoi['category_id'] + 1) 20 | score = hoi['score'] 21 | this_out = {'img_id': img_id, 'human_box': sub_bbox['bbox'], 'object_box': obj_bbox['bbox'], 'score': score} 22 | if hoi_cls not in out_pred.keys(): 23 | out_pred[hoi_cls] = [] 24 | out_pred[hoi_cls].append(this_out) 25 | return out_pred 26 | 27 | 28 | def save_HICO(HICO, HICO_dir, classid, begin, finish): 29 | all_boxes = [] 30 | for i in range(begin, finish + 1): 31 | total = [] 32 | score = [] 33 | if i in HICO.keys(): 34 | for element in HICO[i]: 35 | temp = [] 36 | temp.append(element['human_box']) # Human box 37 | temp.append(element['object_box']) # Object box 38 | temp.append(element['img_id']) # image id 39 | temp.append(int(i - begin)) # action id (0-599) 40 | temp.append(element['score'] * 1000) 41 | total.append(temp) 42 | score.append(element['score'] * 1000) 43 | 44 | idx = np.argsort(score, axis=0)[::-1] 45 | for i_idx in range(min(len(idx), 19999)): 46 | all_boxes.append(total[idx[i_idx]]) 47 | else: 48 | print(i) 49 | savefile = HICO_dir + 'detections_' + str(classid).zfill(2) + '.mat' 50 | sio.savemat(savefile, {'all_boxes': all_boxes}) 51 | 52 | def Generate_HICO_detection(output_file, HICO_dir): 53 | if not os.path.exists(HICO_dir): 54 | os.makedirs(HICO_dir) 55 | 56 | # Remove previous results 57 | filelist = [f for f in os.listdir(HICO_dir)] 58 | for f in filelist: 59 | os.remove(os.path.join(HICO_dir, f)) 60 | 61 | HICO = Format_Pred(output_file) 62 | 63 | save_HICO(HICO, HICO_dir, 1, 161, 170) # 1 person 64 | save_HICO(HICO, HICO_dir, 2, 11, 24) # 2 bicycle 65 | save_HICO(HICO, HICO_dir, 3, 66, 76) # 3 car 66 | save_HICO(HICO, HICO_dir, 4, 147, 160) # 4 motorcycle 67 | save_HICO(HICO, HICO_dir, 5, 1, 10) # 5 airplane 68 | save_HICO(HICO, HICO_dir, 6, 55, 65) # 6 bus 69 | save_HICO(HICO, HICO_dir, 7, 187, 194) # 7 train 70 | save_HICO(HICO, HICO_dir, 8, 568, 576) # 8 truck 71 | save_HICO(HICO, HICO_dir, 9, 32, 46) # 9 boat 72 | save_HICO(HICO, HICO_dir, 10, 563, 567) # 10 traffic light 73 | save_HICO(HICO, HICO_dir, 11, 326, 330) # 11 fire_hydrant 74 | save_HICO(HICO, HICO_dir, 12, 503, 506) # 12 stop_sign 75 | save_HICO(HICO, HICO_dir, 13, 415, 418) # 13 parking_meter 76 | save_HICO(HICO, HICO_dir, 14, 244, 247) # 14 bench 77 | save_HICO(HICO, HICO_dir, 15, 25, 31) # 15 bird 78 | save_HICO(HICO, HICO_dir, 16, 77, 86) # 16 cat 79 | save_HICO(HICO, HICO_dir, 17, 112, 129) # 17 dog 80 | save_HICO(HICO, HICO_dir, 18, 130, 146) # 18 horse 81 | save_HICO(HICO, HICO_dir, 19, 175, 186) # 19 sheep 82 | save_HICO(HICO, HICO_dir, 20, 97, 107) # 20 cow 83 | save_HICO(HICO, HICO_dir, 21, 314, 325) # 21 elephant 84 | save_HICO(HICO, HICO_dir, 22, 236, 239) # 22 bear 85 | save_HICO(HICO, HICO_dir, 23, 596, 600) # 23 zebra 86 | save_HICO(HICO, HICO_dir, 24, 343, 348) # 24 giraffe 87 | save_HICO(HICO, HICO_dir, 25, 209, 214) # 25 backpack 88 | save_HICO(HICO, HICO_dir, 26, 577, 584) # 26 umbrella 89 | save_HICO(HICO, HICO_dir, 27, 353, 356) # 27 handbag 90 | save_HICO(HICO, HICO_dir, 28, 539, 546) # 28 tie 91 | save_HICO(HICO, HICO_dir, 29, 507, 516) # 29 suitcase 92 | save_HICO(HICO, HICO_dir, 30, 337, 342) # 30 Frisbee 93 | save_HICO(HICO, HICO_dir, 31, 464, 474) # 31 skis 94 | save_HICO(HICO, HICO_dir, 32, 475, 483) # 32 snowboard 95 | save_HICO(HICO, HICO_dir, 33, 489, 502) # 33 sports_ball 96 | save_HICO(HICO, HICO_dir, 34, 369, 376) # 34 kite 97 | save_HICO(HICO, HICO_dir, 35, 225, 232) # 35 baseball_bat 98 | save_HICO(HICO, HICO_dir, 36, 233, 235) # 36 baseball_glove 99 | save_HICO(HICO, HICO_dir, 37, 454, 463) # 37 skateboard 100 | save_HICO(HICO, HICO_dir, 38, 517, 528) # 38 surfboard 101 | save_HICO(HICO, HICO_dir, 39, 534, 538) # 39 tennis_racket 102 | save_HICO(HICO, HICO_dir, 40, 47, 54) # 40 bottle 103 | save_HICO(HICO, HICO_dir, 41, 589, 595) # 41 wine_glass 104 | save_HICO(HICO, HICO_dir, 42, 296, 305) # 42 cup 105 | save_HICO(HICO, HICO_dir, 43, 331, 336) # 43 fork 106 | save_HICO(HICO, HICO_dir, 44, 377, 383) # 44 knife 107 | save_HICO(HICO, HICO_dir, 45, 484, 488) # 45 spoon 108 | save_HICO(HICO, HICO_dir, 46, 253, 257) # 46 bowl 109 | save_HICO(HICO, HICO_dir, 47, 215, 224) # 47 banana 110 | save_HICO(HICO, HICO_dir, 48, 199, 208) # 48 apple 111 | save_HICO(HICO, HICO_dir, 49, 439, 445) # 49 sandwich 112 | save_HICO(HICO, HICO_dir, 50, 398, 407) # 50 orange 113 | save_HICO(HICO, HICO_dir, 51, 258, 264) # 51 broccoli 114 | save_HICO(HICO, HICO_dir, 52, 274, 283) # 52 carrot 115 | save_HICO(HICO, HICO_dir, 53, 357, 363) # 53 hot_dog 116 | save_HICO(HICO, HICO_dir, 54, 419, 429) # 54 pizza 117 | save_HICO(HICO, HICO_dir, 55, 306, 313) # 55 donut 118 | save_HICO(HICO, HICO_dir, 56, 265, 273) # 56 cake 119 | save_HICO(HICO, HICO_dir, 57, 87, 92) # 57 chair 120 | save_HICO(HICO, HICO_dir, 58, 93, 96) # 58 couch 121 | save_HICO(HICO, HICO_dir, 59, 171, 174) # 59 potted_plant 122 | save_HICO(HICO, HICO_dir, 60, 240, 243) # 60 bed 123 | save_HICO(HICO, HICO_dir, 61, 108, 111) # 61 dining_table 124 | save_HICO(HICO, HICO_dir, 62, 551, 558) # 62 toilet 125 | save_HICO(HICO, HICO_dir, 63, 195, 198) # 63 TV 126 | save_HICO(HICO, HICO_dir, 64, 384, 389) # 64 laptop 127 | save_HICO(HICO, HICO_dir, 65, 394, 397) # 65 mouse 128 | save_HICO(HICO, HICO_dir, 66, 435, 438) # 66 remote 129 | save_HICO(HICO, HICO_dir, 67, 364, 368) # 67 keyboard 130 | save_HICO(HICO, HICO_dir, 68, 284, 290) # 68 cell_phone 131 | save_HICO(HICO, HICO_dir, 69, 390, 393) # 69 microwave 132 | save_HICO(HICO, HICO_dir, 70, 408, 414) # 70 oven 133 | save_HICO(HICO, HICO_dir, 71, 547, 550) # 71 toaster 134 | save_HICO(HICO, HICO_dir, 72, 450, 453) # 72 sink 135 | save_HICO(HICO, HICO_dir, 73, 430, 434) # 73 refrigerator 136 | save_HICO(HICO, HICO_dir, 74, 248, 252) # 74 book 137 | save_HICO(HICO, HICO_dir, 75, 291, 295) # 75 clock 138 | save_HICO(HICO, HICO_dir, 76, 585, 588) # 76 vase 139 | save_HICO(HICO, HICO_dir, 77, 446, 449) # 77 scissors 140 | save_HICO(HICO, HICO_dir, 78, 529, 533) # 78 teddy_bear 141 | save_HICO(HICO, HICO_dir, 79, 349, 352) # 79 hair_drier 142 | save_HICO(HICO, HICO_dir, 80, 559, 562) # 80 toothbrush 143 | 144 | 145 | if __name__ == '__main__': 146 | Generate_HICO_detection('./results.json', './ppdm_results/') 147 | -------------------------------------------------------------------------------- /util/box_ops.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) Hitachi, Ltd. All Rights Reserved. 3 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 4 | # ------------------------------------------------------------------------ 5 | # Modified from DETR (https://github.com/facebookresearch/detr) 6 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 7 | # ------------------------------------------------------------------------ 8 | """ 9 | Utilities for bounding box manipulation and GIoU. 10 | """ 11 | import torch 12 | from torchvision.ops.boxes import box_area 13 | 14 | 15 | def box_cxcywh_to_xyxy(x): 16 | x_c, y_c, w, h = x.unbind(-1) 17 | b = [(x_c - 0.5 * w), (y_c - 0.5 * h), 18 | (x_c + 0.5 * w), (y_c + 0.5 * h)] 19 | return torch.stack(b, dim=-1) 20 | 21 | 22 | def box_xyxy_to_cxcywh(x): 23 | x0, y0, x1, y1 = x.unbind(-1) 24 | b = [(x0 + x1) / 2, (y0 + y1) / 2, 25 | (x1 - x0), (y1 - y0)] 26 | return torch.stack(b, dim=-1) 27 | 28 | 29 | # modified from torchvision to also return the union 30 | def box_iou(boxes1, boxes2): 31 | area1 = box_area(boxes1) 32 | area2 = box_area(boxes2) 33 | 34 | lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] 35 | rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] 36 | 37 | wh = (rb - lt).clamp(min=0) # [N,M,2] 38 | inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] 39 | 40 | union = area1[:, None] + area2 - inter 41 | 42 | iou = inter / union 43 | return iou, union 44 | 45 | 46 | def generalized_box_iou(boxes1, boxes2): 47 | """ 48 | Generalized IoU from https://giou.stanford.edu/ 49 | 50 | The boxes should be in [x0, y0, x1, y1] format 51 | 52 | Returns a [N, M] pairwise matrix, where N = len(boxes1) 53 | and M = len(boxes2) 54 | """ 55 | # degenerate boxes gives inf / nan results 56 | # so do an early check 57 | assert (boxes1[:, 2:] >= boxes1[:, :2]).all() 58 | assert (boxes2[:, 2:] >= boxes2[:, :2]).all() 59 | iou, union = box_iou(boxes1, boxes2) 60 | 61 | lt = torch.min(boxes1[:, None, :2], boxes2[:, :2]) 62 | rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) 63 | 64 | wh = (rb - lt).clamp(min=0) # [N,M,2] 65 | area = wh[:, :, 0] * wh[:, :, 1] 66 | 67 | return iou - (area - union) / area 68 | 69 | 70 | def masks_to_boxes(masks): 71 | """Compute the bounding boxes around the provided masks 72 | 73 | The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions. 74 | 75 | Returns a [N, 4] tensors, with the boxes in xyxy format 76 | """ 77 | if masks.numel() == 0: 78 | return torch.zeros((0, 4), device=masks.device) 79 | 80 | h, w = masks.shape[-2:] 81 | 82 | y = torch.arange(0, h, dtype=torch.float) 83 | x = torch.arange(0, w, dtype=torch.float) 84 | y, x = torch.meshgrid(y, x) 85 | 86 | x_mask = (masks * x.unsqueeze(0)) 87 | x_max = x_mask.flatten(1).max(-1)[0] 88 | x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] 89 | 90 | y_mask = (masks * y.unsqueeze(0)) 91 | y_max = y_mask.flatten(1).max(-1)[0] 92 | y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] 93 | 94 | return torch.stack([x_min, y_min, x_max, y_max], 1) 95 | -------------------------------------------------------------------------------- /util/logger.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | helper for logging 6 | NOTE: loggers are global objects use with caution 7 | """ 8 | import logging 9 | import math 10 | 11 | import tensorboardX 12 | 13 | 14 | _LOG_FMT = '%(asctime)s - %(levelname)s - %(name)s - %(message)s' 15 | _DATE_FMT = '%m/%d/%Y %H:%M:%S' 16 | logging.basicConfig(format=_LOG_FMT, datefmt=_DATE_FMT, level=logging.INFO) 17 | LOGGER = logging.getLogger('__main__') # this is the global logger 18 | 19 | 20 | def add_log_to_file(log_path): 21 | fh = logging.FileHandler(log_path) 22 | formatter = logging.Formatter(_LOG_FMT, datefmt=_DATE_FMT) 23 | fh.setFormatter(formatter) 24 | LOGGER.addHandler(fh) 25 | 26 | 27 | class TensorboardLogger(object): 28 | def __init__(self): 29 | self._logger = None 30 | self._global_step = 0 31 | 32 | def create(self, path): 33 | self._logger = tensorboardX.SummaryWriter(path) 34 | 35 | def noop(self, *args, **kwargs): 36 | return 37 | 38 | def step(self): 39 | self._global_step += 1 40 | 41 | @property 42 | def global_step(self): 43 | return self._global_step 44 | 45 | def log_scaler_dict(self, log_dict, prefix=''): 46 | """ log a dictionary of scalar values""" 47 | if self._logger is None: 48 | return 49 | if prefix: 50 | prefix = f'{prefix}_' 51 | for name, value in log_dict.items(): 52 | if isinstance(value, dict): 53 | self.log_scaler_dict(value, self._global_step, prefix=f'{prefix}{name}') 54 | else: 55 | self._logger.add_scalar(f'{prefix}{name}', value, self._global_step) 56 | 57 | def __getattr__(self, name): 58 | if self._logger is None: 59 | return self.noop 60 | return self._logger.__getattribute__(name) 61 | 62 | 63 | TB_LOGGER = TensorboardLogger() 64 | 65 | 66 | class RunningMeter(object): 67 | """ running meteor of a scalar value 68 | (useful for monitoring training loss) 69 | """ 70 | def __init__(self, name, val=None, smooth=0.99): 71 | self._name = name 72 | self._sm = smooth 73 | self._val = val 74 | 75 | def __call__(self, value): 76 | val = (value if self._val is None else value*(1-self._sm) + self._val*self._sm) 77 | if not math.isnan(val) and not math.isinf(val): 78 | self._val = val 79 | 80 | def __str__(self): 81 | return f'{self._name}: {self._val:.4f}' 82 | 83 | @property 84 | def val(self): 85 | if self._val is None: 86 | return 0 87 | return self._val 88 | 89 | @property 90 | def name(self): 91 | return self._name -------------------------------------------------------------------------------- /util/scheduler.py: -------------------------------------------------------------------------------- 1 | from typing import Union, List 2 | import logging 3 | import torch 4 | import torch.nn as nn 5 | from torch.optim.lr_scheduler import MultiStepLR, CosineAnnealingWarmRestarts, CosineAnnealingLR 6 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize 7 | from PIL import Image 8 | import os 9 | import torch.distributed as dist 10 | import warnings 11 | import math 12 | 13 | 14 | class MultiStepLRWarmup(MultiStepLR): 15 | 16 | def __init__(self, *args, **kwargs): 17 | self.warmup_iter = kwargs['warmup_iter'] 18 | self.cur_iter = 0 19 | self.warmup_ratio = kwargs['warmup_ratio'] 20 | self.init_lr = None 21 | del kwargs['warmup_iter'] 22 | del kwargs['warmup_ratio'] 23 | super(MultiStepLRWarmup, self).__init__(*args, **kwargs) 24 | self.init_lr = [group['lr'] for group in self.optimizer.param_groups] 25 | 26 | def iter_step(self): 27 | self.cur_iter += 1 28 | if self.cur_iter <= self.warmup_iter and self.init_lr: 29 | values = [lr * (self.warmup_ratio + (1 - self.warmup_ratio) * (self.cur_iter / self.warmup_iter)) 30 | for lr in self.init_lr] 31 | for i, data in enumerate(zip(self.optimizer.param_groups, values)): 32 | param_group, lr = data 33 | param_group['lr'] = lr 34 | self.print_lr(self.verbose, i, lr, 0) 35 | 36 | self._last_lr = [group['lr'] for group in self.optimizer.param_groups] 37 | 38 | 39 | class CosineAnnealingLRWarmup(CosineAnnealingLR): 40 | 41 | def __init__(self, *args, **kwargs): 42 | self.warmup_iter = kwargs['warmup_iter'] 43 | self.cur_iter = 0 44 | self.warmup_ratio = kwargs['warmup_ratio'] 45 | self.init_lr = None 46 | del kwargs['warmup_iter'] 47 | del kwargs['warmup_ratio'] 48 | super(CosineAnnealingLRWarmup, self).__init__(*args, **kwargs) 49 | self.init_lr = [group['lr'] for group in self.optimizer.param_groups] 50 | 51 | def iter_step(self): 52 | self.cur_iter += 1 53 | if self.cur_iter <= self.warmup_iter and self.init_lr: 54 | values = [lr * (self.warmup_ratio + (1 - self.warmup_ratio) * (self.cur_iter / self.warmup_iter)) 55 | for lr in self.init_lr] 56 | for i, data in enumerate(zip(self.optimizer.param_groups, values)): 57 | param_group, lr = data 58 | param_group['lr'] = lr 59 | self.print_lr(self.verbose, i, lr, 0) 60 | 61 | self._last_lr = [group['lr'] for group in self.optimizer.param_groups] 62 | 63 | def get_lr(self): 64 | if not self._get_lr_called_within_step: 65 | warnings.warn("To get the last learning rate computed by the scheduler, " 66 | "please use `get_last_lr()`.", UserWarning) 67 | 68 | if self.last_epoch == 0: 69 | return self.base_lrs 70 | elif (self.last_epoch - 1 - self.T_max) % (2 * self.T_max) == 0: 71 | return [group['lr'] + (base_lr - self.eta_min) * 72 | (1 - math.cos(math.pi / self.T_max)) / 2 73 | for base_lr, group in 74 | zip(self.base_lrs, self.optimizer.param_groups)] 75 | return [(1 + math.cos(math.pi * self.last_epoch / self.T_max)) / 76 | (1 + math.cos(math.pi * (self.last_epoch - 1) / self.T_max)) * 77 | (group['lr'] - init_lr * self.eta_min) + init_lr * self.eta_min 78 | for init_lr, group in zip(self.init_lr, self.optimizer.param_groups)] 79 | -------------------------------------------------------------------------------- /util/topk.py: -------------------------------------------------------------------------------- 1 | def sift(li, low, higt): 2 | tmp = li[low] 3 | i = low 4 | j = 2 * i + 1 5 | while j <= higt: # 情况2:i已经是最后一层 6 | if j + 1 <= higt and li[j + 1] < li[j]: # 右孩子存在并且小于左孩子 7 | j += 1 8 | if tmp > li[j]: 9 | li[i] = li[j] 10 | i = j 11 | j = 2 * i + 1 12 | else: 13 | break # 情况1:j位置比tmp小 14 | li[i] = tmp 15 | 16 | def top_k(li, k): 17 | heap = li[0:k] 18 | # 建堆 19 | for i in range(k // 2 - 1, -1, -1): 20 | sift(heap, i, k - 1) 21 | for i in range(k, len(li)): 22 | if li[i] > heap[0]: 23 | heap[0] = li[i] 24 | sift(heap, 0, k - 1) 25 | # 挨个输出 26 | for i in range(k - 1, -1, -1): 27 | heap[0], heap[i] = heap[i], heap[0] 28 | sift(heap, 0, i - 1) 29 | 30 | return heap 31 | 32 | --------------------------------------------------------------------------------