├── LICENSE ├── README.md ├── assets └── problem_setting.gif ├── data ├── datatest.py └── datatrain.py ├── models ├── dino │ ├── utils.py │ └── vision_transformer.py ├── locate.py └── model_util.py ├── requirements.txt ├── test.py ├── train.py └── utils ├── evaluation.py ├── util.py └── viz.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Gen Li 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LOCATE: Localize and Transfer Object Parts for Weakly Supervised Affordance Grounding 2 | 3 | [![arXiv](https://img.shields.io/badge/arXiv-2303.09665-b31b1b.svg)](https://arxiv.org/abs/2303.09665) 4 | [![GitHub](https://img.shields.io/website?label=Project%20Page&up_message=page&url=https://reagan1311.github.io/locate/)](https://reagan1311.github.io/locate/) 5 | [![ ](https://img.shields.io/youtube/views/RLHansdFxII?label=Video&style=flat)](https://www.youtube.com/watch?v=RLHansdFxII) 6 | 7 | Official pytorch implementation of our CVPR 2023 paper - LOCATE: Localize and Transfer Object Parts for Weakly 8 | Supervised Affordance Grounding. 9 | 10 | ## Abstract 11 | 12 | Humans excel at acquiring knowledge through observation. For example, we can learn to use new tools by watching 13 | demonstrations. This skill is fundamental for intelligent systems to interact with the world. A key step to acquire this 14 | skill is to identify what part of the object affords each action, which is called affordance grounding. In this paper, 15 | we address this problem and propose a framework called LOCATE that can identify matching object parts across images, to 16 | transfer knowledge from images where an object is being used (exocentric images used for learning), to images where the 17 | object is inactive (egocentric ones used to test). To this end, we first find interaction areas and extract their 18 | feature embeddings. Then we learn to aggregate the embeddings into compact prototypes (human, object part, and 19 | background), and select the one representing the object part. Finally, we use the selected prototype to guide affordance 20 | grounding. We do this in a weakly supervised manner, learning only from image-level affordance and object labels. 21 | Extensive experiments demonstrate that our approach outperforms state-of-the-art methods by a large margin on both seen 22 | and unseen objects. 23 | 24 |

25 | 26 |

27 | 28 | ## Usage 29 | 30 | ### 1. Requirements 31 | 32 | Code is tested under Pytorch 1.12.1, python 3.7, and CUDA 11.6 33 | 34 | ``` 35 | pip install -r requirements.txt 36 | ``` 37 | 38 | ### 2. Dataset 39 | 40 | Download the AGD20K dataset 41 | from [ [Google Drive](https://drive.google.com/file/d/1OEz25-u1uqKfeuyCqy7hmiOv7lIWfigk/view?usp=sharing) | [Baidu Pan](https://pan.baidu.com/s/1IRfho7xDAT0oJi5_mvP1sg) (g23n) ] 42 | . 43 | 44 | ### 3. Train and Test 45 | 46 | Our pretrained model can be downloaded 47 | from [Google Drive](https://drive.google.com/drive/folders/1-AcTiE9Lz91bPJlp1o-ubgkxKnudohdx?usp=sharing). Run following commands to start training or testing: 48 | 49 | ``` 50 | python train.py --data_root 51 | python test.py --data_root --model_file 52 | ``` 53 | 54 | ## Citation 55 | 56 | ``` 57 | @inproceedings{li:locate:2023, 58 | title = {LOCATE: Localize and Transfer Object Parts for Weakly Supervised Affordance Grounding}, 59 | author = {Li, Gen and Jampani, Varun and Sun, Deqing and Sevilla-Lara, Laura}, 60 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 61 | year={2023} 62 | } 63 | ``` 64 | 65 | ## Anckowledgement 66 | 67 | This repo is based on [Cross-View-AG](https://github.com/lhc1224/Cross-View-AG) 68 | , [dino-vit-features](https://github.com/ShirAmir/dino-vit-features), 69 | and [dino](https://github.com/facebookresearch/dino). Thanks for their great work! 70 | -------------------------------------------------------------------------------- /assets/problem_setting.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Reagan1311/LOCATE/8db3d015809b8b80cd8f1173b78f84686d77c3c0/assets/problem_setting.gif -------------------------------------------------------------------------------- /data/datatest.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torch.utils import data 3 | from torchvision import transforms 4 | from PIL import Image 5 | 6 | 7 | class TestData(data.Dataset): 8 | def __init__(self, image_root, crop_size=224, divide="Seen", mask_root=None): 9 | self.image_root = image_root 10 | self.image_list = [] 11 | self.crop_size = crop_size 12 | self.mask_root = mask_root 13 | if divide == "Seen": 14 | self.aff_list = ['beat', "boxing", "brush_with", "carry", "catch", 15 | "cut", "cut_with", "drag", 'drink_with', "eat", 16 | "hit", "hold", "jump", "kick", "lie_on", "lift", 17 | "look_out", "open", "pack", "peel", "pick_up", 18 | "pour", "push", "ride", "sip", "sit_on", "stick", 19 | "stir", "swing", "take_photo", "talk_on", "text_on", 20 | "throw", "type_on", "wash", "write"] 21 | self.obj_list = ['apple', 'axe', 'badminton_racket', 'banana', 'baseball', 'baseball_bat', 22 | 'basketball', 'bed', 'bench', 'bicycle', 'binoculars', 'book', 'bottle', 23 | 'bowl', 'broccoli', 'camera', 'carrot', 'cell_phone', 'chair', 'couch', 24 | 'cup', 'discus', 'drum', 'fork', 'frisbee', 'golf_clubs', 'hammer', 'hot_dog', 25 | 'javelin', 'keyboard', 'knife', 'laptop', 'microwave', 'motorcycle', 'orange', 26 | 'oven', 'pen', 'punching_bag', 'refrigerator', 'rugby_ball', 'scissors', 27 | 'skateboard', 'skis', 'snowboard', 'soccer_ball', 'suitcase', 'surfboard', 28 | 'tennis_racket', 'toothbrush', 'wine_glass'] 29 | else: 30 | self.aff_list = ["carry", "catch", "cut", "cut_with", 'drink_with', 31 | "eat", "hit", "hold", "jump", "kick", "lie_on", "open", "peel", 32 | "pick_up", "pour", "push", "ride", "sip", "sit_on", "stick", 33 | "swing", "take_photo", "throw", "type_on", "wash"] 34 | self.obj_list = ['apple', 'axe', 'badminton_racket', 'banana', 'baseball', 'baseball_bat', 35 | 'basketball', 'bed', 'bench', 'bicycle', 'binoculars', 'book', 'bottle', 36 | 'bowl', 'broccoli', 'camera', 'carrot', 'cell_phone', 'chair', 'couch', 37 | 'cup', 'discus', 'drum', 'fork', 'frisbee', 'golf_clubs', 'hammer', 'hot_dog', 38 | 'javelin', 'keyboard', 'knife', 'laptop', 'microwave', 'motorcycle', 'orange', 39 | 'oven', 'pen', 'punching_bag', 'refrigerator', 'rugby_ball', 'scissors', 40 | 'skateboard', 'skis', 'snowboard', 'soccer_ball', 'suitcase', 'surfboard', 41 | 'tennis_racket', 'toothbrush', 'wine_glass'] 42 | 43 | self.transform = transforms.Compose([ 44 | transforms.Resize((crop_size, crop_size)), 45 | transforms.ToTensor(), 46 | transforms.Normalize(mean=(0.485, 0.456, 0.406), 47 | std=(0.229, 0.224, 0.225))]) 48 | 49 | files = os.listdir(self.image_root) 50 | for file in files: 51 | file_path = os.path.join(self.image_root, file) 52 | obj_files = os.listdir(file_path) 53 | for obj_file in obj_files: 54 | obj_file_path = os.path.join(file_path, obj_file) 55 | images = os.listdir(obj_file_path) 56 | for img in images: 57 | img_path = os.path.join(obj_file_path, img) 58 | mask_path = os.path.join(self.mask_root, file, obj_file, img[:-3] + "png") 59 | 60 | if os.path.exists(mask_path): 61 | self.image_list.append(img_path) 62 | # print(self.image_list) 63 | 64 | self.aff2obj_dict = dict() 65 | for aff in self.aff_list: 66 | aff_path = os.path.join(self.image_root, aff) 67 | aff_obj_list = os.listdir(aff_path) 68 | self.aff2obj_dict.update({aff: aff_obj_list}) 69 | 70 | self.obj2aff_dict = dict() 71 | for obj in self.obj_list: 72 | obj2aff_list = [] 73 | for k, v in self.aff2obj_dict.items(): 74 | if obj in v: 75 | obj2aff_list.append(k) 76 | for i in range(len(obj2aff_list)): 77 | obj2aff_list[i] = self.aff_list.index(obj2aff_list[i]) 78 | self.obj2aff_dict.update({obj: obj2aff_list}) 79 | 80 | def __getitem__(self, item): 81 | 82 | image_path = self.image_list[item] 83 | names = image_path.split("/") 84 | aff_name, object = names[-3], names[-2] 85 | 86 | image = self.load_img(image_path) 87 | label = self.aff_list.index(aff_name) 88 | names = image_path.split("/") 89 | mask_path = os.path.join(self.mask_root, names[-3], names[-2], names[-1][:-3] + "png") 90 | 91 | return image, label, mask_path 92 | 93 | def load_img(self, path): 94 | img = Image.open(path).convert('RGB') 95 | img = self.transform(img) 96 | return img 97 | 98 | def __len__(self): 99 | 100 | return len(self.image_list) 101 | -------------------------------------------------------------------------------- /data/datatrain.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import random 4 | from PIL import Image 5 | from torch.utils import data 6 | from torchvision import transforms 7 | 8 | 9 | class TrainData(data.Dataset): 10 | def __init__(self, exocentric_root, egocentric_root, resize_size=256, crop_size=224, divide="Seen"): 11 | 12 | self.exocentric_root = exocentric_root 13 | self.egocentric_root = egocentric_root 14 | 15 | self.image_list = [] 16 | self.exo_image_list = [] 17 | self.resize_size = resize_size 18 | self.crop_size = crop_size 19 | if divide == "Seen": 20 | self.aff_list = ['beat', "boxing", "brush_with", "carry", "catch", "cut", "cut_with", "drag", 'drink_with', 21 | "eat", "hit", "hold", "jump", "kick", "lie_on", "lift", "look_out", "open", "pack", "peel", 22 | "pick_up", "pour", "push", "ride", "sip", "sit_on", "stick", "stir", "swing", "take_photo", 23 | "talk_on", "text_on", "throw", "type_on", "wash", "write"] 24 | self.obj_list = ['apple', 'axe', 'badminton_racket', 'banana', 'baseball', 'baseball_bat', 25 | 'basketball', 'bed', 'bench', 'bicycle', 'binoculars', 'book', 'bottle', 26 | 'bowl', 'broccoli', 'camera', 'carrot', 'cell_phone', 'chair', 'couch', 27 | 'cup', 'discus', 'drum', 'fork', 'frisbee', 'golf_clubs', 'hammer', 'hot_dog', 28 | 'javelin', 'keyboard', 'knife', 'laptop', 'microwave', 'motorcycle', 'orange', 29 | 'oven', 'pen', 'punching_bag', 'refrigerator', 'rugby_ball', 'scissors', 30 | 'skateboard', 'skis', 'snowboard', 'soccer_ball', 'suitcase', 'surfboard', 31 | 'tennis_racket', 'toothbrush', 'wine_glass'] 32 | else: 33 | self.aff_list = ["carry", "catch", "cut", "cut_with", 'drink_with', 34 | "eat", "hit", "hold", "jump", "kick", "lie_on", "open", "peel", 35 | "pick_up", "pour", "push", "ride", "sip", "sit_on", "stick", 36 | "swing", "take_photo", "throw", "type_on", "wash"] 37 | self.obj_list = ['apple', 'axe', 'badminton_racket', 'banana', 'baseball', 'baseball_bat', 38 | 'basketball', 'bed', 'bench', 'bicycle', 'binoculars', 'book', 'bottle', 39 | 'bowl', 'broccoli', 'camera', 'carrot', 'cell_phone', 'chair', 'couch', 40 | 'cup', 'discus', 'drum', 'fork', 'frisbee', 'golf_clubs', 'hammer', 'hot_dog', 41 | 'javelin', 'keyboard', 'knife', 'laptop', 'microwave', 'motorcycle', 'orange', 42 | 'oven', 'pen', 'punching_bag', 'refrigerator', 'rugby_ball', 'scissors', 43 | 'skateboard', 'skis', 'snowboard', 'soccer_ball', 'suitcase', 'surfboard', 44 | 'tennis_racket', 'toothbrush', 'wine_glass'] 45 | 46 | self.transform = transforms.Compose([ 47 | transforms.Resize(resize_size), 48 | transforms.RandomCrop(crop_size), 49 | transforms.RandomHorizontalFlip(), 50 | transforms.ToTensor(), 51 | transforms.Normalize(mean=(0.485, 0.456, 0.406), 52 | std=(0.229, 0.224, 0.225))]) 53 | 54 | # image list for egocentric images 55 | files = os.listdir(self.exocentric_root) 56 | for file in files: 57 | file_path = os.path.join(self.exocentric_root, file) 58 | obj_files = os.listdir(file_path) 59 | for obj_file in obj_files: 60 | obj_file_path = os.path.join(file_path, obj_file) 61 | images = os.listdir(obj_file_path) 62 | for img in images: 63 | img_path = os.path.join(obj_file_path, img) 64 | self.image_list.append(img_path) 65 | 66 | # multiple affordance labels for exo-centric samples 67 | 68 | def __getitem__(self, item): 69 | 70 | # load egocentric image 71 | exocentric_image_path = self.image_list[item] 72 | names = exocentric_image_path.split("/") 73 | aff_name, object = names[-3], names[-2] 74 | exocentric_image = self.load_img(exocentric_image_path) 75 | aff_label = self.aff_list.index(aff_name) 76 | 77 | ego_path = os.path.join(self.egocentric_root, aff_name, object) 78 | obj_images = os.listdir(ego_path) 79 | idx = random.randint(0, len(obj_images) - 1) 80 | egocentric_image_path = os.path.join(ego_path, obj_images[idx]) 81 | egocentric_image = self.load_img(egocentric_image_path) 82 | 83 | # pick one available affordance, and then choose & load exo-centric images 84 | num_exo = 3 85 | exo_dir = os.path.dirname(exocentric_image_path) 86 | exocentrics = os.listdir(exo_dir) 87 | exo_img_name = [os.path.basename(exocentric_image_path)] 88 | exocentric_images = [exocentric_image] 89 | # exocentric_labels = [] 90 | 91 | if len(exocentrics) > num_exo: 92 | for i in range(num_exo - 1): 93 | exo_img_ = random.choice(exocentrics) 94 | while exo_img_ in exo_img_name: 95 | exo_img_ = random.choice(exocentrics) 96 | exo_img_name.append(exo_img_) 97 | tmp_exo = self.load_img(os.path.join(exo_dir, exo_img_)) 98 | exocentric_images.append(tmp_exo) 99 | else: 100 | for i in range(num_exo - 1): 101 | exo_img_ = random.choice(exocentrics) 102 | # while exo_img_ in exo_img_name: 103 | # exo_img_ = random.choice(exocentrics) 104 | exo_img_name.append(exo_img_) 105 | tmp_exo = self.load_img(os.path.join(exo_dir, exo_img_)) 106 | exocentric_images.append(tmp_exo) 107 | 108 | exocentric_images = torch.stack(exocentric_images, dim=0) # n x 3 x 224 x 224 109 | 110 | return exocentric_images, egocentric_image, aff_label 111 | 112 | def load_img(self, path): 113 | img = Image.open(path).convert('RGB') 114 | img = self.transform(img) 115 | return img 116 | 117 | def __len__(self): 118 | 119 | return len(self.image_list) 120 | -------------------------------------------------------------------------------- /models/dino/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Misc functions. 16 | 17 | Mostly copy-paste from torchvision references or other public repos like DETR: 18 | https://github.com/facebookresearch/detr/blob/master/util/misc.py 19 | """ 20 | import os 21 | import sys 22 | import time 23 | import math 24 | import random 25 | import datetime 26 | import subprocess 27 | from collections import defaultdict, deque 28 | 29 | import numpy as np 30 | import torch 31 | from torch import nn 32 | import torch.distributed as dist 33 | from PIL import ImageFilter, ImageOps 34 | 35 | 36 | class GaussianBlur(object): 37 | """ 38 | Apply Gaussian Blur to the PIL image. 39 | """ 40 | 41 | def __init__(self, p=0.5, radius_min=0.1, radius_max=2.): 42 | self.prob = p 43 | self.radius_min = radius_min 44 | self.radius_max = radius_max 45 | 46 | def __call__(self, img): 47 | do_it = random.random() <= self.prob 48 | if not do_it: 49 | return img 50 | 51 | return img.filter( 52 | ImageFilter.GaussianBlur( 53 | radius=random.uniform(self.radius_min, self.radius_max) 54 | ) 55 | ) 56 | 57 | 58 | class Solarization(object): 59 | """ 60 | Apply Solarization to the PIL image. 61 | """ 62 | 63 | def __init__(self, p): 64 | self.p = p 65 | 66 | def __call__(self, img): 67 | if random.random() < self.p: 68 | return ImageOps.solarize(img) 69 | else: 70 | return img 71 | 72 | 73 | def load_pretrained_weights(model, pretrained_weights, checkpoint_key, model_name, patch_size): 74 | if os.path.isfile(pretrained_weights): 75 | state_dict = torch.load(pretrained_weights, map_location="cpu") 76 | if checkpoint_key is not None and checkpoint_key in state_dict: 77 | print(f"Take key {checkpoint_key} in provided checkpoint dict") 78 | state_dict = state_dict[checkpoint_key] 79 | # remove `module.` prefix 80 | state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} 81 | # remove `backbone.` prefix induced by multicrop wrapper 82 | state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()} 83 | msg = model.load_state_dict(state_dict, strict=False) 84 | print('Pretrained weights found at {} and loaded with msg: {}'.format(pretrained_weights, msg)) 85 | else: 86 | # print("Please use the `--pretrained_weights` argument to indicate the path of the checkpoint to evaluate.") 87 | url = None 88 | if model_name == "vit_small" and patch_size == 16: 89 | url = "dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth" 90 | elif model_name == "vit_small" and patch_size == 8: 91 | url = "dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth" 92 | elif model_name == "vit_base" and patch_size == 16: 93 | url = "dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth" 94 | elif model_name == "vit_base" and patch_size == 8: 95 | url = "dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth" 96 | if url is not None: 97 | # print("Since no pretrained weights have been provided, we load the reference pretrained dino weights.") 98 | state_dict = torch.hub.load_state_dict_from_url(url="https://dl.fbaipublicfiles.com/dino/" + url) 99 | model.load_state_dict(state_dict, strict=True) 100 | else: 101 | print("There is no reference weights available for this model => We use random weights.") 102 | 103 | 104 | def clip_gradients(model, clip): 105 | norms = [] 106 | for name, p in model.named_parameters(): 107 | if p.grad is not None: 108 | param_norm = p.grad.data.norm(2) 109 | norms.append(param_norm.item()) 110 | clip_coef = clip / (param_norm + 1e-6) 111 | if clip_coef < 1: 112 | p.grad.data.mul_(clip_coef) 113 | return norms 114 | 115 | 116 | def cancel_gradients_last_layer(epoch, model, freeze_last_layer): 117 | if epoch >= freeze_last_layer: 118 | return 119 | for n, p in model.named_parameters(): 120 | if "last_layer" in n: 121 | p.grad = None 122 | 123 | 124 | def restart_from_checkpoint(ckp_path, run_variables=None, **kwargs): 125 | """ 126 | Re-start from checkpoint 127 | """ 128 | if not os.path.isfile(ckp_path): 129 | return 130 | print("Found checkpoint at {}".format(ckp_path)) 131 | 132 | # open checkpoint file 133 | checkpoint = torch.load(ckp_path, map_location="cpu") 134 | 135 | # key is what to look for in the checkpoint file 136 | # value is the object to load 137 | # example: {'state_dict': model} 138 | for key, value in kwargs.items(): 139 | if key in checkpoint and value is not None: 140 | try: 141 | msg = value.load_state_dict(checkpoint[key], strict=False) 142 | print("=> loaded '{}' from checkpoint '{}' with msg {}".format(key, ckp_path, msg)) 143 | except TypeError: 144 | try: 145 | msg = value.load_state_dict(checkpoint[key]) 146 | print("=> loaded '{}' from checkpoint: '{}'".format(key, ckp_path)) 147 | except ValueError: 148 | print("=> failed to load '{}' from checkpoint: '{}'".format(key, ckp_path)) 149 | else: 150 | print("=> key '{}' not found in checkpoint: '{}'".format(key, ckp_path)) 151 | 152 | # re load variable important for the run 153 | if run_variables is not None: 154 | for var_name in run_variables: 155 | if var_name in checkpoint: 156 | run_variables[var_name] = checkpoint[var_name] 157 | 158 | 159 | def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, start_warmup_value=0): 160 | warmup_schedule = np.array([]) 161 | warmup_iters = warmup_epochs * niter_per_ep 162 | if warmup_epochs > 0: 163 | warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters) 164 | 165 | iters = np.arange(epochs * niter_per_ep - warmup_iters) 166 | schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters))) 167 | 168 | schedule = np.concatenate((warmup_schedule, schedule)) 169 | assert len(schedule) == epochs * niter_per_ep 170 | return schedule 171 | 172 | 173 | def bool_flag(s): 174 | """ 175 | Parse boolean arguments from the command line. 176 | """ 177 | FALSY_STRINGS = {"off", "false", "0"} 178 | TRUTHY_STRINGS = {"on", "true", "1"} 179 | if s.lower() in FALSY_STRINGS: 180 | return False 181 | elif s.lower() in TRUTHY_STRINGS: 182 | return True 183 | else: 184 | raise argparse.ArgumentTypeError("invalid value for a boolean flag") 185 | 186 | 187 | def fix_random_seeds(seed=31): 188 | """ 189 | Fix random seeds. 190 | """ 191 | torch.manual_seed(seed) 192 | torch.cuda.manual_seed_all(seed) 193 | np.random.seed(seed) 194 | 195 | 196 | class SmoothedValue(object): 197 | """Track a series of values and provide access to smoothed values over a 198 | window or the global series average. 199 | """ 200 | 201 | def __init__(self, window_size=20, fmt=None): 202 | if fmt is None: 203 | fmt = "{median:.6f} ({global_avg:.6f})" 204 | self.deque = deque(maxlen=window_size) 205 | self.total = 0.0 206 | self.count = 0 207 | self.fmt = fmt 208 | 209 | def update(self, value, n=1): 210 | self.deque.append(value) 211 | self.count += n 212 | self.total += value * n 213 | 214 | def synchronize_between_processes(self): 215 | """ 216 | Warning: does not synchronize the deque! 217 | """ 218 | if not is_dist_avail_and_initialized(): 219 | return 220 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 221 | dist.barrier() 222 | dist.all_reduce(t) 223 | t = t.tolist() 224 | self.count = int(t[0]) 225 | self.total = t[1] 226 | 227 | @property 228 | def median(self): 229 | d = torch.tensor(list(self.deque)) 230 | return d.median().item() 231 | 232 | @property 233 | def avg(self): 234 | d = torch.tensor(list(self.deque), dtype=torch.float32) 235 | return d.mean().item() 236 | 237 | @property 238 | def global_avg(self): 239 | return self.total / self.count 240 | 241 | @property 242 | def max(self): 243 | return max(self.deque) 244 | 245 | @property 246 | def value(self): 247 | return self.deque[-1] 248 | 249 | def __str__(self): 250 | return self.fmt.format( 251 | median=self.median, 252 | avg=self.avg, 253 | global_avg=self.global_avg, 254 | max=self.max, 255 | value=self.value) 256 | 257 | 258 | def reduce_dict(input_dict, average=True): 259 | """ 260 | Args: 261 | input_dict (dict): all the values will be reduced 262 | average (bool): whether to do average or sum 263 | Reduce the values in the dictionary from all processes so that all processes 264 | have the averaged results. Returns a dict with the same fields as 265 | input_dict, after reduction. 266 | """ 267 | world_size = get_world_size() 268 | if world_size < 2: 269 | return input_dict 270 | with torch.no_grad(): 271 | names = [] 272 | values = [] 273 | # sort the keys so that they are consistent across processes 274 | for k in sorted(input_dict.keys()): 275 | names.append(k) 276 | values.append(input_dict[k]) 277 | values = torch.stack(values, dim=0) 278 | dist.all_reduce(values) 279 | if average: 280 | values /= world_size 281 | reduced_dict = {k: v for k, v in zip(names, values)} 282 | return reduced_dict 283 | 284 | 285 | class MetricLogger(object): 286 | def __init__(self, delimiter="\t"): 287 | self.meters = defaultdict(SmoothedValue) 288 | self.delimiter = delimiter 289 | 290 | def update(self, **kwargs): 291 | for k, v in kwargs.items(): 292 | if isinstance(v, torch.Tensor): 293 | v = v.item() 294 | assert isinstance(v, (float, int)) 295 | self.meters[k].update(v) 296 | 297 | def __getattr__(self, attr): 298 | if attr in self.meters: 299 | return self.meters[attr] 300 | if attr in self.__dict__: 301 | return self.__dict__[attr] 302 | raise AttributeError("'{}' object has no attribute '{}'".format( 303 | type(self).__name__, attr)) 304 | 305 | def __str__(self): 306 | loss_str = [] 307 | for name, meter in self.meters.items(): 308 | loss_str.append( 309 | "{}: {}".format(name, str(meter)) 310 | ) 311 | return self.delimiter.join(loss_str) 312 | 313 | def synchronize_between_processes(self): 314 | for meter in self.meters.values(): 315 | meter.synchronize_between_processes() 316 | 317 | def add_meter(self, name, meter): 318 | self.meters[name] = meter 319 | 320 | def log_every(self, iterable, print_freq, header=None): 321 | i = 0 322 | if not header: 323 | header = '' 324 | start_time = time.time() 325 | end = time.time() 326 | iter_time = SmoothedValue(fmt='{avg:.6f}') 327 | data_time = SmoothedValue(fmt='{avg:.6f}') 328 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 329 | if torch.cuda.is_available(): 330 | log_msg = self.delimiter.join([ 331 | header, 332 | '[{0' + space_fmt + '}/{1}]', 333 | 'eta: {eta}', 334 | '{meters}', 335 | 'time: {time}', 336 | 'data: {data}', 337 | 'max mem: {memory:.0f}' 338 | ]) 339 | else: 340 | log_msg = self.delimiter.join([ 341 | header, 342 | '[{0' + space_fmt + '}/{1}]', 343 | 'eta: {eta}', 344 | '{meters}', 345 | 'time: {time}', 346 | 'data: {data}' 347 | ]) 348 | MB = 1024.0 * 1024.0 349 | for obj in iterable: 350 | data_time.update(time.time() - end) 351 | yield obj 352 | iter_time.update(time.time() - end) 353 | if i % print_freq == 0 or i == len(iterable) - 1: 354 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 355 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 356 | if torch.cuda.is_available(): 357 | print(log_msg.format( 358 | i, len(iterable), eta=eta_string, 359 | meters=str(self), 360 | time=str(iter_time), data=str(data_time), 361 | memory=torch.cuda.max_memory_allocated() / MB)) 362 | else: 363 | print(log_msg.format( 364 | i, len(iterable), eta=eta_string, 365 | meters=str(self), 366 | time=str(iter_time), data=str(data_time))) 367 | i += 1 368 | end = time.time() 369 | total_time = time.time() - start_time 370 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 371 | print('{} Total time: {} ({:.6f} s / it)'.format( 372 | header, total_time_str, total_time / len(iterable))) 373 | 374 | 375 | def get_sha(): 376 | cwd = os.path.dirname(os.path.abspath(__file__)) 377 | 378 | def _run(command): 379 | return subprocess.check_output(command, cwd=cwd).decode('ascii').strip() 380 | 381 | sha = 'N/A' 382 | diff = "clean" 383 | branch = 'N/A' 384 | try: 385 | sha = _run(['git', 'rev-parse', 'HEAD']) 386 | subprocess.check_output(['git', 'diff'], cwd=cwd) 387 | diff = _run(['git', 'diff-index', 'HEAD']) 388 | diff = "has uncommited changes" if diff else "clean" 389 | branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD']) 390 | except Exception: 391 | pass 392 | message = f"sha: {sha}, status: {diff}, branch: {branch}" 393 | return message 394 | 395 | 396 | def is_dist_avail_and_initialized(): 397 | if not dist.is_available(): 398 | return False 399 | if not dist.is_initialized(): 400 | return False 401 | return True 402 | 403 | 404 | def get_world_size(): 405 | if not is_dist_avail_and_initialized(): 406 | return 1 407 | return dist.get_world_size() 408 | 409 | 410 | def get_rank(): 411 | if not is_dist_avail_and_initialized(): 412 | return 0 413 | return dist.get_rank() 414 | 415 | 416 | def is_main_process(): 417 | return get_rank() == 0 418 | 419 | 420 | def save_on_master(*args, **kwargs): 421 | if is_main_process(): 422 | torch.save(*args, **kwargs) 423 | 424 | 425 | def setup_for_distributed(is_master): 426 | """ 427 | This function disables printing when not in master process 428 | """ 429 | import builtins as __builtin__ 430 | builtin_print = __builtin__.print 431 | 432 | def print(*args, **kwargs): 433 | force = kwargs.pop('force', False) 434 | if is_master or force: 435 | builtin_print(*args, **kwargs) 436 | 437 | __builtin__.print = print 438 | 439 | 440 | def init_distributed_mode(args): 441 | # launched with torch.distributed.launch 442 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 443 | args.rank = int(os.environ["RANK"]) 444 | args.world_size = int(os.environ['WORLD_SIZE']) 445 | args.gpu = int(os.environ['LOCAL_RANK']) 446 | # launched with submitit on a slurm cluster 447 | elif 'SLURM_PROCID' in os.environ: 448 | args.rank = int(os.environ['SLURM_PROCID']) 449 | args.gpu = args.rank % torch.cuda.device_count() 450 | # launched naively with `python main_dino.py` 451 | # we manually add MASTER_ADDR and MASTER_PORT to env variables 452 | elif torch.cuda.is_available(): 453 | print('Will run the code on one GPU.') 454 | args.rank, args.gpu, args.world_size = 0, 0, 1 455 | os.environ['MASTER_ADDR'] = '127.0.0.1' 456 | os.environ['MASTER_PORT'] = '29500' 457 | else: 458 | print('Does not support training without GPU.') 459 | sys.exit(1) 460 | 461 | dist.init_process_group( 462 | backend="nccl", 463 | init_method=args.dist_url, 464 | world_size=args.world_size, 465 | rank=args.rank, 466 | ) 467 | 468 | torch.cuda.set_device(args.gpu) 469 | print('| distributed init (rank {}): {}'.format( 470 | args.rank, args.dist_url), flush=True) 471 | dist.barrier() 472 | setup_for_distributed(args.rank == 0) 473 | 474 | 475 | def accuracy(output, target, topk=(1,)): 476 | """Computes the accuracy over the k top predictions for the specified values of k""" 477 | maxk = max(topk) 478 | batch_size = target.size(0) 479 | _, pred = output.topk(maxk, 1, True, True) 480 | pred = pred.t() 481 | correct = pred.eq(target.reshape(1, -1).expand_as(pred)) 482 | return [correct[:k].reshape(-1).float().sum(0) * 100. / batch_size for k in topk] 483 | 484 | 485 | def _no_grad_trunc_normal_(tensor, mean, std, a, b): 486 | # Cut & paste from PyTorch official master until it's in a few official releases - RW 487 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf 488 | def norm_cdf(x): 489 | # Computes standard normal cumulative distribution function 490 | return (1. + math.erf(x / math.sqrt(2.))) / 2. 491 | 492 | if (mean < a - 2 * std) or (mean > b + 2 * std): 493 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " 494 | "The distribution of values may be incorrect.", 495 | stacklevel=2) 496 | 497 | with torch.no_grad(): 498 | # Values are generated by using a truncated uniform distribution and 499 | # then using the inverse CDF for the normal distribution. 500 | # Get upper and lower cdf values 501 | l = norm_cdf((a - mean) / std) 502 | u = norm_cdf((b - mean) / std) 503 | 504 | # Uniformly fill tensor with values from [l, u], then translate to 505 | # [2l-1, 2u-1]. 506 | tensor.uniform_(2 * l - 1, 2 * u - 1) 507 | 508 | # Use inverse cdf transform for normal distribution to get truncated 509 | # standard normal 510 | tensor.erfinv_() 511 | 512 | # Transform to proper mean, std 513 | tensor.mul_(std * math.sqrt(2.)) 514 | tensor.add_(mean) 515 | 516 | # Clamp to ensure it's in the proper range 517 | tensor.clamp_(min=a, max=b) 518 | return tensor 519 | 520 | 521 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): 522 | # type: (Tensor, float, float, float, float) -> Tensor 523 | return _no_grad_trunc_normal_(tensor, mean, std, a, b) 524 | 525 | 526 | class LARS(torch.optim.Optimizer): 527 | """ 528 | Almost copy-paste from https://github.com/facebookresearch/barlowtwins/blob/main/main.py 529 | """ 530 | 531 | def __init__(self, params, lr=0, weight_decay=0, momentum=0.9, eta=0.001, 532 | weight_decay_filter=None, lars_adaptation_filter=None): 533 | defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, 534 | eta=eta, weight_decay_filter=weight_decay_filter, 535 | lars_adaptation_filter=lars_adaptation_filter) 536 | super().__init__(params, defaults) 537 | 538 | @torch.no_grad() 539 | def step(self): 540 | for g in self.param_groups: 541 | for p in g['params']: 542 | dp = p.grad 543 | 544 | if dp is None: 545 | continue 546 | 547 | if p.ndim != 1: 548 | dp = dp.add(p, alpha=g['weight_decay']) 549 | 550 | if p.ndim != 1: 551 | param_norm = torch.norm(p) 552 | update_norm = torch.norm(dp) 553 | one = torch.ones_like(param_norm) 554 | q = torch.where(param_norm > 0., 555 | torch.where(update_norm > 0, 556 | (g['eta'] * param_norm / update_norm), one), one) 557 | dp = dp.mul(q) 558 | 559 | param_state = self.state[p] 560 | if 'mu' not in param_state: 561 | param_state['mu'] = torch.zeros_like(p) 562 | mu = param_state['mu'] 563 | mu.mul_(g['momentum']).add_(dp) 564 | 565 | p.add_(mu, alpha=-g['lr']) 566 | 567 | 568 | class MultiCropWrapper(nn.Module): 569 | """ 570 | Perform forward pass separately on each resolution input. 571 | The inputs corresponding to a single resolution are clubbed and single 572 | forward is run on the same resolution inputs. Hence we do several 573 | forward passes = number of different resolutions used. We then 574 | concatenate all the output features and run the head forward on these 575 | concatenated features. 576 | """ 577 | 578 | def __init__(self, backbone, head): 579 | super(MultiCropWrapper, self).__init__() 580 | # disable layers dedicated to ImageNet labels classification 581 | backbone.fc, backbone.head = nn.Identity(), nn.Identity() 582 | self.backbone = backbone 583 | self.head = head 584 | 585 | def forward(self, x): 586 | # convert to list 587 | if not isinstance(x, list): 588 | x = [x] 589 | idx_crops = torch.cumsum(torch.unique_consecutive( 590 | torch.tensor([inp.shape[-1] for inp in x]), 591 | return_counts=True, 592 | )[1], 0) 593 | start_idx = 0 594 | for end_idx in idx_crops: 595 | _out = self.backbone(torch.cat(x[start_idx: end_idx])) 596 | if start_idx == 0: 597 | output = _out 598 | else: 599 | output = torch.cat((output, _out)) 600 | start_idx = end_idx 601 | # Run the head forward on the concatenated features. 602 | return self.head(output) 603 | 604 | 605 | def get_params_groups(model): 606 | regularized = [] 607 | not_regularized = [] 608 | for name, param in model.named_parameters(): 609 | if not param.requires_grad: 610 | continue 611 | # we do not regularize biases nor Norm parameters 612 | if name.endswith(".bias") or len(param.shape) == 1: 613 | not_regularized.append(param) 614 | else: 615 | regularized.append(param) 616 | return [{'params': regularized}, {'params': not_regularized, 'weight_decay': 0.}] 617 | 618 | 619 | def has_batchnorms(model): 620 | bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm) 621 | for name, module in model.named_modules(): 622 | if isinstance(module, bn_types): 623 | return True 624 | return False 625 | -------------------------------------------------------------------------------- /models/dino/vision_transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Mostly copy-paste from timm library. 16 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py 17 | """ 18 | import math 19 | from functools import partial 20 | 21 | import torch 22 | import torch.nn as nn 23 | 24 | from models.dino.utils import trunc_normal_ 25 | 26 | 27 | def drop_path(x, drop_prob: float = 0., training: bool = False): 28 | if drop_prob == 0. or not training: 29 | return x 30 | keep_prob = 1 - drop_prob 31 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 32 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) 33 | random_tensor.floor_() # binarize 34 | output = x.div(keep_prob) * random_tensor 35 | return output 36 | 37 | 38 | class DropPath(nn.Module): 39 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 40 | """ 41 | 42 | def __init__(self, drop_prob=None): 43 | super(DropPath, self).__init__() 44 | self.drop_prob = drop_prob 45 | 46 | def forward(self, x): 47 | return drop_path(x, self.drop_prob, self.training) 48 | 49 | 50 | class Mlp(nn.Module): 51 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 52 | super().__init__() 53 | out_features = out_features or in_features 54 | hidden_features = hidden_features or in_features 55 | self.fc1 = nn.Linear(in_features, hidden_features) 56 | self.act = act_layer() 57 | self.fc2 = nn.Linear(hidden_features, out_features) 58 | self.drop = nn.Dropout(drop) 59 | 60 | def forward(self, x): 61 | x = self.fc1(x) 62 | x = self.act(x) 63 | x = self.drop(x) 64 | x = self.fc2(x) 65 | x = self.drop(x) 66 | return x 67 | 68 | 69 | class Attention(nn.Module): 70 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 71 | super().__init__() 72 | self.num_heads = num_heads 73 | head_dim = dim // num_heads 74 | self.scale = qk_scale or head_dim ** -0.5 75 | 76 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 77 | self.attn_drop = nn.Dropout(attn_drop) 78 | self.proj = nn.Linear(dim, dim) 79 | self.proj_drop = nn.Dropout(proj_drop) 80 | 81 | def forward(self, x, return_key=False): 82 | B, N, C = x.shape 83 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 84 | q, k, v = qkv[0], qkv[1], qkv[2] 85 | 86 | attn = (q @ k.transpose(-2, -1)) * self.scale 87 | attn = attn.softmax(dim=-1) 88 | attn = self.attn_drop(attn) 89 | 90 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 91 | x = self.proj(x) 92 | x = self.proj_drop(x) 93 | if not return_key: 94 | return x, attn 95 | else: 96 | return x, attn, k 97 | 98 | 99 | class Block(nn.Module): 100 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 101 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 102 | super().__init__() 103 | self.norm1 = norm_layer(dim) 104 | self.attn = Attention( 105 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 106 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 107 | self.norm2 = norm_layer(dim) 108 | mlp_hidden_dim = int(dim * mlp_ratio) 109 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 110 | 111 | def forward(self, x, return_attention=False, return_key=False): 112 | if return_key: 113 | y, attn, key = self.attn(self.norm1(x), return_key) 114 | else: 115 | y, attn = self.attn(self.norm1(x)) 116 | x = x + self.drop_path(y) 117 | x = x + self.drop_path(self.mlp(self.norm2(x))) 118 | if return_attention: 119 | return x, attn 120 | elif return_key: 121 | return x, key, attn 122 | else: 123 | return x 124 | 125 | 126 | class PatchEmbed(nn.Module): 127 | """ Image to Patch Embedding 128 | """ 129 | 130 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): 131 | super().__init__() 132 | num_patches = (img_size // patch_size) * (img_size // patch_size) 133 | self.img_size = img_size 134 | self.patch_size = patch_size 135 | self.num_patches = num_patches 136 | 137 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 138 | 139 | def forward(self, x): 140 | B, C, H, W = x.shape 141 | x = self.proj(x).flatten(2).transpose(1, 2) 142 | return x 143 | 144 | 145 | class VisionTransformer(nn.Module): 146 | """ Vision Transformer """ 147 | 148 | def __init__(self, img_size=[224], patch_size=16, in_chans=3, num_classes=0, embed_dim=768, depth=12, 149 | num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., 150 | drop_path_rate=0., norm_layer=nn.LayerNorm, **kwargs): 151 | super().__init__() 152 | self.num_features = self.embed_dim = embed_dim 153 | 154 | self.patch_embed = PatchEmbed( 155 | img_size=img_size[0], patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 156 | num_patches = self.patch_embed.num_patches 157 | 158 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 159 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) 160 | self.pos_drop = nn.Dropout(p=drop_rate) 161 | 162 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 163 | self.blocks = nn.ModuleList([ 164 | Block( 165 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 166 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) 167 | for i in range(depth)]) 168 | self.norm = norm_layer(embed_dim) 169 | 170 | # Classifier head 171 | self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() 172 | 173 | trunc_normal_(self.pos_embed, std=.02) 174 | trunc_normal_(self.cls_token, std=.02) 175 | self.apply(self._init_weights) 176 | 177 | def _init_weights(self, m): 178 | if isinstance(m, nn.Linear): 179 | trunc_normal_(m.weight, std=.02) 180 | if isinstance(m, nn.Linear) and m.bias is not None: 181 | nn.init.constant_(m.bias, 0) 182 | elif isinstance(m, nn.LayerNorm): 183 | nn.init.constant_(m.bias, 0) 184 | nn.init.constant_(m.weight, 1.0) 185 | 186 | def interpolate_pos_encoding(self, x, w, h): 187 | npatch = x.shape[1] - 1 188 | N = self.pos_embed.shape[1] - 1 189 | if npatch == N and w == h: 190 | return self.pos_embed 191 | class_pos_embed = self.pos_embed[:, 0] 192 | patch_pos_embed = self.pos_embed[:, 1:] 193 | dim = x.shape[-1] 194 | w0 = w // self.patch_embed.patch_size 195 | h0 = h // self.patch_embed.patch_size 196 | # we add a small number to avoid floating point error in the interpolation 197 | # see discussion at https://github.com/facebookresearch/dino/issues/8 198 | w0, h0 = w0 + 0.1, h0 + 0.1 199 | patch_pos_embed = nn.functional.interpolate( 200 | patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), 201 | scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), 202 | mode='bicubic', 203 | ) 204 | assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1] 205 | patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) 206 | return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) 207 | 208 | def prepare_tokens(self, x): 209 | B, nc, w, h = x.shape 210 | x = self.patch_embed(x) # patch linear embedding 211 | 212 | # add the [CLS] token to the embed patch tokens 213 | cls_tokens = self.cls_token.expand(B, -1, -1) 214 | x = torch.cat((cls_tokens, x), dim=1) 215 | 216 | # add positional encoding to each token 217 | x = x + self.interpolate_pos_encoding(x, w, h) 218 | 219 | return self.pos_drop(x) 220 | 221 | def forward(self, x, return_attention=False): 222 | if return_attention: 223 | atten_weights = [] 224 | x = self.prepare_tokens(x) 225 | for blk_ in self.blocks: 226 | x, weights = blk_(x, return_attention) 227 | atten_weights.append(weights) 228 | x = self.norm(x) 229 | return x, atten_weights 230 | 231 | else: 232 | x = self.prepare_tokens(x) 233 | for blk_ in self.blocks: 234 | x = blk_(x) 235 | x = self.norm(x) 236 | return x 237 | 238 | def get_last_selfattention(self, x): 239 | x = self.prepare_tokens(x) 240 | for i, blk in enumerate(self.blocks): 241 | if i < len(self.blocks) - 1: 242 | x = blk(x) 243 | else: 244 | # return attention of the last block 245 | return blk(x, return_attention=True) 246 | 247 | def get_intermediate_layers(self, x, n=1): 248 | x = self.prepare_tokens(x) 249 | # we return the output tokens from the `n` last blocks 250 | output = [] 251 | for i, blk in enumerate(self.blocks): 252 | x = blk(x) 253 | if len(self.blocks) - i <= n: 254 | output.append(self.norm(x)) 255 | return output 256 | 257 | def get_last_key(self, x, extra_layer=None): 258 | x = self.prepare_tokens(x) 259 | key_mid = 0 260 | for i, blk in enumerate(self.blocks): 261 | if extra_layer != None and i == extra_layer: 262 | x, key, attn = blk(x, return_key=True) 263 | key_mid = key 264 | elif i < len(self.blocks) - 1: 265 | x = blk(x) 266 | else: 267 | # return attention of the last block 268 | x, key, attn = blk(x, return_key=True) 269 | if extra_layer == None: 270 | return x, key, attn 271 | else: 272 | return key_mid, x, key, attn 273 | 274 | 275 | def vit_tiny(patch_size=16, **kwargs): 276 | model = VisionTransformer( 277 | patch_size=patch_size, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, 278 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 279 | return model 280 | 281 | 282 | def vit_small(patch_size=16, **kwargs): 283 | model = VisionTransformer( 284 | patch_size=patch_size, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, 285 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 286 | return model 287 | 288 | 289 | def vit_base(patch_size=16, **kwargs): 290 | model = VisionTransformer( 291 | patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, 292 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 293 | return model 294 | 295 | 296 | class DINOHead(nn.Module): 297 | def __init__(self, in_dim, out_dim, use_bn=False, norm_last_layer=True, nlayers=3, hidden_dim=2048, 298 | bottleneck_dim=256): 299 | super().__init__() 300 | nlayers = max(nlayers, 1) 301 | if nlayers == 1: 302 | self.mlp = nn.Linear(in_dim, bottleneck_dim) 303 | else: 304 | layers = [nn.Linear(in_dim, hidden_dim)] 305 | if use_bn: 306 | layers.append(nn.BatchNorm1d(hidden_dim)) 307 | layers.append(nn.GELU()) 308 | for _ in range(nlayers - 2): 309 | layers.append(nn.Linear(hidden_dim, hidden_dim)) 310 | if use_bn: 311 | layers.append(nn.BatchNorm1d(hidden_dim)) 312 | layers.append(nn.GELU()) 313 | layers.append(nn.Linear(hidden_dim, bottleneck_dim)) 314 | self.mlp = nn.Sequential(*layers) 315 | self.apply(self._init_weights) 316 | self.last_layer = nn.utils.weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False)) 317 | self.last_layer.weight_g.data.fill_(1) 318 | if norm_last_layer: 319 | self.last_layer.weight_g.requires_grad = False 320 | 321 | def _init_weights(self, m): 322 | if isinstance(m, nn.Linear): 323 | trunc_normal_(m.weight, std=.02) 324 | if isinstance(m, nn.Linear) and m.bias is not None: 325 | nn.init.constant_(m.bias, 0) 326 | 327 | def forward(self, x): 328 | x = self.mlp(x) 329 | x = nn.functional.normalize(x, dim=-1, p=2) 330 | x = self.last_layer(x) 331 | return x 332 | -------------------------------------------------------------------------------- /models/locate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from models.dino import vision_transformer as vits 5 | from models.dino.utils import load_pretrained_weights 6 | from models.model_util import * 7 | from fast_pytorch_kmeans import KMeans 8 | 9 | 10 | class Mlp(nn.Module): 11 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 12 | super(Mlp, self).__init__() 13 | out_features = out_features or in_features 14 | hidden_features = hidden_features or in_features 15 | self.norm = nn.LayerNorm(in_features) 16 | self.fc1 = nn.Linear(in_features, hidden_features) 17 | self.act = act_layer() 18 | self.fc2 = nn.Linear(hidden_features, out_features) 19 | self.drop = nn.Dropout(drop) 20 | 21 | def forward(self, x): 22 | x = self.norm(x) 23 | x = self.fc1(x) 24 | x = self.act(x) 25 | x = self.drop(x) 26 | x = self.fc2(x) 27 | x = self.drop(x) 28 | return x 29 | 30 | 31 | class Net(nn.Module): 32 | 33 | def __init__(self, aff_classes=36): 34 | super(Net, self).__init__() 35 | 36 | self.aff_classes = aff_classes 37 | self.gap = nn.AdaptiveAvgPool2d(1) 38 | 39 | # --- hyper-parameters --- # 40 | self.aff_cam_thd = 0.6 41 | self.part_iou_thd = 0.6 42 | self.cel_margin = 0.5 43 | 44 | # --- dino-vit features --- # 45 | self.vit_feat_dim = 384 46 | self.cluster_num = 3 47 | self.stride = 16 48 | self.patch = 16 49 | 50 | self.vit_model = vits.__dict__['vit_small'](patch_size=self.patch, num_classes=0) 51 | load_pretrained_weights(self.vit_model, '', None, 'vit_small', self.patch) 52 | 53 | # --- learning parameters --- # 54 | self.aff_proj = Mlp(in_features=self.vit_feat_dim, hidden_features=int(self.vit_feat_dim * 4), 55 | act_layer=nn.GELU, drop=0.) 56 | self.aff_ego_proj = nn.Sequential( 57 | nn.Conv2d(self.vit_feat_dim, self.vit_feat_dim, kernel_size=3, stride=1, padding=1), 58 | nn.BatchNorm2d(self.vit_feat_dim), 59 | nn.ReLU(True), 60 | nn.Conv2d(self.vit_feat_dim, self.vit_feat_dim, kernel_size=3, stride=1, padding=1), 61 | nn.BatchNorm2d(self.vit_feat_dim), 62 | nn.ReLU(True), 63 | ) 64 | self.aff_exo_proj = nn.Sequential( 65 | nn.Conv2d(self.vit_feat_dim, self.vit_feat_dim, kernel_size=3, stride=1, padding=1), 66 | nn.BatchNorm2d(self.vit_feat_dim), 67 | nn.ReLU(True), 68 | nn.Conv2d(self.vit_feat_dim, self.vit_feat_dim, kernel_size=3, stride=1, padding=1), 69 | nn.BatchNorm2d(self.vit_feat_dim), 70 | nn.ReLU(True), 71 | ) 72 | self.aff_fc = nn.Conv2d(self.vit_feat_dim, self.aff_classes, 1) 73 | 74 | def forward(self, exo, ego, aff_label, epoch): 75 | 76 | num_exo = exo.shape[1] 77 | exo = exo.flatten(0, 1) # b*num_exo x 3 x 224 x 224 78 | 79 | # --- Extract deep descriptors from DINO-vit --- # 80 | with torch.no_grad(): 81 | _, ego_key, ego_attn = self.vit_model.get_last_key(ego) # attn: b x 6 x (1+hw) x (1+hw) 82 | _, exo_key, exo_attn = self.vit_model.get_last_key(exo) 83 | ego_desc = ego_key.permute(0, 2, 3, 1).flatten(-2, -1).detach() 84 | exo_desc = exo_key.permute(0, 2, 3, 1).flatten(-2, -1).detach() 85 | 86 | ego_proj = ego_desc[:, 1:] + self.aff_proj(ego_desc[:, 1:]) 87 | exo_proj = exo_desc[:, 1:] + self.aff_proj(exo_desc[:, 1:]) 88 | ego_desc = self._reshape_transform(ego_desc[:, 1:, :], self.patch, self.stride) 89 | exo_desc = self._reshape_transform(exo_desc[:, 1:, :], self.patch, self.stride) 90 | ego_proj = self._reshape_transform(ego_proj, self.patch, self.stride) 91 | exo_proj = self._reshape_transform(exo_proj, self.patch, self.stride) 92 | 93 | b, c, h, w = ego_desc.shape 94 | ego_cls_attn = ego_attn[:, :, 0, 1:].reshape(b, 6, h, w) 95 | ego_cls_attn = (ego_cls_attn > ego_cls_attn.flatten(-2, -1).mean(-1, keepdim=True).unsqueeze(-1)).float() 96 | head_idxs = [0, 1, 3] 97 | ego_sam = ego_cls_attn[:, head_idxs].mean(1) 98 | ego_sam = normalize_minmax(ego_sam) 99 | ego_sam_flat = ego_sam.flatten(-2, -1) 100 | 101 | # --- Affordance CAM generation --- # 102 | exo_proj = self.aff_exo_proj(exo_proj) 103 | aff_cam = self.aff_fc(exo_proj) # b*num_exo x 36 x h x w 104 | aff_logits = self.gap(aff_cam).reshape(b, num_exo, self.aff_classes) 105 | aff_cam_re = aff_cam.reshape(b, num_exo, self.aff_classes, h, w) 106 | 107 | gt_aff_cam = torch.zeros(b, num_exo, h, w).cuda() 108 | for b_ in range(b): 109 | gt_aff_cam[b_, :] = aff_cam_re[b_, :, aff_label[b_]] 110 | 111 | # --- Clustering extracted descriptors based on CAM --- # 112 | ego_desc_flat = ego_desc.flatten(-2, -1) # b x 384 x hw 113 | exo_desc_re_flat = exo_desc.reshape(b, num_exo, c, h, w).flatten(-2, -1) 114 | sim_maps = torch.zeros(b, self.cluster_num, h * w).cuda() 115 | exo_sim_maps = torch.zeros(b, num_exo, self.cluster_num, h * w).cuda() 116 | part_score = torch.zeros(b, self.cluster_num).cuda() 117 | part_proto = torch.zeros(b, c).cuda() 118 | for b_ in range(b): 119 | exo_aff_desc = [] 120 | for n in range(num_exo): 121 | tmp_cam = gt_aff_cam[b_, n].reshape(-1) 122 | tmp_max, tmp_min = tmp_cam.max(), tmp_cam.min() 123 | tmp_cam = (tmp_cam - tmp_min) / (tmp_max - tmp_min + 1e-10) 124 | tmp_desc = exo_desc_re_flat[b_, n] 125 | tmp_top_desc = tmp_desc[:, torch.where(tmp_cam > self.aff_cam_thd)[0]].T # n x c 126 | exo_aff_desc.append(tmp_top_desc) 127 | exo_aff_desc = torch.cat(exo_aff_desc, dim=0) # (n1 + n2 + n3) x c 128 | 129 | if exo_aff_desc.shape[0] < self.cluster_num: 130 | continue 131 | 132 | kmeans = KMeans(n_clusters=self.cluster_num, mode='euclidean', max_iter=300) 133 | kmeans.fit_predict(exo_aff_desc.contiguous()) 134 | clu_cens = F.normalize(kmeans.centroids, dim=1) 135 | 136 | # save the exocentric similarity maps for visualization in training 137 | for n_ in range(num_exo): 138 | exo_sim_maps[b_, n_] = torch.mm(clu_cens, F.normalize(exo_desc_re_flat[b_, n_], dim=0)) 139 | 140 | # find object part prototypes and background prototypes 141 | sim_map = torch.mm(clu_cens, F.normalize(ego_desc_flat[b_], dim=0)) # self.cluster_num x hw 142 | tmp_sim_max, tmp_sim_min = torch.max(sim_map, dim=-1, keepdim=True)[0], \ 143 | torch.min(sim_map, dim=-1, keepdim=True)[0] 144 | sim_map_norm = (sim_map - tmp_sim_min) / (tmp_sim_max - tmp_sim_min + 1e-12) 145 | 146 | sim_map_hard = (sim_map_norm > torch.mean(sim_map_norm, 1, keepdim=True)).float() 147 | sam_hard = (ego_sam_flat > torch.mean(ego_sam_flat, 1, keepdim=True)).float() 148 | 149 | inter = (sim_map_hard * sam_hard[b_]).sum(1) 150 | union = sim_map_hard.sum(1) + sam_hard[b_].sum() - inter 151 | p_score = (inter / sim_map_hard.sum(1) + sam_hard[b_].sum() / union) / 2 152 | 153 | sim_maps[b_] = sim_map 154 | part_score[b_] = p_score 155 | 156 | if p_score.max() < self.part_iou_thd: 157 | continue 158 | 159 | part_proto[b_] = clu_cens[torch.argmax(p_score)] 160 | 161 | sim_maps = sim_maps.reshape(b, self.cluster_num, h, w) 162 | exo_sim_maps = exo_sim_maps.reshape(b, num_exo, self.cluster_num, h, w) 163 | ego_proj = self.aff_ego_proj(ego_proj) 164 | ego_pred = self.aff_fc(ego_proj) 165 | aff_logits_ego = self.gap(ego_pred).view(b, self.aff_classes) 166 | 167 | # --- concentration loss --- # 168 | gt_ego_cam = torch.zeros(b, h, w).cuda() 169 | loss_con = torch.zeros(1).cuda() 170 | for b_ in range(b): 171 | gt_ego_cam[b_] = ego_pred[b_, aff_label[b_]] 172 | loss_con += concentration_loss(ego_pred[b_]) 173 | 174 | gt_ego_cam = normalize_minmax(gt_ego_cam) 175 | loss_con /= b 176 | 177 | # --- prototype guidance loss --- # 178 | loss_proto = torch.zeros(1).cuda() 179 | valid_batch = 0 180 | if epoch[0] > epoch[1]: 181 | for b_ in range(b): 182 | if not part_proto[b_].equal(torch.zeros(c).cuda()): 183 | mask = gt_ego_cam[b_] 184 | tmp_feat = ego_desc[b_] * mask 185 | embedding = tmp_feat.reshape(tmp_feat.shape[0], -1).sum(1) / mask.sum() 186 | loss_proto += torch.max( 187 | 1 - F.cosine_similarity(embedding, part_proto[b_], dim=0) - self.cel_margin, 188 | torch.zeros(1).cuda()) 189 | valid_batch += 1 190 | loss_proto = loss_proto / (valid_batch + 1e-15) 191 | 192 | masks = {'exo_aff': gt_aff_cam, 'ego_sam': ego_sam, 193 | 'pred': (sim_maps, exo_sim_maps, part_score, gt_ego_cam)} 194 | logits = {'aff': aff_logits, 'aff_ego': aff_logits_ego} 195 | 196 | return masks, logits, loss_proto, loss_con 197 | 198 | @torch.no_grad() 199 | def test_forward(self, ego, aff_label): 200 | _, ego_key, ego_attn = self.vit_model.get_last_key(ego) # attn: b x 6 x (1+hw) x (1+hw) 201 | ego_desc = ego_key.permute(0, 2, 3, 1).flatten(-2, -1) 202 | ego_proj = ego_desc[:, 1:] + self.aff_proj(ego_desc[:, 1:]) 203 | ego_desc = self._reshape_transform(ego_desc[:, 1:, :], self.patch, self.stride) 204 | ego_proj = self._reshape_transform(ego_proj, self.patch, self.stride) 205 | 206 | b, c, h, w = ego_desc.shape 207 | ego_proj = self.aff_ego_proj(ego_proj) 208 | ego_pred = self.aff_fc(ego_proj) 209 | 210 | gt_ego_cam = torch.zeros(b, h, w).cuda() 211 | for b_ in range(b): 212 | gt_ego_cam[b_] = ego_pred[b_, aff_label[b_]] 213 | 214 | return gt_ego_cam 215 | 216 | def _reshape_transform(self, tensor, patch_size, stride): 217 | height = (224 - patch_size) // stride + 1 218 | width = (224 - patch_size) // stride + 1 219 | result = tensor.reshape(tensor.size(0), height, width, tensor.size(-1)) 220 | result = result.transpose(2, 3).transpose(1, 2).contiguous() 221 | return result 222 | -------------------------------------------------------------------------------- /models/model_util.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import numpy as np 4 | 5 | __all__ = ['normalize_minmax', 'concentration_loss'] 6 | 7 | 8 | def normalize_minmax(cams, eps=1e-15): 9 | B, _, _ = cams.shape 10 | min_value, _ = cams.view(B, -1).min(1) 11 | cams_minmax = cams - min_value.view(B, 1, 1) 12 | max_value, _ = cams_minmax.view(B, -1).max(1) 13 | cams_minmax /= max_value.view(B, 1, 1) + eps 14 | return cams_minmax 15 | 16 | 17 | def get_variance(part_map, x_c, y_c): 18 | h, w = part_map.shape 19 | x_map, y_map = get_coordinate_tensors(h, w) 20 | 21 | v_x_map = (x_map - x_c) * (x_map - x_c) 22 | v_y_map = (y_map - y_c) * (y_map - y_c) 23 | 24 | v_x = (part_map * v_x_map).sum() 25 | v_y = (part_map * v_y_map).sum() 26 | return v_x, v_y 27 | 28 | 29 | def get_coordinate_tensors(x_max, y_max): 30 | x_map = np.tile(np.arange(x_max), (y_max, 1)) / x_max * 2 - 1.0 31 | y_map = np.tile(np.arange(y_max), (x_max, 1)).T / y_max * 2 - 1.0 32 | 33 | x_map_tensor = torch.from_numpy(x_map.astype(np.float32)).cuda() 34 | y_map_tensor = torch.from_numpy(y_map.astype(np.float32)).cuda() 35 | 36 | return x_map_tensor, y_map_tensor 37 | 38 | 39 | def get_center(part_map, self_referenced=False): 40 | h, w = part_map.shape 41 | x_map, y_map = get_coordinate_tensors(h, w) 42 | 43 | x_center = (part_map * x_map).sum() 44 | y_center = (part_map * y_map).sum() 45 | 46 | if self_referenced: 47 | x_c_value = float(x_center.cpu().detach()) 48 | y_c_value = float(y_center.cpu().detach()) 49 | x_center = (part_map * (x_map - x_c_value)).sum() + x_c_value 50 | y_center = (part_map * (y_map - y_c_value)).sum() + y_c_value 51 | 52 | return x_center, y_center 53 | 54 | 55 | def get_centers(part_maps, detach_k=True, epsilon=1e-3, self_ref_coord=False): 56 | H, W = part_maps.shape 57 | part_map = part_maps + epsilon 58 | k = part_map.sum() 59 | part_map_pdf = part_map / k 60 | x_c, y_c = get_center(part_map_pdf, self_ref_coord) 61 | centers = torch.stack((x_c, y_c), dim=0) 62 | return centers 63 | 64 | 65 | def batch_get_centers(pred_norm): 66 | B, H, W = pred_norm.shape 67 | 68 | centers_list = [] 69 | for b in range(B): 70 | centers_list.append(get_centers(pred_norm[b]).unsqueeze(0)) 71 | return torch.cat(centers_list, dim=0) 72 | 73 | 74 | # Code borrowed from SCOPS https://github.com/NVlabs/SCOPS 75 | def concentration_loss(pred): 76 | # b x h x w 77 | B, H, W = pred.shape 78 | tmp_max, tmp_min = pred.max(-1)[0].max(-1)[0].view(B, 1, 1), \ 79 | pred.min(-1)[0].min(-1)[0].view(B, 1, 1) 80 | 81 | pred_norm = ((pred - tmp_min) / (tmp_max - tmp_min + 1e-10)) # b x 28 x 28 82 | 83 | loss = 0 84 | epsilon = 1e-3 85 | centers_all = batch_get_centers(pred_norm) 86 | for b in range(B): 87 | centers = centers_all[b] 88 | # normalize part map as spatial pdf 89 | part_map = pred_norm[b, :, :] + epsilon # prevent gradient explosion 90 | k = part_map.sum() 91 | part_map_pdf = part_map / k 92 | x_c, y_c = centers 93 | v_x, v_y = get_variance(part_map_pdf, x_c, y_c) 94 | loss_per_part = (v_x + v_y) 95 | loss = loss_per_part + loss 96 | loss = loss / B 97 | return loss 98 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | argparse 2 | tqdm 3 | fast-pytorch-kmeans 4 | numpy 5 | matplotlib 6 | opencv-python -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from tqdm import tqdm 4 | 5 | import cv2 6 | import torch 7 | import numpy as np 8 | from models.locate import Net as model 9 | 10 | from utils.viz import viz_pred_test 11 | from utils.util import set_seed, process_gt, normalize_map 12 | from utils.evaluation import cal_kl, cal_sim, cal_nss 13 | 14 | parser = argparse.ArgumentParser() 15 | ## path 16 | parser.add_argument('--data_root', type=str, default='/home/gen/Project/aff_grounding/dataset/AGD20K/') 17 | parser.add_argument('--model_file', type=str, default=None) 18 | parser.add_argument('--save_path', type=str, default='./save_preds') 19 | parser.add_argument("--divide", type=str, default="Seen") 20 | ## image 21 | parser.add_argument('--crop_size', type=int, default=224) 22 | parser.add_argument('--resize_size', type=int, default=256) 23 | #### test 24 | parser.add_argument('--num_workers', type=int, default=8) 25 | parser.add_argument("--test_batch_size", type=int, default=1) 26 | parser.add_argument('--test_num_workers', type=int, default=8) 27 | parser.add_argument('--gpu', type=str, default='0') 28 | parser.add_argument('--viz', action='store_true', default=False) 29 | 30 | args = parser.parse_args() 31 | 32 | if args.divide == "Seen": 33 | aff_list = ['beat', "boxing", "brush_with", "carry", "catch", "cut", "cut_with", "drag", 'drink_with', 34 | "eat", "hit", "hold", "jump", "kick", "lie_on", "lift", "look_out", "open", "pack", "peel", 35 | "pick_up", "pour", "push", "ride", "sip", "sit_on", "stick", "stir", "swing", "take_photo", 36 | "talk_on", "text_on", "throw", "type_on", "wash", "write"] 37 | else: 38 | aff_list = ["carry", "catch", "cut", "cut_with", 'drink_with', 39 | "eat", "hit", "hold", "jump", "kick", "lie_on", "open", "peel", 40 | "pick_up", "pour", "push", "ride", "sip", "sit_on", "stick", 41 | "swing", "take_photo", "throw", "type_on", "wash"] 42 | 43 | if args.divide == "Seen": 44 | args.num_classes = 36 45 | else: 46 | args.num_classes = 25 47 | 48 | args.test_root = os.path.join(args.data_root, args.divide, "testset", "egocentric") 49 | args.mask_root = os.path.join(args.data_root, args.divide, "testset", "GT") 50 | 51 | if args.viz: 52 | if not os.path.exists(args.save_path): 53 | os.makedirs(args.save_path, exist_ok=True) 54 | 55 | if __name__ == '__main__': 56 | set_seed(seed=0) 57 | 58 | from data.datatest import TestData 59 | 60 | testset = TestData(image_root=args.test_root, 61 | crop_size=args.crop_size, 62 | divide=args.divide, mask_root=args.mask_root) 63 | TestLoader = torch.utils.data.DataLoader(dataset=testset, 64 | batch_size=args.test_batch_size, 65 | shuffle=False, 66 | num_workers=args.test_num_workers, 67 | pin_memory=True) 68 | 69 | model = model(aff_classes=args.num_classes).cuda() 70 | 71 | KLs = [] 72 | SIM = [] 73 | NSS = [] 74 | model.eval() 75 | assert os.path.exists(args.model_file), "Please provide the correct model file for testing" 76 | model.load_state_dict(torch.load(args.model_file)) 77 | 78 | GT_path = args.divide + "_gt.t7" 79 | if not os.path.exists(GT_path): 80 | process_gt(args) 81 | GT_masks = torch.load(args.divide + "_gt.t7") 82 | 83 | for step, (image, label, mask_path) in enumerate(tqdm(TestLoader)): 84 | ego_pred = model.test_forward(image.cuda(), label.long().cuda()) 85 | cluster_sim_maps = [] 86 | ego_pred = np.array(ego_pred.squeeze().data.cpu()) 87 | ego_pred = normalize_map(ego_pred, args.crop_size) 88 | 89 | names = mask_path[0].split("/") 90 | key = names[-3] + "_" + names[-2] + "_" + names[-1] 91 | GT_mask = GT_masks[key] 92 | GT_mask = GT_mask / 255.0 93 | 94 | GT_mask = cv2.resize(GT_mask, (args.crop_size, args.crop_size)) 95 | 96 | kld, sim, nss = cal_kl(ego_pred, GT_mask), cal_sim(ego_pred, GT_mask), cal_nss(ego_pred, GT_mask) 97 | KLs.append(kld) 98 | SIM.append(sim) 99 | NSS.append(nss) 100 | 101 | if args.viz: 102 | img_name = key.split(".")[0] 103 | viz_pred_test(args, image, ego_pred, GT_mask, aff_list, label, img_name) 104 | 105 | mKLD = sum(KLs) / len(KLs) 106 | mSIM = sum(SIM) / len(SIM) 107 | mNSS = sum(NSS) / len(NSS) 108 | 109 | print(f"KLD = {round(mKLD, 3)}\nSIM = {round(mSIM, 3)}\nNSS = {round(mNSS, 3)}") 110 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import shutil 5 | import logging 6 | import argparse 7 | 8 | import cv2 9 | import torch 10 | import torch.nn as nn 11 | import numpy as np 12 | from models.locate import Net as model 13 | 14 | from utils.viz import viz_pred_train, viz_pred_test 15 | from utils.util import set_seed, process_gt, normalize_map, get_optimizer 16 | from utils.evaluation import cal_kl, cal_sim, cal_nss, AverageMeter, compute_cls_acc 17 | 18 | parser = argparse.ArgumentParser() 19 | ## path 20 | parser.add_argument('--data_root', type=str, default='/home/gen/Project/aff_grounding/dataset/AGD20K/') 21 | parser.add_argument('--save_root', type=str, default='save_models') 22 | parser.add_argument("--divide", type=str, default="Seen") 23 | ## image 24 | parser.add_argument('--crop_size', type=int, default=224) 25 | parser.add_argument('--resize_size', type=int, default=256) 26 | ## dataloader 27 | parser.add_argument('--num_workers', type=int, default=8) 28 | ## train 29 | parser.add_argument('--batch_size', type=int, default=16) 30 | parser.add_argument('--warm_epoch', type=int, default=0) 31 | parser.add_argument('--epochs', type=int, default=15) 32 | parser.add_argument('--lr', type=float, default=0.001) 33 | parser.add_argument('--momentum', type=float, default=0.9) 34 | parser.add_argument('--weight_decay', type=float, default=5e-4) 35 | parser.add_argument('--show_step', type=int, default=100) 36 | parser.add_argument('--gpu', type=str, default='0') 37 | parser.add_argument('--viz', action='store_true', default=False) 38 | 39 | #### test 40 | parser.add_argument("--test_batch_size", type=int, default=1) 41 | parser.add_argument('--test_num_workers', type=int, default=8) 42 | 43 | args = parser.parse_args() 44 | torch.cuda.set_device('cuda:' + args.gpu) 45 | lr = args.lr 46 | 47 | if args.divide == "Seen": 48 | aff_list = ['beat', "boxing", "brush_with", "carry", "catch", "cut", "cut_with", "drag", 'drink_with', 49 | "eat", "hit", "hold", "jump", "kick", "lie_on", "lift", "look_out", "open", "pack", "peel", 50 | "pick_up", "pour", "push", "ride", "sip", "sit_on", "stick", "stir", "swing", "take_photo", 51 | "talk_on", "text_on", "throw", "type_on", "wash", "write"] 52 | else: 53 | aff_list = ["carry", "catch", "cut", "cut_with", 'drink_with', 54 | "eat", "hit", "hold", "jump", "kick", "lie_on", "open", "peel", 55 | "pick_up", "pour", "push", "ride", "sip", "sit_on", "stick", 56 | "swing", "take_photo", "throw", "type_on", "wash"] 57 | 58 | if args.divide == "Seen": 59 | args.num_classes = 36 60 | else: 61 | args.num_classes = 25 62 | 63 | args.exocentric_root = os.path.join(args.data_root, args.divide, "trainset", "exocentric") 64 | args.egocentric_root = os.path.join(args.data_root, args.divide, "trainset", "egocentric") 65 | args.test_root = os.path.join(args.data_root, args.divide, "testset", "egocentric") 66 | args.mask_root = os.path.join(args.data_root, args.divide, "testset", "GT") 67 | time_str = time.strftime('%Y%m%d_%H%M%S', time.localtime(time.time())) 68 | args.save_path = os.path.join(args.save_root, time_str) 69 | 70 | if not os.path.exists(args.save_path): 71 | os.makedirs(args.save_path, exist_ok=True) 72 | dict_args = vars(args) 73 | 74 | shutil.copy('./models/locate.py', args.save_path) 75 | shutil.copy('./train.py', args.save_path) 76 | 77 | str_1 = "" 78 | for key, value in dict_args.items(): 79 | str_1 += key + "=" + str(value) + "\n" 80 | 81 | logging.basicConfig(filename='%s/run.log' % args.save_path, level=logging.INFO, format='%(message)s') 82 | logger = logging.getLogger() 83 | logger.addHandler(logging.StreamHandler(sys.stdout)) 84 | logger.info(str_1) 85 | 86 | if __name__ == '__main__': 87 | set_seed(seed=0) 88 | 89 | from data.datatrain import TrainData 90 | 91 | trainset = TrainData(exocentric_root=args.exocentric_root, 92 | egocentric_root=args.egocentric_root, 93 | resize_size=args.resize_size, 94 | crop_size=args.crop_size, divide=args.divide) 95 | 96 | TrainLoader = torch.utils.data.DataLoader(dataset=trainset, 97 | batch_size=args.batch_size, 98 | shuffle=True, 99 | num_workers=args.num_workers, 100 | pin_memory=True) 101 | 102 | from data.datatest import TestData 103 | 104 | testset = TestData(image_root=args.test_root, 105 | crop_size=args.crop_size, 106 | divide=args.divide, mask_root=args.mask_root) 107 | TestLoader = torch.utils.data.DataLoader(dataset=testset, 108 | batch_size=args.test_batch_size, 109 | shuffle=False, 110 | num_workers=args.test_num_workers, 111 | pin_memory=True) 112 | 113 | model = model(aff_classes=args.num_classes) 114 | model = model.cuda() 115 | model.train() 116 | optimizer, scheduler = get_optimizer(model, args) 117 | 118 | best_kld = 1000 119 | print('Train begining!') 120 | for epoch in range(args.epochs): 121 | model.train() 122 | logger.info('LR = ' + str(scheduler.get_last_lr())) 123 | exo_aff_acc = AverageMeter() 124 | ego_obj_acc = AverageMeter() 125 | 126 | for step, (exocentric_image, egocentric_image, aff_label) in enumerate(TrainLoader): 127 | aff_label = aff_label.cuda().long() # b x n x 36 128 | exo = exocentric_image.cuda() # b x n x 3 x 224 x 224 129 | ego = egocentric_image.cuda() 130 | 131 | masks, logits, loss_proto, loss_con = model(exo, ego, aff_label, (epoch, args.warm_epoch)) 132 | 133 | exo_aff_logits = logits['aff'] 134 | num_exo = exo.shape[1] 135 | exo_aff_loss = torch.zeros(1).cuda() 136 | for n in range(num_exo): 137 | exo_aff_loss += nn.CrossEntropyLoss().cuda()(exo_aff_logits[:, n], aff_label) 138 | exo_aff_loss /= num_exo 139 | 140 | loss_dict = {'ego_ce': nn.CrossEntropyLoss().cuda()(logits['aff_ego'], aff_label), 141 | 'exo_ce': exo_aff_loss, 142 | 'con_loss': loss_proto, 143 | 'loss_cen': loss_con * 0.07, 144 | } 145 | 146 | loss = sum(loss_dict.values()) 147 | optimizer.zero_grad() 148 | loss.backward() 149 | optimizer.step() 150 | 151 | cur_batch = exo.size(0) 152 | exo_acc = 100. * compute_cls_acc(logits['aff'].mean(1), aff_label) 153 | exo_aff_acc.updata(exo_acc, cur_batch) 154 | metric_dict = {'exo_aff_acc': exo_aff_acc.avg} 155 | 156 | if (step + 1) % args.show_step == 0: 157 | log_str = 'epoch: %d/%d + %d/%d | ' % (epoch + 1, args.epochs, step + 1, len(TrainLoader)) 158 | log_str += ' | '.join(['%s: %.3f' % (k, v) for k, v in metric_dict.items()]) 159 | log_str += ' | ' 160 | log_str += ' | '.join(['%s: %.3f' % (k, v) for k, v in loss_dict.items()]) 161 | logger.info(log_str) 162 | 163 | # Visualization the prediction during training 164 | if args.viz: 165 | viz_pred_train(args, ego, exo, masks, aff_list, aff_label, epoch, step + 1) 166 | 167 | scheduler.step() 168 | KLs = [] 169 | SIM = [] 170 | NSS = [] 171 | model.eval() 172 | GT_path = args.divide + "_gt.t7" 173 | if not os.path.exists(GT_path): 174 | process_gt(args) 175 | GT_masks = torch.load(args.divide + "_gt.t7") 176 | 177 | for step, (image, label, mask_path) in enumerate(TestLoader): 178 | ego_pred = model.test_forward(image.cuda(), label.long().cuda()) 179 | cluster_sim_maps = [] 180 | ego_pred = np.array(ego_pred.squeeze().data.cpu()) 181 | ego_pred = normalize_map(ego_pred, args.crop_size) 182 | 183 | names = mask_path[0].split("/") 184 | key = names[-3] + "_" + names[-2] + "_" + names[-1] 185 | GT_mask = GT_masks[key] 186 | GT_mask = GT_mask / 255.0 187 | 188 | GT_mask = cv2.resize(GT_mask, (args.crop_size, args.crop_size)) 189 | 190 | kld, sim, nss = cal_kl(ego_pred, GT_mask), cal_sim(ego_pred, GT_mask), cal_nss(ego_pred, GT_mask) 191 | KLs.append(kld) 192 | SIM.append(sim) 193 | NSS.append(nss) 194 | 195 | # Visualization the prediction during evaluation 196 | if args.viz: 197 | if (step + 1) % args.show_step == 0: 198 | img_name = key.split(".")[0] 199 | viz_pred_test(args, image, ego_pred, GT_mask, aff_list, label, img_name, epoch) 200 | 201 | mKLD = sum(KLs) / len(KLs) 202 | mSIM = sum(SIM) / len(SIM) 203 | mNSS = sum(NSS) / len(NSS) 204 | 205 | logger.info( 206 | "epoch=" + str(epoch + 1) + " mKLD = " + str(round(mKLD, 3)) 207 | + " mSIM = " + str(round(mSIM, 3)) + " mNSS = " + str(round(mNSS, 3)) 208 | + " bestKLD = " + str(round(best_kld, 3))) 209 | 210 | if mKLD < best_kld: 211 | best_kld = mKLD 212 | model_name = 'best_model_' + str(epoch + 1) + '_' + str(round(best_kld, 3)) \ 213 | + '_' + str(round(mSIM, 3)) \ 214 | + '_' + str(round(mNSS, 3)) \ 215 | + '.pth' 216 | torch.save(model.state_dict(), os.path.join(args.save_path, model_name)) 217 | -------------------------------------------------------------------------------- /utils/evaluation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def cal_kl(pred: np.ndarray, gt: np.ndarray, eps=1e-12) -> np.ndarray: 6 | map1, map2 = pred / (pred.sum() + eps), gt / (gt.sum() + eps) 7 | kld = np.sum(map2 * np.log(map2 / (map1 + eps) + eps)) 8 | return kld 9 | 10 | 11 | def cal_sim(pred: np.ndarray, gt: np.ndarray, eps=1e-12) -> np.ndarray: 12 | map1, map2 = pred / (pred.sum() + eps), gt / (gt.sum() + eps) 13 | intersection = np.minimum(map1, map2) 14 | 15 | return np.sum(intersection) 16 | 17 | 18 | def image_binary(image, threshold): 19 | output = np.zeros(image.size).reshape(image.shape) 20 | for xx in range(image.shape[0]): 21 | for yy in range(image.shape[1]): 22 | if (image[xx][yy] > threshold): 23 | output[xx][yy] = 1 24 | return output 25 | 26 | 27 | def cal_nss(pred: np.ndarray, gt: np.ndarray) -> np.ndarray: 28 | pred = pred / 255.0 29 | gt = gt / 255.0 30 | std = np.std(pred) 31 | u = np.mean(pred) 32 | 33 | smap = (pred - u) / std 34 | fixation_map = (gt - np.min(gt)) / (np.max(gt) - np.min(gt) + 1e-12) 35 | fixation_map = image_binary(fixation_map, 0.1) 36 | 37 | nss = smap * fixation_map 38 | 39 | nss = np.sum(nss) / np.sum(fixation_map + 1e-12) 40 | 41 | return nss 42 | 43 | 44 | def compute_cls_acc(preds, label): 45 | pred = torch.max(preds, 1)[1] 46 | # label = torch.max(labels, 1)[1] 47 | num_correct = (pred == label).sum() 48 | return float(num_correct) / float(preds.size(0)) 49 | 50 | 51 | class AverageMeter(object): 52 | def __init__(self): 53 | self.reset() 54 | 55 | def reset(self): 56 | self.val = 0.0 57 | self.avg = 0.0 58 | self.sum = 0.0 59 | self.cnt = 0.0 60 | 61 | def updata(self, val, n=1.0): 62 | self.val = val 63 | self.sum += val * n 64 | self.cnt += n 65 | if self.cnt == 0: 66 | self.avg = 1 67 | else: 68 | self.avg = self.sum / self.cnt 69 | -------------------------------------------------------------------------------- /utils/util.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import os 3 | import random 4 | import torch 5 | import numpy as np 6 | from PIL import Image 7 | from matplotlib import cm 8 | 9 | 10 | def set_seed(seed=0): 11 | np.random.seed(seed) 12 | torch.manual_seed(seed) 13 | torch.cuda.manual_seed(seed) 14 | random.seed(seed) 15 | torch.cuda.manual_seed_all(seed) 16 | torch.backends.cudnn.deterministic = True 17 | torch.backends.cudnn.benchmark = False 18 | 19 | 20 | def process_gt(args): 21 | assert args.divide in ["Seen", "Unseen"], "The divide argument should be Seen or Unseen" 22 | files = os.listdir(args.mask_root) 23 | dict_1 = {} 24 | for file in files: 25 | file_path = os.path.join(args.mask_root, file) 26 | objs = os.listdir(file_path) 27 | for obj in objs: 28 | obj_path = os.path.join(file_path, obj) 29 | images = os.listdir(obj_path) 30 | for img in images: 31 | img_path = os.path.join(obj_path, img) 32 | mask = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE) 33 | key = file + "_" + obj + "_" + img 34 | dict_1[key] = mask 35 | 36 | torch.save(dict_1, args.divide + "_gt.t7") 37 | 38 | 39 | def normalize_map(atten_map, crop_size): 40 | atten_map = cv2.resize(atten_map, dsize=(crop_size, crop_size)) 41 | min_val = np.min(atten_map) 42 | max_val = np.max(atten_map) 43 | atten_norm = (atten_map - min_val) / (max_val - min_val + 1e-10) 44 | return atten_norm 45 | 46 | 47 | def get_optimizer(model, args): 48 | lr = args.lr 49 | weight_list = [] 50 | bias_list = [] 51 | last_weight_list = [] 52 | last_bias_list = [] 53 | for name, value in model.named_parameters(): 54 | if value.requires_grad: 55 | if 'fc' in name: 56 | if 'weight' in name: 57 | last_weight_list.append(value) 58 | elif 'bias' in name: 59 | last_bias_list.append(value) 60 | else: 61 | if 'weight' in name: 62 | weight_list.append(value) 63 | elif 'bias' in name: 64 | bias_list.append(value) 65 | optimizer = torch.optim.SGD([{'params': weight_list, 66 | 'lr': lr}, 67 | {'params': bias_list, 68 | 'lr': lr * 2}, 69 | {'params': last_weight_list, 70 | 'lr': lr * 10}, 71 | {'params': last_bias_list, 72 | 'lr': lr * 20}], 73 | momentum=args.momentum, 74 | weight_decay=args.weight_decay, 75 | nesterov=True) 76 | 77 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1) 78 | return optimizer, scheduler 79 | 80 | 81 | def overlay_mask(img: Image.Image, mask: Image.Image, colormap: str = "jet", alpha: float = 0.7) -> Image.Image: 82 | if not isinstance(img, Image.Image) or not isinstance(mask, Image.Image): 83 | raise TypeError("img and mask arguments need to be PIL.Image") 84 | 85 | if not isinstance(alpha, float) or alpha < 0 or alpha >= 1: 86 | raise ValueError("alpha argument is expected to be of type float between 0 and 1") 87 | 88 | cmap = cm.get_cmap(colormap) 89 | # Resize mask and apply colormap 90 | overlay = mask.resize(img.size, resample=Image.BICUBIC) 91 | overlay = (255 * cmap(np.asarray(overlay) ** 2)[:, :, :3]).astype(np.uint8) 92 | # Overlay the image with the mask 93 | overlayed_img = Image.fromarray((alpha * np.asarray(img) + (1 - alpha) * overlay).astype(np.uint8)) 94 | 95 | return overlayed_img 96 | -------------------------------------------------------------------------------- /utils/viz.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from PIL import Image 5 | from utils.util import normalize_map, overlay_mask 6 | import matplotlib.pyplot as plt 7 | 8 | 9 | # visualize the prediction of the first batch 10 | def viz_pred_train(args, ego, exo, masks, aff_list, aff_label, epoch, step): 11 | mean = torch.as_tensor([0.485, 0.456, 0.406], dtype=ego.dtype, device=ego.device).view(-1, 1, 1) 12 | std = torch.as_tensor([0.229, 0.224, 0.225], dtype=ego.dtype, device=ego.device).view(-1, 1, 1) 13 | 14 | ego_0 = ego[0].squeeze(0) * std + mean 15 | ego_0 = ego_0.detach().cpu().numpy() * 255 16 | ego_0 = Image.fromarray(ego_0.transpose(1, 2, 0).astype(np.uint8)) 17 | 18 | exo_img = [] 19 | num_exo = exo.shape[1] 20 | for i in range(num_exo): 21 | name = 'exo_' + str(i) 22 | locals()[name] = exo[0][i].squeeze(0) * std + mean 23 | locals()[name] = locals()[name].detach().cpu().numpy() * 255 24 | locals()[name] = Image.fromarray(locals()[name].transpose(1, 2, 0).astype(np.uint8)) 25 | exo_img.append(locals()[name]) 26 | 27 | exo_cam = masks['exo_aff'][0] 28 | 29 | sim_maps, exo_sim_maps, part_score, ego_pred = masks['pred'] 30 | num_clu = sim_maps.shape[1] 31 | part_score = np.array(part_score[0].squeeze().data.cpu()) 32 | 33 | ego_pred = np.array(ego_pred[0].squeeze().data.cpu()) 34 | ego_pred = normalize_map(ego_pred, args.crop_size) 35 | ego_pred = Image.fromarray(ego_pred) 36 | ego_pred = overlay_mask(ego_0, ego_pred, alpha=0.5) 37 | 38 | ego_sam = masks['ego_sam'] 39 | ego_sam = np.array(ego_sam[0].squeeze().data.cpu()) 40 | ego_sam = normalize_map(ego_sam, args.crop_size) 41 | ego_sam = Image.fromarray(ego_sam) 42 | ego_sam = overlay_mask(ego_0, ego_sam, alpha=0.1) 43 | 44 | aff_str = aff_list[aff_label[0].item()] 45 | 46 | for i in range(num_exo): 47 | name = 'exo_aff' + str(i) 48 | locals()[name] = np.array(exo_cam[i].squeeze().data.cpu()) 49 | locals()[name] = normalize_map(locals()[name], args.crop_size) 50 | locals()[name] = Image.fromarray(locals()[name]) 51 | locals()[name] = overlay_mask(exo_img[i], locals()[name], alpha=0.5) 52 | 53 | for i in range(num_clu): 54 | name = 'sim_map' + str(i) 55 | locals()[name] = np.array(sim_maps[0][i].squeeze().data.cpu()) 56 | locals()[name] = normalize_map(locals()[name], args.crop_size) 57 | locals()[name] = Image.fromarray(locals()[name]) 58 | locals()[name] = overlay_mask(ego_0, locals()[name], alpha=0.5) 59 | 60 | # Similarity maps for the first exocentric image 61 | name = 'exo_sim_map' + str(i) 62 | locals()[name] = np.array(exo_sim_maps[0, 0][i].squeeze().data.cpu()) 63 | locals()[name] = normalize_map(locals()[name], args.crop_size) 64 | locals()[name] = Image.fromarray(locals()[name]) 65 | locals()[name] = overlay_mask(locals()['exo_' + str(0)], locals()[name], alpha=0.5) 66 | 67 | # Exo&Ego plots 68 | fig, ax = plt.subplots(4, max(num_clu, num_exo), figsize=(8, 8)) 69 | for axi in ax.ravel(): 70 | axi.set_axis_off() 71 | for k in range(num_exo): 72 | ax[0, k].imshow(eval('exo_aff' + str(k))) 73 | ax[0, k].set_title("exo_" + aff_str) 74 | for k in range(num_clu): 75 | ax[1, k].imshow(eval('sim_map' + str(k))) 76 | ax[1, k].set_title('PartIoU_' + str(round(part_score[k], 2))) 77 | ax[2, k].imshow(eval('exo_sim_map' + str(k))) 78 | ax[2, k].set_title('sim_map_' + str(k)) 79 | ax[3, 0].imshow(ego_pred) 80 | ax[3, 0].set_title(aff_str) 81 | ax[3, 1].imshow(ego_sam) 82 | ax[3, 1].set_title('Saliency') 83 | 84 | os.makedirs(os.path.join(args.save_path, 'viz_train'), exist_ok=True) 85 | fig_name = os.path.join(args.save_path, 'viz_train', 'cam_' + str(epoch) + '_' + str(step) + '.jpg') 86 | plt.tight_layout() 87 | plt.savefig(fig_name) 88 | plt.close() 89 | 90 | 91 | def viz_pred_test(args, image, ego_pred, GT_mask, aff_list, aff_label, img_name, epoch=None): 92 | mean = torch.as_tensor([0.485, 0.456, 0.406], dtype=image.dtype, device=image.device).view(-1, 1, 1) 93 | std = torch.as_tensor([0.229, 0.224, 0.225], dtype=image.dtype, device=image.device).view(-1, 1, 1) 94 | mean = mean.view(-1, 1, 1) 95 | std = std.view(-1, 1, 1) 96 | img = image.squeeze(0) * std + mean 97 | img = img.detach().cpu().numpy() * 255 98 | img = Image.fromarray(img.transpose(1, 2, 0).astype(np.uint8)) 99 | 100 | gt = Image.fromarray(GT_mask) 101 | gt_result = overlay_mask(img, gt, alpha=0.5) 102 | aff_str = aff_list[aff_label.item()] 103 | 104 | ego_pred = Image.fromarray(ego_pred) 105 | ego_pred = overlay_mask(img, ego_pred, alpha=0.5) 106 | 107 | fig, ax = plt.subplots(1, 3, figsize=(10, 6)) 108 | for axi in ax.ravel(): 109 | axi.set_axis_off() 110 | ax[0].imshow(img) 111 | ax[0].set_title('ego') 112 | ax[1].imshow(ego_pred) 113 | ax[1].set_title(aff_str) 114 | ax[2].imshow(gt_result) 115 | ax[2].set_title('GT') 116 | 117 | os.makedirs(os.path.join(args.save_path, 'viz_test'), exist_ok=True) 118 | if epoch: 119 | fig_name = os.path.join(args.save_path, 'viz_test', "epoch" + str(epoch) + '_' + img_name + '.jpg') 120 | else: 121 | fig_name = os.path.join(args.save_path, 'viz_test', img_name + '.jpg') 122 | plt.savefig(fig_name) 123 | plt.close() 124 | --------------------------------------------------------------------------------