├── data
├── __init__.py
├── cvt_idx.py
├── rafdb.py
├── affectnet.py
├── randaugment.py
├── fer2013.py
├── transforms.py
├── base_dataset.py
└── celeba.py
├── models
├── transformers
│ ├── __init__.py
│ ├── position_encoding.py
│ ├── transformer_predictor.py
│ └── transformer.py
├── __init__.py
├── lewel.py
└── fra.py
├── docs
└── face-framework.png
├── requirements.txt
├── backbone
├── __init__.py
└── resnet.py
├── utils
├── __init__.py
├── batch_norm.py
├── extract_backbone.py
├── lr_schedule.py
├── init.py
├── LARS.py
├── utils.py
└── dist_utils.py
├── .gitignore
├── launch.py
├── README.md
├── engine.py
├── LICENSE
└── main.py
/data/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/models/transformers/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/docs/face-framework.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zaczgao/Facial_Region_Awareness/HEAD/docs/face-framework.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==1.10.2+cu111
2 | torchvision==0.11.3+cu111
3 | tensorboard
4 | classy_vision
5 | pandas
--------------------------------------------------------------------------------
/backbone/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 | # SPDX-License-Identifier: CC-BY-NC-4.0
3 |
4 | from .resnet import *
5 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 | # SPDX-License-Identifier: CC-BY-NC-4.0
3 |
4 | from .batch_norm import get_norm
5 | from .LARS import LARS
6 | from .dist_utils import init_distributed_mode
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 | # SPDX-License-Identifier: CC-BY-NC-4.0
3 |
4 | from .lewel import LEWELB, LEWELB_EMAN
5 | from .fra import FRAB, FRAB_EMAN
6 |
7 |
8 |
9 | def get_model(model):
10 | """
11 | Args:
12 | model (str or callable):
13 |
14 | Returns:
15 | Model
16 | """
17 | if isinstance(model, str):
18 | model = {
19 | "LEWELB": LEWELB,
20 | "LEWELB_EMAN": LEWELB_EMAN,
21 | "FRAB": FRAB,
22 | "FRAB_EMAN": FRAB_EMAN,
23 | }[model]
24 | return model
--------------------------------------------------------------------------------
/utils/batch_norm.py:
--------------------------------------------------------------------------------
1 | # Original copyright Amazon.com, Inc. or its affiliates, under CC-BY-NC-4.0 License.
2 | # Modifications Copyright Lang Huang (laynehuang@outlook.com). All Rights Reserved.
3 | # SPDX-License-Identifier: CC-BY-NC-4.0
4 |
5 | from torch import nn
6 |
7 |
8 | def get_norm(norm):
9 | """
10 | Args:
11 | norm (str or callable):
12 |
13 | Returns:
14 | nn.Module or None: the normalization layer
15 | """
16 | if isinstance(norm, str):
17 | if len(norm) == 0:
18 | return None
19 | norm = {
20 | "BN": nn.BatchNorm2d,
21 | "BN1d": nn.BatchNorm1d,
22 | "SyncBN": nn.SyncBatchNorm,
23 | "GN": lambda channels: nn.GroupNorm(32, channels),
24 | "IN": lambda channels: nn.InstanceNorm2d(channels, affine=True),
25 | "None": None,
26 | }[norm]
27 | return norm
28 |
--------------------------------------------------------------------------------
/utils/extract_backbone.py:
--------------------------------------------------------------------------------
1 | # Copyright Lang Huang (laynehuang@outlook.com). All Rights Reserved.
2 | # SPDX-License-Identifier: CC-BY-NC-4.0
3 |
4 | import sys
5 | import torch
6 |
7 | if __name__ == "__main__":
8 | input = sys.argv[1]
9 |
10 | obj = torch.load(input, map_location="cpu")
11 | print("Loading {} (epoch {})".format(input, obj['epoch']))
12 | obj = obj["state_dict"]
13 |
14 | newmodel = {}
15 | for k, v in obj.items():
16 | if not (k.startswith("module.encoder_q.backbone") or k.startswith("module.online_net.backbone")) or 'fc' in k:
17 | continue
18 | old_k = k
19 | k = k.replace("backbone.", "")
20 | k = k.replace("module.encoder_q.", "")
21 | k = k.replace("module.online_net.", "")
22 | print(old_k, "->", k)
23 | newmodel[k] = v
24 |
25 | with open(sys.argv[2], "wb") as f:
26 | torch.save(newmodel, f, _use_new_zipfile_serialization=False)
27 |
--------------------------------------------------------------------------------
/data/cvt_idx.py:
--------------------------------------------------------------------------------
1 | # Copyright Lang Huang (laynehuang@outlook.com). All Rights Reserved.
2 | # SPDX-License-Identifier: CC-BY-NC-4.0
3 |
4 | import os
5 |
6 | if __name__ == "__main__":
7 | train_file = "/mnt/lustre/share/data/images/meta/train.txt"
8 | idx_file = "data/10percent.txt"
9 | out_file = idx_file + ".ext"
10 | max_class = 1000
11 |
12 | with open(idx_file, "r") as fin, open(train_file, "r") as f_train:
13 | all_samples = {}
14 | idx_samples = []
15 | selected_samples = []
16 | for line in f_train.readlines():
17 | name, label = line.strip().split()
18 | label = int(label)
19 | if label < max_class:
20 | base_name = name.split("/")[1]
21 | all_samples[base_name] = (label, name)
22 | print(f"len of all samples: {len(all_samples)}")
23 |
24 | for line in fin.readlines():
25 | nm = line.strip()
26 | selected_samples.append(all_samples[nm])
27 |
28 | print(f"Len of selected samples {len(selected_samples)}")
29 |
30 | with open(out_file, "w") as fout:
31 | for (lb, nm) in selected_samples:
32 | fout.write(f"{lb} {nm}\n")
33 |
--------------------------------------------------------------------------------
/utils/lr_schedule.py:
--------------------------------------------------------------------------------
1 | # Original copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 | # SPDX-License-Identifier: CC-BY-NC-4.0
3 |
4 | import math
5 |
6 |
7 | def warmup_learning_rate(optimizer, curr_step, warmup_step, args):
8 | """linearly warm up learning rate"""
9 | lr = args.lr
10 | scalar = float(curr_step) / float(max(1, warmup_step))
11 | scalar = min(1., max(0., scalar))
12 | lr *= scalar
13 | for param_group in optimizer.param_groups:
14 | param_group['lr'] = lr
15 |
16 |
17 | def adjust_learning_rate(optimizer, epoch, args):
18 | """Decay the learning rate based on schedule"""
19 | lr = args.lr
20 | if args.cos: # cosine lr schedule
21 | progress = float(epoch - args.warmup_epoch) / float(args.epochs - args.warmup_epoch)
22 | lr *= 0.5 * (1. + math.cos(math.pi * progress))
23 | else: # stepwise lr schedule
24 | for milestone in args.schedule:
25 | lr *= 0.1 if epoch >= milestone else 1.
26 | for param_group in optimizer.param_groups:
27 | param_group['lr'] = lr
28 |
29 |
30 | def adjust_learning_rate_with_min(optimizer, epoch, args):
31 | """Decay the learning rate based on schedule"""
32 | lr = args.lr
33 | if args.cos: # cosine lr schedule
34 | min_lr = args.cos_min_lr
35 | progress = float(epoch - args.warmup_epoch) / float(args.epochs - args.warmup_epoch)
36 | lr = min_lr + 0.5 * (lr - min_lr) * (1. + math.cos(math.pi * progress))
37 | else: # stepwise lr schedule
38 | for milestone in args.schedule:
39 | lr *= 0.1 if epoch >= milestone else 1.
40 | for param_group in optimizer.param_groups:
41 | param_group['lr'] = lr
42 |
--------------------------------------------------------------------------------
/utils/init.py:
--------------------------------------------------------------------------------
1 | # Original copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 | # SPDX-License-Identifier: CC-BY-NC-4.0
3 |
4 | import torch.nn as nn
5 |
6 |
7 | def c2_xavier_fill(module: nn.Module) -> None:
8 | """
9 | Initialize `module.weight` using the "XavierFill" implemented in Caffe2.
10 | Also initializes `module.bias` to 0.
11 |
12 | Args:
13 | module (torch.nn.Module): module to initialize.
14 | """
15 | # Caffe2 implementation of XavierFill in fact
16 | # corresponds to kaiming_uniform_ in PyTorch
17 | nn.init.kaiming_uniform_(module.weight, a=1) # pyre-ignore
18 | if module.bias is not None: # pyre-ignore
19 | nn.init.constant_(module.bias, 0)
20 |
21 |
22 | def c2_msra_fill(module: nn.Module) -> None:
23 | """
24 | Initialize `module.weight` using the "MSRAFill" implemented in Caffe2.
25 | Also initializes `module.bias` to 0.
26 |
27 | Args:
28 | module (torch.nn.Module): module to initialize.
29 | """
30 | # pyre-ignore
31 | nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
32 | if module.bias is not None: # pyre-ignore
33 | nn.init.constant_(module.bias, 0)
34 |
35 |
36 | def normal_init(module: nn.Module, std=0.01):
37 | nn.init.normal_(module.weight, std=std)
38 | if module.bias is not None:
39 | nn.init.constant_(module.bias, 0)
40 |
41 |
42 | def init_weights(module, init_linear='normal'):
43 | assert init_linear in ['normal', 'kaiming'], \
44 | "Undefined init_linear: {}".format(init_linear)
45 | for m in module.modules():
46 | if isinstance(m, nn.Linear):
47 | if init_linear == 'normal':
48 | normal_init(m, std=0.01)
49 | else:
50 | c2_msra_fill(m)
51 | elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.GroupNorm, nn.SyncBatchNorm)):
52 | if m.weight is not None:
53 | nn.init.constant_(m.weight, 1)
54 | if m.bias is not None:
55 | nn.init.constant_(m.bias, 0)
56 | elif isinstance(m, nn.Conv1d):
57 | c2_msra_fill(m)
58 |
--------------------------------------------------------------------------------
/models/transformers/position_encoding.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # # Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/position_encoding.py
3 | """
4 | Various positional encodings for the transformer.
5 | """
6 | import math
7 |
8 | import torch
9 | from torch import nn
10 |
11 |
12 | class PositionEmbeddingSine(nn.Module):
13 | """
14 | This is a more standard version of the position embedding, very similar to the one
15 | used by the Attention is all you need paper, generalized to work on images.
16 | """
17 |
18 | def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
19 | super().__init__()
20 | self.num_pos_feats = num_pos_feats
21 | self.temperature = temperature
22 | self.normalize = normalize
23 | if scale is not None and normalize is False:
24 | raise ValueError("normalize should be True if scale is passed")
25 | if scale is None:
26 | scale = 2 * math.pi
27 | self.scale = scale
28 |
29 | def forward(self, x, mask=None):
30 | if mask is None:
31 | mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool)
32 | not_mask = ~mask
33 | y_embed = not_mask.cumsum(1, dtype=torch.float32)
34 | x_embed = not_mask.cumsum(2, dtype=torch.float32)
35 | if self.normalize:
36 | eps = 1e-6
37 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
38 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
39 |
40 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
41 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
42 |
43 | pos_x = x_embed[:, :, :, None] / dim_t
44 | pos_y = y_embed[:, :, :, None] / dim_t
45 | pos_x = torch.stack(
46 | (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
47 | ).flatten(3)
48 | pos_y = torch.stack(
49 | (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
50 | ).flatten(3)
51 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
52 | return pos
53 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 | #
132 | ckpts/
133 |
134 |
135 |
136 | # PyCharm
137 | /.idea
138 |
139 | # Sphinx
140 | /doc/build
141 |
142 | # Python
143 | __pycache__
144 | *.pyc
145 | *.egg-info
146 |
147 | # macOS
148 | .DS_Store
149 | */.DS_Store
--------------------------------------------------------------------------------
/data/rafdb.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | """
5 | """
6 |
7 | import os
8 | import sys
9 | import numpy as np
10 | from tqdm import tqdm
11 | from PIL import Image
12 | import matplotlib.pyplot as plt
13 |
14 | if sys.platform == 'win32':
15 | os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
16 |
17 | import torch
18 | import torch.utils.data as data
19 | import torchvision.transforms as transforms
20 |
21 | # Root directory of the project
22 | try:
23 | abspath = os.path.abspath(__file__)
24 | except NameError:
25 | abspath = os.getcwd()
26 | ROOT_DIR = os.path.dirname(abspath)
27 |
28 |
29 | IMG_EXTENSIONS = [
30 | '.jpg', '.JPG', '.jpeg', '.JPEG',
31 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
32 | ]
33 |
34 |
35 | def is_image_file(filename):
36 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
37 |
38 |
39 | class RAFDB(data.Dataset):
40 | def __init__(self, root='/media/jiaren/DataSet/basic/', split='train', transform=None):
41 | super().__init__()
42 |
43 | self.root = root
44 |
45 | image_list_file = os.path.join(root, "EmoLabel", "list_patition_label.txt")
46 | self.image_list_file = image_list_file
47 | self.split = split
48 | self.transform = transform
49 |
50 | self.samples = []
51 | self.targets = []
52 | with open(self.image_list_file, 'r') as f:
53 | for i, img_file in enumerate(f):
54 | img_file = img_file.strip()
55 | img_file = img_file.split(' ')
56 | if split in img_file[0]:
57 | self.samples.append(os.path.join(root, "Image", "aligned", img_file[0][:-4]+'_aligned.jpg'))
58 | self.targets.append(int(img_file[1]) - 1)
59 |
60 | def __getitem__(self, index):
61 | img_file = self.samples[index]
62 | image = Image.open(img_file)
63 |
64 | if image.mode != 'RGB':
65 | image = image.convert("RGB")
66 |
67 | target = self.targets[index]
68 |
69 | if self.transform is not None:
70 | image = self.transform(image)
71 |
72 | return image, target, index
73 |
74 | def __len__(self):
75 | return len(self.samples) #12271 #
76 |
77 |
78 | if __name__ == '__main__':
79 | display_transform = transforms.Compose([
80 | transforms.Resize((224, 224)),
81 | transforms.ToTensor()
82 | ])
83 |
84 | split = "train"
85 | dataset = RAFDB(root="../data/RAFDB/basic", split=split, transform=display_transform)
86 | print(len(dataset))
87 | print(set(dataset.targets))
88 |
89 | loader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=8, pin_memory=True,
90 | drop_last=False)
91 |
92 | with torch.no_grad():
93 | for i, (images, target, _) in enumerate(tqdm(loader)):
94 | img = np.clip(images.cpu().numpy(), 0, 1) # [0, 1]
95 | img = img.transpose(0, 2, 3, 1)
96 | img = (img * 255).astype(np.uint8)
97 | img = img.squeeze()
98 |
99 | fig, axs = plt.subplots(1, 1, figsize=(8, 8))
100 | axs.imshow(img)
101 | axs.axis("off")
102 | plt.show()
103 |
--------------------------------------------------------------------------------
/launch.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python3
2 |
3 | import os
4 | import sys
5 | import socket
6 | import random
7 | import argparse
8 | import subprocess
9 | import torch
10 |
11 |
12 | def _find_free_port():
13 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
14 | sock.bind(("", 0))
15 | port = sock.getsockname()[1]
16 | sock.close()
17 | return port
18 |
19 |
20 | def _get_rand_port():
21 | return random.randrange(20000, 60000)
22 |
23 |
24 | def init_workdir():
25 | ROOT = os.path.dirname(os.path.abspath(__file__))
26 | os.chdir(ROOT)
27 | sys.path.insert(0, ROOT)
28 |
29 | if __name__ == '__main__':
30 | parser = argparse.ArgumentParser(description='Launcher')
31 | parser.add_argument('--launch', type=str, default='tools/train.py',
32 | help='Specify launcher script.')
33 | parser.add_argument('--dist', type=int, default=1,
34 | help='Whether start by torch.distributed.launch.')
35 | parser.add_argument('--np', type=int, default=-1,
36 | help='number of processes per node.')
37 | parser.add_argument('--nn', type=int, default=1,
38 | help='number of workers in total.')
39 | parser.add_argument('--port', type=int, default=-1,
40 | help='master port for communication')
41 | parser.add_argument('--nr', type=int, default=0,
42 | help='node rank.')
43 | parser.add_argument('--master_address', '-ma', type=str, default="127.0.0.1")
44 | parser.add_argument('--device', default=None, type=str,
45 | help='indices of GPUs to enable (default: all)')
46 | args, other_args = parser.parse_known_args()
47 |
48 | if args.device:
49 | os.environ["CUDA_VISIBLE_DEVICES"] = args.device
50 | cmd = f"CUDA_VISIBLE_DEVICES={args.device} "
51 | else:
52 | cmd = f""
53 |
54 | init_workdir()
55 | master_address = args.master_address
56 | num_processes_per_worker = torch.cuda.device_count() if args.np < 0 else args.np
57 | num_workers = args.nn
58 | node_rank = args.nr
59 |
60 | if args.port > 0:
61 | master_port = args.port
62 | elif num_workers == 1:
63 | master_port = _find_free_port()
64 | else:
65 | master_port = _get_rand_port()
66 |
67 | if args.dist >= 1:
68 | print(f'Start {args.launch} by torch.distributed.launch with port {master_port}!', flush=True)
69 | os.environ['NPROC_PER_NODE'] = str(num_processes_per_worker)
70 | cmd += f'python3 -m torch.distributed.launch \
71 | --nproc_per_node={num_processes_per_worker} \
72 | --nnodes={num_workers} \
73 | --node_rank={node_rank} \
74 | --master_addr={master_address} \
75 | --master_port={master_port} \
76 | {args.launch}'
77 | else:
78 | print(f'Start {args.launch}!', flush=True)
79 | cmd += f'python3 -u {args.launch}'
80 |
81 | for argv in other_args:
82 | cmd += f' {argv}'
83 |
84 | with open('./log.txt', 'wb') as f:
85 | proc = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE)
86 | while True:
87 | text = proc.stdout.readline()
88 | f.write(text)
89 | f.flush()
90 | sys.stdout.buffer.write(text)
91 | sys.stdout.buffer.flush()
92 | exit_code = proc.poll()
93 | if exit_code is not None:
94 | break
95 | sys.exit(exit_code)
96 |
--------------------------------------------------------------------------------
/data/affectnet.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | """
5 | https://github.com/yaoing/dan
6 | https://github.com/ElenaRyumina/EMO-AffectNetModel
7 | https://github.com/PanosAntoniadis/emotion-gcn
8 | """
9 |
10 | __author__ = "GZ"
11 |
12 | import os
13 | import sys
14 | from shutil import copy
15 | import pandas as pd
16 | import numpy as np
17 | from tqdm import tqdm
18 | import matplotlib.pyplot as plt
19 |
20 | if sys.platform == 'win32':
21 | os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
22 |
23 | # Root directory of the project
24 | try:
25 | abspath = os.path.abspath(__file__)
26 | except NameError:
27 | abspath = os.getcwd()
28 | ROOT_DIR = os.path.dirname(abspath)
29 |
30 |
31 | # def convert_affectnet(label_file, save_dir):
32 | # df = pd.read_csv(label_file)
33 | #
34 | # for i in range(8):
35 | # for j in ['train','val']:
36 | # os.makedirs(os.path.join(save_dir, "AffectNet", j, i), exist_ok=True)
37 | #
38 | # for i, row in df.iterrows():
39 | # p = row['phase']
40 | # l = row['label']
41 | # copy(row['img_path'], os.path.join(save_dir, "AffectNet", p, l))
42 | #
43 | # print('convert done.')
44 | #
45 | #
46 | # def get_AffectNet(root, split, transform, num_class=7):
47 | # data_dir = os.path.join(root, split)
48 | # dataset = datasets.ImageFolder(data_dir, transform=transform)
49 | # if num_class == 7: # ignore the 8-th class
50 | # idx = [i for i in range(len(dataset)) if dataset.imgs[i][1] != 7]
51 | # dataset = data.Subset(dataset, idx)
52 | # return dataset
53 |
54 |
55 | def generate_affectnet(img_dir, label_dir, split, save_dir, num_class=7):
56 | assert split in ["train", "val"]
57 | label_file = "training.csv" if split == "train" else "validation.csv"
58 | head_list = ['subDirectory_filePath', 'face_x', 'face_y', 'face_width', 'face_height', 'facial_landmarks',
59 | 'expression', 'valence', 'arousal']
60 | dict_name_labels = {0: 'Neutral', 1: 'Happiness', 2: 'Sadness', 3: 'Surprise', 4: 'Fear', 5: 'Disgust', 6: 'Anger'}
61 |
62 | df_data_raw = pd.read_csv(os.path.join(label_dir, label_file))
63 | df_data_raw.expression = pd.to_numeric(df_data_raw.expression, errors='coerce').fillna(100).astype('int64')
64 |
65 | df_data = df_data_raw[df_data_raw['expression'] < num_class]
66 |
67 | for label in range(num_class):
68 | os.makedirs(os.path.join(save_dir, split, str(label)), exist_ok=True)
69 |
70 | file_notfound = []
71 | for i, row in tqdm(df_data.iterrows(), total=df_data.shape[0]):
72 | label = row['expression']
73 | img_file = os.path.join(img_dir, row['subDirectory_filePath'])
74 |
75 | if os.path.isfile(img_file):
76 | copy(img_file, os.path.join(save_dir, split, str(label)))
77 | else:
78 | file_notfound.append(img_file)
79 |
80 | # 2/9db2af5a1da8bd77355e8c6a655da519a899ecc42641bf254107bfc0.jpg
81 | print(file_notfound)
82 |
83 |
84 | if __name__ == '__main__':
85 | import torch
86 | import torch.utils.data as data
87 | from torchvision import transforms, datasets
88 | from data.base_dataset import ImageFolderInstance
89 | from data.sampler import DistributedImbalancedSampler, DistributedSamplerWrapper, ImbalancedDatasetSampler
90 |
91 | # label_file = "../data/FER/AffectNet/affectnet.csv"
92 | # save_dir = "../data/FER"
93 | # convert_affectnet(label_file, save_dir)
94 |
95 | img_dir = "../data/FER/AffectNet/Manually_Annotated_Images"
96 | label_dir = '../data/FER/AffectNet/Manually_Annotated_file_lists'
97 | split = "train"
98 | save_dir = "../data/FER/AffectNet_subset"
99 | # generate_affectnet(img_dir, label_dir, split, save_dir)
100 |
101 | data_root = save_dir
102 | display_transform = transforms.Compose([
103 | transforms.Resize((224, 224)),
104 | transforms.ToTensor()
105 | ])
106 |
107 | # dataset = get_AffectNet(data_root, split, display_transform, num_class=7)
108 | data_dir = os.path.join(data_root, split)
109 | dataset = ImageFolderInstance(data_dir, transform=display_transform)
110 | print(dataset)
111 |
112 | train_percent = 0.1
113 | if train_percent < 1.0:
114 | num_subset = int(len(dataset) * train_percent)
115 | indices = torch.randperm(len(dataset))[:num_subset]
116 | indices = indices.tolist()
117 | dataset = torch.utils.data.Subset(dataset, indices)
118 | print("Sub train_dataset:\n{}".format(len(dataset)))
119 |
120 | sampler = ImbalancedDatasetSampler(dataset)
121 | loader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=8, pin_memory=True,
122 | drop_last=False)
123 |
124 | with torch.no_grad():
125 | for i, (images, target, _) in enumerate(tqdm(loader)):
126 | img = np.clip(images.cpu().numpy(), 0, 1) # [0, 1]
127 | img = img.transpose(0, 2, 3, 1)
128 | img = (img * 255).astype(np.uint8)
129 | img = img.squeeze()
130 |
131 | fig, axs = plt.subplots(1, 1, figsize=(8, 8))
132 | axs.imshow(img)
133 | axs.axis("off")
134 | plt.show()
135 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Self-Supervised Facial Representation Learning with Facial Region Awareness
2 |
3 |
4 |
5 |
6 |
7 | Self-Supervised Facial Representation Learning with Facial Region Awareness (CVPR 2024)
8 | By
9 | Zheng Gao and
10 | Ioannis Patras.
11 |
12 |
13 | ## Introduction
14 |
15 | > **Abstract**: Self-supervised pre-training has been proved to be effective in learning transferable representations that benefit various visual tasks. This paper asks this question: can self-supervised pre-training learn general facial representations for various facial analysis tasks? Recent efforts toward this goal are limited to treating each face image as a whole, i.e., learning consistent facial representations at the image-level, which overlooks the **consistency of local facial representations** (i.e., facial regions like eyes, nose, etc). In this work, we make a **first attempt** to propose a novel self-supervised facial representation learning framework to learn consistent global and local facial representations, Facial Region Awareness (FRA). Specifically, we explicitly enforce the consistency of facial regions by matching the local facial representations across views, which are extracted with learned heatmaps highlighting the facial regions. Inspired by the mask prediction in supervised semantic segmentation, we obtain the heatmaps via cosine similarity between the per-pixel projection of feature maps and facial mask embeddings computed from learnable positional embeddings, which leverage the attention mechanism to globally look up the facial image for facial regions. To learn such heatmaps, we formulate the learning of facial mask embeddings as a deep clustering problem by assigning the pixel features from the feature maps to them. The transfer learning results on facial classification and regression tasks show that our FRA outperforms previous pre-trained models and more importantly, using ResNet as the unified backbone for various tasks, our FRA achieves comparable or even better performance compared with SOTA methods in facial analysis tasks.
16 |
17 | 
18 |
19 |
20 | ## Installation
21 | Please refer to `requirement.txt` for the dependencies. Alternatively, you can install dependencies using the following command:
22 | ```
23 | pip3 install -r requirement.txt
24 | ```
25 | The repository works with `PyTorch 1.10.2` or higher and `CUDA 11.1`.
26 |
27 | ## Get started
28 |
29 | We provide basic usage of the implementation in the following sections:
30 |
31 | ### Pre-training on VGGFace2
32 |
33 | Download [VGGFace2](https://academictorrents.com/details/535113b8395832f09121bc53ac85d7bc8ef6fa5b) dataset and specify the path to VGGFace2 by `DATA_ROOT="./data/VGG-Face2-crop"`.
34 |
35 | To perform pre-training of the model with ResNet-50 backbone on VGGFace2 with multi-gpu, run:
36 | ```
37 | python3 launch.py --device=${DEVICES} --launch main.py \
38 | --arch FRAB --backbone resnet50_encoder \
39 | --dataset vggface2 --data-root ${DATA_ROOT} \
40 | --lr 0.9 -b 512 --wd 0.000001 --epochs 50 --cos --warmup-epoch 10 --workers 16 \
41 | --enc-m 0.996 \
42 | --norm SyncBN \
43 | --lewel-loss-weight 0.5 \
44 | --mask_type="attn" --num_proto 8 --teacher_temp 0.04 --loss_w_cluster 0.1 \
45 | --amp \
46 | --save-dir ./ckpts --save-freq 50 --print-freq 100
47 | ```
48 | `DEVICES` denotes the gpu indices.
49 |
50 | ### Evaluation: Facial expression recognition (FER)
51 | The following is an example of evaluating the pre-trained model on RAFDB dataset, under the setting of fine-tuning both encoder backbone and linear classifier:
52 | ```
53 | python3 launch.py --device=${DEVICES} --launch main_fer.py \
54 | -a resnet50 \
55 | --dataset rafdb --data-root ${FER_DATA_ROOT} \
56 | --lr 0.0002 --lr_head 0.0002 --optimizer adamw --weight-decay 0.05 --scheduler cos \
57 | --finetune \
58 | --epochs 100 --batch-size 256 \
59 | --amp \
60 | --workers 16 \
61 | --eval-freq 5 \
62 | --model-prefix online_net.backbone \
63 | --pretrained ${PRETRAINED} \
64 | --image_size 224 \
65 | --multiprocessing_distributed
66 | ```
67 | `PRETRAINED` denotes the path to the pre-trained checkpoint and `FER_DATA_ROOT=/path/to/datasets` is the location for FER datasets.
68 |
69 | ### Evaluation: Face alignment
70 | For evaluation on face alignment, we use [STAR Loss](https://github.com/ZhenglinZhou/STAR) as the downstream backbone. Please refer to [STAR Loss](https://github.com/ZhenglinZhou/STAR).
71 |
72 |
73 | ## Citation
74 |
75 | If you find this repository useful, please consider giving a star :star: and citation:
76 |
77 | ```bibteX
78 | @article{gao2023self,
79 | title={Self-Supervised Representation Learning with Cross-Context Learning between Global and Hypercolumn Features},
80 | author={Gao, Zheng and Patras, Ioannis},
81 | journal={arXiv preprint arXiv:2308.13392},
82 | year={2023}
83 | }
84 | ```
85 |
86 | ## Acknowledgment
87 | Our project is based on [LEWEL](https://github.com/LayneH/LEWEL). Thanks for their wonderful work.
88 |
89 |
90 | ## License
91 |
92 | This project is released under the [CC-BY-NC 4.0 license](LICENSE).
--------------------------------------------------------------------------------
/utils/LARS.py:
--------------------------------------------------------------------------------
1 | # code in this file is adapted from
2 | # https://github.com/yaox12/BYOL-PyTorch/blob/master/optimizer/LARSSGD.py
3 | # Copyright 2020 Xin Yao. Licensed under the MIT License.
4 | # Modifications Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
5 | # SPDX-License-Identifier: CC-BY-NC-4.0
6 |
7 | """ Layer-wise adaptive rate scaling for SGD in PyTorch! """
8 | import torch
9 | from torch.optim.optimizer import Optimizer, required
10 |
11 |
12 | class LARS(Optimizer):
13 | r"""Implements layer-wise adaptive rate scaling for SGD.
14 | Args:
15 | params (iterable): iterable of parameters to optimize or dicts defining
16 | parameter groups
17 | lr (float): base learning rate (\gamma_0)
18 | momentum (float, optional): momentum factor (default: 0) ("m")
19 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
20 | ("\beta")
21 | dampening (float, optional): dampening for momentum (default: 0)
22 | eta (float, optional): LARS coefficient
23 | nesterov (bool, optional): enables Nesterov momentum (default: False)
24 | Based on Algorithm 1 of the following paper by You, Gitman, and Ginsburg.
25 | Large Batch Training of Convolutional Networks:
26 | https://arxiv.org/abs/1708.03888
27 | Example:
28 | >>> optimizer = LARS(model.parameters(), lr=0.1, momentum=0.9,
29 | >>> weight_decay=1e-4, eta=1e-3)
30 | >>> optimizer.zero_grad()
31 | >>> loss_fn(model(input), target).backward()
32 | >>> optimizer.step()
33 | """
34 |
35 | def __init__(self,
36 | params,
37 | lr=required,
38 | momentum=0,
39 | dampening=0,
40 | weight_decay=0,
41 | eta=0.001,
42 | nesterov=False,
43 | eps=1e-8):
44 | if lr is not required and lr < 0.0:
45 | raise ValueError("Invalid learning rate: {}".format(lr))
46 | if momentum < 0.0:
47 | raise ValueError("Invalid momentum value: {}".format(momentum))
48 | if weight_decay < 0.0:
49 | raise ValueError(
50 | "Invalid weight_decay value: {}".format(weight_decay))
51 | if eta < 0.0:
52 | raise ValueError("Invalid LARS coefficient value: {}".format(eta))
53 |
54 | defaults = dict(
55 | lr=lr, momentum=momentum, dampening=dampening,
56 | weight_decay=weight_decay, nesterov=nesterov, eta=eta)
57 | if nesterov and (momentum <= 0 or dampening != 0):
58 | raise ValueError("Nesterov momentum requires a momentum and zero dampening")
59 |
60 | super(LARS, self).__init__(params, defaults)
61 |
62 | self.eps = eps
63 |
64 | def __setstate__(self, state):
65 | super(LARS, self).__setstate__(state)
66 | for group in self.param_groups:
67 | group.setdefault('nesterov', False)
68 |
69 | @torch.no_grad()
70 | def step(self, closure=None):
71 | """Performs a single optimization step.
72 | Arguments:
73 | closure (callable, optional): A closure that reevaluates the model
74 | and returns the loss.
75 | """
76 | loss = None
77 | if closure is not None:
78 | with torch.enable_grad():
79 | loss = closure()
80 |
81 | for group in self.param_groups:
82 | weight_decay = group['weight_decay']
83 | momentum = group['momentum']
84 | dampening = group['dampening']
85 | eta = group['eta']
86 | nesterov = group['nesterov']
87 | lr = group['lr']
88 | lars_exclude = group.get('lars_exclude', False)
89 |
90 | for p in group['params']:
91 | if p.grad is None:
92 | continue
93 |
94 | d_p = p.grad
95 |
96 | if lars_exclude:
97 | local_lr = 1.
98 | else:
99 | weight_norm = torch.norm(p).item()
100 | grad_norm = torch.norm(d_p).item()
101 | # Compute local learning rate for this layer
102 | local_lr = eta * weight_norm / \
103 | (grad_norm + weight_decay * weight_norm + self.eps)
104 |
105 | actual_lr = local_lr * lr
106 | d_p = d_p.add(p, alpha=weight_decay).mul(actual_lr)
107 | if momentum != 0:
108 | param_state = self.state[p]
109 | if 'momentum_buffer' not in param_state:
110 | buf = param_state['momentum_buffer'] = \
111 | torch.clone(d_p).detach()
112 | else:
113 | buf = param_state['momentum_buffer']
114 | buf.mul_(momentum).add_(d_p, alpha=1 - dampening)
115 | if nesterov:
116 | d_p = d_p.add(buf, alpha=momentum)
117 | else:
118 | d_p = buf
119 | p.add_(-d_p)
120 |
121 | return loss
122 |
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | # Original copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 | # SPDX-License-Identifier: CC-BY-NC-4.0
3 | import os
4 | import numpy as np
5 | import shutil
6 | from sklearn.metrics import accuracy_score
7 | import skimage.io
8 |
9 | import torch
10 |
11 |
12 | def load_netowrk(model, path, checkpoint_key="net"):
13 | if os.path.isfile(path):
14 | print("=> loading checkpoint '{}'".format(path))
15 | checkpoint = torch.load(path, map_location="cpu")
16 |
17 | # rename pre-trained keys
18 | state_dict = checkpoint[checkpoint_key]
19 | state_dict_new = {k.replace("module.", ""): v for k, v in state_dict.items()}
20 |
21 | msg = model.load_state_dict(state_dict_new)
22 | assert set(msg.missing_keys) == set()
23 |
24 | print("=> loaded pre-trained model '{}'".format(path))
25 | else:
26 | print("=> no checkpoint found at '{}'".format(path))
27 |
28 |
29 | def save_checkpoint(state, is_best, epoch, args, filename='checkpoint.pth.tar'):
30 | filename = os.path.join(args.save_dir, filename)
31 | torch.save(state, filename)
32 | # if is_best:
33 | # shutil.copyfile(filename, os.path.join(args.save_dir, 'model_best.pth.tar'))
34 | if args.save_freq > 0 and (epoch + 1) % args.save_freq == 0:
35 | shutil.copyfile(filename, os.path.join(args.save_dir, 'checkpoint_{:04d}.pth.tar'.format(epoch)))
36 | if not args.cos:
37 | if (epoch + 1) in args.schedule:
38 | shutil.copyfile(filename, os.path.join(args.save_dir, 'checkpoint_{:04d}.pth.tar'.format(epoch)))
39 |
40 |
41 | def accuracy(output, target, topk=(1,)):
42 | """Computes the accuracy over the k top predictions for the specified values of k"""
43 | with torch.no_grad():
44 | maxk = max(topk)
45 | batch_size = target.size(0)
46 |
47 | _, pred = output.topk(maxk, 1, True, True)
48 | pred = pred.t()
49 | correct = pred.eq(target.view(1, -1).expand_as(pred))
50 |
51 | res = []
52 | for k in topk:
53 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
54 | res.append(correct_k.mul_(100.0 / batch_size))
55 | return res
56 |
57 |
58 | def accuracy_multilabel(output, target, threshold=0.5):
59 | """
60 | https://www.kaggle.com/code/kmkarakaya/multi-label-model-evaluation
61 | """
62 | with torch.no_grad():
63 | batch_size, n_class = target.shape
64 | pred = (output >= threshold).to(torch.float32)
65 |
66 | acc = (pred == target).float().sum() * 100.0 / (batch_size * n_class)
67 |
68 | # acc = sklearn.metrics.accuracy_score(gt_S,pred_S)
69 | # f1m = sklearn.metrics.f1_score(gt_S,pred_S,average = 'macro', zero_division=1)
70 | # f1mi = sklearn.metrics.f1_score(gt_S,pred_S,average = 'micro', zero_division=1)
71 | # print('f1_Macro_Score{}'.format(f1m))
72 | # print('f1_Micro_Score{}'.format(f1mi))
73 | # print('Accuracy{}'.format(acc))
74 |
75 | return acc
76 |
77 |
78 | class AverageMeter(object):
79 | """Computes and stores the average and current value"""
80 | def __init__(self, name, fmt=':f'):
81 | self.name = name
82 | self.fmt = fmt
83 | self.reset()
84 |
85 | def reset(self):
86 | self.val = 0
87 | self.avg = 0
88 | self.sum = 0
89 | self.count = 0
90 |
91 | def update(self, val, n=1):
92 | self.val = val
93 | self.sum += val * n
94 | self.count += n
95 | self.avg = self.sum / self.count
96 |
97 | def __str__(self):
98 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
99 | return fmtstr.format(**self.__dict__)
100 |
101 |
102 | class ProgressMeter(object):
103 | def __init__(self, num_batches, meters, prefix=""):
104 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
105 | self.meters = meters
106 | self.prefix = prefix
107 |
108 | def display(self, batch):
109 | entries = [self.prefix + self.batch_fmtstr.format(batch)]
110 | entries += [str(meter) for meter in self.meters]
111 | print('\t'.join(entries), flush=True)
112 |
113 | def _get_batch_fmtstr(self, num_batches):
114 | num_digits = len(str(num_batches // 1))
115 | fmt = '{:' + str(num_digits) + 'd}'
116 | return '[' + fmt + '/' + fmt.format(num_batches) + ']'
117 |
118 |
119 | class InstantMeter(object):
120 | """Computes and stores the average and current value"""
121 | def __init__(self, name, fmt=':f'):
122 | self.name = name
123 | self.fmt = fmt
124 | self.val = 0
125 |
126 | def update(self, val):
127 | self.val = val
128 |
129 | def __str__(self):
130 | fmtstr = '{name} {val' + self.fmt + '}'
131 | return fmtstr.format(**self.__dict__)
132 |
133 |
134 | def denormalize_batch(batch, mean, std):
135 | """denormalize for visualization"""
136 | dtype = batch.dtype
137 | mean = torch.as_tensor(mean, dtype=dtype, device=batch.device)
138 | std = torch.as_tensor(std, dtype=dtype, device=batch.device)
139 | mean = mean.view(-1, 1, 1)
140 | std = std.view(-1, 1, 1)
141 | batch = batch * std + mean
142 | return batch
143 |
144 | def dump_image(imgNorm, mean, std, filepath=None, verbose=False):
145 | """Denormalizes the output image and optionally plots the landmark coordinates onto the image
146 |
147 | Args:
148 | normalized_image (torch.tensor): Image reconstruction output from the model (normalized)
149 | landmark_coords (torch.tensor): x, y coordinates in normalized range -1 to 1
150 | out_name (str, optional): file to write to
151 | Returns:
152 | np.array: uint8 image data stored in numpy format
153 | """
154 | if imgNorm.dim() < 4:
155 | imgNorm = imgNorm.unsqueeze(0)
156 |
157 | img = denormalize_batch(imgNorm, mean, std)
158 | img = np.clip(img.cpu().numpy(), 0, 1)
159 | img = (img.transpose(0, 2, 3, 1) * 255).astype(np.uint8)
160 |
161 | if filepath is not None:
162 | skimage.io.imsave(filepath, img[0])
163 |
164 | if verbose:
165 | num = min(img.shape[0], 9)
166 | show_images(img[:num], 3, 3)
167 | plt.show()
168 | return img
169 |
170 |
171 | def calc_params(net, verbose=False):
172 | num_params = 0
173 | for param in net.parameters():
174 | num_params += param.numel()
175 | if verbose:
176 | print(net)
177 | print('Total number of parameters : %.3f M' % (num_params / 1e6))
178 |
179 | return num_params
180 |
181 |
182 | if __name__ == '__main__':
183 | output = torch.tensor([[0.35,0.4,0.9], [0.2,0.6,0.8]])
184 | target = torch.tensor([[1, 0, 1], [0, 1, 1]])
185 | acc = accuracy_multilabel(output, target)
186 | print(acc)
187 |
--------------------------------------------------------------------------------
/models/transformers/transformer_predictor.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/detr.py
3 | import os
4 | import sys
5 |
6 | import fvcore.nn.weight_init as weight_init
7 | import torch
8 | from torch import nn
9 | from torch.nn import functional as F
10 |
11 | SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
12 | sys.path.append(os.path.dirname(SCRIPT_DIR))
13 |
14 | from transformers.position_encoding import PositionEmbeddingSine
15 | from transformers.transformer import Transformer
16 |
17 |
18 | class TransformerPredictor(nn.Module):
19 | def __init__(
20 | self,
21 | in_channels,
22 | mask_classification=True,
23 | *,
24 | num_classes: int,
25 | hidden_dim: int,
26 | num_queries: int,
27 | nheads: int,
28 | dropout: float,
29 | dim_feedforward: int,
30 | enc_layers: int,
31 | dec_layers: int,
32 | pre_norm: bool,
33 | deep_supervision: bool,
34 | mask_dim: int,
35 | enforce_input_project: bool,
36 | ):
37 | """
38 | NOTE: this interface is experimental.
39 | Args:
40 | in_channels: channels of the input features
41 | mask_classification: whether to add mask classifier or not
42 | num_classes: number of classes
43 | hidden_dim: Transformer feature dimension
44 | num_queries: number of queries
45 | nheads: number of heads
46 | dropout: dropout in Transformer
47 | dim_feedforward: feature dimension in feedforward network
48 | enc_layers: number of Transformer encoder layers
49 | dec_layers: number of Transformer decoder layers
50 | pre_norm: whether to use pre-LayerNorm or not
51 | deep_supervision: whether to add supervision to every decoder layers
52 | mask_dim: mask feature dimension
53 | enforce_input_project: add input project 1x1 conv even if input
54 | channels and hidden dim is identical
55 | """
56 | super().__init__()
57 |
58 | self.mask_classification = mask_classification
59 |
60 | # positional encoding
61 | N_steps = hidden_dim // 2
62 | self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True)
63 |
64 | transformer = Transformer(
65 | d_model=hidden_dim,
66 | dropout=dropout,
67 | nhead=nheads,
68 | dim_feedforward=dim_feedforward,
69 | num_encoder_layers=enc_layers,
70 | num_decoder_layers=dec_layers,
71 | normalize_before=pre_norm,
72 | return_intermediate_dec=deep_supervision,
73 | )
74 |
75 | self.num_queries = num_queries
76 | self.transformer = transformer
77 | hidden_dim = transformer.d_model
78 |
79 | self.query_embed = nn.Embedding(num_queries, hidden_dim)
80 |
81 | if in_channels != hidden_dim or enforce_input_project:
82 | # self.input_proj = Conv2d(in_channels, hidden_dim, kernel_size=1)
83 | self.input_proj = nn.Conv2d(in_channels, hidden_dim, kernel_size=1)
84 | weight_init.c2_xavier_fill(self.input_proj)
85 | else:
86 | self.input_proj = nn.Sequential()
87 | self.aux_loss = deep_supervision
88 |
89 | # output FFNs
90 | if self.mask_classification:
91 | self.class_embed = nn.Linear(hidden_dim, num_classes + 1)
92 | self.mask_embed = MLP(hidden_dim, hidden_dim, mask_dim, 3)
93 |
94 | def forward(self, x, mask_features=None):
95 | pos = self.pe_layer(x)
96 |
97 | src = x
98 | mask = None
99 | hs, memory = self.transformer(self.input_proj(src), mask, self.query_embed.weight, pos)
100 |
101 | if self.mask_classification:
102 | outputs_class = self.class_embed(hs)
103 | out = {"pred_logits": outputs_class[-1]}
104 | else:
105 | out = {}
106 |
107 | if self.aux_loss:
108 | # [l, bs, queries, embed]
109 | mask_embed = self.mask_embed(hs)
110 | outputs_seg_masks = torch.einsum("lbqc,bchw->lbqhw", mask_embed, mask_features)
111 | out["pred_masks"] = outputs_seg_masks[-1]
112 | out["aux_outputs"] = self._set_aux_loss(
113 | outputs_class if self.mask_classification else None, outputs_seg_masks
114 | )
115 | else:
116 | # FIXME h_boxes takes the last one computed, keep this in mind
117 | # [bs, queries, embed]
118 | mask_embed = self.mask_embed(hs[-1])
119 | out["mask_embed"] = mask_embed
120 | if mask_features is not None:
121 | outputs_seg_masks = torch.einsum("bqc,bchw->bqhw", mask_embed, mask_features)
122 | out["pred_masks"] = outputs_seg_masks
123 | return out
124 |
125 | @torch.jit.unused
126 | def _set_aux_loss(self, outputs_class, outputs_seg_masks):
127 | # this is a workaround to make torchscript happy, as torchscript
128 | # doesn't support dictionary with non-homogeneous values, such
129 | # as a dict having both a Tensor and a list.
130 | if self.mask_classification:
131 | return [
132 | {"pred_logits": a, "pred_masks": b}
133 | for a, b in zip(outputs_class[:-1], outputs_seg_masks[:-1])
134 | ]
135 | else:
136 | return [{"pred_masks": b} for b in outputs_seg_masks[:-1]]
137 |
138 |
139 | class MLP(nn.Module):
140 | """Very simple multi-layer perceptron (also called FFN)"""
141 |
142 | def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
143 | super().__init__()
144 | self.num_layers = num_layers
145 | h = [hidden_dim] * (num_layers - 1)
146 | self.layers = nn.ModuleList(
147 | nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
148 | )
149 |
150 | def forward(self, x):
151 | for i, layer in enumerate(self.layers):
152 | x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
153 | return x
154 |
155 |
156 | if __name__ == '__main__':
157 | from utils.utils import calc_params
158 |
159 | model = TransformerPredictor(in_channels=2048, hidden_dim=256, num_queries=100, nheads=8, dropout=0.1, dim_feedforward=2048,
160 | enc_layers=0, dec_layers=1, pre_norm=False, deep_supervision=False, mask_dim=256,
161 | enforce_input_project=False, mask_classification=False, num_classes=0)
162 | print(model)
163 |
164 | x = torch.randn(16, 2048, 7, 7)
165 | mask_features = torch.randn(16, 256, 7, 7)
166 | out = model(x, mask_features)
167 |
168 | calc_params(model)
169 |
--------------------------------------------------------------------------------
/utils/dist_utils.py:
--------------------------------------------------------------------------------
1 | # some code in this file is adapted from
2 | # https://github.com/facebookresearch/moco
3 | # Original Copyright 2020 Facebook, Inc. and its affiliates. Licensed under the CC-BY-NC 4.0 License.
4 | # Modifications Copyright Lang Huang (laynehuang@outlook.com). All Rights Reserved.
5 | # SPDX-License-Identifier: CC-BY-NC-4.0
6 |
7 | import os
8 | import sys
9 | import random
10 | import datetime
11 | import torch
12 | import torch.distributed as dist
13 |
14 |
15 | @torch.no_grad()
16 | def batch_shuffle_ddp(x):
17 | """
18 | Batch shuffle, for making use of BatchNorm.
19 | *** Only support DistributedDataParallel (DDP) model. ***
20 | """
21 | # gather from all gpus
22 | batch_size_this = x.shape[0]
23 | x_gather = concat_all_gather(x)
24 | batch_size_all = x_gather.shape[0]
25 |
26 | num_gpus = batch_size_all // batch_size_this
27 |
28 | # random shuffle index
29 | idx_shuffle = torch.randperm(batch_size_all).cuda()
30 |
31 | # broadcast to all gpus
32 | torch.distributed.broadcast(idx_shuffle, src=0)
33 |
34 | # index for restoring
35 | idx_unshuffle = torch.argsort(idx_shuffle)
36 |
37 | # shuffled index for this gpu
38 | gpu_idx = torch.distributed.get_rank()
39 | idx_this = idx_shuffle.view(num_gpus, -1)[gpu_idx]
40 |
41 | return x_gather[idx_this], idx_unshuffle
42 |
43 |
44 | @torch.no_grad()
45 | def batch_unshuffle_ddp(x, idx_unshuffle):
46 | """
47 | Undo batch shuffle.
48 | *** Only support DistributedDataParallel (DDP) model. ***
49 | """
50 | # gather from all gpus
51 | batch_size_this = x.shape[0]
52 | x_gather = concat_all_gather(x)
53 | batch_size_all = x_gather.shape[0]
54 |
55 | num_gpus = batch_size_all // batch_size_this
56 |
57 | # restored index for this gpu
58 | gpu_idx = torch.distributed.get_rank()
59 | idx_this = idx_unshuffle.view(num_gpus, -1)[gpu_idx]
60 |
61 | return x_gather[idx_this]
62 |
63 |
64 | @torch.no_grad()
65 | def concat_all_gather(tensor):
66 | """
67 | Performs all_gather operation on the provided tensors.
68 | *** Warning ***: torch.distributed.all_gather has no gradient.
69 | """
70 | tensors_gather = [torch.ones_like(tensor)
71 | for _ in range(torch.distributed.get_world_size())]
72 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
73 |
74 | output = torch.cat(tensors_gather, dim=0)
75 | return output
76 |
77 |
78 | def is_dist_avail_and_initialized():
79 | if not dist.is_available():
80 | return False
81 | if not dist.is_initialized():
82 | return False
83 | return True
84 |
85 |
86 | def get_world_size():
87 | if not is_dist_avail_and_initialized():
88 | return 1
89 | return dist.get_world_size()
90 |
91 |
92 | def get_rank():
93 | if not is_dist_avail_and_initialized():
94 | return 0
95 | return dist.get_rank()
96 |
97 | def init_distributed_mode(args):
98 | if is_dist_avail_and_initialized():
99 | return
100 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
101 | args.rank = int(os.environ["RANK"])
102 | args.world_size = int(os.environ['WORLD_SIZE'])
103 | args.gpu = int(os.environ['LOCAL_RANK'])
104 |
105 | elif 'SLURM_PROCID' in os.environ:
106 | args.rank = int(os.environ['SLURM_PROCID'])
107 | args.gpu = args.rank % torch.cuda.device_count()
108 | elif torch.cuda.is_available():
109 | print('Will run the code on one GPU.')
110 | args.rank, args.gpu, args.world_size = 0, 0, 1
111 | os.environ['MASTER_ADDR'] = '127.0.0.1'
112 | os.environ['MASTER_PORT'] = str(random.randint(0, 9999) + 40000)
113 | else:
114 | print('Does not support training without GPU.')
115 | sys.exit(1)
116 |
117 | print("Use GPU: {} ranked {} out of {} gpus for training".format(args.gpu, args.rank, args.world_size))
118 | if args.multiprocessing_distributed:
119 | dist.init_process_group(
120 | backend="nccl",
121 | init_method=args.dist_url,
122 | world_size=args.world_size,
123 | timeout=datetime.timedelta(hours=5),
124 | rank=args.rank,
125 | )
126 | print('| distributed init (rank {}): {}'.format(
127 | args.rank, args.dist_url), flush=True)
128 | dist.barrier()
129 |
130 | torch.cuda.set_device(args.gpu)
131 | setup_for_distributed(args.rank == 0)
132 |
133 |
134 | def setup_for_distributed(is_master):
135 | """
136 | This function disables printing when not in master process
137 | """
138 | import builtins as __builtin__
139 | builtin_print = __builtin__.print
140 |
141 | def print(*args, **kwargs):
142 | force = kwargs.pop('force', False)
143 | if is_master or force:
144 | builtin_print(*args, **kwargs)
145 |
146 | __builtin__.print = print
147 |
148 |
149 | def all_reduce_mean(x):
150 | # reduce tensore for DDP
151 | # source: https://raw.githubusercontent.com/NVIDIA/apex/master/examples/imagenet/main_amp.py
152 | world_size = get_world_size()
153 | if world_size > 1:
154 | rt = x.clone()
155 | torch.distributed.all_reduce(rt, op=torch.distributed.ReduceOp.SUM)
156 | rt /= world_size
157 | return rt
158 | else:
159 | return x
160 |
161 | # def dist_init(port=23456):
162 | #
163 | # def init_parrots(host_addr, rank, local_rank, world_size, port):
164 | # os.environ['MASTER_ADDR'] = str(host_addr)
165 | # os.environ['MASTER_PORT'] = str(port)
166 | # os.environ['WORLD_SIZE'] = str(world_size)
167 | # os.environ['RANK'] = str(rank)
168 | # torch.distributed.init_process_group(backend="nccl")
169 | # torch.cuda.set_device(local_rank)
170 | #
171 | # def init(host_addr, rank, local_rank, world_size, port):
172 | # host_addr_full = 'tcp://' + host_addr + ':' + str(port)
173 | # torch.distributed.init_process_group("nccl", init_method=host_addr_full,
174 | # rank=rank, world_size=world_size)
175 | # torch.cuda.set_device(local_rank)
176 | # assert torch.distributed.is_initialized()
177 | #
178 | #
179 | # def parse_host_addr(s):
180 | # if '[' in s:
181 | # left_bracket = s.index('[')
182 | # right_bracket = s.index(']')
183 | # prefix = s[:left_bracket]
184 | # first_number = s[left_bracket+1:right_bracket].split(',')[0].split('-')[0]
185 | # return prefix + first_number
186 | # else:
187 | # return s
188 | #
189 | # if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
190 | # rank = int(os.environ["RANK"])
191 | # local_rank = int(os.environ['LOCAL_RANK'])
192 | # world_size = int(os.environ['WORLD_SIZE'])
193 | # ip = 'env://'
194 | #
195 | # elif 'SLURM_PROCID' in os.environ:
196 | # rank = int(os.environ['SLURM_PROCID'])
197 | # local_rank = int(os.environ['SLURM_LOCALID'])
198 | # world_size = int(os.environ['SLURM_NTASKS'])
199 | # ip = parse_host_addr(os.environ['SLURM_STEP_NODELIST'])
200 | # else:
201 | # raise RuntimeError()
202 | #
203 | # if torch.__version__ == 'parrots':
204 | # init_parrots(ip, rank, local_rank, world_size, port)
205 | # else:
206 | # init(ip, rank, local_rank, world_size, port)
207 | #
208 | # return rank, local_rank, world_size
209 |
210 |
211 | # https://github.com/facebookresearch/msn
212 | class AllReduce(torch.autograd.Function):
213 |
214 | @staticmethod
215 | def forward(ctx, x):
216 | if (
217 | dist.is_available()
218 | and dist.is_initialized()
219 | and (dist.get_world_size() > 1)
220 | ):
221 | x = x.contiguous() / dist.get_world_size()
222 | dist.all_reduce(x)
223 | return x
224 |
225 | @staticmethod
226 | def backward(ctx, grads):
227 | return grads
--------------------------------------------------------------------------------
/data/randaugment.py:
--------------------------------------------------------------------------------
1 | # some code in this file is adapted from
2 | # https://github.com/kekmodel/FixMatch-pytorch/blob/master/dataset/randaugment.py
3 | # Original Copyright 2019 Jungdae Kim, Qing Yu. Licensed under the MIT License.
4 | # Modifications Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
5 | # SPDX-License-Identifier: CC-BY-NC-4.0
6 |
7 | import logging
8 | import random
9 |
10 | import numpy as np
11 | import PIL
12 | import PIL.ImageOps
13 | import PIL.ImageEnhance
14 | import PIL.ImageDraw
15 | from PIL import Image, ImageFilter
16 |
17 | logger = logging.getLogger(__name__)
18 |
19 | PARAMETER_MAX = 10
20 |
21 |
22 | def AutoContrast(img, **kwarg):
23 | return PIL.ImageOps.autocontrast(img)
24 |
25 |
26 | def Brightness(img, v, max_v, bias=0):
27 | v = _float_parameter(v, max_v) + bias
28 | return PIL.ImageEnhance.Brightness(img).enhance(v)
29 |
30 |
31 | def Color(img, v, max_v, bias=0):
32 | v = _float_parameter(v, max_v) + bias
33 | return PIL.ImageEnhance.Color(img).enhance(v)
34 |
35 |
36 | def Contrast(img, v, max_v, bias=0):
37 | v = _float_parameter(v, max_v) + bias
38 | return PIL.ImageEnhance.Contrast(img).enhance(v)
39 |
40 |
41 | def Cutout(img, v, max_v, bias=0):
42 | if v == 0:
43 | return img
44 | v = _float_parameter(v, max_v) + bias
45 | v = int(v * min(img.size))
46 | return CutoutAbs(img, v)
47 |
48 |
49 | def CutoutAbs(img, v, **kwarg):
50 | w, h = img.size
51 | x0 = np.random.uniform(0, w)
52 | y0 = np.random.uniform(0, h)
53 | x0 = int(max(0, x0 - v / 2.))
54 | y0 = int(max(0, y0 - v / 2.))
55 | x1 = int(min(w, x0 + v))
56 | y1 = int(min(h, y0 + v))
57 | xy = (x0, y0, x1, y1)
58 | # gray
59 | color = (127, 127, 127)
60 | img = img.copy()
61 | PIL.ImageDraw.Draw(img).rectangle(xy, color)
62 | return img
63 |
64 |
65 | def Equalize(img, **kwarg):
66 | return PIL.ImageOps.equalize(img)
67 |
68 |
69 | def Identity(img, **kwarg):
70 | return img
71 |
72 |
73 | def Invert(img, **kwarg):
74 | return PIL.ImageOps.invert(img)
75 |
76 |
77 | def Posterize(img, v, max_v, bias=0):
78 | v = _int_parameter(v, max_v) + bias
79 | return PIL.ImageOps.posterize(img, v)
80 |
81 |
82 | def Rotate(img, v, max_v, bias=0):
83 | v = _int_parameter(v, max_v) + bias
84 | if random.random() < 0.5:
85 | v = -v
86 | return img.rotate(v)
87 |
88 |
89 | def Sharpness(img, v, max_v, bias=0):
90 | v = _float_parameter(v, max_v) + bias
91 | return PIL.ImageEnhance.Sharpness(img).enhance(v)
92 |
93 |
94 | def ShearX(img, v, max_v, bias=0):
95 | v = _float_parameter(v, max_v) + bias
96 | if random.random() < 0.5:
97 | v = -v
98 | return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0))
99 |
100 |
101 | def ShearY(img, v, max_v, bias=0):
102 | v = _float_parameter(v, max_v) + bias
103 | if random.random() < 0.5:
104 | v = -v
105 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0))
106 |
107 |
108 | def Solarize(img, v, max_v, bias=0):
109 | v = _int_parameter(v, max_v) + bias
110 | return PIL.ImageOps.solarize(img, 256 - v)
111 |
112 |
113 | def SolarizeAdd(img, v, max_v, bias=0, threshold=128):
114 | v = _int_parameter(v, max_v) + bias
115 | if random.random() < 0.5:
116 | v = -v
117 | img_np = np.array(img).astype(np.int)
118 | img_np = img_np + v
119 | img_np = np.clip(img_np, 0, 255)
120 | img_np = img_np.astype(np.uint8)
121 | img = Image.fromarray(img_np)
122 | return PIL.ImageOps.solarize(img, threshold)
123 |
124 |
125 | def TranslateX(img, v, max_v, bias=0):
126 | v = _float_parameter(v, max_v) + bias
127 | if random.random() < 0.5:
128 | v = -v
129 | v = int(v * img.size[0])
130 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0))
131 |
132 |
133 | def TranslateY(img, v, max_v, bias=0):
134 | v = _float_parameter(v, max_v) + bias
135 | if random.random() < 0.5:
136 | v = -v
137 | v = int(v * img.size[1])
138 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v))
139 |
140 |
141 | class GaussianBlur(object):
142 | """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709"""
143 |
144 | def __init__(self, sigma=[.1, 2.]):
145 | self.sigma = sigma
146 |
147 | def __call__(self, x):
148 | sigma = random.uniform(self.sigma[0], self.sigma[1])
149 | x = x.filter(ImageFilter.GaussianBlur(radius=sigma))
150 | return x
151 |
152 |
153 | class NoOpTransform(object):
154 | """
155 | A transform that does nothing.
156 | """
157 |
158 | def __init__(self):
159 | super().__init__()
160 |
161 | def __call__(self, tensor):
162 | """
163 | Args:
164 | tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
165 |
166 | Returns:
167 | Tensor: Original Tensor image.
168 | """
169 | return tensor
170 |
171 | def __repr__(self):
172 | return self.__class__.__name__ + '()'
173 |
174 |
175 | def _float_parameter(v, max_v):
176 | return float(v) * max_v / PARAMETER_MAX
177 |
178 |
179 | def _int_parameter(v, max_v):
180 | return int(v * max_v / PARAMETER_MAX)
181 |
182 |
183 | def fixmatch_augment_pool():
184 | # FixMatch paper
185 | augs = [(AutoContrast, None, None),
186 | (Brightness, 0.9, 0.05),
187 | (Color, 0.9, 0.05),
188 | (Contrast, 0.9, 0.05),
189 | (Equalize, None, None),
190 | (Identity, None, None),
191 | (Posterize, 4, 4),
192 | (Rotate, 30, 0),
193 | (Sharpness, 0.9, 0.05),
194 | (ShearX, 0.3, 0),
195 | (ShearY, 0.3, 0),
196 | (Solarize, 256, 0),
197 | (TranslateX, 0.3, 0),
198 | (TranslateY, 0.3, 0)]
199 | return augs
200 |
201 |
202 | def my_augment_pool():
203 | # Test
204 | augs = [(AutoContrast, None, None),
205 | (Brightness, 1.8, 0.1),
206 | (Color, 1.8, 0.1),
207 | (Contrast, 1.8, 0.1),
208 | (Cutout, 0.2, 0),
209 | (Equalize, None, None),
210 | (Invert, None, None),
211 | (Posterize, 4, 4),
212 | (Rotate, 30, 0),
213 | (Sharpness, 1.8, 0.1),
214 | (ShearX, 0.3, 0),
215 | (ShearY, 0.3, 0),
216 | (Solarize, 256, 0),
217 | (SolarizeAdd, 110, 0),
218 | (TranslateX, 0.45, 0),
219 | (TranslateY, 0.45, 0)]
220 | return augs
221 |
222 |
223 | def imagenet_augment_pool():
224 | # op, max_v, bias
225 | augs = [(AutoContrast, None, None),
226 | (Brightness, 1.8, 0.1),
227 | (Color, 1.8, 0.1),
228 | (Contrast, 1.8, 0.1),
229 | (Equalize, None, None),
230 | (Identity, None, None),
231 | (Invert, None, None),
232 | (Posterize, 4, 4),
233 | (Rotate, 30, 0),
234 | (Sharpness, 1.8, 0.1),
235 | (ShearX, 0.3, 0),
236 | (ShearY, 0.3, 0),
237 | (Solarize, 256, 0),
238 | (SolarizeAdd, 110, 0),
239 | (TranslateX, 0.45, 0),
240 | (TranslateY, 0.45, 0)]
241 | return augs
242 |
243 |
244 | class RandAugmentPC(object):
245 | def __init__(self, n, m):
246 | assert n >= 1
247 | assert 1 <= m <= 10
248 | self.n = n
249 | self.m = m
250 | self.augment_pool = my_augment_pool()
251 |
252 | def __call__(self, img):
253 | ops = random.choices(self.augment_pool, k=self.n)
254 | for op, max_v, bias in ops:
255 | prob = np.random.uniform(0.2, 0.8)
256 | if random.random() + prob >= 1:
257 | img = op(img, v=self.m, max_v=max_v, bias=bias)
258 | img = CutoutAbs(img, 16)
259 | return img
260 |
261 |
262 | class RandAugmentMC(object):
263 | def __init__(self, n, m):
264 | assert n >= 1
265 | assert 1 <= m <= 10
266 | self.n = n
267 | self.m = m
268 | self.augment_pool = fixmatch_augment_pool()
269 |
270 | def __call__(self, img):
271 | ops = random.choices(self.augment_pool, k=self.n)
272 | for op, max_v, bias in ops:
273 | v = np.random.randint(1, self.m)
274 | if random.random() < 0.5:
275 | img = op(img, v=v, max_v=max_v, bias=bias)
276 | img = CutoutAbs(img, 16)
277 | return img
278 |
279 |
280 | class RandAugment(object):
281 | def __init__(self, n, m, prob=None):
282 | assert n >= 1
283 | assert 1 <= m <= 10
284 | if prob is not None:
285 | assert 0. <= prob <= 1.
286 | self.n = n
287 | self.m = m
288 | self.prob = prob
289 | self.augment_pool = imagenet_augment_pool()
290 |
291 | def __call__(self, img):
292 | ops = random.choices(self.augment_pool, k=self.n)
293 | for op, max_v, bias in ops:
294 | v = np.random.randint(1, self.m)
295 | if self.prob is not None:
296 | if random.random() < self.prob:
297 | img = op(img, v=v, max_v=max_v, bias=bias)
298 | else:
299 | img = op(img, v=v, max_v=max_v, bias=bias)
300 | return img
301 |
302 | def __repr__(self):
303 | return self.__class__.__name__ + '(m={0}, n={1}, prob={2})'.format(self.m, self.n, self.prob)
304 |
--------------------------------------------------------------------------------
/data/fer2013.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | """
5 | """
6 |
7 | __author__ = "GZ"
8 |
9 | import os
10 | import sys
11 | import csv
12 | import pathlib
13 | import numpy as np
14 | from tqdm import tqdm
15 | from typing import Any, Callable, Optional, Tuple
16 | from PIL import Image
17 | import matplotlib.pyplot as plt
18 |
19 | if sys.platform == 'win32':
20 | os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
21 |
22 | import torch
23 | from torchvision.datasets import VisionDataset
24 | import torchvision.transforms as transforms
25 |
26 | # Root directory of the project
27 | try:
28 | abspath = os.path.abspath(__file__)
29 | except NameError:
30 | abspath = os.getcwd()
31 | ROOT_DIR = os.path.dirname(abspath)
32 |
33 |
34 | # List of folders for training, validation and test.
35 | folder_names = {'Training' : 'FER2013Train',
36 | 'PublicTest' : 'FER2013Valid',
37 | 'PrivateTest': 'FER2013Test'}
38 |
39 |
40 | class FER2013(VisionDataset):
41 | """`FER2013
42 | `_ Dataset.
43 |
44 | Args:
45 | root (string): Root directory of dataset where directory
46 | ``root/fer2013`` exists.
47 | split (string, optional): The dataset split, supports ``"train"`` (default), or ``"test"``.
48 | transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed
49 | version. E.g, ``transforms.RandomCrop``
50 | target_transform (callable, optional): A function/transform that takes in the target and transforms it.
51 | """
52 |
53 | def __init__(
54 | self,
55 | root: str,
56 | split: str = "train",
57 | transform: Optional[Callable] = None,
58 | target_transform: Optional[Callable] = None,
59 | convert_rgb=False
60 | ) -> None:
61 | self._split = split
62 | assert split in ['Training', 'PublicTest', 'PrivateTest']
63 | super().__init__(root, transform=transform, target_transform=target_transform)
64 |
65 | self.convert_rgb = convert_rgb
66 |
67 | base_folder = pathlib.Path(self.root)
68 | data_file = base_folder / "fer2013.csv"
69 |
70 | self._samples = []
71 | with open(data_file, "r", newline="") as file:
72 | for row in csv.DictReader(file):
73 | if split == row["Usage"]:
74 | data = (
75 | torch.tensor([int(idx) for idx in row["pixels"].split()], dtype=torch.uint8).reshape(48, 48),
76 | int(row["emotion"]) if "emotion" in row else None,
77 | )
78 | self._samples.append(data)
79 |
80 | def __len__(self) -> int:
81 | return len(self._samples)
82 |
83 | def __getitem__(self, idx: int):
84 | image_tensor, target = self._samples[idx]
85 | image = Image.fromarray(image_tensor.numpy())
86 |
87 | if self.convert_rgb:
88 | image = image.convert("RGB")
89 |
90 | if self.transform is not None:
91 | image = self.transform(image)
92 |
93 | if self.target_transform is not None:
94 | target = self.target_transform(target)
95 |
96 | return image, target, idx
97 |
98 |
99 | def extra_repr(self) -> str:
100 | return f"split={self._split}"
101 |
102 |
103 | class FERplus(VisionDataset):
104 | """
105 | https://github.com/microsoft/FERPlus/blob/master/src/ferplus.py
106 | https://github.com/siqueira-hc/Efficient-Facial-Feature-Learning-with-Wide-Ensemble-based-Convolutional-Neural-Networks
107 | """
108 | def __init__(
109 | self,
110 | root: str,
111 | split: str = "Training",
112 | transform: Optional[Callable] = None,
113 | target_transform: Optional[Callable] = None,
114 | convert_rgb=False
115 | ) -> None:
116 | self._split = split
117 | assert split in ['Training', 'PublicTest', 'PrivateTest']
118 | super().__init__(root, transform=transform, target_transform=target_transform)
119 |
120 | self.convert_rgb = convert_rgb
121 | self.per_emotion_count = None
122 |
123 | # Default values
124 | self.emotion_count = 8
125 |
126 | # Load data
127 | self.loaded_data = self._load()
128 | print('Size of the loaded set: {}'.format(self.loaded_data[0].shape[0]))
129 |
130 | def __len__(self):
131 | return self.loaded_data[0].shape[0]
132 |
133 | def __getitem__(self, idx):
134 | image = self.loaded_data[0][idx]
135 | image = Image.fromarray(image)
136 | target = self.loaded_data[1][idx]
137 |
138 | if self.transform is not None:
139 | image = self.transform(image)
140 |
141 | return image, target, idx
142 |
143 | # @staticmethod
144 | # def get_class(idx):
145 | # classes = {
146 | # 0: 'Neutral',
147 | # 1: 'Happy',
148 | # 2: 'Sad',
149 | # 3: 'Surprise',
150 | # 4: 'Fear',
151 | # 5: 'Disgust',
152 | # 6: 'Anger',
153 | # 7: 'Contempt'}
154 | #
155 | # return classes[idx]
156 | #
157 | # @staticmethod
158 | # def _parse_to_label(idx):
159 | # """
160 | # Parse labels to make them compatible with AffectNet.
161 | # :param idx:
162 | # :return:
163 | # """
164 | # emo_to_return = np.argmax(idx)
165 | #
166 | # if emo_to_return == 2:
167 | # emo_to_return = 3
168 | # elif emo_to_return == 3:
169 | # emo_to_return = 2
170 | # elif emo_to_return == 4:
171 | # emo_to_return = 6
172 | # elif emo_to_return == 6:
173 | # emo_to_return = 4
174 | #
175 | # return emo_to_return
176 |
177 | @staticmethod
178 | def _process_data(emotion_raw):
179 | size = len(emotion_raw)
180 | emotion_unknown = [0.0] * size
181 | emotion_unknown[-2] = 1.0
182 |
183 | # remove emotions with a single vote (outlier removal)
184 | for i in range(size):
185 | if emotion_raw[i] < 1.0 + sys.float_info.epsilon:
186 | emotion_raw[i] = 0.0
187 |
188 | sum_list = sum(emotion_raw)
189 | emotion = [0.0] * size
190 |
191 | # find the peak value of the emo_raw list
192 | maxval = max(emotion_raw)
193 | if maxval > 0.5 * sum_list:
194 | emotion[np.argmax(emotion_raw)] = maxval
195 | else:
196 | emotion = emotion_unknown # force setting as unknown
197 |
198 | return [float(i) / sum(emotion) for i in emotion]
199 |
200 | def _load(self):
201 | csv_label = []
202 | data, labels = [], []
203 | self.per_emotion_count = np.zeros(self.emotion_count, dtype=np.int32)
204 |
205 | path_folders_images = os.path.join(self.root, 'Images', folder_names[self._split])
206 | path_folders_labels = os.path.join(self.root, 'Labels', folder_names[self._split])
207 |
208 | with open(os.path.join(path_folders_labels, "label.csv")) as csvfile:
209 | lines = csv.reader(csvfile)
210 | for row in lines:
211 | csv_label.append(row)
212 |
213 | for l in csv_label:
214 | emotion_raw = list(map(float, l[2:len(l)]))
215 | emotion = self._process_data(emotion_raw)
216 | idx = np.argmax(emotion)
217 |
218 | if idx < self.emotion_count: # not unknown or non-face
219 | self.per_emotion_count[idx] += 1
220 |
221 | # emotion = emotion[:-2]
222 | # emotion = [float(i) / sum(emotion) for i in emotion]
223 | # emotion = self._parse_to_label(emotion)
224 |
225 | image = Image.open(os.path.join(path_folders_images, l[0]))
226 | if self.convert_rgb:
227 | image = image.convert("RGB")
228 | image = np.array(image)
229 |
230 | box = list(map(int, l[1][1:-1].split(',')))
231 | if box[-1] != 48:
232 | print("[INFO] Face is not centralized.")
233 | print(os.path.join(path_folders_images, l[0]))
234 | print(box)
235 | exit(-1)
236 |
237 | image = image[box[0]:box[2], box[1]:box[3], :]
238 |
239 | data.append(image)
240 | labels.append(idx)
241 |
242 | return [np.array(data), np.array(labels)]
243 |
244 |
245 | if __name__ == '__main__':
246 | display_transform = transforms.Compose([
247 | transforms.Resize((96, 96)),
248 | transforms.ToTensor()
249 | ])
250 |
251 | split = "PrivateTest"
252 | # dataset = FER2013(root="../data/FER/fer2013", split=split, transform=display_transform)
253 | dataset = FERplus(root="../data/FER/FERPlus/data", split=split, transform=display_transform, convert_rgb=True)
254 | print(dataset)
255 |
256 | loader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=8, pin_memory=True,
257 | drop_last=False)
258 |
259 | with torch.no_grad():
260 | for i, (images, target, _) in enumerate(tqdm(loader)):
261 | img = np.clip(images.cpu().numpy(), 0, 1) # [0, 1]
262 | img = img.transpose(0, 2, 3, 1)
263 | img = (img * 255).astype(np.uint8)
264 | img = img.squeeze()
265 |
266 | fig, axs = plt.subplots(1, 1, figsize=(8, 8))
267 | axs.imshow(img, cmap='gray')
268 | axs.axis("off")
269 | plt.show()
270 |
--------------------------------------------------------------------------------
/data/transforms.py:
--------------------------------------------------------------------------------
1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 | # SPDX-License-Identifier: CC-BY-NC-4.0
3 |
4 | import os
5 | import random
6 | from io import BytesIO
7 | from PIL import Image
8 | from PIL import ImageOps, ImageFilter
9 | import torch
10 | from torchvision import transforms
11 |
12 | from .randaugment import RandAugment
13 |
14 |
15 | # RGB mean tensor([0.5885, 0.4407, 0.3724])
16 | # RGB std tensor([0.2271, 0.1961, 0.1827])
17 |
18 | # RGB mean tensor([0.5231, 0.4044, 0.3489])
19 | # RGB std tensor([0.2536, 0.2194, 0.2070])
20 |
21 | IMG_MEAN = {"vggface2": [0.5231, 0.4044, 0.3489],
22 | "laionface": [0.48145466, 0.4578275, 0.40821073],
23 | "in1k": [0.485, 0.456, 0.406],
24 | "in100": [0.485, 0.456, 0.406]}
25 | IMG_STD = {"vggface2": [0.2536, 0.2194, 0.2070],
26 | "laionface": [0.26862954, 0.26130258, 0.27577711],
27 | "in1k": [0.229, 0.224, 0.225],
28 | "in100": [0.229, 0.224, 0.225]}
29 |
30 |
31 | class Solarize(object):
32 | def __init__(self, threshold=128):
33 | self.threshold = threshold
34 |
35 | def __call__(self, img):
36 | return ImageOps.solarize(img, self.threshold)
37 |
38 | def __repr__(self):
39 | repr_str = self.__class__.__name__
40 | return repr_str
41 |
42 |
43 | class GaussianBlur(object):
44 | """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709"""
45 |
46 | def __init__(self, sigma=[.1, 2.]):
47 | self.sigma = sigma
48 |
49 | def __call__(self, x):
50 | sigma = random.uniform(self.sigma[0], self.sigma[1])
51 | x = x.filter(ImageFilter.GaussianBlur(radius=sigma))
52 | return x
53 |
54 |
55 | class AddGaussianNoise(object):
56 | def __init__(self, mean=0., std=1.):
57 | self.std = std
58 | self.mean = mean
59 |
60 | def __call__(self, tensor):
61 | tensor = tensor + torch.randn(tensor.size()) * self.std + self.mean
62 | tensor = torch.clamp(tensor, min=0., max=1.)
63 | return tensor
64 |
65 | def __repr__(self):
66 | return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)
67 |
68 |
69 | class PcaAug(object):
70 | _eigval = torch.Tensor([0.2175, 0.0188, 0.0045])
71 | _eigvec = torch.Tensor([
72 | [-0.5675, 0.7192, 0.4009],
73 | [-0.5808, -0.0045, -0.8140],
74 | [-0.5836, -0.6948, 0.4203],
75 | ])
76 |
77 | def __init__(self, alpha=0.1):
78 | self.alpha = alpha
79 |
80 | def __call__(self, im):
81 | alpha = torch.randn(3) * self.alpha
82 | rgb = (self._eigvec * alpha.expand(3, 3) * self._eigval.expand(3, 3)).sum(1)
83 | return im + rgb.reshape(3, 1, 1)
84 |
85 |
86 | class JPEGNoise(object):
87 | def __init__(self, low=30, high=99):
88 | self.low = low
89 | self.high = high
90 |
91 | def __call__(self, im):
92 | H = im.height
93 | W = im.width
94 | rW = max(int(0.8 * W), int(W * (1 + 0.5 * torch.randn([]))))
95 | im = transforms.functional.resize(im, (rW, rW))
96 | buf = BytesIO()
97 | im.save(buf, format='JPEG', quality=torch.randint(self.low, self.high,
98 | []).item())
99 | im = Image.open(buf)
100 | im = transforms.functional.resize(im, (H, W))
101 | return im
102 |
103 |
104 | def get_augmentations(aug_type, dataset):
105 | normalize = transforms.Normalize(mean=IMG_MEAN[dataset.lower()],
106 | std=IMG_STD[dataset.lower()])
107 |
108 | default_train_augs = [
109 | transforms.RandomResizedCrop(224),
110 | transforms.RandomHorizontalFlip(),
111 | ]
112 | default_val_augs = [
113 | transforms.Resize(256),
114 | transforms.CenterCrop(224),
115 | ]
116 | appendix_augs = [
117 | transforms.ToTensor(),
118 | normalize,
119 | ]
120 | if aug_type == 'DefaultTrain':
121 | augs = default_train_augs + appendix_augs
122 | elif aug_type == 'DefaultVal':
123 | augs = default_val_augs + appendix_augs
124 | elif aug_type == 'RandAugment':
125 | augs = default_train_augs + [RandAugment(n=2, m=10)] + appendix_augs
126 | elif aug_type == 'MoCoV1':
127 | augs = [
128 | transforms.RandomResizedCrop(224, scale=(0.2, 1.)),
129 | transforms.RandomGrayscale(p=0.2),
130 | transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
131 | transforms.RandomHorizontalFlip()
132 | ] + appendix_augs
133 | elif aug_type == 'MoCoV2':
134 | augs = [
135 | transforms.RandomResizedCrop(224, scale=(0.2, 1.)),
136 | transforms.RandomApply([
137 | transforms.ColorJitter(0.4, 0.4, 0.4, 0.1) # not strengthened
138 | ], p=0.8),
139 | transforms.RandomGrayscale(p=0.2),
140 | transforms.RandomApply([GaussianBlur([.1, 2.])], p=0.5),
141 | transforms.RandomHorizontalFlip(),
142 | ] + appendix_augs
143 | else:
144 | raise NotImplementedError('augmentation type not found: {}'.format(aug_type))
145 |
146 | return augs
147 |
148 |
149 | def get_transforms(aug_type, dataset="in1k"):
150 | augs = get_augmentations(aug_type, dataset)
151 | return transforms.Compose(augs)
152 |
153 |
154 | def get_byol_tranforms():
155 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
156 | std=[0.229, 0.224, 0.225])
157 | augmentation1 = [
158 | transforms.RandomResizedCrop(224, scale=(0.2, 1.)),
159 | transforms.RandomHorizontalFlip(),
160 | transforms.RandomApply([
161 | transforms.ColorJitter(0.4, 0.4, 0.2, 0.1) # not strengthened
162 | ], p=0.8),
163 | transforms.RandomGrayscale(p=0.2),
164 | transforms.RandomApply([GaussianBlur([.1, 2.])], p=1.),
165 | transforms.RandomApply([Solarize()], p=0.),
166 | transforms.ToTensor(),
167 | normalize
168 | ]
169 | augmentation2 = [
170 | transforms.RandomResizedCrop(224, scale=(0.2, 1.)),
171 | transforms.RandomHorizontalFlip(),
172 | transforms.RandomApply([
173 | transforms.ColorJitter(0.4, 0.4, 0.2, 0.1) # not strengthened
174 | ], p=0.8),
175 | transforms.RandomGrayscale(p=0.2),
176 | transforms.RandomApply([GaussianBlur([.1, 2.])], p=0.1),
177 | transforms.RandomApply([Solarize()], p=0.2),
178 | transforms.ToTensor(),
179 | normalize
180 | ]
181 | transform1 = transforms.Compose(augmentation1)
182 | transform2 = transforms.Compose(augmentation2)
183 | return transform1, transform2
184 |
185 |
186 | def get_vggface_tranforms(image_size=128):
187 | normalize = transforms.Normalize(mean=IMG_MEAN["vggface2"],
188 | std=IMG_STD["vggface2"])
189 |
190 | augmentation1 = [
191 | transforms.RandomResizedCrop(image_size, scale=(0.2, 1.)),
192 | # transforms.Resize([image_size, image_size]),
193 | transforms.RandomHorizontalFlip(),
194 | transforms.RandomApply([
195 | transforms.ColorJitter(0.4, 0.4, 0.2, 0.1) # not strengthened
196 | ], p=0.8),
197 | transforms.RandomGrayscale(p=0.2),
198 | transforms.RandomApply([GaussianBlur([.1, 2.])], p=1.),
199 | transforms.RandomApply([Solarize()], p=0.),
200 | transforms.ToTensor(),
201 | normalize
202 | ]
203 | augmentation2 = [
204 | transforms.RandomResizedCrop(image_size, scale=(0.2, 1.)),
205 | # transforms.Resize([image_size, image_size]),
206 | transforms.RandomHorizontalFlip(),
207 | transforms.RandomApply([
208 | transforms.ColorJitter(0.4, 0.4, 0.2, 0.1) # not strengthened
209 | ], p=0.8),
210 | transforms.RandomGrayscale(p=0.2),
211 | transforms.RandomApply([GaussianBlur([.1, 2.])], p=0.1),
212 | transforms.RandomApply([Solarize()], p=0.2),
213 | transforms.ToTensor(),
214 | normalize
215 | ]
216 | transform1 = transforms.Compose(augmentation1)
217 | transform2 = transforms.Compose(augmentation2)
218 | return transform1, transform2
219 |
220 |
221 | class TwoCropsTransform:
222 | """Take two random crops of one image."""
223 |
224 | def __init__(self, transform1, transform2):
225 | self.transform1 = transform1
226 | self.transform2 = transform2
227 |
228 | def __call__(self, x):
229 | out1 = self.transform1(x)
230 | out2 = self.transform2(x)
231 | return out1, out2
232 |
233 | def __repr__(self):
234 | format_string = self.__class__.__name__ + '('
235 | names = ['transform1', 'transform2']
236 | for idx, t in enumerate([self.transform1, self.transform2]):
237 | format_string += '\n'
238 | t_string = '{0}={1}'.format(names[idx], t)
239 | t_string_split = t_string.split('\n')
240 | t_string_split = [' ' + tstr for tstr in t_string_split]
241 | t_string = '\n'.join(t_string_split)
242 | format_string += '{0}'.format(t_string)
243 | format_string += '\n)'
244 | return format_string
245 |
246 |
247 | if __name__ == '__main__':
248 | from utils.utils import dump_image
249 |
250 | # Ryan_Gosling Emily_VanCamp
251 | img = Image.open("./vis_data/0008_01.jpg")
252 |
253 | augment = get_vggface_tranforms(image_size=224)
254 | img1 = augment[0](img)
255 | img2 = augment[1](img)
256 |
257 | save_dir = "./output"
258 | os.makedirs(save_dir, exist_ok=True)
259 |
260 | filepath = os.path.join(save_dir, "{}.png".format("img1"))
261 | dump_image(img1, IMG_MEAN["vggface2"], IMG_STD["vggface2"], filepath=filepath)
262 |
263 | filepath = os.path.join(save_dir, "{}.png".format("img2"))
264 | dump_image(img2, IMG_MEAN["vggface2"], IMG_STD["vggface2"], filepath=filepath)
265 |
--------------------------------------------------------------------------------
/data/base_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright Lang Huang (laynehuang@outlook.com). All Rights Reserved.
2 | # SPDX-License-Identifier: CC-BY-NC-4.0
3 |
4 | import os
5 | import numpy as np
6 | import torch.utils.data as data
7 | import torchvision.datasets as datasets
8 | import torchvision.transforms as transforms
9 | from torchvision.datasets.folder import ImageFolder, default_loader
10 | from PIL import Image
11 |
12 | try:
13 | import mc
14 | except ImportError:
15 | mc = None
16 | import io
17 |
18 |
19 | # class DatasetCache(data.Dataset):
20 | # def __init__(self):
21 | # super().__init__()
22 | # self.initialized = False
23 | #
24 | #
25 | # def _init_memcached(self):
26 | # if not self.initialized:
27 | # server_list_config_file = "/mnt/cache/share/memcached_client/server_list.conf"
28 | # client_config_file = "/mnt/cache/share/memcached_client/client.conf"
29 | # self.mclient = mc.MemcachedClient.GetInstance(server_list_config_file, client_config_file)
30 | # self.initialized = True
31 | #
32 | # def load_image(self, filename):
33 | # self._init_memcached()
34 | # value = mc.pyvector()
35 | # self.mclient.Get(filename, value)
36 | # value_str = mc.ConvertBuffer(value)
37 | #
38 | # buff = io.BytesIO(value_str)
39 | # with Image.open(buff) as img:
40 | # img = img.convert('RGB')
41 | # return img
42 | #
43 | #
44 | #
45 | # class BaseDataset(DatasetCache):
46 | # def __init__(self, mode='train', max_class=1000, aug=None,
47 | # prefix='/mnt/cache/share/images/meta',
48 | # image_folder_prefix='/mnt/cache/share/images/'):
49 | # super().__init__()
50 | # self.initialized = False
51 | #
52 | # if mode == 'train':
53 | # image_list = os.path.join(prefix, 'train.txt')
54 | # self.image_folder = os.path.join(image_folder_prefix, 'train')
55 | # elif mode == 'test':
56 | # image_list = os.path.join(prefix, 'test.txt')
57 | # self.image_folder = os.path.join(image_folder_prefix, 'test')
58 | # elif mode == 'val':
59 | # image_list = os.path.join(prefix, 'val.txt')
60 | # self.image_folder = os.path.join(image_folder_prefix, 'val')
61 | # else:
62 | # raise NotImplementedError('mode: ' + mode + ' does not exist please select from [train, test, val]')
63 | #
64 | #
65 | # self.samples = []
66 | # with open(image_list) as f:
67 | # for line in f:
68 | # name, label = line.split()
69 | # label = int(label)
70 | # if label < max_class:
71 | # self.samples.append((label, name))
72 | #
73 | # if aug is None:
74 | # if mode == 'train':
75 | # self.transform = transforms.Compose([
76 | # transforms.RandomResizedCrop(224),
77 | # transforms.RandomHorizontalFlip(),
78 | # transforms.ToTensor(),
79 | # transforms.Normalize(mean=[0.485, 0.456, 0.406],
80 | # std=[0.229, 0.224, 0.225])
81 | # ])
82 | # else:
83 | # self.transform = transforms.Compose([
84 | # transforms.Resize(256),
85 | # transforms.CenterCrop(224),
86 | # transforms.ToTensor(),
87 | # transforms.Normalize(mean=[0.485, 0.456, 0.406],
88 | # std=[0.229, 0.224, 0.225]),
89 | # ])
90 | #
91 | # else:
92 | # self.transform = aug
93 | #
94 | #
95 | # def get_keep_index(samples, percent, num_classes, shuffle=False):
96 | # labels = np.array([sample[0] for sample in samples])
97 | # keep_indexs = []
98 | # for i in range(num_classes):
99 | # idx = np.where(labels == i)[0]
100 | # num_sample = len(idx)
101 | # label_per_class = min(max(1, round(percent * num_sample)), num_sample)
102 | # if shuffle:
103 | # np.random.shuffle(idx)
104 | # keep_indexs.extend(idx[:label_per_class])
105 | #
106 | # return keep_indexs
107 | #
108 | #
109 | # class ImageNet(BaseDataset):
110 | # def __init__(self, mode='train', max_class=1000, num_classes=1000, transform=None,
111 | # percent=1., shuffle=False, **kwargs):
112 | # super().__init__(mode, max_class, aug=transform, **kwargs)
113 | #
114 | # assert 0 <= percent <= 1
115 | # if percent < 1:
116 | # keep_indexs = get_keep_index(self.samples, percent, num_classes, shuffle)
117 | # self.samples = [self.samples[i] for i in keep_indexs]
118 | #
119 | # def __len__(self):
120 | # return self.samples.__len__()
121 | #
122 | # def __getitem__(self, index):
123 | # label, name = self.samples[index]
124 | # filename = os.path.join(self.image_folder, name)
125 | # img = self.load_image(filename)
126 | # return self.transform(img), label, index
127 | #
128 | #
129 | # class ImageNetWithIdx(BaseDataset):
130 | # def __init__(self, mode='train', max_class=1000, num_classes=1000, transform=None,
131 | # idx=None, shuffle=False, **kwargs):
132 | # super().__init__(mode, max_class, aug=transform, **kwargs)
133 | #
134 | # assert idx is not None
135 | # with open(idx, "r") as fin:
136 | # samples = [line.strip().split(" ") for line in fin.readlines()]
137 | # self.samples = samples
138 | # print(f"Len of training set: {len(self.samples)}")
139 | #
140 | # def __len__(self):
141 | # return self.samples.__len__()
142 | #
143 | # def __getitem__(self, index):
144 | # label, name = self.samples[index]
145 | # filename = os.path.join(self.image_folder, name)
146 | # img = self.load_image(filename)
147 | # return self.transform(img), int(label), index
148 | #
149 | #
150 | # class ImageNet100(ImageNet):
151 | # def __init__(self, **kwargs):
152 | # super().__init__(
153 | # num_classes=100,
154 | # prefix='/mnt/lustre/huanglang/research/selfsup/data/imagenet-100/',
155 | # image_folder_prefix='/mnt/lustre/huanglang/research/selfsup/data/images',
156 | # **kwargs)
157 | #
158 | # class ImageFolderWithPercent(ImageFolder):
159 | #
160 | # def __init__(self, root, transform=None, target_transform=None,
161 | # loader=default_loader, is_valid_file=None, percent=1.0, shuffle=False):
162 | # super().__init__(root, transform=transform, target_transform=target_transform,
163 | # loader=loader, is_valid_file=is_valid_file)
164 | # assert 0 <= percent <= 1
165 | # if percent < 1:
166 | # keep_indexs = get_keep_index(self.targets, percent, len(self.classes), shuffle)
167 | # self.samples = [self.samples[i] for i in keep_indexs]
168 | # self.targets = [self.targets[i] for i in keep_indexs]
169 | # self.imgs = self.samples
170 | #
171 | #
172 | # class ImageFolderWithIndex(ImageFolder):
173 | #
174 | # def __init__(self, root, indexs=None, transform=None, target_transform=None,
175 | # loader=default_loader, is_valid_file=None):
176 | # super().__init__(root, transform=transform, target_transform=target_transform,
177 | # loader=loader, is_valid_file=is_valid_file)
178 | # if indexs is not None:
179 | # self.samples = [self.samples[i] for i in indexs]
180 | # self.targets = [self.targets[i] for i in indexs]
181 | # self.imgs = self.samples
182 |
183 |
184 | class ImageFolderInstance(datasets.ImageFolder):
185 | def __getitem__(self, index):
186 | path, target = self.samples[index]
187 | sample = self.loader(path)
188 | if self.transform is not None:
189 | sample = self.transform(sample)
190 | if self.target_transform is not None:
191 | target = self.target_transform(target)
192 |
193 | return sample, target, index
194 |
195 |
196 | class ImageFolderSubset(datasets.ImageFolder):
197 | """Folder datasets which returns the index of the image (for memory_bank)
198 | """
199 | def __init__(self, class_path, root, transform, **kwargs):
200 | super().__init__(root, transform, **kwargs)
201 | self.class_path = class_path
202 | new_samples, sorted_classes = self.get_class_samples()
203 | self.imgs = self.samples = new_samples # len=126689
204 | self.classes = sorted_classes
205 | self.class_to_idx = {cls_name: i for i, cls_name in enumerate(sorted_classes)}
206 | self.targets = [s[1] for s in self.samples]
207 |
208 | def get_class_samples(self):
209 | classes = open(self.class_path).readlines()
210 | classes = [m.strip() for m in classes]
211 | classes = set(classes)
212 | class_to_sample = [[os.path.basename(os.path.dirname(m[0])), m] for m in self.imgs]
213 | selected_samples = [m[1] for m in class_to_sample if m[0] in classes]
214 |
215 | sorted_classes = sorted(list(classes))
216 | target_mapping = {self.class_to_idx[k]: j for j, k in enumerate(sorted_classes)}
217 |
218 | valid_pairs = [[m[0], target_mapping[m[1]]] for m in selected_samples]
219 | return valid_pairs, sorted_classes
220 |
221 | def __getitem__(self, index):
222 | path, target = self.samples[index]
223 | sample = self.loader(path)
224 | if self.transform is not None:
225 | sample = self.transform(sample)
226 | if self.target_transform is not None:
227 | target = self.target_transform(target)
228 |
229 | return sample, target, index
230 |
231 |
232 | def get_dataset(dataset, mode, transform, data_root=None, **kwargs):
233 | data_dir = os.path.join(data_root, mode)
234 | if mode == "val" and "ImageNet" in data_root and "nobackup_mmv_ioannisp" in data_root:
235 | data_dir = "/import/nobackup_mmv_ioannisp/zg002/data/ImageNet/val"
236 | in100_class_path = "./data/imagenet100.txt"
237 |
238 | if dataset.lower() == 'in1k':
239 | return ImageFolderInstance(data_dir, transform=transform)
240 | elif dataset.lower() == 'in100':
241 | return ImageFolderSubset(in100_class_path, data_dir, transform)
242 | elif dataset.lower() == "vggface2":
243 | return ImageFolderInstance(data_dir, transform=transform)
244 | # elif dataset == 'in1k_idx':
245 | # return ImageNetWithIdx(mode, transform=transform, **kwargs)
246 | # else: # ImageFolder
247 | # data_dir = os.path.join(data_root, mode)
248 | # assert os.path.isdir(data_dir)
249 | # return ImageFolderWithPercent(data_dir, transform, **kwargs)
250 |
251 |
--------------------------------------------------------------------------------
/engine.py:
--------------------------------------------------------------------------------
1 | # Original copyright Amazon.com, Inc. or its affiliates, under CC-BY-NC-4.0 License.
2 | # Modifications Copyright Lang Huang (laynehuang@outlook.com). All Rights Reserved.
3 | # SPDX-License-Identifier: CC-BY-NC-4.0
4 |
5 | import time
6 | from datetime import timedelta
7 | import numpy as np
8 | try:
9 | import faiss
10 | except ImportError:
11 | pass
12 |
13 | import torch
14 | import torch.nn as nn
15 | from classy_vision.generic.distributed_util import is_distributed_training_run
16 |
17 | from utils import utils
18 | from utils.dist_utils import all_reduce_mean
19 |
20 | def validate(val_loader, model, criterion, args):
21 | batch_time = utils.AverageMeter('Time', ':6.3f')
22 | losses = utils.AverageMeter('Loss', ':.4e')
23 | top1 = utils.AverageMeter('Acc@1', ':6.2f')
24 | top5 = utils.AverageMeter('Acc@5', ':6.2f')
25 | progress = utils.ProgressMeter(
26 | len(val_loader),
27 | [batch_time, losses, top1, top5],
28 | prefix='Test: ')
29 |
30 | # switch to evaluate mode
31 | model.eval()
32 |
33 | with torch.no_grad():
34 | end = time.time()
35 | for i, (images, target, _) in enumerate(val_loader):
36 | if args.gpu is not None:
37 | images = images.cuda(args.gpu, non_blocking=True)
38 | target = target.cuda(args.gpu, non_blocking=True)
39 |
40 | # compute output
41 | output = model(images)
42 | loss = criterion(output, target)
43 |
44 | # measure accuracy and record loss
45 | acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
46 |
47 | if is_distributed_training_run():
48 | # torch.distributed.barrier()
49 | acc1 = all_reduce_mean(acc1)
50 | acc5 = all_reduce_mean(acc5)
51 |
52 | losses.update(loss.item(), images.size(0))
53 | top1.update(acc1[0], images.size(0))
54 | top5.update(acc5[0], images.size(0))
55 |
56 | # measure elapsed time
57 | batch_time.update(time.time() - end)
58 | end = time.time()
59 |
60 | if i % args.print_freq == 0:
61 | progress.display(i)
62 |
63 | # TODO: this should also be done with the ProgressMeter
64 | print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f} Loss {loss.avg:.4f}'
65 | .format(top1=top1, top5=top5, loss=losses))
66 |
67 | return top1.avg
68 |
69 |
70 | def ss_validate(val_loader_base, val_loader_query, model, args):
71 | print("start KNN evaluation with key size={} and query size={}".format(
72 | len(val_loader_base.dataset.samples), len(val_loader_query.dataset.samples)))
73 | batch_time_key = utils.AverageMeter('Time', ':6.3f')
74 | batch_time_query = utils.AverageMeter('Time', ':6.3f')
75 | # switch to evaluate mode
76 | model.eval()
77 |
78 | feats_base = []
79 | target_base = []
80 | feats_query = []
81 | target_query = []
82 |
83 | with torch.no_grad():
84 | start = time.time()
85 | end = time.time()
86 | # Memory features
87 | for i, (images, target, _) in enumerate(val_loader_base):
88 | if args.gpu is not None:
89 | images = images.cuda(args.gpu, non_blocking=True)
90 | target = target.cuda(args.gpu, non_blocking=True)
91 |
92 | # compute features
93 | feats = model(images)
94 | # L2 normalization
95 | feats = nn.functional.normalize(feats, dim=1)
96 |
97 | feats_base.append(feats)
98 | target_base.append(target)
99 |
100 | # measure elapsed time
101 | batch_time_key.update(time.time() - end)
102 | end = time.time()
103 |
104 | if i % args.print_freq == 0:
105 | print('Extracting key features: [{0}/{1}]\t'
106 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})'.format(
107 | i, len(val_loader_base), batch_time=batch_time_key))
108 |
109 | end = time.time()
110 | for i, (images, target, _) in enumerate(val_loader_query):
111 | if args.gpu is not None:
112 | images = images.cuda(args.gpu, non_blocking=True)
113 | target = target.cuda(args.gpu, non_blocking=True)
114 |
115 | # compute features
116 | feats = model(images)
117 | # L2 normalization
118 | feats = nn.functional.normalize(feats, dim=1)
119 |
120 | feats_query.append(feats)
121 | target_query.append(target)
122 |
123 | # measure elapsed time
124 | batch_time_query.update(time.time() - end)
125 | end = time.time()
126 |
127 | if i % args.print_freq == 0:
128 | print('Extracting query features: [{0}/{1}]\t'
129 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})'.format(
130 | i, len(val_loader_query), batch_time=batch_time_query))
131 |
132 | feats_base = torch.cat(feats_base, dim=0)
133 | target_base = torch.cat(target_base, dim=0)
134 | feats_query = torch.cat(feats_query, dim=0)
135 | target_query = torch.cat(target_query, dim=0)
136 | feats_base = feats_base.detach().cpu().numpy()
137 | target_base = target_base.detach().cpu().numpy()
138 | feats_query = feats_query.detach().cpu().numpy()
139 | target_query = target_query.detach().cpu().numpy()
140 | feat_time = time.time() - start
141 |
142 | # KNN search
143 | index = faiss.IndexFlatL2(feats_base.shape[1])
144 | index.add(feats_base)
145 | D, I = index.search(feats_query, args.num_nn)
146 | preds = np.array([np.bincount(target_base[n]).argmax() for n in I])
147 |
148 | NN_acc = (preds == target_query).sum() / len(target_query) * 100.0
149 | knn_time = time.time() - start - feat_time
150 | print("finished KNN evaluation, feature time: {}, knn time: {}".format(
151 | timedelta(seconds=feat_time), timedelta(seconds=knn_time)))
152 | print(' * NN Acc@1 {:.3f}'.format(NN_acc))
153 |
154 | return NN_acc
155 |
156 |
157 |
158 | def ss_face_validate(val_loader, model, args, threshold=0.6):
159 | """
160 | https://github.com/sakshamjindal/Face-Matching
161 | """
162 | batch_time = utils.AverageMeter('Time', ':6.3f')
163 | top1 = utils.AverageMeter('Acc@1', ':6.2f')
164 | progress = utils.ProgressMeter(
165 | len(val_loader),
166 | [batch_time, top1],
167 | prefix='Test: ')
168 |
169 | cos = nn.CosineSimilarity(dim=1, eps=1e-6)
170 |
171 | # switch to evaluate mode
172 | model.eval()
173 | model = model.module if hasattr(model, 'module') else model
174 |
175 | with torch.no_grad():
176 | end = time.time()
177 | for i, (img1, img2, target) in enumerate(val_loader):
178 | img1 = img1.cuda(non_blocking=True)
179 | img2 = img2.cuda(non_blocking=True)
180 | target = target.cuda(non_blocking=True)
181 |
182 | # compute output
183 | embedding1, _, _ = model.online_net(img1)
184 | embedding2, _, _ = model.online_net(img2)
185 |
186 | embedding1 = embedding1.squeeze(-1)
187 | embedding2 = embedding2.squeeze(-1)
188 |
189 | assert embedding1.ndim == 2
190 |
191 | # measure accuracy and record loss
192 | cosine_similarity = cos(embedding1, embedding2)
193 | pred = (cosine_similarity >= threshold).to(torch.float32)
194 | acc1 = (pred == target).float().sum() * 100.0 / (target.shape[0])
195 |
196 | top1.update(acc1.item(), img1.size(0))
197 |
198 | # measure elapsed time
199 | batch_time.update(time.time() - end)
200 | end = time.time()
201 |
202 | if i % args.print_freq == 0:
203 | progress.display(i)
204 |
205 | # TODO: this should also be done with the ProgressMeter
206 | print(' * Acc@1 {top1.avg:.3f}'
207 | .format(top1=top1))
208 |
209 | return top1.avg
210 |
211 |
212 | def validate_multilabel(val_loader, model, criterion, args):
213 | batch_time = utils.AverageMeter('Time', ':6.3f')
214 | losses = utils.AverageMeter('Loss', ':.4e')
215 | top1 = utils.AverageMeter('Acc@1', ':6.2f')
216 | progress = utils.ProgressMeter(
217 | len(val_loader),
218 | [batch_time, losses, top1],
219 | prefix='Test: ')
220 |
221 | # switch to evaluate mode
222 | model.eval()
223 |
224 | with torch.no_grad():
225 | end = time.time()
226 | for i, (images, target, _) in enumerate(val_loader):
227 | if args.gpu is not None:
228 | images = images.cuda(args.gpu, non_blocking=True)
229 | target = target.cuda(args.gpu, non_blocking=True).float()
230 |
231 | # compute output
232 | output = model(images)
233 | loss = criterion(output, target)
234 |
235 | # measure accuracy and record loss
236 | acc1 = utils.accuracy_multilabel(torch.sigmoid(output), target)
237 |
238 | if is_distributed_training_run():
239 | # torch.distributed.barrier()
240 | acc1 = all_reduce_mean(acc1)
241 |
242 | losses.update(loss.item(), images.size(0))
243 | top1.update(acc1.item(), images.size(0))
244 |
245 | # measure elapsed time
246 | batch_time.update(time.time() - end)
247 | end = time.time()
248 |
249 | if i % args.print_freq == 0:
250 | progress.display(i)
251 |
252 | # TODO: this should also be done with the ProgressMeter
253 | print(' * Acc@1 {top1.avg:.3f} Loss {loss.avg:.4f}'
254 | .format(top1=top1, loss=losses))
255 |
256 | return top1.avg
257 |
258 |
259 | if __name__ == '__main__':
260 | import backbone as backbone_models
261 | from models import get_model
262 | import torchvision
263 | import torchvision.transforms as transforms
264 |
265 | model_func = get_model("LEWELB_EMAN")
266 | norm_layer = None
267 | model = model_func(
268 | backbone_models.__dict__["resnet50_encoder"],
269 | dim=256,
270 | m=0.996,
271 | hid_dim=4096,
272 | norm_layer=norm_layer,
273 | num_neck_mlp=2,
274 | scale=1.,
275 | l2_norm=True,
276 | num_heads=4,
277 | loss_weight=0.5,
278 | mask_type="max"
279 | )
280 | print(model)
281 |
282 | model.cuda()
283 |
284 | transform_test = transforms.Compose([
285 | transforms.Resize((224, 224)),
286 | # transforms.CenterCrop(args.image_size),
287 | transforms.ToTensor(),
288 | ])
289 | val_dataset = torchvision.datasets.LFWPairs(root="../data/lfw", split="test",
290 | transform=transform_test, download=True)
291 | print(set(val_dataset.targets))
292 |
293 | val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=8, shuffle=False, num_workers=8, pin_memory=True, persistent_workers=True)
294 |
295 | ss_face_validate(val_loader, model, None)
--------------------------------------------------------------------------------
/data/celeba.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | """
5 | """
6 |
7 | __author__ = "GZ"
8 |
9 | import os
10 | import sys
11 | import csv
12 | import pathlib
13 | import numpy as np
14 | from tqdm import tqdm
15 | from collections import namedtuple
16 | import csv
17 | from functools import partial
18 | from typing import Any, Callable, List, Optional, Union, Tuple
19 | import PIL
20 | from PIL import Image
21 | import matplotlib.pyplot as plt
22 |
23 | if sys.platform == 'win32':
24 | os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
25 |
26 | import torch
27 | from torchvision.datasets.utils import check_integrity, download_file_from_google_drive, extract_archive, verify_str_arg
28 | from torchvision.datasets import VisionDataset
29 | import torchvision.transforms as transforms
30 |
31 | # Root directory of the project
32 | try:
33 | abspath = os.path.abspath(__file__)
34 | except NameError:
35 | abspath = os.getcwd()
36 | ROOT_DIR = os.path.dirname(abspath)
37 |
38 |
39 | CSV = namedtuple("CSV", ["header", "index", "data"])
40 |
41 |
42 | class CelebA(VisionDataset):
43 | """`Large-scale CelebFaces Attributes (CelebA) Dataset `_ Dataset.
44 |
45 | Args:
46 | root (string): Root directory where images are downloaded to.
47 | split (string): One of {'train', 'valid', 'test', 'all'}.
48 | Accordingly dataset is selected.
49 | target_type (string or list, optional): Type of target to use, ``attr``, ``identity``, ``bbox``,
50 | or ``landmarks``. Can also be a list to output a tuple with all specified target types.
51 | The targets represent:
52 |
53 | - ``attr`` (np.array shape=(40,) dtype=int): binary (0, 1) labels for attributes
54 | - ``identity`` (int): label for each person (data points with the same identity are the same person)
55 | - ``bbox`` (np.array shape=(4,) dtype=int): bounding box (x, y, width, height)
56 | - ``landmarks`` (np.array shape=(10,) dtype=int): landmark points (lefteye_x, lefteye_y, righteye_x,
57 | righteye_y, nose_x, nose_y, leftmouth_x, leftmouth_y, rightmouth_x, rightmouth_y)
58 |
59 | Defaults to ``attr``. If empty, ``None`` will be returned as target.
60 |
61 | transform (callable, optional): A function/transform that takes in an PIL image
62 | and returns a transformed version. E.g, ``transforms.ToTensor``
63 | target_transform (callable, optional): A function/transform that takes in the
64 | target and transforms it.
65 | download (bool, optional): If true, downloads the dataset from the internet and
66 | puts it in root directory. If dataset is already downloaded, it is not
67 | downloaded again.
68 | """
69 |
70 | base_folder = "celeba"
71 | # There currently does not appear to be a easy way to extract 7z in python (without introducing additional
72 | # dependencies). The "in-the-wild" (not aligned+cropped) images are only in 7z, so they are not available
73 | # right now.
74 | file_list = [
75 | # File ID MD5 Hash Filename
76 | # ("0B7EVK8r0v71pZjFTYXZWM3FlRnM", "00d2c5bc6d35e252742224ab0c1e8fcb", "img_align_celeba.zip"),
77 | # ("0B7EVK8r0v71pbWNEUjJKdDQ3dGc","b6cd7e93bc7a96c2dc33f819aa3ac651", "img_align_celeba_png.7z"),
78 | # ("0B7EVK8r0v71peklHb0pGdDl6R28", "b6cd7e93bc7a96c2dc33f819aa3ac651", "img_celeba.7z"),
79 | ("0B7EVK8r0v71pblRyaVFSWGxPY0U", "75e246fa4810816ffd6ee81facbd244c", "list_attr_celeba.txt"),
80 | ("1_ee_0u7vcNLOfNLegJRHmolfH5ICW-XS", "32bd1bd63d3c78cd57e08160ec5ed1e2", "identity_CelebA.txt"),
81 | ("0B7EVK8r0v71pbThiMVRxWXZ4dU0", "00566efa6fedff7a56946cd1c10f1c16", "list_bbox_celeba.txt"),
82 | ("0B7EVK8r0v71pd0FJY3Blby1HUTQ", "cc24ecafdb5b50baae59b03474781f8c", "list_landmarks_align_celeba.txt"),
83 | # ("0B7EVK8r0v71pTzJIdlJWdHczRlU", "063ee6ddb681f96bc9ca28c6febb9d1a", "list_landmarks_celeba.txt"),
84 | ("0B7EVK8r0v71pY0NSMzRuSXJEVkk", "d32c9cbf5e040fd4025c592c306e6668", "list_eval_partition.txt"),
85 | ]
86 |
87 | def __init__(
88 | self,
89 | root: str,
90 | split: str = "train",
91 | target_type: Union[List[str], str] = "attr",
92 | transform: Optional[Callable] = None,
93 | target_transform: Optional[Callable] = None,
94 | download: bool = False,
95 | crop=False
96 | ) -> None:
97 | super(CelebA, self).__init__(root, transform=transform,
98 | target_transform=target_transform)
99 | self.split = split
100 | self.crop = crop
101 | if isinstance(target_type, list):
102 | self.target_type = target_type
103 | else:
104 | self.target_type = [target_type]
105 |
106 | if not self.target_type and self.target_transform is not None:
107 | raise RuntimeError('target_transform is specified but target_type is empty')
108 |
109 | if download:
110 | self.download()
111 |
112 | if not self._check_integrity():
113 | raise RuntimeError('Dataset not found or corrupted.' +
114 | ' You can use download=True to download it')
115 |
116 | split_map = {
117 | "train": 0,
118 | "valid": 1,
119 | "test": 2,
120 | "all": None,
121 | }
122 | split_ = split_map[verify_str_arg(split.lower(), "split",
123 | ("train", "valid", "test", "all"))]
124 | splits = self._load_csv("list_eval_partition.txt")
125 | identity = self._load_csv("identity_CelebA.txt")
126 | bbox = self._load_csv("list_bbox_celeba.txt", header=1)
127 | landmarks_align = self._load_csv("list_landmarks_align_celeba.txt", header=1)
128 | attr = self._load_csv("list_attr_celeba.txt", header=1)
129 |
130 | mask = slice(None) if split_ is None else (splits.data == split_).squeeze()
131 |
132 | if mask == slice(None): # if split == "all"
133 | self.filename = splits.index
134 | else:
135 | self.filename = [splits.index[i] for i in torch.squeeze(torch.nonzero(mask))]
136 | self.identity = identity.data[mask]
137 | self.bbox = bbox.data[mask]
138 | self.landmarks_align = landmarks_align.data[mask]
139 | self.attr = attr.data[mask]
140 | # map from {-1, 1} to {0, 1}
141 | self.attr = torch.div(self.attr + 1, 2, rounding_mode='floor')
142 | self.attr_names = attr.header
143 |
144 | def _load_csv(
145 | self,
146 | filename: str,
147 | header: Optional[int] = None,
148 | ) -> CSV:
149 | data, indices, headers = [], [], []
150 |
151 | fn = partial(os.path.join, self.root, self.base_folder)
152 | with open(fn(filename)) as csv_file:
153 | data = list(csv.reader(csv_file, delimiter=' ', skipinitialspace=True))
154 |
155 | if header is not None:
156 | headers = data[header]
157 | data = data[header + 1:]
158 |
159 | indices = [row[0] for row in data]
160 | data = [row[1:] for row in data]
161 | data_int = [list(map(int, i)) for i in data]
162 |
163 | return CSV(headers, indices, torch.tensor(data_int))
164 |
165 | def _check_integrity(self) -> bool:
166 | for (_, md5, filename) in self.file_list:
167 | fpath = os.path.join(self.root, self.base_folder, filename)
168 | _, ext = os.path.splitext(filename)
169 | # Allow original archive to be deleted (zip and 7z)
170 | # Only need the extracted images
171 | if ext not in [".zip", ".7z"] and not check_integrity(fpath, md5):
172 | return False
173 |
174 | # Should check a hash of the images
175 | return os.path.isdir(os.path.join(self.root, self.base_folder, "img_align_celeba"))
176 |
177 | def download(self) -> None:
178 | import zipfile
179 |
180 | if self._check_integrity():
181 | print('Files already downloaded and verified')
182 | return
183 |
184 | for (file_id, md5, filename) in self.file_list:
185 | download_file_from_google_drive(file_id, os.path.join(self.root, self.base_folder), filename, md5)
186 |
187 | with zipfile.ZipFile(os.path.join(self.root, self.base_folder, "img_align_celeba.zip"), "r") as f:
188 | f.extractall(os.path.join(self.root, self.base_folder))
189 |
190 | def __getitem__(self, index: int):
191 | X = PIL.Image.open(os.path.join(self.root, self.base_folder, "img_align_celeba", self.filename[index]))
192 |
193 | target: Any = []
194 | for t in self.target_type:
195 | if t == "attr":
196 | target.append(self.attr[index, :])
197 | elif t == "identity":
198 | target.append(self.identity[index, 0])
199 | elif t == "bbox":
200 | target.append(self.bbox[index, :])
201 | elif t == "landmarks":
202 | target.append(self.landmarks_align[index, :])
203 | else:
204 | # TODO: refactor with utils.verify_str_arg
205 | raise ValueError("Target type \"{}\" is not recognized.".format(t))
206 |
207 | if self.crop:
208 | bbox = self.bbox[index, :]
209 | width, height = X.size
210 | left = bbox[0]
211 | top = bbox[1]
212 | right = bbox[0] + bbox[2]
213 | bottom = bbox[1] + bbox[3]
214 | X = X.crop((left, top, right, bottom))
215 |
216 | if self.transform is not None:
217 | X = self.transform(X)
218 |
219 | if target:
220 | target = tuple(target) if len(target) > 1 else target[0]
221 |
222 | if self.target_transform is not None:
223 | target = self.target_transform(target)
224 | else:
225 | target = None
226 |
227 | return X, target, index
228 |
229 | def __len__(self) -> int:
230 | return len(self.attr)
231 |
232 | def extra_repr(self) -> str:
233 | lines = ["Target type: {target_type}", "Split: {split}"]
234 | return '\n'.join(lines).format(**self.__dict__)
235 |
236 |
237 | if __name__ == '__main__':
238 | # 218, 178
239 | display_transform = transforms.Compose([
240 | transforms.Resize((224, 224)),
241 | transforms.ToTensor()
242 | ])
243 |
244 | split = "train"
245 | dataset = CelebA(root="../data", split=split, transform=display_transform, crop=False)
246 | print(dataset)
247 | print(dataset[0])
248 |
249 | loader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=8, pin_memory=True,
250 | drop_last=False)
251 |
252 | with torch.no_grad():
253 | for i, (images, target, _) in enumerate(tqdm(loader)):
254 | img = np.clip(images.cpu().numpy(), 0, 1) # [0, 1]
255 | img = img.transpose(0, 2, 3, 1)
256 | img = (img * 255).astype(np.uint8)
257 | img = img.squeeze()
258 |
259 | fig, axs = plt.subplots(1, 1, figsize=(8, 8))
260 | axs.imshow(img)
261 | axs.axis("off")
262 | plt.show()
263 |
--------------------------------------------------------------------------------
/backbone/resnet.py:
--------------------------------------------------------------------------------
1 | # some code in this file is adapted from
2 | # https://github.com/pytorch/pytorch
3 | # Licensed under a BSD-style license.
4 | # Modifications Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
5 | # SPDX-License-Identifier: CC-BY-NC-4.0
6 |
7 | import torch
8 | import torch.nn as nn
9 |
10 | __all__ = ['resnet18_encoder', 'resnet34_encoder', 'resnet50_encoder', 'resnet101_encoder',
11 | 'resnet50w2x_encoder', 'resnet50w2x_cls']
12 |
13 |
14 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
15 | """3x3 convolution with padding"""
16 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
17 | padding=dilation, groups=groups, bias=False, dilation=dilation)
18 |
19 |
20 | def conv1x1(in_planes, out_planes, stride=1):
21 | """1x1 convolution"""
22 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
23 |
24 |
25 | class BasicBlock(nn.Module):
26 | expansion = 1
27 |
28 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
29 | base_width=64, dilation=1, norm_layer=None):
30 | super(BasicBlock, self).__init__()
31 | if norm_layer is None:
32 | norm_layer = nn.BatchNorm2d
33 | if groups != 1 or base_width != 64:
34 | raise ValueError('BasicBlock only supports groups=1 and base_width=64')
35 | if dilation > 1:
36 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
37 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
38 | self.conv1 = conv3x3(inplanes, planes, stride)
39 | self.bn1 = norm_layer(planes)
40 | self.relu = nn.ReLU(inplace=True)
41 | self.conv2 = conv3x3(planes, planes)
42 | self.bn2 = norm_layer(planes)
43 | self.downsample = downsample
44 | self.stride = stride
45 |
46 | def forward(self, x):
47 | identity = x
48 |
49 | out = self.conv1(x)
50 | out = self.bn1(out)
51 | out = self.relu(out)
52 |
53 | out = self.conv2(out)
54 | out = self.bn2(out)
55 |
56 | if self.downsample is not None:
57 | identity = self.downsample(x)
58 |
59 | out += identity
60 | out = self.relu(out)
61 |
62 | return out
63 |
64 |
65 | class Bottleneck(nn.Module):
66 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
67 | # while original implementation places the stride at the first 1x1 convolution(self.conv1)
68 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
69 | # This variant is also known as ResNet V1.5 and improves accuracy according to
70 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
71 |
72 | expansion = 4
73 |
74 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
75 | base_width=64, dilation=1, norm_layer=None):
76 | super(Bottleneck, self).__init__()
77 | if norm_layer is None:
78 | norm_layer = nn.BatchNorm2d
79 | width = int(planes * (base_width / 64.)) * groups
80 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1
81 | self.conv1 = conv1x1(inplanes, width)
82 | self.bn1 = norm_layer(width)
83 | self.conv2 = conv3x3(width, width, stride, groups, dilation)
84 | self.bn2 = norm_layer(width)
85 | self.conv3 = conv1x1(width, planes * self.expansion)
86 | self.bn3 = norm_layer(planes * self.expansion)
87 | self.relu = nn.ReLU(inplace=True)
88 | self.downsample = downsample
89 | self.stride = stride
90 |
91 | def forward(self, x):
92 | identity = x
93 |
94 | out = self.conv1(x)
95 | out = self.bn1(out)
96 | out = self.relu(out)
97 |
98 | out = self.conv2(out)
99 | out = self.bn2(out)
100 | out = self.relu(out)
101 |
102 | out = self.conv3(out)
103 | out = self.bn3(out)
104 |
105 | if self.downsample is not None:
106 | identity = self.downsample(x)
107 |
108 | out += identity
109 | out = self.relu(out)
110 |
111 | return out
112 |
113 |
114 | class ResNet(nn.Module):
115 |
116 | def __init__(self, block, layers, zero_init_residual=False,
117 | groups=1, width_per_group=64, replace_stride_with_dilation=None,
118 | norm_layer=None, width_multiplier=1, with_avgpool=True):
119 | super(ResNet, self).__init__()
120 | if norm_layer is None:
121 | norm_layer = nn.BatchNorm2d
122 | self._norm_layer = norm_layer
123 |
124 | self.inplanes = 64 * width_multiplier
125 | self.dilation = 1
126 | if replace_stride_with_dilation is None:
127 | # each element in the tuple indicates if we should replace
128 | # the 2x2 stride with a dilated convolution instead
129 | replace_stride_with_dilation = [False, False, False]
130 | if len(replace_stride_with_dilation) != 3:
131 | raise ValueError("replace_stride_with_dilation should be None "
132 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
133 | self.groups = groups
134 | self.base_width = width_per_group
135 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
136 | bias=False)
137 | self.bn1 = norm_layer(self.inplanes)
138 | self.relu = nn.ReLU(inplace=True)
139 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
140 | self.layer1 = self._make_layer(block, 64 * width_multiplier, layers[0])
141 | self.layer2 = self._make_layer(block, 128 * width_multiplier, layers[1], stride=2,
142 | dilate=replace_stride_with_dilation[0])
143 | self.layer3 = self._make_layer(block, 256 * width_multiplier, layers[2], stride=2,
144 | dilate=replace_stride_with_dilation[1])
145 | self.layer4 = self._make_layer(block, 512 * width_multiplier, layers[3], stride=2,
146 | dilate=replace_stride_with_dilation[2])
147 | self.with_avgpool = with_avgpool
148 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) if with_avgpool else nn.Identity()
149 |
150 | self.out_channels = 512 * width_multiplier * block.expansion
151 |
152 | for m in self.modules():
153 | if isinstance(m, nn.Conv2d):
154 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
155 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
156 | nn.init.constant_(m.weight, 1)
157 | nn.init.constant_(m.bias, 0)
158 |
159 | # Zero-initialize the last BN in each residual branch,
160 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
161 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
162 | if zero_init_residual:
163 | for m in self.modules():
164 | if isinstance(m, Bottleneck):
165 | nn.init.constant_(m.bn3.weight, 0)
166 | elif isinstance(m, BasicBlock):
167 | nn.init.constant_(m.bn2.weight, 0)
168 |
169 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
170 | norm_layer = self._norm_layer
171 | downsample = None
172 | previous_dilation = self.dilation
173 | if dilate:
174 | self.dilation *= stride
175 | stride = 1
176 | if stride != 1 or self.inplanes != planes * block.expansion:
177 | downsample = nn.Sequential(
178 | conv1x1(self.inplanes, planes * block.expansion, stride),
179 | norm_layer(planes * block.expansion),
180 | )
181 |
182 | layers = []
183 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
184 | self.base_width, previous_dilation, norm_layer))
185 | self.inplanes = planes * block.expansion
186 | for _ in range(1, blocks):
187 | layers.append(block(self.inplanes, planes, groups=self.groups,
188 | base_width=self.base_width, dilation=self.dilation,
189 | norm_layer=norm_layer))
190 |
191 | return nn.Sequential(*layers)
192 |
193 | def _forward_impl(self, x):
194 | # See note [TorchScript super()]
195 | x = self.conv1(x)
196 | x = self.bn1(x)
197 | x = self.relu(x)
198 | x = self.maxpool(x)
199 |
200 | x = self.layer1(x)
201 | x = self.layer2(x)
202 | x = self.layer3(x)
203 | x = self.layer4(x)
204 |
205 | x = self.avgpool(x)
206 | x = torch.flatten(x, 1) if self.with_avgpool else x
207 |
208 | return x
209 |
210 | def forward(self, x):
211 | return self._forward_impl(x)
212 |
213 |
214 | class ResNetCls(ResNet):
215 |
216 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
217 | groups=1, width_per_group=64, replace_stride_with_dilation=None,
218 | norm_layer=None, width_multiplier=1, normalize=False):
219 | super(ResNetCls, self).__init__(
220 | block, layers,
221 | zero_init_residual=zero_init_residual,
222 | groups=groups,
223 | width_per_group=width_per_group,
224 | replace_stride_with_dilation=replace_stride_with_dilation,
225 | norm_layer=norm_layer,
226 | width_multiplier=width_multiplier,
227 | )
228 | self.fc = nn.Linear(self.out_channels, num_classes)
229 | self.normalize = normalize
230 |
231 | def _forward_impl(self, x):
232 | # See note [TorchScript super()]
233 | x = self.conv1(x)
234 | x = self.bn1(x)
235 | x = self.relu(x)
236 | x = self.maxpool(x)
237 |
238 | x = self.layer1(x)
239 | x = self.layer2(x)
240 | x = self.layer3(x)
241 | x = self.layer4(x)
242 |
243 | x = self.avgpool(x)
244 | x = torch.flatten(x, 1)
245 | if self.normalize:
246 | x = nn.functional.normalize(x, dim=1)
247 | x = self.fc(x)
248 |
249 | return x
250 |
251 |
252 | def resnet18_encoder(**kwargs):
253 | r"""ResNet-18 model from
254 | `"Deep Residual Learning for Image Recognition" `_
255 | """
256 | return ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
257 |
258 |
259 | def resnet34_encoder(**kwargs):
260 | r"""ResNet-34 model from
261 | `"Deep Residual Learning for Image Recognition" `_
262 | """
263 | return ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
264 |
265 |
266 | def resnet50_encoder(**kwargs):
267 | r"""ResNet-50 model from
268 | `"Deep Residual Learning for Image Recognition" `_
269 | """
270 | return ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
271 |
272 |
273 | def resnet50w2x_encoder(**kwargs):
274 | return ResNet(Bottleneck, [3, 4, 6, 3], width_multiplier=2, **kwargs)
275 |
276 |
277 | def resnet50w2x_cls(**kwargs):
278 | model = ResNetCls(Bottleneck, [3, 4, 6, 3], width_multiplier=2, **kwargs)
279 | return model
280 |
281 |
282 | def resnet101_encoder(**kwargs):
283 | r"""ResNet-101 model from
284 | `"Deep Residual Learning for Image Recognition" `_
285 | """
286 | return ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
287 |
--------------------------------------------------------------------------------
/models/transformers/transformer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/transformer.py
3 | """
4 | Transformer class.
5 |
6 | Copy-paste from torch.nn.Transformer with modifications:
7 | * positional encodings are passed in MHattention
8 | * extra LN at the end of encoder is removed
9 | * decoder returns a stack of activations from all decoding layers
10 | """
11 | import copy
12 | from typing import List, Optional
13 |
14 | import torch
15 | import torch.nn.functional as F
16 | from torch import Tensor, nn
17 |
18 |
19 | class Transformer(nn.Module):
20 | def __init__(
21 | self,
22 | d_model=512,
23 | nhead=8,
24 | num_encoder_layers=6,
25 | num_decoder_layers=6,
26 | dim_feedforward=2048,
27 | dropout=0.1,
28 | activation="relu",
29 | normalize_before=False,
30 | return_intermediate_dec=False,
31 | ):
32 | super().__init__()
33 |
34 | encoder_layer = TransformerEncoderLayer(
35 | d_model, nhead, dim_feedforward, dropout, activation, normalize_before
36 | )
37 | encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
38 | self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
39 |
40 | decoder_layer = TransformerDecoderLayer(
41 | d_model, nhead, dim_feedforward, dropout, activation, normalize_before
42 | )
43 | decoder_norm = nn.LayerNorm(d_model)
44 | self.decoder = TransformerDecoder(
45 | decoder_layer,
46 | num_decoder_layers,
47 | decoder_norm,
48 | return_intermediate=return_intermediate_dec,
49 | )
50 |
51 | self._reset_parameters()
52 |
53 | self.d_model = d_model
54 | self.nhead = nhead
55 |
56 | def _reset_parameters(self):
57 | for p in self.parameters():
58 | if p.dim() > 1:
59 | nn.init.xavier_uniform_(p)
60 |
61 | def forward(self, src, mask, query_embed, pos_embed):
62 | # flatten NxCxHxW to HWxNxC
63 | bs, c, h, w = src.shape
64 | src = src.flatten(2).permute(2, 0, 1)
65 | pos_embed = pos_embed.flatten(2).permute(2, 0, 1)
66 | query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
67 | if mask is not None:
68 | mask = mask.flatten(1)
69 |
70 | tgt = torch.zeros_like(query_embed)
71 | memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
72 | hs = self.decoder(
73 | tgt, memory, memory_key_padding_mask=mask, pos=pos_embed, query_pos=query_embed
74 | )
75 | return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, h, w)
76 |
77 |
78 | class TransformerEncoder(nn.Module):
79 | def __init__(self, encoder_layer, num_layers, norm=None):
80 | super().__init__()
81 | self.layers = _get_clones(encoder_layer, num_layers)
82 | self.num_layers = num_layers
83 | self.norm = norm
84 |
85 | def forward(
86 | self,
87 | src,
88 | mask: Optional[Tensor] = None,
89 | src_key_padding_mask: Optional[Tensor] = None,
90 | pos: Optional[Tensor] = None,
91 | ):
92 | output = src
93 |
94 | for layer in self.layers:
95 | output = layer(
96 | output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, pos=pos
97 | )
98 |
99 | if self.norm is not None:
100 | output = self.norm(output)
101 |
102 | return output
103 |
104 |
105 | class TransformerDecoder(nn.Module):
106 | def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
107 | super().__init__()
108 | self.layers = _get_clones(decoder_layer, num_layers)
109 | self.num_layers = num_layers
110 | self.norm = norm
111 | self.return_intermediate = return_intermediate
112 |
113 | def forward(
114 | self,
115 | tgt,
116 | memory,
117 | tgt_mask: Optional[Tensor] = None,
118 | memory_mask: Optional[Tensor] = None,
119 | tgt_key_padding_mask: Optional[Tensor] = None,
120 | memory_key_padding_mask: Optional[Tensor] = None,
121 | pos: Optional[Tensor] = None,
122 | query_pos: Optional[Tensor] = None,
123 | ):
124 | output = tgt
125 |
126 | intermediate = []
127 |
128 | for layer in self.layers:
129 | output = layer(
130 | output,
131 | memory,
132 | tgt_mask=tgt_mask,
133 | memory_mask=memory_mask,
134 | tgt_key_padding_mask=tgt_key_padding_mask,
135 | memory_key_padding_mask=memory_key_padding_mask,
136 | pos=pos,
137 | query_pos=query_pos,
138 | )
139 | if self.return_intermediate:
140 | intermediate.append(self.norm(output))
141 |
142 | if self.norm is not None:
143 | output = self.norm(output)
144 | if self.return_intermediate:
145 | intermediate.pop()
146 | intermediate.append(output)
147 |
148 | if self.return_intermediate:
149 | return torch.stack(intermediate)
150 |
151 | return output.unsqueeze(0)
152 |
153 |
154 | class TransformerEncoderLayer(nn.Module):
155 | def __init__(
156 | self,
157 | d_model,
158 | nhead,
159 | dim_feedforward=2048,
160 | dropout=0.1,
161 | activation="relu",
162 | normalize_before=False,
163 | ):
164 | super().__init__()
165 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
166 | # Implementation of Feedforward model
167 | self.linear1 = nn.Linear(d_model, dim_feedforward)
168 | self.dropout = nn.Dropout(dropout)
169 | self.linear2 = nn.Linear(dim_feedforward, d_model)
170 |
171 | self.norm1 = nn.LayerNorm(d_model)
172 | self.norm2 = nn.LayerNorm(d_model)
173 | self.dropout1 = nn.Dropout(dropout)
174 | self.dropout2 = nn.Dropout(dropout)
175 |
176 | self.activation = _get_activation_fn(activation)
177 | self.normalize_before = normalize_before
178 |
179 | def with_pos_embed(self, tensor, pos: Optional[Tensor]):
180 | return tensor if pos is None else tensor + pos
181 |
182 | def forward_post(
183 | self,
184 | src,
185 | src_mask: Optional[Tensor] = None,
186 | src_key_padding_mask: Optional[Tensor] = None,
187 | pos: Optional[Tensor] = None,
188 | ):
189 | q = k = self.with_pos_embed(src, pos)
190 | src2 = self.self_attn(
191 | q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
192 | )[0]
193 | src = src + self.dropout1(src2)
194 | src = self.norm1(src)
195 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
196 | src = src + self.dropout2(src2)
197 | src = self.norm2(src)
198 | return src
199 |
200 | def forward_pre(
201 | self,
202 | src,
203 | src_mask: Optional[Tensor] = None,
204 | src_key_padding_mask: Optional[Tensor] = None,
205 | pos: Optional[Tensor] = None,
206 | ):
207 | src2 = self.norm1(src)
208 | q = k = self.with_pos_embed(src2, pos)
209 | src2 = self.self_attn(
210 | q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
211 | )[0]
212 | src = src + self.dropout1(src2)
213 | src2 = self.norm2(src)
214 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
215 | src = src + self.dropout2(src2)
216 | return src
217 |
218 | def forward(
219 | self,
220 | src,
221 | src_mask: Optional[Tensor] = None,
222 | src_key_padding_mask: Optional[Tensor] = None,
223 | pos: Optional[Tensor] = None,
224 | ):
225 | if self.normalize_before:
226 | return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
227 | return self.forward_post(src, src_mask, src_key_padding_mask, pos)
228 |
229 |
230 | class TransformerDecoderLayer(nn.Module):
231 | def __init__(
232 | self,
233 | d_model,
234 | nhead,
235 | dim_feedforward=2048,
236 | dropout=0.1,
237 | activation="relu",
238 | normalize_before=False,
239 | ):
240 | super().__init__()
241 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
242 | self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
243 | # Implementation of Feedforward model
244 | self.linear1 = nn.Linear(d_model, dim_feedforward)
245 | self.dropout = nn.Dropout(dropout)
246 | self.linear2 = nn.Linear(dim_feedforward, d_model)
247 |
248 | self.norm1 = nn.LayerNorm(d_model)
249 | self.norm2 = nn.LayerNorm(d_model)
250 | self.norm3 = nn.LayerNorm(d_model)
251 | self.dropout1 = nn.Dropout(dropout)
252 | self.dropout2 = nn.Dropout(dropout)
253 | self.dropout3 = nn.Dropout(dropout)
254 |
255 | self.activation = _get_activation_fn(activation)
256 | self.normalize_before = normalize_before
257 |
258 | def with_pos_embed(self, tensor, pos: Optional[Tensor]):
259 | return tensor if pos is None else tensor + pos
260 |
261 | def forward_post(
262 | self,
263 | tgt,
264 | memory,
265 | tgt_mask: Optional[Tensor] = None,
266 | memory_mask: Optional[Tensor] = None,
267 | tgt_key_padding_mask: Optional[Tensor] = None,
268 | memory_key_padding_mask: Optional[Tensor] = None,
269 | pos: Optional[Tensor] = None,
270 | query_pos: Optional[Tensor] = None,
271 | ):
272 | q = k = self.with_pos_embed(tgt, query_pos)
273 | tgt2 = self.self_attn(
274 | q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
275 | )[0]
276 | tgt = tgt + self.dropout1(tgt2)
277 | tgt = self.norm1(tgt)
278 | tgt2 = self.multihead_attn(
279 | query=self.with_pos_embed(tgt, query_pos),
280 | key=self.with_pos_embed(memory, pos),
281 | value=memory,
282 | attn_mask=memory_mask,
283 | key_padding_mask=memory_key_padding_mask,
284 | )[0]
285 | tgt = tgt + self.dropout2(tgt2)
286 | tgt = self.norm2(tgt)
287 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
288 | tgt = tgt + self.dropout3(tgt2)
289 | tgt = self.norm3(tgt)
290 | return tgt
291 |
292 | def forward_pre(
293 | self,
294 | tgt,
295 | memory,
296 | tgt_mask: Optional[Tensor] = None,
297 | memory_mask: Optional[Tensor] = None,
298 | tgt_key_padding_mask: Optional[Tensor] = None,
299 | memory_key_padding_mask: Optional[Tensor] = None,
300 | pos: Optional[Tensor] = None,
301 | query_pos: Optional[Tensor] = None,
302 | ):
303 | tgt2 = self.norm1(tgt)
304 | q = k = self.with_pos_embed(tgt2, query_pos)
305 | tgt2 = self.self_attn(
306 | q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
307 | )[0]
308 | tgt = tgt + self.dropout1(tgt2)
309 | tgt2 = self.norm2(tgt)
310 | tgt2 = self.multihead_attn(
311 | query=self.with_pos_embed(tgt2, query_pos),
312 | key=self.with_pos_embed(memory, pos),
313 | value=memory,
314 | attn_mask=memory_mask,
315 | key_padding_mask=memory_key_padding_mask,
316 | )[0]
317 | tgt = tgt + self.dropout2(tgt2)
318 | tgt2 = self.norm3(tgt)
319 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
320 | tgt = tgt + self.dropout3(tgt2)
321 | return tgt
322 |
323 | def forward(
324 | self,
325 | tgt,
326 | memory,
327 | tgt_mask: Optional[Tensor] = None,
328 | memory_mask: Optional[Tensor] = None,
329 | tgt_key_padding_mask: Optional[Tensor] = None,
330 | memory_key_padding_mask: Optional[Tensor] = None,
331 | pos: Optional[Tensor] = None,
332 | query_pos: Optional[Tensor] = None,
333 | ):
334 | if self.normalize_before:
335 | return self.forward_pre(
336 | tgt,
337 | memory,
338 | tgt_mask,
339 | memory_mask,
340 | tgt_key_padding_mask,
341 | memory_key_padding_mask,
342 | pos,
343 | query_pos,
344 | )
345 | return self.forward_post(
346 | tgt,
347 | memory,
348 | tgt_mask,
349 | memory_mask,
350 | tgt_key_padding_mask,
351 | memory_key_padding_mask,
352 | pos,
353 | query_pos,
354 | )
355 |
356 |
357 | def _get_clones(module, N):
358 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
359 |
360 |
361 | def _get_activation_fn(activation):
362 | """Return an activation function given a string"""
363 | if activation == "relu":
364 | return F.relu
365 | if activation == "gelu":
366 | return F.gelu
367 | if activation == "glu":
368 | return F.glu
369 | raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
370 |
--------------------------------------------------------------------------------
/models/lewel.py:
--------------------------------------------------------------------------------
1 | # Copyright Lang Huang (laynehuang@outlook.com). All Rights Reserved.
2 | # SPDX-License-Identifier: CC-BY-NC-4.0
3 | import sys
4 | from math import cos, pi
5 | import torch
6 | import torch.nn as nn
7 | from torch.nn import functional as F
8 | from torch.nn.modules import loss
9 | import torch.distributed as dist
10 | from classy_vision.generic.distributed_util import is_distributed_training_run
11 |
12 | from models.transformers.transformer_predictor import TransformerPredictor
13 | from utils import init
14 |
15 |
16 | class MLP1D(nn.Module):
17 | """
18 | The non-linear neck in byol: fc-bn-relu-fc
19 | """
20 | def __init__(self, in_channels, hid_channels, out_channels,
21 | norm_layer=None, bias=False, num_mlp=2):
22 | super(MLP1D, self).__init__()
23 | if norm_layer is None:
24 | norm_layer = nn.BatchNorm1d
25 | mlps = []
26 | for _ in range(num_mlp-1):
27 | mlps.append(nn.Conv1d(in_channels, hid_channels, 1, bias=bias))
28 | mlps.append(norm_layer(hid_channels))
29 | mlps.append(nn.ReLU(inplace=True))
30 | in_channels = hid_channels
31 | mlps.append(nn.Conv1d(hid_channels, out_channels, 1, bias=bias))
32 | self.mlp = nn.Sequential(*mlps)
33 |
34 | def init_weights(self, init_linear='normal'):
35 | init.init_weights(self, init_linear)
36 |
37 | def forward(self, x):
38 | x = self.mlp(x)
39 | return x
40 |
41 |
42 | class ObjectNeck(nn.Module):
43 | def __init__(self,
44 | in_channels,
45 | out_channels,
46 | hid_channels=None,
47 | num_layers=1,
48 | scale=1.,
49 | l2_norm=True,
50 | num_heads=8,
51 | norm_layer=None,
52 | mask_type="group",
53 | num_proto=64,
54 | temp=0.07,
55 | **kwargs):
56 | super(ObjectNeck, self).__init__()
57 |
58 | self.scale = scale
59 | self.l2_norm = l2_norm
60 | assert l2_norm
61 | self.num_heads = num_heads
62 | self.mask_type = mask_type
63 | self.temp = temp
64 | self.eps = 1e-7
65 |
66 | hid_channels = hid_channels or in_channels
67 | self.proj = MLP1D(in_channels, hid_channels, out_channels, norm_layer, num_mlp=num_layers)
68 | self.proj_obj = MLP1D(in_channels, hid_channels, out_channels, norm_layer, num_mlp=num_layers)
69 |
70 | if mask_type == "attn":
71 | # self.slot_embed = nn.Embedding(num_proto, out_channels)
72 | # self.proj_obj = MLP1D(out_channels, hid_channels, out_channels, norm_layer, num_mlp=num_layers)
73 | self.proj_attn = TransformerPredictor(in_channels=out_channels, hidden_dim=out_channels, num_queries=num_proto,
74 | nheads=8, dropout=0.1, dim_feedforward=out_channels, enc_layers=0,
75 | dec_layers=1, pre_norm=False, deep_supervision=False,
76 | mask_dim=out_channels, enforce_input_project=False,
77 | mask_classification=False, num_classes=0)
78 |
79 | def init_weights(self, init_linear='kaiming'):
80 | self.proj.init_weights(init_linear)
81 | self.proj_obj.init_weights(init_linear)
82 |
83 | def forward(self, x):
84 | out = {}
85 |
86 | b, c, h, w = x.shape
87 |
88 | # flatten and projection
89 | x_pool = F.adaptive_avg_pool2d(x, 1).flatten(2)
90 | x = x.flatten(2) # (bs, c, h*w)
91 | z = self.proj(torch.cat([x_pool, x], dim=2)) # (bs, d, 1+h*w)
92 | z_g, z_feat = torch.split(z, [1, x.shape[2]], dim=2) # (bs, d, 1), (bs, d, h*w)
93 |
94 | z_feat = z_feat.contiguous()
95 |
96 | if self.mask_type == "attn":
97 | z_feat = z_feat.view(b, -1, h, w)
98 | x = x.view(b, c, h, w)
99 | attn_out = self.proj_attn(z_feat, None)
100 | mask_embed = attn_out["mask_embed"] # (bs, q, c)
101 | out["mask_embed"] = mask_embed
102 |
103 | dots = torch.einsum('bqc,bchw->bqhw', F.normalize(mask_embed, dim=2), F.normalize(z_feat, dim=1))
104 | obj_attn = (dots / self.temp).softmax(dim=1) + self.eps
105 | # obj_attn = (dots / 1.0).softmax(dim=1) + self.eps
106 | slots = torch.einsum('bchw,bqhw->bqc', x, obj_attn / obj_attn.sum(dim=(2, 3), keepdim=True))
107 | # slots = torch.einsum('bchw,bqhw->bqc', z_feat, obj_attn / obj_attn.sum(dim=(2, 3), keepdim=True))
108 | obj_attn = obj_attn.view(b, -1, h * w)
109 | out["dots"] = dots
110 | else:
111 | # do attention according to obj attention map
112 | obj_attn = F.normalize(z_feat, dim=1) if self.l2_norm else z_feat
113 | obj_attn /= self.scale
114 | obj_attn = obj_attn.view(b, self.num_heads, -1, h * w) # (bs, h, d/h, h*w)
115 | obj_attn_raw = F.softmax(obj_attn, dim=-1)
116 |
117 | if self.mask_type == "group":
118 | obj_attn = F.softmax(obj_attn, dim=-1)
119 | x = x.view(b, self.num_heads, -1, h*w) # (bs, h, c/h, h*w)
120 | obj_val = torch.matmul(x, obj_attn.transpose(3, 2)) # (bs, h, c//h, d/h)
121 | obj_val = obj_val.view(b, c, obj_attn.shape[-2]) # (bs, c, d/h)
122 | elif self.mask_type == "max":
123 | obj_attn, _ = torch.max(obj_attn, dim=1) # (bs, d/h, h*w)
124 | # obj_attn = torch.mean(obj_attn, dim=1)
125 | obj_attn = F.softmax(obj_attn, dim=-1)
126 | obj_val = torch.matmul(x, obj_attn.transpose(2, 1)) # (bs, c, d/h)
127 | elif self.mask_type == "attn":
128 | obj_val = slots.transpose(2, 1) # (bs, c, q)
129 |
130 | # projection
131 | obj_val = self.proj_obj(obj_val) # (bs, d, d/h)
132 |
133 | out["obj_attn"] = obj_attn
134 | out["obj_attn_raw"] = obj_attn_raw
135 |
136 | return z_g, obj_val, out # (bs, d, 1), (bs, d, d//h), where the second dim is channel
137 |
138 | def extra_repr(self) -> str:
139 | parts = []
140 | for name in ["scale", "l2_norm", "num_heads"]:
141 | parts.append(f"{name}={getattr(self, name)}")
142 | return ", ".join(parts)
143 |
144 |
145 | class EncoderObj(nn.Module):
146 | def __init__(self, base_encoder, hid_dim, out_dim, norm_layer=None, num_mlp=2,
147 | scale=1., l2_norm=True, num_heads=8, mask_type="group", num_proto=64, temp=0.07):
148 | super(EncoderObj, self).__init__()
149 | self.backbone = base_encoder(norm_layer=norm_layer, with_avgpool=False)
150 | in_dim = self.backbone.out_channels
151 | self.neck = ObjectNeck(in_channels=in_dim, hid_channels=hid_dim, out_channels=out_dim,
152 | norm_layer=norm_layer, num_layers=num_mlp,
153 | scale=scale, l2_norm=l2_norm, num_heads=num_heads, mask_type=mask_type,
154 | num_proto=num_proto, temp=temp)
155 | # self.neck.init_weights(init_linear='kaiming')
156 |
157 | def forward(self, im):
158 | out = self.backbone(im)
159 | out = self.neck(out)
160 | return out
161 |
162 |
163 | class LEWELB_EMAN(nn.Module):
164 | def __init__(self, base_encoder, dim=256, m=0.996, hid_dim=4096, norm_layer=None, num_neck_mlp=2,
165 | scale=1., l2_norm=True, num_heads=8, loss_weight=0.5, mask_type="group", num_proto=64,
166 | teacher_temp=0.07, student_temp=0.1, loss_w_cluster=0.5, **kwargs):
167 | super().__init__()
168 |
169 | self.base_m = m
170 | self.curr_m = m
171 | self.loss_weight = loss_weight
172 | self.loss_w_cluster = loss_w_cluster
173 | self.mask_type = mask_type
174 | assert mask_type in ["group", "max", "attn"]
175 | self.num_proto = num_proto
176 | self.student_temp = student_temp # 0.1
177 | self.teacher_temp = teacher_temp # 0.07
178 |
179 | # create the encoders
180 | # num_classes is the output fc dimension
181 | self.online_net = EncoderObj(base_encoder, hid_dim, dim, norm_layer, num_neck_mlp,
182 | scale=scale, l2_norm=l2_norm, num_heads=num_heads, mask_type=mask_type,
183 | num_proto=num_proto, temp=self.teacher_temp)
184 |
185 | # checkpoint = torch.load("./checkpoints/lewel_b_400ep.pth", map_location="cpu")
186 | # msg = self.online_net.backbone.load_state_dict(checkpoint)
187 | # assert set(msg.missing_keys) == set()
188 | # state_dict = checkpoint['state_dict']
189 | # state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
190 | # state_dict = {k.replace("online_net.backbone.", ""): v for k, v in state_dict.items()}
191 | # self.online_net.backbone.load_state_dict(state_dict)
192 |
193 | self.target_net = EncoderObj(base_encoder, hid_dim, dim, norm_layer, num_neck_mlp,
194 | scale=scale, l2_norm=l2_norm, num_heads=num_heads, mask_type=mask_type,
195 | num_proto=num_proto, temp=self.teacher_temp)
196 | self.predictor = MLP1D(dim, hid_dim, dim, norm_layer=norm_layer)
197 | # self.predictor.init_weights()
198 | self.predictor_obj = MLP1D(dim, hid_dim, dim, norm_layer=norm_layer)
199 | # self.predictor_obj.init_weights()
200 | self.encoder_q = self.online_net.backbone
201 |
202 | # copy params from online model to target model
203 | for param_ol, param_tgt in zip(self.online_net.parameters(), self.target_net.parameters()):
204 | param_tgt.data.copy_(param_ol.data) # initialize
205 | param_tgt.requires_grad = False # not update by gradient
206 |
207 | self.center_momentum = 0.9
208 | self.register_buffer("center", torch.zeros(1, self.num_proto))
209 |
210 | def mse_loss(self, pred, target):
211 | """
212 | Args:
213 | pred (Tensor): NxC input features.
214 | target (Tensor): NxC target features.
215 | """
216 | N = pred.size(0)
217 | pred_norm = nn.functional.normalize(pred, dim=1)
218 | target_norm = nn.functional.normalize(target, dim=1)
219 | loss = 2 - 2 * (pred_norm * target_norm).sum() / N
220 | return loss
221 |
222 | def self_distill(self, q, k):
223 | q = F.log_softmax(q / self.student_temp, dim=-1)
224 | k = F.softmax((k - self.center) / self.teacher_temp, dim=-1)
225 | return torch.sum(-k * q, dim=-1).mean()
226 |
227 | def loss_func(self, online, target):
228 | z_o, obj_o, res_o = online
229 | z_t, obj_t, res_t = target
230 | # instance-level loss
231 | z_o_pred = self.predictor(z_o).squeeze(-1)
232 | z_t = z_t.squeeze(-1)
233 | loss_inst = self.mse_loss(z_o_pred, z_t)
234 | # object-level loss
235 | b, c, n = obj_o.shape
236 | obj_o_pred = self.predictor_obj(obj_o).transpose(2, 1).reshape(b*n, c)
237 | obj_t = obj_t.transpose(2, 1).reshape(b*n, c)
238 | loss_obj = self.mse_loss(obj_o_pred, obj_t)
239 |
240 | # score_q = torch.einsum('bnc,bc->bn', F.normalize(obj_o_pred, dim=2), F.normalize(z_o_pred, dim=1))
241 | # score_k = torch.einsum('bnc,bc->bn', F.normalize(obj_t, dim=2), F.normalize(z_t, dim=1))
242 | # score_q = torch.einsum('bnc,bc->bn', F.normalize(obj_o.transpose(2, 1), dim=2), F.normalize(z_o.squeeze(-1), dim=1))
243 | # # score_q = torch.einsum('bnc,bc->bn', F.normalize(obj_t, dim=2), F.normalize(z_o.squeeze(-1), dim=1))
244 | # score_k = torch.einsum('bnc,bc->bn', F.normalize(obj_t, dim=2), F.normalize(z_t, dim=1))
245 |
246 | # score_q = torch.einsum('bnc,bc->bn', F.normalize(res_o["mask_embed"], dim=2), F.normalize(z_o.squeeze(-1), dim=1))
247 | # score_q = torch.einsum('bnc,bc->bn', F.normalize(res_o["mask_embed"], dim=2), F.normalize(z_t, dim=1))
248 | # score_q = torch.einsum('bnc,bc->bn', F.normalize(res_t["mask_embed"], dim=2), F.normalize(z_o.squeeze(-1), dim=1))
249 | # score_k = torch.einsum('bnc,bc->bn', F.normalize(res_t["mask_embed"], dim=2), F.normalize(z_t, dim=1))
250 | # loss_relation = self.self_distill(score_q, score_k)
251 |
252 | # score_q_1 = torch.einsum('bnc,bc->bn', F.normalize(res_o["mask_embed"], dim=2), F.normalize(z_t, dim=1))
253 | # score_q_2 = torch.einsum('bnc,bc->bn', F.normalize(res_t["mask_embed"], dim=2), F.normalize(z_o.squeeze(-1), dim=1))
254 | # score_k = torch.einsum('bnc,bc->bn', F.normalize(res_t["mask_embed"], dim=2), F.normalize(z_t, dim=1))
255 | # loss_relation = 0.5 * (self.self_distill(score_q_1, score_k) + self.self_distill(score_q_2, score_k))
256 |
257 | loss_base = loss_inst * self.loss_weight + loss_obj * (1 - self.loss_weight)
258 |
259 | # sum
260 | return loss_base, loss_inst, loss_obj
261 |
262 | @torch.no_grad()
263 | def momentum_update(self, cur_iter, max_iter):
264 | """
265 | Momentum update of the target network.
266 | """
267 | # momentum anneling
268 | momentum = 1. - (1. - self.base_m) * (cos(pi * cur_iter / float(max_iter)) + 1) / 2.0
269 | self.curr_m = momentum
270 | # parameter update for target network
271 | state_dict_ol = self.online_net.state_dict()
272 | state_dict_tgt = self.target_net.state_dict()
273 | for (k_ol, v_ol), (k_tgt, v_tgt) in zip(state_dict_ol.items(), state_dict_tgt.items()):
274 | assert k_tgt == k_ol, "state_dict names are different!"
275 | assert v_ol.shape == v_tgt.shape, "state_dict shapes are different!"
276 | if 'num_batches_tracked' in k_tgt:
277 | v_tgt.copy_(v_ol)
278 | else:
279 | v_tgt.copy_(v_tgt * momentum + (1. - momentum) * v_ol)
280 |
281 | @torch.no_grad()
282 | def update_center(self, teacher_output):
283 | """
284 | Update center used for teacher output.
285 | """
286 | batch_center = torch.mean(teacher_output, dim=0, keepdim=True)
287 | if is_distributed_training_run():
288 | dist.all_reduce(batch_center)
289 | batch_center = batch_center / dist.get_world_size()
290 |
291 | # ema update
292 | self.center = self.center * self.center_momentum + batch_center * (1 - self.center_momentum)
293 |
294 | def get_heatmap(self, x):
295 | _, _, out = self.online_net(x)
296 | return out
297 |
298 | def ctr_loss(self, online_1, online_2, target_1, target_2):
299 | z_o_1, obj_o_1, res_o_1 = online_1
300 | z_o_2, obj_o_2, res_o_2 = online_2
301 | z_t_1, obj_t_1, res_t_1 = target_1
302 | z_t_2, obj_t_2, res_t_2 = target_2
303 |
304 | # corre_o = torch.matmul(F.normalize(res_o_1["mask_embed"], dim=2),
305 | # F.normalize(res_o_2["mask_embed"], dim=2).transpose(2, 1)) # b, q, c
306 | # corre_t = torch.matmul(F.normalize(res_t_1["mask_embed"], dim=2),
307 | # F.normalize(res_t_2["mask_embed"], dim=2).transpose(2, 1)) # b, q, c
308 | # loss = self.self_distill(corre_o.flatten(0, 1), corre_t.flatten(0, 1))
309 | # score = corre_t.flatten(0, 1)
310 |
311 | loss = 0.5 * (self.self_distill(res_o_1["dots"].permute(0, 2, 3, 1).flatten(0, 2),
312 | res_t_1["dots"].permute(0, 2, 3, 1).flatten(0, 2))
313 | + self.self_distill(res_o_2["dots"].permute(0, 2, 3, 1).flatten(0, 2),
314 | res_t_2["dots"].permute(0, 2, 3, 1).flatten(0, 2)))
315 | score_k1 = res_t_1["dots"]
316 | score_k2 = res_t_2["dots"]
317 | score = torch.cat([score_k1, score_k2]).permute(0, 2, 3, 1).flatten(0, 2)
318 |
319 | return loss, score
320 |
321 | def forward(self, im_v1, im_v2=None, **kwargs):
322 | """
323 | Input:
324 | im_v1: a batch of view1 images
325 | im_v2: a batch of view2 images
326 | Output:
327 | loss
328 | """
329 | # for inference, online_net.backbone model only
330 | if im_v2 is None:
331 | feats = self.online_net.backbone(im_v1)
332 | return F.adaptive_avg_pool2d(feats, 1).flatten(1)
333 |
334 | # compute online_net features
335 | proj_online_v1 = self.online_net(im_v1)
336 | proj_online_v2 = self.online_net(im_v2)
337 |
338 | # compute target_net features
339 | with torch.no_grad(): # no gradient to keys
340 | proj_target_v1 = [x.clone().detach() if isinstance(x, torch.Tensor) else x for x in self.target_net(im_v1)]
341 | proj_target_v2 = [x.clone().detach() if isinstance(x, torch.Tensor) else x for x in self.target_net(im_v2)]
342 |
343 | # loss. NOTE: the predction is moved to loss_func
344 | loss_base1, loss_inst1, loss_obj1 = self.loss_func(proj_online_v1, proj_target_v2)
345 | loss_base2, loss_inst2, loss_obj2 = self.loss_func(proj_online_v2, proj_target_v1)
346 | loss_base = loss_base1 + loss_base2
347 |
348 | loss_relation, score = self.ctr_loss(proj_online_v1, proj_online_v2, proj_target_v1, proj_target_v2)
349 | loss = loss_base + loss_relation * self.loss_w_cluster
350 |
351 | loss_pack = {}
352 | loss_pack["base"] = loss_base
353 | loss_pack["inst"] = (loss_inst1 + loss_inst2) * self.loss_weight
354 | loss_pack["obj"] = (loss_obj1 + loss_obj2) * (1 - self.loss_weight)
355 | loss_pack["relation"] = loss_relation
356 |
357 | self.update_center(score)
358 |
359 | return loss, loss_pack
360 |
361 | def extra_repr(self) -> str:
362 | parts = []
363 | for name in ["loss_weight", "mask_type", "num_proto", "teacher_temp", "loss_w_cluster"]:
364 | parts.append(f"{name}={getattr(self, name)}")
365 | return ", ".join(parts)
366 |
367 |
368 | class LEWELB(LEWELB_EMAN):
369 | @torch.no_grad()
370 | def momentum_update(self, cur_iter, max_iter):
371 | """
372 | Momentum update of the target network.
373 | """
374 | # momentum anneling
375 | momentum = 1. - (1. - self.base_m) * (cos(pi * cur_iter / float(max_iter)) + 1) / 2.0
376 | self.curr_m = momentum
377 | # parameter update for target network
378 | for param_ol, param_tgt in zip(self.online_net.parameters(), self.target_net.parameters()):
379 | param_tgt.data = param_tgt.data * momentum + param_ol.data * (1. - momentum)
380 |
381 |
382 | if __name__ == '__main__':
383 | from models import get_model
384 | import backbone as backbone_models
385 |
386 | model_func = get_model("LEWELB_EMAN")
387 | norm_layer = None
388 | model = model_func(
389 | backbone_models.__dict__["resnet50_encoder"],
390 | dim=256,
391 | m=0.996,
392 | hid_dim=4096,
393 | norm_layer=norm_layer,
394 | num_neck_mlp=2,
395 | scale=1.,
396 | l2_norm=True,
397 | num_heads=4,
398 | loss_weight=0.5,
399 | mask_type="attn"
400 | )
401 | print(model)
402 |
403 | x1 = torch.randn(16, 3, 224, 224)
404 | x2 = torch.randn(16, 3, 224, 224)
405 | out = model(x1, x2)
406 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Attribution-NonCommercial 4.0 International
2 |
3 | =======================================================================
4 |
5 | Creative Commons Corporation ("Creative Commons") is not a law firm and
6 | does not provide legal services or legal advice. Distribution of
7 | Creative Commons public licenses does not create a lawyer-client or
8 | other relationship. Creative Commons makes its licenses and related
9 | information available on an "as-is" basis. Creative Commons gives no
10 | warranties regarding its licenses, any material licensed under their
11 | terms and conditions, or any related information. Creative Commons
12 | disclaims all liability for damages resulting from their use to the
13 | fullest extent possible.
14 |
15 | Using Creative Commons Public Licenses
16 |
17 | Creative Commons public licenses provide a standard set of terms and
18 | conditions that creators and other rights holders may use to share
19 | original works of authorship and other material subject to copyright
20 | and certain other rights specified in the public license below. The
21 | following considerations are for informational purposes only, are not
22 | exhaustive, and do not form part of our licenses.
23 |
24 | Considerations for licensors: Our public licenses are
25 | intended for use by those authorized to give the public
26 | permission to use material in ways otherwise restricted by
27 | copyright and certain other rights. Our licenses are
28 | irrevocable. Licensors should read and understand the terms
29 | and conditions of the license they choose before applying it.
30 | Licensors should also secure all rights necessary before
31 | applying our licenses so that the public can reuse the
32 | material as expected. Licensors should clearly mark any
33 | material not subject to the license. This includes other CC-
34 | licensed material, or material used under an exception or
35 | limitation to copyright. More considerations for licensors:
36 | wiki.creativecommons.org/Considerations_for_licensors
37 |
38 | Considerations for the public: By using one of our public
39 | licenses, a licensor grants the public permission to use the
40 | licensed material under specified terms and conditions. If
41 | the licensor's permission is not necessary for any reason--for
42 | example, because of any applicable exception or limitation to
43 | copyright--then that use is not regulated by the license. Our
44 | licenses grant only permissions under copyright and certain
45 | other rights that a licensor has authority to grant. Use of
46 | the licensed material may still be restricted for other
47 | reasons, including because others have copyright or other
48 | rights in the material. A licensor may make special requests,
49 | such as asking that all changes be marked or described.
50 | Although not required by our licenses, you are encouraged to
51 | respect those requests where reasonable. More considerations
52 | for the public:
53 | wiki.creativecommons.org/Considerations_for_licensees
54 |
55 | =======================================================================
56 |
57 | Creative Commons Attribution-NonCommercial 4.0 International Public
58 | License
59 |
60 | By exercising the Licensed Rights (defined below), You accept and agree
61 | to be bound by the terms and conditions of this Creative Commons
62 | Attribution-NonCommercial 4.0 International Public License ("Public
63 | License"). To the extent this Public License may be interpreted as a
64 | contract, You are granted the Licensed Rights in consideration of Your
65 | acceptance of these terms and conditions, and the Licensor grants You
66 | such rights in consideration of benefits the Licensor receives from
67 | making the Licensed Material available under these terms and
68 | conditions.
69 |
70 |
71 | Section 1 -- Definitions.
72 |
73 | a. Adapted Material means material subject to Copyright and Similar
74 | Rights that is derived from or based upon the Licensed Material
75 | and in which the Licensed Material is translated, altered,
76 | arranged, transformed, or otherwise modified in a manner requiring
77 | permission under the Copyright and Similar Rights held by the
78 | Licensor. For purposes of this Public License, where the Licensed
79 | Material is a musical work, performance, or sound recording,
80 | Adapted Material is always produced where the Licensed Material is
81 | synched in timed relation with a moving image.
82 |
83 | b. Adapter's License means the license You apply to Your Copyright
84 | and Similar Rights in Your contributions to Adapted Material in
85 | accordance with the terms and conditions of this Public License.
86 |
87 | c. Copyright and Similar Rights means copyright and/or similar rights
88 | closely related to copyright including, without limitation,
89 | performance, broadcast, sound recording, and Sui Generis Database
90 | Rights, without regard to how the rights are labeled or
91 | categorized. For purposes of this Public License, the rights
92 | specified in Section 2(b)(1)-(2) are not Copyright and Similar
93 | Rights.
94 | d. Effective Technological Measures means those measures that, in the
95 | absence of proper authority, may not be circumvented under laws
96 | fulfilling obligations under Article 11 of the WIPO Copyright
97 | Treaty adopted on December 20, 1996, and/or similar international
98 | agreements.
99 |
100 | e. Exceptions and Limitations means fair use, fair dealing, and/or
101 | any other exception or limitation to Copyright and Similar Rights
102 | that applies to Your use of the Licensed Material.
103 |
104 | f. Licensed Material means the artistic or literary work, database,
105 | or other material to which the Licensor applied this Public
106 | License.
107 |
108 | g. Licensed Rights means the rights granted to You subject to the
109 | terms and conditions of this Public License, which are limited to
110 | all Copyright and Similar Rights that apply to Your use of the
111 | Licensed Material and that the Licensor has authority to license.
112 |
113 | h. Licensor means the individual(s) or entity(ies) granting rights
114 | under this Public License.
115 |
116 | i. NonCommercial means not primarily intended for or directed towards
117 | commercial advantage or monetary compensation. For purposes of
118 | this Public License, the exchange of the Licensed Material for
119 | other material subject to Copyright and Similar Rights by digital
120 | file-sharing or similar means is NonCommercial provided there is
121 | no payment of monetary compensation in connection with the
122 | exchange.
123 |
124 | j. Share means to provide material to the public by any means or
125 | process that requires permission under the Licensed Rights, such
126 | as reproduction, public display, public performance, distribution,
127 | dissemination, communication, or importation, and to make material
128 | available to the public including in ways that members of the
129 | public may access the material from a place and at a time
130 | individually chosen by them.
131 |
132 | k. Sui Generis Database Rights means rights other than copyright
133 | resulting from Directive 96/9/EC of the European Parliament and of
134 | the Council of 11 March 1996 on the legal protection of databases,
135 | as amended and/or succeeded, as well as other essentially
136 | equivalent rights anywhere in the world.
137 |
138 | l. You means the individual or entity exercising the Licensed Rights
139 | under this Public License. Your has a corresponding meaning.
140 |
141 |
142 | Section 2 -- Scope.
143 |
144 | a. License grant.
145 |
146 | 1. Subject to the terms and conditions of this Public License,
147 | the Licensor hereby grants You a worldwide, royalty-free,
148 | non-sublicensable, non-exclusive, irrevocable license to
149 | exercise the Licensed Rights in the Licensed Material to:
150 |
151 | a. reproduce and Share the Licensed Material, in whole or
152 | in part, for NonCommercial purposes only; and
153 |
154 | b. produce, reproduce, and Share Adapted Material for
155 | NonCommercial purposes only.
156 |
157 | 2. Exceptions and Limitations. For the avoidance of doubt, where
158 | Exceptions and Limitations apply to Your use, this Public
159 | License does not apply, and You do not need to comply with
160 | its terms and conditions.
161 |
162 | 3. Term. The term of this Public License is specified in Section
163 | 6(a).
164 |
165 | 4. Media and formats; technical modifications allowed. The
166 | Licensor authorizes You to exercise the Licensed Rights in
167 | all media and formats whether now known or hereafter created,
168 | and to make technical modifications necessary to do so. The
169 | Licensor waives and/or agrees not to assert any right or
170 | authority to forbid You from making technical modifications
171 | necessary to exercise the Licensed Rights, including
172 | technical modifications necessary to circumvent Effective
173 | Technological Measures. For purposes of this Public License,
174 | simply making modifications authorized by this Section 2(a)
175 | (4) never produces Adapted Material.
176 |
177 | 5. Downstream recipients.
178 |
179 | a. Offer from the Licensor -- Licensed Material. Every
180 | recipient of the Licensed Material automatically
181 | receives an offer from the Licensor to exercise the
182 | Licensed Rights under the terms and conditions of this
183 | Public License.
184 |
185 | b. No downstream restrictions. You may not offer or impose
186 | any additional or different terms or conditions on, or
187 | apply any Effective Technological Measures to, the
188 | Licensed Material if doing so restricts exercise of the
189 | Licensed Rights by any recipient of the Licensed
190 | Material.
191 |
192 | 6. No endorsement. Nothing in this Public License constitutes or
193 | may be construed as permission to assert or imply that You
194 | are, or that Your use of the Licensed Material is, connected
195 | with, or sponsored, endorsed, or granted official status by,
196 | the Licensor or others designated to receive attribution as
197 | provided in Section 3(a)(1)(A)(i).
198 |
199 | b. Other rights.
200 |
201 | 1. Moral rights, such as the right of integrity, are not
202 | licensed under this Public License, nor are publicity,
203 | privacy, and/or other similar personality rights; however, to
204 | the extent possible, the Licensor waives and/or agrees not to
205 | assert any such rights held by the Licensor to the limited
206 | extent necessary to allow You to exercise the Licensed
207 | Rights, but not otherwise.
208 |
209 | 2. Patent and trademark rights are not licensed under this
210 | Public License.
211 |
212 | 3. To the extent possible, the Licensor waives any right to
213 | collect royalties from You for the exercise of the Licensed
214 | Rights, whether directly or through a collecting society
215 | under any voluntary or waivable statutory or compulsory
216 | licensing scheme. In all other cases the Licensor expressly
217 | reserves any right to collect such royalties, including when
218 | the Licensed Material is used other than for NonCommercial
219 | purposes.
220 |
221 |
222 | Section 3 -- License Conditions.
223 |
224 | Your exercise of the Licensed Rights is expressly made subject to the
225 | following conditions.
226 |
227 | a. Attribution.
228 |
229 | 1. If You Share the Licensed Material (including in modified
230 | form), You must:
231 |
232 | a. retain the following if it is supplied by the Licensor
233 | with the Licensed Material:
234 |
235 | i. identification of the creator(s) of the Licensed
236 | Material and any others designated to receive
237 | attribution, in any reasonable manner requested by
238 | the Licensor (including by pseudonym if
239 | designated);
240 |
241 | ii. a copyright notice;
242 |
243 | iii. a notice that refers to this Public License;
244 |
245 | iv. a notice that refers to the disclaimer of
246 | warranties;
247 |
248 | v. a URI or hyperlink to the Licensed Material to the
249 | extent reasonably practicable;
250 |
251 | b. indicate if You modified the Licensed Material and
252 | retain an indication of any previous modifications; and
253 |
254 | c. indicate the Licensed Material is licensed under this
255 | Public License, and include the text of, or the URI or
256 | hyperlink to, this Public License.
257 |
258 | 2. You may satisfy the conditions in Section 3(a)(1) in any
259 | reasonable manner based on the medium, means, and context in
260 | which You Share the Licensed Material. For example, it may be
261 | reasonable to satisfy the conditions by providing a URI or
262 | hyperlink to a resource that includes the required
263 | information.
264 |
265 | 3. If requested by the Licensor, You must remove any of the
266 | information required by Section 3(a)(1)(A) to the extent
267 | reasonably practicable.
268 |
269 | 4. If You Share Adapted Material You produce, the Adapter's
270 | License You apply must not prevent recipients of the Adapted
271 | Material from complying with this Public License.
272 |
273 |
274 | Section 4 -- Sui Generis Database Rights.
275 |
276 | Where the Licensed Rights include Sui Generis Database Rights that
277 | apply to Your use of the Licensed Material:
278 |
279 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right
280 | to extract, reuse, reproduce, and Share all or a substantial
281 | portion of the contents of the database for NonCommercial purposes
282 | only;
283 |
284 | b. if You include all or a substantial portion of the database
285 | contents in a database in which You have Sui Generis Database
286 | Rights, then the database in which You have Sui Generis Database
287 | Rights (but not its individual contents) is Adapted Material; and
288 |
289 | c. You must comply with the conditions in Section 3(a) if You Share
290 | all or a substantial portion of the contents of the database.
291 |
292 | For the avoidance of doubt, this Section 4 supplements and does not
293 | replace Your obligations under this Public License where the Licensed
294 | Rights include other Copyright and Similar Rights.
295 |
296 |
297 | Section 5 -- Disclaimer of Warranties and Limitation of Liability.
298 |
299 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
300 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
301 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
302 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
303 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
304 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
305 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
306 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
307 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
308 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
309 |
310 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
311 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
312 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
313 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
314 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
315 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
316 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
317 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
318 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
319 |
320 | c. The disclaimer of warranties and limitation of liability provided
321 | above shall be interpreted in a manner that, to the extent
322 | possible, most closely approximates an absolute disclaimer and
323 | waiver of all liability.
324 |
325 |
326 | Section 6 -- Term and Termination.
327 |
328 | a. This Public License applies for the term of the Copyright and
329 | Similar Rights licensed here. However, if You fail to comply with
330 | this Public License, then Your rights under this Public License
331 | terminate automatically.
332 |
333 | b. Where Your right to use the Licensed Material has terminated under
334 | Section 6(a), it reinstates:
335 |
336 | 1. automatically as of the date the violation is cured, provided
337 | it is cured within 30 days of Your discovery of the
338 | violation; or
339 |
340 | 2. upon express reinstatement by the Licensor.
341 |
342 | For the avoidance of doubt, this Section 6(b) does not affect any
343 | right the Licensor may have to seek remedies for Your violations
344 | of this Public License.
345 |
346 | c. For the avoidance of doubt, the Licensor may also offer the
347 | Licensed Material under separate terms or conditions or stop
348 | distributing the Licensed Material at any time; however, doing so
349 | will not terminate this Public License.
350 |
351 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
352 | License.
353 |
354 |
355 | Section 7 -- Other Terms and Conditions.
356 |
357 | a. The Licensor shall not be bound by any additional or different
358 | terms or conditions communicated by You unless expressly agreed.
359 |
360 | b. Any arrangements, understandings, or agreements regarding the
361 | Licensed Material not stated herein are separate from and
362 | independent of the terms and conditions of this Public License.
363 |
364 |
365 | Section 8 -- Interpretation.
366 |
367 | a. For the avoidance of doubt, this Public License does not, and
368 | shall not be interpreted to, reduce, limit, restrict, or impose
369 | conditions on any use of the Licensed Material that could lawfully
370 | be made without permission under this Public License.
371 |
372 | b. To the extent possible, if any provision of this Public License is
373 | deemed unenforceable, it shall be automatically reformed to the
374 | minimum extent necessary to make it enforceable. If the provision
375 | cannot be reformed, it shall be severed from this Public License
376 | without affecting the enforceability of the remaining terms and
377 | conditions.
378 |
379 | c. No term or condition of this Public License will be waived and no
380 | failure to comply consented to unless expressly agreed to by the
381 | Licensor.
382 |
383 | d. Nothing in this Public License constitutes or may be interpreted
384 | as a limitation upon, or waiver of, any privileges and immunities
385 | that apply to the Licensor or You, including from the legal
386 | processes of any jurisdiction or authority.
387 |
388 | =======================================================================
389 |
390 | Creative Commons is not a party to its public
391 | licenses. Notwithstanding, Creative Commons may elect to apply one of
392 | its public licenses to material it publishes and in those instances
393 | will be considered the “Licensor.” The text of the Creative Commons
394 | public licenses is dedicated to the public domain under the CC0 Public
395 | Domain Dedication. Except for the limited purpose of indicating that
396 | material is shared under a Creative Commons public license or as
397 | otherwise permitted by the Creative Commons policies published at
398 | creativecommons.org/policies, Creative Commons does not authorize the
399 | use of the trademark "Creative Commons" or any other trademark or logo
400 | of Creative Commons without its prior written consent including,
401 | without limitation, in connection with any unauthorized modifications
402 | to any of its public licenses or any other arrangements,
403 | understandings, or agreements concerning use of licensed material. For
404 | the avoidance of doubt, this paragraph does not form part of the
405 | public licenses.
406 |
407 | Creative Commons may be contacted at creativecommons.org.
--------------------------------------------------------------------------------
/models/fra.py:
--------------------------------------------------------------------------------
1 | # Copyright Lang Huang (laynehuang@outlook.com). All Rights Reserved.
2 | # SPDX-License-Identifier: CC-BY-NC-4.0
3 | import sys
4 | import math
5 | from math import cos, pi
6 | import torch
7 | import torch.nn as nn
8 | from torch.nn import functional as F
9 | from torch.nn.modules import loss
10 | import torch.distributed as dist
11 | from classy_vision.generic.distributed_util import is_distributed_training_run
12 |
13 | from models.transformers.transformer_predictor import TransformerPredictor
14 | from utils import init
15 |
16 |
17 |
18 | @torch.no_grad()
19 | def distributed_sinkhorn(Q, num_itr=3, use_dist=True, epsilon=0.05):
20 | _got_dist = use_dist and torch.distributed.is_available() \
21 | and torch.distributed.is_initialized() \
22 | and (torch.distributed.get_world_size() > 1)
23 |
24 | if _got_dist:
25 | world_size = torch.distributed.get_world_size()
26 | else:
27 | world_size = 1
28 |
29 | Q = Q.T
30 | # Q = torch.exp(Q / epsilon).t()
31 | B = Q.shape[1] * world_size # number of samples to assign
32 | K = Q.shape[0] # how many prototypes
33 |
34 | # make the matrix sums to 1
35 | sum_Q = torch.sum(Q)
36 | if _got_dist:
37 | torch.distributed.all_reduce(sum_Q)
38 | Q /= sum_Q
39 |
40 | for it in range(num_itr):
41 | # normalize each row: total weight per prototype must be 1/K
42 | sum_of_rows = torch.sum(Q, dim=1, keepdim=True)
43 | if _got_dist:
44 | torch.distributed.all_reduce(sum_of_rows)
45 | Q /= sum_of_rows
46 | Q /= K
47 |
48 | # normalize each column: total weight per sample must be 1/B
49 | Q /= torch.sum(Q, dim=0, keepdim=True)
50 | Q /= B
51 |
52 | Q *= B # the colomns must sum to 1 so that Q is an assignment
53 | return Q.T
54 |
55 |
56 | class MLP1D(nn.Module):
57 | """
58 | The non-linear neck in byol: fc-bn-relu-fc
59 | """
60 | def __init__(self, in_channels, hid_channels, out_channels,
61 | norm_layer=None, bias=False, num_mlp=2):
62 | super(MLP1D, self).__init__()
63 | if norm_layer is None:
64 | norm_layer = nn.BatchNorm1d
65 | mlps = []
66 | for _ in range(num_mlp-1):
67 | mlps.append(nn.Conv1d(in_channels, hid_channels, 1, bias=bias))
68 | mlps.append(norm_layer(hid_channels))
69 | mlps.append(nn.ReLU(inplace=True))
70 | in_channels = hid_channels
71 | mlps.append(nn.Conv1d(hid_channels, out_channels, 1, bias=bias))
72 | self.mlp = nn.Sequential(*mlps)
73 |
74 | def init_weights(self, init_linear='normal'):
75 | init.init_weights(self, init_linear)
76 |
77 | def forward(self, x):
78 | x = self.mlp(x)
79 | return x
80 |
81 |
82 | class ObjectNeck(nn.Module):
83 | def __init__(self,
84 | in_channels,
85 | out_channels,
86 | hid_channels=None,
87 | num_layers=1,
88 | scale=1.,
89 | l2_norm=True,
90 | num_heads=8,
91 | norm_layer=None,
92 | mask_type="group",
93 | num_proto=64,
94 | temp=0.07,
95 | **kwargs):
96 | super(ObjectNeck, self).__init__()
97 |
98 | self.scale = scale
99 | self.l2_norm = l2_norm
100 | assert l2_norm
101 | self.num_heads = num_heads
102 | self.mask_type = mask_type
103 | self.temp = temp
104 | self.eps = 1e-7
105 |
106 | hid_channels = hid_channels or in_channels
107 | self.proj = MLP1D(in_channels, hid_channels, out_channels, norm_layer, num_mlp=num_layers)
108 | self.proj_pixel = MLP1D(in_channels, hid_channels, out_channels, norm_layer, num_mlp=num_layers)
109 | self.proj_obj = MLP1D(in_channels, hid_channels, out_channels, norm_layer, num_mlp=num_layers)
110 |
111 | if mask_type == "attn":
112 | self.proj_attn = TransformerPredictor(in_channels=in_channels, hidden_dim=out_channels, num_queries=num_proto,
113 | nheads=8, dropout=0.1, dim_feedforward=out_channels, enc_layers=0,
114 | dec_layers=2, pre_norm=False, deep_supervision=False,
115 | mask_dim=out_channels, enforce_input_project=False,
116 | mask_classification=False, num_classes=0)
117 |
118 | self.proto_momentum = 0.9
119 | self.register_buffer("proto", torch.randn(num_proto, out_channels))
120 | # self.proto = nn.Embedding(num_proto, out_channels)
121 |
122 | def init_weights(self, init_linear='kaiming'):
123 | self.proj.init_weights(init_linear)
124 | self.proj_pixel.init_weights(init_linear)
125 | self.proj_obj.init_weights(init_linear)
126 |
127 | @torch.no_grad()
128 | def update_proto(self, mask_embed):
129 | """
130 | Update center used for teacher output.
131 | """
132 | batch_center = torch.mean(mask_embed, dim=0)
133 | if is_distributed_training_run():
134 | dist.all_reduce(batch_center)
135 | batch_center = batch_center / dist.get_world_size()
136 |
137 | # ema update
138 | self.proto = self.proto * self.proto_momentum + batch_center * (1 - self.proto_momentum)
139 |
140 | def forward(self, x, isTrain=True):
141 | out = {}
142 |
143 | b, c, h, w = x.shape
144 |
145 | # flatten and projection
146 | x_pool = F.adaptive_avg_pool2d(x, 1).flatten(2)
147 | x = x.flatten(2) # (bs, c, h*w)
148 | z_g = self.proj(x_pool)
149 | z_feat = self.proj_pixel(x)
150 |
151 | if self.mask_type == "attn":
152 | z_feat = z_feat.view(b, -1, h, w)
153 | x = x.view(b, c, h, w)
154 | # attn_out = self.proj_attn(z_feat, None)
155 | attn_out = self.proj_attn(x, None)
156 | mask_embed = attn_out["mask_embed"] # (bs, q, c)
157 |
158 | if isTrain:
159 | # mask_embed = AllReduce.apply(torch.mean(mask_embed, dim=0, keepdim=True))
160 | mask_embed_avg = torch.mean(mask_embed, dim=0, keepdim=True)
161 | if is_distributed_training_run():
162 | dist.all_reduce(mask_embed_avg)
163 | mask_embed_avg = mask_embed_avg / dist.get_world_size()
164 | mask_embed_avg = mask_embed_avg.repeat(x.size(0), 1, 1)
165 | if z_feat.requires_grad:
166 | assert mask_embed_avg.requires_grad
167 |
168 | dots = torch.einsum('bqc,bchw->bqhw', F.normalize(mask_embed_avg, dim=2), F.normalize(z_feat, dim=1))
169 | else:
170 | dots = torch.einsum('qc,bchw->bqhw', F.normalize(self.proto, dim=1), F.normalize(z_feat, dim=1))
171 |
172 | obj_attn = (dots / self.scale).softmax(dim=1) + self.eps
173 |
174 | slots = torch.einsum('bchw,bqhw->bqc', x, obj_attn / obj_attn.sum(dim=(2, 3), keepdim=True))
175 |
176 | out["dots"] = dots
177 | out["feat"] = z_feat
178 | out["obj_attn"] = obj_attn
179 | else:
180 | # do attention according to obj attention map
181 | obj_attn = F.normalize(z_feat, dim=1) if self.l2_norm else z_feat
182 | obj_attn /= self.scale
183 | obj_attn = obj_attn.view(b, self.num_heads, -1, h * w) # (bs, h, d/h, h*w)
184 |
185 | if self.mask_type == "group":
186 | obj_attn = F.softmax(obj_attn, dim=-1)
187 | x = x.view(b, self.num_heads, -1, h*w) # (bs, h, c/h, h*w)
188 | obj_val = torch.matmul(x, obj_attn.transpose(3, 2)) # (bs, h, c//h, d/h)
189 | obj_val = obj_val.view(b, c, obj_attn.shape[-2]) # (bs, c, d/h)
190 | elif self.mask_type == "max":
191 | obj_attn, _ = torch.max(obj_attn, dim=1) # (bs, d/h, h*w)
192 | # obj_attn = torch.mean(obj_attn, dim=1)
193 | out["obj_attn"] = obj_attn
194 | obj_attn = F.softmax(obj_attn, dim=-1)
195 | obj_val = torch.matmul(x, obj_attn.transpose(2, 1)) # (bs, c, d/h)
196 | elif self.mask_type == "attn":
197 | obj_val = slots.transpose(2, 1) # (bs, c, q)
198 |
199 | # projection
200 | obj_val = self.proj_obj(obj_val) # (bs, d, q)
201 |
202 | if isTrain:
203 | self.update_proto(mask_embed)
204 |
205 | return z_g, obj_val, out # (bs, d, 1), (bs, d, d//h), where the second dim is channel
206 |
207 | def extra_repr(self) -> str:
208 | parts = []
209 | for name in ["scale", "l2_norm", "num_heads"]:
210 | parts.append(f"{name}={getattr(self, name)}")
211 | return ", ".join(parts)
212 |
213 |
214 | class EncoderObj(nn.Module):
215 | def __init__(self, base_encoder, hid_dim, out_dim, norm_layer=None, num_mlp=2,
216 | scale=1., l2_norm=True, num_heads=8, mask_type="group", num_proto=64, temp=0.07):
217 | super(EncoderObj, self).__init__()
218 | self.backbone = base_encoder(norm_layer=norm_layer, with_avgpool=False)
219 | in_dim = self.backbone.out_channels
220 | self.neck = ObjectNeck(in_channels=in_dim, hid_channels=hid_dim, out_channels=out_dim,
221 | norm_layer=norm_layer, num_layers=num_mlp,
222 | scale=scale, l2_norm=l2_norm, num_heads=num_heads, mask_type=mask_type,
223 | num_proto=num_proto, temp=temp)
224 | self.neck.init_weights(init_linear='kaiming')
225 |
226 | def forward(self, im, isTrain=True):
227 | out = self.backbone(im)
228 | out = self.neck(out, isTrain)
229 | return out
230 |
231 |
232 | class FRAB_EMAN(nn.Module):
233 | def __init__(self, base_encoder, dim=256, m=0.996, hid_dim=4096, norm_layer=None, num_neck_mlp=2,
234 | scale=1., l2_norm=True, num_heads=8, loss_weight=0.5, mask_type="group", num_proto=8,
235 | teacher_temp=0.04, student_temp=0.1, loss_w_cluster=0.1, **kwargs):
236 | super().__init__()
237 |
238 | self.base_m = m
239 | self.curr_m = m
240 | self.loss_weight = loss_weight
241 | self.loss_w_cluster = loss_w_cluster
242 | self.loss_w_obj = 0.02
243 | self.mask_type = mask_type
244 | assert mask_type in ["group", "max", "attn"]
245 | self.num_proto = num_proto
246 | self.student_temp = student_temp # 0.1
247 | self.teacher_temp = teacher_temp # 0.04
248 |
249 | # create the encoders
250 | # num_classes is the output fc dimension
251 | self.online_net = EncoderObj(base_encoder, hid_dim, dim, norm_layer, num_neck_mlp,
252 | scale=scale, l2_norm=l2_norm, num_heads=num_heads, mask_type=mask_type,
253 | num_proto=num_proto, temp=self.teacher_temp)
254 |
255 | self.target_net = EncoderObj(base_encoder, hid_dim, dim, norm_layer, num_neck_mlp,
256 | scale=scale, l2_norm=l2_norm, num_heads=num_heads, mask_type=mask_type,
257 | num_proto=num_proto, temp=self.teacher_temp)
258 | self.predictor = MLP1D(dim, hid_dim, dim, norm_layer=norm_layer)
259 | self.predictor.init_weights()
260 | self.predictor_obj = MLP1D(dim, hid_dim, dim, norm_layer=norm_layer)
261 | self.predictor_obj.init_weights()
262 | self.encoder_q = self.online_net.backbone
263 |
264 | # copy params from online model to target model
265 | for param_ol, param_tgt in zip(self.online_net.parameters(), self.target_net.parameters()):
266 | param_tgt.data.copy_(param_ol.data) # initialize
267 | param_tgt.requires_grad = False # not update by gradient
268 |
269 | self.center_momentum = 0.9
270 | self.register_buffer("center", torch.zeros(1, self.num_proto))
271 |
272 | def mse_loss(self, pred, target):
273 | """
274 | Args:
275 | pred (Tensor): NxC input features.
276 | target (Tensor): NxC target features.
277 | """
278 | N = pred.size(0)
279 | pred_norm = nn.functional.normalize(pred, dim=1)
280 | target_norm = nn.functional.normalize(target, dim=1)
281 | loss = 2 - 2 * (pred_norm * target_norm).sum() / N
282 | return loss
283 |
284 | def self_distill(self, q, k, use_sinkhorn=True, me_max=True):
285 | q_probs = F.log_softmax(q / self.student_temp, dim=-1)
286 | k_probs = F.softmax((k - self.center) / self.teacher_temp, dim=-1)
287 |
288 | if use_sinkhorn:
289 | k_probs = distributed_sinkhorn(k_probs)
290 |
291 | ce_loss = torch.sum(-k_probs * q_probs, dim=-1).mean()
292 |
293 | rloss = 0.
294 | if me_max:
295 | probs = F.softmax(q / self.student_temp, dim=-1)
296 |
297 | avg_probs = torch.mean(probs, dim=0)
298 | if is_distributed_training_run():
299 | dist.all_reduce(avg_probs)
300 | avg_probs = avg_probs / dist.get_world_size()
301 | # avg_probs = AllReduce.apply(torch.mean(probs, dim=0))
302 | rloss = - torch.sum(torch.log(avg_probs**(-avg_probs))) + math.log(float(len(avg_probs)))
303 |
304 | loss = ce_loss + 1.0 * rloss
305 |
306 | return loss
307 |
308 | def assign_loss(self, online_1, online_2, target_1, target_2):
309 | z_o1, obj_o1, res_o1 = online_1
310 | z_o2, obj_o2, res_o2 = online_2
311 | z_t1, obj_t1, res_t1 = target_1
312 | z_t2, obj_t2, res_t2 = target_2
313 |
314 | loss = 0.5 * (self.self_distill(res_o1["dots"].permute(0, 2, 3, 1).flatten(0, 2),
315 | res_t1["dots"].permute(0, 2, 3, 1).flatten(0, 2))
316 | + self.self_distill(res_o2["dots"].permute(0, 2, 3, 1).flatten(0, 2),
317 | res_t2["dots"].permute(0, 2, 3, 1).flatten(0, 2)))
318 | score_k1 = res_t1["dots"]
319 | score_k2 = res_t2["dots"]
320 | score = torch.cat([score_k1, score_k2]).permute(0, 2, 3, 1).flatten(0, 2)
321 |
322 | return loss, score
323 |
324 | def compute_unigrad_loss(self, pred, target, idxs=None):
325 | pred = F.normalize(pred, dim=-1)
326 | target = F.normalize(target, dim=-1)
327 |
328 | dense_pred = pred.reshape(-1, pred.shape[-1])
329 | dense_target = target.reshape(-1, target.shape[-1])
330 |
331 | # compute pos term
332 | if idxs is not None:
333 | pos_term = self.mse_loss(dense_pred[idxs], dense_target[idxs])
334 | else:
335 | pos_term = self.mse_loss(dense_pred, dense_target)
336 |
337 | # compute neg term
338 | mask = torch.eye(pred.shape[1], device=pred.device).unsqueeze(0).repeat(pred.size(0), 1, 1)
339 | correlation = torch.matmul(pred, target.transpose(2, 1)) # b,c,c
340 | correlation = correlation * (1.0 - mask)
341 | neg_term = ((correlation**2).sum(-1) / target.shape[1]).reshape(-1)
342 |
343 | if idxs is not None:
344 | neg_term = torch.mean(neg_term[idxs])
345 | else:
346 | neg_term = torch.mean(neg_term)
347 |
348 | # # correlation = (dense_target.T @ dense_target) / dense_target.shape[0]
349 | # correlation = torch.matmul(target.transpose(2, 1), target) / target.shape[1] # b,c,c
350 | # # if is_distributed_training_run():
351 | # # torch.distributed.all_reduce(correlation)
352 | # # correlation = correlation / torch.distributed.get_world_size()
353 | #
354 | # # neg_term = torch.diagonal(dense_pred @ correlation @ dense_pred.T).mean()
355 | # neg_term = torch.matmul(torch.matmul(pred, correlation), pred.transpose(2, 1))
356 | # neg_term = torch.diagonal(neg_term, dim1=-2, dim2=-1).mean()
357 |
358 | loss = pos_term + self.loss_w_obj * neg_term
359 |
360 | return loss
361 |
362 | def loss_func(self, online, target):
363 | z_o, obj_o, res_o = online
364 | z_t, obj_t, res_t = target
365 |
366 | # instance-level loss
367 | z_o_pred = self.predictor(z_o).squeeze(-1)
368 | z_t = z_t.squeeze(-1)
369 | loss_inst = self.mse_loss(z_o_pred, z_t)
370 |
371 | # object-level loss
372 | b, c, n = obj_o.shape
373 | obj_o_pred = self.predictor_obj(obj_o).transpose(2, 1)
374 | obj_t = obj_t.transpose(2, 1)
375 |
376 | score_q = res_o["dots"]
377 | score_k = res_t["dots"]
378 | mask_q = (torch.zeros_like(score_q).scatter_(1, score_q.argmax(1, keepdim=True), 1).sum(-1).sum(
379 | -1) > 0).long().detach()
380 | mask_k = (torch.zeros_like(score_k).scatter_(1, score_k.argmax(1, keepdim=True), 1).sum(-1).sum(
381 | -1) > 0).long().detach()
382 | mask_intersection = (mask_q * mask_k).view(-1)
383 | idxs_q = mask_intersection.nonzero().squeeze(-1)
384 |
385 | # loss_obj = self.mse_loss(obj_o_pred.reshape(b*n, c)[idxs_q], obj_t.reshape(b*n, c)[idxs_q])
386 | # loss_obj = self.compute_unigrad_loss(obj_o_pred, obj_t, idxs_q)
387 | loss_obj = self.compute_unigrad_loss(obj_o_pred, obj_t)
388 |
389 | loss_base = loss_inst * self.loss_weight + loss_obj * (1 - self.loss_weight)
390 |
391 | # sum
392 | return loss_base, loss_inst, loss_obj
393 |
394 | @torch.no_grad()
395 | def momentum_update(self, cur_iter, max_iter):
396 | """
397 | Momentum update of the target network.
398 | """
399 | # momentum anneling
400 | momentum = 1. - (1. - self.base_m) * (cos(pi * cur_iter / float(max_iter)) + 1) / 2.0
401 | self.curr_m = momentum
402 | # parameter update for target network
403 | state_dict_ol = self.online_net.state_dict()
404 | state_dict_tgt = self.target_net.state_dict()
405 | for (k_ol, v_ol), (k_tgt, v_tgt) in zip(state_dict_ol.items(), state_dict_tgt.items()):
406 | assert k_tgt == k_ol, "state_dict names are different!"
407 | assert v_ol.shape == v_tgt.shape, "state_dict shapes are different!"
408 | if 'num_batches_tracked' in k_tgt:
409 | v_tgt.copy_(v_ol)
410 | else:
411 | v_tgt.copy_(v_tgt * momentum + (1. - momentum) * v_ol)
412 |
413 | @torch.no_grad()
414 | def update_center(self, teacher_output):
415 | """
416 | Update center used for teacher output.
417 | """
418 | batch_center = torch.mean(teacher_output, dim=0, keepdim=True)
419 | if is_distributed_training_run():
420 | dist.all_reduce(batch_center)
421 | batch_center = batch_center / dist.get_world_size()
422 |
423 | # ema update
424 | self.center = self.center * self.center_momentum + batch_center * (1 - self.center_momentum)
425 |
426 | def forward(self, im_v1, im_v2=None, **kwargs):
427 | """
428 | Input:
429 | im_v1: a batch of view1 images
430 | im_v2: a batch of view2 images
431 | Output:
432 | loss
433 | """
434 | # for inference, online_net.backbone model only
435 | if im_v2 is None:
436 | feats = self.online_net.backbone(im_v1)
437 | return F.adaptive_avg_pool2d(feats, 1).flatten(1)
438 |
439 | # compute online_net features
440 | proj_online_v1 = self.online_net(im_v1)
441 | proj_online_v2 = self.online_net(im_v2)
442 |
443 | # compute target_net features
444 | with torch.no_grad(): # no gradient to keys
445 | proj_target_v1 = [x.clone().detach() if isinstance(x, torch.Tensor) else x for x in self.target_net(im_v1)]
446 | proj_target_v2 = [x.clone().detach() if isinstance(x, torch.Tensor) else x for x in self.target_net(im_v2)]
447 |
448 | # loss. NOTE: the predction is moved to loss_func
449 | loss_base1, loss_inst1, loss_obj1 = self.loss_func(proj_online_v1, proj_target_v2)
450 | loss_base2, loss_inst2, loss_obj2 = self.loss_func(proj_online_v2, proj_target_v1)
451 | loss_base = loss_base1 + loss_base2
452 |
453 | loss_cluster, score = self.assign_loss(proj_online_v1, proj_online_v2, proj_target_v1, proj_target_v2)
454 | loss = loss_base + loss_cluster * self.loss_w_cluster
455 |
456 | loss_pack = {}
457 | loss_pack["base"] = loss_base
458 | loss_pack["inst"] = (loss_inst1 + loss_inst2) * self.loss_weight
459 | loss_pack["obj"] = (loss_obj1 + loss_obj2) * (1 - self.loss_weight)
460 | loss_pack["clu"] = loss_cluster
461 |
462 | # self.update_center(score)
463 |
464 | return loss, loss_pack
465 |
466 | def extra_repr(self) -> str:
467 | parts = []
468 | for name in ["loss_weight", "mask_type", "num_proto", "teacher_temp", "loss_w_obj", "loss_w_cluster"]:
469 | parts.append(f"{name}={getattr(self, name)}")
470 | return ", ".join(parts)
471 |
472 |
473 | class FRAB(FRAB_EMAN):
474 | @torch.no_grad()
475 | def momentum_update(self, cur_iter, max_iter):
476 | """
477 | Momentum update of the target network.
478 | """
479 | # momentum anneling
480 | momentum = 1. - (1. - self.base_m) * (cos(pi * cur_iter / float(max_iter)) + 1) / 2.0
481 | self.curr_m = momentum
482 | # parameter update for target network
483 | for param_ol, param_tgt in zip(self.online_net.parameters(), self.target_net.parameters()):
484 | param_tgt.data = param_tgt.data * momentum + param_ol.data * (1. - momentum)
485 |
486 |
487 | if __name__ == '__main__':
488 | from models import get_model
489 | import backbone as backbone_models
490 |
491 | checkpoint = torch.load("./checkpoints/flr_r50_vgg_face.pth", map_location="cpu")
492 | state_dict = checkpoint['state_dict'] if "state_dict" in checkpoint else checkpoint
493 |
494 | model_func = get_model("FRAB")
495 | norm_layer = None
496 | model = model_func(
497 | backbone_models.__dict__["resnet50_encoder"],
498 | dim=256,
499 | m=0.996,
500 | hid_dim=4096,
501 | norm_layer=norm_layer,
502 | num_neck_mlp=2,
503 | scale=1.,
504 | l2_norm=True,
505 | num_heads=4,
506 | loss_weight=0.5,
507 | mask_type="attn",
508 | num_proto=8,
509 | teacher_temp=0.04,
510 | )
511 | print(model)
512 |
513 | x1 = torch.randn(16, 3, 224, 224)
514 | x2 = torch.randn(16, 3, 224, 224)
515 | out = model(x1, x2)
516 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | # some code in this file is adapted from
2 | # https://github.com/pytorch/examples
3 | # Original Copyright 2017. Licensed under the BSD 3-Clause License.
4 | # Modifications Copyright Lang Huang (laynehuang@outlook.com). All Rights Reserved.
5 | # SPDX-License-Identifier: CC-BY-NC-4.0
6 |
7 | import argparse
8 | import builtins
9 | from logging import root
10 | import os
11 | import time
12 |
13 | import torch
14 | import torch.nn.parallel
15 | import torch.nn.functional as F
16 | import torch.backends.cudnn as cudnn
17 | import torch.distributed as dist
18 | import torch.optim
19 | import torch.utils.data
20 | import torch.utils.data.distributed
21 | import torchvision
22 | import torchvision.transforms as transforms
23 | from classy_vision.generic.distributed_util import is_distributed_training_run
24 |
25 | import backbone as backbone_models
26 | from models import get_model
27 | from utils import utils, lr_schedule, LARS, get_norm, init_distributed_mode
28 | import data.transforms as data_transforms
29 | from engine import ss_validate, ss_face_validate
30 | from data.base_dataset import get_dataset
31 |
32 | backbone_model_names = sorted(name for name in backbone_models.__dict__
33 | if name.islower() and not name.startswith("__")
34 | and callable(backbone_models.__dict__[name]))
35 |
36 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
37 | parser.add_argument('--dataset', default="in1k",
38 | help='name of dataset', choices=['in1k', 'in100', 'im_folder', 'in1k_idx', "vggface2"])
39 | parser.add_argument('--data-root', default="",
40 | help='root of dataset folder')
41 | parser.add_argument('--arch', metavar='ARCH', default='LEWEL',
42 | help='model architecture')
43 | parser.add_argument('--backbone', default='resnet50_encoder',
44 | choices=backbone_model_names,
45 | help='model architecture: ' +
46 | ' | '.join(backbone_model_names) +
47 | ' (default: resnet50_encoder)')
48 | parser.add_argument('-j', '--workers', default=64, type=int, metavar='N',
49 | help='number of data loading workers (default: 64)')
50 | parser.add_argument('--epochs', default=200, type=int, metavar='N',
51 | help='number of total epochs to run')
52 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
53 | help='manual epoch number (useful on restarts)')
54 | parser.add_argument('--warmup-epoch', default=0, type=int, metavar='N',
55 | help='number of epochs for learning warmup')
56 | parser.add_argument('-b', '--batch-size', default=256, type=int,
57 | metavar='N',
58 | help='mini-batch size (default: 256), this is the total '
59 | 'batch size of all GPUs on the current node when '
60 | 'using Data Parallel or Distributed Data Parallel')
61 | parser.add_argument('--lr', '--learning-rate', default=0.03, type=float,
62 | metavar='LR', help='initial learning rate', dest='lr')
63 | parser.add_argument('--schedule', default=[120, 160], nargs='*', type=int,
64 | help='learning rate schedule (when to drop lr by 10x)')
65 | parser.add_argument('--cos', action='store_true', help='use cosine lr schedule')
66 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
67 | help='momentum of SGD solver')
68 | parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
69 | metavar='W', help='weight decay (default: 1e-4)',
70 | dest='weight_decay')
71 | parser.add_argument('--save-dir', default="ckpts",
72 | help='checkpoint directory')
73 | parser.add_argument('-p', '--print-freq', default=50, type=int,
74 | metavar='N', help='print frequency (default: 10)')
75 | parser.add_argument('--save-freq', default=10, type=int,
76 | metavar='N', help='checkpoint save frequency (default: 10)')
77 | parser.add_argument('--eval-freq', default=5, type=int,
78 | metavar='N', help='evaluation epoch frequency (default: 5)')
79 | parser.add_argument('--resume', default='', type=str, metavar='PATH',
80 | help='path to latest checkpoint (default: none)')
81 | parser.add_argument('--pretrained', default='', type=str, metavar='PATH',
82 | help='path to pretrained model (default: none)')
83 | parser.add_argument('--super-pretrained', default='', type=str, metavar='PATH',
84 | help='path to MoCo pretrained model (default: none)')
85 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
86 | help='evaluate model on validation set')
87 | parser.add_argument('--seed', default=23456, type=int,
88 | help='seed for initializing training. ')
89 |
90 | # dist
91 | parser.add_argument('--world_size', default=-1, type=int, help='number of nodes for distributed training')
92 | parser.add_argument('--rank', default=-1, type=int, help='node rank for distributed training')
93 | parser.add_argument('--gpu', default=None, type=int, help='GPU id to use.')
94 | parser.add_argument('--dist_backend', default='nccl', type=str, help='distributed backend')
95 | parser.add_argument("--dist_url", default="env://", type=str, help="""url used to set up
96 | distributed training; """)
97 | parser.add_argument("--local_rank", default=0, type=int, help="Please ignore and do not set this argument.")
98 | parser.add_argument('--multiprocessing_distributed', action='store_true',
99 | help='Use multi-processing distributed training to launch '
100 | 'N processes per node, which has N GPUs. This is the '
101 | 'fastest way to use PyTorch for either single node or '
102 | 'multi node data parallel training')
103 |
104 | # ssl specific configs:
105 | parser.add_argument('--proj-dim', default=256, type=int,
106 | help='feature dimension (default: 256)')
107 | parser.add_argument('--enc-m', default=0.996, type=float,
108 | help='momentum of updating key encoder (default: 0.996)')
109 | parser.add_argument('--norm', default='None', type=str,
110 | help='the normalization for network (default: None)')
111 | parser.add_argument('--num-neck-mlp', default=2, type=int,
112 | help='number of neck mlp (default: 2)')
113 | parser.add_argument('--hid-dim', default=4096, type=int,
114 | help='hidden dimension of mlp (default: 4096)')
115 | parser.add_argument('--amp', action='store_true',
116 | help='use automatic mixed precision training')
117 |
118 | # options for LEWEL
119 | parser.add_argument('--lewel-l2-norm', action='store_true',
120 | help='use l2-norm before applying softmax on attention map')
121 | parser.add_argument('--lewel-scale', default=1., type=float,
122 | help='Scale factor of attention map (default: 1.)')
123 | parser.add_argument('--lewel-num-heads', default=8, type=int,
124 | help='Number of heads in lewel (default: 8)')
125 | parser.add_argument('--lewel-loss-weight', default=0.5, type=float,
126 | help='loss weight for aligned branch (default: 0.5)')
127 |
128 | parser.add_argument('--train-percent', default=1.0, type=float, help='percentage of training set')
129 | parser.add_argument('--mask_type', default="group", type=str, help='type of masks')
130 | parser.add_argument('--num_proto', default=64, type=int,
131 | help='Number of heatmaps')
132 | parser.add_argument('--teacher_temp', default=0.07, type=float,
133 | help='temperature of the teacher')
134 | parser.add_argument('--loss_w_cluster', default=0.5, type=float,
135 | help='loss weight for cluster assignments (default: 0.5)')
136 |
137 |
138 | # options for KNN search
139 | parser.add_argument('--num-nn', default=20, type=int,
140 | help='Number of nearest neighbors (default: 20)')
141 | parser.add_argument('--nn-mem-percent', type=float, default=0.1,
142 | help='number of percentage mem datan for KNN evaluation')
143 | parser.add_argument('--nn-query-percent', type=float, default=0.5,
144 | help='number of percentage query datan for KNN evaluation')
145 |
146 |
147 | best_acc1 = 0
148 |
149 |
150 | def main(args):
151 | global best_acc1
152 | # args.gpu = args.local_rank
153 |
154 | # create model
155 | print("=> creating model '{}' with backbone '{}'".format(args.arch, args.backbone))
156 | model_func = get_model(args.arch)
157 | norm_layer = get_norm(args.norm)
158 | model = model_func(
159 | backbone_models.__dict__[args.backbone],
160 | dim=args.proj_dim,
161 | m=args.enc_m,
162 | hid_dim=args.hid_dim,
163 | norm_layer=norm_layer,
164 | num_neck_mlp=args.num_neck_mlp,
165 | scale=args.lewel_scale,
166 | l2_norm=args.lewel_l2_norm,
167 | num_heads=args.lewel_num_heads,
168 | loss_weight=args.lewel_loss_weight,
169 | mask_type=args.mask_type,
170 | num_proto=args.num_proto,
171 | teacher_temp=args.teacher_temp,
172 | loss_w_cluster=args.loss_w_cluster
173 | )
174 | print(model)
175 | print(args)
176 |
177 | if args.pretrained:
178 | if os.path.isfile(args.pretrained):
179 | print("=> loading pretrained model from '{}'".format(args.pretrained))
180 | state_dict = torch.load(args.pretrained, map_location="cpu")['state_dict']
181 | # rename state_dict keys
182 | for k in list(state_dict.keys()):
183 | new_key = k.replace("module.", "")
184 | state_dict[new_key] = state_dict[k]
185 | del state_dict[k]
186 | msg = model.load_state_dict(state_dict, strict=False)
187 | print("=> loaded pretrained model from '{}'".format(args.pretrained))
188 | if len(msg.missing_keys) > 0:
189 | print("missing keys: {}".format(msg.missing_keys))
190 | if len(msg.unexpected_keys) > 0:
191 | print("unexpected keys: {}".format(msg.unexpected_keys))
192 | else:
193 | print("=> no pretrained model found at '{}'".format(args.pretrained))
194 |
195 |
196 | model.cuda()
197 | args.batch_size = int(args.batch_size / args.world_size)
198 | args.workers = int((args.workers + args.world_size - 1) / args.world_size)
199 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
200 |
201 | # define optimizer
202 | # args.lr = args.batch_size * args.world_size / 1024 * args.lr
203 | if args.dataset == 'in100':
204 | args.lr *= 2
205 |
206 | # params = collect_params(model, exclude_bias_and_bn=True, sync_bn='EMAN' in args.arch)
207 | params = collect_params(model, exclude_bias_and_bn=True)
208 | optimizer = LARS(params, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
209 | scaler = torch.cuda.amp.GradScaler() if args.amp else None
210 |
211 | # optionally resume from a checkpoint
212 | if args.resume:
213 | if os.path.isfile(args.resume):
214 | print("=> loading checkpoint '{}'".format(args.resume))
215 | if args.gpu is None:
216 | checkpoint = torch.load(args.resume)
217 | else:
218 | # Map model to be loaded to specified single gpu.
219 | loc = 'cuda:{}'.format(args.gpu)
220 | checkpoint = torch.load(args.resume, map_location=loc)
221 | args.start_epoch = checkpoint['epoch']
222 | if 'best_acc1' in checkpoint:
223 | best_acc1 = checkpoint['best_acc1']
224 | model.load_state_dict(checkpoint['state_dict'])
225 | optimizer.load_state_dict(checkpoint['optimizer'])
226 | if 'scaler' in checkpoint:
227 | scaler.load_state_dict(checkpoint['scaler'])
228 | else:
229 | print("no scaler checkpoint")
230 | print("=> loaded checkpoint '{}' (epoch {})"
231 | .format(args.resume, checkpoint['epoch']))
232 | else:
233 | print("=> no checkpoint found at '{}'".format(args.resume))
234 |
235 | cudnn.benchmark = True
236 |
237 | # Data loading code
238 | if args.dataset.lower() == "vggface2":
239 | transform1, transform2 = data_transforms.get_vggface_tranforms(image_size=224)
240 | val_split = "test"
241 | else:
242 | transform1, transform2 = data_transforms.get_byol_tranforms()
243 | val_split = "val"
244 |
245 | train_dataset = get_dataset(
246 | args.dataset,
247 | mode='train',
248 | transform=data_transforms.TwoCropsTransform(transform1, transform2),
249 | data_root=args.data_root)
250 | print("train_dataset:\n{}".format(train_dataset))
251 |
252 | if args.train_percent < 1.0:
253 | num_subset = int(len(train_dataset) * args.train_percent)
254 | indices = torch.randperm(len(train_dataset))[:num_subset]
255 | indices = indices.tolist()
256 | train_dataset = torch.utils.data.Subset(train_dataset, indices)
257 | print("Sub train_dataset:\n{}".format(len(train_dataset)))
258 |
259 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
260 | train_loader = torch.utils.data.DataLoader(
261 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
262 | num_workers=args.workers, pin_memory=True, sampler=train_sampler, drop_last=True,
263 | persistent_workers=True)
264 |
265 | if args.dataset.lower() == "vggface2":
266 | normalize = transforms.Normalize(mean=data_transforms.IMG_MEAN["vggface2"],
267 | std=data_transforms.IMG_STD["vggface2"])
268 | transform_test = transforms.Compose([
269 | transforms.Resize((224, 224)),
270 | # transforms.CenterCrop(args.image_size),
271 | transforms.ToTensor(),
272 | normalize,
273 | ])
274 | val_dataset = torchvision.datasets.LFWPairs(root="../data/lfw", split="test",
275 | transform=transform_test, download=True)
276 | val_loader = torch.utils.data.DataLoader(
277 | val_dataset,
278 | batch_size=args.batch_size, shuffle=False,
279 | num_workers=args.workers//2, pin_memory=True,
280 | persistent_workers=True)
281 |
282 | else:
283 | val_loader_base = torch.utils.data.DataLoader(
284 | get_dataset(
285 | args.dataset,
286 | mode=val_split,
287 | transform=data_transforms.get_transforms("DefaultVal", args.dataset),
288 | data_root=args.data_root,
289 | percent=args.nn_mem_percent
290 | ),
291 | batch_size=args.batch_size, shuffle=False,
292 | num_workers=args.workers//2, pin_memory=True,
293 | persistent_workers=True)
294 |
295 | val_loader_query = torch.utils.data.DataLoader(
296 | get_dataset(
297 | args.dataset,
298 | mode=val_split,
299 | transform=data_transforms.get_transforms("DefaultVal", args.dataset),
300 | data_root=args.data_root,
301 | percent=args.nn_query_percent,
302 | ),
303 | batch_size=args.batch_size, shuffle=False,
304 | num_workers=args.workers//2, pin_memory=True,
305 | persistent_workers=True)
306 |
307 | if args.evaluate:
308 | # ss_validate(val_loader_base, val_loader_query, model, args)
309 | ss_face_validate(val_loader, model, args)
310 | return
311 |
312 | best_epoch = args.start_epoch
313 | for epoch in range(args.start_epoch, args.epochs):
314 | train_sampler.set_epoch(epoch)
315 | if epoch >= args.warmup_epoch:
316 | lr_schedule.adjust_learning_rate(optimizer, epoch, args)
317 |
318 | # train for one epoch
319 | train(train_loader, model, optimizer, scaler, epoch, args)
320 |
321 | is_best = False
322 | if (epoch + 1) % args.eval_freq == 0:
323 | # acc1 = ss_validate(val_loader_base, val_loader_query, model, args)
324 | acc1 = ss_face_validate(val_loader, model, args)
325 | # remember best acc@1 and save checkpoint
326 | is_best = acc1 > best_acc1
327 | best_acc1 = max(acc1, best_acc1)
328 | if is_best:
329 | best_epoch = epoch
330 |
331 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed
332 | and args.local_rank % args.world_size == 0):
333 | utils.save_checkpoint({
334 | 'epoch': epoch + 1,
335 | 'arch': args.arch,
336 | 'state_dict': model.state_dict(),
337 | 'best_acc1': best_acc1,
338 | 'optimizer': optimizer.state_dict(),
339 | 'scaler': None if scaler is None else scaler.state_dict(),
340 | }, is_best=is_best, epoch=epoch, args=args)
341 |
342 | print('Best Acc@1 {0} @ epoch {1}'.format(best_acc1, best_epoch + 1))
343 |
344 |
345 | def train(train_loader, model, optimizer, scaler, epoch, args):
346 | batch_time = utils.AverageMeter('Time', ':6.3f')
347 | data_time = utils.AverageMeter('Data', ':6.3f')
348 | losses = utils.AverageMeter('Loss', ':.4e')
349 | losses_base = utils.AverageMeter('Loss_base', ':.4e')
350 | losses_inst = utils.AverageMeter('Loss_inst', ':.4e')
351 | losses_obj = utils.AverageMeter('Loss_obj', ':.4e')
352 | losses_clu = utils.AverageMeter('Loss_clu', ':.4e')
353 | curr_lr = utils.InstantMeter('LR', ':.7f')
354 | curr_mom = utils.InstantMeter('MOM', ':.7f')
355 | progress = utils.ProgressMeter(
356 | len(train_loader),
357 | [curr_lr, curr_mom, batch_time, data_time, losses, losses_base, losses_inst, losses_obj, losses_clu],
358 | prefix="Epoch: [{}/{}]\t".format(epoch, args.epochs))
359 |
360 | # iter info
361 | batch_iter = len(train_loader)
362 | max_iter = float(batch_iter * args.epochs)
363 |
364 | # switch to train mode
365 | model.train()
366 | if "EMAN" in args.arch:
367 | print("setting the key model to eval mode when using EMAN")
368 | if hasattr(model, 'module'):
369 | model.module.target_net.eval()
370 | else:
371 | model.target_net.eval()
372 |
373 | end = time.time()
374 | for i, (images, _, idx) in enumerate(train_loader):
375 | # update model momentum
376 | curr_iter = float(epoch * batch_iter + i)
377 |
378 | # measure data loading time
379 | data_time.update(time.time() - end)
380 |
381 | if args.gpu is not None:
382 | images[0] = images[0].cuda(args.gpu, non_blocking=True)
383 | images[1] = images[1].cuda(args.gpu, non_blocking=True)
384 | idx = idx.cuda(args.gpu, non_blocking=True)
385 |
386 | # warmup learning rate
387 | if epoch < args.warmup_epoch:
388 | warmup_step = args.warmup_epoch * batch_iter
389 | curr_step = epoch * batch_iter + i + 1
390 | lr_schedule.warmup_learning_rate(optimizer, curr_step, warmup_step, args)
391 | curr_lr.update(optimizer.param_groups[0]['lr'])
392 |
393 | if scaler is None:
394 | # compute loss
395 | loss, loss_pack = model(im_v1=images[0], im_v2=images[1], idx=idx)
396 |
397 | # compute gradient and do SGD step
398 | optimizer.zero_grad()
399 | loss.backward()
400 | optimizer.step()
401 | else: # AMP
402 | optimizer.zero_grad()
403 | with torch.cuda.amp.autocast():
404 | loss, loss_pack = model(im_v1=images[0], im_v2=images[1], idx=idx)
405 |
406 | scaler.scale(loss).backward()
407 | scaler.step(optimizer)
408 | scaler.update()
409 |
410 | # measure accuracy and record loss
411 | losses.update(loss.item(), images[0].size(0))
412 | losses_base.update(loss_pack["base"].item(), images[0].size(0))
413 | losses_inst.update(loss_pack["inst"].item(), images[0].size(0))
414 | losses_obj.update(loss_pack["obj"].item(), images[0].size(0))
415 | losses_clu.update(loss_pack["clu"].item(), images[0].size(0))
416 |
417 | if hasattr(model, 'module'):
418 | model.module.momentum_update(curr_iter, max_iter)
419 | curr_mom.update(model.module.curr_m)
420 | else:
421 | model.momentum_update(curr_iter, max_iter)
422 | curr_mom.update(model.curr_m)
423 |
424 | # measure elapsed time
425 | batch_time.update(time.time() - end)
426 | end = time.time()
427 |
428 | if i % args.print_freq == 0:
429 | progress.display(i)
430 |
431 |
432 | def collect_params(model, exclude_bias_and_bn=True, sync_bn=True):
433 | """
434 | exclude_bias_and bn: exclude bias and bn from both weight decay and LARS adaptation
435 | in the PyTorch implementation of ResNet, `downsample.1` are bn layers
436 | """
437 | weight_param_list, bn_and_bias_param_list = [], []
438 | weight_param_names, bn_and_bias_param_names = [], []
439 | for name, param in model.named_parameters():
440 | if exclude_bias_and_bn and ('bn' in name or 'downsample.1' in name or 'bias' in name):
441 | bn_and_bias_param_list.append(param)
442 | bn_and_bias_param_names.append(name)
443 | else:
444 | weight_param_list.append(param)
445 | weight_param_names.append(name)
446 | print("weight params:\n{}".format('\n'.join(weight_param_names)))
447 | print("bn and bias params:\n{}".format('\n'.join(bn_and_bias_param_names)))
448 | param_list = [{'params': bn_and_bias_param_list, 'weight_decay': 0., 'lars_exclude': True},
449 | {'params': weight_param_list}]
450 | return param_list
451 |
452 |
453 | if __name__ == '__main__':
454 | opt = parser.parse_args()
455 | opt.distributed = True
456 | opt.multiprocessing_distributed = True
457 |
458 | # _, opt.local_rank, opt.world_size = dist_init(opt.port)
459 | # cudnn.benchmark = True
460 | #
461 | # # suppress printing if not master
462 | # if dist.get_rank() != 0:
463 | # def print_pass(*args, **kwargs):
464 | # pass
465 | # builtins.print = print_pass
466 |
467 | init_distributed_mode(opt)
468 |
469 | main(opt)
470 |
--------------------------------------------------------------------------------