├── LICENSE ├── M2TR ├── .DS_Store ├── __init__.py ├── datasets │ ├── CelebDF.py │ ├── FFDF.py │ ├── ForgeryNet.py │ ├── __init__.py │ ├── dataset.py │ └── utils.py ├── models │ ├── .DS_Store │ ├── __init__.py │ ├── base.py │ ├── efficientnet.py │ ├── m2tr.py │ ├── modules │ │ ├── conv_block.py │ │ ├── gram_block.py │ │ ├── head.py │ │ └── transformer_block.py │ └── xception.py └── utils │ ├── __init__.py │ ├── build_helper.py │ ├── checkpoint.py │ ├── distributed.py │ ├── env.py │ ├── logging.py │ ├── loss.py │ ├── meters.py │ ├── optimizer.py │ ├── registries.py │ ├── scheduler.py │ └── visualization.py ├── README.md ├── configs ├── .DS_Store ├── default.yaml └── m2tr.yaml ├── imgs └── network.png ├── requirements.txt ├── run.py ├── setup.py └── tools ├── __init__.py ├── test.py ├── train.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Junke Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /M2TR/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wdrink/M2TR-Multi-modal-Multi-scale-Transformers-for-Deepfake-Detection/6471c0fbcd84ecb5231ba4a4713fa749eab2ec07/M2TR/.DS_Store -------------------------------------------------------------------------------- /M2TR/__init__.py: -------------------------------------------------------------------------------- 1 | from M2TR.utils.env import setup_environment 2 | 3 | setup_environment() 4 | -------------------------------------------------------------------------------- /M2TR/datasets/CelebDF.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | 5 | from M2TR.datasets.dataset import DeepFakeDataset 6 | from M2TR.utils.registries import DATASET_REGISTRY 7 | 8 | from .utils import get_image_from_path 9 | 10 | ''' 11 | DATASET: 12 | DATASET_NAME: CelebDF 13 | ROOT_DIR: /some_where/celeb-df-v2 14 | TRAIN_INFO_TXT: '/some_where/celeb-df-v2/splits/train.txt' 15 | VAL_INFO_TXT: '/some_where/celeb-df-v2/splits/eval.txt' 16 | TEST_INFO_TXT: '/some_where/celeb-df-v2/splits/eval.txt' 17 | IMG_SIZE: 380 18 | SCALE_RATE: 1.0 19 | ''' 20 | 21 | 22 | @DATASET_REGISTRY.register() 23 | class CelebDF(DeepFakeDataset): 24 | def __getitem__(self, idx): 25 | info_line = self.info_list[idx] 26 | image_info = info_line.strip('\n').split() 27 | image_path = image_info[0] 28 | image_abs_path = os.path.join(self.root_dir, image_path) 29 | 30 | img, _ = get_image_from_path( 31 | image_abs_path, None, self.mode, self.dataset_cfg 32 | ) 33 | img_label_binary = int(image_info[1]) 34 | 35 | sample = { 36 | 'img': img, 37 | 'bin_label': [int(img_label_binary)], 38 | } 39 | 40 | sample['img'] = torch.FloatTensor(sample['img']) 41 | sample['bin_label'] = torch.FloatTensor(sample['bin_label']) 42 | sample['bin_label_onehot'] = self.label_to_one_hot( 43 | sample['bin_label'], 2 44 | ).squeeze() 45 | return sample 46 | -------------------------------------------------------------------------------- /M2TR/datasets/FFDF.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | 5 | from M2TR.datasets.dataset import DeepFakeDataset 6 | from M2TR.datasets.utils import ( 7 | get_image_from_path, 8 | get_mask_path_from_img_path, 9 | ) 10 | from M2TR.utils.registries import DATASET_REGISTRY 11 | 12 | ''' 13 | DATASET: 14 | DATASET_NAME: FFDF 15 | ROOT_DIR: /some_where/FF++/face 16 | TRAIN_INFO_TXT: '/some_where/train_face_c23.txt' 17 | VAL_INFO_TXT: '/some_where/val_face_c23.txt' 18 | TEST_INFO_TXT: '/some_where/test_face_c23.txt' 19 | IMG_SIZE: 380 20 | SCALE_RATE: 1.0 21 | ROTATE_ANGLE: 10 22 | CUTOUT_H: 10 23 | CUTOUT_W: 10 24 | COMPRESSION_LOW: 65 25 | COMPRESSION_HIGH: 80 26 | ''' 27 | 28 | 29 | @DATASET_REGISTRY.register() 30 | class FFDF(DeepFakeDataset): 31 | def __getitem__(self, idx): 32 | info_line = self.info_list[idx] 33 | image_info = info_line.strip('\n').split() 34 | image_path = image_info[0] 35 | image_abs_path = os.path.join(self.root_dir, image_path) 36 | 37 | mask_abs_path = get_mask_path_from_img_path( 38 | self.dataset_name, self.root_dir, image_path 39 | ) 40 | img, mask = get_image_from_path( 41 | image_abs_path, mask_abs_path, self.mode, self.dataset_cfg 42 | ) 43 | img_label_binary = int(image_info[1]) 44 | 45 | sample = { 46 | 'img': img, 47 | 'bin_label': [int(img_label_binary)], 48 | } 49 | 50 | sample['img'] = torch.FloatTensor(sample['img']) 51 | sample['bin_label'] = torch.FloatTensor(sample['bin_label']) 52 | sample['bin_label_onehot'] = self.label_to_one_hot( 53 | sample['bin_label'], 2 54 | ).squeeze() 55 | sample['mask'] = torch.FloatTensor(mask) 56 | 57 | return sample 58 | -------------------------------------------------------------------------------- /M2TR/datasets/ForgeryNet.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | 5 | from M2TR.datasets.dataset import DeepFakeDataset 6 | from M2TR.datasets.utils import ( 7 | get_image_from_path, 8 | get_mask_path_from_img_path, 9 | ) 10 | from M2TR.utils.registries import DATASET_REGISTRY 11 | 12 | ''' 13 | DATASET: 14 | DATASET_NAME: ForgeryNet 15 | ROOT_DIR: /some_where/ForgeryNet 16 | IMG_SIZE: 380 17 | SCALE_RATE: 1.0 18 | ROTATE_ANGLE: 10 19 | CUTOUT_H: 10 20 | CUTOUT_W: 10 21 | COMPRESSION_LOW: 65 22 | COMPRESSION_HIGH: 80 23 | ''' 24 | 25 | 26 | @DATASET_REGISTRY.register() 27 | class ForgeryNet(DeepFakeDataset): 28 | def __init__(self, dataset_cfg, mode='train'): 29 | self.dataset_name = dataset_cfg['DATASET_NAME'] 30 | self.mode = mode 31 | self.dataset_cfg = dataset_cfg 32 | info_txt_tag = mode.upper() + '_INFO_TXT' 33 | if mode == 'train': 34 | self.root_dir = os.path.join(dataset_cfg['ROOT_DIR'], 'Training') 35 | if dataset_cfg[info_txt_tag] != '': 36 | self.info_txt = dataset_cfg[info_txt_tag] 37 | else: 38 | self.info_txt = os.path.join( 39 | self.root_dir, 'image_list_train_retina2.txt' 40 | ) 41 | else: 42 | self.root_dir = os.path.join(dataset_cfg['ROOT_DIR'], 'Validation') 43 | if dataset_cfg[info_txt_tag] != '': 44 | self.info_txt = dataset_cfg[info_txt_tag] 45 | else: 46 | self.info_txt = os.path.join(self.root_dir, 'image_list.txt') 47 | 48 | info_list = open(self.info_txt).readlines() 49 | self.info_list = info_list 50 | 51 | def __getitem__(self, idx): 52 | info_line = self.info_list[idx] 53 | image_info = info_line.strip('\n').split() 54 | image_path = image_info[0] 55 | image_abs_path = os.path.join(self.root_dir, 'image2', image_path) 56 | mask_abs_path = get_mask_path_from_img_path( 57 | self.dataset_name, self.root_dir, image_path 58 | ) 59 | img, mask = get_image_from_path( 60 | image_abs_path, mask_abs_path, self.mode, self.dataset_cfg 61 | ) 62 | img_label_binary = int(image_info[1]) 63 | img_label_triple = int(image_info[2]) 64 | img_label_mul = int(image_info[3]) 65 | 66 | sample = { 67 | 'img': img, 68 | 'bin_label': [int(img_label_binary)], 69 | 'tri_label': [int(img_label_triple)], 70 | 'mul_label': [int(img_label_mul)], 71 | } 72 | 73 | sample['img'] = torch.FloatTensor(sample['img']) 74 | sample['bin_label'] = torch.FloatTensor(sample['bin_label']) 75 | sample['bin_label_onehot'] = self.label_to_one_hot( 76 | sample['bin_label'], 2 77 | ).squeeze() 78 | sample['tri_label'] = torch.FloatTensor(sample['tri_label']) 79 | sample['mul_label'] = torch.FloatTensor(sample['mul_label']) 80 | sample['mask'] = torch.FloatTensor(mask) 81 | 82 | return sample 83 | -------------------------------------------------------------------------------- /M2TR/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from M2TR.datasets.CelebDF import CelebDF 2 | from M2TR.datasets.dataset import DeepFakeDataset 3 | from M2TR.datasets.FFDF import FFDF 4 | from M2TR.datasets.ForgeryNet import ForgeryNet 5 | -------------------------------------------------------------------------------- /M2TR/datasets/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch.utils.data import Dataset 5 | 6 | 7 | class DeepFakeDataset(Dataset): 8 | def __init__( 9 | self, 10 | dataset_cfg, 11 | mode='train', 12 | ): 13 | dataset_name = dataset_cfg['DATASET_NAME'] 14 | assert dataset_name in [ 15 | 'ForgeryNet', 16 | 'FFDF', 17 | 'CelebDF', 18 | ], 'no dataset' 19 | assert mode in [ 20 | 'train', 21 | 'val', 22 | 'test', 23 | ], 'wrong mode' 24 | self.dataset_name = dataset_name 25 | self.mode = mode 26 | self.dataset_cfg = dataset_cfg 27 | self.root_dir = dataset_cfg['ROOT_DIR'] 28 | info_txt_tag = mode.upper() + '_INFO_TXT' 29 | if dataset_cfg[info_txt_tag] != '': 30 | self.info_txt = dataset_cfg[info_txt_tag] 31 | else: 32 | self.info_txt = os.path.join( 33 | self.root_dir, 34 | self.dataset_name + '_splits_' + mode + '.txt', 35 | ) 36 | self.info_list = open(self.info_txt).readlines() 37 | 38 | def __len__(self): 39 | return len(self.info_list) 40 | 41 | def label_to_one_hot(self, x, class_count): 42 | return torch.eye(class_count)[x.long(), :] 43 | -------------------------------------------------------------------------------- /M2TR/datasets/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | 4 | import albumentations 5 | import albumentations.pytorch 6 | import numpy as np 7 | import torchvision 8 | from PIL import Image 9 | 10 | 11 | class ResizeRandomCrop: 12 | def __init__(self, img_size=320, scale_rate=8 / 7, p=0.5): 13 | self.img_size = img_size 14 | self.scale_rate = scale_rate 15 | self.p = p 16 | 17 | def __call__(self, image, mask=None): 18 | if random.uniform(0, 1) < self.p: 19 | S1 = int(self.img_size * self.scale_rate) 20 | S2 = S1 21 | resize_func = torchvision.transforms.Resize((S1, S2)) 22 | image = resize_func(image) 23 | crop_params = torchvision.transforms.RandomCrop.get_params( 24 | image, (self.img_size, self.img_size) 25 | ) 26 | image = torchvision.transforms.functional.crop(image, *crop_params) 27 | if mask is not None: 28 | mask = resize_func(mask) 29 | mask = torchvision.transforms.functional.crop( 30 | mask, *crop_params 31 | ) 32 | 33 | else: 34 | resize_func = torchvision.transforms.Resize( 35 | (self.img_size, self.img_size) 36 | ) 37 | image = resize_func(image) 38 | if mask is not None: 39 | mask = resize_func(mask) 40 | 41 | return image, mask 42 | 43 | 44 | def transforms_mask(mask_size): 45 | return albumentations.Compose( 46 | [ 47 | albumentations.Resize(mask_size, mask_size), 48 | albumentations.pytorch.transforms.ToTensorV2(), 49 | ] 50 | ) 51 | 52 | 53 | def get_augmentations_from_list(augs: list, aug_cfg, one_of_p=1): 54 | ops = [] 55 | for aug in augs: 56 | if isinstance(aug, list): 57 | op = albumentations.OneOf 58 | param = get_augmentations_from_list(aug, aug_cfg) 59 | param = [param, one_of_p] 60 | else: 61 | op = getattr(albumentations, aug) 62 | param = ( 63 | aug_cfg[aug.upper() + '_PARAMS'] 64 | if aug.upper() + '_PARAMS' in aug_cfg 65 | else [] 66 | ) 67 | ops.append(op(*tuple(param))) 68 | return ops 69 | 70 | 71 | def get_transformations( 72 | mode, 73 | dataset_cfg, 74 | ): 75 | if mode == 'train': 76 | aug_cfg = dataset_cfg['TRAIN_AUGMENTATIONS'] 77 | else: 78 | aug_cfg = dataset_cfg['TEST_AUGMENTATIONS'] 79 | ops = get_augmentations_from_list(aug_cfg['COMPOSE'], aug_cfg) 80 | ops.append(albumentations.pytorch.transforms.ToTensorV2()) 81 | augmentations = albumentations.Compose(ops, p=1) 82 | return augmentations 83 | 84 | 85 | def get_image_from_path(img_path, mask_path, mode, dataset_cfg): 86 | img_size = dataset_cfg['IMG_SIZE'] 87 | scale_rate = dataset_cfg['SCALE_RATE'] 88 | 89 | img = Image.open(img_path) 90 | if mask_path is not None and os.path.exists(mask_path): 91 | mask = Image.open(mask_path).convert('L') 92 | else: 93 | mask = Image.fromarray(np.zeros((img_size, img_size))) 94 | 95 | trans_list = get_transformations( 96 | mode, 97 | dataset_cfg, 98 | ) 99 | if mode == 'train': 100 | crop = ResizeRandomCrop(img_size=img_size, scale_rate=scale_rate) 101 | img, mask = crop(image=img, mask=mask) 102 | 103 | img = np.asarray(img) 104 | img = trans_list(image=img)['image'] 105 | 106 | mask = np.asarray(mask) 107 | mask = transforms_mask(img_size)(image=mask)['image'] 108 | 109 | else: 110 | img = np.asarray(img) 111 | img = trans_list(image=img)['image'] 112 | mask = np.asarray(mask) 113 | mask = transforms_mask(img_size)(image=mask)['image'] 114 | 115 | return img, mask.float() 116 | 117 | 118 | def get_mask_path_from_img_path(dataset_name, root_dir, img_info): 119 | if dataset_name == 'ForgeryNet': 120 | root_dir = os.path.join(root_dir, 'spatial_localize') 121 | fore_path = img_info.split('/')[0] 122 | if 'train' in fore_path: 123 | img_info = img_info.replace('train_release', 'train_mask_release') 124 | else: 125 | img_info = img_info[20:] 126 | 127 | mask_complete_path = os.path.join(root_dir, img_info) 128 | 129 | elif 'FFDF' in dataset_name: 130 | mask_info = img_info.replace('images', 'masks') 131 | mask_complete_path = os.path.join(root_dir, mask_info) 132 | 133 | return mask_complete_path 134 | -------------------------------------------------------------------------------- /M2TR/models/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wdrink/M2TR-Multi-modal-Multi-scale-Transformers-for-Deepfake-Detection/6471c0fbcd84ecb5231ba4a4713fa749eab2ec07/M2TR/models/.DS_Store -------------------------------------------------------------------------------- /M2TR/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import BaseNetwork 2 | from .efficientnet import EfficientNet 3 | from .m2tr import M2TR 4 | -------------------------------------------------------------------------------- /M2TR/models/base.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class BaseNetwork(nn.Module): 5 | def __init__(self): 6 | super(BaseNetwork, self).__init__() 7 | 8 | def print_network(self): 9 | if isinstance(self, list): 10 | self = self[0] 11 | num_params = 0 12 | for param in self.parameters(): 13 | num_params += param.numel() 14 | print( 15 | 'Network [%s] was created. Total number of parameters: %.1f million. ' 16 | 'To see the architecture, do print(network).' 17 | % (type(self).__name__, num_params / 1000000) 18 | ) 19 | 20 | def init_weights(self, init_type='normal', gain=0.02): 21 | ''' 22 | initialize network's weights 23 | init_type: normal | xavier | kaiming | orthogonal 24 | https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/9451e70673400885567d08a9e97ade2524c700d0/models/networks.py#L39 25 | ''' 26 | 27 | def init_func(m): 28 | classname = m.__class__.__name__ 29 | if classname.find('InstanceNorm2d') != -1: 30 | if hasattr(m, 'weight') and m.weight is not None: 31 | nn.init.constant_(m.weight.data, 1.0) 32 | if hasattr(m, 'bias') and m.bias is not None: 33 | nn.init.constant_(m.bias.data, 0.0) 34 | elif hasattr(m, 'weight') and ( 35 | classname.find('Conv') != -1 or classname.find('Linear') != -1 36 | ): 37 | if init_type == 'normal': 38 | nn.init.normal_(m.weight.data, 0.0, gain) 39 | elif init_type == 'xavier': 40 | nn.init.xavier_normal_(m.weight.data, gain=gain) 41 | elif init_type == 'xavier_uniform': 42 | nn.init.xavier_uniform_(m.weight.data, gain=1.0) 43 | elif init_type == 'kaiming': 44 | nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 45 | elif init_type == 'orthogonal': 46 | nn.init.orthogonal_(m.weight.data, gain=gain) 47 | elif init_type == 'none': # uses pytorch's default init method 48 | m.reset_parameters() 49 | else: 50 | raise NotImplementedError( 51 | 'initialization method [%s] is not implemented' 52 | % init_type 53 | ) 54 | if hasattr(m, 'bias') and m.bias is not None: 55 | nn.init.constant_(m.bias.data, 0.0) 56 | 57 | self.apply(init_func) 58 | 59 | for m in self.children(): 60 | if hasattr(m, 'init_weights'): 61 | m.init_weights(init_type, gain) 62 | -------------------------------------------------------------------------------- /M2TR/models/efficientnet.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import math 3 | import re 4 | from functools import partial 5 | 6 | import torch 7 | from torch import nn 8 | from torch.nn import functional as F 9 | from torch.utils import model_zoo 10 | 11 | from M2TR.utils.registries import MODEL_REGISTRY 12 | 13 | ''' 14 | MODEL: 15 | MODEL_NAME: efficientnet 16 | NAME: efficientnet-b4 17 | PRETRAINED: True 18 | ''' 19 | 20 | 21 | # Parameters for the entire model (stem, all blocks, and head) 22 | GlobalParams = collections.namedtuple( 23 | 'GlobalParams', 24 | [ 25 | 'batch_norm_momentum', 26 | 'batch_norm_epsilon', 27 | 'dropout_rate', 28 | 'num_classes', 29 | 'width_coefficient', 30 | 'depth_coefficient', 31 | 'depth_divisor', 32 | 'min_depth', 33 | 'drop_connect_rate', 34 | 'image_size', 35 | ], 36 | ) 37 | 38 | # Parameters for an individual model block 39 | BlockArgs = collections.namedtuple( 40 | 'BlockArgs', 41 | [ 42 | 'kernel_size', 43 | 'num_repeat', 44 | 'input_filters', 45 | 'output_filters', 46 | 'expand_ratio', 47 | 'id_skip', 48 | 'stride', 49 | 'se_ratio', 50 | ], 51 | ) 52 | 53 | # Change namedtuple defaults 54 | GlobalParams.__new__.__defaults__ = (None,) * len(GlobalParams._fields) 55 | BlockArgs.__new__.__defaults__ = (None,) * len(BlockArgs._fields) 56 | 57 | 58 | efficientnet_params = { 59 | # Coefficients: width,depth,res,dropout 60 | 'efficientnet-b0': (1.0, 1.0, 224, 0.2), 61 | 'efficientnet-b1': (1.0, 1.1, 240, 0.2), 62 | 'efficientnet-b2': (1.1, 1.2, 260, 0.3), 63 | 'efficientnet-b3': (1.2, 1.4, 300, 0.3), 64 | 'efficientnet-b4': (1.4, 1.8, 380, 0.4), 65 | 'efficientnet-b5': (1.6, 2.2, 456, 0.4), 66 | 'efficientnet-b6': (1.8, 2.6, 528, 0.5), 67 | 'efficientnet-b7': (2.0, 3.1, 600, 0.5), 68 | 'efficientnet-b8': (2.2, 3.6, 672, 0.5), 69 | 'efficientnet-l2': (4.3, 5.3, 800, 0.5), 70 | } 71 | 72 | 73 | blocks_args_str = [ 74 | 'r1_k3_s11_e1_i32_o16_se0.25', 75 | 'r2_k3_s22_e6_i16_o24_se0.25', 76 | 'r2_k5_s22_e6_i24_o40_se0.25', 77 | 'r3_k3_s22_e6_i40_o80_se0.25', 78 | 'r3_k5_s11_e6_i80_o112_se0.25', 79 | 'r4_k5_s22_e6_i112_o192_se0.25', 80 | 'r1_k3_s11_e6_i192_o320_se0.25', 81 | ] 82 | 83 | 84 | url_map = { 85 | 'efficientnet-b0': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth', 86 | 'efficientnet-b1': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b1-f1951068.pth', 87 | 'efficientnet-b2': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b2-8bb594d6.pth', 88 | 'efficientnet-b3': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b3-5fb5a3c3.pth', 89 | 'efficientnet-b4': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b4-6ed6700e.pth', 90 | 'efficientnet-b5': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b5-b6417697.pth', 91 | 'efficientnet-b6': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b6-c76e70fd.pth', 92 | 'efficientnet-b7': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b7-dcc49843.pth', 93 | } 94 | 95 | 96 | url_map_advprop = { 97 | 'efficientnet-b0': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b0-b64d5a18.pth', 98 | 'efficientnet-b1': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b1-0f3ce85a.pth', 99 | 'efficientnet-b2': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b2-6e9d97e5.pth', 100 | 'efficientnet-b3': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b3-cdd7c0f4.pth', 101 | 'efficientnet-b4': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b4-44fb3a87.pth', 102 | 'efficientnet-b5': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b5-86493f6b.pth', 103 | 'efficientnet-b6': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b6-ac80338e.pth', 104 | 'efficientnet-b7': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b7-4652b6dd.pth', 105 | 'efficientnet-b8': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b8-22a8fe65.pth', 106 | } 107 | 108 | 109 | class BlockDecoder(object): 110 | @staticmethod 111 | def _decode_block_string(block_string): 112 | """Gets a block through a string notation of arguments.""" 113 | assert isinstance(block_string, str) 114 | 115 | ops = block_string.split('_') 116 | options = {} 117 | for op in ops: 118 | splits = re.split(r'(\d.*)', op) 119 | if len(splits) >= 2: 120 | key, value = splits[:2] 121 | options[key] = value 122 | 123 | # Check stride 124 | assert ('s' in options and len(options['s']) == 1) or ( 125 | len(options['s']) == 2 and options['s'][0] == options['s'][1] 126 | ) 127 | 128 | return BlockArgs( 129 | kernel_size=int(options['k']), 130 | num_repeat=int(options['r']), 131 | input_filters=int(options['i']), 132 | output_filters=int(options['o']), 133 | expand_ratio=int(options['e']), 134 | id_skip=('noskip' not in block_string), 135 | se_ratio=float(options['se']) if 'se' in options else None, 136 | stride=[int(options['s'][0])], 137 | ) 138 | 139 | @staticmethod 140 | def decode(string_list): 141 | assert isinstance(string_list, list) 142 | blocks_args = [] 143 | for block_string in string_list: 144 | blocks_args.append(BlockDecoder._decode_block_string(block_string)) 145 | return blocks_args 146 | 147 | 148 | class SwishImplementation(torch.autograd.Function): 149 | @staticmethod 150 | def forward(ctx, i): 151 | result = i * torch.sigmoid(i) 152 | ctx.save_for_backward(i) 153 | return result 154 | 155 | @staticmethod 156 | def backward(ctx, grad_output): 157 | i = ctx.saved_variables[0] 158 | sigmoid_i = torch.sigmoid(i) 159 | return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i))) 160 | 161 | 162 | class MemoryEfficientSwish(nn.Module): 163 | def forward(self, x): 164 | return SwishImplementation.apply(x) 165 | 166 | 167 | class Swish(nn.Module): 168 | def forward(self, x): 169 | return x * torch.sigmoid(x) 170 | 171 | 172 | def round_filters(filters, global_params): 173 | """Calculate and round number of filters based on depth multiplier.""" 174 | multiplier = global_params.width_coefficient 175 | if not multiplier: 176 | return filters 177 | divisor = global_params.depth_divisor 178 | min_depth = global_params.min_depth 179 | filters *= multiplier 180 | min_depth = min_depth or divisor 181 | new_filters = max( 182 | min_depth, int(filters + divisor / 2) // divisor * divisor 183 | ) 184 | if new_filters < 0.9 * filters: # prevent rounding by more than 10% 185 | new_filters += divisor 186 | return int(new_filters) 187 | 188 | 189 | def round_repeats(repeats, global_params): 190 | """Round number of filters based on depth multiplier.""" 191 | multiplier = global_params.depth_coefficient 192 | if not multiplier: 193 | return repeats 194 | return int(math.ceil(multiplier * repeats)) 195 | 196 | 197 | def drop_connect(inputs, p, training): 198 | """Drop connect.""" 199 | if not training: 200 | return inputs 201 | batch_size = inputs.shape[0] 202 | keep_prob = 1 - p 203 | random_tensor = keep_prob 204 | random_tensor += torch.rand( 205 | [batch_size, 1, 1, 1], dtype=inputs.dtype, device=inputs.device 206 | ) 207 | binary_tensor = torch.floor(random_tensor) 208 | output = inputs / keep_prob * binary_tensor 209 | return output 210 | 211 | 212 | class Identity(nn.Module): 213 | def __init__( 214 | self, 215 | ): 216 | super(Identity, self).__init__() 217 | 218 | def forward(self, input): 219 | return input 220 | 221 | 222 | class MBConvBlock(nn.Module): 223 | """ 224 | Mobile Inverted Residual Bottleneck Block 225 | Args: 226 | block_args (namedtuple): BlockArgs, see above 227 | global_params (namedtuple): GlobalParam, see above 228 | Attributes: 229 | has_se (bool): Whether the block contains a Squeeze and Excitation layer. 230 | """ 231 | 232 | def __init__(self, block_args, global_params): 233 | super().__init__() 234 | self._block_args = block_args 235 | self._bn_mom = 1 - global_params.batch_norm_momentum 236 | self._bn_eps = global_params.batch_norm_epsilon 237 | self.has_se = (self._block_args.se_ratio is not None) and ( 238 | 0 < self._block_args.se_ratio <= 1 239 | ) 240 | self.id_skip = block_args.id_skip # skip connection and drop connect 241 | 242 | # Get static or dynamic convolution depending on image size 243 | Conv2d = partial( 244 | Conv2dStaticSamePadding, image_size=global_params.image_size 245 | ) 246 | 247 | # Expansion phase 248 | inp = self._block_args.input_filters # number of input channels 249 | oup = ( 250 | self._block_args.input_filters * self._block_args.expand_ratio 251 | ) # number of output channels 252 | if self._block_args.expand_ratio != 1: 253 | self._expand_conv = Conv2d( 254 | in_channels=inp, out_channels=oup, kernel_size=1, bias=False 255 | ) 256 | self._bn0 = nn.BatchNorm2d( 257 | num_features=oup, momentum=self._bn_mom, eps=self._bn_eps 258 | ) 259 | 260 | # Depthwise convolution phase 261 | k = self._block_args.kernel_size 262 | s = self._block_args.stride 263 | self._depthwise_conv = Conv2d( 264 | in_channels=oup, 265 | out_channels=oup, 266 | groups=oup, # groups makes it depthwise 267 | kernel_size=k, 268 | stride=s, 269 | bias=False, 270 | ) 271 | self._bn1 = nn.BatchNorm2d( 272 | num_features=oup, momentum=self._bn_mom, eps=self._bn_eps 273 | ) 274 | 275 | # Squeeze and Excitation layer, if desired 276 | if self.has_se: 277 | num_squeezed_channels = max( 278 | 1, 279 | int(self._block_args.input_filters * self._block_args.se_ratio), 280 | ) 281 | self._se_reduce = Conv2d( 282 | in_channels=oup, 283 | out_channels=num_squeezed_channels, 284 | kernel_size=1, 285 | ) 286 | self._se_expand = Conv2d( 287 | in_channels=num_squeezed_channels, 288 | out_channels=oup, 289 | kernel_size=1, 290 | ) 291 | 292 | # Output phase 293 | final_oup = self._block_args.output_filters 294 | self._project_conv = Conv2d( 295 | in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False 296 | ) 297 | self._bn2 = nn.BatchNorm2d( 298 | num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps 299 | ) 300 | self._swish = MemoryEfficientSwish() 301 | 302 | def forward(self, inputs, drop_connect_rate=None): 303 | # Expansion and Depthwise Convolution 304 | x = inputs 305 | if self._block_args.expand_ratio != 1: 306 | x = self._swish(self._bn0(self._expand_conv(inputs))) 307 | x = self._swish(self._bn1(self._depthwise_conv(x))) 308 | 309 | # Squeeze and Excitation 310 | if self.has_se: 311 | x_squeezed = F.adaptive_avg_pool2d(x, 1) 312 | x_squeezed = self._se_expand( 313 | self._swish(self._se_reduce(x_squeezed)) 314 | ) 315 | x = torch.sigmoid(x_squeezed) * x 316 | 317 | x = self._bn2(self._project_conv(x)) 318 | 319 | # Skip connection and drop connect 320 | input_filters, output_filters = ( 321 | self._block_args.input_filters, 322 | self._block_args.output_filters, 323 | ) 324 | if ( 325 | self.id_skip 326 | and self._block_args.stride == 1 327 | and input_filters == output_filters 328 | ): 329 | if drop_connect_rate: 330 | x = drop_connect(x, p=drop_connect_rate, training=self.training) 331 | x = x + inputs # skip connection 332 | return x 333 | 334 | def set_swish(self, memory_efficient=True): 335 | """Sets swish function as memory efficient (for training) or standard (for export)""" 336 | self._swish = MemoryEfficientSwish() if memory_efficient else Swish() 337 | 338 | 339 | class Conv2dStaticSamePadding(nn.Conv2d): 340 | """2D Convolutions like TensorFlow, for a fixed image size""" 341 | 342 | def __init__( 343 | self, in_channels, out_channels, kernel_size, image_size=None, **kwargs 344 | ): 345 | super().__init__(in_channels, out_channels, kernel_size, **kwargs) 346 | self.stride = ( 347 | self.stride if len(self.stride) == 2 else [self.stride[0]] * 2 348 | ) 349 | 350 | # Calculate padding based on image size and save it 351 | assert image_size is not None 352 | ih, iw = ( 353 | image_size if type(image_size) == list else [image_size, image_size] 354 | ) 355 | kh, kw = self.weight.size()[-2:] 356 | sh, sw = self.stride 357 | oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) 358 | pad_h = max( 359 | (oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0 360 | ) 361 | pad_w = max( 362 | (ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0 363 | ) 364 | if pad_h > 0 or pad_w > 0: 365 | self.static_padding = nn.ZeroPad2d( 366 | (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2) 367 | ) 368 | else: 369 | self.static_padding = Identity() 370 | 371 | def forward(self, x): 372 | x = self.static_padding(x) 373 | x = F.conv2d( 374 | x, 375 | self.weight, 376 | self.bias, 377 | self.stride, 378 | self.padding, 379 | self.dilation, 380 | self.groups, 381 | ) 382 | return x 383 | 384 | 385 | @MODEL_REGISTRY.register() 386 | class EfficientNet(nn.Module): 387 | def __init__(self, model_cfg): 388 | super().__init__() 389 | model_name = model_cfg['NAME'] 390 | self.check_model_name_is_valid(model_name) 391 | blocks_args = BlockDecoder.decode(blocks_args_str) 392 | w, d, s, p = efficientnet_params[model_name] 393 | # note: all models have drop connect rate = 0.2 394 | global_params = GlobalParams( 395 | batch_norm_momentum=0.99, 396 | batch_norm_epsilon=1e-3, 397 | dropout_rate=p, 398 | drop_connect_rate=0.2, 399 | num_classes=2, 400 | width_coefficient=w, 401 | depth_coefficient=d, 402 | depth_divisor=8, 403 | min_depth=None, 404 | image_size=s, 405 | ) 406 | assert isinstance(blocks_args, list), 'blocks_args should be a list' 407 | assert len(blocks_args) > 0, 'block args must be greater than 0' 408 | self.escape = '' 409 | self._global_params = global_params 410 | self._blocks_args = blocks_args 411 | Conv2d = partial( 412 | Conv2dStaticSamePadding, image_size=global_params.image_size 413 | ) 414 | 415 | # Batch norm parameters 416 | bn_mom = 1 - self._global_params.batch_norm_momentum 417 | bn_eps = self._global_params.batch_norm_epsilon 418 | 419 | # Stem 420 | in_channels = 3 # rgb 421 | out_channels = round_filters( 422 | 32, self._global_params 423 | ) # number of output channels 424 | self._conv_stem = Conv2d( 425 | in_channels, out_channels, kernel_size=3, stride=2, bias=False 426 | ) 427 | self._bn0 = nn.BatchNorm2d( 428 | num_features=out_channels, momentum=bn_mom, eps=bn_eps 429 | ) 430 | 431 | # Build blocks 432 | self._blocks = nn.ModuleList([]) 433 | self.stage_map = [] 434 | stage_count = 0 435 | for block_args in self._blocks_args: 436 | 437 | # Update block input and output filters based on depth multiplier. 438 | block_args = block_args._replace( 439 | input_filters=round_filters( 440 | block_args.input_filters, self._global_params 441 | ), 442 | output_filters=round_filters( 443 | block_args.output_filters, self._global_params 444 | ), 445 | num_repeat=round_repeats( 446 | block_args.num_repeat, self._global_params 447 | ), 448 | ) 449 | stage_count += 1 450 | self.stage_map += [''] * (block_args.num_repeat - 1) 451 | self.stage_map.append('b%s' % stage_count) 452 | # The first block needs to take care of stride and filter size increase. 453 | self._blocks.append(MBConvBlock(block_args, self._global_params)) 454 | 455 | if block_args.num_repeat > 1: 456 | block_args = block_args._replace( 457 | input_filters=block_args.output_filters, stride=1 458 | ) 459 | for _ in range(block_args.num_repeat - 1): 460 | self._blocks.append( 461 | MBConvBlock(block_args, self._global_params) 462 | ) 463 | 464 | # Head 465 | in_channels = block_args.output_filters # output of final block 466 | out_channels = round_filters(1280, self._global_params) 467 | self._conv_head = Conv2d( 468 | in_channels, out_channels, kernel_size=1, bias=False 469 | ) 470 | self._bn1 = nn.BatchNorm2d( 471 | num_features=out_channels, momentum=bn_mom, eps=bn_eps 472 | ) 473 | 474 | # Final linear layer 475 | self._avg_pooling = nn.AdaptiveAvgPool2d(1) 476 | self._dropout = nn.Dropout(self._global_params.dropout_rate) 477 | self._fc = nn.Linear(out_channels, self._global_params.num_classes) 478 | self._swish = MemoryEfficientSwish() 479 | 480 | if model_cfg['PRETRAINED']: 481 | self.load_pretrained_weights(model_name, advprop=True) 482 | 483 | def set_swish(self, memory_efficient=True): 484 | """Sets swish function as memory efficient (for training) or standard (for export)""" 485 | self._swish = MemoryEfficientSwish() if memory_efficient else Swish() 486 | for block in self._blocks: 487 | block.set_swish(memory_efficient) 488 | 489 | def extract_textures(self, inputs, layers): 490 | # Stem 491 | x = self._swish(self._bn0(self._conv_stem(inputs))) 492 | layers['b0'] = x 493 | # Blocks 494 | for idx, block in enumerate(self._blocks[:6]): 495 | drop_connect_rate = self._global_params.drop_connect_rate 496 | if drop_connect_rate: 497 | drop_connect_rate *= float(idx) / len(self._blocks) 498 | x = block(x, drop_connect_rate=drop_connect_rate) 499 | stage = self.stage_map[idx] 500 | if stage: 501 | layers[stage] = x 502 | if stage == self.escape: 503 | return None 504 | 505 | return x 506 | 507 | def extract_features(self, x, layers): 508 | # Blocks 509 | for idx, block in enumerate(self._blocks[6:]): 510 | idx += 6 511 | drop_connect_rate = self._global_params.drop_connect_rate 512 | if drop_connect_rate: 513 | drop_connect_rate *= float(idx) / len(self._blocks) 514 | x = block(x, drop_connect_rate=drop_connect_rate) 515 | stage = self.stage_map[idx] 516 | if stage: 517 | layers[stage] = x 518 | if stage == self.escape: 519 | return None 520 | # Head 521 | x = self._bn1(self._conv_head(x)) 522 | x = self._swish(x) 523 | return x 524 | 525 | def forward(self, samples): 526 | x = samples['img'] 527 | bs = x.size(0) 528 | layers = {} 529 | x = self.extract_textures(x, layers) 530 | x = self.extract_features(x, layers) 531 | if x is None: 532 | return layers 533 | layers['final'] = x 534 | x = self._avg_pooling(x) 535 | x = x.view(bs, -1) 536 | x = self._dropout(x) 537 | x = self._fc(x) 538 | layers['logits'] = x 539 | return layers 540 | 541 | def load_pretrained_weights(self, model_name, advprop=False): 542 | url_map_ = url_map_advprop if advprop else url_map 543 | state_dict = model_zoo.load_url(url_map_[model_name]) 544 | state_dict.pop('_fc.weight') 545 | state_dict.pop('_fc.bias') 546 | res = self.load_state_dict(state_dict, strict=False) 547 | assert set(res.missing_keys) == set( 548 | ['_fc.weight', '_fc.bias'] 549 | ), 'issue loading pretrained weights' 550 | print('Loaded pretrained weights for {}'.format(model_name)) 551 | 552 | def check_model_name_is_valid(self, model_name): 553 | valid_models = ['efficientnet-b' + str(i) for i in range(9)] 554 | if model_name not in valid_models: 555 | raise ValueError( 556 | 'model_name should be one of: ' + ', '.join(valid_models) 557 | ) 558 | -------------------------------------------------------------------------------- /M2TR/models/m2tr.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.fft 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from M2TR.utils.registries import MODEL_REGISTRY 9 | 10 | from .base import BaseNetwork 11 | from .xception import Xception 12 | from .efficientnet import EfficientNet 13 | from .modules.head import Classifier2D, Localizer 14 | from .modules.transformer_block import FeedForward2D 15 | 16 | 17 | 18 | class GlobalFilter(nn.Module): 19 | def __init__(self, dim=32, h=80, w=41, fp32fft=True): 20 | super().__init__() 21 | self.complex_weight = nn.Parameter( 22 | torch.randn(h, w, dim, 2, dtype=torch.float32) * 0.02 23 | ) 24 | self.w = w 25 | self.h = h 26 | self.fp32fft = fp32fft 27 | 28 | def forward(self, x): 29 | b, _, a, b = x.size() 30 | x = x.permute(0, 2, 3, 1).contiguous() 31 | 32 | if self.fp32fft: 33 | dtype = x.dtype 34 | x = x.to(torch.float32) 35 | 36 | x = torch.fft.rfft2(x, dim=(1, 2), norm="ortho") 37 | weight = torch.view_as_complex(self.complex_weight) 38 | x = x * weight 39 | x = torch.fft.irfft2(x, s=(a, b), dim=(1, 2), norm="ortho") 40 | 41 | if self.fp32fft: 42 | x = x.to(dtype) 43 | 44 | x = x.permute(0, 3, 1, 2).contiguous() 45 | 46 | return x 47 | 48 | 49 | class FreqBlock(nn.Module): 50 | def __init__(self, dim, h=80, w=41, fp32fft=True): 51 | super().__init__() 52 | self.filter = GlobalFilter(dim, h=h, w=w, fp32fft=fp32fft) 53 | self.feed_forward = FeedForward2D(in_channel=dim, out_channel=dim) 54 | 55 | def forward(self, x): 56 | x = x + self.feed_forward(self.filter(x)) 57 | return x 58 | 59 | 60 | def attention(query, key, value): 61 | scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt( 62 | query.size(-1) 63 | ) 64 | p_attn = F.softmax(scores, dim=-1) 65 | p_val = torch.matmul(p_attn, value) 66 | return p_val, p_attn 67 | 68 | 69 | class MultiHeadedAttention(nn.Module): 70 | """ 71 | Take in model size and number of heads. 72 | """ 73 | 74 | def __init__(self, patchsize, d_model): 75 | super().__init__() 76 | self.patchsize = patchsize 77 | self.query_embedding = nn.Conv2d( 78 | d_model, d_model, kernel_size=1, padding=0 79 | ) 80 | self.value_embedding = nn.Conv2d( 81 | d_model, d_model, kernel_size=1, padding=0 82 | ) 83 | self.key_embedding = nn.Conv2d( 84 | d_model, d_model, kernel_size=1, padding=0 85 | ) 86 | self.output_linear = nn.Sequential( 87 | nn.Conv2d(d_model, d_model, kernel_size=3, padding=1), 88 | nn.BatchNorm2d(d_model), 89 | nn.LeakyReLU(0.2, inplace=True), 90 | ) 91 | 92 | def forward(self, x): 93 | b, c, h, w = x.size() 94 | d_k = c // len(self.patchsize) 95 | output = [] 96 | _query = self.query_embedding(x) 97 | _key = self.key_embedding(x) 98 | _value = self.value_embedding(x) 99 | attentions = [] 100 | for (width, height), query, key, value in zip( 101 | self.patchsize, 102 | torch.chunk(_query, len(self.patchsize), dim=1), 103 | torch.chunk(_key, len(self.patchsize), dim=1), 104 | torch.chunk(_value, len(self.patchsize), dim=1), 105 | ): 106 | out_w, out_h = w // width, h // height 107 | 108 | # 1) embedding and reshape 109 | query = query.view(b, d_k, out_h, height, out_w, width) 110 | query = ( 111 | query.permute(0, 2, 4, 1, 3, 5) 112 | .contiguous() 113 | .view(b, out_h * out_w, d_k * height * width) 114 | ) 115 | key = key.view(b, d_k, out_h, height, out_w, width) 116 | key = ( 117 | key.permute(0, 2, 4, 1, 3, 5) 118 | .contiguous() 119 | .view(b, out_h * out_w, d_k * height * width) 120 | ) 121 | value = value.view(b, d_k, out_h, height, out_w, width) 122 | value = ( 123 | value.permute(0, 2, 4, 1, 3, 5) 124 | .contiguous() 125 | .view(b, out_h * out_w, d_k * height * width) 126 | ) 127 | 128 | y, _ = attention(query, key, value) 129 | 130 | # 3) "Concat" using a view and apply a final linear. 131 | y = y.view(b, out_h, out_w, d_k, height, width) 132 | y = y.permute(0, 3, 1, 4, 2, 5).contiguous().view(b, d_k, h, w) 133 | attentions.append(y) 134 | output.append(y) 135 | 136 | output = torch.cat(output, 1) 137 | self_attention = self.output_linear(output) 138 | 139 | return self_attention 140 | 141 | 142 | class TransformerBlock(nn.Module): 143 | """ 144 | Transformer = MultiHead_Attention + Feed_Forward with sublayer connection 145 | """ 146 | 147 | def __init__(self, patchsize, in_channel=256): 148 | super().__init__() 149 | self.attention = MultiHeadedAttention(patchsize, d_model=in_channel) 150 | self.feed_forward = FeedForward2D( 151 | in_channel=in_channel, out_channel=in_channel 152 | ) 153 | 154 | def forward(self, rgb): 155 | self_attention = self.attention(rgb) 156 | output = rgb + self_attention 157 | output = output + self.feed_forward(output) 158 | return output 159 | 160 | 161 | class CMA_Block(nn.Module): 162 | def __init__(self, in_channel, hidden_channel, out_channel): 163 | super(CMA_Block, self).__init__() 164 | 165 | self.conv1 = nn.Conv2d( 166 | in_channel, hidden_channel, kernel_size=1, stride=1, padding=0 167 | ) 168 | self.conv2 = nn.Conv2d( 169 | in_channel, hidden_channel, kernel_size=1, stride=1, padding=0 170 | ) 171 | self.conv3 = nn.Conv2d( 172 | in_channel, hidden_channel, kernel_size=1, stride=1, padding=0 173 | ) 174 | 175 | self.scale = hidden_channel ** -0.5 176 | 177 | self.conv4 = nn.Sequential( 178 | nn.Conv2d( 179 | hidden_channel, out_channel, kernel_size=1, stride=1, padding=0 180 | ), 181 | nn.BatchNorm2d(out_channel), 182 | nn.LeakyReLU(0.2, inplace=True), 183 | ) 184 | 185 | def forward(self, rgb, freq): 186 | _, _, h, w = rgb.size() 187 | 188 | q = self.conv1(rgb) 189 | k = self.conv2(freq) 190 | v = self.conv3(freq) 191 | 192 | q = q.view(q.size(0), q.size(1), q.size(2) * q.size(3)).transpose( 193 | -2, -1 194 | ) 195 | k = k.view(k.size(0), k.size(1), k.size(2) * k.size(3)) 196 | 197 | attn = torch.matmul(q, k) * self.scale 198 | m = attn.softmax(dim=-1) 199 | 200 | v = v.view(v.size(0), v.size(1), v.size(2) * v.size(3)).transpose( 201 | -2, -1 202 | ) 203 | z = torch.matmul(m, v) 204 | z = z.view(z.size(0), h, w, -1) 205 | z = z.permute(0, 3, 1, 2).contiguous() 206 | 207 | output = rgb + self.conv4(z) 208 | 209 | return output 210 | 211 | 212 | class PatchTrans(BaseNetwork): 213 | def __init__(self, in_channel, in_size): 214 | super(PatchTrans, self).__init__() 215 | self.in_size = in_size 216 | 217 | patchsize = [ 218 | (in_size, in_size), 219 | (in_size // 2, in_size // 2), 220 | (in_size // 4, in_size // 4), 221 | (in_size // 8, in_size // 8), 222 | ] 223 | 224 | self.t = TransformerBlock(patchsize, in_channel=in_channel) 225 | 226 | def forward(self, enc_feat): 227 | output = self.t(enc_feat) 228 | return output 229 | 230 | 231 | @MODEL_REGISTRY.register() 232 | class M2TR(BaseNetwork): 233 | def __init__(self, model_cfg): 234 | super(M2TR, self).__init__() 235 | img_size = model_cfg["IMG_SIZE"] 236 | backbone = model_cfg["BACKBONE"] 237 | texture_layer = model_cfg["TEXTURE_LAYER"] 238 | feature_layer = model_cfg["FEATURE_LAYER"] 239 | depth = model_cfg["DEPTH"] 240 | num_classes = model_cfg["NUM_CLASSES"] 241 | drop_ratio = model_cfg["DROP_RATIO"] 242 | has_decoder = model_cfg["HAS_DECODER"] 243 | 244 | freq_h = img_size // 4 245 | freq_w = freq_h // 2 + 1 246 | 247 | if "xception" in backbone: 248 | self.model = Xception(num_classes) 249 | elif backbone.split("-")[0] == "efficientnet": 250 | self.model = EfficientNet({'NAME': backbone, 'PRETRAINED': True}) 251 | 252 | self.texture_layer = texture_layer 253 | self.feature_layer = feature_layer 254 | 255 | with torch.no_grad(): 256 | input = {"img": torch.zeros(1, 3, img_size, img_size)} 257 | layers = self.model(input) 258 | texture_dim = layers[self.texture_layer].shape[1] 259 | feature_dim = layers[self.feature_layer].shape[1] 260 | 261 | self.layers = nn.ModuleList([]) 262 | for _ in range(depth): 263 | self.layers.append( 264 | nn.ModuleList( 265 | [ 266 | PatchTrans(in_channel=texture_dim, in_size=freq_h), 267 | FreqBlock(dim=texture_dim, h=freq_h, w=freq_w), 268 | CMA_Block( 269 | in_channel=texture_dim, 270 | hidden_channel=texture_dim, 271 | out_channel=texture_dim, 272 | ), 273 | ] 274 | ) 275 | ) 276 | 277 | self.classifier = Classifier2D( 278 | feature_dim, num_classes, drop_ratio, "sigmoid" 279 | ) 280 | 281 | self.has_decoder = has_decoder 282 | if self.has_decoder: 283 | self.decoder = Localizer(texture_dim, 1) 284 | 285 | def forward(self, x): 286 | rgb = x["img"] 287 | B = rgb.size(0) 288 | 289 | layers = {} 290 | rgb = self.model.extract_textures(rgb, layers) 291 | 292 | for attn, filter, cma in self.layers: 293 | rgb = attn(rgb) 294 | freq = filter(rgb) 295 | rgb = cma(rgb, freq) 296 | 297 | features = self.model.extract_features(rgb, layers) 298 | features = F.adaptive_avg_pool2d(features, (1, 1)) 299 | features = features.view(B, features.size(1)) 300 | 301 | logits = self.classifier(features) 302 | 303 | if self.has_decoder: 304 | mask = self.decoder(rgb) 305 | mask = mask.squeeze(-1) 306 | 307 | else: 308 | mask = None 309 | 310 | output = {"logits": logits, "mask": mask, "features:": features} 311 | return output 312 | 313 | 314 | if __name__ == "__main__": 315 | from torchsummary import summary 316 | 317 | model = M2TR(num_classes=1, has_decoder=False) 318 | model.cuda() 319 | summary(model, input_size=(3, 320, 320), batch_size=12, device="cuda") 320 | -------------------------------------------------------------------------------- /M2TR/models/modules/conv_block.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | 5 | class Deconv(nn.Module): 6 | def __init__(self, input_channel, output_channel, kernel_size=3, padding=0): 7 | super().__init__() 8 | self.conv = nn.Conv2d( 9 | input_channel, 10 | output_channel, 11 | kernel_size=kernel_size, 12 | stride=1, 13 | padding=padding, 14 | ) 15 | 16 | self.leaky_relu = nn.LeakyReLU(0.2, inplace=True) 17 | 18 | def forward(self, x): 19 | x = F.interpolate( 20 | x, scale_factor=2, mode='bilinear', align_corners=True 21 | ) 22 | out = self.conv(x) 23 | out = self.leaky_relu(out) 24 | return out 25 | 26 | 27 | class ConvBN(nn.Module): 28 | def __init__(self, in_features, out_features): 29 | self.conv = nn.Conv2d(in_features, out_features, 3, padding=1) 30 | self.bn = nn.BatchNorm2d(out_features) 31 | 32 | def forward(self, x): 33 | out = self.conv(x) 34 | out = self.bn(out) 35 | return out 36 | -------------------------------------------------------------------------------- /M2TR/models/modules/gram_block.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class GramMatrix(nn.Module): 5 | def __init__(self): 6 | super(GramMatrix, self).__init__() 7 | 8 | def forward(self, x): 9 | b, c, h, w = x.size() 10 | feature = x.view(b, c, h * w) 11 | feature_t = feature.transpose(1, 2) 12 | gram = feature.bmm(feature_t) 13 | b, h, w = gram.size() 14 | gram = gram.view(b, 1, h, w) 15 | return gram 16 | 17 | 18 | class GramBlock(nn.Module): 19 | def __init__(self, in_channels): 20 | super(GramBlock, self).__init__() 21 | self.conv1 = nn.Conv2d( 22 | in_channels, 32, kernel_size=3, stride=1, padding=2 23 | ) 24 | self.gramMatrix = GramMatrix() 25 | self.conv2 = nn.Sequential( 26 | nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=2), 27 | nn.BatchNorm2d(16), 28 | nn.ReLU(inplace=True), 29 | ) 30 | self.conv3 = nn.Sequential( 31 | nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=2), 32 | nn.BatchNorm2d(32), 33 | nn.ReLU(inplace=True), 34 | ) 35 | self.pool = nn.AdaptiveAvgPool2d((1, 1)) 36 | 37 | def forward(self, x): 38 | x = self.conv1(x) 39 | x = self.gramMatrix(x) 40 | x = self.conv2(x) 41 | x = self.conv3(x) 42 | x = self.pool(x) 43 | return x 44 | -------------------------------------------------------------------------------- /M2TR/models/modules/head.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from M2TR.models.modules.conv_block import Deconv 4 | 5 | 6 | class Classifier2D(nn.Module): 7 | def __init__( 8 | self, 9 | dim_in, 10 | num_classes, 11 | dropout_rate=0.0, 12 | act_func="softmax", 13 | ): 14 | """ 15 | Perform linear projection and activation as head for tranformers. 16 | Args: 17 | dim_in (int): the channel dimension of the input to the head. 18 | num_classes (int): the channel dimensions of the output to the head. 19 | dropout_rate (float): dropout rate. If equal to 0.0, perform no 20 | dropout. 21 | act_func (string): activation function to use. 'softmax': applies 22 | softmax on the output. 'sigmoid': applies sigmoid on the output. 23 | """ 24 | super(Classifier2D, self).__init__() 25 | if dropout_rate > 0.0: 26 | self.dropout = nn.Dropout(dropout_rate) 27 | self.projection = nn.Linear(dim_in, num_classes, bias=True) 28 | 29 | # Softmax for evaluation and testing. 30 | if act_func == "softmax": 31 | self.act = nn.Softmax(dim=1) 32 | elif act_func == "sigmoid": 33 | self.act = nn.Sigmoid() 34 | else: 35 | raise NotImplementedError( 36 | "{} is not supported as an activation" 37 | "function.".format(act_func) 38 | ) 39 | 40 | def forward(self, x): 41 | if hasattr(self, "dropout"): 42 | x = self.dropout(x) 43 | x = self.projection(x) 44 | 45 | if not self.training: 46 | x = self.act(x) 47 | return x 48 | 49 | 50 | class Localizer(nn.Module): 51 | def __init__(self, in_channel, output_channel): 52 | super(self, Localizer).__init__() 53 | self.deconv1 = Deconv(in_channel, in_channel) 54 | hidden_dim = in_channel // 2 55 | self.conv1 = nn.Sequential( 56 | nn.Conv2d(32, hidden_dim, kernel_size=3, stride=1, padding=1), 57 | nn.LeakyReLU(0.2, inplace=True), 58 | ) 59 | self.deconv2 = Deconv(hidden_dim, hidden_dim, kernel_size=3, padding=1) 60 | self.conv2 = nn.Sequential( 61 | nn.LeakyReLU(0.2, inplace=True), 62 | nn.Conv2d( 63 | hidden_dim, output_channel, kernel_size=3, stride=1, padding=1 64 | ), 65 | ) 66 | self.sigmoid = nn.Sigmoid() 67 | 68 | def forward(self, x): 69 | out = self.deconv1(x) 70 | out = self.conv1(out) 71 | out = self.deconv2(out) 72 | out = self.conv2(out) 73 | return self.sigmoid(out) 74 | -------------------------------------------------------------------------------- /M2TR/models/modules/transformer_block.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class Mlp(nn.Module): 5 | def __init__( 6 | self, 7 | in_features, 8 | hidden_features=None, 9 | out_features=None, 10 | act_layer=nn.GELU, 11 | drop=0.0, 12 | ): 13 | super().__init__() 14 | out_features = out_features or in_features 15 | hidden_features = hidden_features or in_features 16 | self.fc1 = nn.Linear(in_features, hidden_features) 17 | self.act = act_layer() 18 | self.fc2 = nn.Linear(hidden_features, out_features) 19 | self.drop = nn.Dropout(drop) 20 | 21 | def forward(self, x): 22 | x = self.fc1(x) 23 | x = self.act(x) 24 | x = self.drop(x) 25 | x = self.fc2(x) 26 | x = self.drop(x) 27 | return x 28 | 29 | 30 | class FeedForward1D(nn.Module): 31 | def __init__(self, dim, hidden_dim, dropout=0.0): 32 | super(FeedForward1D, self).__init__() 33 | self.net = nn.Sequential( 34 | nn.Linear(dim, hidden_dim), 35 | nn.GELU(), 36 | nn.Dropout(dropout), 37 | nn.Linear(hidden_dim, dim), 38 | nn.Dropout(dropout), 39 | ) 40 | 41 | def forward(self, x): 42 | return self.net(x) 43 | 44 | 45 | class FeedForward2D(nn.Module): 46 | def __init__(self, in_channel, out_channel): 47 | super(FeedForward2D, self).__init__() 48 | self.conv = nn.Sequential( 49 | nn.Conv2d( 50 | in_channel, out_channel, kernel_size=3, padding=2, dilation=2 51 | ), 52 | nn.BatchNorm2d(out_channel), 53 | nn.LeakyReLU(0.2, inplace=True), 54 | nn.Conv2d(out_channel, out_channel, kernel_size=3, padding=1), 55 | nn.BatchNorm2d(out_channel), 56 | nn.LeakyReLU(0.2, inplace=True), 57 | ) 58 | 59 | def forward(self, x): 60 | x = self.conv(x) 61 | return x 62 | -------------------------------------------------------------------------------- /M2TR/models/xception.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from M2TR.utils.registries import MODEL_REGISTRY 8 | 9 | ''' 10 | MODEL: 11 | MODEL_NAME: Xception 12 | PRETRAINED: imagenet 13 | ESCAPE: '' 14 | ''' 15 | 16 | 17 | class SeparableConv2d(nn.Module): 18 | def __init__( 19 | self, 20 | in_channels, 21 | out_channels, 22 | kernel_size=1, 23 | stride=1, 24 | padding=0, 25 | dilation=1, 26 | bias=False, 27 | ): 28 | super(SeparableConv2d, self).__init__() 29 | self.conv1 = nn.Conv2d( 30 | in_channels, 31 | in_channels, 32 | kernel_size, 33 | stride, 34 | padding, 35 | dilation, 36 | groups=in_channels, 37 | bias=bias, 38 | ) 39 | self.pointwise = nn.Conv2d( 40 | in_channels, out_channels, 1, 1, 0, 1, 1, bias=bias 41 | ) 42 | 43 | def forward(self, x): 44 | x = self.conv1(x) 45 | x = self.pointwise(x) 46 | return x 47 | 48 | 49 | class Block(nn.Module): 50 | def __init__( 51 | self, 52 | in_filters, 53 | out_filters, 54 | reps, 55 | strides=1, 56 | start_with_relu=True, 57 | grow_first=True, 58 | ): 59 | super(Block, self).__init__() 60 | 61 | if out_filters != in_filters or strides != 1: 62 | self.skip = nn.Conv2d( 63 | in_filters, out_filters, 1, stride=strides, bias=False 64 | ) 65 | self.skipbn = nn.BatchNorm2d(out_filters) 66 | else: 67 | self.skip = None 68 | 69 | rep = [] 70 | 71 | filters = in_filters 72 | if grow_first: 73 | rep.append(nn.ReLU(inplace=True)) 74 | rep.append( 75 | SeparableConv2d( 76 | in_filters, out_filters, 3, stride=1, padding=1, bias=False 77 | ) 78 | ) 79 | rep.append(nn.BatchNorm2d(out_filters)) 80 | filters = out_filters 81 | 82 | for i in range(reps - 1): 83 | rep.append(nn.ReLU(inplace=True)) 84 | rep.append( 85 | SeparableConv2d( 86 | filters, filters, 3, stride=1, padding=1, bias=False 87 | ) 88 | ) 89 | rep.append(nn.BatchNorm2d(filters)) 90 | 91 | if not grow_first: 92 | rep.append(nn.ReLU(inplace=True)) 93 | rep.append( 94 | SeparableConv2d( 95 | in_filters, out_filters, 3, stride=1, padding=1, bias=False 96 | ) 97 | ) 98 | rep.append(nn.BatchNorm2d(out_filters)) 99 | 100 | if not start_with_relu: 101 | rep = rep[1:] 102 | else: 103 | rep[0] = nn.ReLU(inplace=False) 104 | 105 | if strides != 1: 106 | rep.append(nn.MaxPool2d(3, strides, 1)) 107 | self.rep = nn.Sequential(*rep) 108 | 109 | def forward(self, inp): 110 | x = self.rep(inp) 111 | 112 | if self.skip is not None: 113 | skip = self.skip(inp) 114 | skip = self.skipbn(skip) 115 | else: 116 | skip = inp 117 | 118 | x += skip 119 | return x 120 | 121 | 122 | @MODEL_REGISTRY.register() 123 | class Xception(nn.Module): 124 | def __init__(self, model_cfg): 125 | super(Xception, self).__init__() 126 | num_classes = 2 127 | pretrained = model_cfg['PRETRAINED'] 128 | self.escape = model_cfg['ESCAPE'] 129 | self.conv1 = nn.Conv2d(3, 32, 3, 2, 0, bias=False) 130 | self.bn1 = nn.BatchNorm2d(32) 131 | self.relu1 = nn.ReLU(inplace=True) 132 | self.conv2 = nn.Conv2d(32, 64, 3, bias=False) 133 | self.bn2 = nn.BatchNorm2d(64) 134 | self.relu2 = nn.ReLU(inplace=True) 135 | self.block1 = Block( 136 | 64, 128, 2, 2, start_with_relu=False, grow_first=True 137 | ) 138 | self.block2 = Block( 139 | 128, 256, 2, 2, start_with_relu=True, grow_first=True 140 | ) 141 | self.block3 = Block( 142 | 256, 728, 2, 2, start_with_relu=True, grow_first=True 143 | ) 144 | 145 | self.block4 = Block( 146 | 728, 728, 3, 1, start_with_relu=True, grow_first=True 147 | ) 148 | self.block5 = Block( 149 | 728, 728, 3, 1, start_with_relu=True, grow_first=True 150 | ) 151 | self.block6 = Block( 152 | 728, 728, 3, 1, start_with_relu=True, grow_first=True 153 | ) 154 | self.block7 = Block( 155 | 728, 728, 3, 1, start_with_relu=True, grow_first=True 156 | ) 157 | 158 | self.block8 = Block( 159 | 728, 728, 3, 1, start_with_relu=True, grow_first=True 160 | ) 161 | self.block9 = Block( 162 | 728, 728, 3, 1, start_with_relu=True, grow_first=True 163 | ) 164 | self.block10 = Block( 165 | 728, 728, 3, 1, start_with_relu=True, grow_first=True 166 | ) 167 | self.block11 = Block( 168 | 728, 728, 3, 1, start_with_relu=True, grow_first=True 169 | ) 170 | self.block12 = Block( 171 | 728, 1024, 2, 2, start_with_relu=True, grow_first=False 172 | ) 173 | self.conv3 = SeparableConv2d(1024, 1536, 3, 1, 1) 174 | self.bn3 = nn.BatchNorm2d(1536) 175 | self.relu3 = nn.ReLU(inplace=True) 176 | self.conv4 = SeparableConv2d(1536, 2048, 3, 1, 1) 177 | self.bn4 = nn.BatchNorm2d(2048) 178 | self.relu4 = nn.ReLU(inplace=True) 179 | self.last_linear = nn.Linear(2048, num_classes) 180 | self.seq = [] 181 | self.seq.append( 182 | ( 183 | 'b0', 184 | [ 185 | self.conv1, 186 | lambda x: self.bn1(x), 187 | self.relu1, 188 | self.conv2, 189 | lambda x: self.bn2(x), 190 | ], 191 | ) 192 | ) 193 | self.seq.append(('b1', [self.relu2, self.block1])) 194 | self.seq.append(('b2', [self.block2])) 195 | self.seq.append(('b3', [self.block3])) 196 | self.seq.append(('b4', [self.block4])) 197 | self.seq.append(('b5', [self.block5])) 198 | self.seq.append(('b6', [self.block6])) 199 | self.seq.append(('b7', [self.block7])) 200 | self.seq.append(('b8', [self.block8])) 201 | self.seq.append(('b9', [self.block9])) 202 | self.seq.append(('b10', [self.block10])) 203 | self.seq.append(('b11', [self.block11])) 204 | self.seq.append(('b12', [self.block12])) 205 | self.seq.append( 206 | ( 207 | 'final', 208 | [ 209 | self.conv3, 210 | lambda x: self.bn3(x), 211 | self.relu3, 212 | self.conv4, 213 | lambda x: self.bn4(x), 214 | ], 215 | ) 216 | ) 217 | self.seq.append( 218 | ( 219 | 'logits', 220 | [ 221 | self.relu4, 222 | lambda x: F.adaptive_avg_pool2d(x, (1, 1)), 223 | lambda x: x.view(x.size(0), -1), 224 | self.last_linear, 225 | ], 226 | ) 227 | ) 228 | if pretrained == 'imagenet': 229 | self.load_state_dict( 230 | torch.hub.load_state_dict_from_url( 231 | 'http://data.lip6.fr/cadene/pretrainedmodels/xception-43020ad28.pth' 232 | ), 233 | strict=False, 234 | ) 235 | elif pretrained: 236 | ckpt = torch.load(pretrained, map_location='cpu') 237 | self.load_state_dict(ckpt['state_dict']) 238 | else: 239 | for m in self.modules(): 240 | if isinstance(m, nn.Conv2d): 241 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 242 | m.weight.data.normal_(0, math.sqrt(2.0 / n)) 243 | elif isinstance(m, nn.BatchNorm2d): 244 | m.weight.data.fill_(1) 245 | m.bias.data.zero_() 246 | 247 | def forward(self, samples): 248 | x = samples['img'] 249 | layers = {} 250 | for stage in self.seq: 251 | for f in stage[1]: 252 | x = f(x) 253 | layers[stage[0]] = x 254 | if stage[0] == self.escape: 255 | break 256 | return layers 257 | -------------------------------------------------------------------------------- /M2TR/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .loss import * 2 | from .optimizer import build_optimizer 3 | from .scheduler import build_scheduler 4 | -------------------------------------------------------------------------------- /M2TR/utils/build_helper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | 4 | 5 | import M2TR.models 6 | import M2TR.datasets 7 | import M2TR.utils.distributed as du 8 | import M2TR.utils.logging as logging 9 | from M2TR.utils.registries import ( 10 | DATASET_REGISTRY, 11 | LOSS_REGISTRY, 12 | MODEL_REGISTRY, 13 | ) 14 | 15 | logger = logging.get_logger(__name__) 16 | 17 | 18 | def build_model(cfg, gpu_id=None): 19 | # Construct the model 20 | model_cfg = cfg['MODEL'] 21 | name = model_cfg['MODEL_NAME'] 22 | logger.info('MODEL_NAME: ' + name) 23 | model = MODEL_REGISTRY.get(name)(model_cfg) 24 | 25 | assert torch.cuda.is_available(), "Cuda is not available." 26 | assert ( 27 | cfg['NUM_GPUS'] <= torch.cuda.device_count() 28 | ), "Cannot use more GPU devices than available" 29 | 30 | if gpu_id is None: 31 | # Determine the GPU used by the current process 32 | cur_device = torch.cuda.current_device() 33 | else: 34 | cur_device = gpu_id 35 | # Transfer the model to the current GPU device 36 | model = model.cuda(device=cur_device) 37 | # Use multi-process data parallel model in the multi-gpu setting 38 | if cfg['NUM_GPUS'] > 1: 39 | # Make model replica operate on the current device 40 | model = torch.nn.parallel.DistributedDataParallel( 41 | module=model, device_ids=[cur_device], output_device=cur_device, find_unused_parameters=True 42 | ) 43 | 44 | return model 45 | 46 | 47 | def build_loss_fun(cfg): 48 | loss_cfg = cfg['LOSS'] 49 | name = loss_cfg['LOSS_FUN'] 50 | logger.info('LOSS_FUN: ' + name) 51 | loss_fun = LOSS_REGISTRY.get(name)(loss_cfg) 52 | return loss_fun 53 | 54 | 55 | def build_dataset(mode, cfg): 56 | dataset_cfg = cfg['DATASET'] 57 | name = dataset_cfg['DATASET_NAME'] 58 | logger.info('DATASET_NAME: ' + name + ' ' + mode) 59 | return DATASET_REGISTRY.get(name)(dataset_cfg, mode) 60 | 61 | 62 | def build_dataloader(dataset, mode, cfg): 63 | dataloader_cfg = cfg['DATALOADER'] 64 | num_tasks = du.get_world_size() 65 | global_rank = du.get_rank() 66 | 67 | sampler = torch.utils.data.DistributedSampler( 68 | dataset, 69 | num_replicas=num_tasks, 70 | rank=global_rank, 71 | shuffle=True if mode == 'train' else False, 72 | ) 73 | 74 | return DataLoader( 75 | dataset, 76 | batch_size=dataloader_cfg['BATCH_SIZE'], 77 | sampler=sampler, 78 | num_workers=dataloader_cfg['NUM_WORKERS'], 79 | pin_memory=dataloader_cfg['PIN_MEM'], 80 | drop_last=True if mode == 'train' else False, 81 | ) 82 | -------------------------------------------------------------------------------- /M2TR/utils/checkpoint.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os 3 | 4 | import torch 5 | 6 | import M2TR.utils.distributed as du 7 | import M2TR.utils.logging as logging 8 | from M2TR.utils.env import pathmgr 9 | 10 | logger = logging.get_logger(__name__) 11 | 12 | 13 | def make_checkpoint_dir(path_to_job): 14 | """ 15 | Creates the checkpoint directory (if not present already). 16 | Args: 17 | path_to_job (string): the path to the folder of the current job. 18 | """ 19 | checkpoint_dir = os.path.join(path_to_job, "checkpoints") 20 | # Create the checkpoint dir from the master process 21 | if du.is_master_proc() and not pathmgr.exists(checkpoint_dir): 22 | try: 23 | pathmgr.mkdirs(checkpoint_dir) 24 | except Exception: 25 | pass 26 | return checkpoint_dir 27 | 28 | 29 | def get_checkpoint_dir(path_to_job): 30 | """ 31 | Get path for storing checkpoints. 32 | Args: 33 | path_to_job (string): the path to the folder of the current job. 34 | """ 35 | return os.path.join(path_to_job, "checkpoints") 36 | 37 | 38 | def get_path_to_checkpoint(path_to_job, epoch, cfg): 39 | """ 40 | Get the full path to a checkpoint file. 41 | Args: 42 | path_to_job (string): the path to the folder of the current job. 43 | epoch (int): the number of epoch for the checkpoint. 44 | """ 45 | file_name = ( 46 | cfg['MODEL']['MODEL_NAME'] 47 | + '_' 48 | + cfg['DATASET']['DATASET_NAME'] 49 | + '_' 50 | + 'epoch_{:05d}' 51 | + '.pyth' 52 | ) 53 | file_name = file_name.format(epoch) 54 | return os.path.join(get_checkpoint_dir(path_to_job), file_name) 55 | 56 | 57 | def get_last_checkpoint(path_to_job): 58 | """ 59 | Get the last checkpoint from the checkpointing folder. 60 | Args: 61 | path_to_job (string): the path to the folder of the current job. 62 | """ 63 | 64 | d = get_checkpoint_dir(path_to_job) 65 | names = pathmgr.ls(d) if pathmgr.exists(d) else [] 66 | names = [f for f in names if "checkpoint" in f] 67 | assert len(names), "No checkpoints found in '{}'.".format(d) 68 | # Sort the checkpoints by epoch. 69 | name = sorted(names)[-1] 70 | return os.path.join(d, name) 71 | 72 | 73 | def has_checkpoint(path_to_job): 74 | """ 75 | Determines if the given directory contains a checkpoint. 76 | Args: 77 | path_to_job (string): the path to the folder of the current job. 78 | """ 79 | d = get_checkpoint_dir(path_to_job) 80 | files = pathmgr.ls(d) if pathmgr.exists(d) else [] 81 | return any("checkpoint" in f for f in files) 82 | 83 | 84 | def is_checkpoint_epoch(cfg, cur_epoch, multigrid_schedule=None): 85 | """ 86 | Determine if a checkpoint should be saved on current epoch. 87 | Args: 88 | cfg (dict): configs to save. 89 | cur_epoch (int): current number of epoch of the model. 90 | multigrid_schedule (List): schedule for multigrid training. 91 | """ 92 | if cur_epoch + 1 == cfg['TRAIN']['MAX_EPOCH']: 93 | return True 94 | if multigrid_schedule is not None: # TODO remove multigrid_schedule? 95 | prev_epoch = 0 96 | for s in multigrid_schedule: 97 | if cur_epoch < s[-1]: 98 | period = max( 99 | (s[-1] - prev_epoch) // cfg.MULTIGRID.EVAL_FREQ + 1, 1 100 | ) 101 | return (s[-1] - 1 - cur_epoch) % period == 0 102 | prev_epoch = s[-1] 103 | 104 | return (cur_epoch + 1) % cfg['TRAIN']['CHECKPOINT_PERIOD'] == 0 105 | 106 | 107 | def save_checkpoint(model, optimizer, scheduler, epoch, cfg): 108 | """ 109 | Save a checkpoint. 110 | Args: 111 | model (model): model to save the weight to the checkpoint. 112 | optimizer (optim): optimizer to save the historical state. 113 | epoch (int): current number of epoch of the model. 114 | cfg (dict): configs to save. 115 | """ 116 | path_to_job = cfg['TRAIN']['CHECKPOINT_SAVE_PATH'] 117 | # Save checkpoints only from the master process. 118 | if not du.is_master_proc(cfg['NUM_GPUS'] * cfg['NUM_SHARDS']): 119 | return 120 | # Ensure that the checkpoint dir exists. 121 | pathmgr.mkdirs(get_checkpoint_dir(path_to_job)) 122 | # Omit the DDP wrapper in the multi-gpu setting. 123 | sd = ( 124 | model.module.state_dict() if cfg['NUM_GPUS'] > 1 else model.state_dict() 125 | ) 126 | normalized_sd = sub_to_normal_bn(sd) 127 | 128 | # Record the state. 129 | checkpoint = { 130 | "epoch": epoch, 131 | "model_state": normalized_sd, 132 | "optimizer_state": optimizer.state_dict(), 133 | "scheduler_state": scheduler.state_dict() 134 | if scheduler is not None 135 | else None, # TODO 136 | "cfg": cfg, 137 | } 138 | 139 | # Write the checkpoint. 140 | path_to_checkpoint = get_path_to_checkpoint(path_to_job, epoch + 1, cfg) 141 | with pathmgr.open(path_to_checkpoint, "wb") as f: 142 | torch.save(checkpoint, f) 143 | return path_to_checkpoint 144 | 145 | 146 | def load_checkpoint( 147 | path_to_checkpoint, 148 | model, 149 | data_parallel=True, 150 | optimizer=None, 151 | scheduler=None, 152 | epoch_reset=False, 153 | ): 154 | """ 155 | Load the checkpoint from the given file. 156 | Args: 157 | path_to_checkpoint (string): path to the checkpoint to load. 158 | model (model): model to load the weights from the checkpoint. 159 | data_parallel (bool): if true, model is wrapped by 160 | torch.nn.parallel.DistributedDataParallel. 161 | optimizer (optim): optimizer to load the historical state. 162 | 163 | epoch_reset (bool): if True, reset #train iterations from the checkpoint. 164 | 165 | Returns: 166 | (int): the number of training epoch of the checkpoint. 167 | """ 168 | assert pathmgr.exists( 169 | path_to_checkpoint 170 | ), "Checkpoint '{}' not found".format(path_to_checkpoint) 171 | logger.info("Loading network weights from {}.".format(path_to_checkpoint)) 172 | 173 | # Account for the DDP wrapper in the multi-gpu setting. 174 | ms = model.module if data_parallel else model 175 | 176 | # Load the checkpoint on CPU to avoid GPU mem spike. 177 | with pathmgr.open(path_to_checkpoint, "rb") as f: 178 | checkpoint = torch.load(f, map_location="cpu") 179 | 180 | model_state_dict = ( 181 | model.module.state_dict() if data_parallel else model.state_dict() 182 | ) 183 | checkpoint["model_state"] = normal_to_sub_bn( 184 | checkpoint["model_state"], model_state_dict 185 | ) 186 | 187 | pre_train_dict = checkpoint["model_state"] 188 | model_dict = ms.state_dict() 189 | # Match pre-trained weights that have same shape as current model. 190 | pre_train_dict_match = { 191 | k: v 192 | for k, v in pre_train_dict.items() 193 | if k in model_dict and v.size() == model_dict[k].size() 194 | } 195 | 196 | # Weights that do not have match from the pre-trained model. 197 | not_load_layers = [ 198 | k for k in model_dict.keys() if k not in pre_train_dict_match.keys() 199 | ] 200 | 201 | # Log weights that are not loaded with the pre-trained weights. 202 | if not_load_layers: 203 | for k in not_load_layers: 204 | logger.info("Network weights {} not loaded.".format(k)) 205 | 206 | # Load pre-trained weights. 207 | ms.load_state_dict(pre_train_dict_match, strict=False) 208 | epoch = -1 209 | 210 | # Load the optimizer state (commonly not done when fine-tuning) 211 | if "epoch" in checkpoint.keys() and not epoch_reset: 212 | epoch = checkpoint["epoch"] 213 | if optimizer: 214 | optimizer.load_state_dict(checkpoint["optimizer_state"]) 215 | if scheduler: 216 | scheduler.load_state_dict(checkpoint["scheduler_state"]) 217 | 218 | else: 219 | epoch = -1 220 | 221 | return epoch 222 | 223 | 224 | def sub_to_normal_bn(sd): 225 | """ 226 | Convert the Sub-BN paprameters to normal BN parameters in a state dict. 227 | There are two copies of BN layers in a Sub-BN implementation: `bn.bn` and 228 | `bn.split_bn`. `bn.split_bn` is used during training and 229 | "compute_precise_bn". Before saving or evaluation, its stats are copied to 230 | `bn.bn`. We rename `bn.bn` to `bn` and store it to be consistent with normal 231 | BN layers. 232 | Args: 233 | sd (OrderedDict): a dict of parameters whitch might contain Sub-BN 234 | parameters. 235 | Returns: 236 | new_sd (OrderedDict): a dict with Sub-BN parameters reshaped to 237 | normal parameters. 238 | """ 239 | new_sd = copy.deepcopy(sd) 240 | modifications = [ 241 | ("bn.bn.running_mean", "bn.running_mean"), 242 | ("bn.bn.running_var", "bn.running_var"), 243 | ("bn.split_bn.num_batches_tracked", "bn.num_batches_tracked"), 244 | ] 245 | to_remove = ["bn.bn.", ".split_bn."] 246 | for key in sd: 247 | for before, after in modifications: 248 | if key.endswith(before): 249 | new_key = key.split(before)[0] + after 250 | new_sd[new_key] = new_sd.pop(key) 251 | 252 | for rm in to_remove: 253 | if rm in key and key in new_sd: 254 | del new_sd[key] 255 | 256 | for key in new_sd: 257 | if key.endswith("bn.weight") or key.endswith("bn.bias"): 258 | if len(new_sd[key].size()) == 4: 259 | assert all(d == 1 for d in new_sd[key].size()[1:]) 260 | new_sd[key] = new_sd[key][:, 0, 0, 0] 261 | 262 | return new_sd 263 | 264 | 265 | def c2_normal_to_sub_bn(key, model_keys): 266 | """ 267 | Convert BN parameters to Sub-BN parameters if model contains Sub-BNs. 268 | Args: 269 | key (OrderedDict): source dict of parameters. 270 | mdoel_key (OrderedDict): target dict of parameters. 271 | Returns: 272 | new_sd (OrderedDict): converted dict of parameters. 273 | """ 274 | if "bn.running_" in key: 275 | if key in model_keys: 276 | return key 277 | 278 | new_key = key.replace("bn.running_", "bn.split_bn.running_") 279 | if new_key in model_keys: 280 | return new_key 281 | else: 282 | return key 283 | 284 | 285 | def normal_to_sub_bn(checkpoint_sd, model_sd): 286 | """ 287 | Convert BN parameters to Sub-BN parameters if model contains Sub-BNs. 288 | Args: 289 | checkpoint_sd (OrderedDict): source dict of parameters. 290 | model_sd (OrderedDict): target dict of parameters. 291 | Returns: 292 | new_sd (OrderedDict): converted dict of parameters. 293 | """ 294 | for key in model_sd: 295 | if key not in checkpoint_sd: 296 | if "bn.split_bn." in key: 297 | load_key = key.replace("bn.split_bn.", "bn.") 298 | bn_key = key.replace("bn.split_bn.", "bn.bn.") 299 | checkpoint_sd[key] = checkpoint_sd.pop(load_key) 300 | checkpoint_sd[bn_key] = checkpoint_sd[key] 301 | 302 | for key in model_sd: 303 | if key in checkpoint_sd: 304 | model_blob_shape = model_sd[key].shape 305 | c2_blob_shape = checkpoint_sd[key].shape 306 | 307 | if ( 308 | len(model_blob_shape) == 1 309 | and len(c2_blob_shape) == 1 310 | and model_blob_shape[0] > c2_blob_shape[0] 311 | and model_blob_shape[0] % c2_blob_shape[0] == 0 312 | ): 313 | before_shape = checkpoint_sd[key].shape 314 | checkpoint_sd[key] = torch.cat( 315 | [checkpoint_sd[key]] 316 | * (model_blob_shape[0] // c2_blob_shape[0]) 317 | ) 318 | logger.info( 319 | "{} {} -> {}".format( 320 | key, before_shape, checkpoint_sd[key].shape 321 | ) 322 | ) 323 | return checkpoint_sd 324 | 325 | 326 | def load_test_checkpoint(cfg, model): 327 | """ 328 | Loading checkpoint logic for testing. 329 | """ 330 | # Load a checkpoint to test if applicable. 331 | if cfg['TEST']['CHECKPOINT_TEST_PATH'] != "": 332 | load_checkpoint( 333 | cfg['TEST']['CHECKPOINT_TEST_PATH'], 334 | model, 335 | cfg['NUM_GPUS'] > 1, 336 | None, 337 | None, 338 | ) 339 | 340 | else: 341 | logger.info( 342 | "Unknown way of loading checkpoint. Using with random initialization, only for debugging." 343 | ) 344 | 345 | 346 | def load_train_checkpoint(model, optimizer, scheduler, cfg): 347 | """ 348 | Loading checkpoint logic for training. 349 | """ 350 | if cfg['TRAIN']['CHECKPOINT_LOAD_PATH'] != "": 351 | print('Load from given checkpoint file.') 352 | logger.info("Load from given checkpoint file.") 353 | checkpoint_epoch = load_checkpoint( 354 | cfg['TRAIN']['CHECKPOINT_LOAD_PATH'], 355 | model, 356 | cfg['NUM_GPUS'] > 1, 357 | optimizer, 358 | scheduler, 359 | epoch_reset=cfg['TRAIN']['CHECKPOINT_EPOCH_RESET'], 360 | ) 361 | 362 | start_epoch = checkpoint_epoch + 1 363 | 364 | else: 365 | start_epoch = 0 366 | 367 | return start_epoch 368 | -------------------------------------------------------------------------------- /M2TR/utils/distributed.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import logging 3 | import pickle 4 | 5 | import torch 6 | import torch.distributed as dist 7 | 8 | _LOCAL_PROCESS_GROUP = None 9 | 10 | 11 | def all_gather(tensors): 12 | """ 13 | All gathers the provided tensors from all processes across machines. 14 | Args: 15 | tensors (list): tensors to perform all gather across all processes in 16 | all machines. 17 | """ 18 | 19 | gather_list = [] 20 | output_tensor = [] 21 | world_size = dist.get_world_size() 22 | for tensor in tensors: 23 | tensor_placeholder = [ 24 | torch.ones_like(tensor) for _ in range(world_size) 25 | ] 26 | dist.all_gather(tensor_placeholder, tensor, async_op=False) 27 | gather_list.append(tensor_placeholder) 28 | for gathered_tensor in gather_list: 29 | output_tensor.append(torch.cat(gathered_tensor, dim=0)) 30 | return output_tensor 31 | 32 | 33 | def all_reduce(tensors, average=True): 34 | """ 35 | All reduce the provided tensors from all processes across machines. 36 | Args: 37 | tensors (list): tensors to perform all reduce across all processes in 38 | all machines. 39 | average (bool): scales the reduced tensor by the number of overall 40 | processes across all machines. 41 | """ 42 | 43 | for tensor in tensors: 44 | dist.all_reduce(tensor, async_op=False) 45 | if average: 46 | world_size = dist.get_world_size() 47 | for tensor in tensors: 48 | tensor.mul_(1.0 / world_size) 49 | return tensors 50 | 51 | 52 | def init_process_group( 53 | local_rank, 54 | local_world_size, 55 | shard_id, 56 | num_shards, 57 | init_method, 58 | dist_backend="nccl", 59 | ): 60 | """ 61 | Initializes the default process group. 62 | Args: 63 | local_rank (int): the rank on the current local machine. 64 | local_world_size (int): the world size (number of processes running) on 65 | the current local machine. 66 | shard_id (int): the shard index (machine rank) of the current machine. 67 | num_shards (int): number of shards for distributed training. 68 | init_method (string): supporting three different methods for 69 | initializing process groups: 70 | "file": use shared file system to initialize the groups across 71 | different processes. 72 | "tcp": use tcp address to initialize the groups across different 73 | dist_backend (string): backend to use for distributed training. Options 74 | includes gloo, mpi and nccl, the details can be found here: 75 | https://pytorch.org/docs/stable/distributed.html 76 | """ 77 | # Sets the GPU to use. 78 | torch.cuda.set_device(local_rank) 79 | # Initialize the process group. 80 | proc_rank = local_rank + shard_id * local_world_size 81 | world_size = local_world_size * num_shards 82 | dist.init_process_group( 83 | backend=dist_backend, 84 | init_method=init_method, 85 | world_size=world_size, 86 | rank=proc_rank, 87 | ) 88 | 89 | 90 | def is_master_proc(num_gpus=8): 91 | """ 92 | Determines if the current process is the master process. 93 | """ 94 | if torch.distributed.is_initialized(): 95 | return dist.get_rank() % num_gpus == 0 96 | else: 97 | return True 98 | 99 | 100 | def is_root_proc(): 101 | """ 102 | Determines if the current process is the root process. 103 | """ 104 | if torch.distributed.is_initialized(): 105 | return dist.get_rank() == 0 106 | else: 107 | return True 108 | 109 | 110 | def get_world_size(): 111 | """ 112 | Get the size of the world. 113 | """ 114 | if not dist.is_available(): 115 | return 1 116 | if not dist.is_initialized(): 117 | return 1 118 | return dist.get_world_size() 119 | 120 | 121 | def get_rank(): 122 | """ 123 | Get the rank of the current process. 124 | """ 125 | if not dist.is_available(): 126 | return 0 127 | if not dist.is_initialized(): 128 | return 0 129 | return dist.get_rank() 130 | 131 | 132 | def synchronize(): 133 | """ 134 | Helper function to synchronize (barrier) among all processes when 135 | using distributed training 136 | """ 137 | if not dist.is_available(): 138 | return 139 | if not dist.is_initialized(): 140 | return 141 | world_size = dist.get_world_size() 142 | if world_size == 1: 143 | return 144 | dist.barrier() 145 | 146 | 147 | def is_dist_avail_and_initialized(): 148 | if not dist.is_available(): 149 | return False 150 | if not dist.is_initialized(): 151 | return False 152 | return True 153 | 154 | 155 | @functools.lru_cache() 156 | def _get_global_gloo_group(): 157 | """ 158 | Return a process group based on gloo backend, containing all the ranks 159 | The result is cached. 160 | Returns: 161 | (group): pytorch dist group. 162 | """ 163 | if dist.get_backend() == "nccl": 164 | return dist.new_group(backend="gloo") 165 | else: 166 | return dist.group.WORLD 167 | 168 | 169 | def _serialize_to_tensor(data, group): 170 | """ 171 | Seriialize the tensor to ByteTensor. Note that only `gloo` and `nccl` 172 | backend is supported. 173 | Args: 174 | data (data): data to be serialized. 175 | group (group): pytorch dist group. 176 | Returns: 177 | tensor (ByteTensor): tensor that serialized. 178 | """ 179 | 180 | backend = dist.get_backend(group) 181 | assert backend in ["gloo", "nccl"] 182 | device = torch.device("cpu" if backend == "gloo" else "cuda") 183 | 184 | buffer = pickle.dumps(data) 185 | if len(buffer) > 1024 ** 3: 186 | logger = logging.getLogger(__name__) 187 | logger.warning( 188 | "Rank {} trying to all-gather {:.2f} GB of data on device {}".format( 189 | get_rank(), len(buffer) / (1024 ** 3), device 190 | ) 191 | ) 192 | storage = torch.ByteStorage.from_buffer(buffer) 193 | tensor = torch.ByteTensor(storage).to(device=device) 194 | return tensor 195 | 196 | 197 | def _pad_to_largest_tensor(tensor, group): 198 | """ 199 | Padding all the tensors from different GPUs to the largest ones. 200 | Args: 201 | tensor (tensor): tensor to pad. 202 | group (group): pytorch dist group. 203 | Returns: 204 | list[int]: size of the tensor, on each rank 205 | Tensor: padded tensor that has the max size 206 | """ 207 | world_size = dist.get_world_size(group=group) 208 | assert ( 209 | world_size >= 1 210 | ), "comm.gather/all_gather must be called from ranks within the given group!" 211 | local_size = torch.tensor( 212 | [tensor.numel()], dtype=torch.int64, device=tensor.device 213 | ) 214 | size_list = [ 215 | torch.zeros([1], dtype=torch.int64, device=tensor.device) 216 | for _ in range(world_size) 217 | ] 218 | dist.all_gather(size_list, local_size, group=group) 219 | size_list = [int(size.item()) for size in size_list] 220 | 221 | max_size = max(size_list) 222 | 223 | # we pad the tensor because torch all_gather does not support 224 | # gathering tensors of different shapes 225 | if local_size != max_size: 226 | padding = torch.zeros( 227 | (max_size - local_size,), dtype=torch.uint8, device=tensor.device 228 | ) 229 | tensor = torch.cat((tensor, padding), dim=0) 230 | return size_list, tensor 231 | 232 | 233 | def all_gather_unaligned(data, group=None): 234 | """ 235 | Run all_gather on arbitrary picklable data (not necessarily tensors). 236 | Args: 237 | data: any picklable object 238 | group: a torch process group. By default, will use a group which 239 | contains all ranks on gloo backend. 240 | Returns: 241 | list[data]: list of data gathered from each rank 242 | """ 243 | if get_world_size() == 1: 244 | return [data] 245 | if group is None: 246 | group = _get_global_gloo_group() 247 | if dist.get_world_size(group) == 1: 248 | return [data] 249 | 250 | tensor = _serialize_to_tensor(data, group) 251 | 252 | size_list, tensor = _pad_to_largest_tensor(tensor, group) 253 | max_size = max(size_list) 254 | 255 | # receiving Tensor from all ranks 256 | tensor_list = [ 257 | torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) 258 | for _ in size_list 259 | ] 260 | dist.all_gather(tensor_list, tensor, group=group) 261 | 262 | data_list = [] 263 | for size, tensor in zip(size_list, tensor_list): 264 | buffer = tensor.cpu().numpy().tobytes()[:size] 265 | data_list.append(pickle.loads(buffer)) 266 | 267 | return data_list 268 | 269 | 270 | def init_distributed_training(cfg): 271 | """ 272 | Initialize variables needed for distributed training. 273 | """ 274 | if cfg['NUM_GPUS'] <= 1: 275 | return 276 | num_gpus_per_machine = cfg['NUM_GPUS'] 277 | num_machines = cfg['NUM_SHARDS'] 278 | for i in range(num_machines): 279 | ranks_on_i = list( 280 | range(i * num_gpus_per_machine, (i + 1) * num_gpus_per_machine) 281 | ) 282 | pg = dist.new_group(ranks_on_i) 283 | if i == cfg['SHARD_ID']: 284 | global _LOCAL_PROCESS_GROUP 285 | _LOCAL_PROCESS_GROUP = pg 286 | -------------------------------------------------------------------------------- /M2TR/utils/env.py: -------------------------------------------------------------------------------- 1 | from iopath.common.file_io import PathManagerFactory 2 | 3 | _ENV_SETUP_DONE = False 4 | pathmgr = PathManagerFactory.get(key="M2TR") 5 | 6 | 7 | def setup_environment(): 8 | global _ENV_SETUP_DONE 9 | if _ENV_SETUP_DONE: 10 | return 11 | _ENV_SETUP_DONE = True 12 | -------------------------------------------------------------------------------- /M2TR/utils/logging.py: -------------------------------------------------------------------------------- 1 | import atexit 2 | import builtins 3 | import decimal 4 | import functools 5 | import logging 6 | import os 7 | import sys 8 | import time 9 | 10 | import simplejson 11 | 12 | import M2TR.utils.distributed as du 13 | from M2TR.utils.env import pathmgr 14 | 15 | 16 | def _suppress_print(): 17 | """ 18 | Suppresses printing from the current process. 19 | """ 20 | 21 | def print_pass(*objects, sep=" ", end="\n", file=sys.stdout, flush=False): 22 | pass 23 | 24 | builtins.print = print_pass 25 | 26 | 27 | @functools.lru_cache(maxsize=None) 28 | def _cached_log_stream(filename): 29 | # Use 1K buffer if writing to cloud storage. 30 | io = pathmgr.open( 31 | filename, "a", buffering=1024 if "://" in filename else -1 32 | ) 33 | atexit.register(io.close) 34 | return io 35 | 36 | 37 | def setup_logging(cfg, mode='train'): 38 | """ 39 | Sets up the logging for multiple processes. Only enable the logging for the 40 | master process, and suppress logging for the non-master processes. 41 | """ 42 | output_dir = cfg['LOG_FILE_PATH'] 43 | cur_time = time.strftime('%Y-%m-%d_%H:%M:%S', time.localtime(time.time())) 44 | file_name = ( 45 | cur_time 46 | + '_' 47 | + cfg['MODEL']['MODEL_NAME'] 48 | + '_' 49 | + cfg['DATASET']['DATASET_NAME'] 50 | + '_' 51 | + str(cfg['OPTIMIZER']['BASE_LR']) 52 | + '_' 53 | + mode 54 | + '.log' 55 | ) 56 | # Set up logging format. 57 | _FORMAT = "[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s" 58 | 59 | if du.is_master_proc(): 60 | # Enable logging for the master process. 61 | logging.root.handlers = [] 62 | else: 63 | # Suppress logging for non-master processes. 64 | _suppress_print() 65 | 66 | logger = logging.getLogger() 67 | logger.setLevel(logging.DEBUG) 68 | logger.propagate = False 69 | plain_formatter = logging.Formatter( 70 | "[%(asctime)s][%(levelname)s] %(filename)s: %(lineno)3d: %(message)s", 71 | datefmt="%m/%d %H:%M:%S", 72 | ) 73 | 74 | if du.is_master_proc(): 75 | ch = logging.StreamHandler(stream=sys.stdout) 76 | ch.setLevel(logging.DEBUG) 77 | ch.setFormatter(plain_formatter) 78 | logger.addHandler(ch) 79 | 80 | if output_dir is not None and du.is_master_proc(du.get_world_size()): 81 | if not os.path.exists(output_dir): 82 | os.makedirs(output_dir) 83 | filename = os.path.join(output_dir, file_name) 84 | fh = logging.StreamHandler(_cached_log_stream(filename)) 85 | fh.setLevel(logging.DEBUG) 86 | fh.setFormatter(plain_formatter) 87 | logger.addHandler(fh) 88 | 89 | 90 | def get_logger(name): 91 | """ 92 | Retrieve the logger with the specified name or, if name is None, return a 93 | logger which is the root logger of the hierarchy. 94 | Args: 95 | name (string): name of the logger. 96 | """ 97 | return logging.getLogger(name) 98 | 99 | 100 | def log_json_stats(stats): 101 | """ 102 | Logs json stats. 103 | Args: 104 | stats (dict): a dictionary of statistical information to log. 105 | """ 106 | stats = { 107 | k: decimal.Decimal("{:.5f}".format(v)) if isinstance(v, float) else v 108 | for k, v in stats.items() 109 | } 110 | json_stats = simplejson.dumps(stats, sort_keys=True, use_decimal=True) 111 | logger = get_logger(__name__) 112 | logger.info("json_stats: {:s}".format(json_stats)) 113 | -------------------------------------------------------------------------------- /M2TR/utils/loss.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from fvcore.common.registry import Registry 7 | 8 | from M2TR.utils.registries import LOSS_REGISTRY 9 | 10 | from .build_helper import LOSS_REGISTRY 11 | 12 | 13 | class BaseWeightedLoss(nn.Module, metaclass=ABCMeta): 14 | """Base class for loss. 15 | All subclass should overwrite the ``_forward()`` method which returns the 16 | normal loss without loss weights. 17 | Args: 18 | loss_weight (float): Factor scalar multiplied on the loss. 19 | Default: 1.0. 20 | """ 21 | 22 | def __init__(self, loss_weight=1.0): 23 | super().__init__() 24 | self.loss_weight = loss_weight 25 | 26 | @abstractmethod 27 | def _forward(self, *args, **kwargs): 28 | pass 29 | 30 | def forward(self, *args, **kwargs): 31 | """Defines the computation performed at every call. 32 | Args: 33 | *args: The positional arguments for the corresponding 34 | loss. 35 | **kwargs: The keyword arguments for the corresponding 36 | loss. 37 | Returns: 38 | torch.Tensor: The calculated loss. 39 | """ 40 | ret = self._forward(*args, **kwargs) 41 | if isinstance(ret, dict): 42 | for k in ret: 43 | if 'loss' in k: 44 | ret[k] *= self.loss_weight 45 | else: 46 | ret *= self.loss_weight 47 | return ret 48 | 49 | 50 | @LOSS_REGISTRY.register() 51 | class CrossEntropyLoss(BaseWeightedLoss): 52 | """Cross Entropy Loss. 53 | Support two kinds of labels and their corresponding loss type. It's worth 54 | mentioning that loss type will be detected by the shape of ``cls_score`` 55 | and ``label``. 56 | 1) Hard label: This label is an integer array and all of the elements are 57 | in the range [0, num_classes - 1]. This label's shape should be 58 | ``cls_score``'s shape with the `num_classes` dimension removed. 59 | 2) Soft label(probablity distribution over classes): This label is a 60 | probability distribution and all of the elements are in the range 61 | [0, 1]. This label's shape must be the same as ``cls_score``. For now, 62 | only 2-dim soft label is supported. 63 | Args: 64 | loss_weight (float): Factor scalar multiplied on the loss. 65 | Default: 1.0. 66 | class_weight (list[float] | None): Loss weight for each class. If set 67 | as None, use the same weight 1 for all classes. Only applies 68 | to CrossEntropyLoss and BCELossWithLogits (should not be set when 69 | using other losses). Default: None. 70 | """ 71 | 72 | def __init__(self, loss_cfg): 73 | super().__init__(loss_weight=loss_cfg['LOSS_WEIGHT']) 74 | self.class_weight = ( 75 | torch.Tensor(loss_cfg['CLASS_WEIGHT']) 76 | if 'CLASS_WEIGHT' in loss_cfg 77 | else None 78 | ) 79 | 80 | def _forward(self, outputs, samples, **kwargs): 81 | """Forward function. 82 | Args: 83 | cls_score (torch.Tensor): The class score. 84 | samples (dict): The ground truth labels. 85 | kwargs: Any keyword argument to be used to calculate 86 | CrossEntropy loss. 87 | Returns: 88 | torch.Tensor: The returned CrossEntropy loss. 89 | """ 90 | cls_score = outputs['logits'] 91 | label = samples['bin_label_onehot'] 92 | if cls_score.size() == label.size(): 93 | # calculate loss for soft labels 94 | 95 | assert cls_score.dim() == 2, 'Only support 2-dim soft label' 96 | assert len(kwargs) == 0, ( 97 | 'For now, no extra args are supported for soft label, ' 98 | f'but get {kwargs}' 99 | ) 100 | 101 | lsm = F.log_softmax(cls_score, 1) 102 | if self.class_weight is not None: 103 | lsm = lsm * self.class_weight.unsqueeze(0).to(cls_score.device) 104 | loss_cls = -(label * lsm).sum(1) 105 | 106 | # default reduction 'mean' 107 | if self.class_weight is not None: 108 | # Use weighted average as pytorch CrossEntropyLoss does. 109 | # For more information, please visit https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html # noqa 110 | loss_cls = loss_cls.sum() / torch.sum( 111 | self.class_weight.unsqueeze(0).to(cls_score.device) * label 112 | ) 113 | else: 114 | loss_cls = loss_cls.mean() 115 | else: 116 | # calculate loss for hard label 117 | 118 | if self.class_weight is not None: 119 | assert ( 120 | 'weight' not in kwargs 121 | ), "The key 'weight' already exists." 122 | kwargs['weight'] = self.class_weight.to(cls_score.device) 123 | loss_cls = F.cross_entropy(cls_score, label, **kwargs) 124 | 125 | return loss_cls 126 | 127 | 128 | @LOSS_REGISTRY.register() 129 | class BCELossWithLogits(BaseWeightedLoss): 130 | """Binary Cross Entropy Loss with logits. 131 | Args: 132 | loss_weight (float): Factor scalar multiplied on the loss. 133 | Default: 1.0. 134 | class_weight (list[float] | None): Loss weight for each class. If set 135 | as None, use the same weight 1 for all classes. Only applies 136 | to CrossEntropyLoss and BCELossWithLogits (should not be set when 137 | using other losses). Default: None. 138 | """ 139 | 140 | def __init__(self, loss_cfg): 141 | super().__init__(loss_weight=loss_cfg['LOSS_WEIGHT']) 142 | self.class_weight = ( 143 | torch.Tensor(loss_cfg['CLASS_WEIGHT']) 144 | if 'CLASS_WEIGHT' in loss_cfg 145 | else None 146 | ) 147 | 148 | def _forward(self, outputs, samples, **kwargs): 149 | """Forward function. 150 | Args: 151 | cls_score (torch.Tensor): The class score. 152 | samples (dict): The ground truth labels. 153 | kwargs: Any keyword argument to be used to calculate 154 | bce loss with logits. 155 | Returns: 156 | torch.Tensor: The returned bce loss with logits. 157 | """ 158 | cls_score = outputs['logits'] 159 | label = samples['bin_label_onehot'] 160 | if self.class_weight is not None: 161 | assert ( 162 | 'weight' not in kwargs 163 | ), "The key 'weight' already exists." 164 | kwargs['weight'] = self.class_weight.to(cls_score.device) 165 | loss_cls = F.binary_cross_entropy_with_logits( 166 | cls_score, label, **kwargs 167 | ) 168 | return loss_cls 169 | 170 | 171 | @LOSS_REGISTRY.register() 172 | class MSELoss(BaseWeightedLoss): 173 | """MSE Loss 174 | Args: 175 | loss_weight (float): Factor scalar multiplied on the loss. 176 | Default: 1.0. 177 | 178 | """ 179 | 180 | def __init__(self, loss_cfg): 181 | super().__init__(loss_weight=loss_cfg['LOSS_WEIGHT']) 182 | self.mse = nn.MSELoss() 183 | 184 | def _forward(self, pred_mask, gt_mask, **kwargs): # TODO samples 185 | loss = self.mse(pred_mask, gt_mask) 186 | return loss 187 | 188 | 189 | @LOSS_REGISTRY.register() 190 | class ICCLoss(BaseWeightedLoss): 191 | """Contrastive Loss 192 | Args: 193 | loss_weight (float): Factor scalar multiplied on the loss. 194 | Default: 1.0. 195 | """ 196 | 197 | def __init__(self, loss_cfg): 198 | super().__init__(loss_weight=loss_cfg['LOSS_WEIGHT']) 199 | 200 | def _forward(self, feature, label, **kwargs): # TODO samples 201 | # size of feature is (b, 1024) 202 | # size of label is (b) 203 | C = feature.size(1) 204 | label = label.unsqueeze(1) 205 | label = label.repeat(1, C) 206 | # print(label.device) 207 | label = label.type(torch.BoolTensor).cuda() 208 | 209 | res_label = torch.zeros(label.size(), dtype=label.dtype) 210 | res_label = torch.where(label == 1, 0, 1) 211 | res_label = res_label.type(torch.BoolTensor).cuda() 212 | 213 | # print(label, res_label) 214 | pos_feature = torch.masked_select(feature, label) 215 | neg_feature = torch.masked_select(feature, res_label) 216 | 217 | # print('pos_fea: ', pos_feature.device) 218 | # print('nge_fea: ', neg_feature.device) 219 | pos_feature = pos_feature.view(-1, C) 220 | neg_feature = neg_feature.view(-1, C) 221 | 222 | pos_center = torch.mean(pos_feature, dim=0, keepdim=True) 223 | 224 | # dis_pos = torch.sum((pos_feature - pos_center)**2) / torch.norm(pos_feature, p=1) 225 | # dis_neg = torch.sum((neg_feature - pos_center)**2) / torch.norm(neg_feature, p=1) 226 | num_p = pos_feature.size(0) 227 | num_n = neg_feature.size(0) 228 | pos_center1 = pos_center.repeat(num_p, 1) 229 | pos_center2 = pos_center.repeat(num_n, 1) 230 | dis_pos = F.cosine_similarity(pos_feature, pos_center1, eps=1e-6) 231 | dis_pos = torch.mean(dis_pos, dim=0) 232 | dis_neg = F.cosine_similarity(neg_feature, pos_center2, eps=1e-6) 233 | dis_neg = torch.mean(dis_neg, dim=0) 234 | 235 | loss = dis_pos - dis_neg 236 | 237 | return loss 238 | 239 | 240 | @LOSS_REGISTRY.register() 241 | class FocalLoss(BaseWeightedLoss): 242 | def __init__(self, loss_cfg): 243 | super().__init__(loss_weight=loss_cfg['LOSS_WEIGHT']) 244 | super(FocalLoss, self).__init__() 245 | self.alpha = loss_cfg['ALPHA'] if 'ALPHA' in loss_cfg.keys() else 1 246 | self.gamma = loss_cfg['GAMMA'] if 'GAMMA' in loss_cfg.keys() else 2 247 | self.logits = ( 248 | loss_cfg['LOGITS'] if 'LOGITS' in loss_cfg.keys() else True 249 | ) 250 | self.reduce = ( 251 | loss_cfg['REDUCE'] if 'REDUCE' in loss_cfg.keys() else True 252 | ) 253 | 254 | def _forward(self, outputs, samples, **kwargs): 255 | cls_score = outputs['logits'] 256 | label = samples['bin_label_onehot'] 257 | if self.logits: # TODO 258 | BCE_loss = F.binary_cross_entropy_with_logits( 259 | cls_score, label, reduce=False 260 | ) 261 | else: 262 | BCE_loss = F.binary_cross_entropy(cls_score, label, reduce=False) 263 | pt = torch.exp(-BCE_loss) 264 | F_loss = self.alpha * (1 - pt) ** self.gamma * BCE_loss 265 | 266 | if self.reduce: 267 | return torch.mean(F_loss) 268 | else: 269 | return F_loss 270 | 271 | 272 | @LOSS_REGISTRY.register() 273 | class Auxiliary_Loss_v2(BaseWeightedLoss): 274 | def __init__(self, loss_cfg): 275 | super().__init__(loss_weight=loss_cfg['AUX_LOSS_WEIGHT']) 276 | M = loss_cfg['M'] if 'M' in loss_cfg.keys() else 1 277 | N = loss_cfg['N'] if 'N' in loss_cfg.keys() else 1 278 | C = loss_cfg['C'] if 'C' in loss_cfg.keys() else 1 279 | alpha = loss_cfg['ALPHA'] if 'ALPHA' in loss_cfg.keys() else 0.05 280 | margin = loss_cfg['MARGIN'] if 'MARGIN' in loss_cfg.keys() else 1 281 | inner_margin = ( 282 | loss_cfg['INNER_MARGIN'] 283 | if 'INNER_MARGIN' in loss_cfg.keys() 284 | else [0.1, 5] 285 | ) 286 | 287 | self.register_buffer('feature_centers', torch.zeros(M, N)) 288 | self.register_buffer('alpha', torch.tensor(alpha)) 289 | self.num_classes = C 290 | self.margin = margin 291 | from M2TR.models.matdd import AttentionPooling 292 | 293 | self.atp = AttentionPooling() 294 | self.register_buffer('inner_margin', torch.Tensor(inner_margin)) 295 | 296 | def _forward(self, feature_map_d, attentions, y): 297 | B, N, H, W = feature_map_d.size() 298 | B, M, AH, AW = attentions.size() 299 | if AH != H or AW != W: 300 | attentions = F.interpolate( 301 | attentions, (H, W), mode='bilinear', align_corners=True 302 | ) 303 | feature_matrix = self.atp(feature_map_d, attentions) 304 | feature_centers = self.feature_centers 305 | center_momentum = feature_matrix - feature_centers 306 | real_mask = (y == 0).view(-1, 1, 1) 307 | fcts = ( 308 | self.alpha * torch.mean(center_momentum * real_mask, dim=0) 309 | + feature_centers 310 | ) 311 | fctsd = fcts.detach() 312 | if self.training: 313 | with torch.no_grad(): 314 | if torch.distributed.is_initialized(): 315 | torch.distributed.all_reduce( 316 | fctsd, torch.distributed.ReduceOp.SUM 317 | ) 318 | fctsd /= torch.distributed.get_world_size() 319 | self.feature_centers = fctsd 320 | inner_margin = self.inner_margin[y] 321 | intra_class_loss = F.relu( 322 | torch.norm(feature_matrix - fcts, dim=[1, 2]) 323 | * torch.sign(inner_margin) 324 | - inner_margin 325 | ) 326 | intra_class_loss = torch.mean(intra_class_loss) 327 | inter_class_loss = 0 328 | for j in range(M): 329 | for k in range(j + 1, M): 330 | inter_class_loss += F.relu( 331 | self.margin - torch.dist(fcts[j], fcts[k]), inplace=False 332 | ) 333 | inter_class_loss = inter_class_loss / M / self.alpha 334 | # fmd=attentions.flatten(2) 335 | # diverse_loss=torch.mean(F.relu(F.cosine_similarity(fmd.unsqueeze(1),fmd.unsqueeze(2),dim=3)-self.margin,inplace=True)*(1-torch.eye(M,device=attentions.device))) 336 | return intra_class_loss + inter_class_loss, feature_matrix 337 | 338 | 339 | class Auxiliary_Loss_v1(nn.Module): 340 | def __init__(self, loss_cfg): 341 | super().__init__(loss_weight=loss_cfg['loss_weight']) 342 | M = loss_cfg['M'] if 'M' in loss_cfg.keys() else 1 343 | N = loss_cfg['N'] if 'N' in loss_cfg.keys() else 1 344 | C = loss_cfg['N'] if 'N' in loss_cfg.keys() else 1 345 | alpha = loss_cfg['N'] if 'N' in loss_cfg.keys() else 0.05 346 | margin = loss_cfg['N'] if 'N' in loss_cfg.keys() else 1 347 | inner_margin = ( 348 | loss_cfg['inner_margin'] 349 | if 'inner_margin' in loss_cfg.keys() 350 | else [0.01, 0.02] 351 | ) 352 | self.register_buffer('feature_centers', torch.zeros(M, N)) 353 | self.register_buffer('alpha', torch.tensor(alpha)) 354 | self.num_classes = C 355 | self.margin = margin 356 | from M2TR.models.matdd import AttentionPooling 357 | 358 | self.atp = AttentionPooling() 359 | self.register_buffer('inner_margin', torch.Tensor(inner_margin)) 360 | 361 | def forward(self, feature_map_d, attentions, y): 362 | B, N, H, W = feature_map_d.size() 363 | B, M, AH, AW = attentions.size() 364 | if AH != H or AW != W: 365 | attentions = F.interpolate( 366 | attentions, (H, W), mode='bilinear', align_corners=True 367 | ) 368 | feature_matrix = self.atp(feature_map_d, attentions) 369 | feature_centers = self.feature_centers.detach() 370 | center_momentum = feature_matrix - feature_centers 371 | fcts = self.alpha * torch.mean(center_momentum, dim=0) + feature_centers 372 | fctsd = fcts.detach() 373 | if self.training: 374 | with torch.no_grad(): 375 | if torch.distributed.is_initialized(): 376 | torch.distributed.all_reduce( 377 | fctsd, torch.distributed.ReduceOp.SUM 378 | ) 379 | fctsd /= torch.distributed.get_world_size() 380 | self.feature_centers = fctsd 381 | inner_margin = torch.gather( 382 | self.inner_margin.repeat(B, 1), 1, y.unsqueeze(1) 383 | ) 384 | intra_class_loss = F.relu( 385 | torch.norm(feature_matrix - fcts, dim=-1) - inner_margin 386 | ) 387 | intra_class_loss = torch.mean(intra_class_loss) 388 | inter_class_loss = 0 389 | for j in range(M): 390 | for k in range(j + 1, M): 391 | inter_class_loss += F.relu( 392 | self.margin - torch.dist(fcts[j], fcts[k]), inplace=False 393 | ) 394 | inter_calss_loss = inter_class_loss / M / self.alpha 395 | # fmd=attentions.flatten(2) 396 | # inter_class_loss=torch.mean(F.relu(F.cosine_similarity(fmd.unsqueeze(1),fmd.unsqueeze(2),dim=3)-self.margin,inplace=True)*(1-torch.eye(M,device=attentions.device))) 397 | return intra_class_loss + inter_class_loss, feature_matrix 398 | 399 | 400 | @LOSS_REGISTRY.register() 401 | class Multi_attentional_Deepfake_Detection_loss(nn.Module): 402 | def __init__(self, loss_cfg) -> None: 403 | super().__init__() 404 | self.loss_cfg = loss_cfg 405 | 406 | def forward(self, loss_pack, label): 407 | if 'loss' in loss_pack: 408 | return loss_pack['loss'] 409 | loss = ( 410 | self.loss_cfg['ENSEMBLE_LOSS_WEIGHT'] * loss_pack['ensemble_loss'] 411 | + self.loss_cfg['AUX_LOSS_WEIGHT'] * loss_pack['aux_loss'] 412 | ) 413 | if self.loss_cfg['AGDA_LOSS_WEIGHT'] != 0: 414 | loss += ( 415 | self.loss_cfg['AGDA_LOSS_WEIGHT'] 416 | * loss_pack['AGDA_ensemble_loss'] 417 | + self.loss_cfg['MATCH_LOSS_WEIGHT'] * loss_pack['match_loss'] 418 | ) 419 | return loss 420 | -------------------------------------------------------------------------------- /M2TR/utils/meters.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import time 3 | from collections import defaultdict, deque 4 | 5 | import numpy as np 6 | import torch 7 | import torch.distributed as dist 8 | from fvcore.common.timer import Timer 9 | 10 | import M2TR.utils.distributed as du 11 | import M2TR.utils.logging as logging 12 | from sklearn.metrics import roc_auc_score 13 | 14 | logger = logging.get_logger(__name__) 15 | 16 | class SmoothedValue(object): 17 | """Track a series of values and provide access to smoothed values over a 18 | window or the global series average. 19 | """ 20 | 21 | def __init__(self, window_size=20, fmt=None): 22 | if fmt is None: 23 | fmt = "{median:.4f} ({global_avg:.4f})" 24 | self.deque = deque(maxlen=window_size) 25 | self.total = 0.0 26 | self.count = 0 27 | self.fmt = fmt 28 | 29 | def update(self, value, n=1): 30 | self.deque.append(value) 31 | self.count += n 32 | self.total += value * n 33 | 34 | def synchronize_between_processes(self): 35 | """ 36 | Warning: does not synchronize the deque! 37 | """ 38 | if not du.is_dist_avail_and_initialized(): 39 | return 40 | t = torch.tensor( 41 | [self.count, self.total], dtype=torch.float64, device='cuda' 42 | ) 43 | dist.barrier() 44 | dist.all_reduce(t) 45 | t = t.tolist() 46 | self.count = int(t[0]) 47 | self.total = t[1] 48 | 49 | @property 50 | def median(self): 51 | d = torch.tensor(list(self.deque)) 52 | return d.median().item() 53 | 54 | @property 55 | def avg(self): 56 | d = torch.tensor(list(self.deque), dtype=torch.float32) 57 | return d.mean().item() 58 | 59 | @property 60 | def global_avg(self): 61 | return self.total / self.count 62 | 63 | @property 64 | def max(self): 65 | return max(self.deque) 66 | 67 | @property 68 | def value(self): 69 | return self.deque[-1] 70 | 71 | def __str__(self): 72 | return self.fmt.format( 73 | median=self.median, 74 | avg=self.avg, 75 | global_avg=self.global_avg, 76 | max=self.max, 77 | value=self.value, 78 | ) 79 | 80 | 81 | class AucMetric(): 82 | """ 83 | 84 | """ 85 | 86 | def __init__(self, num_gpus): 87 | self.labels = torch.Tensor().cuda() 88 | self.preds = torch.Tensor().cuda() 89 | self.num_gpus = num_gpus 90 | 91 | def update(self, labels, preds): 92 | self.labels = torch.cat([self.labels, labels], dim=0) 93 | self.preds = torch.cat([self.preds, preds], dim=0) 94 | 95 | def synchronize_between_processes(self): 96 | if not du.is_dist_avail_and_initialized(): 97 | return 98 | labels = [torch.zeros(len(self.labels), dtype=self.labels[0].dtype).cuda() for _ in range(self.num_gpus)] 99 | preds = [torch.zeros(len(self.preds), dtype=self.preds[0].dtype).cuda() for _ in range(self.num_gpus)] 100 | dist.all_gather(labels, self.labels) 101 | dist.all_gather(preds, self.preds) 102 | labels = torch.cat(labels, dim=0).cpu() 103 | preds = torch.cat(preds, dim=0).cpu() 104 | self.auc = roc_auc_score(labels, preds) 105 | 106 | # def __str__(self): 107 | # return str(self.auc) 108 | 109 | 110 | class MetricLogger(object): 111 | def __init__(self, delimiter="\t"): 112 | self.meters = defaultdict(SmoothedValue) 113 | self.delimiter = delimiter 114 | 115 | def update(self, **kwargs): 116 | for k, v in kwargs.items(): 117 | if isinstance(v, torch.Tensor): 118 | v = v.item() 119 | assert isinstance(v, (float, int)) 120 | self.meters[k].update(v) 121 | 122 | def __getattr__(self, attr): 123 | if attr in self.meters: 124 | return self.meters[attr] 125 | if attr in self.__dict__: 126 | return self.__dict__[attr] 127 | raise AttributeError( 128 | "'{}' object has no attribute '{}'".format( 129 | type(self).__name__, attr 130 | ) 131 | ) 132 | 133 | def __str__(self): 134 | loss_str = [] 135 | for name, meter in self.meters.items(): 136 | loss_str.append("{}: {}".format(name, str(meter))) 137 | return self.delimiter.join(loss_str) 138 | 139 | def synchronize_between_processes(self): 140 | for meter in self.meters.values(): 141 | meter.synchronize_between_processes() 142 | 143 | def add_meter(self, name, meter): 144 | self.meters[name] = meter 145 | 146 | def log_every(self, iterable, print_freq, header=None): 147 | i = 0 148 | if not header: 149 | header = '' 150 | start_time = time.time() 151 | end = time.time() 152 | iter_time = SmoothedValue(fmt='{avg:.4f}') 153 | data_time = SmoothedValue(fmt='{avg:.4f}') 154 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 155 | log_msg = [ 156 | header, 157 | '[{0' + space_fmt + '}/{1}]', 158 | 'eta: {eta}', 159 | '{meters}', 160 | 'time: {time}', 161 | 'data: {data}', 162 | ] 163 | if torch.cuda.is_available(): 164 | log_msg.append('max mem: {memory:.0f}') 165 | log_msg = self.delimiter.join(log_msg) 166 | MB = 1024.0 * 1024.0 167 | for obj in iterable: 168 | data_time.update(time.time() - end) 169 | yield obj 170 | iter_time.update(time.time() - end) 171 | if i % print_freq == 0 or i == len(iterable) - 1: 172 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 173 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 174 | if torch.cuda.is_available(): 175 | logger.info( 176 | log_msg.format( 177 | i, 178 | len(iterable), 179 | eta=eta_string, 180 | meters=str(self), 181 | time=str(iter_time), 182 | data=str(data_time), 183 | memory=torch.cuda.max_memory_allocated() / MB, 184 | ) 185 | ) 186 | else: 187 | logger.info( 188 | log_msg.format( 189 | i, 190 | len(iterable), 191 | eta=eta_string, 192 | meters=str(self), 193 | time=str(iter_time), 194 | data=str(data_time), 195 | ) 196 | ) 197 | i += 1 198 | end = time.time() 199 | total_time = time.time() - start_time 200 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 201 | logger.info( 202 | '{} Total time: {} ({:.4f} s / it)'.format( 203 | header, total_time_str, total_time / len(iterable) 204 | ) 205 | ) 206 | 207 | 208 | class EpochTimer: 209 | """ 210 | A timer which computes the epoch time. 211 | """ 212 | 213 | def __init__(self) -> None: 214 | self.timer = Timer() 215 | self.timer.reset() 216 | self.epoch_times = [] 217 | 218 | def reset(self) -> None: 219 | """ 220 | Reset the epoch timer. 221 | """ 222 | self.timer.reset() 223 | self.epoch_times = [] 224 | 225 | def epoch_tic(self): 226 | """ 227 | Start to record time. 228 | """ 229 | self.timer.reset() 230 | 231 | def epoch_toc(self): 232 | """ 233 | Stop to record time. 234 | """ 235 | self.timer.pause() 236 | self.epoch_times.append(self.timer.seconds()) 237 | 238 | def last_epoch_time(self): 239 | """ 240 | Get the time for the last epoch. 241 | """ 242 | assert len(self.epoch_times) > 0, "No epoch time has been recorded!" 243 | 244 | return self.epoch_times[-1] 245 | 246 | def avg_epoch_time(self): 247 | """ 248 | Calculate the average epoch time among the recorded epochs. 249 | """ 250 | assert len(self.epoch_times) > 0, "No epoch time has been recorded!" 251 | 252 | return np.mean(self.epoch_times) 253 | 254 | def median_epoch_time(self): 255 | """ 256 | Calculate the median epoch time among the recorded epochs. 257 | """ 258 | assert len(self.epoch_times) > 0, "No epoch time has been recorded!" 259 | 260 | return np.median(self.epoch_times) 261 | -------------------------------------------------------------------------------- /M2TR/utils/optimizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def build_optimizer(optim_params, cfg): 4 | """ 5 | Construct a stochastic gradient descent or ADAM optimizer with momentum. 6 | Details can be found in: 7 | Herbert Robbins, and Sutton Monro. "A stochastic approximation method." 8 | and 9 | Diederik P.Kingma, and Jimmy Ba. 10 | "Adam: A Method for Stochastic Optimization." 11 | Args: 12 | model (model): model to perform stochastic gradient descent 13 | optimization or ADAM optimization. 14 | cfg (dict): configs of hyper-parameters of SGD or ADAM, includes base 15 | learning rate, momentum, weight_decay, dampening, and etc. 16 | """ 17 | optimizer_cfg = cfg['OPTIMIZER'] 18 | if optimizer_cfg['OPTIMIZER_METHOD'] == "sgd": 19 | return torch.optim.SGD( 20 | optim_params, 21 | lr=optimizer_cfg['BASE_LR'], 22 | momentum=optimizer_cfg['MOMENTUM'], 23 | # dampening=optimizer_cfg['DAMPENING'], 24 | # weight_decay=optimizer_cfg['WEIGHT_DECAY'], 25 | # nesterov=optimizer_cfg['NESTEROV'], 26 | ) 27 | elif optimizer_cfg['OPTIMIZER_METHOD'] == "rmsprop": 28 | return torch.optim.RMSprop( 29 | optim_params, 30 | lr=optimizer_cfg['BASE_LR'], 31 | alpha=optimizer_cfg['ALPHA'], 32 | eps=optimizer_cfg['EPS'], 33 | weight_decay=optimizer_cfg['WEIGHT_DECAY'], 34 | momentum=optimizer_cfg['MOMENTUM'], 35 | ) 36 | elif optimizer_cfg['OPTIMIZER_METHOD'] == "adam": 37 | return torch.optim.Adam( 38 | optim_params, 39 | lr=optimizer_cfg['BASE_LR'], 40 | betas=optimizer_cfg['ADAM_BETAS'], 41 | eps=optimizer_cfg['EPS'], 42 | weight_decay=optimizer_cfg['WEIGHT_DECAY'], 43 | amsgrad=optimizer_cfg['AMSGRAD'], 44 | ) 45 | elif optimizer_cfg['OPTIMIZER_METHOD'] == "adamw": 46 | return torch.optim.AdamW( 47 | optim_params, 48 | lr=optimizer_cfg['BASE_LR'], 49 | betas=optimizer_cfg['ADAM_BETAS'], 50 | eps=optimizer_cfg['EPS'], 51 | weight_decay=optimizer_cfg['WEIGHT_DECAY'], 52 | amsgrad=optimizer_cfg['AMSGRAD'], 53 | ) 54 | else: 55 | raise NotImplementedError( 56 | "Does not support {} optimizer".format( 57 | optimizer_cfg['OPTIMIZER_METHOD'] 58 | ) 59 | ) 60 | -------------------------------------------------------------------------------- /M2TR/utils/registries.py: -------------------------------------------------------------------------------- 1 | from fvcore.common.registry import Registry 2 | 3 | MODEL_REGISTRY = Registry("MODEL") 4 | MODEL_REGISTRY.__doc__ = """ 5 | Registry for model. 6 | 7 | The registered object will be called with `obj(cfg)`. 8 | The call should return a `torch.nn.Module` object. 9 | """ 10 | 11 | LOSS_REGISTRY = Registry("LOSS") 12 | LOSS_REGISTRY.__doc__ = """ 13 | Registry for loss functions. 14 | The registered object will be called with `obj(cfg)`. 15 | """ 16 | 17 | 18 | DATASET_REGISTRY = Registry("DATASET") 19 | DATASET_REGISTRY.__doc__ = """ 20 | Registry for datasets. 21 | The registered object will be called with `obj(cfg)`. 22 | """ 23 | -------------------------------------------------------------------------------- /M2TR/utils/scheduler.py: -------------------------------------------------------------------------------- 1 | 2 | from timm.scheduler.cosine_lr import CosineLRScheduler 3 | from timm.scheduler.multistep_lr import MultiStepLRScheduler 4 | from timm.scheduler.step_lr import StepLRScheduler 5 | from timm.scheduler.tanh_lr import TanhLRScheduler 6 | 7 | 8 | def build_scheduler(optimizer, cfg): 9 | num_epochs = cfg['TRAIN']['MAX_EPOCH'] 10 | scheduler_cfg = cfg['SCHEDULER'] 11 | 12 | if 'LR_NOISE' in scheduler_cfg: 13 | lr_noise = scheduler_cfg['LR_NOISE'] 14 | if isinstance(lr_noise, (list, tuple)): 15 | noise_range = [n * num_epochs for n in lr_noise] 16 | if len(noise_range) == 1: 17 | noise_range = noise_range[0] 18 | else: 19 | noise_range = lr_noise * num_epochs 20 | else: 21 | noise_range = None 22 | noise_args = dict( 23 | noise_range_t=noise_range, 24 | noise_pct=scheduler_cfg['LR_NOISE_PCT'] 25 | if 'LR_NOISE_PCT' in scheduler_cfg 26 | else 0.67, 27 | noise_std=scheduler_cfg['LR_NOISE_STD'] 28 | if 'LR_NOISE_STD' in scheduler_cfg 29 | else 1.0, 30 | noise_seed=scheduler_cfg['SEED'] 31 | if 'SEED' in scheduler_cfg 32 | else 42, 33 | ) 34 | cycle_args = dict( 35 | cycle_mul=scheduler_cfg['LR_CYCLE_MUL'] 36 | if 'LR_CYCLE_MUL' in scheduler_cfg 37 | else 1.0, 38 | cycle_decay=scheduler_cfg['LR_CYCLE_DECAY'] 39 | if 'LR_CYCLE_DECAY' in scheduler_cfg 40 | else 0.1, 41 | cycle_limit=scheduler_cfg['LR_CYCLE_LIMIT'] 42 | if 'LR_CYCLE_LIMIT' in scheduler_cfg 43 | else 1, 44 | ) 45 | 46 | lr_scheduler = None 47 | 48 | if scheduler_cfg['SCHEDULER_TYPE'] == 'cosine': 49 | lr_scheduler = CosineLRScheduler( 50 | optimizer, 51 | t_initial=num_epochs, 52 | lr_min=scheduler_cfg['MIN_LR'], 53 | warmup_lr_init=scheduler_cfg['WARMUP_LR'], 54 | warmup_t=scheduler_cfg['WARMUP_EPOCHS'], 55 | k_decay=scheduler_cfg['LR_K_DECAY'] 56 | if 'LR_K_DECAY' in scheduler_cfg 57 | else 1.0, 58 | **cycle_args, 59 | **noise_args, 60 | ) 61 | num_epochs = ( 62 | lr_scheduler.get_cycle_length() + scheduler_cfg['COOLDOWN_EPOCHS'] 63 | ) 64 | 65 | elif scheduler_cfg['SCHEDULER_TYPE'] == 'tanh': 66 | lr_scheduler = TanhLRScheduler( 67 | optimizer, 68 | t_initial=num_epochs, 69 | lr_min=scheduler_cfg['MIN_LR'], 70 | warmup_lr_init=scheduler_cfg['WARMUP_LR'], 71 | warmup_t=scheduler_cfg['WARMUP_EPOCHS'], 72 | t_in_epochs=True, 73 | **cycle_args, 74 | **noise_args, 75 | ) 76 | num_epochs = ( 77 | lr_scheduler.get_cycle_length() + scheduler_cfg['COOLDOWN_EPOCHS'] 78 | ) 79 | elif scheduler_cfg['SCHEDULER_TYPE'] == 'step': 80 | lr_scheduler = StepLRScheduler( 81 | optimizer, 82 | decay_t=scheduler_cfg['DECAY_EPOCHS'], 83 | decay_rate=scheduler_cfg['DECAY_RATE'], 84 | warmup_lr_init=scheduler_cfg['WARMUP_LR'], 85 | warmup_t=scheduler_cfg['WARMUP_EPOCHS'], 86 | **noise_args, 87 | ) 88 | elif scheduler_cfg['SCHEDULER_TYPE'] == 'multistep': 89 | lr_scheduler = MultiStepLRScheduler( 90 | optimizer, 91 | decay_t=scheduler_cfg['DECAY_EPOCHS'], 92 | decay_rate=scheduler_cfg['DECAY_RATE'], 93 | warmup_lr_init=scheduler_cfg['WARMUP_LR'], 94 | warmup_t=scheduler_cfg['WARMUP_EPOCHS'], 95 | **noise_args, 96 | ) 97 | 98 | return lr_scheduler, num_epochs 99 | -------------------------------------------------------------------------------- /M2TR/utils/visualization.py: -------------------------------------------------------------------------------- 1 | from torch.utils.tensorboard import SummaryWriter 2 | 3 | import M2TR.utils.distributed as du 4 | 5 | class TensorBoardWriter(SummaryWriter): 6 | def __init__(self, log_dir=None, comment='', purge_step=None, max_queue=10, flush_secs=120, filename_suffix=''): 7 | super().__init__(log_dir, comment, purge_step, max_queue, flush_secs, filename_suffix) 8 | print('h') 9 | self.is_master_proc = du.is_master_proc(du.get_world_size()) 10 | print('hereeee') 11 | print(self.is_master_proc) 12 | 13 | def add_scalar(self, tag, scalar_value, global_step=None, walltime=None, new_style=False, double_precision=False): 14 | if self.is_master_proc: 15 | super().add_scalar(tag, scalar_value, global_step, walltime, new_style, double_precision) 16 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # M2TR: Multi-modal Multi-scale Transformers for DeepfakeDetection 2 | 3 | ## Introduction 4 | 5 | This is the official pytorch implementation of [Multi-modal Multi-scale for Deepfake detection](https://arxiv.org/abs/2104.09770), which is accepted by ICMR 2022. 6 | 7 | 8 |
9 |
10 |