├── LICENSE
├── README.md
├── assets
└── problem_setting.gif
├── data
├── datatest.py
└── datatrain.py
├── models
├── dino
│ ├── utils.py
│ └── vision_transformer.py
├── locate.py
└── model_util.py
├── requirements.txt
├── test.py
├── train.py
└── utils
├── evaluation.py
├── util.py
└── viz.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Gen Li
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # LOCATE: Localize and Transfer Object Parts for Weakly Supervised Affordance Grounding
2 |
3 | [](https://arxiv.org/abs/2303.09665)
4 | [](https://reagan1311.github.io/locate/)
5 | [](https://www.youtube.com/watch?v=RLHansdFxII)
6 |
7 | Official pytorch implementation of our CVPR 2023 paper - LOCATE: Localize and Transfer Object Parts for Weakly
8 | Supervised Affordance Grounding.
9 |
10 | ## Abstract
11 |
12 | Humans excel at acquiring knowledge through observation. For example, we can learn to use new tools by watching
13 | demonstrations. This skill is fundamental for intelligent systems to interact with the world. A key step to acquire this
14 | skill is to identify what part of the object affords each action, which is called affordance grounding. In this paper,
15 | we address this problem and propose a framework called LOCATE that can identify matching object parts across images, to
16 | transfer knowledge from images where an object is being used (exocentric images used for learning), to images where the
17 | object is inactive (egocentric ones used to test). To this end, we first find interaction areas and extract their
18 | feature embeddings. Then we learn to aggregate the embeddings into compact prototypes (human, object part, and
19 | background), and select the one representing the object part. Finally, we use the selected prototype to guide affordance
20 | grounding. We do this in a weakly supervised manner, learning only from image-level affordance and object labels.
21 | Extensive experiments demonstrate that our approach outperforms state-of-the-art methods by a large margin on both seen
22 | and unseen objects.
23 |
24 |
25 |
26 |
27 |
28 | ## Usage
29 |
30 | ### 1. Requirements
31 |
32 | Code is tested under Pytorch 1.12.1, python 3.7, and CUDA 11.6
33 |
34 | ```
35 | pip install -r requirements.txt
36 | ```
37 |
38 | ### 2. Dataset
39 |
40 | Download the AGD20K dataset
41 | from [ [Google Drive](https://drive.google.com/file/d/1OEz25-u1uqKfeuyCqy7hmiOv7lIWfigk/view?usp=sharing) | [Baidu Pan](https://pan.baidu.com/s/1IRfho7xDAT0oJi5_mvP1sg) (g23n) ]
42 | .
43 |
44 | ### 3. Train and Test
45 |
46 | Our pretrained model can be downloaded
47 | from [Google Drive](https://drive.google.com/drive/folders/1-AcTiE9Lz91bPJlp1o-ubgkxKnudohdx?usp=sharing). Run following commands to start training or testing:
48 |
49 | ```
50 | python train.py --data_root
51 | python test.py --data_root --model_file
52 | ```
53 |
54 | ## Citation
55 |
56 | ```
57 | @inproceedings{li:locate:2023,
58 | title = {LOCATE: Localize and Transfer Object Parts for Weakly Supervised Affordance Grounding},
59 | author = {Li, Gen and Jampani, Varun and Sun, Deqing and Sevilla-Lara, Laura},
60 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
61 | year={2023}
62 | }
63 | ```
64 |
65 | ## Anckowledgement
66 |
67 | This repo is based on [Cross-View-AG](https://github.com/lhc1224/Cross-View-AG)
68 | , [dino-vit-features](https://github.com/ShirAmir/dino-vit-features),
69 | and [dino](https://github.com/facebookresearch/dino). Thanks for their great work!
70 |
--------------------------------------------------------------------------------
/assets/problem_setting.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Reagan1311/LOCATE/8db3d015809b8b80cd8f1173b78f84686d77c3c0/assets/problem_setting.gif
--------------------------------------------------------------------------------
/data/datatest.py:
--------------------------------------------------------------------------------
1 | import os
2 | from torch.utils import data
3 | from torchvision import transforms
4 | from PIL import Image
5 |
6 |
7 | class TestData(data.Dataset):
8 | def __init__(self, image_root, crop_size=224, divide="Seen", mask_root=None):
9 | self.image_root = image_root
10 | self.image_list = []
11 | self.crop_size = crop_size
12 | self.mask_root = mask_root
13 | if divide == "Seen":
14 | self.aff_list = ['beat', "boxing", "brush_with", "carry", "catch",
15 | "cut", "cut_with", "drag", 'drink_with', "eat",
16 | "hit", "hold", "jump", "kick", "lie_on", "lift",
17 | "look_out", "open", "pack", "peel", "pick_up",
18 | "pour", "push", "ride", "sip", "sit_on", "stick",
19 | "stir", "swing", "take_photo", "talk_on", "text_on",
20 | "throw", "type_on", "wash", "write"]
21 | self.obj_list = ['apple', 'axe', 'badminton_racket', 'banana', 'baseball', 'baseball_bat',
22 | 'basketball', 'bed', 'bench', 'bicycle', 'binoculars', 'book', 'bottle',
23 | 'bowl', 'broccoli', 'camera', 'carrot', 'cell_phone', 'chair', 'couch',
24 | 'cup', 'discus', 'drum', 'fork', 'frisbee', 'golf_clubs', 'hammer', 'hot_dog',
25 | 'javelin', 'keyboard', 'knife', 'laptop', 'microwave', 'motorcycle', 'orange',
26 | 'oven', 'pen', 'punching_bag', 'refrigerator', 'rugby_ball', 'scissors',
27 | 'skateboard', 'skis', 'snowboard', 'soccer_ball', 'suitcase', 'surfboard',
28 | 'tennis_racket', 'toothbrush', 'wine_glass']
29 | else:
30 | self.aff_list = ["carry", "catch", "cut", "cut_with", 'drink_with',
31 | "eat", "hit", "hold", "jump", "kick", "lie_on", "open", "peel",
32 | "pick_up", "pour", "push", "ride", "sip", "sit_on", "stick",
33 | "swing", "take_photo", "throw", "type_on", "wash"]
34 | self.obj_list = ['apple', 'axe', 'badminton_racket', 'banana', 'baseball', 'baseball_bat',
35 | 'basketball', 'bed', 'bench', 'bicycle', 'binoculars', 'book', 'bottle',
36 | 'bowl', 'broccoli', 'camera', 'carrot', 'cell_phone', 'chair', 'couch',
37 | 'cup', 'discus', 'drum', 'fork', 'frisbee', 'golf_clubs', 'hammer', 'hot_dog',
38 | 'javelin', 'keyboard', 'knife', 'laptop', 'microwave', 'motorcycle', 'orange',
39 | 'oven', 'pen', 'punching_bag', 'refrigerator', 'rugby_ball', 'scissors',
40 | 'skateboard', 'skis', 'snowboard', 'soccer_ball', 'suitcase', 'surfboard',
41 | 'tennis_racket', 'toothbrush', 'wine_glass']
42 |
43 | self.transform = transforms.Compose([
44 | transforms.Resize((crop_size, crop_size)),
45 | transforms.ToTensor(),
46 | transforms.Normalize(mean=(0.485, 0.456, 0.406),
47 | std=(0.229, 0.224, 0.225))])
48 |
49 | files = os.listdir(self.image_root)
50 | for file in files:
51 | file_path = os.path.join(self.image_root, file)
52 | obj_files = os.listdir(file_path)
53 | for obj_file in obj_files:
54 | obj_file_path = os.path.join(file_path, obj_file)
55 | images = os.listdir(obj_file_path)
56 | for img in images:
57 | img_path = os.path.join(obj_file_path, img)
58 | mask_path = os.path.join(self.mask_root, file, obj_file, img[:-3] + "png")
59 |
60 | if os.path.exists(mask_path):
61 | self.image_list.append(img_path)
62 | # print(self.image_list)
63 |
64 | self.aff2obj_dict = dict()
65 | for aff in self.aff_list:
66 | aff_path = os.path.join(self.image_root, aff)
67 | aff_obj_list = os.listdir(aff_path)
68 | self.aff2obj_dict.update({aff: aff_obj_list})
69 |
70 | self.obj2aff_dict = dict()
71 | for obj in self.obj_list:
72 | obj2aff_list = []
73 | for k, v in self.aff2obj_dict.items():
74 | if obj in v:
75 | obj2aff_list.append(k)
76 | for i in range(len(obj2aff_list)):
77 | obj2aff_list[i] = self.aff_list.index(obj2aff_list[i])
78 | self.obj2aff_dict.update({obj: obj2aff_list})
79 |
80 | def __getitem__(self, item):
81 |
82 | image_path = self.image_list[item]
83 | names = image_path.split("/")
84 | aff_name, object = names[-3], names[-2]
85 |
86 | image = self.load_img(image_path)
87 | label = self.aff_list.index(aff_name)
88 | names = image_path.split("/")
89 | mask_path = os.path.join(self.mask_root, names[-3], names[-2], names[-1][:-3] + "png")
90 |
91 | return image, label, mask_path
92 |
93 | def load_img(self, path):
94 | img = Image.open(path).convert('RGB')
95 | img = self.transform(img)
96 | return img
97 |
98 | def __len__(self):
99 |
100 | return len(self.image_list)
101 |
--------------------------------------------------------------------------------
/data/datatrain.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import random
4 | from PIL import Image
5 | from torch.utils import data
6 | from torchvision import transforms
7 |
8 |
9 | class TrainData(data.Dataset):
10 | def __init__(self, exocentric_root, egocentric_root, resize_size=256, crop_size=224, divide="Seen"):
11 |
12 | self.exocentric_root = exocentric_root
13 | self.egocentric_root = egocentric_root
14 |
15 | self.image_list = []
16 | self.exo_image_list = []
17 | self.resize_size = resize_size
18 | self.crop_size = crop_size
19 | if divide == "Seen":
20 | self.aff_list = ['beat', "boxing", "brush_with", "carry", "catch", "cut", "cut_with", "drag", 'drink_with',
21 | "eat", "hit", "hold", "jump", "kick", "lie_on", "lift", "look_out", "open", "pack", "peel",
22 | "pick_up", "pour", "push", "ride", "sip", "sit_on", "stick", "stir", "swing", "take_photo",
23 | "talk_on", "text_on", "throw", "type_on", "wash", "write"]
24 | self.obj_list = ['apple', 'axe', 'badminton_racket', 'banana', 'baseball', 'baseball_bat',
25 | 'basketball', 'bed', 'bench', 'bicycle', 'binoculars', 'book', 'bottle',
26 | 'bowl', 'broccoli', 'camera', 'carrot', 'cell_phone', 'chair', 'couch',
27 | 'cup', 'discus', 'drum', 'fork', 'frisbee', 'golf_clubs', 'hammer', 'hot_dog',
28 | 'javelin', 'keyboard', 'knife', 'laptop', 'microwave', 'motorcycle', 'orange',
29 | 'oven', 'pen', 'punching_bag', 'refrigerator', 'rugby_ball', 'scissors',
30 | 'skateboard', 'skis', 'snowboard', 'soccer_ball', 'suitcase', 'surfboard',
31 | 'tennis_racket', 'toothbrush', 'wine_glass']
32 | else:
33 | self.aff_list = ["carry", "catch", "cut", "cut_with", 'drink_with',
34 | "eat", "hit", "hold", "jump", "kick", "lie_on", "open", "peel",
35 | "pick_up", "pour", "push", "ride", "sip", "sit_on", "stick",
36 | "swing", "take_photo", "throw", "type_on", "wash"]
37 | self.obj_list = ['apple', 'axe', 'badminton_racket', 'banana', 'baseball', 'baseball_bat',
38 | 'basketball', 'bed', 'bench', 'bicycle', 'binoculars', 'book', 'bottle',
39 | 'bowl', 'broccoli', 'camera', 'carrot', 'cell_phone', 'chair', 'couch',
40 | 'cup', 'discus', 'drum', 'fork', 'frisbee', 'golf_clubs', 'hammer', 'hot_dog',
41 | 'javelin', 'keyboard', 'knife', 'laptop', 'microwave', 'motorcycle', 'orange',
42 | 'oven', 'pen', 'punching_bag', 'refrigerator', 'rugby_ball', 'scissors',
43 | 'skateboard', 'skis', 'snowboard', 'soccer_ball', 'suitcase', 'surfboard',
44 | 'tennis_racket', 'toothbrush', 'wine_glass']
45 |
46 | self.transform = transforms.Compose([
47 | transforms.Resize(resize_size),
48 | transforms.RandomCrop(crop_size),
49 | transforms.RandomHorizontalFlip(),
50 | transforms.ToTensor(),
51 | transforms.Normalize(mean=(0.485, 0.456, 0.406),
52 | std=(0.229, 0.224, 0.225))])
53 |
54 | # image list for egocentric images
55 | files = os.listdir(self.exocentric_root)
56 | for file in files:
57 | file_path = os.path.join(self.exocentric_root, file)
58 | obj_files = os.listdir(file_path)
59 | for obj_file in obj_files:
60 | obj_file_path = os.path.join(file_path, obj_file)
61 | images = os.listdir(obj_file_path)
62 | for img in images:
63 | img_path = os.path.join(obj_file_path, img)
64 | self.image_list.append(img_path)
65 |
66 | # multiple affordance labels for exo-centric samples
67 |
68 | def __getitem__(self, item):
69 |
70 | # load egocentric image
71 | exocentric_image_path = self.image_list[item]
72 | names = exocentric_image_path.split("/")
73 | aff_name, object = names[-3], names[-2]
74 | exocentric_image = self.load_img(exocentric_image_path)
75 | aff_label = self.aff_list.index(aff_name)
76 |
77 | ego_path = os.path.join(self.egocentric_root, aff_name, object)
78 | obj_images = os.listdir(ego_path)
79 | idx = random.randint(0, len(obj_images) - 1)
80 | egocentric_image_path = os.path.join(ego_path, obj_images[idx])
81 | egocentric_image = self.load_img(egocentric_image_path)
82 |
83 | # pick one available affordance, and then choose & load exo-centric images
84 | num_exo = 3
85 | exo_dir = os.path.dirname(exocentric_image_path)
86 | exocentrics = os.listdir(exo_dir)
87 | exo_img_name = [os.path.basename(exocentric_image_path)]
88 | exocentric_images = [exocentric_image]
89 | # exocentric_labels = []
90 |
91 | if len(exocentrics) > num_exo:
92 | for i in range(num_exo - 1):
93 | exo_img_ = random.choice(exocentrics)
94 | while exo_img_ in exo_img_name:
95 | exo_img_ = random.choice(exocentrics)
96 | exo_img_name.append(exo_img_)
97 | tmp_exo = self.load_img(os.path.join(exo_dir, exo_img_))
98 | exocentric_images.append(tmp_exo)
99 | else:
100 | for i in range(num_exo - 1):
101 | exo_img_ = random.choice(exocentrics)
102 | # while exo_img_ in exo_img_name:
103 | # exo_img_ = random.choice(exocentrics)
104 | exo_img_name.append(exo_img_)
105 | tmp_exo = self.load_img(os.path.join(exo_dir, exo_img_))
106 | exocentric_images.append(tmp_exo)
107 |
108 | exocentric_images = torch.stack(exocentric_images, dim=0) # n x 3 x 224 x 224
109 |
110 | return exocentric_images, egocentric_image, aff_label
111 |
112 | def load_img(self, path):
113 | img = Image.open(path).convert('RGB')
114 | img = self.transform(img)
115 | return img
116 |
117 | def __len__(self):
118 |
119 | return len(self.image_list)
120 |
--------------------------------------------------------------------------------
/models/dino/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | Misc functions.
16 |
17 | Mostly copy-paste from torchvision references or other public repos like DETR:
18 | https://github.com/facebookresearch/detr/blob/master/util/misc.py
19 | """
20 | import os
21 | import sys
22 | import time
23 | import math
24 | import random
25 | import datetime
26 | import subprocess
27 | from collections import defaultdict, deque
28 |
29 | import numpy as np
30 | import torch
31 | from torch import nn
32 | import torch.distributed as dist
33 | from PIL import ImageFilter, ImageOps
34 |
35 |
36 | class GaussianBlur(object):
37 | """
38 | Apply Gaussian Blur to the PIL image.
39 | """
40 |
41 | def __init__(self, p=0.5, radius_min=0.1, radius_max=2.):
42 | self.prob = p
43 | self.radius_min = radius_min
44 | self.radius_max = radius_max
45 |
46 | def __call__(self, img):
47 | do_it = random.random() <= self.prob
48 | if not do_it:
49 | return img
50 |
51 | return img.filter(
52 | ImageFilter.GaussianBlur(
53 | radius=random.uniform(self.radius_min, self.radius_max)
54 | )
55 | )
56 |
57 |
58 | class Solarization(object):
59 | """
60 | Apply Solarization to the PIL image.
61 | """
62 |
63 | def __init__(self, p):
64 | self.p = p
65 |
66 | def __call__(self, img):
67 | if random.random() < self.p:
68 | return ImageOps.solarize(img)
69 | else:
70 | return img
71 |
72 |
73 | def load_pretrained_weights(model, pretrained_weights, checkpoint_key, model_name, patch_size):
74 | if os.path.isfile(pretrained_weights):
75 | state_dict = torch.load(pretrained_weights, map_location="cpu")
76 | if checkpoint_key is not None and checkpoint_key in state_dict:
77 | print(f"Take key {checkpoint_key} in provided checkpoint dict")
78 | state_dict = state_dict[checkpoint_key]
79 | # remove `module.` prefix
80 | state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
81 | # remove `backbone.` prefix induced by multicrop wrapper
82 | state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
83 | msg = model.load_state_dict(state_dict, strict=False)
84 | print('Pretrained weights found at {} and loaded with msg: {}'.format(pretrained_weights, msg))
85 | else:
86 | # print("Please use the `--pretrained_weights` argument to indicate the path of the checkpoint to evaluate.")
87 | url = None
88 | if model_name == "vit_small" and patch_size == 16:
89 | url = "dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth"
90 | elif model_name == "vit_small" and patch_size == 8:
91 | url = "dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth"
92 | elif model_name == "vit_base" and patch_size == 16:
93 | url = "dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth"
94 | elif model_name == "vit_base" and patch_size == 8:
95 | url = "dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth"
96 | if url is not None:
97 | # print("Since no pretrained weights have been provided, we load the reference pretrained dino weights.")
98 | state_dict = torch.hub.load_state_dict_from_url(url="https://dl.fbaipublicfiles.com/dino/" + url)
99 | model.load_state_dict(state_dict, strict=True)
100 | else:
101 | print("There is no reference weights available for this model => We use random weights.")
102 |
103 |
104 | def clip_gradients(model, clip):
105 | norms = []
106 | for name, p in model.named_parameters():
107 | if p.grad is not None:
108 | param_norm = p.grad.data.norm(2)
109 | norms.append(param_norm.item())
110 | clip_coef = clip / (param_norm + 1e-6)
111 | if clip_coef < 1:
112 | p.grad.data.mul_(clip_coef)
113 | return norms
114 |
115 |
116 | def cancel_gradients_last_layer(epoch, model, freeze_last_layer):
117 | if epoch >= freeze_last_layer:
118 | return
119 | for n, p in model.named_parameters():
120 | if "last_layer" in n:
121 | p.grad = None
122 |
123 |
124 | def restart_from_checkpoint(ckp_path, run_variables=None, **kwargs):
125 | """
126 | Re-start from checkpoint
127 | """
128 | if not os.path.isfile(ckp_path):
129 | return
130 | print("Found checkpoint at {}".format(ckp_path))
131 |
132 | # open checkpoint file
133 | checkpoint = torch.load(ckp_path, map_location="cpu")
134 |
135 | # key is what to look for in the checkpoint file
136 | # value is the object to load
137 | # example: {'state_dict': model}
138 | for key, value in kwargs.items():
139 | if key in checkpoint and value is not None:
140 | try:
141 | msg = value.load_state_dict(checkpoint[key], strict=False)
142 | print("=> loaded '{}' from checkpoint '{}' with msg {}".format(key, ckp_path, msg))
143 | except TypeError:
144 | try:
145 | msg = value.load_state_dict(checkpoint[key])
146 | print("=> loaded '{}' from checkpoint: '{}'".format(key, ckp_path))
147 | except ValueError:
148 | print("=> failed to load '{}' from checkpoint: '{}'".format(key, ckp_path))
149 | else:
150 | print("=> key '{}' not found in checkpoint: '{}'".format(key, ckp_path))
151 |
152 | # re load variable important for the run
153 | if run_variables is not None:
154 | for var_name in run_variables:
155 | if var_name in checkpoint:
156 | run_variables[var_name] = checkpoint[var_name]
157 |
158 |
159 | def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, start_warmup_value=0):
160 | warmup_schedule = np.array([])
161 | warmup_iters = warmup_epochs * niter_per_ep
162 | if warmup_epochs > 0:
163 | warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)
164 |
165 | iters = np.arange(epochs * niter_per_ep - warmup_iters)
166 | schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters)))
167 |
168 | schedule = np.concatenate((warmup_schedule, schedule))
169 | assert len(schedule) == epochs * niter_per_ep
170 | return schedule
171 |
172 |
173 | def bool_flag(s):
174 | """
175 | Parse boolean arguments from the command line.
176 | """
177 | FALSY_STRINGS = {"off", "false", "0"}
178 | TRUTHY_STRINGS = {"on", "true", "1"}
179 | if s.lower() in FALSY_STRINGS:
180 | return False
181 | elif s.lower() in TRUTHY_STRINGS:
182 | return True
183 | else:
184 | raise argparse.ArgumentTypeError("invalid value for a boolean flag")
185 |
186 |
187 | def fix_random_seeds(seed=31):
188 | """
189 | Fix random seeds.
190 | """
191 | torch.manual_seed(seed)
192 | torch.cuda.manual_seed_all(seed)
193 | np.random.seed(seed)
194 |
195 |
196 | class SmoothedValue(object):
197 | """Track a series of values and provide access to smoothed values over a
198 | window or the global series average.
199 | """
200 |
201 | def __init__(self, window_size=20, fmt=None):
202 | if fmt is None:
203 | fmt = "{median:.6f} ({global_avg:.6f})"
204 | self.deque = deque(maxlen=window_size)
205 | self.total = 0.0
206 | self.count = 0
207 | self.fmt = fmt
208 |
209 | def update(self, value, n=1):
210 | self.deque.append(value)
211 | self.count += n
212 | self.total += value * n
213 |
214 | def synchronize_between_processes(self):
215 | """
216 | Warning: does not synchronize the deque!
217 | """
218 | if not is_dist_avail_and_initialized():
219 | return
220 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
221 | dist.barrier()
222 | dist.all_reduce(t)
223 | t = t.tolist()
224 | self.count = int(t[0])
225 | self.total = t[1]
226 |
227 | @property
228 | def median(self):
229 | d = torch.tensor(list(self.deque))
230 | return d.median().item()
231 |
232 | @property
233 | def avg(self):
234 | d = torch.tensor(list(self.deque), dtype=torch.float32)
235 | return d.mean().item()
236 |
237 | @property
238 | def global_avg(self):
239 | return self.total / self.count
240 |
241 | @property
242 | def max(self):
243 | return max(self.deque)
244 |
245 | @property
246 | def value(self):
247 | return self.deque[-1]
248 |
249 | def __str__(self):
250 | return self.fmt.format(
251 | median=self.median,
252 | avg=self.avg,
253 | global_avg=self.global_avg,
254 | max=self.max,
255 | value=self.value)
256 |
257 |
258 | def reduce_dict(input_dict, average=True):
259 | """
260 | Args:
261 | input_dict (dict): all the values will be reduced
262 | average (bool): whether to do average or sum
263 | Reduce the values in the dictionary from all processes so that all processes
264 | have the averaged results. Returns a dict with the same fields as
265 | input_dict, after reduction.
266 | """
267 | world_size = get_world_size()
268 | if world_size < 2:
269 | return input_dict
270 | with torch.no_grad():
271 | names = []
272 | values = []
273 | # sort the keys so that they are consistent across processes
274 | for k in sorted(input_dict.keys()):
275 | names.append(k)
276 | values.append(input_dict[k])
277 | values = torch.stack(values, dim=0)
278 | dist.all_reduce(values)
279 | if average:
280 | values /= world_size
281 | reduced_dict = {k: v for k, v in zip(names, values)}
282 | return reduced_dict
283 |
284 |
285 | class MetricLogger(object):
286 | def __init__(self, delimiter="\t"):
287 | self.meters = defaultdict(SmoothedValue)
288 | self.delimiter = delimiter
289 |
290 | def update(self, **kwargs):
291 | for k, v in kwargs.items():
292 | if isinstance(v, torch.Tensor):
293 | v = v.item()
294 | assert isinstance(v, (float, int))
295 | self.meters[k].update(v)
296 |
297 | def __getattr__(self, attr):
298 | if attr in self.meters:
299 | return self.meters[attr]
300 | if attr in self.__dict__:
301 | return self.__dict__[attr]
302 | raise AttributeError("'{}' object has no attribute '{}'".format(
303 | type(self).__name__, attr))
304 |
305 | def __str__(self):
306 | loss_str = []
307 | for name, meter in self.meters.items():
308 | loss_str.append(
309 | "{}: {}".format(name, str(meter))
310 | )
311 | return self.delimiter.join(loss_str)
312 |
313 | def synchronize_between_processes(self):
314 | for meter in self.meters.values():
315 | meter.synchronize_between_processes()
316 |
317 | def add_meter(self, name, meter):
318 | self.meters[name] = meter
319 |
320 | def log_every(self, iterable, print_freq, header=None):
321 | i = 0
322 | if not header:
323 | header = ''
324 | start_time = time.time()
325 | end = time.time()
326 | iter_time = SmoothedValue(fmt='{avg:.6f}')
327 | data_time = SmoothedValue(fmt='{avg:.6f}')
328 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
329 | if torch.cuda.is_available():
330 | log_msg = self.delimiter.join([
331 | header,
332 | '[{0' + space_fmt + '}/{1}]',
333 | 'eta: {eta}',
334 | '{meters}',
335 | 'time: {time}',
336 | 'data: {data}',
337 | 'max mem: {memory:.0f}'
338 | ])
339 | else:
340 | log_msg = self.delimiter.join([
341 | header,
342 | '[{0' + space_fmt + '}/{1}]',
343 | 'eta: {eta}',
344 | '{meters}',
345 | 'time: {time}',
346 | 'data: {data}'
347 | ])
348 | MB = 1024.0 * 1024.0
349 | for obj in iterable:
350 | data_time.update(time.time() - end)
351 | yield obj
352 | iter_time.update(time.time() - end)
353 | if i % print_freq == 0 or i == len(iterable) - 1:
354 | eta_seconds = iter_time.global_avg * (len(iterable) - i)
355 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
356 | if torch.cuda.is_available():
357 | print(log_msg.format(
358 | i, len(iterable), eta=eta_string,
359 | meters=str(self),
360 | time=str(iter_time), data=str(data_time),
361 | memory=torch.cuda.max_memory_allocated() / MB))
362 | else:
363 | print(log_msg.format(
364 | i, len(iterable), eta=eta_string,
365 | meters=str(self),
366 | time=str(iter_time), data=str(data_time)))
367 | i += 1
368 | end = time.time()
369 | total_time = time.time() - start_time
370 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
371 | print('{} Total time: {} ({:.6f} s / it)'.format(
372 | header, total_time_str, total_time / len(iterable)))
373 |
374 |
375 | def get_sha():
376 | cwd = os.path.dirname(os.path.abspath(__file__))
377 |
378 | def _run(command):
379 | return subprocess.check_output(command, cwd=cwd).decode('ascii').strip()
380 |
381 | sha = 'N/A'
382 | diff = "clean"
383 | branch = 'N/A'
384 | try:
385 | sha = _run(['git', 'rev-parse', 'HEAD'])
386 | subprocess.check_output(['git', 'diff'], cwd=cwd)
387 | diff = _run(['git', 'diff-index', 'HEAD'])
388 | diff = "has uncommited changes" if diff else "clean"
389 | branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD'])
390 | except Exception:
391 | pass
392 | message = f"sha: {sha}, status: {diff}, branch: {branch}"
393 | return message
394 |
395 |
396 | def is_dist_avail_and_initialized():
397 | if not dist.is_available():
398 | return False
399 | if not dist.is_initialized():
400 | return False
401 | return True
402 |
403 |
404 | def get_world_size():
405 | if not is_dist_avail_and_initialized():
406 | return 1
407 | return dist.get_world_size()
408 |
409 |
410 | def get_rank():
411 | if not is_dist_avail_and_initialized():
412 | return 0
413 | return dist.get_rank()
414 |
415 |
416 | def is_main_process():
417 | return get_rank() == 0
418 |
419 |
420 | def save_on_master(*args, **kwargs):
421 | if is_main_process():
422 | torch.save(*args, **kwargs)
423 |
424 |
425 | def setup_for_distributed(is_master):
426 | """
427 | This function disables printing when not in master process
428 | """
429 | import builtins as __builtin__
430 | builtin_print = __builtin__.print
431 |
432 | def print(*args, **kwargs):
433 | force = kwargs.pop('force', False)
434 | if is_master or force:
435 | builtin_print(*args, **kwargs)
436 |
437 | __builtin__.print = print
438 |
439 |
440 | def init_distributed_mode(args):
441 | # launched with torch.distributed.launch
442 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
443 | args.rank = int(os.environ["RANK"])
444 | args.world_size = int(os.environ['WORLD_SIZE'])
445 | args.gpu = int(os.environ['LOCAL_RANK'])
446 | # launched with submitit on a slurm cluster
447 | elif 'SLURM_PROCID' in os.environ:
448 | args.rank = int(os.environ['SLURM_PROCID'])
449 | args.gpu = args.rank % torch.cuda.device_count()
450 | # launched naively with `python main_dino.py`
451 | # we manually add MASTER_ADDR and MASTER_PORT to env variables
452 | elif torch.cuda.is_available():
453 | print('Will run the code on one GPU.')
454 | args.rank, args.gpu, args.world_size = 0, 0, 1
455 | os.environ['MASTER_ADDR'] = '127.0.0.1'
456 | os.environ['MASTER_PORT'] = '29500'
457 | else:
458 | print('Does not support training without GPU.')
459 | sys.exit(1)
460 |
461 | dist.init_process_group(
462 | backend="nccl",
463 | init_method=args.dist_url,
464 | world_size=args.world_size,
465 | rank=args.rank,
466 | )
467 |
468 | torch.cuda.set_device(args.gpu)
469 | print('| distributed init (rank {}): {}'.format(
470 | args.rank, args.dist_url), flush=True)
471 | dist.barrier()
472 | setup_for_distributed(args.rank == 0)
473 |
474 |
475 | def accuracy(output, target, topk=(1,)):
476 | """Computes the accuracy over the k top predictions for the specified values of k"""
477 | maxk = max(topk)
478 | batch_size = target.size(0)
479 | _, pred = output.topk(maxk, 1, True, True)
480 | pred = pred.t()
481 | correct = pred.eq(target.reshape(1, -1).expand_as(pred))
482 | return [correct[:k].reshape(-1).float().sum(0) * 100. / batch_size for k in topk]
483 |
484 |
485 | def _no_grad_trunc_normal_(tensor, mean, std, a, b):
486 | # Cut & paste from PyTorch official master until it's in a few official releases - RW
487 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
488 | def norm_cdf(x):
489 | # Computes standard normal cumulative distribution function
490 | return (1. + math.erf(x / math.sqrt(2.))) / 2.
491 |
492 | if (mean < a - 2 * std) or (mean > b + 2 * std):
493 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
494 | "The distribution of values may be incorrect.",
495 | stacklevel=2)
496 |
497 | with torch.no_grad():
498 | # Values are generated by using a truncated uniform distribution and
499 | # then using the inverse CDF for the normal distribution.
500 | # Get upper and lower cdf values
501 | l = norm_cdf((a - mean) / std)
502 | u = norm_cdf((b - mean) / std)
503 |
504 | # Uniformly fill tensor with values from [l, u], then translate to
505 | # [2l-1, 2u-1].
506 | tensor.uniform_(2 * l - 1, 2 * u - 1)
507 |
508 | # Use inverse cdf transform for normal distribution to get truncated
509 | # standard normal
510 | tensor.erfinv_()
511 |
512 | # Transform to proper mean, std
513 | tensor.mul_(std * math.sqrt(2.))
514 | tensor.add_(mean)
515 |
516 | # Clamp to ensure it's in the proper range
517 | tensor.clamp_(min=a, max=b)
518 | return tensor
519 |
520 |
521 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
522 | # type: (Tensor, float, float, float, float) -> Tensor
523 | return _no_grad_trunc_normal_(tensor, mean, std, a, b)
524 |
525 |
526 | class LARS(torch.optim.Optimizer):
527 | """
528 | Almost copy-paste from https://github.com/facebookresearch/barlowtwins/blob/main/main.py
529 | """
530 |
531 | def __init__(self, params, lr=0, weight_decay=0, momentum=0.9, eta=0.001,
532 | weight_decay_filter=None, lars_adaptation_filter=None):
533 | defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum,
534 | eta=eta, weight_decay_filter=weight_decay_filter,
535 | lars_adaptation_filter=lars_adaptation_filter)
536 | super().__init__(params, defaults)
537 |
538 | @torch.no_grad()
539 | def step(self):
540 | for g in self.param_groups:
541 | for p in g['params']:
542 | dp = p.grad
543 |
544 | if dp is None:
545 | continue
546 |
547 | if p.ndim != 1:
548 | dp = dp.add(p, alpha=g['weight_decay'])
549 |
550 | if p.ndim != 1:
551 | param_norm = torch.norm(p)
552 | update_norm = torch.norm(dp)
553 | one = torch.ones_like(param_norm)
554 | q = torch.where(param_norm > 0.,
555 | torch.where(update_norm > 0,
556 | (g['eta'] * param_norm / update_norm), one), one)
557 | dp = dp.mul(q)
558 |
559 | param_state = self.state[p]
560 | if 'mu' not in param_state:
561 | param_state['mu'] = torch.zeros_like(p)
562 | mu = param_state['mu']
563 | mu.mul_(g['momentum']).add_(dp)
564 |
565 | p.add_(mu, alpha=-g['lr'])
566 |
567 |
568 | class MultiCropWrapper(nn.Module):
569 | """
570 | Perform forward pass separately on each resolution input.
571 | The inputs corresponding to a single resolution are clubbed and single
572 | forward is run on the same resolution inputs. Hence we do several
573 | forward passes = number of different resolutions used. We then
574 | concatenate all the output features and run the head forward on these
575 | concatenated features.
576 | """
577 |
578 | def __init__(self, backbone, head):
579 | super(MultiCropWrapper, self).__init__()
580 | # disable layers dedicated to ImageNet labels classification
581 | backbone.fc, backbone.head = nn.Identity(), nn.Identity()
582 | self.backbone = backbone
583 | self.head = head
584 |
585 | def forward(self, x):
586 | # convert to list
587 | if not isinstance(x, list):
588 | x = [x]
589 | idx_crops = torch.cumsum(torch.unique_consecutive(
590 | torch.tensor([inp.shape[-1] for inp in x]),
591 | return_counts=True,
592 | )[1], 0)
593 | start_idx = 0
594 | for end_idx in idx_crops:
595 | _out = self.backbone(torch.cat(x[start_idx: end_idx]))
596 | if start_idx == 0:
597 | output = _out
598 | else:
599 | output = torch.cat((output, _out))
600 | start_idx = end_idx
601 | # Run the head forward on the concatenated features.
602 | return self.head(output)
603 |
604 |
605 | def get_params_groups(model):
606 | regularized = []
607 | not_regularized = []
608 | for name, param in model.named_parameters():
609 | if not param.requires_grad:
610 | continue
611 | # we do not regularize biases nor Norm parameters
612 | if name.endswith(".bias") or len(param.shape) == 1:
613 | not_regularized.append(param)
614 | else:
615 | regularized.append(param)
616 | return [{'params': regularized}, {'params': not_regularized, 'weight_decay': 0.}]
617 |
618 |
619 | def has_batchnorms(model):
620 | bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm)
621 | for name, module in model.named_modules():
622 | if isinstance(module, bn_types):
623 | return True
624 | return False
625 |
--------------------------------------------------------------------------------
/models/dino/vision_transformer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | Mostly copy-paste from timm library.
16 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
17 | """
18 | import math
19 | from functools import partial
20 |
21 | import torch
22 | import torch.nn as nn
23 |
24 | from models.dino.utils import trunc_normal_
25 |
26 |
27 | def drop_path(x, drop_prob: float = 0., training: bool = False):
28 | if drop_prob == 0. or not training:
29 | return x
30 | keep_prob = 1 - drop_prob
31 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
32 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
33 | random_tensor.floor_() # binarize
34 | output = x.div(keep_prob) * random_tensor
35 | return output
36 |
37 |
38 | class DropPath(nn.Module):
39 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
40 | """
41 |
42 | def __init__(self, drop_prob=None):
43 | super(DropPath, self).__init__()
44 | self.drop_prob = drop_prob
45 |
46 | def forward(self, x):
47 | return drop_path(x, self.drop_prob, self.training)
48 |
49 |
50 | class Mlp(nn.Module):
51 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
52 | super().__init__()
53 | out_features = out_features or in_features
54 | hidden_features = hidden_features or in_features
55 | self.fc1 = nn.Linear(in_features, hidden_features)
56 | self.act = act_layer()
57 | self.fc2 = nn.Linear(hidden_features, out_features)
58 | self.drop = nn.Dropout(drop)
59 |
60 | def forward(self, x):
61 | x = self.fc1(x)
62 | x = self.act(x)
63 | x = self.drop(x)
64 | x = self.fc2(x)
65 | x = self.drop(x)
66 | return x
67 |
68 |
69 | class Attention(nn.Module):
70 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
71 | super().__init__()
72 | self.num_heads = num_heads
73 | head_dim = dim // num_heads
74 | self.scale = qk_scale or head_dim ** -0.5
75 |
76 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
77 | self.attn_drop = nn.Dropout(attn_drop)
78 | self.proj = nn.Linear(dim, dim)
79 | self.proj_drop = nn.Dropout(proj_drop)
80 |
81 | def forward(self, x, return_key=False):
82 | B, N, C = x.shape
83 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
84 | q, k, v = qkv[0], qkv[1], qkv[2]
85 |
86 | attn = (q @ k.transpose(-2, -1)) * self.scale
87 | attn = attn.softmax(dim=-1)
88 | attn = self.attn_drop(attn)
89 |
90 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
91 | x = self.proj(x)
92 | x = self.proj_drop(x)
93 | if not return_key:
94 | return x, attn
95 | else:
96 | return x, attn, k
97 |
98 |
99 | class Block(nn.Module):
100 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
101 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
102 | super().__init__()
103 | self.norm1 = norm_layer(dim)
104 | self.attn = Attention(
105 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
106 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
107 | self.norm2 = norm_layer(dim)
108 | mlp_hidden_dim = int(dim * mlp_ratio)
109 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
110 |
111 | def forward(self, x, return_attention=False, return_key=False):
112 | if return_key:
113 | y, attn, key = self.attn(self.norm1(x), return_key)
114 | else:
115 | y, attn = self.attn(self.norm1(x))
116 | x = x + self.drop_path(y)
117 | x = x + self.drop_path(self.mlp(self.norm2(x)))
118 | if return_attention:
119 | return x, attn
120 | elif return_key:
121 | return x, key, attn
122 | else:
123 | return x
124 |
125 |
126 | class PatchEmbed(nn.Module):
127 | """ Image to Patch Embedding
128 | """
129 |
130 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
131 | super().__init__()
132 | num_patches = (img_size // patch_size) * (img_size // patch_size)
133 | self.img_size = img_size
134 | self.patch_size = patch_size
135 | self.num_patches = num_patches
136 |
137 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
138 |
139 | def forward(self, x):
140 | B, C, H, W = x.shape
141 | x = self.proj(x).flatten(2).transpose(1, 2)
142 | return x
143 |
144 |
145 | class VisionTransformer(nn.Module):
146 | """ Vision Transformer """
147 |
148 | def __init__(self, img_size=[224], patch_size=16, in_chans=3, num_classes=0, embed_dim=768, depth=12,
149 | num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
150 | drop_path_rate=0., norm_layer=nn.LayerNorm, **kwargs):
151 | super().__init__()
152 | self.num_features = self.embed_dim = embed_dim
153 |
154 | self.patch_embed = PatchEmbed(
155 | img_size=img_size[0], patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
156 | num_patches = self.patch_embed.num_patches
157 |
158 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
159 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
160 | self.pos_drop = nn.Dropout(p=drop_rate)
161 |
162 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
163 | self.blocks = nn.ModuleList([
164 | Block(
165 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
166 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
167 | for i in range(depth)])
168 | self.norm = norm_layer(embed_dim)
169 |
170 | # Classifier head
171 | self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
172 |
173 | trunc_normal_(self.pos_embed, std=.02)
174 | trunc_normal_(self.cls_token, std=.02)
175 | self.apply(self._init_weights)
176 |
177 | def _init_weights(self, m):
178 | if isinstance(m, nn.Linear):
179 | trunc_normal_(m.weight, std=.02)
180 | if isinstance(m, nn.Linear) and m.bias is not None:
181 | nn.init.constant_(m.bias, 0)
182 | elif isinstance(m, nn.LayerNorm):
183 | nn.init.constant_(m.bias, 0)
184 | nn.init.constant_(m.weight, 1.0)
185 |
186 | def interpolate_pos_encoding(self, x, w, h):
187 | npatch = x.shape[1] - 1
188 | N = self.pos_embed.shape[1] - 1
189 | if npatch == N and w == h:
190 | return self.pos_embed
191 | class_pos_embed = self.pos_embed[:, 0]
192 | patch_pos_embed = self.pos_embed[:, 1:]
193 | dim = x.shape[-1]
194 | w0 = w // self.patch_embed.patch_size
195 | h0 = h // self.patch_embed.patch_size
196 | # we add a small number to avoid floating point error in the interpolation
197 | # see discussion at https://github.com/facebookresearch/dino/issues/8
198 | w0, h0 = w0 + 0.1, h0 + 0.1
199 | patch_pos_embed = nn.functional.interpolate(
200 | patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
201 | scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
202 | mode='bicubic',
203 | )
204 | assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
205 | patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
206 | return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
207 |
208 | def prepare_tokens(self, x):
209 | B, nc, w, h = x.shape
210 | x = self.patch_embed(x) # patch linear embedding
211 |
212 | # add the [CLS] token to the embed patch tokens
213 | cls_tokens = self.cls_token.expand(B, -1, -1)
214 | x = torch.cat((cls_tokens, x), dim=1)
215 |
216 | # add positional encoding to each token
217 | x = x + self.interpolate_pos_encoding(x, w, h)
218 |
219 | return self.pos_drop(x)
220 |
221 | def forward(self, x, return_attention=False):
222 | if return_attention:
223 | atten_weights = []
224 | x = self.prepare_tokens(x)
225 | for blk_ in self.blocks:
226 | x, weights = blk_(x, return_attention)
227 | atten_weights.append(weights)
228 | x = self.norm(x)
229 | return x, atten_weights
230 |
231 | else:
232 | x = self.prepare_tokens(x)
233 | for blk_ in self.blocks:
234 | x = blk_(x)
235 | x = self.norm(x)
236 | return x
237 |
238 | def get_last_selfattention(self, x):
239 | x = self.prepare_tokens(x)
240 | for i, blk in enumerate(self.blocks):
241 | if i < len(self.blocks) - 1:
242 | x = blk(x)
243 | else:
244 | # return attention of the last block
245 | return blk(x, return_attention=True)
246 |
247 | def get_intermediate_layers(self, x, n=1):
248 | x = self.prepare_tokens(x)
249 | # we return the output tokens from the `n` last blocks
250 | output = []
251 | for i, blk in enumerate(self.blocks):
252 | x = blk(x)
253 | if len(self.blocks) - i <= n:
254 | output.append(self.norm(x))
255 | return output
256 |
257 | def get_last_key(self, x, extra_layer=None):
258 | x = self.prepare_tokens(x)
259 | key_mid = 0
260 | for i, blk in enumerate(self.blocks):
261 | if extra_layer != None and i == extra_layer:
262 | x, key, attn = blk(x, return_key=True)
263 | key_mid = key
264 | elif i < len(self.blocks) - 1:
265 | x = blk(x)
266 | else:
267 | # return attention of the last block
268 | x, key, attn = blk(x, return_key=True)
269 | if extra_layer == None:
270 | return x, key, attn
271 | else:
272 | return key_mid, x, key, attn
273 |
274 |
275 | def vit_tiny(patch_size=16, **kwargs):
276 | model = VisionTransformer(
277 | patch_size=patch_size, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4,
278 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
279 | return model
280 |
281 |
282 | def vit_small(patch_size=16, **kwargs):
283 | model = VisionTransformer(
284 | patch_size=patch_size, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4,
285 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
286 | return model
287 |
288 |
289 | def vit_base(patch_size=16, **kwargs):
290 | model = VisionTransformer(
291 | patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
292 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
293 | return model
294 |
295 |
296 | class DINOHead(nn.Module):
297 | def __init__(self, in_dim, out_dim, use_bn=False, norm_last_layer=True, nlayers=3, hidden_dim=2048,
298 | bottleneck_dim=256):
299 | super().__init__()
300 | nlayers = max(nlayers, 1)
301 | if nlayers == 1:
302 | self.mlp = nn.Linear(in_dim, bottleneck_dim)
303 | else:
304 | layers = [nn.Linear(in_dim, hidden_dim)]
305 | if use_bn:
306 | layers.append(nn.BatchNorm1d(hidden_dim))
307 | layers.append(nn.GELU())
308 | for _ in range(nlayers - 2):
309 | layers.append(nn.Linear(hidden_dim, hidden_dim))
310 | if use_bn:
311 | layers.append(nn.BatchNorm1d(hidden_dim))
312 | layers.append(nn.GELU())
313 | layers.append(nn.Linear(hidden_dim, bottleneck_dim))
314 | self.mlp = nn.Sequential(*layers)
315 | self.apply(self._init_weights)
316 | self.last_layer = nn.utils.weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
317 | self.last_layer.weight_g.data.fill_(1)
318 | if norm_last_layer:
319 | self.last_layer.weight_g.requires_grad = False
320 |
321 | def _init_weights(self, m):
322 | if isinstance(m, nn.Linear):
323 | trunc_normal_(m.weight, std=.02)
324 | if isinstance(m, nn.Linear) and m.bias is not None:
325 | nn.init.constant_(m.bias, 0)
326 |
327 | def forward(self, x):
328 | x = self.mlp(x)
329 | x = nn.functional.normalize(x, dim=-1, p=2)
330 | x = self.last_layer(x)
331 | return x
332 |
--------------------------------------------------------------------------------
/models/locate.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from models.dino import vision_transformer as vits
5 | from models.dino.utils import load_pretrained_weights
6 | from models.model_util import *
7 | from fast_pytorch_kmeans import KMeans
8 |
9 |
10 | class Mlp(nn.Module):
11 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
12 | super(Mlp, self).__init__()
13 | out_features = out_features or in_features
14 | hidden_features = hidden_features or in_features
15 | self.norm = nn.LayerNorm(in_features)
16 | self.fc1 = nn.Linear(in_features, hidden_features)
17 | self.act = act_layer()
18 | self.fc2 = nn.Linear(hidden_features, out_features)
19 | self.drop = nn.Dropout(drop)
20 |
21 | def forward(self, x):
22 | x = self.norm(x)
23 | x = self.fc1(x)
24 | x = self.act(x)
25 | x = self.drop(x)
26 | x = self.fc2(x)
27 | x = self.drop(x)
28 | return x
29 |
30 |
31 | class Net(nn.Module):
32 |
33 | def __init__(self, aff_classes=36):
34 | super(Net, self).__init__()
35 |
36 | self.aff_classes = aff_classes
37 | self.gap = nn.AdaptiveAvgPool2d(1)
38 |
39 | # --- hyper-parameters --- #
40 | self.aff_cam_thd = 0.6
41 | self.part_iou_thd = 0.6
42 | self.cel_margin = 0.5
43 |
44 | # --- dino-vit features --- #
45 | self.vit_feat_dim = 384
46 | self.cluster_num = 3
47 | self.stride = 16
48 | self.patch = 16
49 |
50 | self.vit_model = vits.__dict__['vit_small'](patch_size=self.patch, num_classes=0)
51 | load_pretrained_weights(self.vit_model, '', None, 'vit_small', self.patch)
52 |
53 | # --- learning parameters --- #
54 | self.aff_proj = Mlp(in_features=self.vit_feat_dim, hidden_features=int(self.vit_feat_dim * 4),
55 | act_layer=nn.GELU, drop=0.)
56 | self.aff_ego_proj = nn.Sequential(
57 | nn.Conv2d(self.vit_feat_dim, self.vit_feat_dim, kernel_size=3, stride=1, padding=1),
58 | nn.BatchNorm2d(self.vit_feat_dim),
59 | nn.ReLU(True),
60 | nn.Conv2d(self.vit_feat_dim, self.vit_feat_dim, kernel_size=3, stride=1, padding=1),
61 | nn.BatchNorm2d(self.vit_feat_dim),
62 | nn.ReLU(True),
63 | )
64 | self.aff_exo_proj = nn.Sequential(
65 | nn.Conv2d(self.vit_feat_dim, self.vit_feat_dim, kernel_size=3, stride=1, padding=1),
66 | nn.BatchNorm2d(self.vit_feat_dim),
67 | nn.ReLU(True),
68 | nn.Conv2d(self.vit_feat_dim, self.vit_feat_dim, kernel_size=3, stride=1, padding=1),
69 | nn.BatchNorm2d(self.vit_feat_dim),
70 | nn.ReLU(True),
71 | )
72 | self.aff_fc = nn.Conv2d(self.vit_feat_dim, self.aff_classes, 1)
73 |
74 | def forward(self, exo, ego, aff_label, epoch):
75 |
76 | num_exo = exo.shape[1]
77 | exo = exo.flatten(0, 1) # b*num_exo x 3 x 224 x 224
78 |
79 | # --- Extract deep descriptors from DINO-vit --- #
80 | with torch.no_grad():
81 | _, ego_key, ego_attn = self.vit_model.get_last_key(ego) # attn: b x 6 x (1+hw) x (1+hw)
82 | _, exo_key, exo_attn = self.vit_model.get_last_key(exo)
83 | ego_desc = ego_key.permute(0, 2, 3, 1).flatten(-2, -1).detach()
84 | exo_desc = exo_key.permute(0, 2, 3, 1).flatten(-2, -1).detach()
85 |
86 | ego_proj = ego_desc[:, 1:] + self.aff_proj(ego_desc[:, 1:])
87 | exo_proj = exo_desc[:, 1:] + self.aff_proj(exo_desc[:, 1:])
88 | ego_desc = self._reshape_transform(ego_desc[:, 1:, :], self.patch, self.stride)
89 | exo_desc = self._reshape_transform(exo_desc[:, 1:, :], self.patch, self.stride)
90 | ego_proj = self._reshape_transform(ego_proj, self.patch, self.stride)
91 | exo_proj = self._reshape_transform(exo_proj, self.patch, self.stride)
92 |
93 | b, c, h, w = ego_desc.shape
94 | ego_cls_attn = ego_attn[:, :, 0, 1:].reshape(b, 6, h, w)
95 | ego_cls_attn = (ego_cls_attn > ego_cls_attn.flatten(-2, -1).mean(-1, keepdim=True).unsqueeze(-1)).float()
96 | head_idxs = [0, 1, 3]
97 | ego_sam = ego_cls_attn[:, head_idxs].mean(1)
98 | ego_sam = normalize_minmax(ego_sam)
99 | ego_sam_flat = ego_sam.flatten(-2, -1)
100 |
101 | # --- Affordance CAM generation --- #
102 | exo_proj = self.aff_exo_proj(exo_proj)
103 | aff_cam = self.aff_fc(exo_proj) # b*num_exo x 36 x h x w
104 | aff_logits = self.gap(aff_cam).reshape(b, num_exo, self.aff_classes)
105 | aff_cam_re = aff_cam.reshape(b, num_exo, self.aff_classes, h, w)
106 |
107 | gt_aff_cam = torch.zeros(b, num_exo, h, w).cuda()
108 | for b_ in range(b):
109 | gt_aff_cam[b_, :] = aff_cam_re[b_, :, aff_label[b_]]
110 |
111 | # --- Clustering extracted descriptors based on CAM --- #
112 | ego_desc_flat = ego_desc.flatten(-2, -1) # b x 384 x hw
113 | exo_desc_re_flat = exo_desc.reshape(b, num_exo, c, h, w).flatten(-2, -1)
114 | sim_maps = torch.zeros(b, self.cluster_num, h * w).cuda()
115 | exo_sim_maps = torch.zeros(b, num_exo, self.cluster_num, h * w).cuda()
116 | part_score = torch.zeros(b, self.cluster_num).cuda()
117 | part_proto = torch.zeros(b, c).cuda()
118 | for b_ in range(b):
119 | exo_aff_desc = []
120 | for n in range(num_exo):
121 | tmp_cam = gt_aff_cam[b_, n].reshape(-1)
122 | tmp_max, tmp_min = tmp_cam.max(), tmp_cam.min()
123 | tmp_cam = (tmp_cam - tmp_min) / (tmp_max - tmp_min + 1e-10)
124 | tmp_desc = exo_desc_re_flat[b_, n]
125 | tmp_top_desc = tmp_desc[:, torch.where(tmp_cam > self.aff_cam_thd)[0]].T # n x c
126 | exo_aff_desc.append(tmp_top_desc)
127 | exo_aff_desc = torch.cat(exo_aff_desc, dim=0) # (n1 + n2 + n3) x c
128 |
129 | if exo_aff_desc.shape[0] < self.cluster_num:
130 | continue
131 |
132 | kmeans = KMeans(n_clusters=self.cluster_num, mode='euclidean', max_iter=300)
133 | kmeans.fit_predict(exo_aff_desc.contiguous())
134 | clu_cens = F.normalize(kmeans.centroids, dim=1)
135 |
136 | # save the exocentric similarity maps for visualization in training
137 | for n_ in range(num_exo):
138 | exo_sim_maps[b_, n_] = torch.mm(clu_cens, F.normalize(exo_desc_re_flat[b_, n_], dim=0))
139 |
140 | # find object part prototypes and background prototypes
141 | sim_map = torch.mm(clu_cens, F.normalize(ego_desc_flat[b_], dim=0)) # self.cluster_num x hw
142 | tmp_sim_max, tmp_sim_min = torch.max(sim_map, dim=-1, keepdim=True)[0], \
143 | torch.min(sim_map, dim=-1, keepdim=True)[0]
144 | sim_map_norm = (sim_map - tmp_sim_min) / (tmp_sim_max - tmp_sim_min + 1e-12)
145 |
146 | sim_map_hard = (sim_map_norm > torch.mean(sim_map_norm, 1, keepdim=True)).float()
147 | sam_hard = (ego_sam_flat > torch.mean(ego_sam_flat, 1, keepdim=True)).float()
148 |
149 | inter = (sim_map_hard * sam_hard[b_]).sum(1)
150 | union = sim_map_hard.sum(1) + sam_hard[b_].sum() - inter
151 | p_score = (inter / sim_map_hard.sum(1) + sam_hard[b_].sum() / union) / 2
152 |
153 | sim_maps[b_] = sim_map
154 | part_score[b_] = p_score
155 |
156 | if p_score.max() < self.part_iou_thd:
157 | continue
158 |
159 | part_proto[b_] = clu_cens[torch.argmax(p_score)]
160 |
161 | sim_maps = sim_maps.reshape(b, self.cluster_num, h, w)
162 | exo_sim_maps = exo_sim_maps.reshape(b, num_exo, self.cluster_num, h, w)
163 | ego_proj = self.aff_ego_proj(ego_proj)
164 | ego_pred = self.aff_fc(ego_proj)
165 | aff_logits_ego = self.gap(ego_pred).view(b, self.aff_classes)
166 |
167 | # --- concentration loss --- #
168 | gt_ego_cam = torch.zeros(b, h, w).cuda()
169 | loss_con = torch.zeros(1).cuda()
170 | for b_ in range(b):
171 | gt_ego_cam[b_] = ego_pred[b_, aff_label[b_]]
172 | loss_con += concentration_loss(ego_pred[b_])
173 |
174 | gt_ego_cam = normalize_minmax(gt_ego_cam)
175 | loss_con /= b
176 |
177 | # --- prototype guidance loss --- #
178 | loss_proto = torch.zeros(1).cuda()
179 | valid_batch = 0
180 | if epoch[0] > epoch[1]:
181 | for b_ in range(b):
182 | if not part_proto[b_].equal(torch.zeros(c).cuda()):
183 | mask = gt_ego_cam[b_]
184 | tmp_feat = ego_desc[b_] * mask
185 | embedding = tmp_feat.reshape(tmp_feat.shape[0], -1).sum(1) / mask.sum()
186 | loss_proto += torch.max(
187 | 1 - F.cosine_similarity(embedding, part_proto[b_], dim=0) - self.cel_margin,
188 | torch.zeros(1).cuda())
189 | valid_batch += 1
190 | loss_proto = loss_proto / (valid_batch + 1e-15)
191 |
192 | masks = {'exo_aff': gt_aff_cam, 'ego_sam': ego_sam,
193 | 'pred': (sim_maps, exo_sim_maps, part_score, gt_ego_cam)}
194 | logits = {'aff': aff_logits, 'aff_ego': aff_logits_ego}
195 |
196 | return masks, logits, loss_proto, loss_con
197 |
198 | @torch.no_grad()
199 | def test_forward(self, ego, aff_label):
200 | _, ego_key, ego_attn = self.vit_model.get_last_key(ego) # attn: b x 6 x (1+hw) x (1+hw)
201 | ego_desc = ego_key.permute(0, 2, 3, 1).flatten(-2, -1)
202 | ego_proj = ego_desc[:, 1:] + self.aff_proj(ego_desc[:, 1:])
203 | ego_desc = self._reshape_transform(ego_desc[:, 1:, :], self.patch, self.stride)
204 | ego_proj = self._reshape_transform(ego_proj, self.patch, self.stride)
205 |
206 | b, c, h, w = ego_desc.shape
207 | ego_proj = self.aff_ego_proj(ego_proj)
208 | ego_pred = self.aff_fc(ego_proj)
209 |
210 | gt_ego_cam = torch.zeros(b, h, w).cuda()
211 | for b_ in range(b):
212 | gt_ego_cam[b_] = ego_pred[b_, aff_label[b_]]
213 |
214 | return gt_ego_cam
215 |
216 | def _reshape_transform(self, tensor, patch_size, stride):
217 | height = (224 - patch_size) // stride + 1
218 | width = (224 - patch_size) // stride + 1
219 | result = tensor.reshape(tensor.size(0), height, width, tensor.size(-1))
220 | result = result.transpose(2, 3).transpose(1, 2).contiguous()
221 | return result
222 |
--------------------------------------------------------------------------------
/models/model_util.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | import numpy as np
4 |
5 | __all__ = ['normalize_minmax', 'concentration_loss']
6 |
7 |
8 | def normalize_minmax(cams, eps=1e-15):
9 | B, _, _ = cams.shape
10 | min_value, _ = cams.view(B, -1).min(1)
11 | cams_minmax = cams - min_value.view(B, 1, 1)
12 | max_value, _ = cams_minmax.view(B, -1).max(1)
13 | cams_minmax /= max_value.view(B, 1, 1) + eps
14 | return cams_minmax
15 |
16 |
17 | def get_variance(part_map, x_c, y_c):
18 | h, w = part_map.shape
19 | x_map, y_map = get_coordinate_tensors(h, w)
20 |
21 | v_x_map = (x_map - x_c) * (x_map - x_c)
22 | v_y_map = (y_map - y_c) * (y_map - y_c)
23 |
24 | v_x = (part_map * v_x_map).sum()
25 | v_y = (part_map * v_y_map).sum()
26 | return v_x, v_y
27 |
28 |
29 | def get_coordinate_tensors(x_max, y_max):
30 | x_map = np.tile(np.arange(x_max), (y_max, 1)) / x_max * 2 - 1.0
31 | y_map = np.tile(np.arange(y_max), (x_max, 1)).T / y_max * 2 - 1.0
32 |
33 | x_map_tensor = torch.from_numpy(x_map.astype(np.float32)).cuda()
34 | y_map_tensor = torch.from_numpy(y_map.astype(np.float32)).cuda()
35 |
36 | return x_map_tensor, y_map_tensor
37 |
38 |
39 | def get_center(part_map, self_referenced=False):
40 | h, w = part_map.shape
41 | x_map, y_map = get_coordinate_tensors(h, w)
42 |
43 | x_center = (part_map * x_map).sum()
44 | y_center = (part_map * y_map).sum()
45 |
46 | if self_referenced:
47 | x_c_value = float(x_center.cpu().detach())
48 | y_c_value = float(y_center.cpu().detach())
49 | x_center = (part_map * (x_map - x_c_value)).sum() + x_c_value
50 | y_center = (part_map * (y_map - y_c_value)).sum() + y_c_value
51 |
52 | return x_center, y_center
53 |
54 |
55 | def get_centers(part_maps, detach_k=True, epsilon=1e-3, self_ref_coord=False):
56 | H, W = part_maps.shape
57 | part_map = part_maps + epsilon
58 | k = part_map.sum()
59 | part_map_pdf = part_map / k
60 | x_c, y_c = get_center(part_map_pdf, self_ref_coord)
61 | centers = torch.stack((x_c, y_c), dim=0)
62 | return centers
63 |
64 |
65 | def batch_get_centers(pred_norm):
66 | B, H, W = pred_norm.shape
67 |
68 | centers_list = []
69 | for b in range(B):
70 | centers_list.append(get_centers(pred_norm[b]).unsqueeze(0))
71 | return torch.cat(centers_list, dim=0)
72 |
73 |
74 | # Code borrowed from SCOPS https://github.com/NVlabs/SCOPS
75 | def concentration_loss(pred):
76 | # b x h x w
77 | B, H, W = pred.shape
78 | tmp_max, tmp_min = pred.max(-1)[0].max(-1)[0].view(B, 1, 1), \
79 | pred.min(-1)[0].min(-1)[0].view(B, 1, 1)
80 |
81 | pred_norm = ((pred - tmp_min) / (tmp_max - tmp_min + 1e-10)) # b x 28 x 28
82 |
83 | loss = 0
84 | epsilon = 1e-3
85 | centers_all = batch_get_centers(pred_norm)
86 | for b in range(B):
87 | centers = centers_all[b]
88 | # normalize part map as spatial pdf
89 | part_map = pred_norm[b, :, :] + epsilon # prevent gradient explosion
90 | k = part_map.sum()
91 | part_map_pdf = part_map / k
92 | x_c, y_c = centers
93 | v_x, v_y = get_variance(part_map_pdf, x_c, y_c)
94 | loss_per_part = (v_x + v_y)
95 | loss = loss_per_part + loss
96 | loss = loss / B
97 | return loss
98 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | argparse
2 | tqdm
3 | fast-pytorch-kmeans
4 | numpy
5 | matplotlib
6 | opencv-python
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | from tqdm import tqdm
4 |
5 | import cv2
6 | import torch
7 | import numpy as np
8 | from models.locate import Net as model
9 |
10 | from utils.viz import viz_pred_test
11 | from utils.util import set_seed, process_gt, normalize_map
12 | from utils.evaluation import cal_kl, cal_sim, cal_nss
13 |
14 | parser = argparse.ArgumentParser()
15 | ## path
16 | parser.add_argument('--data_root', type=str, default='/home/gen/Project/aff_grounding/dataset/AGD20K/')
17 | parser.add_argument('--model_file', type=str, default=None)
18 | parser.add_argument('--save_path', type=str, default='./save_preds')
19 | parser.add_argument("--divide", type=str, default="Seen")
20 | ## image
21 | parser.add_argument('--crop_size', type=int, default=224)
22 | parser.add_argument('--resize_size', type=int, default=256)
23 | #### test
24 | parser.add_argument('--num_workers', type=int, default=8)
25 | parser.add_argument("--test_batch_size", type=int, default=1)
26 | parser.add_argument('--test_num_workers', type=int, default=8)
27 | parser.add_argument('--gpu', type=str, default='0')
28 | parser.add_argument('--viz', action='store_true', default=False)
29 |
30 | args = parser.parse_args()
31 |
32 | if args.divide == "Seen":
33 | aff_list = ['beat', "boxing", "brush_with", "carry", "catch", "cut", "cut_with", "drag", 'drink_with',
34 | "eat", "hit", "hold", "jump", "kick", "lie_on", "lift", "look_out", "open", "pack", "peel",
35 | "pick_up", "pour", "push", "ride", "sip", "sit_on", "stick", "stir", "swing", "take_photo",
36 | "talk_on", "text_on", "throw", "type_on", "wash", "write"]
37 | else:
38 | aff_list = ["carry", "catch", "cut", "cut_with", 'drink_with',
39 | "eat", "hit", "hold", "jump", "kick", "lie_on", "open", "peel",
40 | "pick_up", "pour", "push", "ride", "sip", "sit_on", "stick",
41 | "swing", "take_photo", "throw", "type_on", "wash"]
42 |
43 | if args.divide == "Seen":
44 | args.num_classes = 36
45 | else:
46 | args.num_classes = 25
47 |
48 | args.test_root = os.path.join(args.data_root, args.divide, "testset", "egocentric")
49 | args.mask_root = os.path.join(args.data_root, args.divide, "testset", "GT")
50 |
51 | if args.viz:
52 | if not os.path.exists(args.save_path):
53 | os.makedirs(args.save_path, exist_ok=True)
54 |
55 | if __name__ == '__main__':
56 | set_seed(seed=0)
57 |
58 | from data.datatest import TestData
59 |
60 | testset = TestData(image_root=args.test_root,
61 | crop_size=args.crop_size,
62 | divide=args.divide, mask_root=args.mask_root)
63 | TestLoader = torch.utils.data.DataLoader(dataset=testset,
64 | batch_size=args.test_batch_size,
65 | shuffle=False,
66 | num_workers=args.test_num_workers,
67 | pin_memory=True)
68 |
69 | model = model(aff_classes=args.num_classes).cuda()
70 |
71 | KLs = []
72 | SIM = []
73 | NSS = []
74 | model.eval()
75 | assert os.path.exists(args.model_file), "Please provide the correct model file for testing"
76 | model.load_state_dict(torch.load(args.model_file))
77 |
78 | GT_path = args.divide + "_gt.t7"
79 | if not os.path.exists(GT_path):
80 | process_gt(args)
81 | GT_masks = torch.load(args.divide + "_gt.t7")
82 |
83 | for step, (image, label, mask_path) in enumerate(tqdm(TestLoader)):
84 | ego_pred = model.test_forward(image.cuda(), label.long().cuda())
85 | cluster_sim_maps = []
86 | ego_pred = np.array(ego_pred.squeeze().data.cpu())
87 | ego_pred = normalize_map(ego_pred, args.crop_size)
88 |
89 | names = mask_path[0].split("/")
90 | key = names[-3] + "_" + names[-2] + "_" + names[-1]
91 | GT_mask = GT_masks[key]
92 | GT_mask = GT_mask / 255.0
93 |
94 | GT_mask = cv2.resize(GT_mask, (args.crop_size, args.crop_size))
95 |
96 | kld, sim, nss = cal_kl(ego_pred, GT_mask), cal_sim(ego_pred, GT_mask), cal_nss(ego_pred, GT_mask)
97 | KLs.append(kld)
98 | SIM.append(sim)
99 | NSS.append(nss)
100 |
101 | if args.viz:
102 | img_name = key.split(".")[0]
103 | viz_pred_test(args, image, ego_pred, GT_mask, aff_list, label, img_name)
104 |
105 | mKLD = sum(KLs) / len(KLs)
106 | mSIM = sum(SIM) / len(SIM)
107 | mNSS = sum(NSS) / len(NSS)
108 |
109 | print(f"KLD = {round(mKLD, 3)}\nSIM = {round(mSIM, 3)}\nNSS = {round(mNSS, 3)}")
110 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import time
4 | import shutil
5 | import logging
6 | import argparse
7 |
8 | import cv2
9 | import torch
10 | import torch.nn as nn
11 | import numpy as np
12 | from models.locate import Net as model
13 |
14 | from utils.viz import viz_pred_train, viz_pred_test
15 | from utils.util import set_seed, process_gt, normalize_map, get_optimizer
16 | from utils.evaluation import cal_kl, cal_sim, cal_nss, AverageMeter, compute_cls_acc
17 |
18 | parser = argparse.ArgumentParser()
19 | ## path
20 | parser.add_argument('--data_root', type=str, default='/home/gen/Project/aff_grounding/dataset/AGD20K/')
21 | parser.add_argument('--save_root', type=str, default='save_models')
22 | parser.add_argument("--divide", type=str, default="Seen")
23 | ## image
24 | parser.add_argument('--crop_size', type=int, default=224)
25 | parser.add_argument('--resize_size', type=int, default=256)
26 | ## dataloader
27 | parser.add_argument('--num_workers', type=int, default=8)
28 | ## train
29 | parser.add_argument('--batch_size', type=int, default=16)
30 | parser.add_argument('--warm_epoch', type=int, default=0)
31 | parser.add_argument('--epochs', type=int, default=15)
32 | parser.add_argument('--lr', type=float, default=0.001)
33 | parser.add_argument('--momentum', type=float, default=0.9)
34 | parser.add_argument('--weight_decay', type=float, default=5e-4)
35 | parser.add_argument('--show_step', type=int, default=100)
36 | parser.add_argument('--gpu', type=str, default='0')
37 | parser.add_argument('--viz', action='store_true', default=False)
38 |
39 | #### test
40 | parser.add_argument("--test_batch_size", type=int, default=1)
41 | parser.add_argument('--test_num_workers', type=int, default=8)
42 |
43 | args = parser.parse_args()
44 | torch.cuda.set_device('cuda:' + args.gpu)
45 | lr = args.lr
46 |
47 | if args.divide == "Seen":
48 | aff_list = ['beat', "boxing", "brush_with", "carry", "catch", "cut", "cut_with", "drag", 'drink_with',
49 | "eat", "hit", "hold", "jump", "kick", "lie_on", "lift", "look_out", "open", "pack", "peel",
50 | "pick_up", "pour", "push", "ride", "sip", "sit_on", "stick", "stir", "swing", "take_photo",
51 | "talk_on", "text_on", "throw", "type_on", "wash", "write"]
52 | else:
53 | aff_list = ["carry", "catch", "cut", "cut_with", 'drink_with',
54 | "eat", "hit", "hold", "jump", "kick", "lie_on", "open", "peel",
55 | "pick_up", "pour", "push", "ride", "sip", "sit_on", "stick",
56 | "swing", "take_photo", "throw", "type_on", "wash"]
57 |
58 | if args.divide == "Seen":
59 | args.num_classes = 36
60 | else:
61 | args.num_classes = 25
62 |
63 | args.exocentric_root = os.path.join(args.data_root, args.divide, "trainset", "exocentric")
64 | args.egocentric_root = os.path.join(args.data_root, args.divide, "trainset", "egocentric")
65 | args.test_root = os.path.join(args.data_root, args.divide, "testset", "egocentric")
66 | args.mask_root = os.path.join(args.data_root, args.divide, "testset", "GT")
67 | time_str = time.strftime('%Y%m%d_%H%M%S', time.localtime(time.time()))
68 | args.save_path = os.path.join(args.save_root, time_str)
69 |
70 | if not os.path.exists(args.save_path):
71 | os.makedirs(args.save_path, exist_ok=True)
72 | dict_args = vars(args)
73 |
74 | shutil.copy('./models/locate.py', args.save_path)
75 | shutil.copy('./train.py', args.save_path)
76 |
77 | str_1 = ""
78 | for key, value in dict_args.items():
79 | str_1 += key + "=" + str(value) + "\n"
80 |
81 | logging.basicConfig(filename='%s/run.log' % args.save_path, level=logging.INFO, format='%(message)s')
82 | logger = logging.getLogger()
83 | logger.addHandler(logging.StreamHandler(sys.stdout))
84 | logger.info(str_1)
85 |
86 | if __name__ == '__main__':
87 | set_seed(seed=0)
88 |
89 | from data.datatrain import TrainData
90 |
91 | trainset = TrainData(exocentric_root=args.exocentric_root,
92 | egocentric_root=args.egocentric_root,
93 | resize_size=args.resize_size,
94 | crop_size=args.crop_size, divide=args.divide)
95 |
96 | TrainLoader = torch.utils.data.DataLoader(dataset=trainset,
97 | batch_size=args.batch_size,
98 | shuffle=True,
99 | num_workers=args.num_workers,
100 | pin_memory=True)
101 |
102 | from data.datatest import TestData
103 |
104 | testset = TestData(image_root=args.test_root,
105 | crop_size=args.crop_size,
106 | divide=args.divide, mask_root=args.mask_root)
107 | TestLoader = torch.utils.data.DataLoader(dataset=testset,
108 | batch_size=args.test_batch_size,
109 | shuffle=False,
110 | num_workers=args.test_num_workers,
111 | pin_memory=True)
112 |
113 | model = model(aff_classes=args.num_classes)
114 | model = model.cuda()
115 | model.train()
116 | optimizer, scheduler = get_optimizer(model, args)
117 |
118 | best_kld = 1000
119 | print('Train begining!')
120 | for epoch in range(args.epochs):
121 | model.train()
122 | logger.info('LR = ' + str(scheduler.get_last_lr()))
123 | exo_aff_acc = AverageMeter()
124 | ego_obj_acc = AverageMeter()
125 |
126 | for step, (exocentric_image, egocentric_image, aff_label) in enumerate(TrainLoader):
127 | aff_label = aff_label.cuda().long() # b x n x 36
128 | exo = exocentric_image.cuda() # b x n x 3 x 224 x 224
129 | ego = egocentric_image.cuda()
130 |
131 | masks, logits, loss_proto, loss_con = model(exo, ego, aff_label, (epoch, args.warm_epoch))
132 |
133 | exo_aff_logits = logits['aff']
134 | num_exo = exo.shape[1]
135 | exo_aff_loss = torch.zeros(1).cuda()
136 | for n in range(num_exo):
137 | exo_aff_loss += nn.CrossEntropyLoss().cuda()(exo_aff_logits[:, n], aff_label)
138 | exo_aff_loss /= num_exo
139 |
140 | loss_dict = {'ego_ce': nn.CrossEntropyLoss().cuda()(logits['aff_ego'], aff_label),
141 | 'exo_ce': exo_aff_loss,
142 | 'con_loss': loss_proto,
143 | 'loss_cen': loss_con * 0.07,
144 | }
145 |
146 | loss = sum(loss_dict.values())
147 | optimizer.zero_grad()
148 | loss.backward()
149 | optimizer.step()
150 |
151 | cur_batch = exo.size(0)
152 | exo_acc = 100. * compute_cls_acc(logits['aff'].mean(1), aff_label)
153 | exo_aff_acc.updata(exo_acc, cur_batch)
154 | metric_dict = {'exo_aff_acc': exo_aff_acc.avg}
155 |
156 | if (step + 1) % args.show_step == 0:
157 | log_str = 'epoch: %d/%d + %d/%d | ' % (epoch + 1, args.epochs, step + 1, len(TrainLoader))
158 | log_str += ' | '.join(['%s: %.3f' % (k, v) for k, v in metric_dict.items()])
159 | log_str += ' | '
160 | log_str += ' | '.join(['%s: %.3f' % (k, v) for k, v in loss_dict.items()])
161 | logger.info(log_str)
162 |
163 | # Visualization the prediction during training
164 | if args.viz:
165 | viz_pred_train(args, ego, exo, masks, aff_list, aff_label, epoch, step + 1)
166 |
167 | scheduler.step()
168 | KLs = []
169 | SIM = []
170 | NSS = []
171 | model.eval()
172 | GT_path = args.divide + "_gt.t7"
173 | if not os.path.exists(GT_path):
174 | process_gt(args)
175 | GT_masks = torch.load(args.divide + "_gt.t7")
176 |
177 | for step, (image, label, mask_path) in enumerate(TestLoader):
178 | ego_pred = model.test_forward(image.cuda(), label.long().cuda())
179 | cluster_sim_maps = []
180 | ego_pred = np.array(ego_pred.squeeze().data.cpu())
181 | ego_pred = normalize_map(ego_pred, args.crop_size)
182 |
183 | names = mask_path[0].split("/")
184 | key = names[-3] + "_" + names[-2] + "_" + names[-1]
185 | GT_mask = GT_masks[key]
186 | GT_mask = GT_mask / 255.0
187 |
188 | GT_mask = cv2.resize(GT_mask, (args.crop_size, args.crop_size))
189 |
190 | kld, sim, nss = cal_kl(ego_pred, GT_mask), cal_sim(ego_pred, GT_mask), cal_nss(ego_pred, GT_mask)
191 | KLs.append(kld)
192 | SIM.append(sim)
193 | NSS.append(nss)
194 |
195 | # Visualization the prediction during evaluation
196 | if args.viz:
197 | if (step + 1) % args.show_step == 0:
198 | img_name = key.split(".")[0]
199 | viz_pred_test(args, image, ego_pred, GT_mask, aff_list, label, img_name, epoch)
200 |
201 | mKLD = sum(KLs) / len(KLs)
202 | mSIM = sum(SIM) / len(SIM)
203 | mNSS = sum(NSS) / len(NSS)
204 |
205 | logger.info(
206 | "epoch=" + str(epoch + 1) + " mKLD = " + str(round(mKLD, 3))
207 | + " mSIM = " + str(round(mSIM, 3)) + " mNSS = " + str(round(mNSS, 3))
208 | + " bestKLD = " + str(round(best_kld, 3)))
209 |
210 | if mKLD < best_kld:
211 | best_kld = mKLD
212 | model_name = 'best_model_' + str(epoch + 1) + '_' + str(round(best_kld, 3)) \
213 | + '_' + str(round(mSIM, 3)) \
214 | + '_' + str(round(mNSS, 3)) \
215 | + '.pth'
216 | torch.save(model.state_dict(), os.path.join(args.save_path, model_name))
217 |
--------------------------------------------------------------------------------
/utils/evaluation.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 |
5 | def cal_kl(pred: np.ndarray, gt: np.ndarray, eps=1e-12) -> np.ndarray:
6 | map1, map2 = pred / (pred.sum() + eps), gt / (gt.sum() + eps)
7 | kld = np.sum(map2 * np.log(map2 / (map1 + eps) + eps))
8 | return kld
9 |
10 |
11 | def cal_sim(pred: np.ndarray, gt: np.ndarray, eps=1e-12) -> np.ndarray:
12 | map1, map2 = pred / (pred.sum() + eps), gt / (gt.sum() + eps)
13 | intersection = np.minimum(map1, map2)
14 |
15 | return np.sum(intersection)
16 |
17 |
18 | def image_binary(image, threshold):
19 | output = np.zeros(image.size).reshape(image.shape)
20 | for xx in range(image.shape[0]):
21 | for yy in range(image.shape[1]):
22 | if (image[xx][yy] > threshold):
23 | output[xx][yy] = 1
24 | return output
25 |
26 |
27 | def cal_nss(pred: np.ndarray, gt: np.ndarray) -> np.ndarray:
28 | pred = pred / 255.0
29 | gt = gt / 255.0
30 | std = np.std(pred)
31 | u = np.mean(pred)
32 |
33 | smap = (pred - u) / std
34 | fixation_map = (gt - np.min(gt)) / (np.max(gt) - np.min(gt) + 1e-12)
35 | fixation_map = image_binary(fixation_map, 0.1)
36 |
37 | nss = smap * fixation_map
38 |
39 | nss = np.sum(nss) / np.sum(fixation_map + 1e-12)
40 |
41 | return nss
42 |
43 |
44 | def compute_cls_acc(preds, label):
45 | pred = torch.max(preds, 1)[1]
46 | # label = torch.max(labels, 1)[1]
47 | num_correct = (pred == label).sum()
48 | return float(num_correct) / float(preds.size(0))
49 |
50 |
51 | class AverageMeter(object):
52 | def __init__(self):
53 | self.reset()
54 |
55 | def reset(self):
56 | self.val = 0.0
57 | self.avg = 0.0
58 | self.sum = 0.0
59 | self.cnt = 0.0
60 |
61 | def updata(self, val, n=1.0):
62 | self.val = val
63 | self.sum += val * n
64 | self.cnt += n
65 | if self.cnt == 0:
66 | self.avg = 1
67 | else:
68 | self.avg = self.sum / self.cnt
69 |
--------------------------------------------------------------------------------
/utils/util.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import os
3 | import random
4 | import torch
5 | import numpy as np
6 | from PIL import Image
7 | from matplotlib import cm
8 |
9 |
10 | def set_seed(seed=0):
11 | np.random.seed(seed)
12 | torch.manual_seed(seed)
13 | torch.cuda.manual_seed(seed)
14 | random.seed(seed)
15 | torch.cuda.manual_seed_all(seed)
16 | torch.backends.cudnn.deterministic = True
17 | torch.backends.cudnn.benchmark = False
18 |
19 |
20 | def process_gt(args):
21 | assert args.divide in ["Seen", "Unseen"], "The divide argument should be Seen or Unseen"
22 | files = os.listdir(args.mask_root)
23 | dict_1 = {}
24 | for file in files:
25 | file_path = os.path.join(args.mask_root, file)
26 | objs = os.listdir(file_path)
27 | for obj in objs:
28 | obj_path = os.path.join(file_path, obj)
29 | images = os.listdir(obj_path)
30 | for img in images:
31 | img_path = os.path.join(obj_path, img)
32 | mask = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
33 | key = file + "_" + obj + "_" + img
34 | dict_1[key] = mask
35 |
36 | torch.save(dict_1, args.divide + "_gt.t7")
37 |
38 |
39 | def normalize_map(atten_map, crop_size):
40 | atten_map = cv2.resize(atten_map, dsize=(crop_size, crop_size))
41 | min_val = np.min(atten_map)
42 | max_val = np.max(atten_map)
43 | atten_norm = (atten_map - min_val) / (max_val - min_val + 1e-10)
44 | return atten_norm
45 |
46 |
47 | def get_optimizer(model, args):
48 | lr = args.lr
49 | weight_list = []
50 | bias_list = []
51 | last_weight_list = []
52 | last_bias_list = []
53 | for name, value in model.named_parameters():
54 | if value.requires_grad:
55 | if 'fc' in name:
56 | if 'weight' in name:
57 | last_weight_list.append(value)
58 | elif 'bias' in name:
59 | last_bias_list.append(value)
60 | else:
61 | if 'weight' in name:
62 | weight_list.append(value)
63 | elif 'bias' in name:
64 | bias_list.append(value)
65 | optimizer = torch.optim.SGD([{'params': weight_list,
66 | 'lr': lr},
67 | {'params': bias_list,
68 | 'lr': lr * 2},
69 | {'params': last_weight_list,
70 | 'lr': lr * 10},
71 | {'params': last_bias_list,
72 | 'lr': lr * 20}],
73 | momentum=args.momentum,
74 | weight_decay=args.weight_decay,
75 | nesterov=True)
76 |
77 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)
78 | return optimizer, scheduler
79 |
80 |
81 | def overlay_mask(img: Image.Image, mask: Image.Image, colormap: str = "jet", alpha: float = 0.7) -> Image.Image:
82 | if not isinstance(img, Image.Image) or not isinstance(mask, Image.Image):
83 | raise TypeError("img and mask arguments need to be PIL.Image")
84 |
85 | if not isinstance(alpha, float) or alpha < 0 or alpha >= 1:
86 | raise ValueError("alpha argument is expected to be of type float between 0 and 1")
87 |
88 | cmap = cm.get_cmap(colormap)
89 | # Resize mask and apply colormap
90 | overlay = mask.resize(img.size, resample=Image.BICUBIC)
91 | overlay = (255 * cmap(np.asarray(overlay) ** 2)[:, :, :3]).astype(np.uint8)
92 | # Overlay the image with the mask
93 | overlayed_img = Image.fromarray((alpha * np.asarray(img) + (1 - alpha) * overlay).astype(np.uint8))
94 |
95 | return overlayed_img
96 |
--------------------------------------------------------------------------------
/utils/viz.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numpy as np
4 | from PIL import Image
5 | from utils.util import normalize_map, overlay_mask
6 | import matplotlib.pyplot as plt
7 |
8 |
9 | # visualize the prediction of the first batch
10 | def viz_pred_train(args, ego, exo, masks, aff_list, aff_label, epoch, step):
11 | mean = torch.as_tensor([0.485, 0.456, 0.406], dtype=ego.dtype, device=ego.device).view(-1, 1, 1)
12 | std = torch.as_tensor([0.229, 0.224, 0.225], dtype=ego.dtype, device=ego.device).view(-1, 1, 1)
13 |
14 | ego_0 = ego[0].squeeze(0) * std + mean
15 | ego_0 = ego_0.detach().cpu().numpy() * 255
16 | ego_0 = Image.fromarray(ego_0.transpose(1, 2, 0).astype(np.uint8))
17 |
18 | exo_img = []
19 | num_exo = exo.shape[1]
20 | for i in range(num_exo):
21 | name = 'exo_' + str(i)
22 | locals()[name] = exo[0][i].squeeze(0) * std + mean
23 | locals()[name] = locals()[name].detach().cpu().numpy() * 255
24 | locals()[name] = Image.fromarray(locals()[name].transpose(1, 2, 0).astype(np.uint8))
25 | exo_img.append(locals()[name])
26 |
27 | exo_cam = masks['exo_aff'][0]
28 |
29 | sim_maps, exo_sim_maps, part_score, ego_pred = masks['pred']
30 | num_clu = sim_maps.shape[1]
31 | part_score = np.array(part_score[0].squeeze().data.cpu())
32 |
33 | ego_pred = np.array(ego_pred[0].squeeze().data.cpu())
34 | ego_pred = normalize_map(ego_pred, args.crop_size)
35 | ego_pred = Image.fromarray(ego_pred)
36 | ego_pred = overlay_mask(ego_0, ego_pred, alpha=0.5)
37 |
38 | ego_sam = masks['ego_sam']
39 | ego_sam = np.array(ego_sam[0].squeeze().data.cpu())
40 | ego_sam = normalize_map(ego_sam, args.crop_size)
41 | ego_sam = Image.fromarray(ego_sam)
42 | ego_sam = overlay_mask(ego_0, ego_sam, alpha=0.1)
43 |
44 | aff_str = aff_list[aff_label[0].item()]
45 |
46 | for i in range(num_exo):
47 | name = 'exo_aff' + str(i)
48 | locals()[name] = np.array(exo_cam[i].squeeze().data.cpu())
49 | locals()[name] = normalize_map(locals()[name], args.crop_size)
50 | locals()[name] = Image.fromarray(locals()[name])
51 | locals()[name] = overlay_mask(exo_img[i], locals()[name], alpha=0.5)
52 |
53 | for i in range(num_clu):
54 | name = 'sim_map' + str(i)
55 | locals()[name] = np.array(sim_maps[0][i].squeeze().data.cpu())
56 | locals()[name] = normalize_map(locals()[name], args.crop_size)
57 | locals()[name] = Image.fromarray(locals()[name])
58 | locals()[name] = overlay_mask(ego_0, locals()[name], alpha=0.5)
59 |
60 | # Similarity maps for the first exocentric image
61 | name = 'exo_sim_map' + str(i)
62 | locals()[name] = np.array(exo_sim_maps[0, 0][i].squeeze().data.cpu())
63 | locals()[name] = normalize_map(locals()[name], args.crop_size)
64 | locals()[name] = Image.fromarray(locals()[name])
65 | locals()[name] = overlay_mask(locals()['exo_' + str(0)], locals()[name], alpha=0.5)
66 |
67 | # Exo&Ego plots
68 | fig, ax = plt.subplots(4, max(num_clu, num_exo), figsize=(8, 8))
69 | for axi in ax.ravel():
70 | axi.set_axis_off()
71 | for k in range(num_exo):
72 | ax[0, k].imshow(eval('exo_aff' + str(k)))
73 | ax[0, k].set_title("exo_" + aff_str)
74 | for k in range(num_clu):
75 | ax[1, k].imshow(eval('sim_map' + str(k)))
76 | ax[1, k].set_title('PartIoU_' + str(round(part_score[k], 2)))
77 | ax[2, k].imshow(eval('exo_sim_map' + str(k)))
78 | ax[2, k].set_title('sim_map_' + str(k))
79 | ax[3, 0].imshow(ego_pred)
80 | ax[3, 0].set_title(aff_str)
81 | ax[3, 1].imshow(ego_sam)
82 | ax[3, 1].set_title('Saliency')
83 |
84 | os.makedirs(os.path.join(args.save_path, 'viz_train'), exist_ok=True)
85 | fig_name = os.path.join(args.save_path, 'viz_train', 'cam_' + str(epoch) + '_' + str(step) + '.jpg')
86 | plt.tight_layout()
87 | plt.savefig(fig_name)
88 | plt.close()
89 |
90 |
91 | def viz_pred_test(args, image, ego_pred, GT_mask, aff_list, aff_label, img_name, epoch=None):
92 | mean = torch.as_tensor([0.485, 0.456, 0.406], dtype=image.dtype, device=image.device).view(-1, 1, 1)
93 | std = torch.as_tensor([0.229, 0.224, 0.225], dtype=image.dtype, device=image.device).view(-1, 1, 1)
94 | mean = mean.view(-1, 1, 1)
95 | std = std.view(-1, 1, 1)
96 | img = image.squeeze(0) * std + mean
97 | img = img.detach().cpu().numpy() * 255
98 | img = Image.fromarray(img.transpose(1, 2, 0).astype(np.uint8))
99 |
100 | gt = Image.fromarray(GT_mask)
101 | gt_result = overlay_mask(img, gt, alpha=0.5)
102 | aff_str = aff_list[aff_label.item()]
103 |
104 | ego_pred = Image.fromarray(ego_pred)
105 | ego_pred = overlay_mask(img, ego_pred, alpha=0.5)
106 |
107 | fig, ax = plt.subplots(1, 3, figsize=(10, 6))
108 | for axi in ax.ravel():
109 | axi.set_axis_off()
110 | ax[0].imshow(img)
111 | ax[0].set_title('ego')
112 | ax[1].imshow(ego_pred)
113 | ax[1].set_title(aff_str)
114 | ax[2].imshow(gt_result)
115 | ax[2].set_title('GT')
116 |
117 | os.makedirs(os.path.join(args.save_path, 'viz_test'), exist_ok=True)
118 | if epoch:
119 | fig_name = os.path.join(args.save_path, 'viz_test', "epoch" + str(epoch) + '_' + img_name + '.jpg')
120 | else:
121 | fig_name = os.path.join(args.save_path, 'viz_test', img_name + '.jpg')
122 | plt.savefig(fig_name)
123 | plt.close()
124 |
--------------------------------------------------------------------------------