├── data ├── __init__.py ├── cvt_idx.py ├── rafdb.py ├── affectnet.py ├── randaugment.py ├── fer2013.py ├── transforms.py ├── base_dataset.py └── celeba.py ├── models ├── transformers │ ├── __init__.py │ ├── position_encoding.py │ ├── transformer_predictor.py │ └── transformer.py ├── __init__.py ├── lewel.py └── fra.py ├── docs └── face-framework.png ├── requirements.txt ├── backbone ├── __init__.py └── resnet.py ├── utils ├── __init__.py ├── batch_norm.py ├── extract_backbone.py ├── lr_schedule.py ├── init.py ├── LARS.py ├── utils.py └── dist_utils.py ├── .gitignore ├── launch.py ├── README.md ├── engine.py ├── LICENSE └── main.py /data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/transformers/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /docs/face-framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zaczgao/Facial_Region_Awareness/HEAD/docs/face-framework.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.10.2+cu111 2 | torchvision==0.11.3+cu111 3 | tensorboard 4 | classy_vision 5 | pandas -------------------------------------------------------------------------------- /backbone/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: CC-BY-NC-4.0 3 | 4 | from .resnet import * 5 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: CC-BY-NC-4.0 3 | 4 | from .batch_norm import get_norm 5 | from .LARS import LARS 6 | from .dist_utils import init_distributed_mode -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: CC-BY-NC-4.0 3 | 4 | from .lewel import LEWELB, LEWELB_EMAN 5 | from .fra import FRAB, FRAB_EMAN 6 | 7 | 8 | 9 | def get_model(model): 10 | """ 11 | Args: 12 | model (str or callable): 13 | 14 | Returns: 15 | Model 16 | """ 17 | if isinstance(model, str): 18 | model = { 19 | "LEWELB": LEWELB, 20 | "LEWELB_EMAN": LEWELB_EMAN, 21 | "FRAB": FRAB, 22 | "FRAB_EMAN": FRAB_EMAN, 23 | }[model] 24 | return model -------------------------------------------------------------------------------- /utils/batch_norm.py: -------------------------------------------------------------------------------- 1 | # Original copyright Amazon.com, Inc. or its affiliates, under CC-BY-NC-4.0 License. 2 | # Modifications Copyright Lang Huang (laynehuang@outlook.com). All Rights Reserved. 3 | # SPDX-License-Identifier: CC-BY-NC-4.0 4 | 5 | from torch import nn 6 | 7 | 8 | def get_norm(norm): 9 | """ 10 | Args: 11 | norm (str or callable): 12 | 13 | Returns: 14 | nn.Module or None: the normalization layer 15 | """ 16 | if isinstance(norm, str): 17 | if len(norm) == 0: 18 | return None 19 | norm = { 20 | "BN": nn.BatchNorm2d, 21 | "BN1d": nn.BatchNorm1d, 22 | "SyncBN": nn.SyncBatchNorm, 23 | "GN": lambda channels: nn.GroupNorm(32, channels), 24 | "IN": lambda channels: nn.InstanceNorm2d(channels, affine=True), 25 | "None": None, 26 | }[norm] 27 | return norm 28 | -------------------------------------------------------------------------------- /utils/extract_backbone.py: -------------------------------------------------------------------------------- 1 | # Copyright Lang Huang (laynehuang@outlook.com). All Rights Reserved. 2 | # SPDX-License-Identifier: CC-BY-NC-4.0 3 | 4 | import sys 5 | import torch 6 | 7 | if __name__ == "__main__": 8 | input = sys.argv[1] 9 | 10 | obj = torch.load(input, map_location="cpu") 11 | print("Loading {} (epoch {})".format(input, obj['epoch'])) 12 | obj = obj["state_dict"] 13 | 14 | newmodel = {} 15 | for k, v in obj.items(): 16 | if not (k.startswith("module.encoder_q.backbone") or k.startswith("module.online_net.backbone")) or 'fc' in k: 17 | continue 18 | old_k = k 19 | k = k.replace("backbone.", "") 20 | k = k.replace("module.encoder_q.", "") 21 | k = k.replace("module.online_net.", "") 22 | print(old_k, "->", k) 23 | newmodel[k] = v 24 | 25 | with open(sys.argv[2], "wb") as f: 26 | torch.save(newmodel, f, _use_new_zipfile_serialization=False) 27 | -------------------------------------------------------------------------------- /data/cvt_idx.py: -------------------------------------------------------------------------------- 1 | # Copyright Lang Huang (laynehuang@outlook.com). All Rights Reserved. 2 | # SPDX-License-Identifier: CC-BY-NC-4.0 3 | 4 | import os 5 | 6 | if __name__ == "__main__": 7 | train_file = "/mnt/lustre/share/data/images/meta/train.txt" 8 | idx_file = "data/10percent.txt" 9 | out_file = idx_file + ".ext" 10 | max_class = 1000 11 | 12 | with open(idx_file, "r") as fin, open(train_file, "r") as f_train: 13 | all_samples = {} 14 | idx_samples = [] 15 | selected_samples = [] 16 | for line in f_train.readlines(): 17 | name, label = line.strip().split() 18 | label = int(label) 19 | if label < max_class: 20 | base_name = name.split("/")[1] 21 | all_samples[base_name] = (label, name) 22 | print(f"len of all samples: {len(all_samples)}") 23 | 24 | for line in fin.readlines(): 25 | nm = line.strip() 26 | selected_samples.append(all_samples[nm]) 27 | 28 | print(f"Len of selected samples {len(selected_samples)}") 29 | 30 | with open(out_file, "w") as fout: 31 | for (lb, nm) in selected_samples: 32 | fout.write(f"{lb} {nm}\n") 33 | -------------------------------------------------------------------------------- /utils/lr_schedule.py: -------------------------------------------------------------------------------- 1 | # Original copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: CC-BY-NC-4.0 3 | 4 | import math 5 | 6 | 7 | def warmup_learning_rate(optimizer, curr_step, warmup_step, args): 8 | """linearly warm up learning rate""" 9 | lr = args.lr 10 | scalar = float(curr_step) / float(max(1, warmup_step)) 11 | scalar = min(1., max(0., scalar)) 12 | lr *= scalar 13 | for param_group in optimizer.param_groups: 14 | param_group['lr'] = lr 15 | 16 | 17 | def adjust_learning_rate(optimizer, epoch, args): 18 | """Decay the learning rate based on schedule""" 19 | lr = args.lr 20 | if args.cos: # cosine lr schedule 21 | progress = float(epoch - args.warmup_epoch) / float(args.epochs - args.warmup_epoch) 22 | lr *= 0.5 * (1. + math.cos(math.pi * progress)) 23 | else: # stepwise lr schedule 24 | for milestone in args.schedule: 25 | lr *= 0.1 if epoch >= milestone else 1. 26 | for param_group in optimizer.param_groups: 27 | param_group['lr'] = lr 28 | 29 | 30 | def adjust_learning_rate_with_min(optimizer, epoch, args): 31 | """Decay the learning rate based on schedule""" 32 | lr = args.lr 33 | if args.cos: # cosine lr schedule 34 | min_lr = args.cos_min_lr 35 | progress = float(epoch - args.warmup_epoch) / float(args.epochs - args.warmup_epoch) 36 | lr = min_lr + 0.5 * (lr - min_lr) * (1. + math.cos(math.pi * progress)) 37 | else: # stepwise lr schedule 38 | for milestone in args.schedule: 39 | lr *= 0.1 if epoch >= milestone else 1. 40 | for param_group in optimizer.param_groups: 41 | param_group['lr'] = lr 42 | -------------------------------------------------------------------------------- /utils/init.py: -------------------------------------------------------------------------------- 1 | # Original copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: CC-BY-NC-4.0 3 | 4 | import torch.nn as nn 5 | 6 | 7 | def c2_xavier_fill(module: nn.Module) -> None: 8 | """ 9 | Initialize `module.weight` using the "XavierFill" implemented in Caffe2. 10 | Also initializes `module.bias` to 0. 11 | 12 | Args: 13 | module (torch.nn.Module): module to initialize. 14 | """ 15 | # Caffe2 implementation of XavierFill in fact 16 | # corresponds to kaiming_uniform_ in PyTorch 17 | nn.init.kaiming_uniform_(module.weight, a=1) # pyre-ignore 18 | if module.bias is not None: # pyre-ignore 19 | nn.init.constant_(module.bias, 0) 20 | 21 | 22 | def c2_msra_fill(module: nn.Module) -> None: 23 | """ 24 | Initialize `module.weight` using the "MSRAFill" implemented in Caffe2. 25 | Also initializes `module.bias` to 0. 26 | 27 | Args: 28 | module (torch.nn.Module): module to initialize. 29 | """ 30 | # pyre-ignore 31 | nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu") 32 | if module.bias is not None: # pyre-ignore 33 | nn.init.constant_(module.bias, 0) 34 | 35 | 36 | def normal_init(module: nn.Module, std=0.01): 37 | nn.init.normal_(module.weight, std=std) 38 | if module.bias is not None: 39 | nn.init.constant_(module.bias, 0) 40 | 41 | 42 | def init_weights(module, init_linear='normal'): 43 | assert init_linear in ['normal', 'kaiming'], \ 44 | "Undefined init_linear: {}".format(init_linear) 45 | for m in module.modules(): 46 | if isinstance(m, nn.Linear): 47 | if init_linear == 'normal': 48 | normal_init(m, std=0.01) 49 | else: 50 | c2_msra_fill(m) 51 | elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.GroupNorm, nn.SyncBatchNorm)): 52 | if m.weight is not None: 53 | nn.init.constant_(m.weight, 1) 54 | if m.bias is not None: 55 | nn.init.constant_(m.bias, 0) 56 | elif isinstance(m, nn.Conv1d): 57 | c2_msra_fill(m) 58 | -------------------------------------------------------------------------------- /models/transformers/position_encoding.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # # Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/position_encoding.py 3 | """ 4 | Various positional encodings for the transformer. 5 | """ 6 | import math 7 | 8 | import torch 9 | from torch import nn 10 | 11 | 12 | class PositionEmbeddingSine(nn.Module): 13 | """ 14 | This is a more standard version of the position embedding, very similar to the one 15 | used by the Attention is all you need paper, generalized to work on images. 16 | """ 17 | 18 | def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): 19 | super().__init__() 20 | self.num_pos_feats = num_pos_feats 21 | self.temperature = temperature 22 | self.normalize = normalize 23 | if scale is not None and normalize is False: 24 | raise ValueError("normalize should be True if scale is passed") 25 | if scale is None: 26 | scale = 2 * math.pi 27 | self.scale = scale 28 | 29 | def forward(self, x, mask=None): 30 | if mask is None: 31 | mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool) 32 | not_mask = ~mask 33 | y_embed = not_mask.cumsum(1, dtype=torch.float32) 34 | x_embed = not_mask.cumsum(2, dtype=torch.float32) 35 | if self.normalize: 36 | eps = 1e-6 37 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale 38 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale 39 | 40 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) 41 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) 42 | 43 | pos_x = x_embed[:, :, :, None] / dim_t 44 | pos_y = y_embed[:, :, :, None] / dim_t 45 | pos_x = torch.stack( 46 | (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4 47 | ).flatten(3) 48 | pos_y = torch.stack( 49 | (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4 50 | ).flatten(3) 51 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) 52 | return pos 53 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # 132 | ckpts/ 133 | 134 | 135 | 136 | # PyCharm 137 | /.idea 138 | 139 | # Sphinx 140 | /doc/build 141 | 142 | # Python 143 | __pycache__ 144 | *.pyc 145 | *.egg-info 146 | 147 | # macOS 148 | .DS_Store 149 | */.DS_Store -------------------------------------------------------------------------------- /data/rafdb.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | """ 5 | """ 6 | 7 | import os 8 | import sys 9 | import numpy as np 10 | from tqdm import tqdm 11 | from PIL import Image 12 | import matplotlib.pyplot as plt 13 | 14 | if sys.platform == 'win32': 15 | os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" 16 | 17 | import torch 18 | import torch.utils.data as data 19 | import torchvision.transforms as transforms 20 | 21 | # Root directory of the project 22 | try: 23 | abspath = os.path.abspath(__file__) 24 | except NameError: 25 | abspath = os.getcwd() 26 | ROOT_DIR = os.path.dirname(abspath) 27 | 28 | 29 | IMG_EXTENSIONS = [ 30 | '.jpg', '.JPG', '.jpeg', '.JPEG', 31 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 32 | ] 33 | 34 | 35 | def is_image_file(filename): 36 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 37 | 38 | 39 | class RAFDB(data.Dataset): 40 | def __init__(self, root='/media/jiaren/DataSet/basic/', split='train', transform=None): 41 | super().__init__() 42 | 43 | self.root = root 44 | 45 | image_list_file = os.path.join(root, "EmoLabel", "list_patition_label.txt") 46 | self.image_list_file = image_list_file 47 | self.split = split 48 | self.transform = transform 49 | 50 | self.samples = [] 51 | self.targets = [] 52 | with open(self.image_list_file, 'r') as f: 53 | for i, img_file in enumerate(f): 54 | img_file = img_file.strip() 55 | img_file = img_file.split(' ') 56 | if split in img_file[0]: 57 | self.samples.append(os.path.join(root, "Image", "aligned", img_file[0][:-4]+'_aligned.jpg')) 58 | self.targets.append(int(img_file[1]) - 1) 59 | 60 | def __getitem__(self, index): 61 | img_file = self.samples[index] 62 | image = Image.open(img_file) 63 | 64 | if image.mode != 'RGB': 65 | image = image.convert("RGB") 66 | 67 | target = self.targets[index] 68 | 69 | if self.transform is not None: 70 | image = self.transform(image) 71 | 72 | return image, target, index 73 | 74 | def __len__(self): 75 | return len(self.samples) #12271 # 76 | 77 | 78 | if __name__ == '__main__': 79 | display_transform = transforms.Compose([ 80 | transforms.Resize((224, 224)), 81 | transforms.ToTensor() 82 | ]) 83 | 84 | split = "train" 85 | dataset = RAFDB(root="../data/RAFDB/basic", split=split, transform=display_transform) 86 | print(len(dataset)) 87 | print(set(dataset.targets)) 88 | 89 | loader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=8, pin_memory=True, 90 | drop_last=False) 91 | 92 | with torch.no_grad(): 93 | for i, (images, target, _) in enumerate(tqdm(loader)): 94 | img = np.clip(images.cpu().numpy(), 0, 1) # [0, 1] 95 | img = img.transpose(0, 2, 3, 1) 96 | img = (img * 255).astype(np.uint8) 97 | img = img.squeeze() 98 | 99 | fig, axs = plt.subplots(1, 1, figsize=(8, 8)) 100 | axs.imshow(img) 101 | axs.axis("off") 102 | plt.show() 103 | -------------------------------------------------------------------------------- /launch.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | 3 | import os 4 | import sys 5 | import socket 6 | import random 7 | import argparse 8 | import subprocess 9 | import torch 10 | 11 | 12 | def _find_free_port(): 13 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 14 | sock.bind(("", 0)) 15 | port = sock.getsockname()[1] 16 | sock.close() 17 | return port 18 | 19 | 20 | def _get_rand_port(): 21 | return random.randrange(20000, 60000) 22 | 23 | 24 | def init_workdir(): 25 | ROOT = os.path.dirname(os.path.abspath(__file__)) 26 | os.chdir(ROOT) 27 | sys.path.insert(0, ROOT) 28 | 29 | if __name__ == '__main__': 30 | parser = argparse.ArgumentParser(description='Launcher') 31 | parser.add_argument('--launch', type=str, default='tools/train.py', 32 | help='Specify launcher script.') 33 | parser.add_argument('--dist', type=int, default=1, 34 | help='Whether start by torch.distributed.launch.') 35 | parser.add_argument('--np', type=int, default=-1, 36 | help='number of processes per node.') 37 | parser.add_argument('--nn', type=int, default=1, 38 | help='number of workers in total.') 39 | parser.add_argument('--port', type=int, default=-1, 40 | help='master port for communication') 41 | parser.add_argument('--nr', type=int, default=0, 42 | help='node rank.') 43 | parser.add_argument('--master_address', '-ma', type=str, default="127.0.0.1") 44 | parser.add_argument('--device', default=None, type=str, 45 | help='indices of GPUs to enable (default: all)') 46 | args, other_args = parser.parse_known_args() 47 | 48 | if args.device: 49 | os.environ["CUDA_VISIBLE_DEVICES"] = args.device 50 | cmd = f"CUDA_VISIBLE_DEVICES={args.device} " 51 | else: 52 | cmd = f"" 53 | 54 | init_workdir() 55 | master_address = args.master_address 56 | num_processes_per_worker = torch.cuda.device_count() if args.np < 0 else args.np 57 | num_workers = args.nn 58 | node_rank = args.nr 59 | 60 | if args.port > 0: 61 | master_port = args.port 62 | elif num_workers == 1: 63 | master_port = _find_free_port() 64 | else: 65 | master_port = _get_rand_port() 66 | 67 | if args.dist >= 1: 68 | print(f'Start {args.launch} by torch.distributed.launch with port {master_port}!', flush=True) 69 | os.environ['NPROC_PER_NODE'] = str(num_processes_per_worker) 70 | cmd += f'python3 -m torch.distributed.launch \ 71 | --nproc_per_node={num_processes_per_worker} \ 72 | --nnodes={num_workers} \ 73 | --node_rank={node_rank} \ 74 | --master_addr={master_address} \ 75 | --master_port={master_port} \ 76 | {args.launch}' 77 | else: 78 | print(f'Start {args.launch}!', flush=True) 79 | cmd += f'python3 -u {args.launch}' 80 | 81 | for argv in other_args: 82 | cmd += f' {argv}' 83 | 84 | with open('./log.txt', 'wb') as f: 85 | proc = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE) 86 | while True: 87 | text = proc.stdout.readline() 88 | f.write(text) 89 | f.flush() 90 | sys.stdout.buffer.write(text) 91 | sys.stdout.buffer.flush() 92 | exit_code = proc.poll() 93 | if exit_code is not None: 94 | break 95 | sys.exit(exit_code) 96 | -------------------------------------------------------------------------------- /data/affectnet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | """ 5 | https://github.com/yaoing/dan 6 | https://github.com/ElenaRyumina/EMO-AffectNetModel 7 | https://github.com/PanosAntoniadis/emotion-gcn 8 | """ 9 | 10 | __author__ = "GZ" 11 | 12 | import os 13 | import sys 14 | from shutil import copy 15 | import pandas as pd 16 | import numpy as np 17 | from tqdm import tqdm 18 | import matplotlib.pyplot as plt 19 | 20 | if sys.platform == 'win32': 21 | os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" 22 | 23 | # Root directory of the project 24 | try: 25 | abspath = os.path.abspath(__file__) 26 | except NameError: 27 | abspath = os.getcwd() 28 | ROOT_DIR = os.path.dirname(abspath) 29 | 30 | 31 | # def convert_affectnet(label_file, save_dir): 32 | # df = pd.read_csv(label_file) 33 | # 34 | # for i in range(8): 35 | # for j in ['train','val']: 36 | # os.makedirs(os.path.join(save_dir, "AffectNet", j, i), exist_ok=True) 37 | # 38 | # for i, row in df.iterrows(): 39 | # p = row['phase'] 40 | # l = row['label'] 41 | # copy(row['img_path'], os.path.join(save_dir, "AffectNet", p, l)) 42 | # 43 | # print('convert done.') 44 | # 45 | # 46 | # def get_AffectNet(root, split, transform, num_class=7): 47 | # data_dir = os.path.join(root, split) 48 | # dataset = datasets.ImageFolder(data_dir, transform=transform) 49 | # if num_class == 7: # ignore the 8-th class 50 | # idx = [i for i in range(len(dataset)) if dataset.imgs[i][1] != 7] 51 | # dataset = data.Subset(dataset, idx) 52 | # return dataset 53 | 54 | 55 | def generate_affectnet(img_dir, label_dir, split, save_dir, num_class=7): 56 | assert split in ["train", "val"] 57 | label_file = "training.csv" if split == "train" else "validation.csv" 58 | head_list = ['subDirectory_filePath', 'face_x', 'face_y', 'face_width', 'face_height', 'facial_landmarks', 59 | 'expression', 'valence', 'arousal'] 60 | dict_name_labels = {0: 'Neutral', 1: 'Happiness', 2: 'Sadness', 3: 'Surprise', 4: 'Fear', 5: 'Disgust', 6: 'Anger'} 61 | 62 | df_data_raw = pd.read_csv(os.path.join(label_dir, label_file)) 63 | df_data_raw.expression = pd.to_numeric(df_data_raw.expression, errors='coerce').fillna(100).astype('int64') 64 | 65 | df_data = df_data_raw[df_data_raw['expression'] < num_class] 66 | 67 | for label in range(num_class): 68 | os.makedirs(os.path.join(save_dir, split, str(label)), exist_ok=True) 69 | 70 | file_notfound = [] 71 | for i, row in tqdm(df_data.iterrows(), total=df_data.shape[0]): 72 | label = row['expression'] 73 | img_file = os.path.join(img_dir, row['subDirectory_filePath']) 74 | 75 | if os.path.isfile(img_file): 76 | copy(img_file, os.path.join(save_dir, split, str(label))) 77 | else: 78 | file_notfound.append(img_file) 79 | 80 | # 2/9db2af5a1da8bd77355e8c6a655da519a899ecc42641bf254107bfc0.jpg 81 | print(file_notfound) 82 | 83 | 84 | if __name__ == '__main__': 85 | import torch 86 | import torch.utils.data as data 87 | from torchvision import transforms, datasets 88 | from data.base_dataset import ImageFolderInstance 89 | from data.sampler import DistributedImbalancedSampler, DistributedSamplerWrapper, ImbalancedDatasetSampler 90 | 91 | # label_file = "../data/FER/AffectNet/affectnet.csv" 92 | # save_dir = "../data/FER" 93 | # convert_affectnet(label_file, save_dir) 94 | 95 | img_dir = "../data/FER/AffectNet/Manually_Annotated_Images" 96 | label_dir = '../data/FER/AffectNet/Manually_Annotated_file_lists' 97 | split = "train" 98 | save_dir = "../data/FER/AffectNet_subset" 99 | # generate_affectnet(img_dir, label_dir, split, save_dir) 100 | 101 | data_root = save_dir 102 | display_transform = transforms.Compose([ 103 | transforms.Resize((224, 224)), 104 | transforms.ToTensor() 105 | ]) 106 | 107 | # dataset = get_AffectNet(data_root, split, display_transform, num_class=7) 108 | data_dir = os.path.join(data_root, split) 109 | dataset = ImageFolderInstance(data_dir, transform=display_transform) 110 | print(dataset) 111 | 112 | train_percent = 0.1 113 | if train_percent < 1.0: 114 | num_subset = int(len(dataset) * train_percent) 115 | indices = torch.randperm(len(dataset))[:num_subset] 116 | indices = indices.tolist() 117 | dataset = torch.utils.data.Subset(dataset, indices) 118 | print("Sub train_dataset:\n{}".format(len(dataset))) 119 | 120 | sampler = ImbalancedDatasetSampler(dataset) 121 | loader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=8, pin_memory=True, 122 | drop_last=False) 123 | 124 | with torch.no_grad(): 125 | for i, (images, target, _) in enumerate(tqdm(loader)): 126 | img = np.clip(images.cpu().numpy(), 0, 1) # [0, 1] 127 | img = img.transpose(0, 2, 3, 1) 128 | img = (img * 255).astype(np.uint8) 129 | img = img.squeeze() 130 | 131 | fig, axs = plt.subplots(1, 1, figsize=(8, 8)) 132 | axs.imshow(img) 133 | axs.axis("off") 134 | plt.show() 135 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Self-Supervised Facial Representation Learning with Facial Region Awareness 2 | 3 |

4 | 5 |

6 |

7 | Self-Supervised Facial Representation Learning with Facial Region Awareness (CVPR 2024)
8 | By 9 | Zheng Gao and 10 | Ioannis Patras. 11 |

12 | 13 | ## Introduction 14 | 15 | > **Abstract**: Self-supervised pre-training has been proved to be effective in learning transferable representations that benefit various visual tasks. This paper asks this question: can self-supervised pre-training learn general facial representations for various facial analysis tasks? Recent efforts toward this goal are limited to treating each face image as a whole, i.e., learning consistent facial representations at the image-level, which overlooks the **consistency of local facial representations** (i.e., facial regions like eyes, nose, etc). In this work, we make a **first attempt** to propose a novel self-supervised facial representation learning framework to learn consistent global and local facial representations, Facial Region Awareness (FRA). Specifically, we explicitly enforce the consistency of facial regions by matching the local facial representations across views, which are extracted with learned heatmaps highlighting the facial regions. Inspired by the mask prediction in supervised semantic segmentation, we obtain the heatmaps via cosine similarity between the per-pixel projection of feature maps and facial mask embeddings computed from learnable positional embeddings, which leverage the attention mechanism to globally look up the facial image for facial regions. To learn such heatmaps, we formulate the learning of facial mask embeddings as a deep clustering problem by assigning the pixel features from the feature maps to them. The transfer learning results on facial classification and regression tasks show that our FRA outperforms previous pre-trained models and more importantly, using ResNet as the unified backbone for various tasks, our FRA achieves comparable or even better performance compared with SOTA methods in facial analysis tasks. 16 | 17 | ![framework](docs/face-framework.png) 18 | 19 | 20 | ## Installation 21 | Please refer to `requirement.txt` for the dependencies. Alternatively, you can install dependencies using the following command: 22 | ``` 23 | pip3 install -r requirement.txt 24 | ``` 25 | The repository works with `PyTorch 1.10.2` or higher and `CUDA 11.1`. 26 | 27 | ## Get started 28 | 29 | We provide basic usage of the implementation in the following sections: 30 | 31 | ### Pre-training on VGGFace2 32 | 33 | Download [VGGFace2](https://academictorrents.com/details/535113b8395832f09121bc53ac85d7bc8ef6fa5b) dataset and specify the path to VGGFace2 by `DATA_ROOT="./data/VGG-Face2-crop"`. 34 | 35 | To perform pre-training of the model with ResNet-50 backbone on VGGFace2 with multi-gpu, run: 36 | ``` 37 | python3 launch.py --device=${DEVICES} --launch main.py \ 38 | --arch FRAB --backbone resnet50_encoder \ 39 | --dataset vggface2 --data-root ${DATA_ROOT} \ 40 | --lr 0.9 -b 512 --wd 0.000001 --epochs 50 --cos --warmup-epoch 10 --workers 16 \ 41 | --enc-m 0.996 \ 42 | --norm SyncBN \ 43 | --lewel-loss-weight 0.5 \ 44 | --mask_type="attn" --num_proto 8 --teacher_temp 0.04 --loss_w_cluster 0.1 \ 45 | --amp \ 46 | --save-dir ./ckpts --save-freq 50 --print-freq 100 47 | ``` 48 | `DEVICES` denotes the gpu indices. 49 | 50 | ### Evaluation: Facial expression recognition (FER) 51 | The following is an example of evaluating the pre-trained model on RAFDB dataset, under the setting of fine-tuning both encoder backbone and linear classifier: 52 | ``` 53 | python3 launch.py --device=${DEVICES} --launch main_fer.py \ 54 | -a resnet50 \ 55 | --dataset rafdb --data-root ${FER_DATA_ROOT} \ 56 | --lr 0.0002 --lr_head 0.0002 --optimizer adamw --weight-decay 0.05 --scheduler cos \ 57 | --finetune \ 58 | --epochs 100 --batch-size 256 \ 59 | --amp \ 60 | --workers 16 \ 61 | --eval-freq 5 \ 62 | --model-prefix online_net.backbone \ 63 | --pretrained ${PRETRAINED} \ 64 | --image_size 224 \ 65 | --multiprocessing_distributed 66 | ``` 67 | `PRETRAINED` denotes the path to the pre-trained checkpoint and `FER_DATA_ROOT=/path/to/datasets` is the location for FER datasets. 68 | 69 | ### Evaluation: Face alignment 70 | For evaluation on face alignment, we use [STAR Loss](https://github.com/ZhenglinZhou/STAR) as the downstream backbone. Please refer to [STAR Loss](https://github.com/ZhenglinZhou/STAR). 71 | 72 | 73 | ## Citation 74 | 75 | If you find this repository useful, please consider giving a star :star: and citation: 76 | 77 | ```bibteX 78 | @article{gao2023self, 79 | title={Self-Supervised Representation Learning with Cross-Context Learning between Global and Hypercolumn Features}, 80 | author={Gao, Zheng and Patras, Ioannis}, 81 | journal={arXiv preprint arXiv:2308.13392}, 82 | year={2023} 83 | } 84 | ``` 85 | 86 | ## Acknowledgment 87 | Our project is based on [LEWEL](https://github.com/LayneH/LEWEL). Thanks for their wonderful work. 88 | 89 | 90 | ## License 91 | 92 | This project is released under the [CC-BY-NC 4.0 license](LICENSE). -------------------------------------------------------------------------------- /utils/LARS.py: -------------------------------------------------------------------------------- 1 | # code in this file is adapted from 2 | # https://github.com/yaox12/BYOL-PyTorch/blob/master/optimizer/LARSSGD.py 3 | # Copyright 2020 Xin Yao. Licensed under the MIT License. 4 | # Modifications Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 5 | # SPDX-License-Identifier: CC-BY-NC-4.0 6 | 7 | """ Layer-wise adaptive rate scaling for SGD in PyTorch! """ 8 | import torch 9 | from torch.optim.optimizer import Optimizer, required 10 | 11 | 12 | class LARS(Optimizer): 13 | r"""Implements layer-wise adaptive rate scaling for SGD. 14 | Args: 15 | params (iterable): iterable of parameters to optimize or dicts defining 16 | parameter groups 17 | lr (float): base learning rate (\gamma_0) 18 | momentum (float, optional): momentum factor (default: 0) ("m") 19 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 20 | ("\beta") 21 | dampening (float, optional): dampening for momentum (default: 0) 22 | eta (float, optional): LARS coefficient 23 | nesterov (bool, optional): enables Nesterov momentum (default: False) 24 | Based on Algorithm 1 of the following paper by You, Gitman, and Ginsburg. 25 | Large Batch Training of Convolutional Networks: 26 | https://arxiv.org/abs/1708.03888 27 | Example: 28 | >>> optimizer = LARS(model.parameters(), lr=0.1, momentum=0.9, 29 | >>> weight_decay=1e-4, eta=1e-3) 30 | >>> optimizer.zero_grad() 31 | >>> loss_fn(model(input), target).backward() 32 | >>> optimizer.step() 33 | """ 34 | 35 | def __init__(self, 36 | params, 37 | lr=required, 38 | momentum=0, 39 | dampening=0, 40 | weight_decay=0, 41 | eta=0.001, 42 | nesterov=False, 43 | eps=1e-8): 44 | if lr is not required and lr < 0.0: 45 | raise ValueError("Invalid learning rate: {}".format(lr)) 46 | if momentum < 0.0: 47 | raise ValueError("Invalid momentum value: {}".format(momentum)) 48 | if weight_decay < 0.0: 49 | raise ValueError( 50 | "Invalid weight_decay value: {}".format(weight_decay)) 51 | if eta < 0.0: 52 | raise ValueError("Invalid LARS coefficient value: {}".format(eta)) 53 | 54 | defaults = dict( 55 | lr=lr, momentum=momentum, dampening=dampening, 56 | weight_decay=weight_decay, nesterov=nesterov, eta=eta) 57 | if nesterov and (momentum <= 0 or dampening != 0): 58 | raise ValueError("Nesterov momentum requires a momentum and zero dampening") 59 | 60 | super(LARS, self).__init__(params, defaults) 61 | 62 | self.eps = eps 63 | 64 | def __setstate__(self, state): 65 | super(LARS, self).__setstate__(state) 66 | for group in self.param_groups: 67 | group.setdefault('nesterov', False) 68 | 69 | @torch.no_grad() 70 | def step(self, closure=None): 71 | """Performs a single optimization step. 72 | Arguments: 73 | closure (callable, optional): A closure that reevaluates the model 74 | and returns the loss. 75 | """ 76 | loss = None 77 | if closure is not None: 78 | with torch.enable_grad(): 79 | loss = closure() 80 | 81 | for group in self.param_groups: 82 | weight_decay = group['weight_decay'] 83 | momentum = group['momentum'] 84 | dampening = group['dampening'] 85 | eta = group['eta'] 86 | nesterov = group['nesterov'] 87 | lr = group['lr'] 88 | lars_exclude = group.get('lars_exclude', False) 89 | 90 | for p in group['params']: 91 | if p.grad is None: 92 | continue 93 | 94 | d_p = p.grad 95 | 96 | if lars_exclude: 97 | local_lr = 1. 98 | else: 99 | weight_norm = torch.norm(p).item() 100 | grad_norm = torch.norm(d_p).item() 101 | # Compute local learning rate for this layer 102 | local_lr = eta * weight_norm / \ 103 | (grad_norm + weight_decay * weight_norm + self.eps) 104 | 105 | actual_lr = local_lr * lr 106 | d_p = d_p.add(p, alpha=weight_decay).mul(actual_lr) 107 | if momentum != 0: 108 | param_state = self.state[p] 109 | if 'momentum_buffer' not in param_state: 110 | buf = param_state['momentum_buffer'] = \ 111 | torch.clone(d_p).detach() 112 | else: 113 | buf = param_state['momentum_buffer'] 114 | buf.mul_(momentum).add_(d_p, alpha=1 - dampening) 115 | if nesterov: 116 | d_p = d_p.add(buf, alpha=momentum) 117 | else: 118 | d_p = buf 119 | p.add_(-d_p) 120 | 121 | return loss 122 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | # Original copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: CC-BY-NC-4.0 3 | import os 4 | import numpy as np 5 | import shutil 6 | from sklearn.metrics import accuracy_score 7 | import skimage.io 8 | 9 | import torch 10 | 11 | 12 | def load_netowrk(model, path, checkpoint_key="net"): 13 | if os.path.isfile(path): 14 | print("=> loading checkpoint '{}'".format(path)) 15 | checkpoint = torch.load(path, map_location="cpu") 16 | 17 | # rename pre-trained keys 18 | state_dict = checkpoint[checkpoint_key] 19 | state_dict_new = {k.replace("module.", ""): v for k, v in state_dict.items()} 20 | 21 | msg = model.load_state_dict(state_dict_new) 22 | assert set(msg.missing_keys) == set() 23 | 24 | print("=> loaded pre-trained model '{}'".format(path)) 25 | else: 26 | print("=> no checkpoint found at '{}'".format(path)) 27 | 28 | 29 | def save_checkpoint(state, is_best, epoch, args, filename='checkpoint.pth.tar'): 30 | filename = os.path.join(args.save_dir, filename) 31 | torch.save(state, filename) 32 | # if is_best: 33 | # shutil.copyfile(filename, os.path.join(args.save_dir, 'model_best.pth.tar')) 34 | if args.save_freq > 0 and (epoch + 1) % args.save_freq == 0: 35 | shutil.copyfile(filename, os.path.join(args.save_dir, 'checkpoint_{:04d}.pth.tar'.format(epoch))) 36 | if not args.cos: 37 | if (epoch + 1) in args.schedule: 38 | shutil.copyfile(filename, os.path.join(args.save_dir, 'checkpoint_{:04d}.pth.tar'.format(epoch))) 39 | 40 | 41 | def accuracy(output, target, topk=(1,)): 42 | """Computes the accuracy over the k top predictions for the specified values of k""" 43 | with torch.no_grad(): 44 | maxk = max(topk) 45 | batch_size = target.size(0) 46 | 47 | _, pred = output.topk(maxk, 1, True, True) 48 | pred = pred.t() 49 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 50 | 51 | res = [] 52 | for k in topk: 53 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 54 | res.append(correct_k.mul_(100.0 / batch_size)) 55 | return res 56 | 57 | 58 | def accuracy_multilabel(output, target, threshold=0.5): 59 | """ 60 | https://www.kaggle.com/code/kmkarakaya/multi-label-model-evaluation 61 | """ 62 | with torch.no_grad(): 63 | batch_size, n_class = target.shape 64 | pred = (output >= threshold).to(torch.float32) 65 | 66 | acc = (pred == target).float().sum() * 100.0 / (batch_size * n_class) 67 | 68 | # acc = sklearn.metrics.accuracy_score(gt_S,pred_S) 69 | # f1m = sklearn.metrics.f1_score(gt_S,pred_S,average = 'macro', zero_division=1) 70 | # f1mi = sklearn.metrics.f1_score(gt_S,pred_S,average = 'micro', zero_division=1) 71 | # print('f1_Macro_Score{}'.format(f1m)) 72 | # print('f1_Micro_Score{}'.format(f1mi)) 73 | # print('Accuracy{}'.format(acc)) 74 | 75 | return acc 76 | 77 | 78 | class AverageMeter(object): 79 | """Computes and stores the average and current value""" 80 | def __init__(self, name, fmt=':f'): 81 | self.name = name 82 | self.fmt = fmt 83 | self.reset() 84 | 85 | def reset(self): 86 | self.val = 0 87 | self.avg = 0 88 | self.sum = 0 89 | self.count = 0 90 | 91 | def update(self, val, n=1): 92 | self.val = val 93 | self.sum += val * n 94 | self.count += n 95 | self.avg = self.sum / self.count 96 | 97 | def __str__(self): 98 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 99 | return fmtstr.format(**self.__dict__) 100 | 101 | 102 | class ProgressMeter(object): 103 | def __init__(self, num_batches, meters, prefix=""): 104 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 105 | self.meters = meters 106 | self.prefix = prefix 107 | 108 | def display(self, batch): 109 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 110 | entries += [str(meter) for meter in self.meters] 111 | print('\t'.join(entries), flush=True) 112 | 113 | def _get_batch_fmtstr(self, num_batches): 114 | num_digits = len(str(num_batches // 1)) 115 | fmt = '{:' + str(num_digits) + 'd}' 116 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 117 | 118 | 119 | class InstantMeter(object): 120 | """Computes and stores the average and current value""" 121 | def __init__(self, name, fmt=':f'): 122 | self.name = name 123 | self.fmt = fmt 124 | self.val = 0 125 | 126 | def update(self, val): 127 | self.val = val 128 | 129 | def __str__(self): 130 | fmtstr = '{name} {val' + self.fmt + '}' 131 | return fmtstr.format(**self.__dict__) 132 | 133 | 134 | def denormalize_batch(batch, mean, std): 135 | """denormalize for visualization""" 136 | dtype = batch.dtype 137 | mean = torch.as_tensor(mean, dtype=dtype, device=batch.device) 138 | std = torch.as_tensor(std, dtype=dtype, device=batch.device) 139 | mean = mean.view(-1, 1, 1) 140 | std = std.view(-1, 1, 1) 141 | batch = batch * std + mean 142 | return batch 143 | 144 | def dump_image(imgNorm, mean, std, filepath=None, verbose=False): 145 | """Denormalizes the output image and optionally plots the landmark coordinates onto the image 146 | 147 | Args: 148 | normalized_image (torch.tensor): Image reconstruction output from the model (normalized) 149 | landmark_coords (torch.tensor): x, y coordinates in normalized range -1 to 1 150 | out_name (str, optional): file to write to 151 | Returns: 152 | np.array: uint8 image data stored in numpy format 153 | """ 154 | if imgNorm.dim() < 4: 155 | imgNorm = imgNorm.unsqueeze(0) 156 | 157 | img = denormalize_batch(imgNorm, mean, std) 158 | img = np.clip(img.cpu().numpy(), 0, 1) 159 | img = (img.transpose(0, 2, 3, 1) * 255).astype(np.uint8) 160 | 161 | if filepath is not None: 162 | skimage.io.imsave(filepath, img[0]) 163 | 164 | if verbose: 165 | num = min(img.shape[0], 9) 166 | show_images(img[:num], 3, 3) 167 | plt.show() 168 | return img 169 | 170 | 171 | def calc_params(net, verbose=False): 172 | num_params = 0 173 | for param in net.parameters(): 174 | num_params += param.numel() 175 | if verbose: 176 | print(net) 177 | print('Total number of parameters : %.3f M' % (num_params / 1e6)) 178 | 179 | return num_params 180 | 181 | 182 | if __name__ == '__main__': 183 | output = torch.tensor([[0.35,0.4,0.9], [0.2,0.6,0.8]]) 184 | target = torch.tensor([[1, 0, 1], [0, 1, 1]]) 185 | acc = accuracy_multilabel(output, target) 186 | print(acc) 187 | -------------------------------------------------------------------------------- /models/transformers/transformer_predictor.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/detr.py 3 | import os 4 | import sys 5 | 6 | import fvcore.nn.weight_init as weight_init 7 | import torch 8 | from torch import nn 9 | from torch.nn import functional as F 10 | 11 | SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) 12 | sys.path.append(os.path.dirname(SCRIPT_DIR)) 13 | 14 | from transformers.position_encoding import PositionEmbeddingSine 15 | from transformers.transformer import Transformer 16 | 17 | 18 | class TransformerPredictor(nn.Module): 19 | def __init__( 20 | self, 21 | in_channels, 22 | mask_classification=True, 23 | *, 24 | num_classes: int, 25 | hidden_dim: int, 26 | num_queries: int, 27 | nheads: int, 28 | dropout: float, 29 | dim_feedforward: int, 30 | enc_layers: int, 31 | dec_layers: int, 32 | pre_norm: bool, 33 | deep_supervision: bool, 34 | mask_dim: int, 35 | enforce_input_project: bool, 36 | ): 37 | """ 38 | NOTE: this interface is experimental. 39 | Args: 40 | in_channels: channels of the input features 41 | mask_classification: whether to add mask classifier or not 42 | num_classes: number of classes 43 | hidden_dim: Transformer feature dimension 44 | num_queries: number of queries 45 | nheads: number of heads 46 | dropout: dropout in Transformer 47 | dim_feedforward: feature dimension in feedforward network 48 | enc_layers: number of Transformer encoder layers 49 | dec_layers: number of Transformer decoder layers 50 | pre_norm: whether to use pre-LayerNorm or not 51 | deep_supervision: whether to add supervision to every decoder layers 52 | mask_dim: mask feature dimension 53 | enforce_input_project: add input project 1x1 conv even if input 54 | channels and hidden dim is identical 55 | """ 56 | super().__init__() 57 | 58 | self.mask_classification = mask_classification 59 | 60 | # positional encoding 61 | N_steps = hidden_dim // 2 62 | self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True) 63 | 64 | transformer = Transformer( 65 | d_model=hidden_dim, 66 | dropout=dropout, 67 | nhead=nheads, 68 | dim_feedforward=dim_feedforward, 69 | num_encoder_layers=enc_layers, 70 | num_decoder_layers=dec_layers, 71 | normalize_before=pre_norm, 72 | return_intermediate_dec=deep_supervision, 73 | ) 74 | 75 | self.num_queries = num_queries 76 | self.transformer = transformer 77 | hidden_dim = transformer.d_model 78 | 79 | self.query_embed = nn.Embedding(num_queries, hidden_dim) 80 | 81 | if in_channels != hidden_dim or enforce_input_project: 82 | # self.input_proj = Conv2d(in_channels, hidden_dim, kernel_size=1) 83 | self.input_proj = nn.Conv2d(in_channels, hidden_dim, kernel_size=1) 84 | weight_init.c2_xavier_fill(self.input_proj) 85 | else: 86 | self.input_proj = nn.Sequential() 87 | self.aux_loss = deep_supervision 88 | 89 | # output FFNs 90 | if self.mask_classification: 91 | self.class_embed = nn.Linear(hidden_dim, num_classes + 1) 92 | self.mask_embed = MLP(hidden_dim, hidden_dim, mask_dim, 3) 93 | 94 | def forward(self, x, mask_features=None): 95 | pos = self.pe_layer(x) 96 | 97 | src = x 98 | mask = None 99 | hs, memory = self.transformer(self.input_proj(src), mask, self.query_embed.weight, pos) 100 | 101 | if self.mask_classification: 102 | outputs_class = self.class_embed(hs) 103 | out = {"pred_logits": outputs_class[-1]} 104 | else: 105 | out = {} 106 | 107 | if self.aux_loss: 108 | # [l, bs, queries, embed] 109 | mask_embed = self.mask_embed(hs) 110 | outputs_seg_masks = torch.einsum("lbqc,bchw->lbqhw", mask_embed, mask_features) 111 | out["pred_masks"] = outputs_seg_masks[-1] 112 | out["aux_outputs"] = self._set_aux_loss( 113 | outputs_class if self.mask_classification else None, outputs_seg_masks 114 | ) 115 | else: 116 | # FIXME h_boxes takes the last one computed, keep this in mind 117 | # [bs, queries, embed] 118 | mask_embed = self.mask_embed(hs[-1]) 119 | out["mask_embed"] = mask_embed 120 | if mask_features is not None: 121 | outputs_seg_masks = torch.einsum("bqc,bchw->bqhw", mask_embed, mask_features) 122 | out["pred_masks"] = outputs_seg_masks 123 | return out 124 | 125 | @torch.jit.unused 126 | def _set_aux_loss(self, outputs_class, outputs_seg_masks): 127 | # this is a workaround to make torchscript happy, as torchscript 128 | # doesn't support dictionary with non-homogeneous values, such 129 | # as a dict having both a Tensor and a list. 130 | if self.mask_classification: 131 | return [ 132 | {"pred_logits": a, "pred_masks": b} 133 | for a, b in zip(outputs_class[:-1], outputs_seg_masks[:-1]) 134 | ] 135 | else: 136 | return [{"pred_masks": b} for b in outputs_seg_masks[:-1]] 137 | 138 | 139 | class MLP(nn.Module): 140 | """Very simple multi-layer perceptron (also called FFN)""" 141 | 142 | def __init__(self, input_dim, hidden_dim, output_dim, num_layers): 143 | super().__init__() 144 | self.num_layers = num_layers 145 | h = [hidden_dim] * (num_layers - 1) 146 | self.layers = nn.ModuleList( 147 | nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) 148 | ) 149 | 150 | def forward(self, x): 151 | for i, layer in enumerate(self.layers): 152 | x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) 153 | return x 154 | 155 | 156 | if __name__ == '__main__': 157 | from utils.utils import calc_params 158 | 159 | model = TransformerPredictor(in_channels=2048, hidden_dim=256, num_queries=100, nheads=8, dropout=0.1, dim_feedforward=2048, 160 | enc_layers=0, dec_layers=1, pre_norm=False, deep_supervision=False, mask_dim=256, 161 | enforce_input_project=False, mask_classification=False, num_classes=0) 162 | print(model) 163 | 164 | x = torch.randn(16, 2048, 7, 7) 165 | mask_features = torch.randn(16, 256, 7, 7) 166 | out = model(x, mask_features) 167 | 168 | calc_params(model) 169 | -------------------------------------------------------------------------------- /utils/dist_utils.py: -------------------------------------------------------------------------------- 1 | # some code in this file is adapted from 2 | # https://github.com/facebookresearch/moco 3 | # Original Copyright 2020 Facebook, Inc. and its affiliates. Licensed under the CC-BY-NC 4.0 License. 4 | # Modifications Copyright Lang Huang (laynehuang@outlook.com). All Rights Reserved. 5 | # SPDX-License-Identifier: CC-BY-NC-4.0 6 | 7 | import os 8 | import sys 9 | import random 10 | import datetime 11 | import torch 12 | import torch.distributed as dist 13 | 14 | 15 | @torch.no_grad() 16 | def batch_shuffle_ddp(x): 17 | """ 18 | Batch shuffle, for making use of BatchNorm. 19 | *** Only support DistributedDataParallel (DDP) model. *** 20 | """ 21 | # gather from all gpus 22 | batch_size_this = x.shape[0] 23 | x_gather = concat_all_gather(x) 24 | batch_size_all = x_gather.shape[0] 25 | 26 | num_gpus = batch_size_all // batch_size_this 27 | 28 | # random shuffle index 29 | idx_shuffle = torch.randperm(batch_size_all).cuda() 30 | 31 | # broadcast to all gpus 32 | torch.distributed.broadcast(idx_shuffle, src=0) 33 | 34 | # index for restoring 35 | idx_unshuffle = torch.argsort(idx_shuffle) 36 | 37 | # shuffled index for this gpu 38 | gpu_idx = torch.distributed.get_rank() 39 | idx_this = idx_shuffle.view(num_gpus, -1)[gpu_idx] 40 | 41 | return x_gather[idx_this], idx_unshuffle 42 | 43 | 44 | @torch.no_grad() 45 | def batch_unshuffle_ddp(x, idx_unshuffle): 46 | """ 47 | Undo batch shuffle. 48 | *** Only support DistributedDataParallel (DDP) model. *** 49 | """ 50 | # gather from all gpus 51 | batch_size_this = x.shape[0] 52 | x_gather = concat_all_gather(x) 53 | batch_size_all = x_gather.shape[0] 54 | 55 | num_gpus = batch_size_all // batch_size_this 56 | 57 | # restored index for this gpu 58 | gpu_idx = torch.distributed.get_rank() 59 | idx_this = idx_unshuffle.view(num_gpus, -1)[gpu_idx] 60 | 61 | return x_gather[idx_this] 62 | 63 | 64 | @torch.no_grad() 65 | def concat_all_gather(tensor): 66 | """ 67 | Performs all_gather operation on the provided tensors. 68 | *** Warning ***: torch.distributed.all_gather has no gradient. 69 | """ 70 | tensors_gather = [torch.ones_like(tensor) 71 | for _ in range(torch.distributed.get_world_size())] 72 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False) 73 | 74 | output = torch.cat(tensors_gather, dim=0) 75 | return output 76 | 77 | 78 | def is_dist_avail_and_initialized(): 79 | if not dist.is_available(): 80 | return False 81 | if not dist.is_initialized(): 82 | return False 83 | return True 84 | 85 | 86 | def get_world_size(): 87 | if not is_dist_avail_and_initialized(): 88 | return 1 89 | return dist.get_world_size() 90 | 91 | 92 | def get_rank(): 93 | if not is_dist_avail_and_initialized(): 94 | return 0 95 | return dist.get_rank() 96 | 97 | def init_distributed_mode(args): 98 | if is_dist_avail_and_initialized(): 99 | return 100 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 101 | args.rank = int(os.environ["RANK"]) 102 | args.world_size = int(os.environ['WORLD_SIZE']) 103 | args.gpu = int(os.environ['LOCAL_RANK']) 104 | 105 | elif 'SLURM_PROCID' in os.environ: 106 | args.rank = int(os.environ['SLURM_PROCID']) 107 | args.gpu = args.rank % torch.cuda.device_count() 108 | elif torch.cuda.is_available(): 109 | print('Will run the code on one GPU.') 110 | args.rank, args.gpu, args.world_size = 0, 0, 1 111 | os.environ['MASTER_ADDR'] = '127.0.0.1' 112 | os.environ['MASTER_PORT'] = str(random.randint(0, 9999) + 40000) 113 | else: 114 | print('Does not support training without GPU.') 115 | sys.exit(1) 116 | 117 | print("Use GPU: {} ranked {} out of {} gpus for training".format(args.gpu, args.rank, args.world_size)) 118 | if args.multiprocessing_distributed: 119 | dist.init_process_group( 120 | backend="nccl", 121 | init_method=args.dist_url, 122 | world_size=args.world_size, 123 | timeout=datetime.timedelta(hours=5), 124 | rank=args.rank, 125 | ) 126 | print('| distributed init (rank {}): {}'.format( 127 | args.rank, args.dist_url), flush=True) 128 | dist.barrier() 129 | 130 | torch.cuda.set_device(args.gpu) 131 | setup_for_distributed(args.rank == 0) 132 | 133 | 134 | def setup_for_distributed(is_master): 135 | """ 136 | This function disables printing when not in master process 137 | """ 138 | import builtins as __builtin__ 139 | builtin_print = __builtin__.print 140 | 141 | def print(*args, **kwargs): 142 | force = kwargs.pop('force', False) 143 | if is_master or force: 144 | builtin_print(*args, **kwargs) 145 | 146 | __builtin__.print = print 147 | 148 | 149 | def all_reduce_mean(x): 150 | # reduce tensore for DDP 151 | # source: https://raw.githubusercontent.com/NVIDIA/apex/master/examples/imagenet/main_amp.py 152 | world_size = get_world_size() 153 | if world_size > 1: 154 | rt = x.clone() 155 | torch.distributed.all_reduce(rt, op=torch.distributed.ReduceOp.SUM) 156 | rt /= world_size 157 | return rt 158 | else: 159 | return x 160 | 161 | # def dist_init(port=23456): 162 | # 163 | # def init_parrots(host_addr, rank, local_rank, world_size, port): 164 | # os.environ['MASTER_ADDR'] = str(host_addr) 165 | # os.environ['MASTER_PORT'] = str(port) 166 | # os.environ['WORLD_SIZE'] = str(world_size) 167 | # os.environ['RANK'] = str(rank) 168 | # torch.distributed.init_process_group(backend="nccl") 169 | # torch.cuda.set_device(local_rank) 170 | # 171 | # def init(host_addr, rank, local_rank, world_size, port): 172 | # host_addr_full = 'tcp://' + host_addr + ':' + str(port) 173 | # torch.distributed.init_process_group("nccl", init_method=host_addr_full, 174 | # rank=rank, world_size=world_size) 175 | # torch.cuda.set_device(local_rank) 176 | # assert torch.distributed.is_initialized() 177 | # 178 | # 179 | # def parse_host_addr(s): 180 | # if '[' in s: 181 | # left_bracket = s.index('[') 182 | # right_bracket = s.index(']') 183 | # prefix = s[:left_bracket] 184 | # first_number = s[left_bracket+1:right_bracket].split(',')[0].split('-')[0] 185 | # return prefix + first_number 186 | # else: 187 | # return s 188 | # 189 | # if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 190 | # rank = int(os.environ["RANK"]) 191 | # local_rank = int(os.environ['LOCAL_RANK']) 192 | # world_size = int(os.environ['WORLD_SIZE']) 193 | # ip = 'env://' 194 | # 195 | # elif 'SLURM_PROCID' in os.environ: 196 | # rank = int(os.environ['SLURM_PROCID']) 197 | # local_rank = int(os.environ['SLURM_LOCALID']) 198 | # world_size = int(os.environ['SLURM_NTASKS']) 199 | # ip = parse_host_addr(os.environ['SLURM_STEP_NODELIST']) 200 | # else: 201 | # raise RuntimeError() 202 | # 203 | # if torch.__version__ == 'parrots': 204 | # init_parrots(ip, rank, local_rank, world_size, port) 205 | # else: 206 | # init(ip, rank, local_rank, world_size, port) 207 | # 208 | # return rank, local_rank, world_size 209 | 210 | 211 | # https://github.com/facebookresearch/msn 212 | class AllReduce(torch.autograd.Function): 213 | 214 | @staticmethod 215 | def forward(ctx, x): 216 | if ( 217 | dist.is_available() 218 | and dist.is_initialized() 219 | and (dist.get_world_size() > 1) 220 | ): 221 | x = x.contiguous() / dist.get_world_size() 222 | dist.all_reduce(x) 223 | return x 224 | 225 | @staticmethod 226 | def backward(ctx, grads): 227 | return grads -------------------------------------------------------------------------------- /data/randaugment.py: -------------------------------------------------------------------------------- 1 | # some code in this file is adapted from 2 | # https://github.com/kekmodel/FixMatch-pytorch/blob/master/dataset/randaugment.py 3 | # Original Copyright 2019 Jungdae Kim, Qing Yu. Licensed under the MIT License. 4 | # Modifications Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 5 | # SPDX-License-Identifier: CC-BY-NC-4.0 6 | 7 | import logging 8 | import random 9 | 10 | import numpy as np 11 | import PIL 12 | import PIL.ImageOps 13 | import PIL.ImageEnhance 14 | import PIL.ImageDraw 15 | from PIL import Image, ImageFilter 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | PARAMETER_MAX = 10 20 | 21 | 22 | def AutoContrast(img, **kwarg): 23 | return PIL.ImageOps.autocontrast(img) 24 | 25 | 26 | def Brightness(img, v, max_v, bias=0): 27 | v = _float_parameter(v, max_v) + bias 28 | return PIL.ImageEnhance.Brightness(img).enhance(v) 29 | 30 | 31 | def Color(img, v, max_v, bias=0): 32 | v = _float_parameter(v, max_v) + bias 33 | return PIL.ImageEnhance.Color(img).enhance(v) 34 | 35 | 36 | def Contrast(img, v, max_v, bias=0): 37 | v = _float_parameter(v, max_v) + bias 38 | return PIL.ImageEnhance.Contrast(img).enhance(v) 39 | 40 | 41 | def Cutout(img, v, max_v, bias=0): 42 | if v == 0: 43 | return img 44 | v = _float_parameter(v, max_v) + bias 45 | v = int(v * min(img.size)) 46 | return CutoutAbs(img, v) 47 | 48 | 49 | def CutoutAbs(img, v, **kwarg): 50 | w, h = img.size 51 | x0 = np.random.uniform(0, w) 52 | y0 = np.random.uniform(0, h) 53 | x0 = int(max(0, x0 - v / 2.)) 54 | y0 = int(max(0, y0 - v / 2.)) 55 | x1 = int(min(w, x0 + v)) 56 | y1 = int(min(h, y0 + v)) 57 | xy = (x0, y0, x1, y1) 58 | # gray 59 | color = (127, 127, 127) 60 | img = img.copy() 61 | PIL.ImageDraw.Draw(img).rectangle(xy, color) 62 | return img 63 | 64 | 65 | def Equalize(img, **kwarg): 66 | return PIL.ImageOps.equalize(img) 67 | 68 | 69 | def Identity(img, **kwarg): 70 | return img 71 | 72 | 73 | def Invert(img, **kwarg): 74 | return PIL.ImageOps.invert(img) 75 | 76 | 77 | def Posterize(img, v, max_v, bias=0): 78 | v = _int_parameter(v, max_v) + bias 79 | return PIL.ImageOps.posterize(img, v) 80 | 81 | 82 | def Rotate(img, v, max_v, bias=0): 83 | v = _int_parameter(v, max_v) + bias 84 | if random.random() < 0.5: 85 | v = -v 86 | return img.rotate(v) 87 | 88 | 89 | def Sharpness(img, v, max_v, bias=0): 90 | v = _float_parameter(v, max_v) + bias 91 | return PIL.ImageEnhance.Sharpness(img).enhance(v) 92 | 93 | 94 | def ShearX(img, v, max_v, bias=0): 95 | v = _float_parameter(v, max_v) + bias 96 | if random.random() < 0.5: 97 | v = -v 98 | return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0)) 99 | 100 | 101 | def ShearY(img, v, max_v, bias=0): 102 | v = _float_parameter(v, max_v) + bias 103 | if random.random() < 0.5: 104 | v = -v 105 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0)) 106 | 107 | 108 | def Solarize(img, v, max_v, bias=0): 109 | v = _int_parameter(v, max_v) + bias 110 | return PIL.ImageOps.solarize(img, 256 - v) 111 | 112 | 113 | def SolarizeAdd(img, v, max_v, bias=0, threshold=128): 114 | v = _int_parameter(v, max_v) + bias 115 | if random.random() < 0.5: 116 | v = -v 117 | img_np = np.array(img).astype(np.int) 118 | img_np = img_np + v 119 | img_np = np.clip(img_np, 0, 255) 120 | img_np = img_np.astype(np.uint8) 121 | img = Image.fromarray(img_np) 122 | return PIL.ImageOps.solarize(img, threshold) 123 | 124 | 125 | def TranslateX(img, v, max_v, bias=0): 126 | v = _float_parameter(v, max_v) + bias 127 | if random.random() < 0.5: 128 | v = -v 129 | v = int(v * img.size[0]) 130 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0)) 131 | 132 | 133 | def TranslateY(img, v, max_v, bias=0): 134 | v = _float_parameter(v, max_v) + bias 135 | if random.random() < 0.5: 136 | v = -v 137 | v = int(v * img.size[1]) 138 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v)) 139 | 140 | 141 | class GaussianBlur(object): 142 | """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709""" 143 | 144 | def __init__(self, sigma=[.1, 2.]): 145 | self.sigma = sigma 146 | 147 | def __call__(self, x): 148 | sigma = random.uniform(self.sigma[0], self.sigma[1]) 149 | x = x.filter(ImageFilter.GaussianBlur(radius=sigma)) 150 | return x 151 | 152 | 153 | class NoOpTransform(object): 154 | """ 155 | A transform that does nothing. 156 | """ 157 | 158 | def __init__(self): 159 | super().__init__() 160 | 161 | def __call__(self, tensor): 162 | """ 163 | Args: 164 | tensor (Tensor): Tensor image of size (C, H, W) to be normalized. 165 | 166 | Returns: 167 | Tensor: Original Tensor image. 168 | """ 169 | return tensor 170 | 171 | def __repr__(self): 172 | return self.__class__.__name__ + '()' 173 | 174 | 175 | def _float_parameter(v, max_v): 176 | return float(v) * max_v / PARAMETER_MAX 177 | 178 | 179 | def _int_parameter(v, max_v): 180 | return int(v * max_v / PARAMETER_MAX) 181 | 182 | 183 | def fixmatch_augment_pool(): 184 | # FixMatch paper 185 | augs = [(AutoContrast, None, None), 186 | (Brightness, 0.9, 0.05), 187 | (Color, 0.9, 0.05), 188 | (Contrast, 0.9, 0.05), 189 | (Equalize, None, None), 190 | (Identity, None, None), 191 | (Posterize, 4, 4), 192 | (Rotate, 30, 0), 193 | (Sharpness, 0.9, 0.05), 194 | (ShearX, 0.3, 0), 195 | (ShearY, 0.3, 0), 196 | (Solarize, 256, 0), 197 | (TranslateX, 0.3, 0), 198 | (TranslateY, 0.3, 0)] 199 | return augs 200 | 201 | 202 | def my_augment_pool(): 203 | # Test 204 | augs = [(AutoContrast, None, None), 205 | (Brightness, 1.8, 0.1), 206 | (Color, 1.8, 0.1), 207 | (Contrast, 1.8, 0.1), 208 | (Cutout, 0.2, 0), 209 | (Equalize, None, None), 210 | (Invert, None, None), 211 | (Posterize, 4, 4), 212 | (Rotate, 30, 0), 213 | (Sharpness, 1.8, 0.1), 214 | (ShearX, 0.3, 0), 215 | (ShearY, 0.3, 0), 216 | (Solarize, 256, 0), 217 | (SolarizeAdd, 110, 0), 218 | (TranslateX, 0.45, 0), 219 | (TranslateY, 0.45, 0)] 220 | return augs 221 | 222 | 223 | def imagenet_augment_pool(): 224 | # op, max_v, bias 225 | augs = [(AutoContrast, None, None), 226 | (Brightness, 1.8, 0.1), 227 | (Color, 1.8, 0.1), 228 | (Contrast, 1.8, 0.1), 229 | (Equalize, None, None), 230 | (Identity, None, None), 231 | (Invert, None, None), 232 | (Posterize, 4, 4), 233 | (Rotate, 30, 0), 234 | (Sharpness, 1.8, 0.1), 235 | (ShearX, 0.3, 0), 236 | (ShearY, 0.3, 0), 237 | (Solarize, 256, 0), 238 | (SolarizeAdd, 110, 0), 239 | (TranslateX, 0.45, 0), 240 | (TranslateY, 0.45, 0)] 241 | return augs 242 | 243 | 244 | class RandAugmentPC(object): 245 | def __init__(self, n, m): 246 | assert n >= 1 247 | assert 1 <= m <= 10 248 | self.n = n 249 | self.m = m 250 | self.augment_pool = my_augment_pool() 251 | 252 | def __call__(self, img): 253 | ops = random.choices(self.augment_pool, k=self.n) 254 | for op, max_v, bias in ops: 255 | prob = np.random.uniform(0.2, 0.8) 256 | if random.random() + prob >= 1: 257 | img = op(img, v=self.m, max_v=max_v, bias=bias) 258 | img = CutoutAbs(img, 16) 259 | return img 260 | 261 | 262 | class RandAugmentMC(object): 263 | def __init__(self, n, m): 264 | assert n >= 1 265 | assert 1 <= m <= 10 266 | self.n = n 267 | self.m = m 268 | self.augment_pool = fixmatch_augment_pool() 269 | 270 | def __call__(self, img): 271 | ops = random.choices(self.augment_pool, k=self.n) 272 | for op, max_v, bias in ops: 273 | v = np.random.randint(1, self.m) 274 | if random.random() < 0.5: 275 | img = op(img, v=v, max_v=max_v, bias=bias) 276 | img = CutoutAbs(img, 16) 277 | return img 278 | 279 | 280 | class RandAugment(object): 281 | def __init__(self, n, m, prob=None): 282 | assert n >= 1 283 | assert 1 <= m <= 10 284 | if prob is not None: 285 | assert 0. <= prob <= 1. 286 | self.n = n 287 | self.m = m 288 | self.prob = prob 289 | self.augment_pool = imagenet_augment_pool() 290 | 291 | def __call__(self, img): 292 | ops = random.choices(self.augment_pool, k=self.n) 293 | for op, max_v, bias in ops: 294 | v = np.random.randint(1, self.m) 295 | if self.prob is not None: 296 | if random.random() < self.prob: 297 | img = op(img, v=v, max_v=max_v, bias=bias) 298 | else: 299 | img = op(img, v=v, max_v=max_v, bias=bias) 300 | return img 301 | 302 | def __repr__(self): 303 | return self.__class__.__name__ + '(m={0}, n={1}, prob={2})'.format(self.m, self.n, self.prob) 304 | -------------------------------------------------------------------------------- /data/fer2013.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | """ 5 | """ 6 | 7 | __author__ = "GZ" 8 | 9 | import os 10 | import sys 11 | import csv 12 | import pathlib 13 | import numpy as np 14 | from tqdm import tqdm 15 | from typing import Any, Callable, Optional, Tuple 16 | from PIL import Image 17 | import matplotlib.pyplot as plt 18 | 19 | if sys.platform == 'win32': 20 | os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" 21 | 22 | import torch 23 | from torchvision.datasets import VisionDataset 24 | import torchvision.transforms as transforms 25 | 26 | # Root directory of the project 27 | try: 28 | abspath = os.path.abspath(__file__) 29 | except NameError: 30 | abspath = os.getcwd() 31 | ROOT_DIR = os.path.dirname(abspath) 32 | 33 | 34 | # List of folders for training, validation and test. 35 | folder_names = {'Training' : 'FER2013Train', 36 | 'PublicTest' : 'FER2013Valid', 37 | 'PrivateTest': 'FER2013Test'} 38 | 39 | 40 | class FER2013(VisionDataset): 41 | """`FER2013 42 | `_ Dataset. 43 | 44 | Args: 45 | root (string): Root directory of dataset where directory 46 | ``root/fer2013`` exists. 47 | split (string, optional): The dataset split, supports ``"train"`` (default), or ``"test"``. 48 | transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed 49 | version. E.g, ``transforms.RandomCrop`` 50 | target_transform (callable, optional): A function/transform that takes in the target and transforms it. 51 | """ 52 | 53 | def __init__( 54 | self, 55 | root: str, 56 | split: str = "train", 57 | transform: Optional[Callable] = None, 58 | target_transform: Optional[Callable] = None, 59 | convert_rgb=False 60 | ) -> None: 61 | self._split = split 62 | assert split in ['Training', 'PublicTest', 'PrivateTest'] 63 | super().__init__(root, transform=transform, target_transform=target_transform) 64 | 65 | self.convert_rgb = convert_rgb 66 | 67 | base_folder = pathlib.Path(self.root) 68 | data_file = base_folder / "fer2013.csv" 69 | 70 | self._samples = [] 71 | with open(data_file, "r", newline="") as file: 72 | for row in csv.DictReader(file): 73 | if split == row["Usage"]: 74 | data = ( 75 | torch.tensor([int(idx) for idx in row["pixels"].split()], dtype=torch.uint8).reshape(48, 48), 76 | int(row["emotion"]) if "emotion" in row else None, 77 | ) 78 | self._samples.append(data) 79 | 80 | def __len__(self) -> int: 81 | return len(self._samples) 82 | 83 | def __getitem__(self, idx: int): 84 | image_tensor, target = self._samples[idx] 85 | image = Image.fromarray(image_tensor.numpy()) 86 | 87 | if self.convert_rgb: 88 | image = image.convert("RGB") 89 | 90 | if self.transform is not None: 91 | image = self.transform(image) 92 | 93 | if self.target_transform is not None: 94 | target = self.target_transform(target) 95 | 96 | return image, target, idx 97 | 98 | 99 | def extra_repr(self) -> str: 100 | return f"split={self._split}" 101 | 102 | 103 | class FERplus(VisionDataset): 104 | """ 105 | https://github.com/microsoft/FERPlus/blob/master/src/ferplus.py 106 | https://github.com/siqueira-hc/Efficient-Facial-Feature-Learning-with-Wide-Ensemble-based-Convolutional-Neural-Networks 107 | """ 108 | def __init__( 109 | self, 110 | root: str, 111 | split: str = "Training", 112 | transform: Optional[Callable] = None, 113 | target_transform: Optional[Callable] = None, 114 | convert_rgb=False 115 | ) -> None: 116 | self._split = split 117 | assert split in ['Training', 'PublicTest', 'PrivateTest'] 118 | super().__init__(root, transform=transform, target_transform=target_transform) 119 | 120 | self.convert_rgb = convert_rgb 121 | self.per_emotion_count = None 122 | 123 | # Default values 124 | self.emotion_count = 8 125 | 126 | # Load data 127 | self.loaded_data = self._load() 128 | print('Size of the loaded set: {}'.format(self.loaded_data[0].shape[0])) 129 | 130 | def __len__(self): 131 | return self.loaded_data[0].shape[0] 132 | 133 | def __getitem__(self, idx): 134 | image = self.loaded_data[0][idx] 135 | image = Image.fromarray(image) 136 | target = self.loaded_data[1][idx] 137 | 138 | if self.transform is not None: 139 | image = self.transform(image) 140 | 141 | return image, target, idx 142 | 143 | # @staticmethod 144 | # def get_class(idx): 145 | # classes = { 146 | # 0: 'Neutral', 147 | # 1: 'Happy', 148 | # 2: 'Sad', 149 | # 3: 'Surprise', 150 | # 4: 'Fear', 151 | # 5: 'Disgust', 152 | # 6: 'Anger', 153 | # 7: 'Contempt'} 154 | # 155 | # return classes[idx] 156 | # 157 | # @staticmethod 158 | # def _parse_to_label(idx): 159 | # """ 160 | # Parse labels to make them compatible with AffectNet. 161 | # :param idx: 162 | # :return: 163 | # """ 164 | # emo_to_return = np.argmax(idx) 165 | # 166 | # if emo_to_return == 2: 167 | # emo_to_return = 3 168 | # elif emo_to_return == 3: 169 | # emo_to_return = 2 170 | # elif emo_to_return == 4: 171 | # emo_to_return = 6 172 | # elif emo_to_return == 6: 173 | # emo_to_return = 4 174 | # 175 | # return emo_to_return 176 | 177 | @staticmethod 178 | def _process_data(emotion_raw): 179 | size = len(emotion_raw) 180 | emotion_unknown = [0.0] * size 181 | emotion_unknown[-2] = 1.0 182 | 183 | # remove emotions with a single vote (outlier removal) 184 | for i in range(size): 185 | if emotion_raw[i] < 1.0 + sys.float_info.epsilon: 186 | emotion_raw[i] = 0.0 187 | 188 | sum_list = sum(emotion_raw) 189 | emotion = [0.0] * size 190 | 191 | # find the peak value of the emo_raw list 192 | maxval = max(emotion_raw) 193 | if maxval > 0.5 * sum_list: 194 | emotion[np.argmax(emotion_raw)] = maxval 195 | else: 196 | emotion = emotion_unknown # force setting as unknown 197 | 198 | return [float(i) / sum(emotion) for i in emotion] 199 | 200 | def _load(self): 201 | csv_label = [] 202 | data, labels = [], [] 203 | self.per_emotion_count = np.zeros(self.emotion_count, dtype=np.int32) 204 | 205 | path_folders_images = os.path.join(self.root, 'Images', folder_names[self._split]) 206 | path_folders_labels = os.path.join(self.root, 'Labels', folder_names[self._split]) 207 | 208 | with open(os.path.join(path_folders_labels, "label.csv")) as csvfile: 209 | lines = csv.reader(csvfile) 210 | for row in lines: 211 | csv_label.append(row) 212 | 213 | for l in csv_label: 214 | emotion_raw = list(map(float, l[2:len(l)])) 215 | emotion = self._process_data(emotion_raw) 216 | idx = np.argmax(emotion) 217 | 218 | if idx < self.emotion_count: # not unknown or non-face 219 | self.per_emotion_count[idx] += 1 220 | 221 | # emotion = emotion[:-2] 222 | # emotion = [float(i) / sum(emotion) for i in emotion] 223 | # emotion = self._parse_to_label(emotion) 224 | 225 | image = Image.open(os.path.join(path_folders_images, l[0])) 226 | if self.convert_rgb: 227 | image = image.convert("RGB") 228 | image = np.array(image) 229 | 230 | box = list(map(int, l[1][1:-1].split(','))) 231 | if box[-1] != 48: 232 | print("[INFO] Face is not centralized.") 233 | print(os.path.join(path_folders_images, l[0])) 234 | print(box) 235 | exit(-1) 236 | 237 | image = image[box[0]:box[2], box[1]:box[3], :] 238 | 239 | data.append(image) 240 | labels.append(idx) 241 | 242 | return [np.array(data), np.array(labels)] 243 | 244 | 245 | if __name__ == '__main__': 246 | display_transform = transforms.Compose([ 247 | transforms.Resize((96, 96)), 248 | transforms.ToTensor() 249 | ]) 250 | 251 | split = "PrivateTest" 252 | # dataset = FER2013(root="../data/FER/fer2013", split=split, transform=display_transform) 253 | dataset = FERplus(root="../data/FER/FERPlus/data", split=split, transform=display_transform, convert_rgb=True) 254 | print(dataset) 255 | 256 | loader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=8, pin_memory=True, 257 | drop_last=False) 258 | 259 | with torch.no_grad(): 260 | for i, (images, target, _) in enumerate(tqdm(loader)): 261 | img = np.clip(images.cpu().numpy(), 0, 1) # [0, 1] 262 | img = img.transpose(0, 2, 3, 1) 263 | img = (img * 255).astype(np.uint8) 264 | img = img.squeeze() 265 | 266 | fig, axs = plt.subplots(1, 1, figsize=(8, 8)) 267 | axs.imshow(img, cmap='gray') 268 | axs.axis("off") 269 | plt.show() 270 | -------------------------------------------------------------------------------- /data/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: CC-BY-NC-4.0 3 | 4 | import os 5 | import random 6 | from io import BytesIO 7 | from PIL import Image 8 | from PIL import ImageOps, ImageFilter 9 | import torch 10 | from torchvision import transforms 11 | 12 | from .randaugment import RandAugment 13 | 14 | 15 | # RGB mean tensor([0.5885, 0.4407, 0.3724]) 16 | # RGB std tensor([0.2271, 0.1961, 0.1827]) 17 | 18 | # RGB mean tensor([0.5231, 0.4044, 0.3489]) 19 | # RGB std tensor([0.2536, 0.2194, 0.2070]) 20 | 21 | IMG_MEAN = {"vggface2": [0.5231, 0.4044, 0.3489], 22 | "laionface": [0.48145466, 0.4578275, 0.40821073], 23 | "in1k": [0.485, 0.456, 0.406], 24 | "in100": [0.485, 0.456, 0.406]} 25 | IMG_STD = {"vggface2": [0.2536, 0.2194, 0.2070], 26 | "laionface": [0.26862954, 0.26130258, 0.27577711], 27 | "in1k": [0.229, 0.224, 0.225], 28 | "in100": [0.229, 0.224, 0.225]} 29 | 30 | 31 | class Solarize(object): 32 | def __init__(self, threshold=128): 33 | self.threshold = threshold 34 | 35 | def __call__(self, img): 36 | return ImageOps.solarize(img, self.threshold) 37 | 38 | def __repr__(self): 39 | repr_str = self.__class__.__name__ 40 | return repr_str 41 | 42 | 43 | class GaussianBlur(object): 44 | """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709""" 45 | 46 | def __init__(self, sigma=[.1, 2.]): 47 | self.sigma = sigma 48 | 49 | def __call__(self, x): 50 | sigma = random.uniform(self.sigma[0], self.sigma[1]) 51 | x = x.filter(ImageFilter.GaussianBlur(radius=sigma)) 52 | return x 53 | 54 | 55 | class AddGaussianNoise(object): 56 | def __init__(self, mean=0., std=1.): 57 | self.std = std 58 | self.mean = mean 59 | 60 | def __call__(self, tensor): 61 | tensor = tensor + torch.randn(tensor.size()) * self.std + self.mean 62 | tensor = torch.clamp(tensor, min=0., max=1.) 63 | return tensor 64 | 65 | def __repr__(self): 66 | return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std) 67 | 68 | 69 | class PcaAug(object): 70 | _eigval = torch.Tensor([0.2175, 0.0188, 0.0045]) 71 | _eigvec = torch.Tensor([ 72 | [-0.5675, 0.7192, 0.4009], 73 | [-0.5808, -0.0045, -0.8140], 74 | [-0.5836, -0.6948, 0.4203], 75 | ]) 76 | 77 | def __init__(self, alpha=0.1): 78 | self.alpha = alpha 79 | 80 | def __call__(self, im): 81 | alpha = torch.randn(3) * self.alpha 82 | rgb = (self._eigvec * alpha.expand(3, 3) * self._eigval.expand(3, 3)).sum(1) 83 | return im + rgb.reshape(3, 1, 1) 84 | 85 | 86 | class JPEGNoise(object): 87 | def __init__(self, low=30, high=99): 88 | self.low = low 89 | self.high = high 90 | 91 | def __call__(self, im): 92 | H = im.height 93 | W = im.width 94 | rW = max(int(0.8 * W), int(W * (1 + 0.5 * torch.randn([])))) 95 | im = transforms.functional.resize(im, (rW, rW)) 96 | buf = BytesIO() 97 | im.save(buf, format='JPEG', quality=torch.randint(self.low, self.high, 98 | []).item()) 99 | im = Image.open(buf) 100 | im = transforms.functional.resize(im, (H, W)) 101 | return im 102 | 103 | 104 | def get_augmentations(aug_type, dataset): 105 | normalize = transforms.Normalize(mean=IMG_MEAN[dataset.lower()], 106 | std=IMG_STD[dataset.lower()]) 107 | 108 | default_train_augs = [ 109 | transforms.RandomResizedCrop(224), 110 | transforms.RandomHorizontalFlip(), 111 | ] 112 | default_val_augs = [ 113 | transforms.Resize(256), 114 | transforms.CenterCrop(224), 115 | ] 116 | appendix_augs = [ 117 | transforms.ToTensor(), 118 | normalize, 119 | ] 120 | if aug_type == 'DefaultTrain': 121 | augs = default_train_augs + appendix_augs 122 | elif aug_type == 'DefaultVal': 123 | augs = default_val_augs + appendix_augs 124 | elif aug_type == 'RandAugment': 125 | augs = default_train_augs + [RandAugment(n=2, m=10)] + appendix_augs 126 | elif aug_type == 'MoCoV1': 127 | augs = [ 128 | transforms.RandomResizedCrop(224, scale=(0.2, 1.)), 129 | transforms.RandomGrayscale(p=0.2), 130 | transforms.ColorJitter(0.4, 0.4, 0.4, 0.4), 131 | transforms.RandomHorizontalFlip() 132 | ] + appendix_augs 133 | elif aug_type == 'MoCoV2': 134 | augs = [ 135 | transforms.RandomResizedCrop(224, scale=(0.2, 1.)), 136 | transforms.RandomApply([ 137 | transforms.ColorJitter(0.4, 0.4, 0.4, 0.1) # not strengthened 138 | ], p=0.8), 139 | transforms.RandomGrayscale(p=0.2), 140 | transforms.RandomApply([GaussianBlur([.1, 2.])], p=0.5), 141 | transforms.RandomHorizontalFlip(), 142 | ] + appendix_augs 143 | else: 144 | raise NotImplementedError('augmentation type not found: {}'.format(aug_type)) 145 | 146 | return augs 147 | 148 | 149 | def get_transforms(aug_type, dataset="in1k"): 150 | augs = get_augmentations(aug_type, dataset) 151 | return transforms.Compose(augs) 152 | 153 | 154 | def get_byol_tranforms(): 155 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 156 | std=[0.229, 0.224, 0.225]) 157 | augmentation1 = [ 158 | transforms.RandomResizedCrop(224, scale=(0.2, 1.)), 159 | transforms.RandomHorizontalFlip(), 160 | transforms.RandomApply([ 161 | transforms.ColorJitter(0.4, 0.4, 0.2, 0.1) # not strengthened 162 | ], p=0.8), 163 | transforms.RandomGrayscale(p=0.2), 164 | transforms.RandomApply([GaussianBlur([.1, 2.])], p=1.), 165 | transforms.RandomApply([Solarize()], p=0.), 166 | transforms.ToTensor(), 167 | normalize 168 | ] 169 | augmentation2 = [ 170 | transforms.RandomResizedCrop(224, scale=(0.2, 1.)), 171 | transforms.RandomHorizontalFlip(), 172 | transforms.RandomApply([ 173 | transforms.ColorJitter(0.4, 0.4, 0.2, 0.1) # not strengthened 174 | ], p=0.8), 175 | transforms.RandomGrayscale(p=0.2), 176 | transforms.RandomApply([GaussianBlur([.1, 2.])], p=0.1), 177 | transforms.RandomApply([Solarize()], p=0.2), 178 | transforms.ToTensor(), 179 | normalize 180 | ] 181 | transform1 = transforms.Compose(augmentation1) 182 | transform2 = transforms.Compose(augmentation2) 183 | return transform1, transform2 184 | 185 | 186 | def get_vggface_tranforms(image_size=128): 187 | normalize = transforms.Normalize(mean=IMG_MEAN["vggface2"], 188 | std=IMG_STD["vggface2"]) 189 | 190 | augmentation1 = [ 191 | transforms.RandomResizedCrop(image_size, scale=(0.2, 1.)), 192 | # transforms.Resize([image_size, image_size]), 193 | transforms.RandomHorizontalFlip(), 194 | transforms.RandomApply([ 195 | transforms.ColorJitter(0.4, 0.4, 0.2, 0.1) # not strengthened 196 | ], p=0.8), 197 | transforms.RandomGrayscale(p=0.2), 198 | transforms.RandomApply([GaussianBlur([.1, 2.])], p=1.), 199 | transforms.RandomApply([Solarize()], p=0.), 200 | transforms.ToTensor(), 201 | normalize 202 | ] 203 | augmentation2 = [ 204 | transforms.RandomResizedCrop(image_size, scale=(0.2, 1.)), 205 | # transforms.Resize([image_size, image_size]), 206 | transforms.RandomHorizontalFlip(), 207 | transforms.RandomApply([ 208 | transforms.ColorJitter(0.4, 0.4, 0.2, 0.1) # not strengthened 209 | ], p=0.8), 210 | transforms.RandomGrayscale(p=0.2), 211 | transforms.RandomApply([GaussianBlur([.1, 2.])], p=0.1), 212 | transforms.RandomApply([Solarize()], p=0.2), 213 | transforms.ToTensor(), 214 | normalize 215 | ] 216 | transform1 = transforms.Compose(augmentation1) 217 | transform2 = transforms.Compose(augmentation2) 218 | return transform1, transform2 219 | 220 | 221 | class TwoCropsTransform: 222 | """Take two random crops of one image.""" 223 | 224 | def __init__(self, transform1, transform2): 225 | self.transform1 = transform1 226 | self.transform2 = transform2 227 | 228 | def __call__(self, x): 229 | out1 = self.transform1(x) 230 | out2 = self.transform2(x) 231 | return out1, out2 232 | 233 | def __repr__(self): 234 | format_string = self.__class__.__name__ + '(' 235 | names = ['transform1', 'transform2'] 236 | for idx, t in enumerate([self.transform1, self.transform2]): 237 | format_string += '\n' 238 | t_string = '{0}={1}'.format(names[idx], t) 239 | t_string_split = t_string.split('\n') 240 | t_string_split = [' ' + tstr for tstr in t_string_split] 241 | t_string = '\n'.join(t_string_split) 242 | format_string += '{0}'.format(t_string) 243 | format_string += '\n)' 244 | return format_string 245 | 246 | 247 | if __name__ == '__main__': 248 | from utils.utils import dump_image 249 | 250 | # Ryan_Gosling Emily_VanCamp 251 | img = Image.open("./vis_data/0008_01.jpg") 252 | 253 | augment = get_vggface_tranforms(image_size=224) 254 | img1 = augment[0](img) 255 | img2 = augment[1](img) 256 | 257 | save_dir = "./output" 258 | os.makedirs(save_dir, exist_ok=True) 259 | 260 | filepath = os.path.join(save_dir, "{}.png".format("img1")) 261 | dump_image(img1, IMG_MEAN["vggface2"], IMG_STD["vggface2"], filepath=filepath) 262 | 263 | filepath = os.path.join(save_dir, "{}.png".format("img2")) 264 | dump_image(img2, IMG_MEAN["vggface2"], IMG_STD["vggface2"], filepath=filepath) 265 | -------------------------------------------------------------------------------- /data/base_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright Lang Huang (laynehuang@outlook.com). All Rights Reserved. 2 | # SPDX-License-Identifier: CC-BY-NC-4.0 3 | 4 | import os 5 | import numpy as np 6 | import torch.utils.data as data 7 | import torchvision.datasets as datasets 8 | import torchvision.transforms as transforms 9 | from torchvision.datasets.folder import ImageFolder, default_loader 10 | from PIL import Image 11 | 12 | try: 13 | import mc 14 | except ImportError: 15 | mc = None 16 | import io 17 | 18 | 19 | # class DatasetCache(data.Dataset): 20 | # def __init__(self): 21 | # super().__init__() 22 | # self.initialized = False 23 | # 24 | # 25 | # def _init_memcached(self): 26 | # if not self.initialized: 27 | # server_list_config_file = "/mnt/cache/share/memcached_client/server_list.conf" 28 | # client_config_file = "/mnt/cache/share/memcached_client/client.conf" 29 | # self.mclient = mc.MemcachedClient.GetInstance(server_list_config_file, client_config_file) 30 | # self.initialized = True 31 | # 32 | # def load_image(self, filename): 33 | # self._init_memcached() 34 | # value = mc.pyvector() 35 | # self.mclient.Get(filename, value) 36 | # value_str = mc.ConvertBuffer(value) 37 | # 38 | # buff = io.BytesIO(value_str) 39 | # with Image.open(buff) as img: 40 | # img = img.convert('RGB') 41 | # return img 42 | # 43 | # 44 | # 45 | # class BaseDataset(DatasetCache): 46 | # def __init__(self, mode='train', max_class=1000, aug=None, 47 | # prefix='/mnt/cache/share/images/meta', 48 | # image_folder_prefix='/mnt/cache/share/images/'): 49 | # super().__init__() 50 | # self.initialized = False 51 | # 52 | # if mode == 'train': 53 | # image_list = os.path.join(prefix, 'train.txt') 54 | # self.image_folder = os.path.join(image_folder_prefix, 'train') 55 | # elif mode == 'test': 56 | # image_list = os.path.join(prefix, 'test.txt') 57 | # self.image_folder = os.path.join(image_folder_prefix, 'test') 58 | # elif mode == 'val': 59 | # image_list = os.path.join(prefix, 'val.txt') 60 | # self.image_folder = os.path.join(image_folder_prefix, 'val') 61 | # else: 62 | # raise NotImplementedError('mode: ' + mode + ' does not exist please select from [train, test, val]') 63 | # 64 | # 65 | # self.samples = [] 66 | # with open(image_list) as f: 67 | # for line in f: 68 | # name, label = line.split() 69 | # label = int(label) 70 | # if label < max_class: 71 | # self.samples.append((label, name)) 72 | # 73 | # if aug is None: 74 | # if mode == 'train': 75 | # self.transform = transforms.Compose([ 76 | # transforms.RandomResizedCrop(224), 77 | # transforms.RandomHorizontalFlip(), 78 | # transforms.ToTensor(), 79 | # transforms.Normalize(mean=[0.485, 0.456, 0.406], 80 | # std=[0.229, 0.224, 0.225]) 81 | # ]) 82 | # else: 83 | # self.transform = transforms.Compose([ 84 | # transforms.Resize(256), 85 | # transforms.CenterCrop(224), 86 | # transforms.ToTensor(), 87 | # transforms.Normalize(mean=[0.485, 0.456, 0.406], 88 | # std=[0.229, 0.224, 0.225]), 89 | # ]) 90 | # 91 | # else: 92 | # self.transform = aug 93 | # 94 | # 95 | # def get_keep_index(samples, percent, num_classes, shuffle=False): 96 | # labels = np.array([sample[0] for sample in samples]) 97 | # keep_indexs = [] 98 | # for i in range(num_classes): 99 | # idx = np.where(labels == i)[0] 100 | # num_sample = len(idx) 101 | # label_per_class = min(max(1, round(percent * num_sample)), num_sample) 102 | # if shuffle: 103 | # np.random.shuffle(idx) 104 | # keep_indexs.extend(idx[:label_per_class]) 105 | # 106 | # return keep_indexs 107 | # 108 | # 109 | # class ImageNet(BaseDataset): 110 | # def __init__(self, mode='train', max_class=1000, num_classes=1000, transform=None, 111 | # percent=1., shuffle=False, **kwargs): 112 | # super().__init__(mode, max_class, aug=transform, **kwargs) 113 | # 114 | # assert 0 <= percent <= 1 115 | # if percent < 1: 116 | # keep_indexs = get_keep_index(self.samples, percent, num_classes, shuffle) 117 | # self.samples = [self.samples[i] for i in keep_indexs] 118 | # 119 | # def __len__(self): 120 | # return self.samples.__len__() 121 | # 122 | # def __getitem__(self, index): 123 | # label, name = self.samples[index] 124 | # filename = os.path.join(self.image_folder, name) 125 | # img = self.load_image(filename) 126 | # return self.transform(img), label, index 127 | # 128 | # 129 | # class ImageNetWithIdx(BaseDataset): 130 | # def __init__(self, mode='train', max_class=1000, num_classes=1000, transform=None, 131 | # idx=None, shuffle=False, **kwargs): 132 | # super().__init__(mode, max_class, aug=transform, **kwargs) 133 | # 134 | # assert idx is not None 135 | # with open(idx, "r") as fin: 136 | # samples = [line.strip().split(" ") for line in fin.readlines()] 137 | # self.samples = samples 138 | # print(f"Len of training set: {len(self.samples)}") 139 | # 140 | # def __len__(self): 141 | # return self.samples.__len__() 142 | # 143 | # def __getitem__(self, index): 144 | # label, name = self.samples[index] 145 | # filename = os.path.join(self.image_folder, name) 146 | # img = self.load_image(filename) 147 | # return self.transform(img), int(label), index 148 | # 149 | # 150 | # class ImageNet100(ImageNet): 151 | # def __init__(self, **kwargs): 152 | # super().__init__( 153 | # num_classes=100, 154 | # prefix='/mnt/lustre/huanglang/research/selfsup/data/imagenet-100/', 155 | # image_folder_prefix='/mnt/lustre/huanglang/research/selfsup/data/images', 156 | # **kwargs) 157 | # 158 | # class ImageFolderWithPercent(ImageFolder): 159 | # 160 | # def __init__(self, root, transform=None, target_transform=None, 161 | # loader=default_loader, is_valid_file=None, percent=1.0, shuffle=False): 162 | # super().__init__(root, transform=transform, target_transform=target_transform, 163 | # loader=loader, is_valid_file=is_valid_file) 164 | # assert 0 <= percent <= 1 165 | # if percent < 1: 166 | # keep_indexs = get_keep_index(self.targets, percent, len(self.classes), shuffle) 167 | # self.samples = [self.samples[i] for i in keep_indexs] 168 | # self.targets = [self.targets[i] for i in keep_indexs] 169 | # self.imgs = self.samples 170 | # 171 | # 172 | # class ImageFolderWithIndex(ImageFolder): 173 | # 174 | # def __init__(self, root, indexs=None, transform=None, target_transform=None, 175 | # loader=default_loader, is_valid_file=None): 176 | # super().__init__(root, transform=transform, target_transform=target_transform, 177 | # loader=loader, is_valid_file=is_valid_file) 178 | # if indexs is not None: 179 | # self.samples = [self.samples[i] for i in indexs] 180 | # self.targets = [self.targets[i] for i in indexs] 181 | # self.imgs = self.samples 182 | 183 | 184 | class ImageFolderInstance(datasets.ImageFolder): 185 | def __getitem__(self, index): 186 | path, target = self.samples[index] 187 | sample = self.loader(path) 188 | if self.transform is not None: 189 | sample = self.transform(sample) 190 | if self.target_transform is not None: 191 | target = self.target_transform(target) 192 | 193 | return sample, target, index 194 | 195 | 196 | class ImageFolderSubset(datasets.ImageFolder): 197 | """Folder datasets which returns the index of the image (for memory_bank) 198 | """ 199 | def __init__(self, class_path, root, transform, **kwargs): 200 | super().__init__(root, transform, **kwargs) 201 | self.class_path = class_path 202 | new_samples, sorted_classes = self.get_class_samples() 203 | self.imgs = self.samples = new_samples # len=126689 204 | self.classes = sorted_classes 205 | self.class_to_idx = {cls_name: i for i, cls_name in enumerate(sorted_classes)} 206 | self.targets = [s[1] for s in self.samples] 207 | 208 | def get_class_samples(self): 209 | classes = open(self.class_path).readlines() 210 | classes = [m.strip() for m in classes] 211 | classes = set(classes) 212 | class_to_sample = [[os.path.basename(os.path.dirname(m[0])), m] for m in self.imgs] 213 | selected_samples = [m[1] for m in class_to_sample if m[0] in classes] 214 | 215 | sorted_classes = sorted(list(classes)) 216 | target_mapping = {self.class_to_idx[k]: j for j, k in enumerate(sorted_classes)} 217 | 218 | valid_pairs = [[m[0], target_mapping[m[1]]] for m in selected_samples] 219 | return valid_pairs, sorted_classes 220 | 221 | def __getitem__(self, index): 222 | path, target = self.samples[index] 223 | sample = self.loader(path) 224 | if self.transform is not None: 225 | sample = self.transform(sample) 226 | if self.target_transform is not None: 227 | target = self.target_transform(target) 228 | 229 | return sample, target, index 230 | 231 | 232 | def get_dataset(dataset, mode, transform, data_root=None, **kwargs): 233 | data_dir = os.path.join(data_root, mode) 234 | if mode == "val" and "ImageNet" in data_root and "nobackup_mmv_ioannisp" in data_root: 235 | data_dir = "/import/nobackup_mmv_ioannisp/zg002/data/ImageNet/val" 236 | in100_class_path = "./data/imagenet100.txt" 237 | 238 | if dataset.lower() == 'in1k': 239 | return ImageFolderInstance(data_dir, transform=transform) 240 | elif dataset.lower() == 'in100': 241 | return ImageFolderSubset(in100_class_path, data_dir, transform) 242 | elif dataset.lower() == "vggface2": 243 | return ImageFolderInstance(data_dir, transform=transform) 244 | # elif dataset == 'in1k_idx': 245 | # return ImageNetWithIdx(mode, transform=transform, **kwargs) 246 | # else: # ImageFolder 247 | # data_dir = os.path.join(data_root, mode) 248 | # assert os.path.isdir(data_dir) 249 | # return ImageFolderWithPercent(data_dir, transform, **kwargs) 250 | 251 | -------------------------------------------------------------------------------- /engine.py: -------------------------------------------------------------------------------- 1 | # Original copyright Amazon.com, Inc. or its affiliates, under CC-BY-NC-4.0 License. 2 | # Modifications Copyright Lang Huang (laynehuang@outlook.com). All Rights Reserved. 3 | # SPDX-License-Identifier: CC-BY-NC-4.0 4 | 5 | import time 6 | from datetime import timedelta 7 | import numpy as np 8 | try: 9 | import faiss 10 | except ImportError: 11 | pass 12 | 13 | import torch 14 | import torch.nn as nn 15 | from classy_vision.generic.distributed_util import is_distributed_training_run 16 | 17 | from utils import utils 18 | from utils.dist_utils import all_reduce_mean 19 | 20 | def validate(val_loader, model, criterion, args): 21 | batch_time = utils.AverageMeter('Time', ':6.3f') 22 | losses = utils.AverageMeter('Loss', ':.4e') 23 | top1 = utils.AverageMeter('Acc@1', ':6.2f') 24 | top5 = utils.AverageMeter('Acc@5', ':6.2f') 25 | progress = utils.ProgressMeter( 26 | len(val_loader), 27 | [batch_time, losses, top1, top5], 28 | prefix='Test: ') 29 | 30 | # switch to evaluate mode 31 | model.eval() 32 | 33 | with torch.no_grad(): 34 | end = time.time() 35 | for i, (images, target, _) in enumerate(val_loader): 36 | if args.gpu is not None: 37 | images = images.cuda(args.gpu, non_blocking=True) 38 | target = target.cuda(args.gpu, non_blocking=True) 39 | 40 | # compute output 41 | output = model(images) 42 | loss = criterion(output, target) 43 | 44 | # measure accuracy and record loss 45 | acc1, acc5 = utils.accuracy(output, target, topk=(1, 5)) 46 | 47 | if is_distributed_training_run(): 48 | # torch.distributed.barrier() 49 | acc1 = all_reduce_mean(acc1) 50 | acc5 = all_reduce_mean(acc5) 51 | 52 | losses.update(loss.item(), images.size(0)) 53 | top1.update(acc1[0], images.size(0)) 54 | top5.update(acc5[0], images.size(0)) 55 | 56 | # measure elapsed time 57 | batch_time.update(time.time() - end) 58 | end = time.time() 59 | 60 | if i % args.print_freq == 0: 61 | progress.display(i) 62 | 63 | # TODO: this should also be done with the ProgressMeter 64 | print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f} Loss {loss.avg:.4f}' 65 | .format(top1=top1, top5=top5, loss=losses)) 66 | 67 | return top1.avg 68 | 69 | 70 | def ss_validate(val_loader_base, val_loader_query, model, args): 71 | print("start KNN evaluation with key size={} and query size={}".format( 72 | len(val_loader_base.dataset.samples), len(val_loader_query.dataset.samples))) 73 | batch_time_key = utils.AverageMeter('Time', ':6.3f') 74 | batch_time_query = utils.AverageMeter('Time', ':6.3f') 75 | # switch to evaluate mode 76 | model.eval() 77 | 78 | feats_base = [] 79 | target_base = [] 80 | feats_query = [] 81 | target_query = [] 82 | 83 | with torch.no_grad(): 84 | start = time.time() 85 | end = time.time() 86 | # Memory features 87 | for i, (images, target, _) in enumerate(val_loader_base): 88 | if args.gpu is not None: 89 | images = images.cuda(args.gpu, non_blocking=True) 90 | target = target.cuda(args.gpu, non_blocking=True) 91 | 92 | # compute features 93 | feats = model(images) 94 | # L2 normalization 95 | feats = nn.functional.normalize(feats, dim=1) 96 | 97 | feats_base.append(feats) 98 | target_base.append(target) 99 | 100 | # measure elapsed time 101 | batch_time_key.update(time.time() - end) 102 | end = time.time() 103 | 104 | if i % args.print_freq == 0: 105 | print('Extracting key features: [{0}/{1}]\t' 106 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})'.format( 107 | i, len(val_loader_base), batch_time=batch_time_key)) 108 | 109 | end = time.time() 110 | for i, (images, target, _) in enumerate(val_loader_query): 111 | if args.gpu is not None: 112 | images = images.cuda(args.gpu, non_blocking=True) 113 | target = target.cuda(args.gpu, non_blocking=True) 114 | 115 | # compute features 116 | feats = model(images) 117 | # L2 normalization 118 | feats = nn.functional.normalize(feats, dim=1) 119 | 120 | feats_query.append(feats) 121 | target_query.append(target) 122 | 123 | # measure elapsed time 124 | batch_time_query.update(time.time() - end) 125 | end = time.time() 126 | 127 | if i % args.print_freq == 0: 128 | print('Extracting query features: [{0}/{1}]\t' 129 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})'.format( 130 | i, len(val_loader_query), batch_time=batch_time_query)) 131 | 132 | feats_base = torch.cat(feats_base, dim=0) 133 | target_base = torch.cat(target_base, dim=0) 134 | feats_query = torch.cat(feats_query, dim=0) 135 | target_query = torch.cat(target_query, dim=0) 136 | feats_base = feats_base.detach().cpu().numpy() 137 | target_base = target_base.detach().cpu().numpy() 138 | feats_query = feats_query.detach().cpu().numpy() 139 | target_query = target_query.detach().cpu().numpy() 140 | feat_time = time.time() - start 141 | 142 | # KNN search 143 | index = faiss.IndexFlatL2(feats_base.shape[1]) 144 | index.add(feats_base) 145 | D, I = index.search(feats_query, args.num_nn) 146 | preds = np.array([np.bincount(target_base[n]).argmax() for n in I]) 147 | 148 | NN_acc = (preds == target_query).sum() / len(target_query) * 100.0 149 | knn_time = time.time() - start - feat_time 150 | print("finished KNN evaluation, feature time: {}, knn time: {}".format( 151 | timedelta(seconds=feat_time), timedelta(seconds=knn_time))) 152 | print(' * NN Acc@1 {:.3f}'.format(NN_acc)) 153 | 154 | return NN_acc 155 | 156 | 157 | 158 | def ss_face_validate(val_loader, model, args, threshold=0.6): 159 | """ 160 | https://github.com/sakshamjindal/Face-Matching 161 | """ 162 | batch_time = utils.AverageMeter('Time', ':6.3f') 163 | top1 = utils.AverageMeter('Acc@1', ':6.2f') 164 | progress = utils.ProgressMeter( 165 | len(val_loader), 166 | [batch_time, top1], 167 | prefix='Test: ') 168 | 169 | cos = nn.CosineSimilarity(dim=1, eps=1e-6) 170 | 171 | # switch to evaluate mode 172 | model.eval() 173 | model = model.module if hasattr(model, 'module') else model 174 | 175 | with torch.no_grad(): 176 | end = time.time() 177 | for i, (img1, img2, target) in enumerate(val_loader): 178 | img1 = img1.cuda(non_blocking=True) 179 | img2 = img2.cuda(non_blocking=True) 180 | target = target.cuda(non_blocking=True) 181 | 182 | # compute output 183 | embedding1, _, _ = model.online_net(img1) 184 | embedding2, _, _ = model.online_net(img2) 185 | 186 | embedding1 = embedding1.squeeze(-1) 187 | embedding2 = embedding2.squeeze(-1) 188 | 189 | assert embedding1.ndim == 2 190 | 191 | # measure accuracy and record loss 192 | cosine_similarity = cos(embedding1, embedding2) 193 | pred = (cosine_similarity >= threshold).to(torch.float32) 194 | acc1 = (pred == target).float().sum() * 100.0 / (target.shape[0]) 195 | 196 | top1.update(acc1.item(), img1.size(0)) 197 | 198 | # measure elapsed time 199 | batch_time.update(time.time() - end) 200 | end = time.time() 201 | 202 | if i % args.print_freq == 0: 203 | progress.display(i) 204 | 205 | # TODO: this should also be done with the ProgressMeter 206 | print(' * Acc@1 {top1.avg:.3f}' 207 | .format(top1=top1)) 208 | 209 | return top1.avg 210 | 211 | 212 | def validate_multilabel(val_loader, model, criterion, args): 213 | batch_time = utils.AverageMeter('Time', ':6.3f') 214 | losses = utils.AverageMeter('Loss', ':.4e') 215 | top1 = utils.AverageMeter('Acc@1', ':6.2f') 216 | progress = utils.ProgressMeter( 217 | len(val_loader), 218 | [batch_time, losses, top1], 219 | prefix='Test: ') 220 | 221 | # switch to evaluate mode 222 | model.eval() 223 | 224 | with torch.no_grad(): 225 | end = time.time() 226 | for i, (images, target, _) in enumerate(val_loader): 227 | if args.gpu is not None: 228 | images = images.cuda(args.gpu, non_blocking=True) 229 | target = target.cuda(args.gpu, non_blocking=True).float() 230 | 231 | # compute output 232 | output = model(images) 233 | loss = criterion(output, target) 234 | 235 | # measure accuracy and record loss 236 | acc1 = utils.accuracy_multilabel(torch.sigmoid(output), target) 237 | 238 | if is_distributed_training_run(): 239 | # torch.distributed.barrier() 240 | acc1 = all_reduce_mean(acc1) 241 | 242 | losses.update(loss.item(), images.size(0)) 243 | top1.update(acc1.item(), images.size(0)) 244 | 245 | # measure elapsed time 246 | batch_time.update(time.time() - end) 247 | end = time.time() 248 | 249 | if i % args.print_freq == 0: 250 | progress.display(i) 251 | 252 | # TODO: this should also be done with the ProgressMeter 253 | print(' * Acc@1 {top1.avg:.3f} Loss {loss.avg:.4f}' 254 | .format(top1=top1, loss=losses)) 255 | 256 | return top1.avg 257 | 258 | 259 | if __name__ == '__main__': 260 | import backbone as backbone_models 261 | from models import get_model 262 | import torchvision 263 | import torchvision.transforms as transforms 264 | 265 | model_func = get_model("LEWELB_EMAN") 266 | norm_layer = None 267 | model = model_func( 268 | backbone_models.__dict__["resnet50_encoder"], 269 | dim=256, 270 | m=0.996, 271 | hid_dim=4096, 272 | norm_layer=norm_layer, 273 | num_neck_mlp=2, 274 | scale=1., 275 | l2_norm=True, 276 | num_heads=4, 277 | loss_weight=0.5, 278 | mask_type="max" 279 | ) 280 | print(model) 281 | 282 | model.cuda() 283 | 284 | transform_test = transforms.Compose([ 285 | transforms.Resize((224, 224)), 286 | # transforms.CenterCrop(args.image_size), 287 | transforms.ToTensor(), 288 | ]) 289 | val_dataset = torchvision.datasets.LFWPairs(root="../data/lfw", split="test", 290 | transform=transform_test, download=True) 291 | print(set(val_dataset.targets)) 292 | 293 | val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=8, shuffle=False, num_workers=8, pin_memory=True, persistent_workers=True) 294 | 295 | ss_face_validate(val_loader, model, None) -------------------------------------------------------------------------------- /data/celeba.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | """ 5 | """ 6 | 7 | __author__ = "GZ" 8 | 9 | import os 10 | import sys 11 | import csv 12 | import pathlib 13 | import numpy as np 14 | from tqdm import tqdm 15 | from collections import namedtuple 16 | import csv 17 | from functools import partial 18 | from typing import Any, Callable, List, Optional, Union, Tuple 19 | import PIL 20 | from PIL import Image 21 | import matplotlib.pyplot as plt 22 | 23 | if sys.platform == 'win32': 24 | os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" 25 | 26 | import torch 27 | from torchvision.datasets.utils import check_integrity, download_file_from_google_drive, extract_archive, verify_str_arg 28 | from torchvision.datasets import VisionDataset 29 | import torchvision.transforms as transforms 30 | 31 | # Root directory of the project 32 | try: 33 | abspath = os.path.abspath(__file__) 34 | except NameError: 35 | abspath = os.getcwd() 36 | ROOT_DIR = os.path.dirname(abspath) 37 | 38 | 39 | CSV = namedtuple("CSV", ["header", "index", "data"]) 40 | 41 | 42 | class CelebA(VisionDataset): 43 | """`Large-scale CelebFaces Attributes (CelebA) Dataset `_ Dataset. 44 | 45 | Args: 46 | root (string): Root directory where images are downloaded to. 47 | split (string): One of {'train', 'valid', 'test', 'all'}. 48 | Accordingly dataset is selected. 49 | target_type (string or list, optional): Type of target to use, ``attr``, ``identity``, ``bbox``, 50 | or ``landmarks``. Can also be a list to output a tuple with all specified target types. 51 | The targets represent: 52 | 53 | - ``attr`` (np.array shape=(40,) dtype=int): binary (0, 1) labels for attributes 54 | - ``identity`` (int): label for each person (data points with the same identity are the same person) 55 | - ``bbox`` (np.array shape=(4,) dtype=int): bounding box (x, y, width, height) 56 | - ``landmarks`` (np.array shape=(10,) dtype=int): landmark points (lefteye_x, lefteye_y, righteye_x, 57 | righteye_y, nose_x, nose_y, leftmouth_x, leftmouth_y, rightmouth_x, rightmouth_y) 58 | 59 | Defaults to ``attr``. If empty, ``None`` will be returned as target. 60 | 61 | transform (callable, optional): A function/transform that takes in an PIL image 62 | and returns a transformed version. E.g, ``transforms.ToTensor`` 63 | target_transform (callable, optional): A function/transform that takes in the 64 | target and transforms it. 65 | download (bool, optional): If true, downloads the dataset from the internet and 66 | puts it in root directory. If dataset is already downloaded, it is not 67 | downloaded again. 68 | """ 69 | 70 | base_folder = "celeba" 71 | # There currently does not appear to be a easy way to extract 7z in python (without introducing additional 72 | # dependencies). The "in-the-wild" (not aligned+cropped) images are only in 7z, so they are not available 73 | # right now. 74 | file_list = [ 75 | # File ID MD5 Hash Filename 76 | # ("0B7EVK8r0v71pZjFTYXZWM3FlRnM", "00d2c5bc6d35e252742224ab0c1e8fcb", "img_align_celeba.zip"), 77 | # ("0B7EVK8r0v71pbWNEUjJKdDQ3dGc","b6cd7e93bc7a96c2dc33f819aa3ac651", "img_align_celeba_png.7z"), 78 | # ("0B7EVK8r0v71peklHb0pGdDl6R28", "b6cd7e93bc7a96c2dc33f819aa3ac651", "img_celeba.7z"), 79 | ("0B7EVK8r0v71pblRyaVFSWGxPY0U", "75e246fa4810816ffd6ee81facbd244c", "list_attr_celeba.txt"), 80 | ("1_ee_0u7vcNLOfNLegJRHmolfH5ICW-XS", "32bd1bd63d3c78cd57e08160ec5ed1e2", "identity_CelebA.txt"), 81 | ("0B7EVK8r0v71pbThiMVRxWXZ4dU0", "00566efa6fedff7a56946cd1c10f1c16", "list_bbox_celeba.txt"), 82 | ("0B7EVK8r0v71pd0FJY3Blby1HUTQ", "cc24ecafdb5b50baae59b03474781f8c", "list_landmarks_align_celeba.txt"), 83 | # ("0B7EVK8r0v71pTzJIdlJWdHczRlU", "063ee6ddb681f96bc9ca28c6febb9d1a", "list_landmarks_celeba.txt"), 84 | ("0B7EVK8r0v71pY0NSMzRuSXJEVkk", "d32c9cbf5e040fd4025c592c306e6668", "list_eval_partition.txt"), 85 | ] 86 | 87 | def __init__( 88 | self, 89 | root: str, 90 | split: str = "train", 91 | target_type: Union[List[str], str] = "attr", 92 | transform: Optional[Callable] = None, 93 | target_transform: Optional[Callable] = None, 94 | download: bool = False, 95 | crop=False 96 | ) -> None: 97 | super(CelebA, self).__init__(root, transform=transform, 98 | target_transform=target_transform) 99 | self.split = split 100 | self.crop = crop 101 | if isinstance(target_type, list): 102 | self.target_type = target_type 103 | else: 104 | self.target_type = [target_type] 105 | 106 | if not self.target_type and self.target_transform is not None: 107 | raise RuntimeError('target_transform is specified but target_type is empty') 108 | 109 | if download: 110 | self.download() 111 | 112 | if not self._check_integrity(): 113 | raise RuntimeError('Dataset not found or corrupted.' + 114 | ' You can use download=True to download it') 115 | 116 | split_map = { 117 | "train": 0, 118 | "valid": 1, 119 | "test": 2, 120 | "all": None, 121 | } 122 | split_ = split_map[verify_str_arg(split.lower(), "split", 123 | ("train", "valid", "test", "all"))] 124 | splits = self._load_csv("list_eval_partition.txt") 125 | identity = self._load_csv("identity_CelebA.txt") 126 | bbox = self._load_csv("list_bbox_celeba.txt", header=1) 127 | landmarks_align = self._load_csv("list_landmarks_align_celeba.txt", header=1) 128 | attr = self._load_csv("list_attr_celeba.txt", header=1) 129 | 130 | mask = slice(None) if split_ is None else (splits.data == split_).squeeze() 131 | 132 | if mask == slice(None): # if split == "all" 133 | self.filename = splits.index 134 | else: 135 | self.filename = [splits.index[i] for i in torch.squeeze(torch.nonzero(mask))] 136 | self.identity = identity.data[mask] 137 | self.bbox = bbox.data[mask] 138 | self.landmarks_align = landmarks_align.data[mask] 139 | self.attr = attr.data[mask] 140 | # map from {-1, 1} to {0, 1} 141 | self.attr = torch.div(self.attr + 1, 2, rounding_mode='floor') 142 | self.attr_names = attr.header 143 | 144 | def _load_csv( 145 | self, 146 | filename: str, 147 | header: Optional[int] = None, 148 | ) -> CSV: 149 | data, indices, headers = [], [], [] 150 | 151 | fn = partial(os.path.join, self.root, self.base_folder) 152 | with open(fn(filename)) as csv_file: 153 | data = list(csv.reader(csv_file, delimiter=' ', skipinitialspace=True)) 154 | 155 | if header is not None: 156 | headers = data[header] 157 | data = data[header + 1:] 158 | 159 | indices = [row[0] for row in data] 160 | data = [row[1:] for row in data] 161 | data_int = [list(map(int, i)) for i in data] 162 | 163 | return CSV(headers, indices, torch.tensor(data_int)) 164 | 165 | def _check_integrity(self) -> bool: 166 | for (_, md5, filename) in self.file_list: 167 | fpath = os.path.join(self.root, self.base_folder, filename) 168 | _, ext = os.path.splitext(filename) 169 | # Allow original archive to be deleted (zip and 7z) 170 | # Only need the extracted images 171 | if ext not in [".zip", ".7z"] and not check_integrity(fpath, md5): 172 | return False 173 | 174 | # Should check a hash of the images 175 | return os.path.isdir(os.path.join(self.root, self.base_folder, "img_align_celeba")) 176 | 177 | def download(self) -> None: 178 | import zipfile 179 | 180 | if self._check_integrity(): 181 | print('Files already downloaded and verified') 182 | return 183 | 184 | for (file_id, md5, filename) in self.file_list: 185 | download_file_from_google_drive(file_id, os.path.join(self.root, self.base_folder), filename, md5) 186 | 187 | with zipfile.ZipFile(os.path.join(self.root, self.base_folder, "img_align_celeba.zip"), "r") as f: 188 | f.extractall(os.path.join(self.root, self.base_folder)) 189 | 190 | def __getitem__(self, index: int): 191 | X = PIL.Image.open(os.path.join(self.root, self.base_folder, "img_align_celeba", self.filename[index])) 192 | 193 | target: Any = [] 194 | for t in self.target_type: 195 | if t == "attr": 196 | target.append(self.attr[index, :]) 197 | elif t == "identity": 198 | target.append(self.identity[index, 0]) 199 | elif t == "bbox": 200 | target.append(self.bbox[index, :]) 201 | elif t == "landmarks": 202 | target.append(self.landmarks_align[index, :]) 203 | else: 204 | # TODO: refactor with utils.verify_str_arg 205 | raise ValueError("Target type \"{}\" is not recognized.".format(t)) 206 | 207 | if self.crop: 208 | bbox = self.bbox[index, :] 209 | width, height = X.size 210 | left = bbox[0] 211 | top = bbox[1] 212 | right = bbox[0] + bbox[2] 213 | bottom = bbox[1] + bbox[3] 214 | X = X.crop((left, top, right, bottom)) 215 | 216 | if self.transform is not None: 217 | X = self.transform(X) 218 | 219 | if target: 220 | target = tuple(target) if len(target) > 1 else target[0] 221 | 222 | if self.target_transform is not None: 223 | target = self.target_transform(target) 224 | else: 225 | target = None 226 | 227 | return X, target, index 228 | 229 | def __len__(self) -> int: 230 | return len(self.attr) 231 | 232 | def extra_repr(self) -> str: 233 | lines = ["Target type: {target_type}", "Split: {split}"] 234 | return '\n'.join(lines).format(**self.__dict__) 235 | 236 | 237 | if __name__ == '__main__': 238 | # 218, 178 239 | display_transform = transforms.Compose([ 240 | transforms.Resize((224, 224)), 241 | transforms.ToTensor() 242 | ]) 243 | 244 | split = "train" 245 | dataset = CelebA(root="../data", split=split, transform=display_transform, crop=False) 246 | print(dataset) 247 | print(dataset[0]) 248 | 249 | loader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=8, pin_memory=True, 250 | drop_last=False) 251 | 252 | with torch.no_grad(): 253 | for i, (images, target, _) in enumerate(tqdm(loader)): 254 | img = np.clip(images.cpu().numpy(), 0, 1) # [0, 1] 255 | img = img.transpose(0, 2, 3, 1) 256 | img = (img * 255).astype(np.uint8) 257 | img = img.squeeze() 258 | 259 | fig, axs = plt.subplots(1, 1, figsize=(8, 8)) 260 | axs.imshow(img) 261 | axs.axis("off") 262 | plt.show() 263 | -------------------------------------------------------------------------------- /backbone/resnet.py: -------------------------------------------------------------------------------- 1 | # some code in this file is adapted from 2 | # https://github.com/pytorch/pytorch 3 | # Licensed under a BSD-style license. 4 | # Modifications Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 5 | # SPDX-License-Identifier: CC-BY-NC-4.0 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | __all__ = ['resnet18_encoder', 'resnet34_encoder', 'resnet50_encoder', 'resnet101_encoder', 11 | 'resnet50w2x_encoder', 'resnet50w2x_cls'] 12 | 13 | 14 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 15 | """3x3 convolution with padding""" 16 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 17 | padding=dilation, groups=groups, bias=False, dilation=dilation) 18 | 19 | 20 | def conv1x1(in_planes, out_planes, stride=1): 21 | """1x1 convolution""" 22 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 23 | 24 | 25 | class BasicBlock(nn.Module): 26 | expansion = 1 27 | 28 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 29 | base_width=64, dilation=1, norm_layer=None): 30 | super(BasicBlock, self).__init__() 31 | if norm_layer is None: 32 | norm_layer = nn.BatchNorm2d 33 | if groups != 1 or base_width != 64: 34 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 35 | if dilation > 1: 36 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 37 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 38 | self.conv1 = conv3x3(inplanes, planes, stride) 39 | self.bn1 = norm_layer(planes) 40 | self.relu = nn.ReLU(inplace=True) 41 | self.conv2 = conv3x3(planes, planes) 42 | self.bn2 = norm_layer(planes) 43 | self.downsample = downsample 44 | self.stride = stride 45 | 46 | def forward(self, x): 47 | identity = x 48 | 49 | out = self.conv1(x) 50 | out = self.bn1(out) 51 | out = self.relu(out) 52 | 53 | out = self.conv2(out) 54 | out = self.bn2(out) 55 | 56 | if self.downsample is not None: 57 | identity = self.downsample(x) 58 | 59 | out += identity 60 | out = self.relu(out) 61 | 62 | return out 63 | 64 | 65 | class Bottleneck(nn.Module): 66 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) 67 | # while original implementation places the stride at the first 1x1 convolution(self.conv1) 68 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. 69 | # This variant is also known as ResNet V1.5 and improves accuracy according to 70 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. 71 | 72 | expansion = 4 73 | 74 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 75 | base_width=64, dilation=1, norm_layer=None): 76 | super(Bottleneck, self).__init__() 77 | if norm_layer is None: 78 | norm_layer = nn.BatchNorm2d 79 | width = int(planes * (base_width / 64.)) * groups 80 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 81 | self.conv1 = conv1x1(inplanes, width) 82 | self.bn1 = norm_layer(width) 83 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 84 | self.bn2 = norm_layer(width) 85 | self.conv3 = conv1x1(width, planes * self.expansion) 86 | self.bn3 = norm_layer(planes * self.expansion) 87 | self.relu = nn.ReLU(inplace=True) 88 | self.downsample = downsample 89 | self.stride = stride 90 | 91 | def forward(self, x): 92 | identity = x 93 | 94 | out = self.conv1(x) 95 | out = self.bn1(out) 96 | out = self.relu(out) 97 | 98 | out = self.conv2(out) 99 | out = self.bn2(out) 100 | out = self.relu(out) 101 | 102 | out = self.conv3(out) 103 | out = self.bn3(out) 104 | 105 | if self.downsample is not None: 106 | identity = self.downsample(x) 107 | 108 | out += identity 109 | out = self.relu(out) 110 | 111 | return out 112 | 113 | 114 | class ResNet(nn.Module): 115 | 116 | def __init__(self, block, layers, zero_init_residual=False, 117 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 118 | norm_layer=None, width_multiplier=1, with_avgpool=True): 119 | super(ResNet, self).__init__() 120 | if norm_layer is None: 121 | norm_layer = nn.BatchNorm2d 122 | self._norm_layer = norm_layer 123 | 124 | self.inplanes = 64 * width_multiplier 125 | self.dilation = 1 126 | if replace_stride_with_dilation is None: 127 | # each element in the tuple indicates if we should replace 128 | # the 2x2 stride with a dilated convolution instead 129 | replace_stride_with_dilation = [False, False, False] 130 | if len(replace_stride_with_dilation) != 3: 131 | raise ValueError("replace_stride_with_dilation should be None " 132 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 133 | self.groups = groups 134 | self.base_width = width_per_group 135 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 136 | bias=False) 137 | self.bn1 = norm_layer(self.inplanes) 138 | self.relu = nn.ReLU(inplace=True) 139 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 140 | self.layer1 = self._make_layer(block, 64 * width_multiplier, layers[0]) 141 | self.layer2 = self._make_layer(block, 128 * width_multiplier, layers[1], stride=2, 142 | dilate=replace_stride_with_dilation[0]) 143 | self.layer3 = self._make_layer(block, 256 * width_multiplier, layers[2], stride=2, 144 | dilate=replace_stride_with_dilation[1]) 145 | self.layer4 = self._make_layer(block, 512 * width_multiplier, layers[3], stride=2, 146 | dilate=replace_stride_with_dilation[2]) 147 | self.with_avgpool = with_avgpool 148 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) if with_avgpool else nn.Identity() 149 | 150 | self.out_channels = 512 * width_multiplier * block.expansion 151 | 152 | for m in self.modules(): 153 | if isinstance(m, nn.Conv2d): 154 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 155 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 156 | nn.init.constant_(m.weight, 1) 157 | nn.init.constant_(m.bias, 0) 158 | 159 | # Zero-initialize the last BN in each residual branch, 160 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 161 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 162 | if zero_init_residual: 163 | for m in self.modules(): 164 | if isinstance(m, Bottleneck): 165 | nn.init.constant_(m.bn3.weight, 0) 166 | elif isinstance(m, BasicBlock): 167 | nn.init.constant_(m.bn2.weight, 0) 168 | 169 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 170 | norm_layer = self._norm_layer 171 | downsample = None 172 | previous_dilation = self.dilation 173 | if dilate: 174 | self.dilation *= stride 175 | stride = 1 176 | if stride != 1 or self.inplanes != planes * block.expansion: 177 | downsample = nn.Sequential( 178 | conv1x1(self.inplanes, planes * block.expansion, stride), 179 | norm_layer(planes * block.expansion), 180 | ) 181 | 182 | layers = [] 183 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 184 | self.base_width, previous_dilation, norm_layer)) 185 | self.inplanes = planes * block.expansion 186 | for _ in range(1, blocks): 187 | layers.append(block(self.inplanes, planes, groups=self.groups, 188 | base_width=self.base_width, dilation=self.dilation, 189 | norm_layer=norm_layer)) 190 | 191 | return nn.Sequential(*layers) 192 | 193 | def _forward_impl(self, x): 194 | # See note [TorchScript super()] 195 | x = self.conv1(x) 196 | x = self.bn1(x) 197 | x = self.relu(x) 198 | x = self.maxpool(x) 199 | 200 | x = self.layer1(x) 201 | x = self.layer2(x) 202 | x = self.layer3(x) 203 | x = self.layer4(x) 204 | 205 | x = self.avgpool(x) 206 | x = torch.flatten(x, 1) if self.with_avgpool else x 207 | 208 | return x 209 | 210 | def forward(self, x): 211 | return self._forward_impl(x) 212 | 213 | 214 | class ResNetCls(ResNet): 215 | 216 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 217 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 218 | norm_layer=None, width_multiplier=1, normalize=False): 219 | super(ResNetCls, self).__init__( 220 | block, layers, 221 | zero_init_residual=zero_init_residual, 222 | groups=groups, 223 | width_per_group=width_per_group, 224 | replace_stride_with_dilation=replace_stride_with_dilation, 225 | norm_layer=norm_layer, 226 | width_multiplier=width_multiplier, 227 | ) 228 | self.fc = nn.Linear(self.out_channels, num_classes) 229 | self.normalize = normalize 230 | 231 | def _forward_impl(self, x): 232 | # See note [TorchScript super()] 233 | x = self.conv1(x) 234 | x = self.bn1(x) 235 | x = self.relu(x) 236 | x = self.maxpool(x) 237 | 238 | x = self.layer1(x) 239 | x = self.layer2(x) 240 | x = self.layer3(x) 241 | x = self.layer4(x) 242 | 243 | x = self.avgpool(x) 244 | x = torch.flatten(x, 1) 245 | if self.normalize: 246 | x = nn.functional.normalize(x, dim=1) 247 | x = self.fc(x) 248 | 249 | return x 250 | 251 | 252 | def resnet18_encoder(**kwargs): 253 | r"""ResNet-18 model from 254 | `"Deep Residual Learning for Image Recognition" `_ 255 | """ 256 | return ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 257 | 258 | 259 | def resnet34_encoder(**kwargs): 260 | r"""ResNet-34 model from 261 | `"Deep Residual Learning for Image Recognition" `_ 262 | """ 263 | return ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 264 | 265 | 266 | def resnet50_encoder(**kwargs): 267 | r"""ResNet-50 model from 268 | `"Deep Residual Learning for Image Recognition" `_ 269 | """ 270 | return ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 271 | 272 | 273 | def resnet50w2x_encoder(**kwargs): 274 | return ResNet(Bottleneck, [3, 4, 6, 3], width_multiplier=2, **kwargs) 275 | 276 | 277 | def resnet50w2x_cls(**kwargs): 278 | model = ResNetCls(Bottleneck, [3, 4, 6, 3], width_multiplier=2, **kwargs) 279 | return model 280 | 281 | 282 | def resnet101_encoder(**kwargs): 283 | r"""ResNet-101 model from 284 | `"Deep Residual Learning for Image Recognition" `_ 285 | """ 286 | return ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 287 | -------------------------------------------------------------------------------- /models/transformers/transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/transformer.py 3 | """ 4 | Transformer class. 5 | 6 | Copy-paste from torch.nn.Transformer with modifications: 7 | * positional encodings are passed in MHattention 8 | * extra LN at the end of encoder is removed 9 | * decoder returns a stack of activations from all decoding layers 10 | """ 11 | import copy 12 | from typing import List, Optional 13 | 14 | import torch 15 | import torch.nn.functional as F 16 | from torch import Tensor, nn 17 | 18 | 19 | class Transformer(nn.Module): 20 | def __init__( 21 | self, 22 | d_model=512, 23 | nhead=8, 24 | num_encoder_layers=6, 25 | num_decoder_layers=6, 26 | dim_feedforward=2048, 27 | dropout=0.1, 28 | activation="relu", 29 | normalize_before=False, 30 | return_intermediate_dec=False, 31 | ): 32 | super().__init__() 33 | 34 | encoder_layer = TransformerEncoderLayer( 35 | d_model, nhead, dim_feedforward, dropout, activation, normalize_before 36 | ) 37 | encoder_norm = nn.LayerNorm(d_model) if normalize_before else None 38 | self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) 39 | 40 | decoder_layer = TransformerDecoderLayer( 41 | d_model, nhead, dim_feedforward, dropout, activation, normalize_before 42 | ) 43 | decoder_norm = nn.LayerNorm(d_model) 44 | self.decoder = TransformerDecoder( 45 | decoder_layer, 46 | num_decoder_layers, 47 | decoder_norm, 48 | return_intermediate=return_intermediate_dec, 49 | ) 50 | 51 | self._reset_parameters() 52 | 53 | self.d_model = d_model 54 | self.nhead = nhead 55 | 56 | def _reset_parameters(self): 57 | for p in self.parameters(): 58 | if p.dim() > 1: 59 | nn.init.xavier_uniform_(p) 60 | 61 | def forward(self, src, mask, query_embed, pos_embed): 62 | # flatten NxCxHxW to HWxNxC 63 | bs, c, h, w = src.shape 64 | src = src.flatten(2).permute(2, 0, 1) 65 | pos_embed = pos_embed.flatten(2).permute(2, 0, 1) 66 | query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) 67 | if mask is not None: 68 | mask = mask.flatten(1) 69 | 70 | tgt = torch.zeros_like(query_embed) 71 | memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed) 72 | hs = self.decoder( 73 | tgt, memory, memory_key_padding_mask=mask, pos=pos_embed, query_pos=query_embed 74 | ) 75 | return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, h, w) 76 | 77 | 78 | class TransformerEncoder(nn.Module): 79 | def __init__(self, encoder_layer, num_layers, norm=None): 80 | super().__init__() 81 | self.layers = _get_clones(encoder_layer, num_layers) 82 | self.num_layers = num_layers 83 | self.norm = norm 84 | 85 | def forward( 86 | self, 87 | src, 88 | mask: Optional[Tensor] = None, 89 | src_key_padding_mask: Optional[Tensor] = None, 90 | pos: Optional[Tensor] = None, 91 | ): 92 | output = src 93 | 94 | for layer in self.layers: 95 | output = layer( 96 | output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, pos=pos 97 | ) 98 | 99 | if self.norm is not None: 100 | output = self.norm(output) 101 | 102 | return output 103 | 104 | 105 | class TransformerDecoder(nn.Module): 106 | def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False): 107 | super().__init__() 108 | self.layers = _get_clones(decoder_layer, num_layers) 109 | self.num_layers = num_layers 110 | self.norm = norm 111 | self.return_intermediate = return_intermediate 112 | 113 | def forward( 114 | self, 115 | tgt, 116 | memory, 117 | tgt_mask: Optional[Tensor] = None, 118 | memory_mask: Optional[Tensor] = None, 119 | tgt_key_padding_mask: Optional[Tensor] = None, 120 | memory_key_padding_mask: Optional[Tensor] = None, 121 | pos: Optional[Tensor] = None, 122 | query_pos: Optional[Tensor] = None, 123 | ): 124 | output = tgt 125 | 126 | intermediate = [] 127 | 128 | for layer in self.layers: 129 | output = layer( 130 | output, 131 | memory, 132 | tgt_mask=tgt_mask, 133 | memory_mask=memory_mask, 134 | tgt_key_padding_mask=tgt_key_padding_mask, 135 | memory_key_padding_mask=memory_key_padding_mask, 136 | pos=pos, 137 | query_pos=query_pos, 138 | ) 139 | if self.return_intermediate: 140 | intermediate.append(self.norm(output)) 141 | 142 | if self.norm is not None: 143 | output = self.norm(output) 144 | if self.return_intermediate: 145 | intermediate.pop() 146 | intermediate.append(output) 147 | 148 | if self.return_intermediate: 149 | return torch.stack(intermediate) 150 | 151 | return output.unsqueeze(0) 152 | 153 | 154 | class TransformerEncoderLayer(nn.Module): 155 | def __init__( 156 | self, 157 | d_model, 158 | nhead, 159 | dim_feedforward=2048, 160 | dropout=0.1, 161 | activation="relu", 162 | normalize_before=False, 163 | ): 164 | super().__init__() 165 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 166 | # Implementation of Feedforward model 167 | self.linear1 = nn.Linear(d_model, dim_feedforward) 168 | self.dropout = nn.Dropout(dropout) 169 | self.linear2 = nn.Linear(dim_feedforward, d_model) 170 | 171 | self.norm1 = nn.LayerNorm(d_model) 172 | self.norm2 = nn.LayerNorm(d_model) 173 | self.dropout1 = nn.Dropout(dropout) 174 | self.dropout2 = nn.Dropout(dropout) 175 | 176 | self.activation = _get_activation_fn(activation) 177 | self.normalize_before = normalize_before 178 | 179 | def with_pos_embed(self, tensor, pos: Optional[Tensor]): 180 | return tensor if pos is None else tensor + pos 181 | 182 | def forward_post( 183 | self, 184 | src, 185 | src_mask: Optional[Tensor] = None, 186 | src_key_padding_mask: Optional[Tensor] = None, 187 | pos: Optional[Tensor] = None, 188 | ): 189 | q = k = self.with_pos_embed(src, pos) 190 | src2 = self.self_attn( 191 | q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask 192 | )[0] 193 | src = src + self.dropout1(src2) 194 | src = self.norm1(src) 195 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) 196 | src = src + self.dropout2(src2) 197 | src = self.norm2(src) 198 | return src 199 | 200 | def forward_pre( 201 | self, 202 | src, 203 | src_mask: Optional[Tensor] = None, 204 | src_key_padding_mask: Optional[Tensor] = None, 205 | pos: Optional[Tensor] = None, 206 | ): 207 | src2 = self.norm1(src) 208 | q = k = self.with_pos_embed(src2, pos) 209 | src2 = self.self_attn( 210 | q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask 211 | )[0] 212 | src = src + self.dropout1(src2) 213 | src2 = self.norm2(src) 214 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) 215 | src = src + self.dropout2(src2) 216 | return src 217 | 218 | def forward( 219 | self, 220 | src, 221 | src_mask: Optional[Tensor] = None, 222 | src_key_padding_mask: Optional[Tensor] = None, 223 | pos: Optional[Tensor] = None, 224 | ): 225 | if self.normalize_before: 226 | return self.forward_pre(src, src_mask, src_key_padding_mask, pos) 227 | return self.forward_post(src, src_mask, src_key_padding_mask, pos) 228 | 229 | 230 | class TransformerDecoderLayer(nn.Module): 231 | def __init__( 232 | self, 233 | d_model, 234 | nhead, 235 | dim_feedforward=2048, 236 | dropout=0.1, 237 | activation="relu", 238 | normalize_before=False, 239 | ): 240 | super().__init__() 241 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 242 | self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 243 | # Implementation of Feedforward model 244 | self.linear1 = nn.Linear(d_model, dim_feedforward) 245 | self.dropout = nn.Dropout(dropout) 246 | self.linear2 = nn.Linear(dim_feedforward, d_model) 247 | 248 | self.norm1 = nn.LayerNorm(d_model) 249 | self.norm2 = nn.LayerNorm(d_model) 250 | self.norm3 = nn.LayerNorm(d_model) 251 | self.dropout1 = nn.Dropout(dropout) 252 | self.dropout2 = nn.Dropout(dropout) 253 | self.dropout3 = nn.Dropout(dropout) 254 | 255 | self.activation = _get_activation_fn(activation) 256 | self.normalize_before = normalize_before 257 | 258 | def with_pos_embed(self, tensor, pos: Optional[Tensor]): 259 | return tensor if pos is None else tensor + pos 260 | 261 | def forward_post( 262 | self, 263 | tgt, 264 | memory, 265 | tgt_mask: Optional[Tensor] = None, 266 | memory_mask: Optional[Tensor] = None, 267 | tgt_key_padding_mask: Optional[Tensor] = None, 268 | memory_key_padding_mask: Optional[Tensor] = None, 269 | pos: Optional[Tensor] = None, 270 | query_pos: Optional[Tensor] = None, 271 | ): 272 | q = k = self.with_pos_embed(tgt, query_pos) 273 | tgt2 = self.self_attn( 274 | q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask 275 | )[0] 276 | tgt = tgt + self.dropout1(tgt2) 277 | tgt = self.norm1(tgt) 278 | tgt2 = self.multihead_attn( 279 | query=self.with_pos_embed(tgt, query_pos), 280 | key=self.with_pos_embed(memory, pos), 281 | value=memory, 282 | attn_mask=memory_mask, 283 | key_padding_mask=memory_key_padding_mask, 284 | )[0] 285 | tgt = tgt + self.dropout2(tgt2) 286 | tgt = self.norm2(tgt) 287 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) 288 | tgt = tgt + self.dropout3(tgt2) 289 | tgt = self.norm3(tgt) 290 | return tgt 291 | 292 | def forward_pre( 293 | self, 294 | tgt, 295 | memory, 296 | tgt_mask: Optional[Tensor] = None, 297 | memory_mask: Optional[Tensor] = None, 298 | tgt_key_padding_mask: Optional[Tensor] = None, 299 | memory_key_padding_mask: Optional[Tensor] = None, 300 | pos: Optional[Tensor] = None, 301 | query_pos: Optional[Tensor] = None, 302 | ): 303 | tgt2 = self.norm1(tgt) 304 | q = k = self.with_pos_embed(tgt2, query_pos) 305 | tgt2 = self.self_attn( 306 | q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask 307 | )[0] 308 | tgt = tgt + self.dropout1(tgt2) 309 | tgt2 = self.norm2(tgt) 310 | tgt2 = self.multihead_attn( 311 | query=self.with_pos_embed(tgt2, query_pos), 312 | key=self.with_pos_embed(memory, pos), 313 | value=memory, 314 | attn_mask=memory_mask, 315 | key_padding_mask=memory_key_padding_mask, 316 | )[0] 317 | tgt = tgt + self.dropout2(tgt2) 318 | tgt2 = self.norm3(tgt) 319 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) 320 | tgt = tgt + self.dropout3(tgt2) 321 | return tgt 322 | 323 | def forward( 324 | self, 325 | tgt, 326 | memory, 327 | tgt_mask: Optional[Tensor] = None, 328 | memory_mask: Optional[Tensor] = None, 329 | tgt_key_padding_mask: Optional[Tensor] = None, 330 | memory_key_padding_mask: Optional[Tensor] = None, 331 | pos: Optional[Tensor] = None, 332 | query_pos: Optional[Tensor] = None, 333 | ): 334 | if self.normalize_before: 335 | return self.forward_pre( 336 | tgt, 337 | memory, 338 | tgt_mask, 339 | memory_mask, 340 | tgt_key_padding_mask, 341 | memory_key_padding_mask, 342 | pos, 343 | query_pos, 344 | ) 345 | return self.forward_post( 346 | tgt, 347 | memory, 348 | tgt_mask, 349 | memory_mask, 350 | tgt_key_padding_mask, 351 | memory_key_padding_mask, 352 | pos, 353 | query_pos, 354 | ) 355 | 356 | 357 | def _get_clones(module, N): 358 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) 359 | 360 | 361 | def _get_activation_fn(activation): 362 | """Return an activation function given a string""" 363 | if activation == "relu": 364 | return F.relu 365 | if activation == "gelu": 366 | return F.gelu 367 | if activation == "glu": 368 | return F.glu 369 | raise RuntimeError(f"activation should be relu/gelu, not {activation}.") 370 | -------------------------------------------------------------------------------- /models/lewel.py: -------------------------------------------------------------------------------- 1 | # Copyright Lang Huang (laynehuang@outlook.com). All Rights Reserved. 2 | # SPDX-License-Identifier: CC-BY-NC-4.0 3 | import sys 4 | from math import cos, pi 5 | import torch 6 | import torch.nn as nn 7 | from torch.nn import functional as F 8 | from torch.nn.modules import loss 9 | import torch.distributed as dist 10 | from classy_vision.generic.distributed_util import is_distributed_training_run 11 | 12 | from models.transformers.transformer_predictor import TransformerPredictor 13 | from utils import init 14 | 15 | 16 | class MLP1D(nn.Module): 17 | """ 18 | The non-linear neck in byol: fc-bn-relu-fc 19 | """ 20 | def __init__(self, in_channels, hid_channels, out_channels, 21 | norm_layer=None, bias=False, num_mlp=2): 22 | super(MLP1D, self).__init__() 23 | if norm_layer is None: 24 | norm_layer = nn.BatchNorm1d 25 | mlps = [] 26 | for _ in range(num_mlp-1): 27 | mlps.append(nn.Conv1d(in_channels, hid_channels, 1, bias=bias)) 28 | mlps.append(norm_layer(hid_channels)) 29 | mlps.append(nn.ReLU(inplace=True)) 30 | in_channels = hid_channels 31 | mlps.append(nn.Conv1d(hid_channels, out_channels, 1, bias=bias)) 32 | self.mlp = nn.Sequential(*mlps) 33 | 34 | def init_weights(self, init_linear='normal'): 35 | init.init_weights(self, init_linear) 36 | 37 | def forward(self, x): 38 | x = self.mlp(x) 39 | return x 40 | 41 | 42 | class ObjectNeck(nn.Module): 43 | def __init__(self, 44 | in_channels, 45 | out_channels, 46 | hid_channels=None, 47 | num_layers=1, 48 | scale=1., 49 | l2_norm=True, 50 | num_heads=8, 51 | norm_layer=None, 52 | mask_type="group", 53 | num_proto=64, 54 | temp=0.07, 55 | **kwargs): 56 | super(ObjectNeck, self).__init__() 57 | 58 | self.scale = scale 59 | self.l2_norm = l2_norm 60 | assert l2_norm 61 | self.num_heads = num_heads 62 | self.mask_type = mask_type 63 | self.temp = temp 64 | self.eps = 1e-7 65 | 66 | hid_channels = hid_channels or in_channels 67 | self.proj = MLP1D(in_channels, hid_channels, out_channels, norm_layer, num_mlp=num_layers) 68 | self.proj_obj = MLP1D(in_channels, hid_channels, out_channels, norm_layer, num_mlp=num_layers) 69 | 70 | if mask_type == "attn": 71 | # self.slot_embed = nn.Embedding(num_proto, out_channels) 72 | # self.proj_obj = MLP1D(out_channels, hid_channels, out_channels, norm_layer, num_mlp=num_layers) 73 | self.proj_attn = TransformerPredictor(in_channels=out_channels, hidden_dim=out_channels, num_queries=num_proto, 74 | nheads=8, dropout=0.1, dim_feedforward=out_channels, enc_layers=0, 75 | dec_layers=1, pre_norm=False, deep_supervision=False, 76 | mask_dim=out_channels, enforce_input_project=False, 77 | mask_classification=False, num_classes=0) 78 | 79 | def init_weights(self, init_linear='kaiming'): 80 | self.proj.init_weights(init_linear) 81 | self.proj_obj.init_weights(init_linear) 82 | 83 | def forward(self, x): 84 | out = {} 85 | 86 | b, c, h, w = x.shape 87 | 88 | # flatten and projection 89 | x_pool = F.adaptive_avg_pool2d(x, 1).flatten(2) 90 | x = x.flatten(2) # (bs, c, h*w) 91 | z = self.proj(torch.cat([x_pool, x], dim=2)) # (bs, d, 1+h*w) 92 | z_g, z_feat = torch.split(z, [1, x.shape[2]], dim=2) # (bs, d, 1), (bs, d, h*w) 93 | 94 | z_feat = z_feat.contiguous() 95 | 96 | if self.mask_type == "attn": 97 | z_feat = z_feat.view(b, -1, h, w) 98 | x = x.view(b, c, h, w) 99 | attn_out = self.proj_attn(z_feat, None) 100 | mask_embed = attn_out["mask_embed"] # (bs, q, c) 101 | out["mask_embed"] = mask_embed 102 | 103 | dots = torch.einsum('bqc,bchw->bqhw', F.normalize(mask_embed, dim=2), F.normalize(z_feat, dim=1)) 104 | obj_attn = (dots / self.temp).softmax(dim=1) + self.eps 105 | # obj_attn = (dots / 1.0).softmax(dim=1) + self.eps 106 | slots = torch.einsum('bchw,bqhw->bqc', x, obj_attn / obj_attn.sum(dim=(2, 3), keepdim=True)) 107 | # slots = torch.einsum('bchw,bqhw->bqc', z_feat, obj_attn / obj_attn.sum(dim=(2, 3), keepdim=True)) 108 | obj_attn = obj_attn.view(b, -1, h * w) 109 | out["dots"] = dots 110 | else: 111 | # do attention according to obj attention map 112 | obj_attn = F.normalize(z_feat, dim=1) if self.l2_norm else z_feat 113 | obj_attn /= self.scale 114 | obj_attn = obj_attn.view(b, self.num_heads, -1, h * w) # (bs, h, d/h, h*w) 115 | obj_attn_raw = F.softmax(obj_attn, dim=-1) 116 | 117 | if self.mask_type == "group": 118 | obj_attn = F.softmax(obj_attn, dim=-1) 119 | x = x.view(b, self.num_heads, -1, h*w) # (bs, h, c/h, h*w) 120 | obj_val = torch.matmul(x, obj_attn.transpose(3, 2)) # (bs, h, c//h, d/h) 121 | obj_val = obj_val.view(b, c, obj_attn.shape[-2]) # (bs, c, d/h) 122 | elif self.mask_type == "max": 123 | obj_attn, _ = torch.max(obj_attn, dim=1) # (bs, d/h, h*w) 124 | # obj_attn = torch.mean(obj_attn, dim=1) 125 | obj_attn = F.softmax(obj_attn, dim=-1) 126 | obj_val = torch.matmul(x, obj_attn.transpose(2, 1)) # (bs, c, d/h) 127 | elif self.mask_type == "attn": 128 | obj_val = slots.transpose(2, 1) # (bs, c, q) 129 | 130 | # projection 131 | obj_val = self.proj_obj(obj_val) # (bs, d, d/h) 132 | 133 | out["obj_attn"] = obj_attn 134 | out["obj_attn_raw"] = obj_attn_raw 135 | 136 | return z_g, obj_val, out # (bs, d, 1), (bs, d, d//h), where the second dim is channel 137 | 138 | def extra_repr(self) -> str: 139 | parts = [] 140 | for name in ["scale", "l2_norm", "num_heads"]: 141 | parts.append(f"{name}={getattr(self, name)}") 142 | return ", ".join(parts) 143 | 144 | 145 | class EncoderObj(nn.Module): 146 | def __init__(self, base_encoder, hid_dim, out_dim, norm_layer=None, num_mlp=2, 147 | scale=1., l2_norm=True, num_heads=8, mask_type="group", num_proto=64, temp=0.07): 148 | super(EncoderObj, self).__init__() 149 | self.backbone = base_encoder(norm_layer=norm_layer, with_avgpool=False) 150 | in_dim = self.backbone.out_channels 151 | self.neck = ObjectNeck(in_channels=in_dim, hid_channels=hid_dim, out_channels=out_dim, 152 | norm_layer=norm_layer, num_layers=num_mlp, 153 | scale=scale, l2_norm=l2_norm, num_heads=num_heads, mask_type=mask_type, 154 | num_proto=num_proto, temp=temp) 155 | # self.neck.init_weights(init_linear='kaiming') 156 | 157 | def forward(self, im): 158 | out = self.backbone(im) 159 | out = self.neck(out) 160 | return out 161 | 162 | 163 | class LEWELB_EMAN(nn.Module): 164 | def __init__(self, base_encoder, dim=256, m=0.996, hid_dim=4096, norm_layer=None, num_neck_mlp=2, 165 | scale=1., l2_norm=True, num_heads=8, loss_weight=0.5, mask_type="group", num_proto=64, 166 | teacher_temp=0.07, student_temp=0.1, loss_w_cluster=0.5, **kwargs): 167 | super().__init__() 168 | 169 | self.base_m = m 170 | self.curr_m = m 171 | self.loss_weight = loss_weight 172 | self.loss_w_cluster = loss_w_cluster 173 | self.mask_type = mask_type 174 | assert mask_type in ["group", "max", "attn"] 175 | self.num_proto = num_proto 176 | self.student_temp = student_temp # 0.1 177 | self.teacher_temp = teacher_temp # 0.07 178 | 179 | # create the encoders 180 | # num_classes is the output fc dimension 181 | self.online_net = EncoderObj(base_encoder, hid_dim, dim, norm_layer, num_neck_mlp, 182 | scale=scale, l2_norm=l2_norm, num_heads=num_heads, mask_type=mask_type, 183 | num_proto=num_proto, temp=self.teacher_temp) 184 | 185 | # checkpoint = torch.load("./checkpoints/lewel_b_400ep.pth", map_location="cpu") 186 | # msg = self.online_net.backbone.load_state_dict(checkpoint) 187 | # assert set(msg.missing_keys) == set() 188 | # state_dict = checkpoint['state_dict'] 189 | # state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} 190 | # state_dict = {k.replace("online_net.backbone.", ""): v for k, v in state_dict.items()} 191 | # self.online_net.backbone.load_state_dict(state_dict) 192 | 193 | self.target_net = EncoderObj(base_encoder, hid_dim, dim, norm_layer, num_neck_mlp, 194 | scale=scale, l2_norm=l2_norm, num_heads=num_heads, mask_type=mask_type, 195 | num_proto=num_proto, temp=self.teacher_temp) 196 | self.predictor = MLP1D(dim, hid_dim, dim, norm_layer=norm_layer) 197 | # self.predictor.init_weights() 198 | self.predictor_obj = MLP1D(dim, hid_dim, dim, norm_layer=norm_layer) 199 | # self.predictor_obj.init_weights() 200 | self.encoder_q = self.online_net.backbone 201 | 202 | # copy params from online model to target model 203 | for param_ol, param_tgt in zip(self.online_net.parameters(), self.target_net.parameters()): 204 | param_tgt.data.copy_(param_ol.data) # initialize 205 | param_tgt.requires_grad = False # not update by gradient 206 | 207 | self.center_momentum = 0.9 208 | self.register_buffer("center", torch.zeros(1, self.num_proto)) 209 | 210 | def mse_loss(self, pred, target): 211 | """ 212 | Args: 213 | pred (Tensor): NxC input features. 214 | target (Tensor): NxC target features. 215 | """ 216 | N = pred.size(0) 217 | pred_norm = nn.functional.normalize(pred, dim=1) 218 | target_norm = nn.functional.normalize(target, dim=1) 219 | loss = 2 - 2 * (pred_norm * target_norm).sum() / N 220 | return loss 221 | 222 | def self_distill(self, q, k): 223 | q = F.log_softmax(q / self.student_temp, dim=-1) 224 | k = F.softmax((k - self.center) / self.teacher_temp, dim=-1) 225 | return torch.sum(-k * q, dim=-1).mean() 226 | 227 | def loss_func(self, online, target): 228 | z_o, obj_o, res_o = online 229 | z_t, obj_t, res_t = target 230 | # instance-level loss 231 | z_o_pred = self.predictor(z_o).squeeze(-1) 232 | z_t = z_t.squeeze(-1) 233 | loss_inst = self.mse_loss(z_o_pred, z_t) 234 | # object-level loss 235 | b, c, n = obj_o.shape 236 | obj_o_pred = self.predictor_obj(obj_o).transpose(2, 1).reshape(b*n, c) 237 | obj_t = obj_t.transpose(2, 1).reshape(b*n, c) 238 | loss_obj = self.mse_loss(obj_o_pred, obj_t) 239 | 240 | # score_q = torch.einsum('bnc,bc->bn', F.normalize(obj_o_pred, dim=2), F.normalize(z_o_pred, dim=1)) 241 | # score_k = torch.einsum('bnc,bc->bn', F.normalize(obj_t, dim=2), F.normalize(z_t, dim=1)) 242 | # score_q = torch.einsum('bnc,bc->bn', F.normalize(obj_o.transpose(2, 1), dim=2), F.normalize(z_o.squeeze(-1), dim=1)) 243 | # # score_q = torch.einsum('bnc,bc->bn', F.normalize(obj_t, dim=2), F.normalize(z_o.squeeze(-1), dim=1)) 244 | # score_k = torch.einsum('bnc,bc->bn', F.normalize(obj_t, dim=2), F.normalize(z_t, dim=1)) 245 | 246 | # score_q = torch.einsum('bnc,bc->bn', F.normalize(res_o["mask_embed"], dim=2), F.normalize(z_o.squeeze(-1), dim=1)) 247 | # score_q = torch.einsum('bnc,bc->bn', F.normalize(res_o["mask_embed"], dim=2), F.normalize(z_t, dim=1)) 248 | # score_q = torch.einsum('bnc,bc->bn', F.normalize(res_t["mask_embed"], dim=2), F.normalize(z_o.squeeze(-1), dim=1)) 249 | # score_k = torch.einsum('bnc,bc->bn', F.normalize(res_t["mask_embed"], dim=2), F.normalize(z_t, dim=1)) 250 | # loss_relation = self.self_distill(score_q, score_k) 251 | 252 | # score_q_1 = torch.einsum('bnc,bc->bn', F.normalize(res_o["mask_embed"], dim=2), F.normalize(z_t, dim=1)) 253 | # score_q_2 = torch.einsum('bnc,bc->bn', F.normalize(res_t["mask_embed"], dim=2), F.normalize(z_o.squeeze(-1), dim=1)) 254 | # score_k = torch.einsum('bnc,bc->bn', F.normalize(res_t["mask_embed"], dim=2), F.normalize(z_t, dim=1)) 255 | # loss_relation = 0.5 * (self.self_distill(score_q_1, score_k) + self.self_distill(score_q_2, score_k)) 256 | 257 | loss_base = loss_inst * self.loss_weight + loss_obj * (1 - self.loss_weight) 258 | 259 | # sum 260 | return loss_base, loss_inst, loss_obj 261 | 262 | @torch.no_grad() 263 | def momentum_update(self, cur_iter, max_iter): 264 | """ 265 | Momentum update of the target network. 266 | """ 267 | # momentum anneling 268 | momentum = 1. - (1. - self.base_m) * (cos(pi * cur_iter / float(max_iter)) + 1) / 2.0 269 | self.curr_m = momentum 270 | # parameter update for target network 271 | state_dict_ol = self.online_net.state_dict() 272 | state_dict_tgt = self.target_net.state_dict() 273 | for (k_ol, v_ol), (k_tgt, v_tgt) in zip(state_dict_ol.items(), state_dict_tgt.items()): 274 | assert k_tgt == k_ol, "state_dict names are different!" 275 | assert v_ol.shape == v_tgt.shape, "state_dict shapes are different!" 276 | if 'num_batches_tracked' in k_tgt: 277 | v_tgt.copy_(v_ol) 278 | else: 279 | v_tgt.copy_(v_tgt * momentum + (1. - momentum) * v_ol) 280 | 281 | @torch.no_grad() 282 | def update_center(self, teacher_output): 283 | """ 284 | Update center used for teacher output. 285 | """ 286 | batch_center = torch.mean(teacher_output, dim=0, keepdim=True) 287 | if is_distributed_training_run(): 288 | dist.all_reduce(batch_center) 289 | batch_center = batch_center / dist.get_world_size() 290 | 291 | # ema update 292 | self.center = self.center * self.center_momentum + batch_center * (1 - self.center_momentum) 293 | 294 | def get_heatmap(self, x): 295 | _, _, out = self.online_net(x) 296 | return out 297 | 298 | def ctr_loss(self, online_1, online_2, target_1, target_2): 299 | z_o_1, obj_o_1, res_o_1 = online_1 300 | z_o_2, obj_o_2, res_o_2 = online_2 301 | z_t_1, obj_t_1, res_t_1 = target_1 302 | z_t_2, obj_t_2, res_t_2 = target_2 303 | 304 | # corre_o = torch.matmul(F.normalize(res_o_1["mask_embed"], dim=2), 305 | # F.normalize(res_o_2["mask_embed"], dim=2).transpose(2, 1)) # b, q, c 306 | # corre_t = torch.matmul(F.normalize(res_t_1["mask_embed"], dim=2), 307 | # F.normalize(res_t_2["mask_embed"], dim=2).transpose(2, 1)) # b, q, c 308 | # loss = self.self_distill(corre_o.flatten(0, 1), corre_t.flatten(0, 1)) 309 | # score = corre_t.flatten(0, 1) 310 | 311 | loss = 0.5 * (self.self_distill(res_o_1["dots"].permute(0, 2, 3, 1).flatten(0, 2), 312 | res_t_1["dots"].permute(0, 2, 3, 1).flatten(0, 2)) 313 | + self.self_distill(res_o_2["dots"].permute(0, 2, 3, 1).flatten(0, 2), 314 | res_t_2["dots"].permute(0, 2, 3, 1).flatten(0, 2))) 315 | score_k1 = res_t_1["dots"] 316 | score_k2 = res_t_2["dots"] 317 | score = torch.cat([score_k1, score_k2]).permute(0, 2, 3, 1).flatten(0, 2) 318 | 319 | return loss, score 320 | 321 | def forward(self, im_v1, im_v2=None, **kwargs): 322 | """ 323 | Input: 324 | im_v1: a batch of view1 images 325 | im_v2: a batch of view2 images 326 | Output: 327 | loss 328 | """ 329 | # for inference, online_net.backbone model only 330 | if im_v2 is None: 331 | feats = self.online_net.backbone(im_v1) 332 | return F.adaptive_avg_pool2d(feats, 1).flatten(1) 333 | 334 | # compute online_net features 335 | proj_online_v1 = self.online_net(im_v1) 336 | proj_online_v2 = self.online_net(im_v2) 337 | 338 | # compute target_net features 339 | with torch.no_grad(): # no gradient to keys 340 | proj_target_v1 = [x.clone().detach() if isinstance(x, torch.Tensor) else x for x in self.target_net(im_v1)] 341 | proj_target_v2 = [x.clone().detach() if isinstance(x, torch.Tensor) else x for x in self.target_net(im_v2)] 342 | 343 | # loss. NOTE: the predction is moved to loss_func 344 | loss_base1, loss_inst1, loss_obj1 = self.loss_func(proj_online_v1, proj_target_v2) 345 | loss_base2, loss_inst2, loss_obj2 = self.loss_func(proj_online_v2, proj_target_v1) 346 | loss_base = loss_base1 + loss_base2 347 | 348 | loss_relation, score = self.ctr_loss(proj_online_v1, proj_online_v2, proj_target_v1, proj_target_v2) 349 | loss = loss_base + loss_relation * self.loss_w_cluster 350 | 351 | loss_pack = {} 352 | loss_pack["base"] = loss_base 353 | loss_pack["inst"] = (loss_inst1 + loss_inst2) * self.loss_weight 354 | loss_pack["obj"] = (loss_obj1 + loss_obj2) * (1 - self.loss_weight) 355 | loss_pack["relation"] = loss_relation 356 | 357 | self.update_center(score) 358 | 359 | return loss, loss_pack 360 | 361 | def extra_repr(self) -> str: 362 | parts = [] 363 | for name in ["loss_weight", "mask_type", "num_proto", "teacher_temp", "loss_w_cluster"]: 364 | parts.append(f"{name}={getattr(self, name)}") 365 | return ", ".join(parts) 366 | 367 | 368 | class LEWELB(LEWELB_EMAN): 369 | @torch.no_grad() 370 | def momentum_update(self, cur_iter, max_iter): 371 | """ 372 | Momentum update of the target network. 373 | """ 374 | # momentum anneling 375 | momentum = 1. - (1. - self.base_m) * (cos(pi * cur_iter / float(max_iter)) + 1) / 2.0 376 | self.curr_m = momentum 377 | # parameter update for target network 378 | for param_ol, param_tgt in zip(self.online_net.parameters(), self.target_net.parameters()): 379 | param_tgt.data = param_tgt.data * momentum + param_ol.data * (1. - momentum) 380 | 381 | 382 | if __name__ == '__main__': 383 | from models import get_model 384 | import backbone as backbone_models 385 | 386 | model_func = get_model("LEWELB_EMAN") 387 | norm_layer = None 388 | model = model_func( 389 | backbone_models.__dict__["resnet50_encoder"], 390 | dim=256, 391 | m=0.996, 392 | hid_dim=4096, 393 | norm_layer=norm_layer, 394 | num_neck_mlp=2, 395 | scale=1., 396 | l2_norm=True, 397 | num_heads=4, 398 | loss_weight=0.5, 399 | mask_type="attn" 400 | ) 401 | print(model) 402 | 403 | x1 = torch.randn(16, 3, 224, 224) 404 | x2 = torch.randn(16, 3, 224, 224) 405 | out = model(x1, x2) 406 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Attribution-NonCommercial 4.0 International 2 | 3 | ======================================================================= 4 | 5 | Creative Commons Corporation ("Creative Commons") is not a law firm and 6 | does not provide legal services or legal advice. Distribution of 7 | Creative Commons public licenses does not create a lawyer-client or 8 | other relationship. Creative Commons makes its licenses and related 9 | information available on an "as-is" basis. Creative Commons gives no 10 | warranties regarding its licenses, any material licensed under their 11 | terms and conditions, or any related information. Creative Commons 12 | disclaims all liability for damages resulting from their use to the 13 | fullest extent possible. 14 | 15 | Using Creative Commons Public Licenses 16 | 17 | Creative Commons public licenses provide a standard set of terms and 18 | conditions that creators and other rights holders may use to share 19 | original works of authorship and other material subject to copyright 20 | and certain other rights specified in the public license below. The 21 | following considerations are for informational purposes only, are not 22 | exhaustive, and do not form part of our licenses. 23 | 24 | Considerations for licensors: Our public licenses are 25 | intended for use by those authorized to give the public 26 | permission to use material in ways otherwise restricted by 27 | copyright and certain other rights. Our licenses are 28 | irrevocable. Licensors should read and understand the terms 29 | and conditions of the license they choose before applying it. 30 | Licensors should also secure all rights necessary before 31 | applying our licenses so that the public can reuse the 32 | material as expected. Licensors should clearly mark any 33 | material not subject to the license. This includes other CC- 34 | licensed material, or material used under an exception or 35 | limitation to copyright. More considerations for licensors: 36 | wiki.creativecommons.org/Considerations_for_licensors 37 | 38 | Considerations for the public: By using one of our public 39 | licenses, a licensor grants the public permission to use the 40 | licensed material under specified terms and conditions. If 41 | the licensor's permission is not necessary for any reason--for 42 | example, because of any applicable exception or limitation to 43 | copyright--then that use is not regulated by the license. Our 44 | licenses grant only permissions under copyright and certain 45 | other rights that a licensor has authority to grant. Use of 46 | the licensed material may still be restricted for other 47 | reasons, including because others have copyright or other 48 | rights in the material. A licensor may make special requests, 49 | such as asking that all changes be marked or described. 50 | Although not required by our licenses, you are encouraged to 51 | respect those requests where reasonable. More considerations 52 | for the public: 53 | wiki.creativecommons.org/Considerations_for_licensees 54 | 55 | ======================================================================= 56 | 57 | Creative Commons Attribution-NonCommercial 4.0 International Public 58 | License 59 | 60 | By exercising the Licensed Rights (defined below), You accept and agree 61 | to be bound by the terms and conditions of this Creative Commons 62 | Attribution-NonCommercial 4.0 International Public License ("Public 63 | License"). To the extent this Public License may be interpreted as a 64 | contract, You are granted the Licensed Rights in consideration of Your 65 | acceptance of these terms and conditions, and the Licensor grants You 66 | such rights in consideration of benefits the Licensor receives from 67 | making the Licensed Material available under these terms and 68 | conditions. 69 | 70 | 71 | Section 1 -- Definitions. 72 | 73 | a. Adapted Material means material subject to Copyright and Similar 74 | Rights that is derived from or based upon the Licensed Material 75 | and in which the Licensed Material is translated, altered, 76 | arranged, transformed, or otherwise modified in a manner requiring 77 | permission under the Copyright and Similar Rights held by the 78 | Licensor. For purposes of this Public License, where the Licensed 79 | Material is a musical work, performance, or sound recording, 80 | Adapted Material is always produced where the Licensed Material is 81 | synched in timed relation with a moving image. 82 | 83 | b. Adapter's License means the license You apply to Your Copyright 84 | and Similar Rights in Your contributions to Adapted Material in 85 | accordance with the terms and conditions of this Public License. 86 | 87 | c. Copyright and Similar Rights means copyright and/or similar rights 88 | closely related to copyright including, without limitation, 89 | performance, broadcast, sound recording, and Sui Generis Database 90 | Rights, without regard to how the rights are labeled or 91 | categorized. For purposes of this Public License, the rights 92 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 93 | Rights. 94 | d. Effective Technological Measures means those measures that, in the 95 | absence of proper authority, may not be circumvented under laws 96 | fulfilling obligations under Article 11 of the WIPO Copyright 97 | Treaty adopted on December 20, 1996, and/or similar international 98 | agreements. 99 | 100 | e. Exceptions and Limitations means fair use, fair dealing, and/or 101 | any other exception or limitation to Copyright and Similar Rights 102 | that applies to Your use of the Licensed Material. 103 | 104 | f. Licensed Material means the artistic or literary work, database, 105 | or other material to which the Licensor applied this Public 106 | License. 107 | 108 | g. Licensed Rights means the rights granted to You subject to the 109 | terms and conditions of this Public License, which are limited to 110 | all Copyright and Similar Rights that apply to Your use of the 111 | Licensed Material and that the Licensor has authority to license. 112 | 113 | h. Licensor means the individual(s) or entity(ies) granting rights 114 | under this Public License. 115 | 116 | i. NonCommercial means not primarily intended for or directed towards 117 | commercial advantage or monetary compensation. For purposes of 118 | this Public License, the exchange of the Licensed Material for 119 | other material subject to Copyright and Similar Rights by digital 120 | file-sharing or similar means is NonCommercial provided there is 121 | no payment of monetary compensation in connection with the 122 | exchange. 123 | 124 | j. Share means to provide material to the public by any means or 125 | process that requires permission under the Licensed Rights, such 126 | as reproduction, public display, public performance, distribution, 127 | dissemination, communication, or importation, and to make material 128 | available to the public including in ways that members of the 129 | public may access the material from a place and at a time 130 | individually chosen by them. 131 | 132 | k. Sui Generis Database Rights means rights other than copyright 133 | resulting from Directive 96/9/EC of the European Parliament and of 134 | the Council of 11 March 1996 on the legal protection of databases, 135 | as amended and/or succeeded, as well as other essentially 136 | equivalent rights anywhere in the world. 137 | 138 | l. You means the individual or entity exercising the Licensed Rights 139 | under this Public License. Your has a corresponding meaning. 140 | 141 | 142 | Section 2 -- Scope. 143 | 144 | a. License grant. 145 | 146 | 1. Subject to the terms and conditions of this Public License, 147 | the Licensor hereby grants You a worldwide, royalty-free, 148 | non-sublicensable, non-exclusive, irrevocable license to 149 | exercise the Licensed Rights in the Licensed Material to: 150 | 151 | a. reproduce and Share the Licensed Material, in whole or 152 | in part, for NonCommercial purposes only; and 153 | 154 | b. produce, reproduce, and Share Adapted Material for 155 | NonCommercial purposes only. 156 | 157 | 2. Exceptions and Limitations. For the avoidance of doubt, where 158 | Exceptions and Limitations apply to Your use, this Public 159 | License does not apply, and You do not need to comply with 160 | its terms and conditions. 161 | 162 | 3. Term. The term of this Public License is specified in Section 163 | 6(a). 164 | 165 | 4. Media and formats; technical modifications allowed. The 166 | Licensor authorizes You to exercise the Licensed Rights in 167 | all media and formats whether now known or hereafter created, 168 | and to make technical modifications necessary to do so. The 169 | Licensor waives and/or agrees not to assert any right or 170 | authority to forbid You from making technical modifications 171 | necessary to exercise the Licensed Rights, including 172 | technical modifications necessary to circumvent Effective 173 | Technological Measures. For purposes of this Public License, 174 | simply making modifications authorized by this Section 2(a) 175 | (4) never produces Adapted Material. 176 | 177 | 5. Downstream recipients. 178 | 179 | a. Offer from the Licensor -- Licensed Material. Every 180 | recipient of the Licensed Material automatically 181 | receives an offer from the Licensor to exercise the 182 | Licensed Rights under the terms and conditions of this 183 | Public License. 184 | 185 | b. No downstream restrictions. You may not offer or impose 186 | any additional or different terms or conditions on, or 187 | apply any Effective Technological Measures to, the 188 | Licensed Material if doing so restricts exercise of the 189 | Licensed Rights by any recipient of the Licensed 190 | Material. 191 | 192 | 6. No endorsement. Nothing in this Public License constitutes or 193 | may be construed as permission to assert or imply that You 194 | are, or that Your use of the Licensed Material is, connected 195 | with, or sponsored, endorsed, or granted official status by, 196 | the Licensor or others designated to receive attribution as 197 | provided in Section 3(a)(1)(A)(i). 198 | 199 | b. Other rights. 200 | 201 | 1. Moral rights, such as the right of integrity, are not 202 | licensed under this Public License, nor are publicity, 203 | privacy, and/or other similar personality rights; however, to 204 | the extent possible, the Licensor waives and/or agrees not to 205 | assert any such rights held by the Licensor to the limited 206 | extent necessary to allow You to exercise the Licensed 207 | Rights, but not otherwise. 208 | 209 | 2. Patent and trademark rights are not licensed under this 210 | Public License. 211 | 212 | 3. To the extent possible, the Licensor waives any right to 213 | collect royalties from You for the exercise of the Licensed 214 | Rights, whether directly or through a collecting society 215 | under any voluntary or waivable statutory or compulsory 216 | licensing scheme. In all other cases the Licensor expressly 217 | reserves any right to collect such royalties, including when 218 | the Licensed Material is used other than for NonCommercial 219 | purposes. 220 | 221 | 222 | Section 3 -- License Conditions. 223 | 224 | Your exercise of the Licensed Rights is expressly made subject to the 225 | following conditions. 226 | 227 | a. Attribution. 228 | 229 | 1. If You Share the Licensed Material (including in modified 230 | form), You must: 231 | 232 | a. retain the following if it is supplied by the Licensor 233 | with the Licensed Material: 234 | 235 | i. identification of the creator(s) of the Licensed 236 | Material and any others designated to receive 237 | attribution, in any reasonable manner requested by 238 | the Licensor (including by pseudonym if 239 | designated); 240 | 241 | ii. a copyright notice; 242 | 243 | iii. a notice that refers to this Public License; 244 | 245 | iv. a notice that refers to the disclaimer of 246 | warranties; 247 | 248 | v. a URI or hyperlink to the Licensed Material to the 249 | extent reasonably practicable; 250 | 251 | b. indicate if You modified the Licensed Material and 252 | retain an indication of any previous modifications; and 253 | 254 | c. indicate the Licensed Material is licensed under this 255 | Public License, and include the text of, or the URI or 256 | hyperlink to, this Public License. 257 | 258 | 2. You may satisfy the conditions in Section 3(a)(1) in any 259 | reasonable manner based on the medium, means, and context in 260 | which You Share the Licensed Material. For example, it may be 261 | reasonable to satisfy the conditions by providing a URI or 262 | hyperlink to a resource that includes the required 263 | information. 264 | 265 | 3. If requested by the Licensor, You must remove any of the 266 | information required by Section 3(a)(1)(A) to the extent 267 | reasonably practicable. 268 | 269 | 4. If You Share Adapted Material You produce, the Adapter's 270 | License You apply must not prevent recipients of the Adapted 271 | Material from complying with this Public License. 272 | 273 | 274 | Section 4 -- Sui Generis Database Rights. 275 | 276 | Where the Licensed Rights include Sui Generis Database Rights that 277 | apply to Your use of the Licensed Material: 278 | 279 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 280 | to extract, reuse, reproduce, and Share all or a substantial 281 | portion of the contents of the database for NonCommercial purposes 282 | only; 283 | 284 | b. if You include all or a substantial portion of the database 285 | contents in a database in which You have Sui Generis Database 286 | Rights, then the database in which You have Sui Generis Database 287 | Rights (but not its individual contents) is Adapted Material; and 288 | 289 | c. You must comply with the conditions in Section 3(a) if You Share 290 | all or a substantial portion of the contents of the database. 291 | 292 | For the avoidance of doubt, this Section 4 supplements and does not 293 | replace Your obligations under this Public License where the Licensed 294 | Rights include other Copyright and Similar Rights. 295 | 296 | 297 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 298 | 299 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 300 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 301 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 302 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 303 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 304 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 305 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 306 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 307 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 308 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 309 | 310 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 311 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 312 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 313 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 314 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 315 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 316 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 317 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 318 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 319 | 320 | c. The disclaimer of warranties and limitation of liability provided 321 | above shall be interpreted in a manner that, to the extent 322 | possible, most closely approximates an absolute disclaimer and 323 | waiver of all liability. 324 | 325 | 326 | Section 6 -- Term and Termination. 327 | 328 | a. This Public License applies for the term of the Copyright and 329 | Similar Rights licensed here. However, if You fail to comply with 330 | this Public License, then Your rights under this Public License 331 | terminate automatically. 332 | 333 | b. Where Your right to use the Licensed Material has terminated under 334 | Section 6(a), it reinstates: 335 | 336 | 1. automatically as of the date the violation is cured, provided 337 | it is cured within 30 days of Your discovery of the 338 | violation; or 339 | 340 | 2. upon express reinstatement by the Licensor. 341 | 342 | For the avoidance of doubt, this Section 6(b) does not affect any 343 | right the Licensor may have to seek remedies for Your violations 344 | of this Public License. 345 | 346 | c. For the avoidance of doubt, the Licensor may also offer the 347 | Licensed Material under separate terms or conditions or stop 348 | distributing the Licensed Material at any time; however, doing so 349 | will not terminate this Public License. 350 | 351 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 352 | License. 353 | 354 | 355 | Section 7 -- Other Terms and Conditions. 356 | 357 | a. The Licensor shall not be bound by any additional or different 358 | terms or conditions communicated by You unless expressly agreed. 359 | 360 | b. Any arrangements, understandings, or agreements regarding the 361 | Licensed Material not stated herein are separate from and 362 | independent of the terms and conditions of this Public License. 363 | 364 | 365 | Section 8 -- Interpretation. 366 | 367 | a. For the avoidance of doubt, this Public License does not, and 368 | shall not be interpreted to, reduce, limit, restrict, or impose 369 | conditions on any use of the Licensed Material that could lawfully 370 | be made without permission under this Public License. 371 | 372 | b. To the extent possible, if any provision of this Public License is 373 | deemed unenforceable, it shall be automatically reformed to the 374 | minimum extent necessary to make it enforceable. If the provision 375 | cannot be reformed, it shall be severed from this Public License 376 | without affecting the enforceability of the remaining terms and 377 | conditions. 378 | 379 | c. No term or condition of this Public License will be waived and no 380 | failure to comply consented to unless expressly agreed to by the 381 | Licensor. 382 | 383 | d. Nothing in this Public License constitutes or may be interpreted 384 | as a limitation upon, or waiver of, any privileges and immunities 385 | that apply to the Licensor or You, including from the legal 386 | processes of any jurisdiction or authority. 387 | 388 | ======================================================================= 389 | 390 | Creative Commons is not a party to its public 391 | licenses. Notwithstanding, Creative Commons may elect to apply one of 392 | its public licenses to material it publishes and in those instances 393 | will be considered the “Licensor.” The text of the Creative Commons 394 | public licenses is dedicated to the public domain under the CC0 Public 395 | Domain Dedication. Except for the limited purpose of indicating that 396 | material is shared under a Creative Commons public license or as 397 | otherwise permitted by the Creative Commons policies published at 398 | creativecommons.org/policies, Creative Commons does not authorize the 399 | use of the trademark "Creative Commons" or any other trademark or logo 400 | of Creative Commons without its prior written consent including, 401 | without limitation, in connection with any unauthorized modifications 402 | to any of its public licenses or any other arrangements, 403 | understandings, or agreements concerning use of licensed material. For 404 | the avoidance of doubt, this paragraph does not form part of the 405 | public licenses. 406 | 407 | Creative Commons may be contacted at creativecommons.org. -------------------------------------------------------------------------------- /models/fra.py: -------------------------------------------------------------------------------- 1 | # Copyright Lang Huang (laynehuang@outlook.com). All Rights Reserved. 2 | # SPDX-License-Identifier: CC-BY-NC-4.0 3 | import sys 4 | import math 5 | from math import cos, pi 6 | import torch 7 | import torch.nn as nn 8 | from torch.nn import functional as F 9 | from torch.nn.modules import loss 10 | import torch.distributed as dist 11 | from classy_vision.generic.distributed_util import is_distributed_training_run 12 | 13 | from models.transformers.transformer_predictor import TransformerPredictor 14 | from utils import init 15 | 16 | 17 | 18 | @torch.no_grad() 19 | def distributed_sinkhorn(Q, num_itr=3, use_dist=True, epsilon=0.05): 20 | _got_dist = use_dist and torch.distributed.is_available() \ 21 | and torch.distributed.is_initialized() \ 22 | and (torch.distributed.get_world_size() > 1) 23 | 24 | if _got_dist: 25 | world_size = torch.distributed.get_world_size() 26 | else: 27 | world_size = 1 28 | 29 | Q = Q.T 30 | # Q = torch.exp(Q / epsilon).t() 31 | B = Q.shape[1] * world_size # number of samples to assign 32 | K = Q.shape[0] # how many prototypes 33 | 34 | # make the matrix sums to 1 35 | sum_Q = torch.sum(Q) 36 | if _got_dist: 37 | torch.distributed.all_reduce(sum_Q) 38 | Q /= sum_Q 39 | 40 | for it in range(num_itr): 41 | # normalize each row: total weight per prototype must be 1/K 42 | sum_of_rows = torch.sum(Q, dim=1, keepdim=True) 43 | if _got_dist: 44 | torch.distributed.all_reduce(sum_of_rows) 45 | Q /= sum_of_rows 46 | Q /= K 47 | 48 | # normalize each column: total weight per sample must be 1/B 49 | Q /= torch.sum(Q, dim=0, keepdim=True) 50 | Q /= B 51 | 52 | Q *= B # the colomns must sum to 1 so that Q is an assignment 53 | return Q.T 54 | 55 | 56 | class MLP1D(nn.Module): 57 | """ 58 | The non-linear neck in byol: fc-bn-relu-fc 59 | """ 60 | def __init__(self, in_channels, hid_channels, out_channels, 61 | norm_layer=None, bias=False, num_mlp=2): 62 | super(MLP1D, self).__init__() 63 | if norm_layer is None: 64 | norm_layer = nn.BatchNorm1d 65 | mlps = [] 66 | for _ in range(num_mlp-1): 67 | mlps.append(nn.Conv1d(in_channels, hid_channels, 1, bias=bias)) 68 | mlps.append(norm_layer(hid_channels)) 69 | mlps.append(nn.ReLU(inplace=True)) 70 | in_channels = hid_channels 71 | mlps.append(nn.Conv1d(hid_channels, out_channels, 1, bias=bias)) 72 | self.mlp = nn.Sequential(*mlps) 73 | 74 | def init_weights(self, init_linear='normal'): 75 | init.init_weights(self, init_linear) 76 | 77 | def forward(self, x): 78 | x = self.mlp(x) 79 | return x 80 | 81 | 82 | class ObjectNeck(nn.Module): 83 | def __init__(self, 84 | in_channels, 85 | out_channels, 86 | hid_channels=None, 87 | num_layers=1, 88 | scale=1., 89 | l2_norm=True, 90 | num_heads=8, 91 | norm_layer=None, 92 | mask_type="group", 93 | num_proto=64, 94 | temp=0.07, 95 | **kwargs): 96 | super(ObjectNeck, self).__init__() 97 | 98 | self.scale = scale 99 | self.l2_norm = l2_norm 100 | assert l2_norm 101 | self.num_heads = num_heads 102 | self.mask_type = mask_type 103 | self.temp = temp 104 | self.eps = 1e-7 105 | 106 | hid_channels = hid_channels or in_channels 107 | self.proj = MLP1D(in_channels, hid_channels, out_channels, norm_layer, num_mlp=num_layers) 108 | self.proj_pixel = MLP1D(in_channels, hid_channels, out_channels, norm_layer, num_mlp=num_layers) 109 | self.proj_obj = MLP1D(in_channels, hid_channels, out_channels, norm_layer, num_mlp=num_layers) 110 | 111 | if mask_type == "attn": 112 | self.proj_attn = TransformerPredictor(in_channels=in_channels, hidden_dim=out_channels, num_queries=num_proto, 113 | nheads=8, dropout=0.1, dim_feedforward=out_channels, enc_layers=0, 114 | dec_layers=2, pre_norm=False, deep_supervision=False, 115 | mask_dim=out_channels, enforce_input_project=False, 116 | mask_classification=False, num_classes=0) 117 | 118 | self.proto_momentum = 0.9 119 | self.register_buffer("proto", torch.randn(num_proto, out_channels)) 120 | # self.proto = nn.Embedding(num_proto, out_channels) 121 | 122 | def init_weights(self, init_linear='kaiming'): 123 | self.proj.init_weights(init_linear) 124 | self.proj_pixel.init_weights(init_linear) 125 | self.proj_obj.init_weights(init_linear) 126 | 127 | @torch.no_grad() 128 | def update_proto(self, mask_embed): 129 | """ 130 | Update center used for teacher output. 131 | """ 132 | batch_center = torch.mean(mask_embed, dim=0) 133 | if is_distributed_training_run(): 134 | dist.all_reduce(batch_center) 135 | batch_center = batch_center / dist.get_world_size() 136 | 137 | # ema update 138 | self.proto = self.proto * self.proto_momentum + batch_center * (1 - self.proto_momentum) 139 | 140 | def forward(self, x, isTrain=True): 141 | out = {} 142 | 143 | b, c, h, w = x.shape 144 | 145 | # flatten and projection 146 | x_pool = F.adaptive_avg_pool2d(x, 1).flatten(2) 147 | x = x.flatten(2) # (bs, c, h*w) 148 | z_g = self.proj(x_pool) 149 | z_feat = self.proj_pixel(x) 150 | 151 | if self.mask_type == "attn": 152 | z_feat = z_feat.view(b, -1, h, w) 153 | x = x.view(b, c, h, w) 154 | # attn_out = self.proj_attn(z_feat, None) 155 | attn_out = self.proj_attn(x, None) 156 | mask_embed = attn_out["mask_embed"] # (bs, q, c) 157 | 158 | if isTrain: 159 | # mask_embed = AllReduce.apply(torch.mean(mask_embed, dim=0, keepdim=True)) 160 | mask_embed_avg = torch.mean(mask_embed, dim=0, keepdim=True) 161 | if is_distributed_training_run(): 162 | dist.all_reduce(mask_embed_avg) 163 | mask_embed_avg = mask_embed_avg / dist.get_world_size() 164 | mask_embed_avg = mask_embed_avg.repeat(x.size(0), 1, 1) 165 | if z_feat.requires_grad: 166 | assert mask_embed_avg.requires_grad 167 | 168 | dots = torch.einsum('bqc,bchw->bqhw', F.normalize(mask_embed_avg, dim=2), F.normalize(z_feat, dim=1)) 169 | else: 170 | dots = torch.einsum('qc,bchw->bqhw', F.normalize(self.proto, dim=1), F.normalize(z_feat, dim=1)) 171 | 172 | obj_attn = (dots / self.scale).softmax(dim=1) + self.eps 173 | 174 | slots = torch.einsum('bchw,bqhw->bqc', x, obj_attn / obj_attn.sum(dim=(2, 3), keepdim=True)) 175 | 176 | out["dots"] = dots 177 | out["feat"] = z_feat 178 | out["obj_attn"] = obj_attn 179 | else: 180 | # do attention according to obj attention map 181 | obj_attn = F.normalize(z_feat, dim=1) if self.l2_norm else z_feat 182 | obj_attn /= self.scale 183 | obj_attn = obj_attn.view(b, self.num_heads, -1, h * w) # (bs, h, d/h, h*w) 184 | 185 | if self.mask_type == "group": 186 | obj_attn = F.softmax(obj_attn, dim=-1) 187 | x = x.view(b, self.num_heads, -1, h*w) # (bs, h, c/h, h*w) 188 | obj_val = torch.matmul(x, obj_attn.transpose(3, 2)) # (bs, h, c//h, d/h) 189 | obj_val = obj_val.view(b, c, obj_attn.shape[-2]) # (bs, c, d/h) 190 | elif self.mask_type == "max": 191 | obj_attn, _ = torch.max(obj_attn, dim=1) # (bs, d/h, h*w) 192 | # obj_attn = torch.mean(obj_attn, dim=1) 193 | out["obj_attn"] = obj_attn 194 | obj_attn = F.softmax(obj_attn, dim=-1) 195 | obj_val = torch.matmul(x, obj_attn.transpose(2, 1)) # (bs, c, d/h) 196 | elif self.mask_type == "attn": 197 | obj_val = slots.transpose(2, 1) # (bs, c, q) 198 | 199 | # projection 200 | obj_val = self.proj_obj(obj_val) # (bs, d, q) 201 | 202 | if isTrain: 203 | self.update_proto(mask_embed) 204 | 205 | return z_g, obj_val, out # (bs, d, 1), (bs, d, d//h), where the second dim is channel 206 | 207 | def extra_repr(self) -> str: 208 | parts = [] 209 | for name in ["scale", "l2_norm", "num_heads"]: 210 | parts.append(f"{name}={getattr(self, name)}") 211 | return ", ".join(parts) 212 | 213 | 214 | class EncoderObj(nn.Module): 215 | def __init__(self, base_encoder, hid_dim, out_dim, norm_layer=None, num_mlp=2, 216 | scale=1., l2_norm=True, num_heads=8, mask_type="group", num_proto=64, temp=0.07): 217 | super(EncoderObj, self).__init__() 218 | self.backbone = base_encoder(norm_layer=norm_layer, with_avgpool=False) 219 | in_dim = self.backbone.out_channels 220 | self.neck = ObjectNeck(in_channels=in_dim, hid_channels=hid_dim, out_channels=out_dim, 221 | norm_layer=norm_layer, num_layers=num_mlp, 222 | scale=scale, l2_norm=l2_norm, num_heads=num_heads, mask_type=mask_type, 223 | num_proto=num_proto, temp=temp) 224 | self.neck.init_weights(init_linear='kaiming') 225 | 226 | def forward(self, im, isTrain=True): 227 | out = self.backbone(im) 228 | out = self.neck(out, isTrain) 229 | return out 230 | 231 | 232 | class FRAB_EMAN(nn.Module): 233 | def __init__(self, base_encoder, dim=256, m=0.996, hid_dim=4096, norm_layer=None, num_neck_mlp=2, 234 | scale=1., l2_norm=True, num_heads=8, loss_weight=0.5, mask_type="group", num_proto=8, 235 | teacher_temp=0.04, student_temp=0.1, loss_w_cluster=0.1, **kwargs): 236 | super().__init__() 237 | 238 | self.base_m = m 239 | self.curr_m = m 240 | self.loss_weight = loss_weight 241 | self.loss_w_cluster = loss_w_cluster 242 | self.loss_w_obj = 0.02 243 | self.mask_type = mask_type 244 | assert mask_type in ["group", "max", "attn"] 245 | self.num_proto = num_proto 246 | self.student_temp = student_temp # 0.1 247 | self.teacher_temp = teacher_temp # 0.04 248 | 249 | # create the encoders 250 | # num_classes is the output fc dimension 251 | self.online_net = EncoderObj(base_encoder, hid_dim, dim, norm_layer, num_neck_mlp, 252 | scale=scale, l2_norm=l2_norm, num_heads=num_heads, mask_type=mask_type, 253 | num_proto=num_proto, temp=self.teacher_temp) 254 | 255 | self.target_net = EncoderObj(base_encoder, hid_dim, dim, norm_layer, num_neck_mlp, 256 | scale=scale, l2_norm=l2_norm, num_heads=num_heads, mask_type=mask_type, 257 | num_proto=num_proto, temp=self.teacher_temp) 258 | self.predictor = MLP1D(dim, hid_dim, dim, norm_layer=norm_layer) 259 | self.predictor.init_weights() 260 | self.predictor_obj = MLP1D(dim, hid_dim, dim, norm_layer=norm_layer) 261 | self.predictor_obj.init_weights() 262 | self.encoder_q = self.online_net.backbone 263 | 264 | # copy params from online model to target model 265 | for param_ol, param_tgt in zip(self.online_net.parameters(), self.target_net.parameters()): 266 | param_tgt.data.copy_(param_ol.data) # initialize 267 | param_tgt.requires_grad = False # not update by gradient 268 | 269 | self.center_momentum = 0.9 270 | self.register_buffer("center", torch.zeros(1, self.num_proto)) 271 | 272 | def mse_loss(self, pred, target): 273 | """ 274 | Args: 275 | pred (Tensor): NxC input features. 276 | target (Tensor): NxC target features. 277 | """ 278 | N = pred.size(0) 279 | pred_norm = nn.functional.normalize(pred, dim=1) 280 | target_norm = nn.functional.normalize(target, dim=1) 281 | loss = 2 - 2 * (pred_norm * target_norm).sum() / N 282 | return loss 283 | 284 | def self_distill(self, q, k, use_sinkhorn=True, me_max=True): 285 | q_probs = F.log_softmax(q / self.student_temp, dim=-1) 286 | k_probs = F.softmax((k - self.center) / self.teacher_temp, dim=-1) 287 | 288 | if use_sinkhorn: 289 | k_probs = distributed_sinkhorn(k_probs) 290 | 291 | ce_loss = torch.sum(-k_probs * q_probs, dim=-1).mean() 292 | 293 | rloss = 0. 294 | if me_max: 295 | probs = F.softmax(q / self.student_temp, dim=-1) 296 | 297 | avg_probs = torch.mean(probs, dim=0) 298 | if is_distributed_training_run(): 299 | dist.all_reduce(avg_probs) 300 | avg_probs = avg_probs / dist.get_world_size() 301 | # avg_probs = AllReduce.apply(torch.mean(probs, dim=0)) 302 | rloss = - torch.sum(torch.log(avg_probs**(-avg_probs))) + math.log(float(len(avg_probs))) 303 | 304 | loss = ce_loss + 1.0 * rloss 305 | 306 | return loss 307 | 308 | def assign_loss(self, online_1, online_2, target_1, target_2): 309 | z_o1, obj_o1, res_o1 = online_1 310 | z_o2, obj_o2, res_o2 = online_2 311 | z_t1, obj_t1, res_t1 = target_1 312 | z_t2, obj_t2, res_t2 = target_2 313 | 314 | loss = 0.5 * (self.self_distill(res_o1["dots"].permute(0, 2, 3, 1).flatten(0, 2), 315 | res_t1["dots"].permute(0, 2, 3, 1).flatten(0, 2)) 316 | + self.self_distill(res_o2["dots"].permute(0, 2, 3, 1).flatten(0, 2), 317 | res_t2["dots"].permute(0, 2, 3, 1).flatten(0, 2))) 318 | score_k1 = res_t1["dots"] 319 | score_k2 = res_t2["dots"] 320 | score = torch.cat([score_k1, score_k2]).permute(0, 2, 3, 1).flatten(0, 2) 321 | 322 | return loss, score 323 | 324 | def compute_unigrad_loss(self, pred, target, idxs=None): 325 | pred = F.normalize(pred, dim=-1) 326 | target = F.normalize(target, dim=-1) 327 | 328 | dense_pred = pred.reshape(-1, pred.shape[-1]) 329 | dense_target = target.reshape(-1, target.shape[-1]) 330 | 331 | # compute pos term 332 | if idxs is not None: 333 | pos_term = self.mse_loss(dense_pred[idxs], dense_target[idxs]) 334 | else: 335 | pos_term = self.mse_loss(dense_pred, dense_target) 336 | 337 | # compute neg term 338 | mask = torch.eye(pred.shape[1], device=pred.device).unsqueeze(0).repeat(pred.size(0), 1, 1) 339 | correlation = torch.matmul(pred, target.transpose(2, 1)) # b,c,c 340 | correlation = correlation * (1.0 - mask) 341 | neg_term = ((correlation**2).sum(-1) / target.shape[1]).reshape(-1) 342 | 343 | if idxs is not None: 344 | neg_term = torch.mean(neg_term[idxs]) 345 | else: 346 | neg_term = torch.mean(neg_term) 347 | 348 | # # correlation = (dense_target.T @ dense_target) / dense_target.shape[0] 349 | # correlation = torch.matmul(target.transpose(2, 1), target) / target.shape[1] # b,c,c 350 | # # if is_distributed_training_run(): 351 | # # torch.distributed.all_reduce(correlation) 352 | # # correlation = correlation / torch.distributed.get_world_size() 353 | # 354 | # # neg_term = torch.diagonal(dense_pred @ correlation @ dense_pred.T).mean() 355 | # neg_term = torch.matmul(torch.matmul(pred, correlation), pred.transpose(2, 1)) 356 | # neg_term = torch.diagonal(neg_term, dim1=-2, dim2=-1).mean() 357 | 358 | loss = pos_term + self.loss_w_obj * neg_term 359 | 360 | return loss 361 | 362 | def loss_func(self, online, target): 363 | z_o, obj_o, res_o = online 364 | z_t, obj_t, res_t = target 365 | 366 | # instance-level loss 367 | z_o_pred = self.predictor(z_o).squeeze(-1) 368 | z_t = z_t.squeeze(-1) 369 | loss_inst = self.mse_loss(z_o_pred, z_t) 370 | 371 | # object-level loss 372 | b, c, n = obj_o.shape 373 | obj_o_pred = self.predictor_obj(obj_o).transpose(2, 1) 374 | obj_t = obj_t.transpose(2, 1) 375 | 376 | score_q = res_o["dots"] 377 | score_k = res_t["dots"] 378 | mask_q = (torch.zeros_like(score_q).scatter_(1, score_q.argmax(1, keepdim=True), 1).sum(-1).sum( 379 | -1) > 0).long().detach() 380 | mask_k = (torch.zeros_like(score_k).scatter_(1, score_k.argmax(1, keepdim=True), 1).sum(-1).sum( 381 | -1) > 0).long().detach() 382 | mask_intersection = (mask_q * mask_k).view(-1) 383 | idxs_q = mask_intersection.nonzero().squeeze(-1) 384 | 385 | # loss_obj = self.mse_loss(obj_o_pred.reshape(b*n, c)[idxs_q], obj_t.reshape(b*n, c)[idxs_q]) 386 | # loss_obj = self.compute_unigrad_loss(obj_o_pred, obj_t, idxs_q) 387 | loss_obj = self.compute_unigrad_loss(obj_o_pred, obj_t) 388 | 389 | loss_base = loss_inst * self.loss_weight + loss_obj * (1 - self.loss_weight) 390 | 391 | # sum 392 | return loss_base, loss_inst, loss_obj 393 | 394 | @torch.no_grad() 395 | def momentum_update(self, cur_iter, max_iter): 396 | """ 397 | Momentum update of the target network. 398 | """ 399 | # momentum anneling 400 | momentum = 1. - (1. - self.base_m) * (cos(pi * cur_iter / float(max_iter)) + 1) / 2.0 401 | self.curr_m = momentum 402 | # parameter update for target network 403 | state_dict_ol = self.online_net.state_dict() 404 | state_dict_tgt = self.target_net.state_dict() 405 | for (k_ol, v_ol), (k_tgt, v_tgt) in zip(state_dict_ol.items(), state_dict_tgt.items()): 406 | assert k_tgt == k_ol, "state_dict names are different!" 407 | assert v_ol.shape == v_tgt.shape, "state_dict shapes are different!" 408 | if 'num_batches_tracked' in k_tgt: 409 | v_tgt.copy_(v_ol) 410 | else: 411 | v_tgt.copy_(v_tgt * momentum + (1. - momentum) * v_ol) 412 | 413 | @torch.no_grad() 414 | def update_center(self, teacher_output): 415 | """ 416 | Update center used for teacher output. 417 | """ 418 | batch_center = torch.mean(teacher_output, dim=0, keepdim=True) 419 | if is_distributed_training_run(): 420 | dist.all_reduce(batch_center) 421 | batch_center = batch_center / dist.get_world_size() 422 | 423 | # ema update 424 | self.center = self.center * self.center_momentum + batch_center * (1 - self.center_momentum) 425 | 426 | def forward(self, im_v1, im_v2=None, **kwargs): 427 | """ 428 | Input: 429 | im_v1: a batch of view1 images 430 | im_v2: a batch of view2 images 431 | Output: 432 | loss 433 | """ 434 | # for inference, online_net.backbone model only 435 | if im_v2 is None: 436 | feats = self.online_net.backbone(im_v1) 437 | return F.adaptive_avg_pool2d(feats, 1).flatten(1) 438 | 439 | # compute online_net features 440 | proj_online_v1 = self.online_net(im_v1) 441 | proj_online_v2 = self.online_net(im_v2) 442 | 443 | # compute target_net features 444 | with torch.no_grad(): # no gradient to keys 445 | proj_target_v1 = [x.clone().detach() if isinstance(x, torch.Tensor) else x for x in self.target_net(im_v1)] 446 | proj_target_v2 = [x.clone().detach() if isinstance(x, torch.Tensor) else x for x in self.target_net(im_v2)] 447 | 448 | # loss. NOTE: the predction is moved to loss_func 449 | loss_base1, loss_inst1, loss_obj1 = self.loss_func(proj_online_v1, proj_target_v2) 450 | loss_base2, loss_inst2, loss_obj2 = self.loss_func(proj_online_v2, proj_target_v1) 451 | loss_base = loss_base1 + loss_base2 452 | 453 | loss_cluster, score = self.assign_loss(proj_online_v1, proj_online_v2, proj_target_v1, proj_target_v2) 454 | loss = loss_base + loss_cluster * self.loss_w_cluster 455 | 456 | loss_pack = {} 457 | loss_pack["base"] = loss_base 458 | loss_pack["inst"] = (loss_inst1 + loss_inst2) * self.loss_weight 459 | loss_pack["obj"] = (loss_obj1 + loss_obj2) * (1 - self.loss_weight) 460 | loss_pack["clu"] = loss_cluster 461 | 462 | # self.update_center(score) 463 | 464 | return loss, loss_pack 465 | 466 | def extra_repr(self) -> str: 467 | parts = [] 468 | for name in ["loss_weight", "mask_type", "num_proto", "teacher_temp", "loss_w_obj", "loss_w_cluster"]: 469 | parts.append(f"{name}={getattr(self, name)}") 470 | return ", ".join(parts) 471 | 472 | 473 | class FRAB(FRAB_EMAN): 474 | @torch.no_grad() 475 | def momentum_update(self, cur_iter, max_iter): 476 | """ 477 | Momentum update of the target network. 478 | """ 479 | # momentum anneling 480 | momentum = 1. - (1. - self.base_m) * (cos(pi * cur_iter / float(max_iter)) + 1) / 2.0 481 | self.curr_m = momentum 482 | # parameter update for target network 483 | for param_ol, param_tgt in zip(self.online_net.parameters(), self.target_net.parameters()): 484 | param_tgt.data = param_tgt.data * momentum + param_ol.data * (1. - momentum) 485 | 486 | 487 | if __name__ == '__main__': 488 | from models import get_model 489 | import backbone as backbone_models 490 | 491 | checkpoint = torch.load("./checkpoints/flr_r50_vgg_face.pth", map_location="cpu") 492 | state_dict = checkpoint['state_dict'] if "state_dict" in checkpoint else checkpoint 493 | 494 | model_func = get_model("FRAB") 495 | norm_layer = None 496 | model = model_func( 497 | backbone_models.__dict__["resnet50_encoder"], 498 | dim=256, 499 | m=0.996, 500 | hid_dim=4096, 501 | norm_layer=norm_layer, 502 | num_neck_mlp=2, 503 | scale=1., 504 | l2_norm=True, 505 | num_heads=4, 506 | loss_weight=0.5, 507 | mask_type="attn", 508 | num_proto=8, 509 | teacher_temp=0.04, 510 | ) 511 | print(model) 512 | 513 | x1 = torch.randn(16, 3, 224, 224) 514 | x2 = torch.randn(16, 3, 224, 224) 515 | out = model(x1, x2) 516 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # some code in this file is adapted from 2 | # https://github.com/pytorch/examples 3 | # Original Copyright 2017. Licensed under the BSD 3-Clause License. 4 | # Modifications Copyright Lang Huang (laynehuang@outlook.com). All Rights Reserved. 5 | # SPDX-License-Identifier: CC-BY-NC-4.0 6 | 7 | import argparse 8 | import builtins 9 | from logging import root 10 | import os 11 | import time 12 | 13 | import torch 14 | import torch.nn.parallel 15 | import torch.nn.functional as F 16 | import torch.backends.cudnn as cudnn 17 | import torch.distributed as dist 18 | import torch.optim 19 | import torch.utils.data 20 | import torch.utils.data.distributed 21 | import torchvision 22 | import torchvision.transforms as transforms 23 | from classy_vision.generic.distributed_util import is_distributed_training_run 24 | 25 | import backbone as backbone_models 26 | from models import get_model 27 | from utils import utils, lr_schedule, LARS, get_norm, init_distributed_mode 28 | import data.transforms as data_transforms 29 | from engine import ss_validate, ss_face_validate 30 | from data.base_dataset import get_dataset 31 | 32 | backbone_model_names = sorted(name for name in backbone_models.__dict__ 33 | if name.islower() and not name.startswith("__") 34 | and callable(backbone_models.__dict__[name])) 35 | 36 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 37 | parser.add_argument('--dataset', default="in1k", 38 | help='name of dataset', choices=['in1k', 'in100', 'im_folder', 'in1k_idx', "vggface2"]) 39 | parser.add_argument('--data-root', default="", 40 | help='root of dataset folder') 41 | parser.add_argument('--arch', metavar='ARCH', default='LEWEL', 42 | help='model architecture') 43 | parser.add_argument('--backbone', default='resnet50_encoder', 44 | choices=backbone_model_names, 45 | help='model architecture: ' + 46 | ' | '.join(backbone_model_names) + 47 | ' (default: resnet50_encoder)') 48 | parser.add_argument('-j', '--workers', default=64, type=int, metavar='N', 49 | help='number of data loading workers (default: 64)') 50 | parser.add_argument('--epochs', default=200, type=int, metavar='N', 51 | help='number of total epochs to run') 52 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 53 | help='manual epoch number (useful on restarts)') 54 | parser.add_argument('--warmup-epoch', default=0, type=int, metavar='N', 55 | help='number of epochs for learning warmup') 56 | parser.add_argument('-b', '--batch-size', default=256, type=int, 57 | metavar='N', 58 | help='mini-batch size (default: 256), this is the total ' 59 | 'batch size of all GPUs on the current node when ' 60 | 'using Data Parallel or Distributed Data Parallel') 61 | parser.add_argument('--lr', '--learning-rate', default=0.03, type=float, 62 | metavar='LR', help='initial learning rate', dest='lr') 63 | parser.add_argument('--schedule', default=[120, 160], nargs='*', type=int, 64 | help='learning rate schedule (when to drop lr by 10x)') 65 | parser.add_argument('--cos', action='store_true', help='use cosine lr schedule') 66 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 67 | help='momentum of SGD solver') 68 | parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, 69 | metavar='W', help='weight decay (default: 1e-4)', 70 | dest='weight_decay') 71 | parser.add_argument('--save-dir', default="ckpts", 72 | help='checkpoint directory') 73 | parser.add_argument('-p', '--print-freq', default=50, type=int, 74 | metavar='N', help='print frequency (default: 10)') 75 | parser.add_argument('--save-freq', default=10, type=int, 76 | metavar='N', help='checkpoint save frequency (default: 10)') 77 | parser.add_argument('--eval-freq', default=5, type=int, 78 | metavar='N', help='evaluation epoch frequency (default: 5)') 79 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 80 | help='path to latest checkpoint (default: none)') 81 | parser.add_argument('--pretrained', default='', type=str, metavar='PATH', 82 | help='path to pretrained model (default: none)') 83 | parser.add_argument('--super-pretrained', default='', type=str, metavar='PATH', 84 | help='path to MoCo pretrained model (default: none)') 85 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 86 | help='evaluate model on validation set') 87 | parser.add_argument('--seed', default=23456, type=int, 88 | help='seed for initializing training. ') 89 | 90 | # dist 91 | parser.add_argument('--world_size', default=-1, type=int, help='number of nodes for distributed training') 92 | parser.add_argument('--rank', default=-1, type=int, help='node rank for distributed training') 93 | parser.add_argument('--gpu', default=None, type=int, help='GPU id to use.') 94 | parser.add_argument('--dist_backend', default='nccl', type=str, help='distributed backend') 95 | parser.add_argument("--dist_url", default="env://", type=str, help="""url used to set up 96 | distributed training; """) 97 | parser.add_argument("--local_rank", default=0, type=int, help="Please ignore and do not set this argument.") 98 | parser.add_argument('--multiprocessing_distributed', action='store_true', 99 | help='Use multi-processing distributed training to launch ' 100 | 'N processes per node, which has N GPUs. This is the ' 101 | 'fastest way to use PyTorch for either single node or ' 102 | 'multi node data parallel training') 103 | 104 | # ssl specific configs: 105 | parser.add_argument('--proj-dim', default=256, type=int, 106 | help='feature dimension (default: 256)') 107 | parser.add_argument('--enc-m', default=0.996, type=float, 108 | help='momentum of updating key encoder (default: 0.996)') 109 | parser.add_argument('--norm', default='None', type=str, 110 | help='the normalization for network (default: None)') 111 | parser.add_argument('--num-neck-mlp', default=2, type=int, 112 | help='number of neck mlp (default: 2)') 113 | parser.add_argument('--hid-dim', default=4096, type=int, 114 | help='hidden dimension of mlp (default: 4096)') 115 | parser.add_argument('--amp', action='store_true', 116 | help='use automatic mixed precision training') 117 | 118 | # options for LEWEL 119 | parser.add_argument('--lewel-l2-norm', action='store_true', 120 | help='use l2-norm before applying softmax on attention map') 121 | parser.add_argument('--lewel-scale', default=1., type=float, 122 | help='Scale factor of attention map (default: 1.)') 123 | parser.add_argument('--lewel-num-heads', default=8, type=int, 124 | help='Number of heads in lewel (default: 8)') 125 | parser.add_argument('--lewel-loss-weight', default=0.5, type=float, 126 | help='loss weight for aligned branch (default: 0.5)') 127 | 128 | parser.add_argument('--train-percent', default=1.0, type=float, help='percentage of training set') 129 | parser.add_argument('--mask_type', default="group", type=str, help='type of masks') 130 | parser.add_argument('--num_proto', default=64, type=int, 131 | help='Number of heatmaps') 132 | parser.add_argument('--teacher_temp', default=0.07, type=float, 133 | help='temperature of the teacher') 134 | parser.add_argument('--loss_w_cluster', default=0.5, type=float, 135 | help='loss weight for cluster assignments (default: 0.5)') 136 | 137 | 138 | # options for KNN search 139 | parser.add_argument('--num-nn', default=20, type=int, 140 | help='Number of nearest neighbors (default: 20)') 141 | parser.add_argument('--nn-mem-percent', type=float, default=0.1, 142 | help='number of percentage mem datan for KNN evaluation') 143 | parser.add_argument('--nn-query-percent', type=float, default=0.5, 144 | help='number of percentage query datan for KNN evaluation') 145 | 146 | 147 | best_acc1 = 0 148 | 149 | 150 | def main(args): 151 | global best_acc1 152 | # args.gpu = args.local_rank 153 | 154 | # create model 155 | print("=> creating model '{}' with backbone '{}'".format(args.arch, args.backbone)) 156 | model_func = get_model(args.arch) 157 | norm_layer = get_norm(args.norm) 158 | model = model_func( 159 | backbone_models.__dict__[args.backbone], 160 | dim=args.proj_dim, 161 | m=args.enc_m, 162 | hid_dim=args.hid_dim, 163 | norm_layer=norm_layer, 164 | num_neck_mlp=args.num_neck_mlp, 165 | scale=args.lewel_scale, 166 | l2_norm=args.lewel_l2_norm, 167 | num_heads=args.lewel_num_heads, 168 | loss_weight=args.lewel_loss_weight, 169 | mask_type=args.mask_type, 170 | num_proto=args.num_proto, 171 | teacher_temp=args.teacher_temp, 172 | loss_w_cluster=args.loss_w_cluster 173 | ) 174 | print(model) 175 | print(args) 176 | 177 | if args.pretrained: 178 | if os.path.isfile(args.pretrained): 179 | print("=> loading pretrained model from '{}'".format(args.pretrained)) 180 | state_dict = torch.load(args.pretrained, map_location="cpu")['state_dict'] 181 | # rename state_dict keys 182 | for k in list(state_dict.keys()): 183 | new_key = k.replace("module.", "") 184 | state_dict[new_key] = state_dict[k] 185 | del state_dict[k] 186 | msg = model.load_state_dict(state_dict, strict=False) 187 | print("=> loaded pretrained model from '{}'".format(args.pretrained)) 188 | if len(msg.missing_keys) > 0: 189 | print("missing keys: {}".format(msg.missing_keys)) 190 | if len(msg.unexpected_keys) > 0: 191 | print("unexpected keys: {}".format(msg.unexpected_keys)) 192 | else: 193 | print("=> no pretrained model found at '{}'".format(args.pretrained)) 194 | 195 | 196 | model.cuda() 197 | args.batch_size = int(args.batch_size / args.world_size) 198 | args.workers = int((args.workers + args.world_size - 1) / args.world_size) 199 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 200 | 201 | # define optimizer 202 | # args.lr = args.batch_size * args.world_size / 1024 * args.lr 203 | if args.dataset == 'in100': 204 | args.lr *= 2 205 | 206 | # params = collect_params(model, exclude_bias_and_bn=True, sync_bn='EMAN' in args.arch) 207 | params = collect_params(model, exclude_bias_and_bn=True) 208 | optimizer = LARS(params, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) 209 | scaler = torch.cuda.amp.GradScaler() if args.amp else None 210 | 211 | # optionally resume from a checkpoint 212 | if args.resume: 213 | if os.path.isfile(args.resume): 214 | print("=> loading checkpoint '{}'".format(args.resume)) 215 | if args.gpu is None: 216 | checkpoint = torch.load(args.resume) 217 | else: 218 | # Map model to be loaded to specified single gpu. 219 | loc = 'cuda:{}'.format(args.gpu) 220 | checkpoint = torch.load(args.resume, map_location=loc) 221 | args.start_epoch = checkpoint['epoch'] 222 | if 'best_acc1' in checkpoint: 223 | best_acc1 = checkpoint['best_acc1'] 224 | model.load_state_dict(checkpoint['state_dict']) 225 | optimizer.load_state_dict(checkpoint['optimizer']) 226 | if 'scaler' in checkpoint: 227 | scaler.load_state_dict(checkpoint['scaler']) 228 | else: 229 | print("no scaler checkpoint") 230 | print("=> loaded checkpoint '{}' (epoch {})" 231 | .format(args.resume, checkpoint['epoch'])) 232 | else: 233 | print("=> no checkpoint found at '{}'".format(args.resume)) 234 | 235 | cudnn.benchmark = True 236 | 237 | # Data loading code 238 | if args.dataset.lower() == "vggface2": 239 | transform1, transform2 = data_transforms.get_vggface_tranforms(image_size=224) 240 | val_split = "test" 241 | else: 242 | transform1, transform2 = data_transforms.get_byol_tranforms() 243 | val_split = "val" 244 | 245 | train_dataset = get_dataset( 246 | args.dataset, 247 | mode='train', 248 | transform=data_transforms.TwoCropsTransform(transform1, transform2), 249 | data_root=args.data_root) 250 | print("train_dataset:\n{}".format(train_dataset)) 251 | 252 | if args.train_percent < 1.0: 253 | num_subset = int(len(train_dataset) * args.train_percent) 254 | indices = torch.randperm(len(train_dataset))[:num_subset] 255 | indices = indices.tolist() 256 | train_dataset = torch.utils.data.Subset(train_dataset, indices) 257 | print("Sub train_dataset:\n{}".format(len(train_dataset))) 258 | 259 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 260 | train_loader = torch.utils.data.DataLoader( 261 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), 262 | num_workers=args.workers, pin_memory=True, sampler=train_sampler, drop_last=True, 263 | persistent_workers=True) 264 | 265 | if args.dataset.lower() == "vggface2": 266 | normalize = transforms.Normalize(mean=data_transforms.IMG_MEAN["vggface2"], 267 | std=data_transforms.IMG_STD["vggface2"]) 268 | transform_test = transforms.Compose([ 269 | transforms.Resize((224, 224)), 270 | # transforms.CenterCrop(args.image_size), 271 | transforms.ToTensor(), 272 | normalize, 273 | ]) 274 | val_dataset = torchvision.datasets.LFWPairs(root="../data/lfw", split="test", 275 | transform=transform_test, download=True) 276 | val_loader = torch.utils.data.DataLoader( 277 | val_dataset, 278 | batch_size=args.batch_size, shuffle=False, 279 | num_workers=args.workers//2, pin_memory=True, 280 | persistent_workers=True) 281 | 282 | else: 283 | val_loader_base = torch.utils.data.DataLoader( 284 | get_dataset( 285 | args.dataset, 286 | mode=val_split, 287 | transform=data_transforms.get_transforms("DefaultVal", args.dataset), 288 | data_root=args.data_root, 289 | percent=args.nn_mem_percent 290 | ), 291 | batch_size=args.batch_size, shuffle=False, 292 | num_workers=args.workers//2, pin_memory=True, 293 | persistent_workers=True) 294 | 295 | val_loader_query = torch.utils.data.DataLoader( 296 | get_dataset( 297 | args.dataset, 298 | mode=val_split, 299 | transform=data_transforms.get_transforms("DefaultVal", args.dataset), 300 | data_root=args.data_root, 301 | percent=args.nn_query_percent, 302 | ), 303 | batch_size=args.batch_size, shuffle=False, 304 | num_workers=args.workers//2, pin_memory=True, 305 | persistent_workers=True) 306 | 307 | if args.evaluate: 308 | # ss_validate(val_loader_base, val_loader_query, model, args) 309 | ss_face_validate(val_loader, model, args) 310 | return 311 | 312 | best_epoch = args.start_epoch 313 | for epoch in range(args.start_epoch, args.epochs): 314 | train_sampler.set_epoch(epoch) 315 | if epoch >= args.warmup_epoch: 316 | lr_schedule.adjust_learning_rate(optimizer, epoch, args) 317 | 318 | # train for one epoch 319 | train(train_loader, model, optimizer, scaler, epoch, args) 320 | 321 | is_best = False 322 | if (epoch + 1) % args.eval_freq == 0: 323 | # acc1 = ss_validate(val_loader_base, val_loader_query, model, args) 324 | acc1 = ss_face_validate(val_loader, model, args) 325 | # remember best acc@1 and save checkpoint 326 | is_best = acc1 > best_acc1 327 | best_acc1 = max(acc1, best_acc1) 328 | if is_best: 329 | best_epoch = epoch 330 | 331 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed 332 | and args.local_rank % args.world_size == 0): 333 | utils.save_checkpoint({ 334 | 'epoch': epoch + 1, 335 | 'arch': args.arch, 336 | 'state_dict': model.state_dict(), 337 | 'best_acc1': best_acc1, 338 | 'optimizer': optimizer.state_dict(), 339 | 'scaler': None if scaler is None else scaler.state_dict(), 340 | }, is_best=is_best, epoch=epoch, args=args) 341 | 342 | print('Best Acc@1 {0} @ epoch {1}'.format(best_acc1, best_epoch + 1)) 343 | 344 | 345 | def train(train_loader, model, optimizer, scaler, epoch, args): 346 | batch_time = utils.AverageMeter('Time', ':6.3f') 347 | data_time = utils.AverageMeter('Data', ':6.3f') 348 | losses = utils.AverageMeter('Loss', ':.4e') 349 | losses_base = utils.AverageMeter('Loss_base', ':.4e') 350 | losses_inst = utils.AverageMeter('Loss_inst', ':.4e') 351 | losses_obj = utils.AverageMeter('Loss_obj', ':.4e') 352 | losses_clu = utils.AverageMeter('Loss_clu', ':.4e') 353 | curr_lr = utils.InstantMeter('LR', ':.7f') 354 | curr_mom = utils.InstantMeter('MOM', ':.7f') 355 | progress = utils.ProgressMeter( 356 | len(train_loader), 357 | [curr_lr, curr_mom, batch_time, data_time, losses, losses_base, losses_inst, losses_obj, losses_clu], 358 | prefix="Epoch: [{}/{}]\t".format(epoch, args.epochs)) 359 | 360 | # iter info 361 | batch_iter = len(train_loader) 362 | max_iter = float(batch_iter * args.epochs) 363 | 364 | # switch to train mode 365 | model.train() 366 | if "EMAN" in args.arch: 367 | print("setting the key model to eval mode when using EMAN") 368 | if hasattr(model, 'module'): 369 | model.module.target_net.eval() 370 | else: 371 | model.target_net.eval() 372 | 373 | end = time.time() 374 | for i, (images, _, idx) in enumerate(train_loader): 375 | # update model momentum 376 | curr_iter = float(epoch * batch_iter + i) 377 | 378 | # measure data loading time 379 | data_time.update(time.time() - end) 380 | 381 | if args.gpu is not None: 382 | images[0] = images[0].cuda(args.gpu, non_blocking=True) 383 | images[1] = images[1].cuda(args.gpu, non_blocking=True) 384 | idx = idx.cuda(args.gpu, non_blocking=True) 385 | 386 | # warmup learning rate 387 | if epoch < args.warmup_epoch: 388 | warmup_step = args.warmup_epoch * batch_iter 389 | curr_step = epoch * batch_iter + i + 1 390 | lr_schedule.warmup_learning_rate(optimizer, curr_step, warmup_step, args) 391 | curr_lr.update(optimizer.param_groups[0]['lr']) 392 | 393 | if scaler is None: 394 | # compute loss 395 | loss, loss_pack = model(im_v1=images[0], im_v2=images[1], idx=idx) 396 | 397 | # compute gradient and do SGD step 398 | optimizer.zero_grad() 399 | loss.backward() 400 | optimizer.step() 401 | else: # AMP 402 | optimizer.zero_grad() 403 | with torch.cuda.amp.autocast(): 404 | loss, loss_pack = model(im_v1=images[0], im_v2=images[1], idx=idx) 405 | 406 | scaler.scale(loss).backward() 407 | scaler.step(optimizer) 408 | scaler.update() 409 | 410 | # measure accuracy and record loss 411 | losses.update(loss.item(), images[0].size(0)) 412 | losses_base.update(loss_pack["base"].item(), images[0].size(0)) 413 | losses_inst.update(loss_pack["inst"].item(), images[0].size(0)) 414 | losses_obj.update(loss_pack["obj"].item(), images[0].size(0)) 415 | losses_clu.update(loss_pack["clu"].item(), images[0].size(0)) 416 | 417 | if hasattr(model, 'module'): 418 | model.module.momentum_update(curr_iter, max_iter) 419 | curr_mom.update(model.module.curr_m) 420 | else: 421 | model.momentum_update(curr_iter, max_iter) 422 | curr_mom.update(model.curr_m) 423 | 424 | # measure elapsed time 425 | batch_time.update(time.time() - end) 426 | end = time.time() 427 | 428 | if i % args.print_freq == 0: 429 | progress.display(i) 430 | 431 | 432 | def collect_params(model, exclude_bias_and_bn=True, sync_bn=True): 433 | """ 434 | exclude_bias_and bn: exclude bias and bn from both weight decay and LARS adaptation 435 | in the PyTorch implementation of ResNet, `downsample.1` are bn layers 436 | """ 437 | weight_param_list, bn_and_bias_param_list = [], [] 438 | weight_param_names, bn_and_bias_param_names = [], [] 439 | for name, param in model.named_parameters(): 440 | if exclude_bias_and_bn and ('bn' in name or 'downsample.1' in name or 'bias' in name): 441 | bn_and_bias_param_list.append(param) 442 | bn_and_bias_param_names.append(name) 443 | else: 444 | weight_param_list.append(param) 445 | weight_param_names.append(name) 446 | print("weight params:\n{}".format('\n'.join(weight_param_names))) 447 | print("bn and bias params:\n{}".format('\n'.join(bn_and_bias_param_names))) 448 | param_list = [{'params': bn_and_bias_param_list, 'weight_decay': 0., 'lars_exclude': True}, 449 | {'params': weight_param_list}] 450 | return param_list 451 | 452 | 453 | if __name__ == '__main__': 454 | opt = parser.parse_args() 455 | opt.distributed = True 456 | opt.multiprocessing_distributed = True 457 | 458 | # _, opt.local_rank, opt.world_size = dist_init(opt.port) 459 | # cudnn.benchmark = True 460 | # 461 | # # suppress printing if not master 462 | # if dist.get_rank() != 0: 463 | # def print_pass(*args, **kwargs): 464 | # pass 465 | # builtins.print = print_pass 466 | 467 | init_distributed_mode(opt) 468 | 469 | main(opt) 470 | --------------------------------------------------------------------------------