├── tools
├── __init__.py
├── utils.py
├── test.py
└── train.py
├── M2TR
├── __init__.py
├── models
│ ├── __init__.py
│ ├── .DS_Store
│ ├── modules
│ │ ├── conv_block.py
│ │ ├── gram_block.py
│ │ ├── transformer_block.py
│ │ └── head.py
│ ├── base.py
│ ├── xception.py
│ ├── m2tr.py
│ └── efficientnet.py
├── utils
│ ├── __init__.py
│ ├── env.py
│ ├── registries.py
│ ├── visualization.py
│ ├── optimizer.py
│ ├── build_helper.py
│ ├── logging.py
│ ├── scheduler.py
│ ├── meters.py
│ ├── distributed.py
│ ├── checkpoint.py
│ └── loss.py
├── .DS_Store
└── datasets
│ ├── __init__.py
│ ├── dataset.py
│ ├── CelebDF.py
│ ├── FFDF.py
│ ├── ForgeryNet.py
│ └── utils.py
├── configs
├── .DS_Store
├── m2tr.yaml
└── default.yaml
├── imgs
└── network.png
├── setup.py
├── run.py
├── LICENSE
├── requirements.txt
└── README.md
/tools/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/M2TR/__init__.py:
--------------------------------------------------------------------------------
1 | from M2TR.utils.env import setup_environment
2 |
3 | setup_environment()
4 |
--------------------------------------------------------------------------------
/M2TR/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .base import BaseNetwork
2 | from .efficientnet import EfficientNet
3 | from .m2tr import M2TR
4 |
--------------------------------------------------------------------------------
/M2TR/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .loss import *
2 | from .optimizer import build_optimizer
3 | from .scheduler import build_scheduler
4 |
--------------------------------------------------------------------------------
/M2TR/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wdrink/M2TR-Multi-modal-Multi-scale-Transformers-for-Deepfake-Detection/HEAD/M2TR/.DS_Store
--------------------------------------------------------------------------------
/configs/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wdrink/M2TR-Multi-modal-Multi-scale-Transformers-for-Deepfake-Detection/HEAD/configs/.DS_Store
--------------------------------------------------------------------------------
/imgs/network.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wdrink/M2TR-Multi-modal-Multi-scale-Transformers-for-Deepfake-Detection/HEAD/imgs/network.png
--------------------------------------------------------------------------------
/M2TR/models/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wdrink/M2TR-Multi-modal-Multi-scale-Transformers-for-Deepfake-Detection/HEAD/M2TR/models/.DS_Store
--------------------------------------------------------------------------------
/M2TR/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from M2TR.datasets.CelebDF import CelebDF
2 | from M2TR.datasets.dataset import DeepFakeDataset
3 | from M2TR.datasets.FFDF import FFDF
4 | from M2TR.datasets.ForgeryNet import ForgeryNet
5 |
--------------------------------------------------------------------------------
/M2TR/utils/env.py:
--------------------------------------------------------------------------------
1 | from iopath.common.file_io import PathManagerFactory
2 |
3 | _ENV_SETUP_DONE = False
4 | pathmgr = PathManagerFactory.get(key="M2TR")
5 |
6 |
7 | def setup_environment():
8 | global _ENV_SETUP_DONE
9 | if _ENV_SETUP_DONE:
10 | return
11 | _ENV_SETUP_DONE = True
12 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import find_packages, setup
2 |
3 | setup(
4 | name='M2TR',
5 | version='0.1.1',
6 | requires= ['timm',
7 | 'fvcore',
8 | 'albumentations',
9 | 'kornia',
10 | 'simplejson',
11 | 'tensorboard',
12 | ],
13 | packages=find_packages(),
14 | license="apache 2.0")
15 |
--------------------------------------------------------------------------------
/run.py:
--------------------------------------------------------------------------------
1 | from tools.test import test
2 | from tools.train import train
3 | from tools.utils import launch_func, load_config, parse_args
4 |
5 |
6 | def main():
7 | args = parse_args()
8 | cfg = load_config(args)
9 | if cfg['TRAIN']['ENABLE']:
10 | launch_func(cfg=cfg, func=train)
11 | if cfg['TEST']['ENABLE']:
12 | launch_func(cfg=cfg, func=test)
13 |
14 |
15 | if __name__ == '__main__':
16 | main()
17 |
--------------------------------------------------------------------------------
/M2TR/utils/registries.py:
--------------------------------------------------------------------------------
1 | from fvcore.common.registry import Registry
2 |
3 | MODEL_REGISTRY = Registry("MODEL")
4 | MODEL_REGISTRY.__doc__ = """
5 | Registry for model.
6 |
7 | The registered object will be called with `obj(cfg)`.
8 | The call should return a `torch.nn.Module` object.
9 | """
10 |
11 | LOSS_REGISTRY = Registry("LOSS")
12 | LOSS_REGISTRY.__doc__ = """
13 | Registry for loss functions.
14 | The registered object will be called with `obj(cfg)`.
15 | """
16 |
17 |
18 | DATASET_REGISTRY = Registry("DATASET")
19 | DATASET_REGISTRY.__doc__ = """
20 | Registry for datasets.
21 | The registered object will be called with `obj(cfg)`.
22 | """
23 |
--------------------------------------------------------------------------------
/M2TR/utils/visualization.py:
--------------------------------------------------------------------------------
1 | from torch.utils.tensorboard import SummaryWriter
2 |
3 | import M2TR.utils.distributed as du
4 |
5 | class TensorBoardWriter(SummaryWriter):
6 | def __init__(self, log_dir=None, comment='', purge_step=None, max_queue=10, flush_secs=120, filename_suffix=''):
7 | super().__init__(log_dir, comment, purge_step, max_queue, flush_secs, filename_suffix)
8 | print('h')
9 | self.is_master_proc = du.is_master_proc(du.get_world_size())
10 | print('hereeee')
11 | print(self.is_master_proc)
12 |
13 | def add_scalar(self, tag, scalar_value, global_step=None, walltime=None, new_style=False, double_precision=False):
14 | if self.is_master_proc:
15 | super().add_scalar(tag, scalar_value, global_step, walltime, new_style, double_precision)
16 |
--------------------------------------------------------------------------------
/configs/m2tr.yaml:
--------------------------------------------------------------------------------
1 | NUM_GPUS: 8
2 | TRAIN:
3 | ENABLE: True
4 | MAX_EPOCH: 20
5 | CHECKPOINT_PERIOD: 1
6 | CHECKPOINT_EPOCH_RESET: True
7 | TEST:
8 | ENABLE: True
9 | DATASET:
10 | DATASET_NAME: FFDF
11 | ROOT_DIR: path/to/your/dataset
12 | TRAIN_INFO_TXT: 'path/to/your/train.txt'
13 | VAL_INFO_TXT: 'path/to/your/val.txt'
14 | TEST_INFO_TXT: 'path/to/your/test.txt'
15 | IMG_SIZE: 320
16 | SCALE_RATE: 8/7
17 | ROTATE_ANGLE: 10
18 | CUTOUT_H: 10
19 | CUTOUT_W: 10
20 | COMPRESSION_LOW: 65
21 | COMPRESSION_HIGH: 80
22 | DATALOADER:
23 | BATCH_SIZE: 8
24 | NUM_WORKERS: 24
25 | LOSS:
26 | LOSS_FUN: FocalLoss
27 | LOSS_WEIGHT: 1
28 | logits: True
29 | MODEL:
30 | MODEL_NAME: M2TR
31 | PRETRAINED: 'imagenet'
32 | ESCAPE: ''
33 | IMG_SIZE: 320
34 | BACKBONE: efficientnet-b4
35 | DEPTH: 4
36 | TEXTURE_LAYER: b2
37 | FEATURE_LAYER: final
38 | NUM_CLASSES: 2
39 | DROP_RATIO: 0.5
40 | HAS_DECODER: False
41 | OPTIMIZER:
42 | OPTIMIZER_METHOD: sgd
43 | BASE_LR: 0.0005
44 | EPS: 0.00000001
45 | MOMENTUM: 0.9
46 |
--------------------------------------------------------------------------------
/M2TR/models/modules/conv_block.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 |
4 |
5 | class Deconv(nn.Module):
6 | def __init__(self, input_channel, output_channel, kernel_size=3, padding=0):
7 | super().__init__()
8 | self.conv = nn.Conv2d(
9 | input_channel,
10 | output_channel,
11 | kernel_size=kernel_size,
12 | stride=1,
13 | padding=padding,
14 | )
15 |
16 | self.leaky_relu = nn.LeakyReLU(0.2, inplace=True)
17 |
18 | def forward(self, x):
19 | x = F.interpolate(
20 | x, scale_factor=2, mode='bilinear', align_corners=True
21 | )
22 | out = self.conv(x)
23 | out = self.leaky_relu(out)
24 | return out
25 |
26 |
27 | class ConvBN(nn.Module):
28 | def __init__(self, in_features, out_features):
29 | self.conv = nn.Conv2d(in_features, out_features, 3, padding=1)
30 | self.bn = nn.BatchNorm2d(out_features)
31 |
32 | def forward(self, x):
33 | out = self.conv(x)
34 | out = self.bn(out)
35 | return out
36 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Junke Wang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py==1.0.0
2 | albumentations==1.1.0
3 | cachetools==5.0.0
4 | certifi==2021.10.8
5 | charset-normalizer==2.0.12
6 | fvcore==0.1.5.post20220414
7 | google-auth==2.6.5
8 | google-auth-oauthlib==0.4.6
9 | grpcio==1.44.0
10 | idna==3.3
11 | imageio==2.16.2
12 | importlib-metadata==4.11.3
13 | iopath==0.1.9
14 | joblib==1.1.0
15 | kornia==0.6.4
16 | Markdown==3.3.6
17 | networkx==2.8
18 | numpy==1.22.3
19 | oauthlib==3.2.0
20 | opencv-python-headless==4.5.5.64
21 | packaging==21.3
22 | Pillow==9.1.0
23 | portalocker==2.4.0
24 | protobuf==3.20.0
25 | pyasn1==0.4.8
26 | pyasn1-modules==0.2.8
27 | pyparsing==3.0.8
28 | PyWavelets==1.3.0
29 | PyYAML==6.0
30 | qudida==0.0.4
31 | requests==2.27.1
32 | requests-oauthlib==1.3.1
33 | rsa==4.8
34 | scikit-image==0.19.2
35 | scikit-learn==1.0.2
36 | scipy==1.8.0
37 | simplejson==3.17.6
38 | six==1.16.0
39 | tabulate==0.8.9
40 | tensorboard==2.8.0
41 | tensorboard-data-server==0.6.1
42 | tensorboard-plugin-wit==1.8.1
43 | termcolor==1.1.0
44 | threadpoolctl==3.1.0
45 | tifffile==2022.4.8
46 | timm==0.5.4
47 | torch==1.11.0+cu113
48 | torchaudio==0.11.0+cu113
49 | torchvision==0.12.0+cu113
50 | tqdm==4.64.0
51 | typing_extensions==4.1.1
52 | urllib3==1.26.9
53 | Werkzeug==2.1.1
54 | yacs==0.1.8
55 | zipp==3.8.0
56 |
--------------------------------------------------------------------------------
/M2TR/datasets/dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | from torch.utils.data import Dataset
5 |
6 |
7 | class DeepFakeDataset(Dataset):
8 | def __init__(
9 | self,
10 | dataset_cfg,
11 | mode='train',
12 | ):
13 | dataset_name = dataset_cfg['DATASET_NAME']
14 | assert dataset_name in [
15 | 'ForgeryNet',
16 | 'FFDF',
17 | 'CelebDF',
18 | ], 'no dataset'
19 | assert mode in [
20 | 'train',
21 | 'val',
22 | 'test',
23 | ], 'wrong mode'
24 | self.dataset_name = dataset_name
25 | self.mode = mode
26 | self.dataset_cfg = dataset_cfg
27 | self.root_dir = dataset_cfg['ROOT_DIR']
28 | info_txt_tag = mode.upper() + '_INFO_TXT'
29 | if dataset_cfg[info_txt_tag] != '':
30 | self.info_txt = dataset_cfg[info_txt_tag]
31 | else:
32 | self.info_txt = os.path.join(
33 | self.root_dir,
34 | self.dataset_name + '_splits_' + mode + '.txt',
35 | )
36 | self.info_list = open(self.info_txt).readlines()
37 |
38 | def __len__(self):
39 | return len(self.info_list)
40 |
41 | def label_to_one_hot(self, x, class_count):
42 | return torch.eye(class_count)[x.long(), :]
43 |
--------------------------------------------------------------------------------
/M2TR/models/modules/gram_block.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 |
4 | class GramMatrix(nn.Module):
5 | def __init__(self):
6 | super(GramMatrix, self).__init__()
7 |
8 | def forward(self, x):
9 | b, c, h, w = x.size()
10 | feature = x.view(b, c, h * w)
11 | feature_t = feature.transpose(1, 2)
12 | gram = feature.bmm(feature_t)
13 | b, h, w = gram.size()
14 | gram = gram.view(b, 1, h, w)
15 | return gram
16 |
17 |
18 | class GramBlock(nn.Module):
19 | def __init__(self, in_channels):
20 | super(GramBlock, self).__init__()
21 | self.conv1 = nn.Conv2d(
22 | in_channels, 32, kernel_size=3, stride=1, padding=2
23 | )
24 | self.gramMatrix = GramMatrix()
25 | self.conv2 = nn.Sequential(
26 | nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=2),
27 | nn.BatchNorm2d(16),
28 | nn.ReLU(inplace=True),
29 | )
30 | self.conv3 = nn.Sequential(
31 | nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=2),
32 | nn.BatchNorm2d(32),
33 | nn.ReLU(inplace=True),
34 | )
35 | self.pool = nn.AdaptiveAvgPool2d((1, 1))
36 |
37 | def forward(self, x):
38 | x = self.conv1(x)
39 | x = self.gramMatrix(x)
40 | x = self.conv2(x)
41 | x = self.conv3(x)
42 | x = self.pool(x)
43 | return x
44 |
--------------------------------------------------------------------------------
/M2TR/datasets/CelebDF.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 |
5 | from M2TR.datasets.dataset import DeepFakeDataset
6 | from M2TR.utils.registries import DATASET_REGISTRY
7 |
8 | from .utils import get_image_from_path
9 |
10 | '''
11 | DATASET:
12 | DATASET_NAME: CelebDF
13 | ROOT_DIR: /some_where/celeb-df-v2
14 | TRAIN_INFO_TXT: '/some_where/celeb-df-v2/splits/train.txt'
15 | VAL_INFO_TXT: '/some_where/celeb-df-v2/splits/eval.txt'
16 | TEST_INFO_TXT: '/some_where/celeb-df-v2/splits/eval.txt'
17 | IMG_SIZE: 380
18 | SCALE_RATE: 1.0
19 | '''
20 |
21 |
22 | @DATASET_REGISTRY.register()
23 | class CelebDF(DeepFakeDataset):
24 | def __getitem__(self, idx):
25 | info_line = self.info_list[idx]
26 | image_info = info_line.strip('\n').split()
27 | image_path = image_info[0]
28 | image_abs_path = os.path.join(self.root_dir, image_path)
29 |
30 | img, _ = get_image_from_path(
31 | image_abs_path, None, self.mode, self.dataset_cfg
32 | )
33 | img_label_binary = int(image_info[1])
34 |
35 | sample = {
36 | 'img': img,
37 | 'bin_label': [int(img_label_binary)],
38 | }
39 |
40 | sample['img'] = torch.FloatTensor(sample['img'])
41 | sample['bin_label'] = torch.FloatTensor(sample['bin_label'])
42 | sample['bin_label_onehot'] = self.label_to_one_hot(
43 | sample['bin_label'], 2
44 | ).squeeze()
45 | return sample
46 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # M2TR: Multi-modal Multi-scale Transformers for DeepfakeDetection
2 |
3 | ## Introduction
4 |
5 | This is the official pytorch implementation of [Multi-modal Multi-scale for Deepfake detection](https://arxiv.org/abs/2104.09770), which is accepted by ICMR 2022.
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 | ## Model Zoo
14 |
15 | The baseline models on three versions of [FF-DF](https://github.com/ondyari/FaceForensics) dataset and [Celeb-DF](https://github.com/yuezunli/celeb-deepfakeforensics) are provided.
16 |
17 | | Dataset | Accuracy | model |
18 | | --- | --- | --- |
19 | | FF++ (Raw) | 99.50 | [FF-RAW](https://drive.google.com/file/d/1_HaPE6r7Zzof2mmLmmc4fbIbqyWs17S0/view?usp=sharing) |
20 | | FF++ (C23) | 97.93 | [FF-C23](https://drive.google.com/file/d/1XRIllA6p5YnITztl1burwcr5l7LAcpqv/view?usp=sharing)
21 | | FF++ (C40) | 92.89 | [FF-C40](https://drive.google.com/file/d/1xhclIjoh8GkVvoVefjDY-itdaV0VaMxY/view?usp=sharing) |
22 | | CelebDF |99.76 |[CelebDF](https://drive.google.com/file/d/1_HaPE6r7Zzof2mmLmmc4fbIbqyWs17S0/view?usp=sharing) |
23 |
24 | ## Training and Evaluation
25 |
26 | ```
27 | python run.py --cfg m2tr.yaml
28 | ```
29 |
30 | ## License
31 |
32 | This project is released under the [MIT license](https://opensource.org/licenses/MIT).
33 |
34 |
35 | ## Citations
36 |
37 | ```bibtex
38 | @article{wang2021m2tr,
39 | inproceedings={M2TR: Multi-modal Multi-scale Transformers for Deepfake Detection},
40 | author={Wang, Junke and Wu, Zuxuan and Chen, Jingjing and Jiang, Yu-Gang},
41 | booktitle={ICMR},
42 | year={2022}
43 | }
44 | ```
45 |
--------------------------------------------------------------------------------
/M2TR/datasets/FFDF.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 |
5 | from M2TR.datasets.dataset import DeepFakeDataset
6 | from M2TR.datasets.utils import (
7 | get_image_from_path,
8 | get_mask_path_from_img_path,
9 | )
10 | from M2TR.utils.registries import DATASET_REGISTRY
11 |
12 | '''
13 | DATASET:
14 | DATASET_NAME: FFDF
15 | ROOT_DIR: /some_where/FF++/face
16 | TRAIN_INFO_TXT: '/some_where/train_face_c23.txt'
17 | VAL_INFO_TXT: '/some_where/val_face_c23.txt'
18 | TEST_INFO_TXT: '/some_where/test_face_c23.txt'
19 | IMG_SIZE: 380
20 | SCALE_RATE: 1.0
21 | ROTATE_ANGLE: 10
22 | CUTOUT_H: 10
23 | CUTOUT_W: 10
24 | COMPRESSION_LOW: 65
25 | COMPRESSION_HIGH: 80
26 | '''
27 |
28 |
29 | @DATASET_REGISTRY.register()
30 | class FFDF(DeepFakeDataset):
31 | def __getitem__(self, idx):
32 | info_line = self.info_list[idx]
33 | image_info = info_line.strip('\n').split()
34 | image_path = image_info[0]
35 | image_abs_path = os.path.join(self.root_dir, image_path)
36 |
37 | mask_abs_path = get_mask_path_from_img_path(
38 | self.dataset_name, self.root_dir, image_path
39 | )
40 | img, mask = get_image_from_path(
41 | image_abs_path, mask_abs_path, self.mode, self.dataset_cfg
42 | )
43 | img_label_binary = int(image_info[1])
44 |
45 | sample = {
46 | 'img': img,
47 | 'bin_label': [int(img_label_binary)],
48 | }
49 |
50 | sample['img'] = torch.FloatTensor(sample['img'])
51 | sample['bin_label'] = torch.FloatTensor(sample['bin_label'])
52 | sample['bin_label_onehot'] = self.label_to_one_hot(
53 | sample['bin_label'], 2
54 | ).squeeze()
55 | sample['mask'] = torch.FloatTensor(mask)
56 |
57 | return sample
58 |
--------------------------------------------------------------------------------
/M2TR/models/modules/transformer_block.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 |
4 | class Mlp(nn.Module):
5 | def __init__(
6 | self,
7 | in_features,
8 | hidden_features=None,
9 | out_features=None,
10 | act_layer=nn.GELU,
11 | drop=0.0,
12 | ):
13 | super().__init__()
14 | out_features = out_features or in_features
15 | hidden_features = hidden_features or in_features
16 | self.fc1 = nn.Linear(in_features, hidden_features)
17 | self.act = act_layer()
18 | self.fc2 = nn.Linear(hidden_features, out_features)
19 | self.drop = nn.Dropout(drop)
20 |
21 | def forward(self, x):
22 | x = self.fc1(x)
23 | x = self.act(x)
24 | x = self.drop(x)
25 | x = self.fc2(x)
26 | x = self.drop(x)
27 | return x
28 |
29 |
30 | class FeedForward1D(nn.Module):
31 | def __init__(self, dim, hidden_dim, dropout=0.0):
32 | super(FeedForward1D, self).__init__()
33 | self.net = nn.Sequential(
34 | nn.Linear(dim, hidden_dim),
35 | nn.GELU(),
36 | nn.Dropout(dropout),
37 | nn.Linear(hidden_dim, dim),
38 | nn.Dropout(dropout),
39 | )
40 |
41 | def forward(self, x):
42 | return self.net(x)
43 |
44 |
45 | class FeedForward2D(nn.Module):
46 | def __init__(self, in_channel, out_channel):
47 | super(FeedForward2D, self).__init__()
48 | self.conv = nn.Sequential(
49 | nn.Conv2d(
50 | in_channel, out_channel, kernel_size=3, padding=2, dilation=2
51 | ),
52 | nn.BatchNorm2d(out_channel),
53 | nn.LeakyReLU(0.2, inplace=True),
54 | nn.Conv2d(out_channel, out_channel, kernel_size=3, padding=1),
55 | nn.BatchNorm2d(out_channel),
56 | nn.LeakyReLU(0.2, inplace=True),
57 | )
58 |
59 | def forward(self, x):
60 | x = self.conv(x)
61 | return x
62 |
--------------------------------------------------------------------------------
/M2TR/utils/optimizer.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | def build_optimizer(optim_params, cfg):
4 | """
5 | Construct a stochastic gradient descent or ADAM optimizer with momentum.
6 | Details can be found in:
7 | Herbert Robbins, and Sutton Monro. "A stochastic approximation method."
8 | and
9 | Diederik P.Kingma, and Jimmy Ba.
10 | "Adam: A Method for Stochastic Optimization."
11 | Args:
12 | model (model): model to perform stochastic gradient descent
13 | optimization or ADAM optimization.
14 | cfg (dict): configs of hyper-parameters of SGD or ADAM, includes base
15 | learning rate, momentum, weight_decay, dampening, and etc.
16 | """
17 | optimizer_cfg = cfg['OPTIMIZER']
18 | if optimizer_cfg['OPTIMIZER_METHOD'] == "sgd":
19 | return torch.optim.SGD(
20 | optim_params,
21 | lr=optimizer_cfg['BASE_LR'],
22 | momentum=optimizer_cfg['MOMENTUM'],
23 | # dampening=optimizer_cfg['DAMPENING'],
24 | # weight_decay=optimizer_cfg['WEIGHT_DECAY'],
25 | # nesterov=optimizer_cfg['NESTEROV'],
26 | )
27 | elif optimizer_cfg['OPTIMIZER_METHOD'] == "rmsprop":
28 | return torch.optim.RMSprop(
29 | optim_params,
30 | lr=optimizer_cfg['BASE_LR'],
31 | alpha=optimizer_cfg['ALPHA'],
32 | eps=optimizer_cfg['EPS'],
33 | weight_decay=optimizer_cfg['WEIGHT_DECAY'],
34 | momentum=optimizer_cfg['MOMENTUM'],
35 | )
36 | elif optimizer_cfg['OPTIMIZER_METHOD'] == "adam":
37 | return torch.optim.Adam(
38 | optim_params,
39 | lr=optimizer_cfg['BASE_LR'],
40 | betas=optimizer_cfg['ADAM_BETAS'],
41 | eps=optimizer_cfg['EPS'],
42 | weight_decay=optimizer_cfg['WEIGHT_DECAY'],
43 | amsgrad=optimizer_cfg['AMSGRAD'],
44 | )
45 | elif optimizer_cfg['OPTIMIZER_METHOD'] == "adamw":
46 | return torch.optim.AdamW(
47 | optim_params,
48 | lr=optimizer_cfg['BASE_LR'],
49 | betas=optimizer_cfg['ADAM_BETAS'],
50 | eps=optimizer_cfg['EPS'],
51 | weight_decay=optimizer_cfg['WEIGHT_DECAY'],
52 | amsgrad=optimizer_cfg['AMSGRAD'],
53 | )
54 | else:
55 | raise NotImplementedError(
56 | "Does not support {} optimizer".format(
57 | optimizer_cfg['OPTIMIZER_METHOD']
58 | )
59 | )
60 |
--------------------------------------------------------------------------------
/configs/default.yaml:
--------------------------------------------------------------------------------
1 | NUM_GPUS: 4
2 | NUM_SHARDS: 1
3 | SHARD_ID: 0
4 | DIST_BACKEND: nccl
5 | RNG_SEED: 0
6 | LOG_FILE_PATH: './logs'
7 | INIT_METHOD: 'tcp://localhost:9999'
8 | AMP_ENABLE: False
9 | TRAIN:
10 | ENABLE: True
11 | MAX_EPOCH: 20
12 | EVAL_PERIOD: 1
13 | CHECKPOINT_PERIOD: 5
14 | CHECKPOINT_EPOCH_RESET: FALSE
15 | CHECKPOINT_LOAD_PATH: ''
16 | CHECKPOINT_SAVE_PATH: '.'
17 | TEST:
18 | ENABLE: True
19 | CHECKPOINT_TEST_PATH: ''
20 | DATASET:
21 | DATASET_NAME: FFDF
22 | ROOT_DIR: /share/test/ouyang/FF++/face
23 | TRAIN_INFO_TXT: '/share/test/ouyang/FF++/splits/train_face_raw.txt'
24 | VAL_INFO_TXT: '/share/test/ouyang/FF++/splits/test_face_raw.txt'
25 | TEST_INFO_TXT: '/share/test/ouyang/FF++/splits/test_face_raw.txt'
26 | IMG_SIZE: 380
27 | SCALE_RATE: 1.0
28 | TRAIN_AUGMENTATIONS:
29 | COMPOSE:
30 | [
31 | [
32 | ChannelDropout,
33 | ToGray
34 | ],
35 | [
36 | ColorJitter,
37 | RandomBrightnessContrast,
38 | HueSaturationValue,
39 | CLAHE,
40 | RandomGamma,
41 | Sharpen
42 | ],
43 | [
44 | Blur,
45 | MotionBlur,
46 | GaussianBlur,
47 | GlassBlur
48 | ],
49 | GaussNoise,
50 | HorizontalFlip,
51 | Rotate,
52 | [
53 | RandomFog,
54 | RandomRain,
55 | RandomSnow,
56 | RandomSunFlare,
57 | RandomToneCurve
58 | ],
59 | CoarseDropout,
60 | ImageCompression,
61 | Normalize
62 | ]
63 | ROTATE_PARAMS: [10]
64 | COARSEDROPOUT_PARAMS: [10, 10]
65 | IMAGECOMPRESSION_PARAMS: [65, 80]
66 | NORMALIZE_PARAMS: [[0.485, 0.456, 0.406], [0.229, 0.224, 0.225]]
67 | TEST_AUGMENTATIONS:
68 | COMPOSE: [Resize, Normalize]
69 | NORMALIZE_PARAMS: [[0.485, 0.456, 0.406], [0.229, 0.224, 0.225]]
70 | DATALOADER:
71 | BATCH_SIZE: 16
72 | NUM_WORKERS: 8
73 | PIN_MEM: True
74 | OPTIMIZER:
75 | OPTIMIZER_METHOD: adamw
76 | BASE_LR: 0.001
77 | ADAM_BETAS: [0.9, 0.999]
78 | EPS: 0.00000001
79 | WEIGHT_DECAY: 0.01
80 | AMSGRAD: False
81 | SCHEDULER:
82 | LR_NOISE_PCT: 0.67
83 | LR_NOISE_STD: 1.0
84 | SEED: 42
85 | LR_CYCLE_MUL: 1.0
86 | LR_CYCLE_DECAY: 0.1
87 | LR_CYCLE_LIMIT: 1
88 | SCHEDULER_TYPE: cosine
89 | MIN_LR: 0.000001
90 | WARMUP_LR: 0
91 | WARMUP_EPOCHS: 0
92 | COOLDOWN_EPOCHS: 10
93 | LR_K_DECAY: 1.0
94 | SCHEDULER_STEP: 5
95 | SCHEDULER_GAMMA: 0.5
--------------------------------------------------------------------------------
/M2TR/utils/build_helper.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import DataLoader
3 |
4 |
5 | import M2TR.models
6 | import M2TR.datasets
7 | import M2TR.utils.distributed as du
8 | import M2TR.utils.logging as logging
9 | from M2TR.utils.registries import (
10 | DATASET_REGISTRY,
11 | LOSS_REGISTRY,
12 | MODEL_REGISTRY,
13 | )
14 |
15 | logger = logging.get_logger(__name__)
16 |
17 |
18 | def build_model(cfg, gpu_id=None):
19 | # Construct the model
20 | model_cfg = cfg['MODEL']
21 | name = model_cfg['MODEL_NAME']
22 | logger.info('MODEL_NAME: ' + name)
23 | model = MODEL_REGISTRY.get(name)(model_cfg)
24 |
25 | assert torch.cuda.is_available(), "Cuda is not available."
26 | assert (
27 | cfg['NUM_GPUS'] <= torch.cuda.device_count()
28 | ), "Cannot use more GPU devices than available"
29 |
30 | if gpu_id is None:
31 | # Determine the GPU used by the current process
32 | cur_device = torch.cuda.current_device()
33 | else:
34 | cur_device = gpu_id
35 | # Transfer the model to the current GPU device
36 | model = model.cuda(device=cur_device)
37 | # Use multi-process data parallel model in the multi-gpu setting
38 | if cfg['NUM_GPUS'] > 1:
39 | # Make model replica operate on the current device
40 | model = torch.nn.parallel.DistributedDataParallel(
41 | module=model, device_ids=[cur_device], output_device=cur_device, find_unused_parameters=True
42 | )
43 |
44 | return model
45 |
46 |
47 | def build_loss_fun(cfg):
48 | loss_cfg = cfg['LOSS']
49 | name = loss_cfg['LOSS_FUN']
50 | logger.info('LOSS_FUN: ' + name)
51 | loss_fun = LOSS_REGISTRY.get(name)(loss_cfg)
52 | return loss_fun
53 |
54 |
55 | def build_dataset(mode, cfg):
56 | dataset_cfg = cfg['DATASET']
57 | name = dataset_cfg['DATASET_NAME']
58 | logger.info('DATASET_NAME: ' + name + ' ' + mode)
59 | return DATASET_REGISTRY.get(name)(dataset_cfg, mode)
60 |
61 |
62 | def build_dataloader(dataset, mode, cfg):
63 | dataloader_cfg = cfg['DATALOADER']
64 | num_tasks = du.get_world_size()
65 | global_rank = du.get_rank()
66 |
67 | sampler = torch.utils.data.DistributedSampler(
68 | dataset,
69 | num_replicas=num_tasks,
70 | rank=global_rank,
71 | shuffle=True if mode == 'train' else False,
72 | )
73 |
74 | return DataLoader(
75 | dataset,
76 | batch_size=dataloader_cfg['BATCH_SIZE'],
77 | sampler=sampler,
78 | num_workers=dataloader_cfg['NUM_WORKERS'],
79 | pin_memory=dataloader_cfg['PIN_MEM'],
80 | drop_last=True if mode == 'train' else False,
81 | )
82 |
--------------------------------------------------------------------------------
/M2TR/models/modules/head.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | from M2TR.models.modules.conv_block import Deconv
4 |
5 |
6 | class Classifier2D(nn.Module):
7 | def __init__(
8 | self,
9 | dim_in,
10 | num_classes,
11 | dropout_rate=0.0,
12 | act_func="softmax",
13 | ):
14 | """
15 | Perform linear projection and activation as head for tranformers.
16 | Args:
17 | dim_in (int): the channel dimension of the input to the head.
18 | num_classes (int): the channel dimensions of the output to the head.
19 | dropout_rate (float): dropout rate. If equal to 0.0, perform no
20 | dropout.
21 | act_func (string): activation function to use. 'softmax': applies
22 | softmax on the output. 'sigmoid': applies sigmoid on the output.
23 | """
24 | super(Classifier2D, self).__init__()
25 | if dropout_rate > 0.0:
26 | self.dropout = nn.Dropout(dropout_rate)
27 | self.projection = nn.Linear(dim_in, num_classes, bias=True)
28 |
29 | # Softmax for evaluation and testing.
30 | if act_func == "softmax":
31 | self.act = nn.Softmax(dim=1)
32 | elif act_func == "sigmoid":
33 | self.act = nn.Sigmoid()
34 | else:
35 | raise NotImplementedError(
36 | "{} is not supported as an activation"
37 | "function.".format(act_func)
38 | )
39 |
40 | def forward(self, x):
41 | if hasattr(self, "dropout"):
42 | x = self.dropout(x)
43 | x = self.projection(x)
44 |
45 | if not self.training:
46 | x = self.act(x)
47 | return x
48 |
49 |
50 | class Localizer(nn.Module):
51 | def __init__(self, in_channel, output_channel):
52 | super(self, Localizer).__init__()
53 | self.deconv1 = Deconv(in_channel, in_channel)
54 | hidden_dim = in_channel // 2
55 | self.conv1 = nn.Sequential(
56 | nn.Conv2d(32, hidden_dim, kernel_size=3, stride=1, padding=1),
57 | nn.LeakyReLU(0.2, inplace=True),
58 | )
59 | self.deconv2 = Deconv(hidden_dim, hidden_dim, kernel_size=3, padding=1)
60 | self.conv2 = nn.Sequential(
61 | nn.LeakyReLU(0.2, inplace=True),
62 | nn.Conv2d(
63 | hidden_dim, output_channel, kernel_size=3, stride=1, padding=1
64 | ),
65 | )
66 | self.sigmoid = nn.Sigmoid()
67 |
68 | def forward(self, x):
69 | out = self.deconv1(x)
70 | out = self.conv1(out)
71 | out = self.deconv2(out)
72 | out = self.conv2(out)
73 | return self.sigmoid(out)
74 |
--------------------------------------------------------------------------------
/M2TR/models/base.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 |
4 | class BaseNetwork(nn.Module):
5 | def __init__(self):
6 | super(BaseNetwork, self).__init__()
7 |
8 | def print_network(self):
9 | if isinstance(self, list):
10 | self = self[0]
11 | num_params = 0
12 | for param in self.parameters():
13 | num_params += param.numel()
14 | print(
15 | 'Network [%s] was created. Total number of parameters: %.1f million. '
16 | 'To see the architecture, do print(network).'
17 | % (type(self).__name__, num_params / 1000000)
18 | )
19 |
20 | def init_weights(self, init_type='normal', gain=0.02):
21 | '''
22 | initialize network's weights
23 | init_type: normal | xavier | kaiming | orthogonal
24 | https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/9451e70673400885567d08a9e97ade2524c700d0/models/networks.py#L39
25 | '''
26 |
27 | def init_func(m):
28 | classname = m.__class__.__name__
29 | if classname.find('InstanceNorm2d') != -1:
30 | if hasattr(m, 'weight') and m.weight is not None:
31 | nn.init.constant_(m.weight.data, 1.0)
32 | if hasattr(m, 'bias') and m.bias is not None:
33 | nn.init.constant_(m.bias.data, 0.0)
34 | elif hasattr(m, 'weight') and (
35 | classname.find('Conv') != -1 or classname.find('Linear') != -1
36 | ):
37 | if init_type == 'normal':
38 | nn.init.normal_(m.weight.data, 0.0, gain)
39 | elif init_type == 'xavier':
40 | nn.init.xavier_normal_(m.weight.data, gain=gain)
41 | elif init_type == 'xavier_uniform':
42 | nn.init.xavier_uniform_(m.weight.data, gain=1.0)
43 | elif init_type == 'kaiming':
44 | nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
45 | elif init_type == 'orthogonal':
46 | nn.init.orthogonal_(m.weight.data, gain=gain)
47 | elif init_type == 'none': # uses pytorch's default init method
48 | m.reset_parameters()
49 | else:
50 | raise NotImplementedError(
51 | 'initialization method [%s] is not implemented'
52 | % init_type
53 | )
54 | if hasattr(m, 'bias') and m.bias is not None:
55 | nn.init.constant_(m.bias.data, 0.0)
56 |
57 | self.apply(init_func)
58 |
59 | for m in self.children():
60 | if hasattr(m, 'init_weights'):
61 | m.init_weights(init_type, gain)
62 |
--------------------------------------------------------------------------------
/M2TR/datasets/ForgeryNet.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 |
5 | from M2TR.datasets.dataset import DeepFakeDataset
6 | from M2TR.datasets.utils import (
7 | get_image_from_path,
8 | get_mask_path_from_img_path,
9 | )
10 | from M2TR.utils.registries import DATASET_REGISTRY
11 |
12 | '''
13 | DATASET:
14 | DATASET_NAME: ForgeryNet
15 | ROOT_DIR: /some_where/ForgeryNet
16 | IMG_SIZE: 380
17 | SCALE_RATE: 1.0
18 | ROTATE_ANGLE: 10
19 | CUTOUT_H: 10
20 | CUTOUT_W: 10
21 | COMPRESSION_LOW: 65
22 | COMPRESSION_HIGH: 80
23 | '''
24 |
25 |
26 | @DATASET_REGISTRY.register()
27 | class ForgeryNet(DeepFakeDataset):
28 | def __init__(self, dataset_cfg, mode='train'):
29 | self.dataset_name = dataset_cfg['DATASET_NAME']
30 | self.mode = mode
31 | self.dataset_cfg = dataset_cfg
32 | info_txt_tag = mode.upper() + '_INFO_TXT'
33 | if mode == 'train':
34 | self.root_dir = os.path.join(dataset_cfg['ROOT_DIR'], 'Training')
35 | if dataset_cfg[info_txt_tag] != '':
36 | self.info_txt = dataset_cfg[info_txt_tag]
37 | else:
38 | self.info_txt = os.path.join(
39 | self.root_dir, 'image_list_train_retina2.txt'
40 | )
41 | else:
42 | self.root_dir = os.path.join(dataset_cfg['ROOT_DIR'], 'Validation')
43 | if dataset_cfg[info_txt_tag] != '':
44 | self.info_txt = dataset_cfg[info_txt_tag]
45 | else:
46 | self.info_txt = os.path.join(self.root_dir, 'image_list.txt')
47 |
48 | info_list = open(self.info_txt).readlines()
49 | self.info_list = info_list
50 |
51 | def __getitem__(self, idx):
52 | info_line = self.info_list[idx]
53 | image_info = info_line.strip('\n').split()
54 | image_path = image_info[0]
55 | image_abs_path = os.path.join(self.root_dir, 'image2', image_path)
56 | mask_abs_path = get_mask_path_from_img_path(
57 | self.dataset_name, self.root_dir, image_path
58 | )
59 | img, mask = get_image_from_path(
60 | image_abs_path, mask_abs_path, self.mode, self.dataset_cfg
61 | )
62 | img_label_binary = int(image_info[1])
63 | img_label_triple = int(image_info[2])
64 | img_label_mul = int(image_info[3])
65 |
66 | sample = {
67 | 'img': img,
68 | 'bin_label': [int(img_label_binary)],
69 | 'tri_label': [int(img_label_triple)],
70 | 'mul_label': [int(img_label_mul)],
71 | }
72 |
73 | sample['img'] = torch.FloatTensor(sample['img'])
74 | sample['bin_label'] = torch.FloatTensor(sample['bin_label'])
75 | sample['bin_label_onehot'] = self.label_to_one_hot(
76 | sample['bin_label'], 2
77 | ).squeeze()
78 | sample['tri_label'] = torch.FloatTensor(sample['tri_label'])
79 | sample['mul_label'] = torch.FloatTensor(sample['mul_label'])
80 | sample['mask'] = torch.FloatTensor(mask)
81 |
82 | return sample
83 |
--------------------------------------------------------------------------------
/tools/utils.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import torch
4 | import yaml
5 |
6 | import M2TR.utils.logging as logging
7 | from M2TR.utils.checkpoint import get_path_to_checkpoint
8 |
9 | logger = logging.get_logger(__name__)
10 |
11 |
12 | def parse_args():
13 | parser = argparse.ArgumentParser(
14 | description="Provide training and testing pipeline."
15 | )
16 | parser.add_argument(
17 | "--cfg",
18 | dest="cfg_file",
19 | help="Path to the config file",
20 | required=True,
21 | type=str,
22 | )
23 | parser.add_argument(
24 | "--shard_id",
25 | dest="shard_id",
26 | help="The shard id of current node, Starts from 0 to NUM_SHARDS - 1",
27 | default=0,
28 | type=int,
29 | )
30 | parser.add_argument(
31 | "--lr",
32 | dest="base_lr",
33 | help="The base learning rate",
34 | type=float,
35 | )
36 | return parser.parse_args()
37 |
38 |
39 | def merge_a_into_b(a, b):
40 | for k, v in a.items():
41 | if isinstance(v, dict) and k in b:
42 | assert isinstance(
43 | b[k], dict
44 | ), "Cannot inherit key '{}' from base!".format(k)
45 | merge_a_into_b(v, b[k])
46 | else:
47 | b[k] = v
48 |
49 |
50 | def load_config(args):
51 | with open('./configs/default.yaml', 'r') as file:
52 | cfg = yaml.safe_load(file)
53 | logger.info('Use cfg_file: ', './configs/' + args.cfg_file)
54 | with open('./configs/' + args.cfg_file, 'r') as file: # TODO use PATH.join?
55 | custom_cfg = yaml.safe_load(file)
56 | merge_a_into_b(custom_cfg, cfg)
57 | if args.shard_id is not None:
58 | cfg['SHARD_ID'] = args.shard_id
59 | if args.base_lr is not None:
60 | cfg['OPTIMIZER']['BASE_LR'] = args.base_lr
61 |
62 | if cfg['TRAIN']['ENABLE']:
63 | cfg['TEST']['CHECKPOINT_TEST_PATH'] = get_path_to_checkpoint(
64 | cfg['TRAIN']['CHECKPOINT_SAVE_PATH'], cfg['TRAIN']['MAX_EPOCH'], cfg
65 | )
66 | cfg['DATASET']['TRAIN_AUGMENTATIONS']['RESIZE_PARAMS'] = [
67 | cfg['DATASET']['IMG_SIZE'],
68 | cfg['DATASET']['IMG_SIZE'],
69 | ]
70 | cfg['DATASET']['TEST_AUGMENTATIONS']['RESIZE_PARAMS'] = [
71 | cfg['DATASET']['IMG_SIZE'],
72 | cfg['DATASET']['IMG_SIZE'],
73 | ]
74 |
75 | logger.info(cfg)
76 | return cfg
77 |
78 |
79 | def launch_func(cfg, func, daemon=False):
80 | if cfg['NUM_GPUS'] > 1:
81 | torch.multiprocessing.spawn(
82 | func,
83 | nprocs=cfg['NUM_GPUS'],
84 | args=(
85 | cfg['NUM_GPUS'],
86 | cfg['INIT_METHOD'],
87 | cfg['SHARD_ID'],
88 | cfg['NUM_SHARDS'],
89 | cfg['DIST_BACKEND'],
90 | cfg,
91 | ),
92 | daemon=daemon,
93 | )
94 | else:
95 | func(
96 | local_rank=0,
97 | num_proc=1,
98 | init_method=cfg['INIT_METHOD'],
99 | shard_id=0,
100 | num_shards=1,
101 | backend=cfg['DIST_BACKEND'],
102 | cfg=cfg,
103 | )
104 |
--------------------------------------------------------------------------------
/M2TR/utils/logging.py:
--------------------------------------------------------------------------------
1 | import atexit
2 | import builtins
3 | import decimal
4 | import functools
5 | import logging
6 | import os
7 | import sys
8 | import time
9 |
10 | import simplejson
11 |
12 | import M2TR.utils.distributed as du
13 | from M2TR.utils.env import pathmgr
14 |
15 |
16 | def _suppress_print():
17 | """
18 | Suppresses printing from the current process.
19 | """
20 |
21 | def print_pass(*objects, sep=" ", end="\n", file=sys.stdout, flush=False):
22 | pass
23 |
24 | builtins.print = print_pass
25 |
26 |
27 | @functools.lru_cache(maxsize=None)
28 | def _cached_log_stream(filename):
29 | # Use 1K buffer if writing to cloud storage.
30 | io = pathmgr.open(
31 | filename, "a", buffering=1024 if "://" in filename else -1
32 | )
33 | atexit.register(io.close)
34 | return io
35 |
36 |
37 | def setup_logging(cfg, mode='train'):
38 | """
39 | Sets up the logging for multiple processes. Only enable the logging for the
40 | master process, and suppress logging for the non-master processes.
41 | """
42 | output_dir = cfg['LOG_FILE_PATH']
43 | cur_time = time.strftime('%Y-%m-%d_%H:%M:%S', time.localtime(time.time()))
44 | file_name = (
45 | cur_time
46 | + '_'
47 | + cfg['MODEL']['MODEL_NAME']
48 | + '_'
49 | + cfg['DATASET']['DATASET_NAME']
50 | + '_'
51 | + str(cfg['OPTIMIZER']['BASE_LR'])
52 | + '_'
53 | + mode
54 | + '.log'
55 | )
56 | # Set up logging format.
57 | _FORMAT = "[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s"
58 |
59 | if du.is_master_proc():
60 | # Enable logging for the master process.
61 | logging.root.handlers = []
62 | else:
63 | # Suppress logging for non-master processes.
64 | _suppress_print()
65 |
66 | logger = logging.getLogger()
67 | logger.setLevel(logging.DEBUG)
68 | logger.propagate = False
69 | plain_formatter = logging.Formatter(
70 | "[%(asctime)s][%(levelname)s] %(filename)s: %(lineno)3d: %(message)s",
71 | datefmt="%m/%d %H:%M:%S",
72 | )
73 |
74 | if du.is_master_proc():
75 | ch = logging.StreamHandler(stream=sys.stdout)
76 | ch.setLevel(logging.DEBUG)
77 | ch.setFormatter(plain_formatter)
78 | logger.addHandler(ch)
79 |
80 | if output_dir is not None and du.is_master_proc(du.get_world_size()):
81 | if not os.path.exists(output_dir):
82 | os.makedirs(output_dir)
83 | filename = os.path.join(output_dir, file_name)
84 | fh = logging.StreamHandler(_cached_log_stream(filename))
85 | fh.setLevel(logging.DEBUG)
86 | fh.setFormatter(plain_formatter)
87 | logger.addHandler(fh)
88 |
89 |
90 | def get_logger(name):
91 | """
92 | Retrieve the logger with the specified name or, if name is None, return a
93 | logger which is the root logger of the hierarchy.
94 | Args:
95 | name (string): name of the logger.
96 | """
97 | return logging.getLogger(name)
98 |
99 |
100 | def log_json_stats(stats):
101 | """
102 | Logs json stats.
103 | Args:
104 | stats (dict): a dictionary of statistical information to log.
105 | """
106 | stats = {
107 | k: decimal.Decimal("{:.5f}".format(v)) if isinstance(v, float) else v
108 | for k, v in stats.items()
109 | }
110 | json_stats = simplejson.dumps(stats, sort_keys=True, use_decimal=True)
111 | logger = get_logger(__name__)
112 | logger.info("json_stats: {:s}".format(json_stats))
113 |
--------------------------------------------------------------------------------
/tools/test.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn.functional as F
4 | from timm.utils import accuracy
5 |
6 | import M2TR.utils.checkpoint as cu
7 | import M2TR.utils.distributed as du
8 | import M2TR.utils.logging as logging
9 | from M2TR.utils.build_helper import (
10 | build_dataloader,
11 | build_dataset,
12 | build_loss_fun,
13 | build_model,
14 | )
15 | from M2TR.utils.meters import AucMetric, MetricLogger
16 |
17 | logger = logging.get_logger(__name__)
18 |
19 |
20 | @torch.no_grad()
21 | def perform_test(
22 | test_loader, model, cfg, cur_epoch=None, writer=None, mode='Test'
23 | ):
24 | criterion = build_loss_fun(cfg)
25 |
26 | metric_logger = MetricLogger(delimiter=" ")
27 | auc_metrics = AucMetric(cfg['NUM_GPUS'])
28 | header = mode + ':'
29 |
30 | model.eval()
31 |
32 | for samples in metric_logger.log_every(test_loader, 10, header):
33 | samples = dict(
34 | zip(
35 | samples,
36 | map(
37 | lambda sample: sample.cuda(non_blocking=True),
38 | samples.values(),
39 | ),
40 | )
41 | )
42 | with torch.cuda.amp.autocast(enabled=cfg['AMP_ENABLE']):
43 | outputs = model(samples)
44 | loss = criterion(outputs, samples)
45 | preds = F.softmax(outputs['logits'], dim=1)[:, 1]
46 | auc_metrics.update(samples['bin_label'].squeeze(dim=1), preds)
47 | acc1 = accuracy(
48 | outputs['logits'], samples['bin_label'], topk=(1,)
49 | )
50 | batch_size = samples['img'].shape[0]
51 | metric_logger.update(loss=loss.item())
52 | metric_logger.meters['acc1'].update(acc1[0].item(), n=batch_size)
53 |
54 | auc_metrics.synchronize_between_processes()
55 | metric_logger.synchronize_between_processes()
56 | if writer and cur_epoch is not None:
57 | writer.add_scalar(
58 | 'acc1', metric_logger.acc1.global_avg, global_step=cur_epoch
59 | )
60 | writer.add_scalar('auc', auc_metrics.auc, global_step=cur_epoch)
61 |
62 | logger.info(
63 | '* Acc@1 {top1.global_avg:.3f} Auc {auc:.3f} loss {losses.global_avg:.3f}'.format(
64 | top1=metric_logger.acc1,
65 | auc=auc_metrics.auc,
66 | losses=metric_logger.loss,
67 | )
68 | )
69 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
70 |
71 |
72 | def test(local_rank, num_proc, init_method, shard_id, num_shards, backend, cfg):
73 | world_size = num_proc * num_shards
74 | rank = shard_id * num_proc + local_rank
75 | try:
76 | torch.distributed.init_process_group(
77 | backend=backend,
78 | init_method=init_method,
79 | world_size=world_size,
80 | rank=rank,
81 | )
82 | except:
83 | pass
84 | torch.cuda.set_device(local_rank)
85 | du.init_distributed_training(cfg)
86 | np.random.seed(cfg['RNG_SEED'])
87 | torch.manual_seed(cfg['RNG_SEED'])
88 |
89 | logging.setup_logging(cfg, mode='test')
90 |
91 | model = build_model(cfg)
92 | cu.load_test_checkpoint(cfg, model)
93 |
94 | test_dataset = build_dataset('test', cfg)
95 | test_loader = build_dataloader(test_dataset, 'test', cfg)
96 |
97 | logger.info("Testing model for {} iterations".format(len(test_loader)))
98 |
99 | test_stats = perform_test(test_loader, model, cfg)
100 | logger.info(
101 | f"Accuracy of the network on the {len(test_dataset)} test images: {test_stats['acc1']:.1f}%"
102 | )
103 |
--------------------------------------------------------------------------------
/M2TR/utils/scheduler.py:
--------------------------------------------------------------------------------
1 |
2 | from timm.scheduler.cosine_lr import CosineLRScheduler
3 | from timm.scheduler.multistep_lr import MultiStepLRScheduler
4 | from timm.scheduler.step_lr import StepLRScheduler
5 | from timm.scheduler.tanh_lr import TanhLRScheduler
6 |
7 |
8 | def build_scheduler(optimizer, cfg):
9 | num_epochs = cfg['TRAIN']['MAX_EPOCH']
10 | scheduler_cfg = cfg['SCHEDULER']
11 |
12 | if 'LR_NOISE' in scheduler_cfg:
13 | lr_noise = scheduler_cfg['LR_NOISE']
14 | if isinstance(lr_noise, (list, tuple)):
15 | noise_range = [n * num_epochs for n in lr_noise]
16 | if len(noise_range) == 1:
17 | noise_range = noise_range[0]
18 | else:
19 | noise_range = lr_noise * num_epochs
20 | else:
21 | noise_range = None
22 | noise_args = dict(
23 | noise_range_t=noise_range,
24 | noise_pct=scheduler_cfg['LR_NOISE_PCT']
25 | if 'LR_NOISE_PCT' in scheduler_cfg
26 | else 0.67,
27 | noise_std=scheduler_cfg['LR_NOISE_STD']
28 | if 'LR_NOISE_STD' in scheduler_cfg
29 | else 1.0,
30 | noise_seed=scheduler_cfg['SEED']
31 | if 'SEED' in scheduler_cfg
32 | else 42,
33 | )
34 | cycle_args = dict(
35 | cycle_mul=scheduler_cfg['LR_CYCLE_MUL']
36 | if 'LR_CYCLE_MUL' in scheduler_cfg
37 | else 1.0,
38 | cycle_decay=scheduler_cfg['LR_CYCLE_DECAY']
39 | if 'LR_CYCLE_DECAY' in scheduler_cfg
40 | else 0.1,
41 | cycle_limit=scheduler_cfg['LR_CYCLE_LIMIT']
42 | if 'LR_CYCLE_LIMIT' in scheduler_cfg
43 | else 1,
44 | )
45 |
46 | lr_scheduler = None
47 |
48 | if scheduler_cfg['SCHEDULER_TYPE'] == 'cosine':
49 | lr_scheduler = CosineLRScheduler(
50 | optimizer,
51 | t_initial=num_epochs,
52 | lr_min=scheduler_cfg['MIN_LR'],
53 | warmup_lr_init=scheduler_cfg['WARMUP_LR'],
54 | warmup_t=scheduler_cfg['WARMUP_EPOCHS'],
55 | k_decay=scheduler_cfg['LR_K_DECAY']
56 | if 'LR_K_DECAY' in scheduler_cfg
57 | else 1.0,
58 | **cycle_args,
59 | **noise_args,
60 | )
61 | num_epochs = (
62 | lr_scheduler.get_cycle_length() + scheduler_cfg['COOLDOWN_EPOCHS']
63 | )
64 |
65 | elif scheduler_cfg['SCHEDULER_TYPE'] == 'tanh':
66 | lr_scheduler = TanhLRScheduler(
67 | optimizer,
68 | t_initial=num_epochs,
69 | lr_min=scheduler_cfg['MIN_LR'],
70 | warmup_lr_init=scheduler_cfg['WARMUP_LR'],
71 | warmup_t=scheduler_cfg['WARMUP_EPOCHS'],
72 | t_in_epochs=True,
73 | **cycle_args,
74 | **noise_args,
75 | )
76 | num_epochs = (
77 | lr_scheduler.get_cycle_length() + scheduler_cfg['COOLDOWN_EPOCHS']
78 | )
79 | elif scheduler_cfg['SCHEDULER_TYPE'] == 'step':
80 | lr_scheduler = StepLRScheduler(
81 | optimizer,
82 | decay_t=scheduler_cfg['DECAY_EPOCHS'],
83 | decay_rate=scheduler_cfg['DECAY_RATE'],
84 | warmup_lr_init=scheduler_cfg['WARMUP_LR'],
85 | warmup_t=scheduler_cfg['WARMUP_EPOCHS'],
86 | **noise_args,
87 | )
88 | elif scheduler_cfg['SCHEDULER_TYPE'] == 'multistep':
89 | lr_scheduler = MultiStepLRScheduler(
90 | optimizer,
91 | decay_t=scheduler_cfg['DECAY_EPOCHS'],
92 | decay_rate=scheduler_cfg['DECAY_RATE'],
93 | warmup_lr_init=scheduler_cfg['WARMUP_LR'],
94 | warmup_t=scheduler_cfg['WARMUP_EPOCHS'],
95 | **noise_args,
96 | )
97 |
98 | return lr_scheduler, num_epochs
99 |
--------------------------------------------------------------------------------
/M2TR/datasets/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 |
4 | import albumentations
5 | import albumentations.pytorch
6 | import numpy as np
7 | import torchvision
8 | from PIL import Image
9 |
10 |
11 | class ResizeRandomCrop:
12 | def __init__(self, img_size=320, scale_rate=8 / 7, p=0.5):
13 | self.img_size = img_size
14 | self.scale_rate = scale_rate
15 | self.p = p
16 |
17 | def __call__(self, image, mask=None):
18 | if random.uniform(0, 1) < self.p:
19 | S1 = int(self.img_size * self.scale_rate)
20 | S2 = S1
21 | resize_func = torchvision.transforms.Resize((S1, S2))
22 | image = resize_func(image)
23 | crop_params = torchvision.transforms.RandomCrop.get_params(
24 | image, (self.img_size, self.img_size)
25 | )
26 | image = torchvision.transforms.functional.crop(image, *crop_params)
27 | if mask is not None:
28 | mask = resize_func(mask)
29 | mask = torchvision.transforms.functional.crop(
30 | mask, *crop_params
31 | )
32 |
33 | else:
34 | resize_func = torchvision.transforms.Resize(
35 | (self.img_size, self.img_size)
36 | )
37 | image = resize_func(image)
38 | if mask is not None:
39 | mask = resize_func(mask)
40 |
41 | return image, mask
42 |
43 |
44 | def transforms_mask(mask_size):
45 | return albumentations.Compose(
46 | [
47 | albumentations.Resize(mask_size, mask_size),
48 | albumentations.pytorch.transforms.ToTensorV2(),
49 | ]
50 | )
51 |
52 |
53 | def get_augmentations_from_list(augs: list, aug_cfg, one_of_p=1):
54 | ops = []
55 | for aug in augs:
56 | if isinstance(aug, list):
57 | op = albumentations.OneOf
58 | param = get_augmentations_from_list(aug, aug_cfg)
59 | param = [param, one_of_p]
60 | else:
61 | op = getattr(albumentations, aug)
62 | param = (
63 | aug_cfg[aug.upper() + '_PARAMS']
64 | if aug.upper() + '_PARAMS' in aug_cfg
65 | else []
66 | )
67 | ops.append(op(*tuple(param)))
68 | return ops
69 |
70 |
71 | def get_transformations(
72 | mode,
73 | dataset_cfg,
74 | ):
75 | if mode == 'train':
76 | aug_cfg = dataset_cfg['TRAIN_AUGMENTATIONS']
77 | else:
78 | aug_cfg = dataset_cfg['TEST_AUGMENTATIONS']
79 | ops = get_augmentations_from_list(aug_cfg['COMPOSE'], aug_cfg)
80 | ops.append(albumentations.pytorch.transforms.ToTensorV2())
81 | augmentations = albumentations.Compose(ops, p=1)
82 | return augmentations
83 |
84 |
85 | def get_image_from_path(img_path, mask_path, mode, dataset_cfg):
86 | img_size = dataset_cfg['IMG_SIZE']
87 | scale_rate = dataset_cfg['SCALE_RATE']
88 |
89 | img = Image.open(img_path)
90 | if mask_path is not None and os.path.exists(mask_path):
91 | mask = Image.open(mask_path).convert('L')
92 | else:
93 | mask = Image.fromarray(np.zeros((img_size, img_size)))
94 |
95 | trans_list = get_transformations(
96 | mode,
97 | dataset_cfg,
98 | )
99 | if mode == 'train':
100 | crop = ResizeRandomCrop(img_size=img_size, scale_rate=scale_rate)
101 | img, mask = crop(image=img, mask=mask)
102 |
103 | img = np.asarray(img)
104 | img = trans_list(image=img)['image']
105 |
106 | mask = np.asarray(mask)
107 | mask = transforms_mask(img_size)(image=mask)['image']
108 |
109 | else:
110 | img = np.asarray(img)
111 | img = trans_list(image=img)['image']
112 | mask = np.asarray(mask)
113 | mask = transforms_mask(img_size)(image=mask)['image']
114 |
115 | return img, mask.float()
116 |
117 |
118 | def get_mask_path_from_img_path(dataset_name, root_dir, img_info):
119 | if dataset_name == 'ForgeryNet':
120 | root_dir = os.path.join(root_dir, 'spatial_localize')
121 | fore_path = img_info.split('/')[0]
122 | if 'train' in fore_path:
123 | img_info = img_info.replace('train_release', 'train_mask_release')
124 | else:
125 | img_info = img_info[20:]
126 |
127 | mask_complete_path = os.path.join(root_dir, img_info)
128 |
129 | elif 'FFDF' in dataset_name:
130 | mask_info = img_info.replace('images', 'masks')
131 | mask_complete_path = os.path.join(root_dir, mask_info)
132 |
133 | return mask_complete_path
134 |
--------------------------------------------------------------------------------
/tools/train.py:
--------------------------------------------------------------------------------
1 | import pprint
2 |
3 | import numpy as np
4 | import torch
5 | from torch.utils.tensorboard import SummaryWriter
6 |
7 | import M2TR.utils.checkpoint as cu
8 | import M2TR.utils.distributed as du
9 | import M2TR.utils.logging as logging
10 | from M2TR.utils.build_helper import (
11 | build_dataloader,
12 | build_dataset,
13 | build_loss_fun,
14 | build_model,
15 | )
16 | from M2TR.utils.meters import EpochTimer, MetricLogger, SmoothedValue
17 | from M2TR.utils.optimizer import build_optimizer
18 | from M2TR.utils.scheduler import build_scheduler
19 | from tools.test import perform_test
20 |
21 | logger = logging.get_logger(__name__)
22 |
23 |
24 | def train_epoch(
25 | train_loader, model, criterion, optimizer, cfg, cur_epoch, cur_iter, writer
26 | ):
27 | model.train()
28 | train_meter = MetricLogger(delimiter=" ")
29 | train_meter.add_meter('lr', SmoothedValue(window_size=1, fmt='{value:.8f}'))
30 | header = 'Epoch: [{}]'.format(cur_epoch)
31 | print_freq = 10
32 |
33 | for samples in train_meter.log_every(train_loader, print_freq, header):
34 | samples = dict(
35 | zip(
36 | samples,
37 | map(
38 | lambda sample: sample.cuda(non_blocking=True),
39 | samples.values(),
40 | ),
41 | )
42 | )
43 |
44 | with torch.cuda.amp.autocast(enabled=cfg['AMP_ENABLE']):
45 | outputs = model(samples)
46 | loss = criterion(outputs, samples)
47 |
48 | loss_value = loss.item()
49 |
50 | optimizer.zero_grad()
51 | loss.backward()
52 | optimizer.step()
53 |
54 | torch.cuda.synchronize()
55 |
56 | if 'CLIP_GRAD_L2NORM' in cfg['TRAIN']: # TODO
57 | torch.nn.utils.clip_grad_norm_(
58 | model.parameters(), cfg['TRAIN']['CLIP_GRAD_L2NORM']
59 | )
60 |
61 | if writer:
62 | writer.add_scalar('train loss', loss_value, global_step=cur_iter)
63 | writer.add_scalar(
64 | 'lr', optimizer.param_groups[0]["lr"], global_step=cur_iter
65 | )
66 | train_meter.update(loss=loss_value)
67 | train_meter.update(lr=optimizer.param_groups[0]["lr"])
68 | cur_iter = cur_iter + 1
69 |
70 | train_meter.synchronize_between_processes()
71 | logger.info("Averaged stats:" + str(train_meter))
72 | return {
73 | k: meter.global_avg for k, meter in train_meter.meters.items()
74 | }, cur_iter
75 |
76 |
77 | def train(
78 | local_rank, num_proc, init_method, shard_id, num_shards, backend, cfg
79 | ):
80 | world_size = num_proc * num_shards
81 | rank = shard_id * num_proc + local_rank
82 | torch.distributed.init_process_group(
83 | backend=backend,
84 | init_method=init_method,
85 | world_size=world_size,
86 | rank=rank,
87 | )
88 | torch.cuda.set_device(local_rank)
89 | du.init_distributed_training(cfg)
90 | np.random.seed(cfg['RNG_SEED'])
91 | torch.manual_seed(cfg['RNG_SEED'])
92 |
93 | logging.setup_logging(cfg)
94 | logger.info(pprint.pformat(cfg))
95 | if du.is_master_proc(du.get_world_size()):
96 | writer = SummaryWriter(cfg['LOG_FILE_PATH'])
97 | else:
98 | writer = None
99 |
100 | model = build_model(cfg)
101 | optimizer = build_optimizer(model.parameters(), cfg)
102 | scheduler, _ = build_scheduler(optimizer, cfg) # TODO _?
103 | loss_fun = build_loss_fun(cfg)
104 | train_dataset = build_dataset('train', cfg)
105 | train_loader = build_dataloader(train_dataset, 'train', cfg)
106 | val_dataset = build_dataset('val', cfg)
107 | val_loader = build_dataloader(val_dataset, 'val', cfg)
108 |
109 | start_epoch = cu.load_train_checkpoint(model, optimizer, scheduler, cfg)
110 |
111 | logger.info("Start epoch: {}".format(start_epoch + 1))
112 | epoch_timer = EpochTimer()
113 |
114 | cur_iter = 0
115 |
116 | for cur_epoch in range(start_epoch, cfg['TRAIN']['MAX_EPOCH']):
117 | logger.info('========================================================')
118 | train_loader.sampler.set_epoch(cur_epoch)
119 | val_loader.sampler.set_epoch(cur_epoch)
120 | epoch_timer.epoch_tic()
121 | _, cur_iter = train_epoch(
122 | train_loader,
123 | model,
124 | loss_fun,
125 | optimizer,
126 | cfg,
127 | cur_epoch,
128 | cur_iter,
129 | writer,
130 | )
131 | epoch_timer.epoch_toc()
132 | perform_test(val_loader, model, cfg, cur_epoch, writer, mode='Val')
133 | logger.info(
134 | f"Epoch {cur_epoch} takes {epoch_timer.last_epoch_time():.2f}s. Epochs "
135 | f"from {start_epoch} to {cur_epoch} take "
136 | f"{epoch_timer.avg_epoch_time():.2f}s in average and "
137 | f"{epoch_timer.median_epoch_time():.2f}s in median."
138 | )
139 | logger.info(
140 | f"For epoch {cur_epoch}, each iteraction takes "
141 | f"{epoch_timer.last_epoch_time()/len(train_loader):.2f}s in average. "
142 | f"From epoch {start_epoch} to {cur_epoch}, each iteraction takes "
143 | f"{epoch_timer.avg_epoch_time()/len(train_loader):.2f}s in average."
144 | )
145 |
146 | scheduler.step(cur_epoch)
147 |
148 | is_checkp_epoch = cu.is_checkpoint_epoch(cfg, cur_epoch)
149 |
150 | if is_checkp_epoch:
151 | cu.save_checkpoint(model, optimizer, scheduler, cur_epoch, cfg)
152 |
153 | if writer:
154 | writer.flush()
155 | writer.close()
156 |
--------------------------------------------------------------------------------
/M2TR/models/xception.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | from M2TR.utils.registries import MODEL_REGISTRY
8 |
9 | '''
10 | MODEL:
11 | MODEL_NAME: Xception
12 | PRETRAINED: imagenet
13 | ESCAPE: ''
14 | '''
15 |
16 |
17 | class SeparableConv2d(nn.Module):
18 | def __init__(
19 | self,
20 | in_channels,
21 | out_channels,
22 | kernel_size=1,
23 | stride=1,
24 | padding=0,
25 | dilation=1,
26 | bias=False,
27 | ):
28 | super(SeparableConv2d, self).__init__()
29 | self.conv1 = nn.Conv2d(
30 | in_channels,
31 | in_channels,
32 | kernel_size,
33 | stride,
34 | padding,
35 | dilation,
36 | groups=in_channels,
37 | bias=bias,
38 | )
39 | self.pointwise = nn.Conv2d(
40 | in_channels, out_channels, 1, 1, 0, 1, 1, bias=bias
41 | )
42 |
43 | def forward(self, x):
44 | x = self.conv1(x)
45 | x = self.pointwise(x)
46 | return x
47 |
48 |
49 | class Block(nn.Module):
50 | def __init__(
51 | self,
52 | in_filters,
53 | out_filters,
54 | reps,
55 | strides=1,
56 | start_with_relu=True,
57 | grow_first=True,
58 | ):
59 | super(Block, self).__init__()
60 |
61 | if out_filters != in_filters or strides != 1:
62 | self.skip = nn.Conv2d(
63 | in_filters, out_filters, 1, stride=strides, bias=False
64 | )
65 | self.skipbn = nn.BatchNorm2d(out_filters)
66 | else:
67 | self.skip = None
68 |
69 | rep = []
70 |
71 | filters = in_filters
72 | if grow_first:
73 | rep.append(nn.ReLU(inplace=True))
74 | rep.append(
75 | SeparableConv2d(
76 | in_filters, out_filters, 3, stride=1, padding=1, bias=False
77 | )
78 | )
79 | rep.append(nn.BatchNorm2d(out_filters))
80 | filters = out_filters
81 |
82 | for i in range(reps - 1):
83 | rep.append(nn.ReLU(inplace=True))
84 | rep.append(
85 | SeparableConv2d(
86 | filters, filters, 3, stride=1, padding=1, bias=False
87 | )
88 | )
89 | rep.append(nn.BatchNorm2d(filters))
90 |
91 | if not grow_first:
92 | rep.append(nn.ReLU(inplace=True))
93 | rep.append(
94 | SeparableConv2d(
95 | in_filters, out_filters, 3, stride=1, padding=1, bias=False
96 | )
97 | )
98 | rep.append(nn.BatchNorm2d(out_filters))
99 |
100 | if not start_with_relu:
101 | rep = rep[1:]
102 | else:
103 | rep[0] = nn.ReLU(inplace=False)
104 |
105 | if strides != 1:
106 | rep.append(nn.MaxPool2d(3, strides, 1))
107 | self.rep = nn.Sequential(*rep)
108 |
109 | def forward(self, inp):
110 | x = self.rep(inp)
111 |
112 | if self.skip is not None:
113 | skip = self.skip(inp)
114 | skip = self.skipbn(skip)
115 | else:
116 | skip = inp
117 |
118 | x += skip
119 | return x
120 |
121 |
122 | @MODEL_REGISTRY.register()
123 | class Xception(nn.Module):
124 | def __init__(self, model_cfg):
125 | super(Xception, self).__init__()
126 | num_classes = 2
127 | pretrained = model_cfg['PRETRAINED']
128 | self.escape = model_cfg['ESCAPE']
129 | self.conv1 = nn.Conv2d(3, 32, 3, 2, 0, bias=False)
130 | self.bn1 = nn.BatchNorm2d(32)
131 | self.relu1 = nn.ReLU(inplace=True)
132 | self.conv2 = nn.Conv2d(32, 64, 3, bias=False)
133 | self.bn2 = nn.BatchNorm2d(64)
134 | self.relu2 = nn.ReLU(inplace=True)
135 | self.block1 = Block(
136 | 64, 128, 2, 2, start_with_relu=False, grow_first=True
137 | )
138 | self.block2 = Block(
139 | 128, 256, 2, 2, start_with_relu=True, grow_first=True
140 | )
141 | self.block3 = Block(
142 | 256, 728, 2, 2, start_with_relu=True, grow_first=True
143 | )
144 |
145 | self.block4 = Block(
146 | 728, 728, 3, 1, start_with_relu=True, grow_first=True
147 | )
148 | self.block5 = Block(
149 | 728, 728, 3, 1, start_with_relu=True, grow_first=True
150 | )
151 | self.block6 = Block(
152 | 728, 728, 3, 1, start_with_relu=True, grow_first=True
153 | )
154 | self.block7 = Block(
155 | 728, 728, 3, 1, start_with_relu=True, grow_first=True
156 | )
157 |
158 | self.block8 = Block(
159 | 728, 728, 3, 1, start_with_relu=True, grow_first=True
160 | )
161 | self.block9 = Block(
162 | 728, 728, 3, 1, start_with_relu=True, grow_first=True
163 | )
164 | self.block10 = Block(
165 | 728, 728, 3, 1, start_with_relu=True, grow_first=True
166 | )
167 | self.block11 = Block(
168 | 728, 728, 3, 1, start_with_relu=True, grow_first=True
169 | )
170 | self.block12 = Block(
171 | 728, 1024, 2, 2, start_with_relu=True, grow_first=False
172 | )
173 | self.conv3 = SeparableConv2d(1024, 1536, 3, 1, 1)
174 | self.bn3 = nn.BatchNorm2d(1536)
175 | self.relu3 = nn.ReLU(inplace=True)
176 | self.conv4 = SeparableConv2d(1536, 2048, 3, 1, 1)
177 | self.bn4 = nn.BatchNorm2d(2048)
178 | self.relu4 = nn.ReLU(inplace=True)
179 | self.last_linear = nn.Linear(2048, num_classes)
180 | self.seq = []
181 | self.seq.append(
182 | (
183 | 'b0',
184 | [
185 | self.conv1,
186 | lambda x: self.bn1(x),
187 | self.relu1,
188 | self.conv2,
189 | lambda x: self.bn2(x),
190 | ],
191 | )
192 | )
193 | self.seq.append(('b1', [self.relu2, self.block1]))
194 | self.seq.append(('b2', [self.block2]))
195 | self.seq.append(('b3', [self.block3]))
196 | self.seq.append(('b4', [self.block4]))
197 | self.seq.append(('b5', [self.block5]))
198 | self.seq.append(('b6', [self.block6]))
199 | self.seq.append(('b7', [self.block7]))
200 | self.seq.append(('b8', [self.block8]))
201 | self.seq.append(('b9', [self.block9]))
202 | self.seq.append(('b10', [self.block10]))
203 | self.seq.append(('b11', [self.block11]))
204 | self.seq.append(('b12', [self.block12]))
205 | self.seq.append(
206 | (
207 | 'final',
208 | [
209 | self.conv3,
210 | lambda x: self.bn3(x),
211 | self.relu3,
212 | self.conv4,
213 | lambda x: self.bn4(x),
214 | ],
215 | )
216 | )
217 | self.seq.append(
218 | (
219 | 'logits',
220 | [
221 | self.relu4,
222 | lambda x: F.adaptive_avg_pool2d(x, (1, 1)),
223 | lambda x: x.view(x.size(0), -1),
224 | self.last_linear,
225 | ],
226 | )
227 | )
228 | if pretrained == 'imagenet':
229 | self.load_state_dict(
230 | torch.hub.load_state_dict_from_url(
231 | 'http://data.lip6.fr/cadene/pretrainedmodels/xception-43020ad28.pth'
232 | ),
233 | strict=False,
234 | )
235 | elif pretrained:
236 | ckpt = torch.load(pretrained, map_location='cpu')
237 | self.load_state_dict(ckpt['state_dict'])
238 | else:
239 | for m in self.modules():
240 | if isinstance(m, nn.Conv2d):
241 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
242 | m.weight.data.normal_(0, math.sqrt(2.0 / n))
243 | elif isinstance(m, nn.BatchNorm2d):
244 | m.weight.data.fill_(1)
245 | m.bias.data.zero_()
246 |
247 | def forward(self, samples):
248 | x = samples['img']
249 | layers = {}
250 | for stage in self.seq:
251 | for f in stage[1]:
252 | x = f(x)
253 | layers[stage[0]] = x
254 | if stage[0] == self.escape:
255 | break
256 | return layers
257 |
--------------------------------------------------------------------------------
/M2TR/utils/meters.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import time
3 | from collections import defaultdict, deque
4 |
5 | import numpy as np
6 | import torch
7 | import torch.distributed as dist
8 | from fvcore.common.timer import Timer
9 |
10 | import M2TR.utils.distributed as du
11 | import M2TR.utils.logging as logging
12 | from sklearn.metrics import roc_auc_score
13 |
14 | logger = logging.get_logger(__name__)
15 |
16 | class SmoothedValue(object):
17 | """Track a series of values and provide access to smoothed values over a
18 | window or the global series average.
19 | """
20 |
21 | def __init__(self, window_size=20, fmt=None):
22 | if fmt is None:
23 | fmt = "{median:.4f} ({global_avg:.4f})"
24 | self.deque = deque(maxlen=window_size)
25 | self.total = 0.0
26 | self.count = 0
27 | self.fmt = fmt
28 |
29 | def update(self, value, n=1):
30 | self.deque.append(value)
31 | self.count += n
32 | self.total += value * n
33 |
34 | def synchronize_between_processes(self):
35 | """
36 | Warning: does not synchronize the deque!
37 | """
38 | if not du.is_dist_avail_and_initialized():
39 | return
40 | t = torch.tensor(
41 | [self.count, self.total], dtype=torch.float64, device='cuda'
42 | )
43 | dist.barrier()
44 | dist.all_reduce(t)
45 | t = t.tolist()
46 | self.count = int(t[0])
47 | self.total = t[1]
48 |
49 | @property
50 | def median(self):
51 | d = torch.tensor(list(self.deque))
52 | return d.median().item()
53 |
54 | @property
55 | def avg(self):
56 | d = torch.tensor(list(self.deque), dtype=torch.float32)
57 | return d.mean().item()
58 |
59 | @property
60 | def global_avg(self):
61 | return self.total / self.count
62 |
63 | @property
64 | def max(self):
65 | return max(self.deque)
66 |
67 | @property
68 | def value(self):
69 | return self.deque[-1]
70 |
71 | def __str__(self):
72 | return self.fmt.format(
73 | median=self.median,
74 | avg=self.avg,
75 | global_avg=self.global_avg,
76 | max=self.max,
77 | value=self.value,
78 | )
79 |
80 |
81 | class AucMetric():
82 | """
83 |
84 | """
85 |
86 | def __init__(self, num_gpus):
87 | self.labels = torch.Tensor().cuda()
88 | self.preds = torch.Tensor().cuda()
89 | self.num_gpus = num_gpus
90 |
91 | def update(self, labels, preds):
92 | self.labels = torch.cat([self.labels, labels], dim=0)
93 | self.preds = torch.cat([self.preds, preds], dim=0)
94 |
95 | def synchronize_between_processes(self):
96 | if not du.is_dist_avail_and_initialized():
97 | return
98 | labels = [torch.zeros(len(self.labels), dtype=self.labels[0].dtype).cuda() for _ in range(self.num_gpus)]
99 | preds = [torch.zeros(len(self.preds), dtype=self.preds[0].dtype).cuda() for _ in range(self.num_gpus)]
100 | dist.all_gather(labels, self.labels)
101 | dist.all_gather(preds, self.preds)
102 | labels = torch.cat(labels, dim=0).cpu()
103 | preds = torch.cat(preds, dim=0).cpu()
104 | self.auc = roc_auc_score(labels, preds)
105 |
106 | # def __str__(self):
107 | # return str(self.auc)
108 |
109 |
110 | class MetricLogger(object):
111 | def __init__(self, delimiter="\t"):
112 | self.meters = defaultdict(SmoothedValue)
113 | self.delimiter = delimiter
114 |
115 | def update(self, **kwargs):
116 | for k, v in kwargs.items():
117 | if isinstance(v, torch.Tensor):
118 | v = v.item()
119 | assert isinstance(v, (float, int))
120 | self.meters[k].update(v)
121 |
122 | def __getattr__(self, attr):
123 | if attr in self.meters:
124 | return self.meters[attr]
125 | if attr in self.__dict__:
126 | return self.__dict__[attr]
127 | raise AttributeError(
128 | "'{}' object has no attribute '{}'".format(
129 | type(self).__name__, attr
130 | )
131 | )
132 |
133 | def __str__(self):
134 | loss_str = []
135 | for name, meter in self.meters.items():
136 | loss_str.append("{}: {}".format(name, str(meter)))
137 | return self.delimiter.join(loss_str)
138 |
139 | def synchronize_between_processes(self):
140 | for meter in self.meters.values():
141 | meter.synchronize_between_processes()
142 |
143 | def add_meter(self, name, meter):
144 | self.meters[name] = meter
145 |
146 | def log_every(self, iterable, print_freq, header=None):
147 | i = 0
148 | if not header:
149 | header = ''
150 | start_time = time.time()
151 | end = time.time()
152 | iter_time = SmoothedValue(fmt='{avg:.4f}')
153 | data_time = SmoothedValue(fmt='{avg:.4f}')
154 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
155 | log_msg = [
156 | header,
157 | '[{0' + space_fmt + '}/{1}]',
158 | 'eta: {eta}',
159 | '{meters}',
160 | 'time: {time}',
161 | 'data: {data}',
162 | ]
163 | if torch.cuda.is_available():
164 | log_msg.append('max mem: {memory:.0f}')
165 | log_msg = self.delimiter.join(log_msg)
166 | MB = 1024.0 * 1024.0
167 | for obj in iterable:
168 | data_time.update(time.time() - end)
169 | yield obj
170 | iter_time.update(time.time() - end)
171 | if i % print_freq == 0 or i == len(iterable) - 1:
172 | eta_seconds = iter_time.global_avg * (len(iterable) - i)
173 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
174 | if torch.cuda.is_available():
175 | logger.info(
176 | log_msg.format(
177 | i,
178 | len(iterable),
179 | eta=eta_string,
180 | meters=str(self),
181 | time=str(iter_time),
182 | data=str(data_time),
183 | memory=torch.cuda.max_memory_allocated() / MB,
184 | )
185 | )
186 | else:
187 | logger.info(
188 | log_msg.format(
189 | i,
190 | len(iterable),
191 | eta=eta_string,
192 | meters=str(self),
193 | time=str(iter_time),
194 | data=str(data_time),
195 | )
196 | )
197 | i += 1
198 | end = time.time()
199 | total_time = time.time() - start_time
200 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
201 | logger.info(
202 | '{} Total time: {} ({:.4f} s / it)'.format(
203 | header, total_time_str, total_time / len(iterable)
204 | )
205 | )
206 |
207 |
208 | class EpochTimer:
209 | """
210 | A timer which computes the epoch time.
211 | """
212 |
213 | def __init__(self) -> None:
214 | self.timer = Timer()
215 | self.timer.reset()
216 | self.epoch_times = []
217 |
218 | def reset(self) -> None:
219 | """
220 | Reset the epoch timer.
221 | """
222 | self.timer.reset()
223 | self.epoch_times = []
224 |
225 | def epoch_tic(self):
226 | """
227 | Start to record time.
228 | """
229 | self.timer.reset()
230 |
231 | def epoch_toc(self):
232 | """
233 | Stop to record time.
234 | """
235 | self.timer.pause()
236 | self.epoch_times.append(self.timer.seconds())
237 |
238 | def last_epoch_time(self):
239 | """
240 | Get the time for the last epoch.
241 | """
242 | assert len(self.epoch_times) > 0, "No epoch time has been recorded!"
243 |
244 | return self.epoch_times[-1]
245 |
246 | def avg_epoch_time(self):
247 | """
248 | Calculate the average epoch time among the recorded epochs.
249 | """
250 | assert len(self.epoch_times) > 0, "No epoch time has been recorded!"
251 |
252 | return np.mean(self.epoch_times)
253 |
254 | def median_epoch_time(self):
255 | """
256 | Calculate the median epoch time among the recorded epochs.
257 | """
258 | assert len(self.epoch_times) > 0, "No epoch time has been recorded!"
259 |
260 | return np.median(self.epoch_times)
261 |
--------------------------------------------------------------------------------
/M2TR/utils/distributed.py:
--------------------------------------------------------------------------------
1 | import functools
2 | import logging
3 | import pickle
4 |
5 | import torch
6 | import torch.distributed as dist
7 |
8 | _LOCAL_PROCESS_GROUP = None
9 |
10 |
11 | def all_gather(tensors):
12 | """
13 | All gathers the provided tensors from all processes across machines.
14 | Args:
15 | tensors (list): tensors to perform all gather across all processes in
16 | all machines.
17 | """
18 |
19 | gather_list = []
20 | output_tensor = []
21 | world_size = dist.get_world_size()
22 | for tensor in tensors:
23 | tensor_placeholder = [
24 | torch.ones_like(tensor) for _ in range(world_size)
25 | ]
26 | dist.all_gather(tensor_placeholder, tensor, async_op=False)
27 | gather_list.append(tensor_placeholder)
28 | for gathered_tensor in gather_list:
29 | output_tensor.append(torch.cat(gathered_tensor, dim=0))
30 | return output_tensor
31 |
32 |
33 | def all_reduce(tensors, average=True):
34 | """
35 | All reduce the provided tensors from all processes across machines.
36 | Args:
37 | tensors (list): tensors to perform all reduce across all processes in
38 | all machines.
39 | average (bool): scales the reduced tensor by the number of overall
40 | processes across all machines.
41 | """
42 |
43 | for tensor in tensors:
44 | dist.all_reduce(tensor, async_op=False)
45 | if average:
46 | world_size = dist.get_world_size()
47 | for tensor in tensors:
48 | tensor.mul_(1.0 / world_size)
49 | return tensors
50 |
51 |
52 | def init_process_group(
53 | local_rank,
54 | local_world_size,
55 | shard_id,
56 | num_shards,
57 | init_method,
58 | dist_backend="nccl",
59 | ):
60 | """
61 | Initializes the default process group.
62 | Args:
63 | local_rank (int): the rank on the current local machine.
64 | local_world_size (int): the world size (number of processes running) on
65 | the current local machine.
66 | shard_id (int): the shard index (machine rank) of the current machine.
67 | num_shards (int): number of shards for distributed training.
68 | init_method (string): supporting three different methods for
69 | initializing process groups:
70 | "file": use shared file system to initialize the groups across
71 | different processes.
72 | "tcp": use tcp address to initialize the groups across different
73 | dist_backend (string): backend to use for distributed training. Options
74 | includes gloo, mpi and nccl, the details can be found here:
75 | https://pytorch.org/docs/stable/distributed.html
76 | """
77 | # Sets the GPU to use.
78 | torch.cuda.set_device(local_rank)
79 | # Initialize the process group.
80 | proc_rank = local_rank + shard_id * local_world_size
81 | world_size = local_world_size * num_shards
82 | dist.init_process_group(
83 | backend=dist_backend,
84 | init_method=init_method,
85 | world_size=world_size,
86 | rank=proc_rank,
87 | )
88 |
89 |
90 | def is_master_proc(num_gpus=8):
91 | """
92 | Determines if the current process is the master process.
93 | """
94 | if torch.distributed.is_initialized():
95 | return dist.get_rank() % num_gpus == 0
96 | else:
97 | return True
98 |
99 |
100 | def is_root_proc():
101 | """
102 | Determines if the current process is the root process.
103 | """
104 | if torch.distributed.is_initialized():
105 | return dist.get_rank() == 0
106 | else:
107 | return True
108 |
109 |
110 | def get_world_size():
111 | """
112 | Get the size of the world.
113 | """
114 | if not dist.is_available():
115 | return 1
116 | if not dist.is_initialized():
117 | return 1
118 | return dist.get_world_size()
119 |
120 |
121 | def get_rank():
122 | """
123 | Get the rank of the current process.
124 | """
125 | if not dist.is_available():
126 | return 0
127 | if not dist.is_initialized():
128 | return 0
129 | return dist.get_rank()
130 |
131 |
132 | def synchronize():
133 | """
134 | Helper function to synchronize (barrier) among all processes when
135 | using distributed training
136 | """
137 | if not dist.is_available():
138 | return
139 | if not dist.is_initialized():
140 | return
141 | world_size = dist.get_world_size()
142 | if world_size == 1:
143 | return
144 | dist.barrier()
145 |
146 |
147 | def is_dist_avail_and_initialized():
148 | if not dist.is_available():
149 | return False
150 | if not dist.is_initialized():
151 | return False
152 | return True
153 |
154 |
155 | @functools.lru_cache()
156 | def _get_global_gloo_group():
157 | """
158 | Return a process group based on gloo backend, containing all the ranks
159 | The result is cached.
160 | Returns:
161 | (group): pytorch dist group.
162 | """
163 | if dist.get_backend() == "nccl":
164 | return dist.new_group(backend="gloo")
165 | else:
166 | return dist.group.WORLD
167 |
168 |
169 | def _serialize_to_tensor(data, group):
170 | """
171 | Seriialize the tensor to ByteTensor. Note that only `gloo` and `nccl`
172 | backend is supported.
173 | Args:
174 | data (data): data to be serialized.
175 | group (group): pytorch dist group.
176 | Returns:
177 | tensor (ByteTensor): tensor that serialized.
178 | """
179 |
180 | backend = dist.get_backend(group)
181 | assert backend in ["gloo", "nccl"]
182 | device = torch.device("cpu" if backend == "gloo" else "cuda")
183 |
184 | buffer = pickle.dumps(data)
185 | if len(buffer) > 1024 ** 3:
186 | logger = logging.getLogger(__name__)
187 | logger.warning(
188 | "Rank {} trying to all-gather {:.2f} GB of data on device {}".format(
189 | get_rank(), len(buffer) / (1024 ** 3), device
190 | )
191 | )
192 | storage = torch.ByteStorage.from_buffer(buffer)
193 | tensor = torch.ByteTensor(storage).to(device=device)
194 | return tensor
195 |
196 |
197 | def _pad_to_largest_tensor(tensor, group):
198 | """
199 | Padding all the tensors from different GPUs to the largest ones.
200 | Args:
201 | tensor (tensor): tensor to pad.
202 | group (group): pytorch dist group.
203 | Returns:
204 | list[int]: size of the tensor, on each rank
205 | Tensor: padded tensor that has the max size
206 | """
207 | world_size = dist.get_world_size(group=group)
208 | assert (
209 | world_size >= 1
210 | ), "comm.gather/all_gather must be called from ranks within the given group!"
211 | local_size = torch.tensor(
212 | [tensor.numel()], dtype=torch.int64, device=tensor.device
213 | )
214 | size_list = [
215 | torch.zeros([1], dtype=torch.int64, device=tensor.device)
216 | for _ in range(world_size)
217 | ]
218 | dist.all_gather(size_list, local_size, group=group)
219 | size_list = [int(size.item()) for size in size_list]
220 |
221 | max_size = max(size_list)
222 |
223 | # we pad the tensor because torch all_gather does not support
224 | # gathering tensors of different shapes
225 | if local_size != max_size:
226 | padding = torch.zeros(
227 | (max_size - local_size,), dtype=torch.uint8, device=tensor.device
228 | )
229 | tensor = torch.cat((tensor, padding), dim=0)
230 | return size_list, tensor
231 |
232 |
233 | def all_gather_unaligned(data, group=None):
234 | """
235 | Run all_gather on arbitrary picklable data (not necessarily tensors).
236 | Args:
237 | data: any picklable object
238 | group: a torch process group. By default, will use a group which
239 | contains all ranks on gloo backend.
240 | Returns:
241 | list[data]: list of data gathered from each rank
242 | """
243 | if get_world_size() == 1:
244 | return [data]
245 | if group is None:
246 | group = _get_global_gloo_group()
247 | if dist.get_world_size(group) == 1:
248 | return [data]
249 |
250 | tensor = _serialize_to_tensor(data, group)
251 |
252 | size_list, tensor = _pad_to_largest_tensor(tensor, group)
253 | max_size = max(size_list)
254 |
255 | # receiving Tensor from all ranks
256 | tensor_list = [
257 | torch.empty((max_size,), dtype=torch.uint8, device=tensor.device)
258 | for _ in size_list
259 | ]
260 | dist.all_gather(tensor_list, tensor, group=group)
261 |
262 | data_list = []
263 | for size, tensor in zip(size_list, tensor_list):
264 | buffer = tensor.cpu().numpy().tobytes()[:size]
265 | data_list.append(pickle.loads(buffer))
266 |
267 | return data_list
268 |
269 |
270 | def init_distributed_training(cfg):
271 | """
272 | Initialize variables needed for distributed training.
273 | """
274 | if cfg['NUM_GPUS'] <= 1:
275 | return
276 | num_gpus_per_machine = cfg['NUM_GPUS']
277 | num_machines = cfg['NUM_SHARDS']
278 | for i in range(num_machines):
279 | ranks_on_i = list(
280 | range(i * num_gpus_per_machine, (i + 1) * num_gpus_per_machine)
281 | )
282 | pg = dist.new_group(ranks_on_i)
283 | if i == cfg['SHARD_ID']:
284 | global _LOCAL_PROCESS_GROUP
285 | _LOCAL_PROCESS_GROUP = pg
286 |
--------------------------------------------------------------------------------
/M2TR/models/m2tr.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | import torch.fft
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 | from M2TR.utils.registries import MODEL_REGISTRY
9 |
10 | from .base import BaseNetwork
11 | from .xception import Xception
12 | from .efficientnet import EfficientNet
13 | from .modules.head import Classifier2D, Localizer
14 | from .modules.transformer_block import FeedForward2D
15 |
16 |
17 |
18 | class GlobalFilter(nn.Module):
19 | def __init__(self, dim=32, h=80, w=41, fp32fft=True):
20 | super().__init__()
21 | self.complex_weight = nn.Parameter(
22 | torch.randn(h, w, dim, 2, dtype=torch.float32) * 0.02
23 | )
24 | self.w = w
25 | self.h = h
26 | self.fp32fft = fp32fft
27 |
28 | def forward(self, x):
29 | b, _, a, b = x.size()
30 | x = x.permute(0, 2, 3, 1).contiguous()
31 |
32 | if self.fp32fft:
33 | dtype = x.dtype
34 | x = x.to(torch.float32)
35 |
36 | x = torch.fft.rfft2(x, dim=(1, 2), norm="ortho")
37 | weight = torch.view_as_complex(self.complex_weight)
38 | x = x * weight
39 | x = torch.fft.irfft2(x, s=(a, b), dim=(1, 2), norm="ortho")
40 |
41 | if self.fp32fft:
42 | x = x.to(dtype)
43 |
44 | x = x.permute(0, 3, 1, 2).contiguous()
45 |
46 | return x
47 |
48 |
49 | class FreqBlock(nn.Module):
50 | def __init__(self, dim, h=80, w=41, fp32fft=True):
51 | super().__init__()
52 | self.filter = GlobalFilter(dim, h=h, w=w, fp32fft=fp32fft)
53 | self.feed_forward = FeedForward2D(in_channel=dim, out_channel=dim)
54 |
55 | def forward(self, x):
56 | x = x + self.feed_forward(self.filter(x))
57 | return x
58 |
59 |
60 | def attention(query, key, value):
61 | scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(
62 | query.size(-1)
63 | )
64 | p_attn = F.softmax(scores, dim=-1)
65 | p_val = torch.matmul(p_attn, value)
66 | return p_val, p_attn
67 |
68 |
69 | class MultiHeadedAttention(nn.Module):
70 | """
71 | Take in model size and number of heads.
72 | """
73 |
74 | def __init__(self, patchsize, d_model):
75 | super().__init__()
76 | self.patchsize = patchsize
77 | self.query_embedding = nn.Conv2d(
78 | d_model, d_model, kernel_size=1, padding=0
79 | )
80 | self.value_embedding = nn.Conv2d(
81 | d_model, d_model, kernel_size=1, padding=0
82 | )
83 | self.key_embedding = nn.Conv2d(
84 | d_model, d_model, kernel_size=1, padding=0
85 | )
86 | self.output_linear = nn.Sequential(
87 | nn.Conv2d(d_model, d_model, kernel_size=3, padding=1),
88 | nn.BatchNorm2d(d_model),
89 | nn.LeakyReLU(0.2, inplace=True),
90 | )
91 |
92 | def forward(self, x):
93 | b, c, h, w = x.size()
94 | d_k = c // len(self.patchsize)
95 | output = []
96 | _query = self.query_embedding(x)
97 | _key = self.key_embedding(x)
98 | _value = self.value_embedding(x)
99 | attentions = []
100 | for (width, height), query, key, value in zip(
101 | self.patchsize,
102 | torch.chunk(_query, len(self.patchsize), dim=1),
103 | torch.chunk(_key, len(self.patchsize), dim=1),
104 | torch.chunk(_value, len(self.patchsize), dim=1),
105 | ):
106 | out_w, out_h = w // width, h // height
107 |
108 | # 1) embedding and reshape
109 | query = query.view(b, d_k, out_h, height, out_w, width)
110 | query = (
111 | query.permute(0, 2, 4, 1, 3, 5)
112 | .contiguous()
113 | .view(b, out_h * out_w, d_k * height * width)
114 | )
115 | key = key.view(b, d_k, out_h, height, out_w, width)
116 | key = (
117 | key.permute(0, 2, 4, 1, 3, 5)
118 | .contiguous()
119 | .view(b, out_h * out_w, d_k * height * width)
120 | )
121 | value = value.view(b, d_k, out_h, height, out_w, width)
122 | value = (
123 | value.permute(0, 2, 4, 1, 3, 5)
124 | .contiguous()
125 | .view(b, out_h * out_w, d_k * height * width)
126 | )
127 |
128 | y, _ = attention(query, key, value)
129 |
130 | # 3) "Concat" using a view and apply a final linear.
131 | y = y.view(b, out_h, out_w, d_k, height, width)
132 | y = y.permute(0, 3, 1, 4, 2, 5).contiguous().view(b, d_k, h, w)
133 | attentions.append(y)
134 | output.append(y)
135 |
136 | output = torch.cat(output, 1)
137 | self_attention = self.output_linear(output)
138 |
139 | return self_attention
140 |
141 |
142 | class TransformerBlock(nn.Module):
143 | """
144 | Transformer = MultiHead_Attention + Feed_Forward with sublayer connection
145 | """
146 |
147 | def __init__(self, patchsize, in_channel=256):
148 | super().__init__()
149 | self.attention = MultiHeadedAttention(patchsize, d_model=in_channel)
150 | self.feed_forward = FeedForward2D(
151 | in_channel=in_channel, out_channel=in_channel
152 | )
153 |
154 | def forward(self, rgb):
155 | self_attention = self.attention(rgb)
156 | output = rgb + self_attention
157 | output = output + self.feed_forward(output)
158 | return output
159 |
160 |
161 | class CMA_Block(nn.Module):
162 | def __init__(self, in_channel, hidden_channel, out_channel):
163 | super(CMA_Block, self).__init__()
164 |
165 | self.conv1 = nn.Conv2d(
166 | in_channel, hidden_channel, kernel_size=1, stride=1, padding=0
167 | )
168 | self.conv2 = nn.Conv2d(
169 | in_channel, hidden_channel, kernel_size=1, stride=1, padding=0
170 | )
171 | self.conv3 = nn.Conv2d(
172 | in_channel, hidden_channel, kernel_size=1, stride=1, padding=0
173 | )
174 |
175 | self.scale = hidden_channel ** -0.5
176 |
177 | self.conv4 = nn.Sequential(
178 | nn.Conv2d(
179 | hidden_channel, out_channel, kernel_size=1, stride=1, padding=0
180 | ),
181 | nn.BatchNorm2d(out_channel),
182 | nn.LeakyReLU(0.2, inplace=True),
183 | )
184 |
185 | def forward(self, rgb, freq):
186 | _, _, h, w = rgb.size()
187 |
188 | q = self.conv1(rgb)
189 | k = self.conv2(freq)
190 | v = self.conv3(freq)
191 |
192 | q = q.view(q.size(0), q.size(1), q.size(2) * q.size(3)).transpose(
193 | -2, -1
194 | )
195 | k = k.view(k.size(0), k.size(1), k.size(2) * k.size(3))
196 |
197 | attn = torch.matmul(q, k) * self.scale
198 | m = attn.softmax(dim=-1)
199 |
200 | v = v.view(v.size(0), v.size(1), v.size(2) * v.size(3)).transpose(
201 | -2, -1
202 | )
203 | z = torch.matmul(m, v)
204 | z = z.view(z.size(0), h, w, -1)
205 | z = z.permute(0, 3, 1, 2).contiguous()
206 |
207 | output = rgb + self.conv4(z)
208 |
209 | return output
210 |
211 |
212 | class PatchTrans(BaseNetwork):
213 | def __init__(self, in_channel, in_size):
214 | super(PatchTrans, self).__init__()
215 | self.in_size = in_size
216 |
217 | patchsize = [
218 | (in_size, in_size),
219 | (in_size // 2, in_size // 2),
220 | (in_size // 4, in_size // 4),
221 | (in_size // 8, in_size // 8),
222 | ]
223 |
224 | self.t = TransformerBlock(patchsize, in_channel=in_channel)
225 |
226 | def forward(self, enc_feat):
227 | output = self.t(enc_feat)
228 | return output
229 |
230 |
231 | @MODEL_REGISTRY.register()
232 | class M2TR(BaseNetwork):
233 | def __init__(self, model_cfg):
234 | super(M2TR, self).__init__()
235 | img_size = model_cfg["IMG_SIZE"]
236 | backbone = model_cfg["BACKBONE"]
237 | texture_layer = model_cfg["TEXTURE_LAYER"]
238 | feature_layer = model_cfg["FEATURE_LAYER"]
239 | depth = model_cfg["DEPTH"]
240 | num_classes = model_cfg["NUM_CLASSES"]
241 | drop_ratio = model_cfg["DROP_RATIO"]
242 | has_decoder = model_cfg["HAS_DECODER"]
243 |
244 | freq_h = img_size // 4
245 | freq_w = freq_h // 2 + 1
246 |
247 | if "xception" in backbone:
248 | self.model = Xception(num_classes)
249 | elif backbone.split("-")[0] == "efficientnet":
250 | self.model = EfficientNet({'NAME': backbone, 'PRETRAINED': True})
251 |
252 | self.texture_layer = texture_layer
253 | self.feature_layer = feature_layer
254 |
255 | with torch.no_grad():
256 | input = {"img": torch.zeros(1, 3, img_size, img_size)}
257 | layers = self.model(input)
258 | texture_dim = layers[self.texture_layer].shape[1]
259 | feature_dim = layers[self.feature_layer].shape[1]
260 |
261 | self.layers = nn.ModuleList([])
262 | for _ in range(depth):
263 | self.layers.append(
264 | nn.ModuleList(
265 | [
266 | PatchTrans(in_channel=texture_dim, in_size=freq_h),
267 | FreqBlock(dim=texture_dim, h=freq_h, w=freq_w),
268 | CMA_Block(
269 | in_channel=texture_dim,
270 | hidden_channel=texture_dim,
271 | out_channel=texture_dim,
272 | ),
273 | ]
274 | )
275 | )
276 |
277 | self.classifier = Classifier2D(
278 | feature_dim, num_classes, drop_ratio, "sigmoid"
279 | )
280 |
281 | self.has_decoder = has_decoder
282 | if self.has_decoder:
283 | self.decoder = Localizer(texture_dim, 1)
284 |
285 | def forward(self, x):
286 | rgb = x["img"]
287 | B = rgb.size(0)
288 |
289 | layers = {}
290 | rgb = self.model.extract_textures(rgb, layers)
291 |
292 | for attn, filter, cma in self.layers:
293 | rgb = attn(rgb)
294 | freq = filter(rgb)
295 | rgb = cma(rgb, freq)
296 |
297 | features = self.model.extract_features(rgb, layers)
298 | features = F.adaptive_avg_pool2d(features, (1, 1))
299 | features = features.view(B, features.size(1))
300 |
301 | logits = self.classifier(features)
302 |
303 | if self.has_decoder:
304 | mask = self.decoder(rgb)
305 | mask = mask.squeeze(-1)
306 |
307 | else:
308 | mask = None
309 |
310 | output = {"logits": logits, "mask": mask, "features:": features}
311 | return output
312 |
313 |
314 | if __name__ == "__main__":
315 | from torchsummary import summary
316 |
317 | model = M2TR(num_classes=1, has_decoder=False)
318 | model.cuda()
319 | summary(model, input_size=(3, 320, 320), batch_size=12, device="cuda")
320 |
--------------------------------------------------------------------------------
/M2TR/utils/checkpoint.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import os
3 |
4 | import torch
5 |
6 | import M2TR.utils.distributed as du
7 | import M2TR.utils.logging as logging
8 | from M2TR.utils.env import pathmgr
9 |
10 | logger = logging.get_logger(__name__)
11 |
12 |
13 | def make_checkpoint_dir(path_to_job):
14 | """
15 | Creates the checkpoint directory (if not present already).
16 | Args:
17 | path_to_job (string): the path to the folder of the current job.
18 | """
19 | checkpoint_dir = os.path.join(path_to_job, "checkpoints")
20 | # Create the checkpoint dir from the master process
21 | if du.is_master_proc() and not pathmgr.exists(checkpoint_dir):
22 | try:
23 | pathmgr.mkdirs(checkpoint_dir)
24 | except Exception:
25 | pass
26 | return checkpoint_dir
27 |
28 |
29 | def get_checkpoint_dir(path_to_job):
30 | """
31 | Get path for storing checkpoints.
32 | Args:
33 | path_to_job (string): the path to the folder of the current job.
34 | """
35 | return os.path.join(path_to_job, "checkpoints")
36 |
37 |
38 | def get_path_to_checkpoint(path_to_job, epoch, cfg):
39 | """
40 | Get the full path to a checkpoint file.
41 | Args:
42 | path_to_job (string): the path to the folder of the current job.
43 | epoch (int): the number of epoch for the checkpoint.
44 | """
45 | file_name = (
46 | cfg['MODEL']['MODEL_NAME']
47 | + '_'
48 | + cfg['DATASET']['DATASET_NAME']
49 | + '_'
50 | + 'epoch_{:05d}'
51 | + '.pyth'
52 | )
53 | file_name = file_name.format(epoch)
54 | return os.path.join(get_checkpoint_dir(path_to_job), file_name)
55 |
56 |
57 | def get_last_checkpoint(path_to_job):
58 | """
59 | Get the last checkpoint from the checkpointing folder.
60 | Args:
61 | path_to_job (string): the path to the folder of the current job.
62 | """
63 |
64 | d = get_checkpoint_dir(path_to_job)
65 | names = pathmgr.ls(d) if pathmgr.exists(d) else []
66 | names = [f for f in names if "checkpoint" in f]
67 | assert len(names), "No checkpoints found in '{}'.".format(d)
68 | # Sort the checkpoints by epoch.
69 | name = sorted(names)[-1]
70 | return os.path.join(d, name)
71 |
72 |
73 | def has_checkpoint(path_to_job):
74 | """
75 | Determines if the given directory contains a checkpoint.
76 | Args:
77 | path_to_job (string): the path to the folder of the current job.
78 | """
79 | d = get_checkpoint_dir(path_to_job)
80 | files = pathmgr.ls(d) if pathmgr.exists(d) else []
81 | return any("checkpoint" in f for f in files)
82 |
83 |
84 | def is_checkpoint_epoch(cfg, cur_epoch, multigrid_schedule=None):
85 | """
86 | Determine if a checkpoint should be saved on current epoch.
87 | Args:
88 | cfg (dict): configs to save.
89 | cur_epoch (int): current number of epoch of the model.
90 | multigrid_schedule (List): schedule for multigrid training.
91 | """
92 | if cur_epoch + 1 == cfg['TRAIN']['MAX_EPOCH']:
93 | return True
94 | if multigrid_schedule is not None: # TODO remove multigrid_schedule?
95 | prev_epoch = 0
96 | for s in multigrid_schedule:
97 | if cur_epoch < s[-1]:
98 | period = max(
99 | (s[-1] - prev_epoch) // cfg.MULTIGRID.EVAL_FREQ + 1, 1
100 | )
101 | return (s[-1] - 1 - cur_epoch) % period == 0
102 | prev_epoch = s[-1]
103 |
104 | return (cur_epoch + 1) % cfg['TRAIN']['CHECKPOINT_PERIOD'] == 0
105 |
106 |
107 | def save_checkpoint(model, optimizer, scheduler, epoch, cfg):
108 | """
109 | Save a checkpoint.
110 | Args:
111 | model (model): model to save the weight to the checkpoint.
112 | optimizer (optim): optimizer to save the historical state.
113 | epoch (int): current number of epoch of the model.
114 | cfg (dict): configs to save.
115 | """
116 | path_to_job = cfg['TRAIN']['CHECKPOINT_SAVE_PATH']
117 | # Save checkpoints only from the master process.
118 | if not du.is_master_proc(cfg['NUM_GPUS'] * cfg['NUM_SHARDS']):
119 | return
120 | # Ensure that the checkpoint dir exists.
121 | pathmgr.mkdirs(get_checkpoint_dir(path_to_job))
122 | # Omit the DDP wrapper in the multi-gpu setting.
123 | sd = (
124 | model.module.state_dict() if cfg['NUM_GPUS'] > 1 else model.state_dict()
125 | )
126 | normalized_sd = sub_to_normal_bn(sd)
127 |
128 | # Record the state.
129 | checkpoint = {
130 | "epoch": epoch,
131 | "model_state": normalized_sd,
132 | "optimizer_state": optimizer.state_dict(),
133 | "scheduler_state": scheduler.state_dict()
134 | if scheduler is not None
135 | else None, # TODO
136 | "cfg": cfg,
137 | }
138 |
139 | # Write the checkpoint.
140 | path_to_checkpoint = get_path_to_checkpoint(path_to_job, epoch + 1, cfg)
141 | with pathmgr.open(path_to_checkpoint, "wb") as f:
142 | torch.save(checkpoint, f)
143 | return path_to_checkpoint
144 |
145 |
146 | def load_checkpoint(
147 | path_to_checkpoint,
148 | model,
149 | data_parallel=True,
150 | optimizer=None,
151 | scheduler=None,
152 | epoch_reset=False,
153 | ):
154 | """
155 | Load the checkpoint from the given file.
156 | Args:
157 | path_to_checkpoint (string): path to the checkpoint to load.
158 | model (model): model to load the weights from the checkpoint.
159 | data_parallel (bool): if true, model is wrapped by
160 | torch.nn.parallel.DistributedDataParallel.
161 | optimizer (optim): optimizer to load the historical state.
162 |
163 | epoch_reset (bool): if True, reset #train iterations from the checkpoint.
164 |
165 | Returns:
166 | (int): the number of training epoch of the checkpoint.
167 | """
168 | assert pathmgr.exists(
169 | path_to_checkpoint
170 | ), "Checkpoint '{}' not found".format(path_to_checkpoint)
171 | logger.info("Loading network weights from {}.".format(path_to_checkpoint))
172 |
173 | # Account for the DDP wrapper in the multi-gpu setting.
174 | ms = model.module if data_parallel else model
175 |
176 | # Load the checkpoint on CPU to avoid GPU mem spike.
177 | with pathmgr.open(path_to_checkpoint, "rb") as f:
178 | checkpoint = torch.load(f, map_location="cpu")
179 |
180 | model_state_dict = (
181 | model.module.state_dict() if data_parallel else model.state_dict()
182 | )
183 | checkpoint["model_state"] = normal_to_sub_bn(
184 | checkpoint["model_state"], model_state_dict
185 | )
186 |
187 | pre_train_dict = checkpoint["model_state"]
188 | model_dict = ms.state_dict()
189 | # Match pre-trained weights that have same shape as current model.
190 | pre_train_dict_match = {
191 | k: v
192 | for k, v in pre_train_dict.items()
193 | if k in model_dict and v.size() == model_dict[k].size()
194 | }
195 |
196 | # Weights that do not have match from the pre-trained model.
197 | not_load_layers = [
198 | k for k in model_dict.keys() if k not in pre_train_dict_match.keys()
199 | ]
200 |
201 | # Log weights that are not loaded with the pre-trained weights.
202 | if not_load_layers:
203 | for k in not_load_layers:
204 | logger.info("Network weights {} not loaded.".format(k))
205 |
206 | # Load pre-trained weights.
207 | ms.load_state_dict(pre_train_dict_match, strict=False)
208 | epoch = -1
209 |
210 | # Load the optimizer state (commonly not done when fine-tuning)
211 | if "epoch" in checkpoint.keys() and not epoch_reset:
212 | epoch = checkpoint["epoch"]
213 | if optimizer:
214 | optimizer.load_state_dict(checkpoint["optimizer_state"])
215 | if scheduler:
216 | scheduler.load_state_dict(checkpoint["scheduler_state"])
217 |
218 | else:
219 | epoch = -1
220 |
221 | return epoch
222 |
223 |
224 | def sub_to_normal_bn(sd):
225 | """
226 | Convert the Sub-BN paprameters to normal BN parameters in a state dict.
227 | There are two copies of BN layers in a Sub-BN implementation: `bn.bn` and
228 | `bn.split_bn`. `bn.split_bn` is used during training and
229 | "compute_precise_bn". Before saving or evaluation, its stats are copied to
230 | `bn.bn`. We rename `bn.bn` to `bn` and store it to be consistent with normal
231 | BN layers.
232 | Args:
233 | sd (OrderedDict): a dict of parameters whitch might contain Sub-BN
234 | parameters.
235 | Returns:
236 | new_sd (OrderedDict): a dict with Sub-BN parameters reshaped to
237 | normal parameters.
238 | """
239 | new_sd = copy.deepcopy(sd)
240 | modifications = [
241 | ("bn.bn.running_mean", "bn.running_mean"),
242 | ("bn.bn.running_var", "bn.running_var"),
243 | ("bn.split_bn.num_batches_tracked", "bn.num_batches_tracked"),
244 | ]
245 | to_remove = ["bn.bn.", ".split_bn."]
246 | for key in sd:
247 | for before, after in modifications:
248 | if key.endswith(before):
249 | new_key = key.split(before)[0] + after
250 | new_sd[new_key] = new_sd.pop(key)
251 |
252 | for rm in to_remove:
253 | if rm in key and key in new_sd:
254 | del new_sd[key]
255 |
256 | for key in new_sd:
257 | if key.endswith("bn.weight") or key.endswith("bn.bias"):
258 | if len(new_sd[key].size()) == 4:
259 | assert all(d == 1 for d in new_sd[key].size()[1:])
260 | new_sd[key] = new_sd[key][:, 0, 0, 0]
261 |
262 | return new_sd
263 |
264 |
265 | def c2_normal_to_sub_bn(key, model_keys):
266 | """
267 | Convert BN parameters to Sub-BN parameters if model contains Sub-BNs.
268 | Args:
269 | key (OrderedDict): source dict of parameters.
270 | mdoel_key (OrderedDict): target dict of parameters.
271 | Returns:
272 | new_sd (OrderedDict): converted dict of parameters.
273 | """
274 | if "bn.running_" in key:
275 | if key in model_keys:
276 | return key
277 |
278 | new_key = key.replace("bn.running_", "bn.split_bn.running_")
279 | if new_key in model_keys:
280 | return new_key
281 | else:
282 | return key
283 |
284 |
285 | def normal_to_sub_bn(checkpoint_sd, model_sd):
286 | """
287 | Convert BN parameters to Sub-BN parameters if model contains Sub-BNs.
288 | Args:
289 | checkpoint_sd (OrderedDict): source dict of parameters.
290 | model_sd (OrderedDict): target dict of parameters.
291 | Returns:
292 | new_sd (OrderedDict): converted dict of parameters.
293 | """
294 | for key in model_sd:
295 | if key not in checkpoint_sd:
296 | if "bn.split_bn." in key:
297 | load_key = key.replace("bn.split_bn.", "bn.")
298 | bn_key = key.replace("bn.split_bn.", "bn.bn.")
299 | checkpoint_sd[key] = checkpoint_sd.pop(load_key)
300 | checkpoint_sd[bn_key] = checkpoint_sd[key]
301 |
302 | for key in model_sd:
303 | if key in checkpoint_sd:
304 | model_blob_shape = model_sd[key].shape
305 | c2_blob_shape = checkpoint_sd[key].shape
306 |
307 | if (
308 | len(model_blob_shape) == 1
309 | and len(c2_blob_shape) == 1
310 | and model_blob_shape[0] > c2_blob_shape[0]
311 | and model_blob_shape[0] % c2_blob_shape[0] == 0
312 | ):
313 | before_shape = checkpoint_sd[key].shape
314 | checkpoint_sd[key] = torch.cat(
315 | [checkpoint_sd[key]]
316 | * (model_blob_shape[0] // c2_blob_shape[0])
317 | )
318 | logger.info(
319 | "{} {} -> {}".format(
320 | key, before_shape, checkpoint_sd[key].shape
321 | )
322 | )
323 | return checkpoint_sd
324 |
325 |
326 | def load_test_checkpoint(cfg, model):
327 | """
328 | Loading checkpoint logic for testing.
329 | """
330 | # Load a checkpoint to test if applicable.
331 | if cfg['TEST']['CHECKPOINT_TEST_PATH'] != "":
332 | load_checkpoint(
333 | cfg['TEST']['CHECKPOINT_TEST_PATH'],
334 | model,
335 | cfg['NUM_GPUS'] > 1,
336 | None,
337 | None,
338 | )
339 |
340 | else:
341 | logger.info(
342 | "Unknown way of loading checkpoint. Using with random initialization, only for debugging."
343 | )
344 |
345 |
346 | def load_train_checkpoint(model, optimizer, scheduler, cfg):
347 | """
348 | Loading checkpoint logic for training.
349 | """
350 | if cfg['TRAIN']['CHECKPOINT_LOAD_PATH'] != "":
351 | print('Load from given checkpoint file.')
352 | logger.info("Load from given checkpoint file.")
353 | checkpoint_epoch = load_checkpoint(
354 | cfg['TRAIN']['CHECKPOINT_LOAD_PATH'],
355 | model,
356 | cfg['NUM_GPUS'] > 1,
357 | optimizer,
358 | scheduler,
359 | epoch_reset=cfg['TRAIN']['CHECKPOINT_EPOCH_RESET'],
360 | )
361 |
362 | start_epoch = checkpoint_epoch + 1
363 |
364 | else:
365 | start_epoch = 0
366 |
367 | return start_epoch
368 |
--------------------------------------------------------------------------------
/M2TR/utils/loss.py:
--------------------------------------------------------------------------------
1 | from abc import ABCMeta, abstractmethod
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from fvcore.common.registry import Registry
7 |
8 | from M2TR.utils.registries import LOSS_REGISTRY
9 |
10 | from .build_helper import LOSS_REGISTRY
11 |
12 |
13 | class BaseWeightedLoss(nn.Module, metaclass=ABCMeta):
14 | """Base class for loss.
15 | All subclass should overwrite the ``_forward()`` method which returns the
16 | normal loss without loss weights.
17 | Args:
18 | loss_weight (float): Factor scalar multiplied on the loss.
19 | Default: 1.0.
20 | """
21 |
22 | def __init__(self, loss_weight=1.0):
23 | super().__init__()
24 | self.loss_weight = loss_weight
25 |
26 | @abstractmethod
27 | def _forward(self, *args, **kwargs):
28 | pass
29 |
30 | def forward(self, *args, **kwargs):
31 | """Defines the computation performed at every call.
32 | Args:
33 | *args: The positional arguments for the corresponding
34 | loss.
35 | **kwargs: The keyword arguments for the corresponding
36 | loss.
37 | Returns:
38 | torch.Tensor: The calculated loss.
39 | """
40 | ret = self._forward(*args, **kwargs)
41 | if isinstance(ret, dict):
42 | for k in ret:
43 | if 'loss' in k:
44 | ret[k] *= self.loss_weight
45 | else:
46 | ret *= self.loss_weight
47 | return ret
48 |
49 |
50 | @LOSS_REGISTRY.register()
51 | class CrossEntropyLoss(BaseWeightedLoss):
52 | """Cross Entropy Loss.
53 | Support two kinds of labels and their corresponding loss type. It's worth
54 | mentioning that loss type will be detected by the shape of ``cls_score``
55 | and ``label``.
56 | 1) Hard label: This label is an integer array and all of the elements are
57 | in the range [0, num_classes - 1]. This label's shape should be
58 | ``cls_score``'s shape with the `num_classes` dimension removed.
59 | 2) Soft label(probablity distribution over classes): This label is a
60 | probability distribution and all of the elements are in the range
61 | [0, 1]. This label's shape must be the same as ``cls_score``. For now,
62 | only 2-dim soft label is supported.
63 | Args:
64 | loss_weight (float): Factor scalar multiplied on the loss.
65 | Default: 1.0.
66 | class_weight (list[float] | None): Loss weight for each class. If set
67 | as None, use the same weight 1 for all classes. Only applies
68 | to CrossEntropyLoss and BCELossWithLogits (should not be set when
69 | using other losses). Default: None.
70 | """
71 |
72 | def __init__(self, loss_cfg):
73 | super().__init__(loss_weight=loss_cfg['LOSS_WEIGHT'])
74 | self.class_weight = (
75 | torch.Tensor(loss_cfg['CLASS_WEIGHT'])
76 | if 'CLASS_WEIGHT' in loss_cfg
77 | else None
78 | )
79 |
80 | def _forward(self, outputs, samples, **kwargs):
81 | """Forward function.
82 | Args:
83 | cls_score (torch.Tensor): The class score.
84 | samples (dict): The ground truth labels.
85 | kwargs: Any keyword argument to be used to calculate
86 | CrossEntropy loss.
87 | Returns:
88 | torch.Tensor: The returned CrossEntropy loss.
89 | """
90 | cls_score = outputs['logits']
91 | label = samples['bin_label_onehot']
92 | if cls_score.size() == label.size():
93 | # calculate loss for soft labels
94 |
95 | assert cls_score.dim() == 2, 'Only support 2-dim soft label'
96 | assert len(kwargs) == 0, (
97 | 'For now, no extra args are supported for soft label, '
98 | f'but get {kwargs}'
99 | )
100 |
101 | lsm = F.log_softmax(cls_score, 1)
102 | if self.class_weight is not None:
103 | lsm = lsm * self.class_weight.unsqueeze(0).to(cls_score.device)
104 | loss_cls = -(label * lsm).sum(1)
105 |
106 | # default reduction 'mean'
107 | if self.class_weight is not None:
108 | # Use weighted average as pytorch CrossEntropyLoss does.
109 | # For more information, please visit https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html # noqa
110 | loss_cls = loss_cls.sum() / torch.sum(
111 | self.class_weight.unsqueeze(0).to(cls_score.device) * label
112 | )
113 | else:
114 | loss_cls = loss_cls.mean()
115 | else:
116 | # calculate loss for hard label
117 |
118 | if self.class_weight is not None:
119 | assert (
120 | 'weight' not in kwargs
121 | ), "The key 'weight' already exists."
122 | kwargs['weight'] = self.class_weight.to(cls_score.device)
123 | loss_cls = F.cross_entropy(cls_score, label, **kwargs)
124 |
125 | return loss_cls
126 |
127 |
128 | @LOSS_REGISTRY.register()
129 | class BCELossWithLogits(BaseWeightedLoss):
130 | """Binary Cross Entropy Loss with logits.
131 | Args:
132 | loss_weight (float): Factor scalar multiplied on the loss.
133 | Default: 1.0.
134 | class_weight (list[float] | None): Loss weight for each class. If set
135 | as None, use the same weight 1 for all classes. Only applies
136 | to CrossEntropyLoss and BCELossWithLogits (should not be set when
137 | using other losses). Default: None.
138 | """
139 |
140 | def __init__(self, loss_cfg):
141 | super().__init__(loss_weight=loss_cfg['LOSS_WEIGHT'])
142 | self.class_weight = (
143 | torch.Tensor(loss_cfg['CLASS_WEIGHT'])
144 | if 'CLASS_WEIGHT' in loss_cfg
145 | else None
146 | )
147 |
148 | def _forward(self, outputs, samples, **kwargs):
149 | """Forward function.
150 | Args:
151 | cls_score (torch.Tensor): The class score.
152 | samples (dict): The ground truth labels.
153 | kwargs: Any keyword argument to be used to calculate
154 | bce loss with logits.
155 | Returns:
156 | torch.Tensor: The returned bce loss with logits.
157 | """
158 | cls_score = outputs['logits']
159 | label = samples['bin_label_onehot']
160 | if self.class_weight is not None:
161 | assert (
162 | 'weight' not in kwargs
163 | ), "The key 'weight' already exists."
164 | kwargs['weight'] = self.class_weight.to(cls_score.device)
165 | loss_cls = F.binary_cross_entropy_with_logits(
166 | cls_score, label, **kwargs
167 | )
168 | return loss_cls
169 |
170 |
171 | @LOSS_REGISTRY.register()
172 | class MSELoss(BaseWeightedLoss):
173 | """MSE Loss
174 | Args:
175 | loss_weight (float): Factor scalar multiplied on the loss.
176 | Default: 1.0.
177 |
178 | """
179 |
180 | def __init__(self, loss_cfg):
181 | super().__init__(loss_weight=loss_cfg['LOSS_WEIGHT'])
182 | self.mse = nn.MSELoss()
183 |
184 | def _forward(self, pred_mask, gt_mask, **kwargs): # TODO samples
185 | loss = self.mse(pred_mask, gt_mask)
186 | return loss
187 |
188 |
189 | @LOSS_REGISTRY.register()
190 | class ICCLoss(BaseWeightedLoss):
191 | """Contrastive Loss
192 | Args:
193 | loss_weight (float): Factor scalar multiplied on the loss.
194 | Default: 1.0.
195 | """
196 |
197 | def __init__(self, loss_cfg):
198 | super().__init__(loss_weight=loss_cfg['LOSS_WEIGHT'])
199 |
200 | def _forward(self, feature, label, **kwargs): # TODO samples
201 | # size of feature is (b, 1024)
202 | # size of label is (b)
203 | C = feature.size(1)
204 | label = label.unsqueeze(1)
205 | label = label.repeat(1, C)
206 | # print(label.device)
207 | label = label.type(torch.BoolTensor).cuda()
208 |
209 | res_label = torch.zeros(label.size(), dtype=label.dtype)
210 | res_label = torch.where(label == 1, 0, 1)
211 | res_label = res_label.type(torch.BoolTensor).cuda()
212 |
213 | # print(label, res_label)
214 | pos_feature = torch.masked_select(feature, label)
215 | neg_feature = torch.masked_select(feature, res_label)
216 |
217 | # print('pos_fea: ', pos_feature.device)
218 | # print('nge_fea: ', neg_feature.device)
219 | pos_feature = pos_feature.view(-1, C)
220 | neg_feature = neg_feature.view(-1, C)
221 |
222 | pos_center = torch.mean(pos_feature, dim=0, keepdim=True)
223 |
224 | # dis_pos = torch.sum((pos_feature - pos_center)**2) / torch.norm(pos_feature, p=1)
225 | # dis_neg = torch.sum((neg_feature - pos_center)**2) / torch.norm(neg_feature, p=1)
226 | num_p = pos_feature.size(0)
227 | num_n = neg_feature.size(0)
228 | pos_center1 = pos_center.repeat(num_p, 1)
229 | pos_center2 = pos_center.repeat(num_n, 1)
230 | dis_pos = F.cosine_similarity(pos_feature, pos_center1, eps=1e-6)
231 | dis_pos = torch.mean(dis_pos, dim=0)
232 | dis_neg = F.cosine_similarity(neg_feature, pos_center2, eps=1e-6)
233 | dis_neg = torch.mean(dis_neg, dim=0)
234 |
235 | loss = dis_pos - dis_neg
236 |
237 | return loss
238 |
239 |
240 | @LOSS_REGISTRY.register()
241 | class FocalLoss(BaseWeightedLoss):
242 | def __init__(self, loss_cfg):
243 | super().__init__(loss_weight=loss_cfg['LOSS_WEIGHT'])
244 | super(FocalLoss, self).__init__()
245 | self.alpha = loss_cfg['ALPHA'] if 'ALPHA' in loss_cfg.keys() else 1
246 | self.gamma = loss_cfg['GAMMA'] if 'GAMMA' in loss_cfg.keys() else 2
247 | self.logits = (
248 | loss_cfg['LOGITS'] if 'LOGITS' in loss_cfg.keys() else True
249 | )
250 | self.reduce = (
251 | loss_cfg['REDUCE'] if 'REDUCE' in loss_cfg.keys() else True
252 | )
253 |
254 | def _forward(self, outputs, samples, **kwargs):
255 | cls_score = outputs['logits']
256 | label = samples['bin_label_onehot']
257 | if self.logits: # TODO
258 | BCE_loss = F.binary_cross_entropy_with_logits(
259 | cls_score, label, reduce=False
260 | )
261 | else:
262 | BCE_loss = F.binary_cross_entropy(cls_score, label, reduce=False)
263 | pt = torch.exp(-BCE_loss)
264 | F_loss = self.alpha * (1 - pt) ** self.gamma * BCE_loss
265 |
266 | if self.reduce:
267 | return torch.mean(F_loss)
268 | else:
269 | return F_loss
270 |
271 |
272 | @LOSS_REGISTRY.register()
273 | class Auxiliary_Loss_v2(BaseWeightedLoss):
274 | def __init__(self, loss_cfg):
275 | super().__init__(loss_weight=loss_cfg['AUX_LOSS_WEIGHT'])
276 | M = loss_cfg['M'] if 'M' in loss_cfg.keys() else 1
277 | N = loss_cfg['N'] if 'N' in loss_cfg.keys() else 1
278 | C = loss_cfg['C'] if 'C' in loss_cfg.keys() else 1
279 | alpha = loss_cfg['ALPHA'] if 'ALPHA' in loss_cfg.keys() else 0.05
280 | margin = loss_cfg['MARGIN'] if 'MARGIN' in loss_cfg.keys() else 1
281 | inner_margin = (
282 | loss_cfg['INNER_MARGIN']
283 | if 'INNER_MARGIN' in loss_cfg.keys()
284 | else [0.1, 5]
285 | )
286 |
287 | self.register_buffer('feature_centers', torch.zeros(M, N))
288 | self.register_buffer('alpha', torch.tensor(alpha))
289 | self.num_classes = C
290 | self.margin = margin
291 | from M2TR.models.matdd import AttentionPooling
292 |
293 | self.atp = AttentionPooling()
294 | self.register_buffer('inner_margin', torch.Tensor(inner_margin))
295 |
296 | def _forward(self, feature_map_d, attentions, y):
297 | B, N, H, W = feature_map_d.size()
298 | B, M, AH, AW = attentions.size()
299 | if AH != H or AW != W:
300 | attentions = F.interpolate(
301 | attentions, (H, W), mode='bilinear', align_corners=True
302 | )
303 | feature_matrix = self.atp(feature_map_d, attentions)
304 | feature_centers = self.feature_centers
305 | center_momentum = feature_matrix - feature_centers
306 | real_mask = (y == 0).view(-1, 1, 1)
307 | fcts = (
308 | self.alpha * torch.mean(center_momentum * real_mask, dim=0)
309 | + feature_centers
310 | )
311 | fctsd = fcts.detach()
312 | if self.training:
313 | with torch.no_grad():
314 | if torch.distributed.is_initialized():
315 | torch.distributed.all_reduce(
316 | fctsd, torch.distributed.ReduceOp.SUM
317 | )
318 | fctsd /= torch.distributed.get_world_size()
319 | self.feature_centers = fctsd
320 | inner_margin = self.inner_margin[y]
321 | intra_class_loss = F.relu(
322 | torch.norm(feature_matrix - fcts, dim=[1, 2])
323 | * torch.sign(inner_margin)
324 | - inner_margin
325 | )
326 | intra_class_loss = torch.mean(intra_class_loss)
327 | inter_class_loss = 0
328 | for j in range(M):
329 | for k in range(j + 1, M):
330 | inter_class_loss += F.relu(
331 | self.margin - torch.dist(fcts[j], fcts[k]), inplace=False
332 | )
333 | inter_class_loss = inter_class_loss / M / self.alpha
334 | # fmd=attentions.flatten(2)
335 | # diverse_loss=torch.mean(F.relu(F.cosine_similarity(fmd.unsqueeze(1),fmd.unsqueeze(2),dim=3)-self.margin,inplace=True)*(1-torch.eye(M,device=attentions.device)))
336 | return intra_class_loss + inter_class_loss, feature_matrix
337 |
338 |
339 | class Auxiliary_Loss_v1(nn.Module):
340 | def __init__(self, loss_cfg):
341 | super().__init__(loss_weight=loss_cfg['loss_weight'])
342 | M = loss_cfg['M'] if 'M' in loss_cfg.keys() else 1
343 | N = loss_cfg['N'] if 'N' in loss_cfg.keys() else 1
344 | C = loss_cfg['N'] if 'N' in loss_cfg.keys() else 1
345 | alpha = loss_cfg['N'] if 'N' in loss_cfg.keys() else 0.05
346 | margin = loss_cfg['N'] if 'N' in loss_cfg.keys() else 1
347 | inner_margin = (
348 | loss_cfg['inner_margin']
349 | if 'inner_margin' in loss_cfg.keys()
350 | else [0.01, 0.02]
351 | )
352 | self.register_buffer('feature_centers', torch.zeros(M, N))
353 | self.register_buffer('alpha', torch.tensor(alpha))
354 | self.num_classes = C
355 | self.margin = margin
356 | from M2TR.models.matdd import AttentionPooling
357 |
358 | self.atp = AttentionPooling()
359 | self.register_buffer('inner_margin', torch.Tensor(inner_margin))
360 |
361 | def forward(self, feature_map_d, attentions, y):
362 | B, N, H, W = feature_map_d.size()
363 | B, M, AH, AW = attentions.size()
364 | if AH != H or AW != W:
365 | attentions = F.interpolate(
366 | attentions, (H, W), mode='bilinear', align_corners=True
367 | )
368 | feature_matrix = self.atp(feature_map_d, attentions)
369 | feature_centers = self.feature_centers.detach()
370 | center_momentum = feature_matrix - feature_centers
371 | fcts = self.alpha * torch.mean(center_momentum, dim=0) + feature_centers
372 | fctsd = fcts.detach()
373 | if self.training:
374 | with torch.no_grad():
375 | if torch.distributed.is_initialized():
376 | torch.distributed.all_reduce(
377 | fctsd, torch.distributed.ReduceOp.SUM
378 | )
379 | fctsd /= torch.distributed.get_world_size()
380 | self.feature_centers = fctsd
381 | inner_margin = torch.gather(
382 | self.inner_margin.repeat(B, 1), 1, y.unsqueeze(1)
383 | )
384 | intra_class_loss = F.relu(
385 | torch.norm(feature_matrix - fcts, dim=-1) - inner_margin
386 | )
387 | intra_class_loss = torch.mean(intra_class_loss)
388 | inter_class_loss = 0
389 | for j in range(M):
390 | for k in range(j + 1, M):
391 | inter_class_loss += F.relu(
392 | self.margin - torch.dist(fcts[j], fcts[k]), inplace=False
393 | )
394 | inter_calss_loss = inter_class_loss / M / self.alpha
395 | # fmd=attentions.flatten(2)
396 | # inter_class_loss=torch.mean(F.relu(F.cosine_similarity(fmd.unsqueeze(1),fmd.unsqueeze(2),dim=3)-self.margin,inplace=True)*(1-torch.eye(M,device=attentions.device)))
397 | return intra_class_loss + inter_class_loss, feature_matrix
398 |
399 |
400 | @LOSS_REGISTRY.register()
401 | class Multi_attentional_Deepfake_Detection_loss(nn.Module):
402 | def __init__(self, loss_cfg) -> None:
403 | super().__init__()
404 | self.loss_cfg = loss_cfg
405 |
406 | def forward(self, loss_pack, label):
407 | if 'loss' in loss_pack:
408 | return loss_pack['loss']
409 | loss = (
410 | self.loss_cfg['ENSEMBLE_LOSS_WEIGHT'] * loss_pack['ensemble_loss']
411 | + self.loss_cfg['AUX_LOSS_WEIGHT'] * loss_pack['aux_loss']
412 | )
413 | if self.loss_cfg['AGDA_LOSS_WEIGHT'] != 0:
414 | loss += (
415 | self.loss_cfg['AGDA_LOSS_WEIGHT']
416 | * loss_pack['AGDA_ensemble_loss']
417 | + self.loss_cfg['MATCH_LOSS_WEIGHT'] * loss_pack['match_loss']
418 | )
419 | return loss
420 |
--------------------------------------------------------------------------------
/M2TR/models/efficientnet.py:
--------------------------------------------------------------------------------
1 | import collections
2 | import math
3 | import re
4 | from functools import partial
5 |
6 | import torch
7 | from torch import nn
8 | from torch.nn import functional as F
9 | from torch.utils import model_zoo
10 |
11 | from M2TR.utils.registries import MODEL_REGISTRY
12 |
13 | '''
14 | MODEL:
15 | MODEL_NAME: efficientnet
16 | NAME: efficientnet-b4
17 | PRETRAINED: True
18 | '''
19 |
20 |
21 | # Parameters for the entire model (stem, all blocks, and head)
22 | GlobalParams = collections.namedtuple(
23 | 'GlobalParams',
24 | [
25 | 'batch_norm_momentum',
26 | 'batch_norm_epsilon',
27 | 'dropout_rate',
28 | 'num_classes',
29 | 'width_coefficient',
30 | 'depth_coefficient',
31 | 'depth_divisor',
32 | 'min_depth',
33 | 'drop_connect_rate',
34 | 'image_size',
35 | ],
36 | )
37 |
38 | # Parameters for an individual model block
39 | BlockArgs = collections.namedtuple(
40 | 'BlockArgs',
41 | [
42 | 'kernel_size',
43 | 'num_repeat',
44 | 'input_filters',
45 | 'output_filters',
46 | 'expand_ratio',
47 | 'id_skip',
48 | 'stride',
49 | 'se_ratio',
50 | ],
51 | )
52 |
53 | # Change namedtuple defaults
54 | GlobalParams.__new__.__defaults__ = (None,) * len(GlobalParams._fields)
55 | BlockArgs.__new__.__defaults__ = (None,) * len(BlockArgs._fields)
56 |
57 |
58 | efficientnet_params = {
59 | # Coefficients: width,depth,res,dropout
60 | 'efficientnet-b0': (1.0, 1.0, 224, 0.2),
61 | 'efficientnet-b1': (1.0, 1.1, 240, 0.2),
62 | 'efficientnet-b2': (1.1, 1.2, 260, 0.3),
63 | 'efficientnet-b3': (1.2, 1.4, 300, 0.3),
64 | 'efficientnet-b4': (1.4, 1.8, 380, 0.4),
65 | 'efficientnet-b5': (1.6, 2.2, 456, 0.4),
66 | 'efficientnet-b6': (1.8, 2.6, 528, 0.5),
67 | 'efficientnet-b7': (2.0, 3.1, 600, 0.5),
68 | 'efficientnet-b8': (2.2, 3.6, 672, 0.5),
69 | 'efficientnet-l2': (4.3, 5.3, 800, 0.5),
70 | }
71 |
72 |
73 | blocks_args_str = [
74 | 'r1_k3_s11_e1_i32_o16_se0.25',
75 | 'r2_k3_s22_e6_i16_o24_se0.25',
76 | 'r2_k5_s22_e6_i24_o40_se0.25',
77 | 'r3_k3_s22_e6_i40_o80_se0.25',
78 | 'r3_k5_s11_e6_i80_o112_se0.25',
79 | 'r4_k5_s22_e6_i112_o192_se0.25',
80 | 'r1_k3_s11_e6_i192_o320_se0.25',
81 | ]
82 |
83 |
84 | url_map = {
85 | 'efficientnet-b0': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth',
86 | 'efficientnet-b1': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b1-f1951068.pth',
87 | 'efficientnet-b2': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b2-8bb594d6.pth',
88 | 'efficientnet-b3': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b3-5fb5a3c3.pth',
89 | 'efficientnet-b4': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b4-6ed6700e.pth',
90 | 'efficientnet-b5': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b5-b6417697.pth',
91 | 'efficientnet-b6': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b6-c76e70fd.pth',
92 | 'efficientnet-b7': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b7-dcc49843.pth',
93 | }
94 |
95 |
96 | url_map_advprop = {
97 | 'efficientnet-b0': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b0-b64d5a18.pth',
98 | 'efficientnet-b1': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b1-0f3ce85a.pth',
99 | 'efficientnet-b2': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b2-6e9d97e5.pth',
100 | 'efficientnet-b3': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b3-cdd7c0f4.pth',
101 | 'efficientnet-b4': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b4-44fb3a87.pth',
102 | 'efficientnet-b5': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b5-86493f6b.pth',
103 | 'efficientnet-b6': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b6-ac80338e.pth',
104 | 'efficientnet-b7': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b7-4652b6dd.pth',
105 | 'efficientnet-b8': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b8-22a8fe65.pth',
106 | }
107 |
108 |
109 | class BlockDecoder(object):
110 | @staticmethod
111 | def _decode_block_string(block_string):
112 | """Gets a block through a string notation of arguments."""
113 | assert isinstance(block_string, str)
114 |
115 | ops = block_string.split('_')
116 | options = {}
117 | for op in ops:
118 | splits = re.split(r'(\d.*)', op)
119 | if len(splits) >= 2:
120 | key, value = splits[:2]
121 | options[key] = value
122 |
123 | # Check stride
124 | assert ('s' in options and len(options['s']) == 1) or (
125 | len(options['s']) == 2 and options['s'][0] == options['s'][1]
126 | )
127 |
128 | return BlockArgs(
129 | kernel_size=int(options['k']),
130 | num_repeat=int(options['r']),
131 | input_filters=int(options['i']),
132 | output_filters=int(options['o']),
133 | expand_ratio=int(options['e']),
134 | id_skip=('noskip' not in block_string),
135 | se_ratio=float(options['se']) if 'se' in options else None,
136 | stride=[int(options['s'][0])],
137 | )
138 |
139 | @staticmethod
140 | def decode(string_list):
141 | assert isinstance(string_list, list)
142 | blocks_args = []
143 | for block_string in string_list:
144 | blocks_args.append(BlockDecoder._decode_block_string(block_string))
145 | return blocks_args
146 |
147 |
148 | class SwishImplementation(torch.autograd.Function):
149 | @staticmethod
150 | def forward(ctx, i):
151 | result = i * torch.sigmoid(i)
152 | ctx.save_for_backward(i)
153 | return result
154 |
155 | @staticmethod
156 | def backward(ctx, grad_output):
157 | i = ctx.saved_variables[0]
158 | sigmoid_i = torch.sigmoid(i)
159 | return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))
160 |
161 |
162 | class MemoryEfficientSwish(nn.Module):
163 | def forward(self, x):
164 | return SwishImplementation.apply(x)
165 |
166 |
167 | class Swish(nn.Module):
168 | def forward(self, x):
169 | return x * torch.sigmoid(x)
170 |
171 |
172 | def round_filters(filters, global_params):
173 | """Calculate and round number of filters based on depth multiplier."""
174 | multiplier = global_params.width_coefficient
175 | if not multiplier:
176 | return filters
177 | divisor = global_params.depth_divisor
178 | min_depth = global_params.min_depth
179 | filters *= multiplier
180 | min_depth = min_depth or divisor
181 | new_filters = max(
182 | min_depth, int(filters + divisor / 2) // divisor * divisor
183 | )
184 | if new_filters < 0.9 * filters: # prevent rounding by more than 10%
185 | new_filters += divisor
186 | return int(new_filters)
187 |
188 |
189 | def round_repeats(repeats, global_params):
190 | """Round number of filters based on depth multiplier."""
191 | multiplier = global_params.depth_coefficient
192 | if not multiplier:
193 | return repeats
194 | return int(math.ceil(multiplier * repeats))
195 |
196 |
197 | def drop_connect(inputs, p, training):
198 | """Drop connect."""
199 | if not training:
200 | return inputs
201 | batch_size = inputs.shape[0]
202 | keep_prob = 1 - p
203 | random_tensor = keep_prob
204 | random_tensor += torch.rand(
205 | [batch_size, 1, 1, 1], dtype=inputs.dtype, device=inputs.device
206 | )
207 | binary_tensor = torch.floor(random_tensor)
208 | output = inputs / keep_prob * binary_tensor
209 | return output
210 |
211 |
212 | class Identity(nn.Module):
213 | def __init__(
214 | self,
215 | ):
216 | super(Identity, self).__init__()
217 |
218 | def forward(self, input):
219 | return input
220 |
221 |
222 | class MBConvBlock(nn.Module):
223 | """
224 | Mobile Inverted Residual Bottleneck Block
225 | Args:
226 | block_args (namedtuple): BlockArgs, see above
227 | global_params (namedtuple): GlobalParam, see above
228 | Attributes:
229 | has_se (bool): Whether the block contains a Squeeze and Excitation layer.
230 | """
231 |
232 | def __init__(self, block_args, global_params):
233 | super().__init__()
234 | self._block_args = block_args
235 | self._bn_mom = 1 - global_params.batch_norm_momentum
236 | self._bn_eps = global_params.batch_norm_epsilon
237 | self.has_se = (self._block_args.se_ratio is not None) and (
238 | 0 < self._block_args.se_ratio <= 1
239 | )
240 | self.id_skip = block_args.id_skip # skip connection and drop connect
241 |
242 | # Get static or dynamic convolution depending on image size
243 | Conv2d = partial(
244 | Conv2dStaticSamePadding, image_size=global_params.image_size
245 | )
246 |
247 | # Expansion phase
248 | inp = self._block_args.input_filters # number of input channels
249 | oup = (
250 | self._block_args.input_filters * self._block_args.expand_ratio
251 | ) # number of output channels
252 | if self._block_args.expand_ratio != 1:
253 | self._expand_conv = Conv2d(
254 | in_channels=inp, out_channels=oup, kernel_size=1, bias=False
255 | )
256 | self._bn0 = nn.BatchNorm2d(
257 | num_features=oup, momentum=self._bn_mom, eps=self._bn_eps
258 | )
259 |
260 | # Depthwise convolution phase
261 | k = self._block_args.kernel_size
262 | s = self._block_args.stride
263 | self._depthwise_conv = Conv2d(
264 | in_channels=oup,
265 | out_channels=oup,
266 | groups=oup, # groups makes it depthwise
267 | kernel_size=k,
268 | stride=s,
269 | bias=False,
270 | )
271 | self._bn1 = nn.BatchNorm2d(
272 | num_features=oup, momentum=self._bn_mom, eps=self._bn_eps
273 | )
274 |
275 | # Squeeze and Excitation layer, if desired
276 | if self.has_se:
277 | num_squeezed_channels = max(
278 | 1,
279 | int(self._block_args.input_filters * self._block_args.se_ratio),
280 | )
281 | self._se_reduce = Conv2d(
282 | in_channels=oup,
283 | out_channels=num_squeezed_channels,
284 | kernel_size=1,
285 | )
286 | self._se_expand = Conv2d(
287 | in_channels=num_squeezed_channels,
288 | out_channels=oup,
289 | kernel_size=1,
290 | )
291 |
292 | # Output phase
293 | final_oup = self._block_args.output_filters
294 | self._project_conv = Conv2d(
295 | in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False
296 | )
297 | self._bn2 = nn.BatchNorm2d(
298 | num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps
299 | )
300 | self._swish = MemoryEfficientSwish()
301 |
302 | def forward(self, inputs, drop_connect_rate=None):
303 | # Expansion and Depthwise Convolution
304 | x = inputs
305 | if self._block_args.expand_ratio != 1:
306 | x = self._swish(self._bn0(self._expand_conv(inputs)))
307 | x = self._swish(self._bn1(self._depthwise_conv(x)))
308 |
309 | # Squeeze and Excitation
310 | if self.has_se:
311 | x_squeezed = F.adaptive_avg_pool2d(x, 1)
312 | x_squeezed = self._se_expand(
313 | self._swish(self._se_reduce(x_squeezed))
314 | )
315 | x = torch.sigmoid(x_squeezed) * x
316 |
317 | x = self._bn2(self._project_conv(x))
318 |
319 | # Skip connection and drop connect
320 | input_filters, output_filters = (
321 | self._block_args.input_filters,
322 | self._block_args.output_filters,
323 | )
324 | if (
325 | self.id_skip
326 | and self._block_args.stride == 1
327 | and input_filters == output_filters
328 | ):
329 | if drop_connect_rate:
330 | x = drop_connect(x, p=drop_connect_rate, training=self.training)
331 | x = x + inputs # skip connection
332 | return x
333 |
334 | def set_swish(self, memory_efficient=True):
335 | """Sets swish function as memory efficient (for training) or standard (for export)"""
336 | self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
337 |
338 |
339 | class Conv2dStaticSamePadding(nn.Conv2d):
340 | """2D Convolutions like TensorFlow, for a fixed image size"""
341 |
342 | def __init__(
343 | self, in_channels, out_channels, kernel_size, image_size=None, **kwargs
344 | ):
345 | super().__init__(in_channels, out_channels, kernel_size, **kwargs)
346 | self.stride = (
347 | self.stride if len(self.stride) == 2 else [self.stride[0]] * 2
348 | )
349 |
350 | # Calculate padding based on image size and save it
351 | assert image_size is not None
352 | ih, iw = (
353 | image_size if type(image_size) == list else [image_size, image_size]
354 | )
355 | kh, kw = self.weight.size()[-2:]
356 | sh, sw = self.stride
357 | oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
358 | pad_h = max(
359 | (oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0
360 | )
361 | pad_w = max(
362 | (ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0
363 | )
364 | if pad_h > 0 or pad_w > 0:
365 | self.static_padding = nn.ZeroPad2d(
366 | (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2)
367 | )
368 | else:
369 | self.static_padding = Identity()
370 |
371 | def forward(self, x):
372 | x = self.static_padding(x)
373 | x = F.conv2d(
374 | x,
375 | self.weight,
376 | self.bias,
377 | self.stride,
378 | self.padding,
379 | self.dilation,
380 | self.groups,
381 | )
382 | return x
383 |
384 |
385 | @MODEL_REGISTRY.register()
386 | class EfficientNet(nn.Module):
387 | def __init__(self, model_cfg):
388 | super().__init__()
389 | model_name = model_cfg['NAME']
390 | self.check_model_name_is_valid(model_name)
391 | blocks_args = BlockDecoder.decode(blocks_args_str)
392 | w, d, s, p = efficientnet_params[model_name]
393 | # note: all models have drop connect rate = 0.2
394 | global_params = GlobalParams(
395 | batch_norm_momentum=0.99,
396 | batch_norm_epsilon=1e-3,
397 | dropout_rate=p,
398 | drop_connect_rate=0.2,
399 | num_classes=2,
400 | width_coefficient=w,
401 | depth_coefficient=d,
402 | depth_divisor=8,
403 | min_depth=None,
404 | image_size=s,
405 | )
406 | assert isinstance(blocks_args, list), 'blocks_args should be a list'
407 | assert len(blocks_args) > 0, 'block args must be greater than 0'
408 | self.escape = ''
409 | self._global_params = global_params
410 | self._blocks_args = blocks_args
411 | Conv2d = partial(
412 | Conv2dStaticSamePadding, image_size=global_params.image_size
413 | )
414 |
415 | # Batch norm parameters
416 | bn_mom = 1 - self._global_params.batch_norm_momentum
417 | bn_eps = self._global_params.batch_norm_epsilon
418 |
419 | # Stem
420 | in_channels = 3 # rgb
421 | out_channels = round_filters(
422 | 32, self._global_params
423 | ) # number of output channels
424 | self._conv_stem = Conv2d(
425 | in_channels, out_channels, kernel_size=3, stride=2, bias=False
426 | )
427 | self._bn0 = nn.BatchNorm2d(
428 | num_features=out_channels, momentum=bn_mom, eps=bn_eps
429 | )
430 |
431 | # Build blocks
432 | self._blocks = nn.ModuleList([])
433 | self.stage_map = []
434 | stage_count = 0
435 | for block_args in self._blocks_args:
436 |
437 | # Update block input and output filters based on depth multiplier.
438 | block_args = block_args._replace(
439 | input_filters=round_filters(
440 | block_args.input_filters, self._global_params
441 | ),
442 | output_filters=round_filters(
443 | block_args.output_filters, self._global_params
444 | ),
445 | num_repeat=round_repeats(
446 | block_args.num_repeat, self._global_params
447 | ),
448 | )
449 | stage_count += 1
450 | self.stage_map += [''] * (block_args.num_repeat - 1)
451 | self.stage_map.append('b%s' % stage_count)
452 | # The first block needs to take care of stride and filter size increase.
453 | self._blocks.append(MBConvBlock(block_args, self._global_params))
454 |
455 | if block_args.num_repeat > 1:
456 | block_args = block_args._replace(
457 | input_filters=block_args.output_filters, stride=1
458 | )
459 | for _ in range(block_args.num_repeat - 1):
460 | self._blocks.append(
461 | MBConvBlock(block_args, self._global_params)
462 | )
463 |
464 | # Head
465 | in_channels = block_args.output_filters # output of final block
466 | out_channels = round_filters(1280, self._global_params)
467 | self._conv_head = Conv2d(
468 | in_channels, out_channels, kernel_size=1, bias=False
469 | )
470 | self._bn1 = nn.BatchNorm2d(
471 | num_features=out_channels, momentum=bn_mom, eps=bn_eps
472 | )
473 |
474 | # Final linear layer
475 | self._avg_pooling = nn.AdaptiveAvgPool2d(1)
476 | self._dropout = nn.Dropout(self._global_params.dropout_rate)
477 | self._fc = nn.Linear(out_channels, self._global_params.num_classes)
478 | self._swish = MemoryEfficientSwish()
479 |
480 | if model_cfg['PRETRAINED']:
481 | self.load_pretrained_weights(model_name, advprop=True)
482 |
483 | def set_swish(self, memory_efficient=True):
484 | """Sets swish function as memory efficient (for training) or standard (for export)"""
485 | self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
486 | for block in self._blocks:
487 | block.set_swish(memory_efficient)
488 |
489 | def extract_textures(self, inputs, layers):
490 | # Stem
491 | x = self._swish(self._bn0(self._conv_stem(inputs)))
492 | layers['b0'] = x
493 | # Blocks
494 | for idx, block in enumerate(self._blocks[:6]):
495 | drop_connect_rate = self._global_params.drop_connect_rate
496 | if drop_connect_rate:
497 | drop_connect_rate *= float(idx) / len(self._blocks)
498 | x = block(x, drop_connect_rate=drop_connect_rate)
499 | stage = self.stage_map[idx]
500 | if stage:
501 | layers[stage] = x
502 | if stage == self.escape:
503 | return None
504 |
505 | return x
506 |
507 | def extract_features(self, x, layers):
508 | # Blocks
509 | for idx, block in enumerate(self._blocks[6:]):
510 | idx += 6
511 | drop_connect_rate = self._global_params.drop_connect_rate
512 | if drop_connect_rate:
513 | drop_connect_rate *= float(idx) / len(self._blocks)
514 | x = block(x, drop_connect_rate=drop_connect_rate)
515 | stage = self.stage_map[idx]
516 | if stage:
517 | layers[stage] = x
518 | if stage == self.escape:
519 | return None
520 | # Head
521 | x = self._bn1(self._conv_head(x))
522 | x = self._swish(x)
523 | return x
524 |
525 | def forward(self, samples):
526 | x = samples['img']
527 | bs = x.size(0)
528 | layers = {}
529 | x = self.extract_textures(x, layers)
530 | x = self.extract_features(x, layers)
531 | if x is None:
532 | return layers
533 | layers['final'] = x
534 | x = self._avg_pooling(x)
535 | x = x.view(bs, -1)
536 | x = self._dropout(x)
537 | x = self._fc(x)
538 | layers['logits'] = x
539 | return layers
540 |
541 | def load_pretrained_weights(self, model_name, advprop=False):
542 | url_map_ = url_map_advprop if advprop else url_map
543 | state_dict = model_zoo.load_url(url_map_[model_name])
544 | state_dict.pop('_fc.weight')
545 | state_dict.pop('_fc.bias')
546 | res = self.load_state_dict(state_dict, strict=False)
547 | assert set(res.missing_keys) == set(
548 | ['_fc.weight', '_fc.bias']
549 | ), 'issue loading pretrained weights'
550 | print('Loaded pretrained weights for {}'.format(model_name))
551 |
552 | def check_model_name_is_valid(self, model_name):
553 | valid_models = ['efficientnet-b' + str(i) for i in range(9)]
554 | if model_name not in valid_models:
555 | raise ValueError(
556 | 'model_name should be one of: ' + ', '.join(valid_models)
557 | )
558 |
--------------------------------------------------------------------------------