├── README.md
├── __init__.py
├── checkpoints
├── .DS_Store
└── .empty
├── configs
├── compnode
│ ├── 1-node-cluster.yaml
│ ├── 2-node-cluster.yaml
│ ├── 4-node-cluster.yaml
│ ├── cpu.yaml
│ ├── mps.yaml
│ ├── single-gpu.yaml
│ └── two-gpus.yaml
├── config.yaml
├── dataset
│ ├── ADE20K.yaml
│ ├── CelebAHQ.yaml
│ ├── idesigner.yaml
│ ├── train_aug
│ │ ├── basic.yaml
│ │ ├── basicpad.yaml
│ │ ├── crop.yaml
│ │ ├── none.yaml
│ │ └── nonepad.yaml
│ └── val_aug
│ │ ├── crop.yaml
│ │ ├── none.yaml
│ │ └── nonepad.yaml
└── model
│ ├── clade.yaml
│ ├── disc_augments
│ └── basic.yaml
│ ├── discriminator
│ ├── mcad.yaml
│ ├── oasis.yaml
│ ├── patchgan.yaml
│ └── stylegan2.yaml
│ ├── encoder
│ ├── groupdnet.yaml
│ ├── inade.yaml
│ ├── sat.yaml
│ ├── sean.yaml
│ └── spade.yaml
│ ├── generator
│ ├── clade.yaml
│ ├── groupdnet.yaml
│ ├── inade.yaml
│ ├── scam.yaml
│ ├── sean.yaml
│ ├── sean_clade.yaml
│ └── spade.yaml
│ ├── groupdnet.yaml
│ ├── inade.yaml
│ ├── optim
│ ├── adam.yaml
│ ├── adamw.yaml
│ └── sam.yaml
│ ├── scam.yaml
│ ├── sean.yaml
│ ├── sean_clade.yaml
│ └── spade.yaml
├── data
├── __init__.py
├── datamodule.py
└── image_dataset.py
├── datasets
└── .empty
├── environment.yml
├── gan.py
├── medias
├── .DS_Store
└── teaser.png
├── metrics
├── __init__.py
├── fid.py
├── models
│ └── fastreid
│ │ ├── __init__.py
│ │ ├── layers
│ │ ├── __init__.py
│ │ ├── activation.py
│ │ ├── am_softmax.py
│ │ ├── arc_softmax.py
│ │ ├── batch_drop.py
│ │ ├── batch_norm.py
│ │ ├── circle_softmax.py
│ │ ├── context_block.py
│ │ ├── frn.py
│ │ ├── gather_layer.py
│ │ ├── non_local.py
│ │ ├── pooling.py
│ │ ├── se_layer.py
│ │ └── splat.py
│ │ └── model.py
└── reid.py
├── models
├── __init__.py
├── diffaugment.py
├── discriminators
│ ├── __init__.py
│ ├── mcad.py
│ ├── oasis.py
│ ├── patchgan.py
│ └── stylegan2.py
├── encoders
│ ├── __init__.py
│ ├── groupdnet.py
│ ├── inade.py
│ ├── sat.py
│ ├── sean.py
│ └── spade.py
├── generators
│ ├── __init__.py
│ ├── clade.py
│ ├── groupdnet.py
│ ├── inade.py
│ ├── scam.py
│ ├── sean.py
│ ├── sean_clade.py
│ └── spade.py
├── loss.py
└── utils_blocks
│ ├── EMA.py
│ ├── SAM.py
│ ├── __init__.py
│ ├── attention.py
│ ├── base.py
│ └── equallr.py
├── preprocess_fid_features.py
├── train.py
└── utils
├── .DS_Store
├── __init__.py
├── callbacks.py
├── partial_conv.py
└── utils.py
/README.md:
--------------------------------------------------------------------------------
1 | # SCAM! Transferring humans between images with Semantic Cross Attention Modulation
2 |
3 | Official PyTorch implementation of "SCAM! Transferring humans between images with Semantic Cross Attention Modulation", ECCV 2022.
4 |
5 | Arxiv | Website
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 | ## Bibtex
15 |
16 | If you happen to find this code or method useful in your research please cite this paper with
17 |
18 | ```
19 | @article{dufour2022scam,
20 | title={SCAM! Transferring humans between images with Semantic Cross Attention Modulation},
21 | author={Nicolas Dufour, David Picard, Vicky Kalogeiton},
22 | booktitle={Proceedings of the European Conference on Computer Vision (ECCV)},
23 | year={2022}
24 | }
25 | ```
26 |
27 | ## Installation guide
28 | To install this repo follow the following step:
29 |
30 | 1. Install Anaconda or MiniConda
31 | 2. Run `conda create -n scam python=3.9.12`
32 | 3. Activate scam: `conda activate scam`
33 | 4. Install pytorch 1.12 and torchvision 0.13 that match your device:
34 | - For cpu:
35 | ```bash
36 | conda install pytorch==1.12.0 torchvision==0.13.0 torchaudio==0.12.0 cpuonly -c pytorch
37 | ```
38 | - For cuda 11.6:
39 | ```bash
40 | conda install pytorch==1.12.0 torchvision==0.13.0 cudatoolkit=11.6 -c pytorch -c conda-forge
41 | ```
42 | 5. Install dependencies:
43 | ```
44 | conda env update --file environment.yml
45 | ```
46 |
47 | ## Prepare the data
48 |
49 | ### Download CelebAHQ
50 |
51 | Download link: Dataset provided and processed by SEAN authors.
52 | Unzip in the datasets folder.
53 | Then run:
54 |
55 | ```bash
56 | python preprocess_fid_features.py dataset=CelebAHQ
57 | ```
58 | to run on GPU run:
59 |
60 | ```bash
61 | python preprocess_fid_features.py dataset=CelebAHQ compnode=single-gpu
62 | ```
63 |
64 | ## Add a new dataset
65 | If you want to add a new dataset you need to follow the same directory structure as the other datasets:
66 |
67 | 2 folders: train and test.
68 | Each folders contains 4 sub-folders: images, labels, vis and stats.
69 | - images: contains the rgb images. can be jpg or png.
70 | - labels: the segmentation labels. needs to be png where each pixels takes value between 0 and num_labels.
71 | - vis: rgb visualization of the labels
72 | - stats: Inception statistics of the train dataset. To do so, initialize a dataloader with the train dataloader
73 |
74 | Create a config file for the dataset in configs/dataset (the name of the dataset should be the name of the config file)
75 |
76 | and run:
77 |
78 | ```bash
79 | python preprocess_fid_features.py dataset=dataset_name compnode=(cpu or single-gpu)
80 | ```
81 |
82 | ## Run SCAM
83 | Our code needs to be logged in wandb. Make sure to be logged in to wandb before starting.
84 | To run our code you just need to do:
85 |
86 | ```bash
87 | python train.py
88 | ```
89 |
90 | By default the code will be runned on cpu. If you want to run it on gpu you can do:
91 |
92 | ```bash
93 | python train.py compnode=single-gpu
94 | ```
95 |
96 | Other compute configs exists in `configs/compnode`. If none suit your needs you can easily create one thanks to the modularity of hydra.
97 |
98 | In this paper, the main compute config we use is the 4 gpus (NVIDIA V100 32g with a total batch size of 32, 8 per GPUs) config called `1-node-cluster`.
99 |
100 | By default, our method scam will be runned. If you want to try one of the baselines we compare to, you can run
101 |
102 | ```bash
103 | python train.py model=sean
104 | ```
105 |
106 | We include implementations for SEAN, SPADE, INADE, CLADE and GroupDNET.
107 |
108 | To swap datasets, you can do:
109 |
110 | ```bash
111 | python train.py dataset=CelebAHQ
112 | ```
113 |
114 | we support 3 datasets, `CelebAHQ`, `idesigner` and `ADE20K`.
115 |
116 | If you want to run different experiments, make sure to change `experiment_name_comp` or `experiment_name` when running. Otherwise or code will restart the training from the checkpointed weights for the experiment with the same name.
117 |
118 | ## Pretrained models
119 | Coming soon
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nicolas-dufour/SCAM/848484acd71c277f3ed48456e206870847dd400d/__init__.py
--------------------------------------------------------------------------------
/checkpoints/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nicolas-dufour/SCAM/848484acd71c277f3ed48456e206870847dd400d/checkpoints/.DS_Store
--------------------------------------------------------------------------------
/checkpoints/.empty:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nicolas-dufour/SCAM/848484acd71c277f3ed48456e206870847dd400d/checkpoints/.empty
--------------------------------------------------------------------------------
/configs/compnode/1-node-cluster.yaml:
--------------------------------------------------------------------------------
1 | devices: 4
2 | progress_bar_refresh_rate: 2
3 | sync_batchnorm: True
4 | accelerator: gpu
5 | precision: 16
6 | strategy: ddp
7 | batch_size: 8
8 | num_nodes: 1
9 |
--------------------------------------------------------------------------------
/configs/compnode/2-node-cluster.yaml:
--------------------------------------------------------------------------------
1 | devices: 4
2 | progress_bar_refresh_rate: 2
3 | sync_batchnorm: True
4 | accelerator: gpu
5 | precision: 16
6 | strategy: ddp
7 | batch_size: 8
8 | num_nodes: 2
9 |
--------------------------------------------------------------------------------
/configs/compnode/4-node-cluster.yaml:
--------------------------------------------------------------------------------
1 | devices: 4
2 | progress_bar_refresh_rate: 2
3 | sync_batchnorm: True
4 | accelerator: gpu
5 | precision: 16
6 | strategy: ddp
7 | batch_size: 8
8 | num_nodes: 4
9 |
--------------------------------------------------------------------------------
/configs/compnode/cpu.yaml:
--------------------------------------------------------------------------------
1 | devices: 1
2 | progress_bar_refresh_rate: 2
3 | sync_batchnorm: False
4 | accelerator: cpu
5 | precision: 32
6 | strategy: null
7 | batch_size: 4
8 | num_nodes: 1
--------------------------------------------------------------------------------
/configs/compnode/mps.yaml:
--------------------------------------------------------------------------------
1 | devices: 1
2 | progress_bar_refresh_rate: 2
3 | sync_batchnorm: False
4 | accelerator: mps
5 | precision: 32
6 | strategy: null
7 | batch_size: 4
8 | num_nodes: 1
--------------------------------------------------------------------------------
/configs/compnode/single-gpu.yaml:
--------------------------------------------------------------------------------
1 | devices: 1
2 | progress_bar_refresh_rate: 2
3 | sync_batchnorm: False
4 | accelerator: gpu
5 | precision: 16
6 | strategy: null
7 | batch_size: 4
8 | num_nodes: 1
--------------------------------------------------------------------------------
/configs/compnode/two-gpus.yaml:
--------------------------------------------------------------------------------
1 | devices: 2
2 | progress_bar_refresh_rate: 2
3 | sync_batchnorm: False
4 | accelerator: gpu
5 | precision: 16
6 | strategy: ddp
7 | batch_size: 4
8 | num_nodes: 1
--------------------------------------------------------------------------------
/configs/config.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - model: scam
3 | - compnode: cpu
4 | - dataset: CelebAHQ
5 |
6 | trainer:
7 | _target_: pytorch_lightning.Trainer
8 | max_steps: 100000
9 | devices: ${compnode.devices}
10 | accelerator: ${compnode.accelerator}
11 | sync_batchnorm: ${compnode.sync_batchnorm}
12 | strategy: ${compnode.strategy}
13 | log_every_n_steps: 1
14 | num_nodes: ${compnode.num_nodes}
15 | precision: ${compnode.precision}
16 | dataset:
17 | batch_size: ${compnode.batch_size}
18 |
19 | logger:
20 | _target_: pytorch_lightning.loggers.WandbLogger
21 | save_dir: ${root_dir}/wandb
22 | name: ${experiment_name}
23 | project: Pose_Transfer
24 | log_model: False
25 | offline: True
26 |
27 | checkpoints:
28 | _target_: pytorch_lightning.callbacks.ModelCheckpoint
29 | dirpath: ${root_dir}/checkpoints/${experiment_name}
30 | monitor: val/reco_fid
31 | save_last: True
32 | every_n_epochs: 1
33 |
34 | progress_bar:
35 | _target_: pytorch_lightning.callbacks.TQDMProgressBar
36 | refresh_rate: ${compnode.progress_bar_refresh_rate}
37 |
38 | hydra:
39 | run:
40 | dir: outputs/${hydra.job.name}/${now:%Y-%m-%d_%H-%M-%S}/${experiment_name}
41 |
42 | data_dir: ${root_dir}/datasets
43 | root_dir: ${hydra:runtime.cwd}
44 | experiment_name_comp: base
45 | experiment_name: ${dataset.name}_${model.name}_${experiment_name_comp}
46 |
--------------------------------------------------------------------------------
/configs/dataset/ADE20K.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - train_aug: crop
3 | - val_aug: crop
4 |
5 | path: ${data_dir}/ADE20K
6 | name: ADE20K
7 | height: 256
8 | width: 256
9 | num_labels_orig: 151
10 | num_labels: ${dataset.num_labels_orig}
11 | num_channels: 3
12 | num_workers: 4
13 | image_extension: jpg
14 | label_merge_strat: none
--------------------------------------------------------------------------------
/configs/dataset/CelebAHQ.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - train_aug: none
3 | - val_aug: none
4 |
5 | path: ${data_dir}/CelebA-HQ
6 | name: CelebA-HQ
7 | height: 256
8 | width: 256
9 | num_labels: ${dataset.num_labels_orig}
10 | num_labels_orig: 19
11 | num_channels: 3
12 | num_workers: 10
13 | image_extension: jpg
14 | label_merge_strat: none
--------------------------------------------------------------------------------
/configs/dataset/idesigner.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - train_aug: none
3 | - val_aug: none
4 |
5 | path: ${data_dir}/idesigner
6 | name: idesigner
7 | height: 256
8 | width: 170
9 | num_labels_orig: 20
10 | num_labels: 3
11 | num_channels: 3
12 | num_workers: 10
13 | image_extension: png
14 | label_merge_strat: body_face_background
--------------------------------------------------------------------------------
/configs/dataset/train_aug/basic.yaml:
--------------------------------------------------------------------------------
1 | _target_: albumentations.Compose
2 | transforms:
3 | - _target_: albumentations.ColorJitter
4 | p: 0.4
5 | - _target_: albumentations.HorizontalFlip
6 | p: 0.5
7 | - _target_: albumentations.OneOf
8 | p: 1
9 | transforms:
10 | - _target_: albumentations.RandomResizedCrop
11 | height: ${dataset.height}
12 | width: ${dataset.width}
13 | scale: [0.5, 1]
14 | ratio: [1,1]
15 | p: 0.7
16 | - _target_: albumentations.Resize
17 | height: ${dataset.height}
18 | width: ${dataset.width}
19 | p: 0.3
20 |
21 | - _target_: albumentations.Normalize
22 | mean: 0.5
23 | std: 0.5
24 | - _target_: albumentations.pytorch.ToTensorV2
--------------------------------------------------------------------------------
/configs/dataset/train_aug/basicpad.yaml:
--------------------------------------------------------------------------------
1 | _target_: albumentations.Compose
2 | transforms:
3 | - _target_: albumentations.ColorJitter
4 | p: 0.4
5 | - _target_: albumentations.HorizontalFlip
6 | p: 0.5
7 | - _target_: albumentations.LongestMaxSize
8 | p: 1
9 | max_size: ${dataset.height}
10 | - _target_: albumentations.PadIfNeeded
11 | p: 1
12 | border_mode: 0
13 | min_height: ${dataset.height}
14 | min_width: ${dataset.width}
15 | value: [0,0,0]
16 | mask_value: 0
17 |
18 | - _target_: albumentations.Normalize
19 | mean: 0.5
20 | std: 0.5
21 | - _target_: albumentations.pytorch.ToTensorV2
--------------------------------------------------------------------------------
/configs/dataset/train_aug/crop.yaml:
--------------------------------------------------------------------------------
1 | _target_: albumentations.Compose
2 | transforms:
3 | - _target_: albumentations.Resize
4 | height: ${dataset.height}
5 | width: ${dataset.width}
6 | - _target_: albumentations.ColorJitter
7 | p: 0.4
8 | - _target_: albumentations.HorizontalFlip
9 | p: 0.5
10 | - _target_: albumentations.OneOf
11 | transforms:
12 | - _target_: albumentations.RandomSizedCrop
13 | p: 0.7
14 | w2h_ratio: 0.664
15 | min_max_height: [400, 512]
16 | height: ${dataset.height}
17 | width: ${dataset.width}
18 | - _target_: albumentations.RandomSizedCrop
19 | p: 0.3
20 | w2h_ratio: 0.664
21 | min_max_height: [128, 400]
22 | height: ${dataset.height}
23 | width: ${dataset.width}
24 | p: 0.6
25 | - _target_: albumentations.Normalize
26 | mean: 0.5
27 | std: 0.5
28 | - _target_: albumentations.pytorch.ToTensorV2
--------------------------------------------------------------------------------
/configs/dataset/train_aug/none.yaml:
--------------------------------------------------------------------------------
1 | _target_: albumentations.Compose
2 | transforms:
3 | - _target_: albumentations.Resize
4 | height: ${dataset.height}
5 | width: ${dataset.width}
6 | - _target_: albumentations.Normalize
7 | mean: 0.5
8 | std: 0.5
9 | - _target_: albumentations.pytorch.ToTensorV2
--------------------------------------------------------------------------------
/configs/dataset/train_aug/nonepad.yaml:
--------------------------------------------------------------------------------
1 | _target_: albumentations.Compose
2 | transforms:
3 | - _target_: albumentations.LongestMaxSize
4 | p: 1
5 | max_size: ${dataset.height}
6 | - _target_: albumentations.PadIfNeeded
7 | p: 1
8 | border_mode: 0
9 | min_height: ${dataset.height}
10 | min_width: ${dataset.width}
11 | value: [0,0,0]
12 | mask_value: 0
13 | - _target_: albumentations.HorizontalFlip
14 | p: 0.5
15 | - _target_: albumentations.Normalize
16 | mean: 0.5
17 | std: 0.5
18 | - _target_: albumentations.pytorch.ToTensorV2
--------------------------------------------------------------------------------
/configs/dataset/val_aug/crop.yaml:
--------------------------------------------------------------------------------
1 | _target_: albumentations.Compose
2 | transforms:
3 | - _target_: albumentations.SmallestMaxSize
4 | p: 1
5 | max_size: ${dataset.height}
6 | - _target_: albumentations.CenterCrop
7 | p: 1
8 | height: ${dataset.height}
9 | width: ${dataset.width}
10 | - _target_: albumentations.HorizontalFlip
11 | p: 0.5
12 | - _target_: albumentations.Normalize
13 | mean: 0.5
14 | std: 0.5
15 | - _target_: albumentations.pytorch.ToTensorV2
--------------------------------------------------------------------------------
/configs/dataset/val_aug/none.yaml:
--------------------------------------------------------------------------------
1 | _target_: albumentations.Compose
2 | transforms:
3 | - _target_: albumentations.Resize
4 | height: ${dataset.height}
5 | width: ${dataset.width}
6 |
7 | - _target_: albumentations.Normalize
8 | mean: 0.5
9 | std: 0.5
10 | - _target_: albumentations.pytorch.ToTensorV2
--------------------------------------------------------------------------------
/configs/dataset/val_aug/nonepad.yaml:
--------------------------------------------------------------------------------
1 | _target_: albumentations.Compose
2 | transforms:
3 | - _target_: albumentations.LongestMaxSize
4 | p: 1
5 | max_size: ${dataset.height}
6 | - _target_: albumentations.PadIfNeeded
7 | p: 1
8 | border_mode: 0
9 | min_height: ${dataset.height}
10 | min_width: ${dataset.width}
11 | value: [0,0,0]
12 | mask_value: 0
13 | - _target_: albumentations.Normalize
14 | mean: 0.5
15 | std: 0.5
16 | - _target_: albumentations.pytorch.ToTensorV2
--------------------------------------------------------------------------------
/configs/model/clade.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - disc_augments: null
3 | - generator: clade
4 | - discriminator: patchgan
5 | - encoder: spade
6 | - optim: adam
7 |
8 | change_init: True
9 | gradient_clip_val: 0
10 | name: CLADE
11 | dataset_path: ${dataset.path}
12 | losses:
13 | lambda_gan: 1
14 | lambda_gan_end: ${model.losses.lambda_gan}
15 | lambda_gan_decay_steps: 1
16 |
17 | lambda_fm: 10.0
18 | lambda_fm_end: ${model.losses.lambda_fm}
19 | lambda_fm_decay_steps: 1
20 |
21 | lambda_label_mix: 0
22 | lambda_label_mix_end: ${model.losses.lambda_label_mix}
23 | lambda_label_mix_decay_steps: 1
24 |
25 | lambda_l1: 0
26 | lambda_l1_end: ${model.losses.lambda_l1}
27 | lambda_l1_decay_steps: 1
28 |
29 | lambda_perceptual: 10.0
30 | lambda_perceptual_end: ${model.losses.lambda_perceptual}
31 | lambda_perceptual_decay_steps: 1
32 |
33 | lambda_r1: 0
34 | lambda_r1_end: ${model.losses.lambda_r1}
35 | lambda_r1_decay_steps: 1
36 |
37 | lambda_kld: 0.05
38 | lambda_kld_end: ${model.losses.lambda_kld}
39 | lambda_kld_decay_steps: 1
40 |
41 | lazy_r1_step: 16
42 | gan_loss_type: hinge
43 | gan_loss_on_swaps: False
44 | use_adaptive_lambda : False
--------------------------------------------------------------------------------
/configs/model/disc_augments/basic.yaml:
--------------------------------------------------------------------------------
1 | _target_: torch.nn.Sequential
2 | _args_:
3 | - _target_: kornia.augmentation.augmentation.Denormalize
4 | mean: 0.5
5 | std: 0.5
6 | - _target_: kornia.augmentation.augmentation.ColorJitter
7 | p: 0.8
8 | brightness : 0.2
9 | contrast: 0.3
10 | hue: 0.2
11 | - _target_: kornia.augmentation.augmentation.RandomErasing
12 | p: 0.5
13 | - _target_: kornia.augmentation.augmentation.Normalize
14 | mean: 0.5
15 | std: 0.5
16 |
--------------------------------------------------------------------------------
/configs/model/discriminator/mcad.yaml:
--------------------------------------------------------------------------------
1 | _target_: models.discriminators.mcad.MultiScaleMCADiscriminator
2 | num_discriminator: 2
3 | image_num_channels: ${dataset.num_channels}
4 | segmap_num_channels: ${dataset.num_labels}
5 | positional_embedding_dim: 40
6 | num_labels: ${dataset.num_labels}
7 | num_latent_per_labels: 8
8 | latent_dim: 64
9 | num_blocks: 6
10 | attention_latent_dim: 64
11 | num_cross_heads: 1
12 | num_self_heads: 1
13 | apply_spectral_norm: False
14 | concat_segmaps: True
15 | output_type: attention_pool
16 | keep_intermediate_results: True
--------------------------------------------------------------------------------
/configs/model/discriminator/oasis.yaml:
--------------------------------------------------------------------------------
1 | _target_: models.discriminators.oasis.OASISDiscriminator
2 | image_num_channels: ${dataset.num_channels}
3 | segmap_num_channels: ${dataset.num_labels}
4 | apply_spectral_norm: True
5 | apply_grad_norm: False
--------------------------------------------------------------------------------
/configs/model/discriminator/patchgan.yaml:
--------------------------------------------------------------------------------
1 | _target_: models.discriminators.patchgan.MultiScalePatchGanDiscriminator
2 | num_discriminator: 2
3 | image_num_channels: ${dataset.num_channels}
4 | segmap_num_channels: ${dataset.num_labels}
5 | num_features_fst_conv: 64
6 | num_layers: 4
7 | apply_spectral_norm: True
8 | apply_grad_norm: False
9 | keep_intermediate_results: True
10 | use_equalized_lr: False
11 | lr_mul: 1.0
--------------------------------------------------------------------------------
/configs/model/discriminator/stylegan2.yaml:
--------------------------------------------------------------------------------
1 | _target_: models.discriminators.stylegan2.MultiStyleGan2Discriminator
2 | num_discriminator: 1
3 | image_num_channels: ${dataset.num_channels}
4 | segmap_num_channels: ${dataset.num_labels}
5 | num_features_fst_conv: 64
6 | num_layers: 7
7 | fmap_max: 512
8 | apply_spectral_norm: False
9 | apply_grad_norm: False
10 | keep_intermediate_results: True
11 | use_equalized_lr: False
12 | lr_mul: 1.0
--------------------------------------------------------------------------------
/configs/model/encoder/groupdnet.yaml:
--------------------------------------------------------------------------------
1 | _target_: models.encoders.groupdnet.GroupDNetStyleEncoder
2 | num_labels: ${dataset.num_labels}
3 |
--------------------------------------------------------------------------------
/configs/model/encoder/inade.yaml:
--------------------------------------------------------------------------------
1 | _target_: models.encoders.inade.InstanceAdaptiveEncoder
2 | num_labels: ${dataset.num_labels}
3 | noise_dim: ${model.generator.noise_dim}
4 | use_vae: True
--------------------------------------------------------------------------------
/configs/model/encoder/sat.yaml:
--------------------------------------------------------------------------------
1 | _target_: models.encoders.sat.SemanticAttentionTransformerEncoder
2 | num_input_channels: 3
3 | positional_embedding_dim: 40
4 | num_labels: ${dataset.num_labels}
5 | num_latent_per_labels: 8
6 | num_latents_bg: 8
7 | latent_dim: 256
8 | type_of_initial_latents: learned
9 | attention_latent_dim: 256
10 | num_blocks: 7
11 | num_cross_heads: 1
12 | num_self_heads: 1
13 | image_conv: True
14 | reverse_conv: False
15 | conv_features_dim_first: 16
16 | use_self_attention: True
17 | use_equalized_lr: False
18 | lr_mul: 1.0
19 | use_vae: False
--------------------------------------------------------------------------------
/configs/model/encoder/sean.yaml:
--------------------------------------------------------------------------------
1 | _target_: models.encoders.sean.RegionalAveragePoolingStyleEncoder
2 | num_input_channels: ${dataset.num_channels}
3 | latent_dim: 512
4 | num_features_fst_conv: 32
5 | use_vae: False
--------------------------------------------------------------------------------
/configs/model/encoder/spade.yaml:
--------------------------------------------------------------------------------
1 | _target_: models.encoders.spade.SPADEStyleEncoder
2 | use_vae: True
--------------------------------------------------------------------------------
/configs/model/generator/clade.yaml:
--------------------------------------------------------------------------------
1 | _target_ : models.generators.clade.CLADEGenerator
2 | num_filters_last_layer: 64
3 | num_up_layers: 7
4 | height: ${dataset.height}
5 | width: ${dataset.width}
6 | num_labels: ${dataset.num_labels}
7 | kernel_size: 3
8 | num_output_channels: ${dataset.num_channels}
9 | apply_spectral_norm: True
10 | use_vae : True
11 | use_dists: False
--------------------------------------------------------------------------------
/configs/model/generator/groupdnet.yaml:
--------------------------------------------------------------------------------
1 | _target_ : models.generators.groupdnet.GroupDNetGenerator
2 | num_filters_last_layer: 64
3 | num_up_layers: 7
4 | height: ${dataset.height}
5 | width: ${dataset.width}
6 | num_labels: ${dataset.num_labels}
7 | kernel_size: 3
8 | num_output_channels: ${dataset.num_channels}
9 | apply_spectral_norm: True
10 | use_vae : True
--------------------------------------------------------------------------------
/configs/model/generator/inade.yaml:
--------------------------------------------------------------------------------
1 | _target_ : models.generators.inade.INADEGenerator
2 | num_filters_last_layer: 64
3 | num_up_layers: 7
4 | height: ${dataset.height}
5 | width: ${dataset.width}
6 | num_labels: ${dataset.num_labels}
7 | noise_dim: 64
8 | kernel_size: 3
9 | num_output_channels: ${dataset.num_channels}
10 | apply_spectral_norm: True
11 | use_vae: True
--------------------------------------------------------------------------------
/configs/model/generator/scam.yaml:
--------------------------------------------------------------------------------
1 | _target_ : models.generators.scam.SCAMGenerator
2 | num_filters_last_layer: 64
3 | num_up_layers: 6
4 | height: ${dataset.height}
5 | width: ${dataset.width}
6 | num_labels: ${dataset.num_labels}
7 | num_labels_split: 8
8 | num_labels_bg: 8
9 | style_dim: ${model.encoder.latent_dim}
10 | kernel_size: 3
11 | attention_latent_dim: 256
12 | num_heads: 1
13 | attention_type: duplex
14 | num_up_layers_with_mask_adain: 0
15 | num_output_channels: ${dataset.num_channels}
16 | latent_pos_emb: none
17 | apply_spectral_norm: False
18 | split_latents: False
19 | norm_type: InstanceNorm
20 | architecture: skip
21 | add_noise: True
22 | modulate: True
23 | use_equalized_lr: False
24 | lr_mul: 1.0
25 | use_vae: False
--------------------------------------------------------------------------------
/configs/model/generator/sean.yaml:
--------------------------------------------------------------------------------
1 | _target_ : models.generators.sean.SEANGenerator
2 | num_filters_last_layer: 64
3 | num_up_layers: 7
4 | height: ${dataset.height}
5 | width: ${dataset.width}
6 | num_labels: ${dataset.num_labels}
7 | style_dim: ${model.encoder.latent_dim}
8 | kernel_size: 3
9 | num_output_channels: ${dataset.num_channels}
10 | apply_spectral_norm: True
11 | use_vae: False
--------------------------------------------------------------------------------
/configs/model/generator/sean_clade.yaml:
--------------------------------------------------------------------------------
1 | _target_ : models.generators.sean_clade.SEANCLADEGenerator
2 | num_filters_last_layer: 64
3 | num_up_layers: 7
4 | height: ${dataset.height}
5 | width: ${dataset.width}
6 | num_labels: ${dataset.num_labels}
7 | style_dim: ${model.encoder.latent_dim}
8 | kernel_size: 3
9 | num_output_channels: ${dataset.num_channels}
10 | apply_spectral_norm: True
11 | use_dists: False
--------------------------------------------------------------------------------
/configs/model/generator/spade.yaml:
--------------------------------------------------------------------------------
1 | _target_ : models.generators.spade.SPADEGenerator
2 | num_filters_last_layer: 64
3 | num_up_layers: 7
4 | height: ${dataset.height}
5 | width: ${dataset.width}
6 | num_labels: ${dataset.num_labels}
7 | kernel_size: 3
8 | num_output_channels: ${dataset.num_channels}
9 | apply_spectral_norm: True
10 | use_vae : True
--------------------------------------------------------------------------------
/configs/model/groupdnet.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - disc_augments: null
3 | - generator: groupdnet
4 | - discriminator: patchgan
5 | - encoder: groupdnet
6 | - optim: adam
7 |
8 | change_init: True
9 | gradient_clip_val: 0
10 | name: GroupDNet
11 | dataset_path: ${dataset.path}
12 | losses:
13 | lambda_gan: 1
14 | lambda_gan_end: ${model.losses.lambda_gan}
15 | lambda_gan_decay_steps: 1
16 |
17 | lambda_fm: 10.0
18 | lambda_fm_end: ${model.losses.lambda_fm}
19 | lambda_fm_decay_steps: 1
20 |
21 | lambda_label_mix: 0
22 | lambda_label_mix_end: ${model.losses.lambda_label_mix}
23 | lambda_label_mix_decay_steps: 1
24 |
25 | lambda_l1: 0
26 | lambda_l1_end: ${model.losses.lambda_l1}
27 | lambda_l1_decay_steps: 1
28 |
29 | lambda_perceptual: 10.0
30 | lambda_perceptual_end: ${model.losses.lambda_perceptual}
31 | lambda_perceptual_decay_steps: 1
32 |
33 | lambda_r1: 0
34 | lambda_r1_end: ${model.losses.lambda_r1}
35 | lambda_r1_decay_steps: 1
36 |
37 | lambda_kld: 0.05
38 | lambda_kld_end: ${model.losses.lambda_kld}
39 | lambda_kld_decay_steps: 1
40 | lazy_r1_step: 16
41 | gan_loss_type: hinge
42 | gan_loss_on_swaps: False
43 | use_adaptive_lambda : False
--------------------------------------------------------------------------------
/configs/model/inade.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - disc_augments: null
3 | - generator: inade
4 | - discriminator: patchgan
5 | - encoder: inade
6 | - optim: adam
7 |
8 | change_init: True
9 | gradient_clip_val: 0
10 | name: INADE
11 | dataset_path: ${dataset.path}
12 | losses:
13 | lambda_gan: 1
14 | lambda_gan_end: ${model.losses.lambda_gan}
15 | lambda_gan_decay_steps: 1
16 |
17 | lambda_fm: 10.0
18 | lambda_fm_end: ${model.losses.lambda_fm}
19 | lambda_fm_decay_steps: 1
20 |
21 | lambda_label_mix: 0
22 | lambda_label_mix_end: ${model.losses.lambda_label_mix}
23 | lambda_label_mix_decay_steps: 1
24 |
25 | lambda_l1: 0
26 | lambda_l1_end: ${model.losses.lambda_l1}
27 | lambda_l1_decay_steps: 1
28 |
29 | lambda_perceptual: 10.0
30 | lambda_perceptual_end: ${model.losses.lambda_perceptual}
31 | lambda_perceptual_decay_steps: 1
32 |
33 | lambda_r1: 0
34 | lambda_r1_end: ${model.losses.lambda_r1}
35 | lambda_r1_decay_steps: 1
36 |
37 | lambda_kld: 0.05
38 | lambda_kld_end: ${model.losses.lambda_kld}
39 | lambda_kld_decay_steps: 1
40 | lazy_r1_step: 16
41 | gan_loss_type: hinge
42 | gan_loss_on_swaps: False
43 | use_adaptive_lambda : False
--------------------------------------------------------------------------------
/configs/model/optim/adam.yaml:
--------------------------------------------------------------------------------
1 | disc_optim:
2 | _target_: torch.optim.Adam
3 | lr: ${model.optim.disc_lr}
4 | betas: ${model.optim.betas}
5 | weight_decay: ${model.optim.weight_decay}
6 | gen_optim:
7 | _target_: torch.optim.Adam
8 | lr: ${model.optim.gen_lr}
9 | betas: ${model.optim.betas}
10 | weight_decay: ${model.optim.weight_decay}
11 | SAM: False
12 | betas : [0, 0.999]
13 | weight_decay: 0
14 | disc_lr: 4e-4
15 | gen_lr: 1e-4
--------------------------------------------------------------------------------
/configs/model/optim/adamw.yaml:
--------------------------------------------------------------------------------
1 | disc_optim:
2 | _target_: torch.optim.AdamW
3 | lr: ${model.optim.disc_lr}
4 | betas: ${model.optim.betas}
5 | weight_decay: ${model.optim.weight_decay}
6 | gen_optim:
7 | _target_: torch.optim.AdamW
8 | lr: ${model.optim.gen_lr}
9 | betas: ${model.optim.betas}
10 | weight_decay: ${model.optim.weight_decay}
11 | SAM: False
12 | betas : [0.9, 0.999]
13 | weight_decay: 0.01
14 | disc_lr: 4e-4
15 | gen_lr: 1e-4
--------------------------------------------------------------------------------
/configs/model/optim/sam.yaml:
--------------------------------------------------------------------------------
1 | disc_optim:
2 | _target_: models.utils_blocks.SAM.SAM
3 | rho: 0.5
4 | adaptive: True
5 | lr: 0.0004
6 | betas: [0.9, 0.999]
7 | weight_decay: 0.01
8 | gen_optim:
9 | _target_: models.utils_blocks.SAM.SAM
10 | rho: 0.5
11 | adaptive: True
12 | lr: 0.0001
13 | betas: [0.9, 0.999]
14 | weight_decay: 0.01
15 | SAM: True
16 |
--------------------------------------------------------------------------------
/configs/model/scam.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - disc_augments: null
3 | - generator: scam
4 | - discriminator: patchgan
5 | - encoder: sat
6 | - optim: adamw
7 |
8 | change_init: False
9 | gradient_clip_val: 0
10 | name: SCAM
11 | dataset_path: ${dataset.path}
12 |
13 |
14 | discriminator:
15 | apply_spectral_norm: False
16 | apply_grad_norm: True
17 |
18 | generator:
19 | apply_spectral_norm: False
20 | losses:
21 | lambda_gan: 1
22 | lambda_gan_end: ${model.losses.lambda_gan}
23 | lambda_gan_decay_steps: 1
24 |
25 | lambda_fm: 0
26 | lambda_fm_end: ${model.losses.lambda_fm}
27 | lambda_fm_decay_steps: 1
28 |
29 | lambda_label_mix: 0
30 | lambda_label_mix_end: ${model.losses.lambda_label_mix}
31 | lambda_label_mix_decay_steps: 1
32 |
33 | lambda_l1: 10.0
34 | lambda_l1_end: ${model.losses.lambda_l1}
35 | lambda_l1_decay_steps: 1
36 |
37 | lambda_perceptual: 10.0
38 | lambda_perceptual_end: ${model.losses.lambda_perceptual}
39 | lambda_perceptual_decay_steps: 1
40 |
41 | lambda_r1: 0
42 | lambda_r1_end: ${model.losses.lambda_r1}
43 | lambda_r1_decay_steps: 1
44 |
45 | lambda_kld: 0
46 | lambda_kld_end: ${model.losses.lambda_kld}
47 | lambda_kld_decay_steps: 1
48 |
49 | lazy_r1_step: 16
50 | gan_loss_type: hinge
51 | gan_loss_on_swaps: False
52 | use_adaptive_lambda: False
--------------------------------------------------------------------------------
/configs/model/sean.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - disc_augments: null
3 | - generator: sean
4 | - discriminator: patchgan
5 | - encoder: sean
6 | - optim: adam
7 |
8 | change_init: True
9 | gradient_clip_val: 0
10 | name: SEAN
11 | dataset_path: ${dataset.path}
12 | losses:
13 | lambda_gan: 1
14 | lambda_gan_end: ${model.losses.lambda_gan}
15 | lambda_gan_decay_steps: 1
16 |
17 | lambda_fm: 10.0
18 | lambda_fm_end: ${model.losses.lambda_fm}
19 | lambda_fm_decay_steps: 1
20 |
21 | lambda_label_mix: 0
22 | lambda_label_mix_end: ${model.losses.lambda_label_mix}
23 | lambda_label_mix_decay_steps: 1
24 |
25 | lambda_l1: 0
26 | lambda_l1_end: ${model.losses.lambda_l1}
27 | lambda_l1_decay_steps: 1
28 |
29 | lambda_perceptual: 10.0
30 | lambda_perceptual_end: ${model.losses.lambda_perceptual}
31 | lambda_perceptual_decay_steps: 1
32 |
33 | lambda_r1: 0
34 | lambda_r1_end: ${model.losses.lambda_r1}
35 | lambda_r1_decay_steps: 1
36 |
37 | lambda_kld: 0
38 | lambda_kld_end: ${model.losses.lambda_kld}
39 | lambda_kld_decay_steps: 1
40 |
41 | lazy_r1_step: 16
42 | gan_loss_type: hinge
43 | gan_loss_on_swaps: False
44 | use_adaptive_lambda: False
--------------------------------------------------------------------------------
/configs/model/sean_clade.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - disc_augments: null
3 | - generator: sean_clade
4 | - discriminator: patchgan
5 | - encoder: sean
6 | - optim: adam
7 |
8 | change_init: True
9 | gradient_clip_val: 0
10 | name: SEAN_CLADE
11 | dataset_path: ${dataset.path}
12 | losses:
13 | lambda_gan: 1
14 | lambda_gan_end: ${model.losses.lambda_gan}
15 | lambda_gan_decay_steps: 1
16 |
17 | lambda_fm: 10.0
18 | lambda_fm_end: ${model.losses.lambda_fm}
19 | lambda_fm_decay_steps: 1
20 |
21 | lambda_label_mix: 0
22 | lambda_label_mix_end: ${model.losses.lambda_label_mix}
23 | lambda_label_mix_decay_steps: 1
24 |
25 | lambda_l1: 0
26 | lambda_l1_end: ${model.losses.lambda_l1}
27 | lambda_l1_decay_steps: 1
28 |
29 | lambda_perceptual: 10.0
30 | lambda_perceptual_end: ${model.losses.lambda_perceptual}
31 | lambda_perceptual_decay_steps: 1
32 |
33 | lambda_r1: 0
34 | lambda_r1_end: ${model.losses.lambda_r1}
35 | lambda_r1_decay_steps: 1
36 |
37 | lambda_kld: 0
38 | lambda_kld_end: ${model.losses.lambda_kld}
39 | lambda_kld_decay_steps: 1
40 | lazy_r1_step: 16
41 | gan_loss_type: hinge
42 | gan_loss_on_swaps: False
43 | use_adaptive_lambda : False
--------------------------------------------------------------------------------
/configs/model/spade.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - disc_augments: null
3 | - generator: spade
4 | - discriminator: patchgan
5 | - encoder: spade
6 | - optim: adam
7 |
8 | change_init: True
9 | gradient_clip_val: 0
10 | name: SPADE
11 | dataset_path: ${dataset.path}
12 | losses:
13 | lambda_gan: 1
14 | lambda_gan_end: ${model.losses.lambda_gan}
15 | lambda_gan_decay_steps: 1
16 |
17 | lambda_fm: 10.0
18 | lambda_fm_end: ${model.losses.lambda_fm}
19 | lambda_fm_decay_steps: 1
20 |
21 | lambda_label_mix: 0
22 | lambda_label_mix_end: ${model.losses.lambda_label_mix}
23 | lambda_label_mix_decay_steps: 1
24 |
25 | lambda_l1: 0
26 | lambda_l1_end: ${model.losses.lambda_l1}
27 | lambda_l1_decay_steps: 1
28 |
29 | lambda_perceptual: 10.0
30 | lambda_perceptual_end: ${model.losses.lambda_perceptual}
31 | lambda_perceptual_decay_steps: 1
32 |
33 | lambda_r1: 0
34 | lambda_r1_end: ${model.losses.lambda_r1}
35 | lambda_r1_decay_steps: 1
36 |
37 | lambda_kld: 0.05
38 | lambda_kld_end: ${model.losses.lambda_kld}
39 | lambda_kld_decay_steps: 1
40 | lazy_r1_step: 16
41 | gan_loss_type: hinge
42 | gan_loss_on_swaps: False
43 | use_adaptive_lambda: False
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nicolas-dufour/SCAM/848484acd71c277f3ed48456e206870847dd400d/data/__init__.py
--------------------------------------------------------------------------------
/data/datamodule.py:
--------------------------------------------------------------------------------
1 | import pytorch_lightning as pl
2 | import torchvision.transforms as transforms
3 | from torch.utils.data import DataLoader
4 | from pathlib import Path
5 | from .image_dataset import ImageDataset
6 | import albumentations as A
7 | from hydra.utils import instantiate
8 |
9 |
10 | class ImageDataModule(pl.LightningDataModule):
11 | """
12 | Module to load image data
13 | """
14 |
15 | def __init__(self, cfg):
16 | super().__init__()
17 | self.cfg = cfg
18 |
19 | def setup(self, stage=None):
20 |
21 | train_transforms = instantiate(self.cfg.train_aug)
22 | val_transforms = instantiate(self.cfg.val_aug)
23 | train_path = Path(self.cfg.path) / Path("train")
24 | test_path = Path(self.cfg.path) / Path("test")
25 |
26 | self.train_dataset = ImageDataset(
27 | train_path,
28 | self.cfg.num_labels_orig,
29 | train_transforms,
30 | image_extension=self.cfg.image_extension,
31 | label_merge_strat=self.cfg.label_merge_strat,
32 | )
33 | self.test_dataset = ImageDataset(
34 | test_path,
35 | self.cfg.num_labels_orig,
36 | val_transforms,
37 | image_extension=self.cfg.image_extension,
38 | label_merge_strat=self.cfg.label_merge_strat,
39 | )
40 |
41 | def train_dataloader(self):
42 | return DataLoader(
43 | self.train_dataset,
44 | batch_size=self.cfg.batch_size,
45 | shuffle=True,
46 | pin_memory=True,
47 | num_workers=self.cfg.num_workers,
48 | )
49 |
50 | def val_dataloader(self):
51 | return DataLoader(
52 | self.test_dataset,
53 | batch_size=self.cfg.batch_size,
54 | shuffle=False,
55 | pin_memory=True,
56 | num_workers=self.cfg.num_workers,
57 | )
58 |
59 | def test_dataloader(self):
60 | return DataLoader(
61 | self.test_dataset,
62 | batch_size=self.cfg.batch_size,
63 | shuffle=False,
64 | pin_memory=True,
65 | num_workers=self.cfg.num_workers,
66 | )
67 |
--------------------------------------------------------------------------------
/data/image_dataset.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import Dataset
2 | import torch
3 | import os
4 | from pathlib import Path
5 | import PIL
6 | from PIL import Image
7 | import numpy as np
8 | import torch.nn.functional as F
9 |
10 |
11 | class ImageDataset(Dataset):
12 | """
13 | Image Dataset in jpg format
14 | """
15 |
16 | def __init__(
17 | self,
18 | path,
19 | num_labels,
20 | transforms=None,
21 | image_extension="jpg",
22 | label_merge_strat="none",
23 | ):
24 | super().__init__()
25 | self.image_dir = Path(path) / Path("images")
26 | self.labels = Path(path) / Path("labels")
27 |
28 | self.image_extension = image_extension
29 |
30 | images_list = sorted(os.listdir(self.image_dir))
31 |
32 | self.image_list_filtered = []
33 |
34 | for image_name in images_list:
35 | if image_name.endswith(self.image_extension):
36 | self.image_list_filtered.append(image_name.split(".")[0])
37 |
38 | self.transforms = transforms
39 | self.num_labels = num_labels
40 | self.label_merge_strat = label_merge_strat
41 |
42 | def __getitem__(self, index):
43 | image_name = self.image_list_filtered[index]
44 | image = np.array(
45 | Image.open(
46 | self.image_dir / Path(f"{image_name}.{self.image_extension}")
47 | ).convert("RGB")
48 | )
49 | segmentation_mask = np.array(
50 | Image.open(self.labels / Path(f"{image_name}.png")).resize(
51 | image.shape[:-1][::-1], resample=Image.NEAREST
52 | )
53 | )
54 | if self.transforms:
55 | augmented = self.transforms(image=image, mask=segmentation_mask)
56 | image = augmented["image"]
57 | segmentation_mask = augmented["mask"]
58 | segmentation_mask = self.onehot_encode_labels(segmentation_mask)
59 | if self.label_merge_strat == "body_background":
60 | background = segmentation_mask[0]
61 | body = segmentation_mask[1:].max(dim=0).values
62 | segmentation_mask = torch.stack([background, body], dim=0)
63 |
64 | elif self.label_merge_strat == "body_face_background":
65 | background = segmentation_mask[0]
66 | face = (
67 | torch.stack(
68 | [
69 | segmentation_mask[1],
70 | segmentation_mask[2],
71 | segmentation_mask[4],
72 | segmentation_mask[13],
73 | ],
74 | dim=0,
75 | )
76 | .max(dim=0)
77 | .values
78 | )
79 | segmentation_mask[4] = torch.zeros_like(segmentation_mask[4])
80 | segmentation_mask[13] = torch.zeros_like(segmentation_mask[13])
81 | body = segmentation_mask[3:].max(dim=0).values
82 | segmentation_mask = torch.stack([background, face, body], dim=0)
83 | return image, segmentation_mask
84 |
85 | def __len__(self):
86 | return len(self.image_list_filtered)
87 |
88 | def onehot_encode_labels(self, labels):
89 | height, width = labels.shape
90 | labels = labels.unsqueeze(0).long()
91 | input_label = torch.FloatTensor(self.num_labels, height, width).zero_()
92 | input_semantics = input_label.scatter_(0, labels, 1.0)
93 | return input_semantics
94 |
--------------------------------------------------------------------------------
/datasets/.empty:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nicolas-dufour/SCAM/848484acd71c277f3ed48456e206870847dd400d/datasets/.empty
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: scam
2 | channels:
3 | - pytorch
4 | - defaults
5 | dependencies:
6 | - numpy=1.21.2
7 | - pillow=9.0.1
8 | - python=3.9.12
9 | - pytorch=1.12
10 | - pip=21.2.4
11 | - pip:
12 | - torchvision==0.13
13 | - albumentations==1.1.0
14 | - einops==0.4.1
15 | - hydra-core==1.1.1
16 | - kornia==0.6.4
17 | - pytorch-lightning==1.6.0
18 | - torch-fidelity==0.3.0
19 | - torchmetrics==0.7.3
20 | - torchtyping==0.1.4
21 | - tqdm==4.64.0
22 | - wandb==0.12.4
--------------------------------------------------------------------------------
/medias/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nicolas-dufour/SCAM/848484acd71c277f3ed48456e206870847dd400d/medias/.DS_Store
--------------------------------------------------------------------------------
/medias/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nicolas-dufour/SCAM/848484acd71c277f3ed48456e206870847dd400d/medias/teaser.png
--------------------------------------------------------------------------------
/metrics/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nicolas-dufour/SCAM/848484acd71c277f3ed48456e206870847dd400d/metrics/__init__.py
--------------------------------------------------------------------------------
/metrics/fid.py:
--------------------------------------------------------------------------------
1 | from torch_fidelity.feature_extractor_inceptionv3 import FeatureExtractorInceptionV3
2 | from typing import Any, Callable, List, Optional, Union
3 |
4 | import numpy as np
5 | import torch
6 | import torch.nn as nn
7 | from torch import Tensor
8 | from torch.autograd import Function
9 | from torchvision.transforms import Resize
10 | from pathlib import Path
11 | from torchmetrics.metric import Metric
12 | from tqdm import tqdm
13 |
14 | from utils.utils import remap_image_torch
15 |
16 |
17 | class MatrixSquareRoot(Function):
18 | """Square root of a positive definite matrix.
19 | All credit to:
20 | https://github.com/steveli/pytorch-sqrtm/blob/master/sqrtm.py
21 | """
22 |
23 | @staticmethod
24 | def forward(ctx: Any, input: Tensor) -> Tensor:
25 | import scipy
26 |
27 | m = input.detach().cpu().numpy().astype(np.float_)
28 | scipy_res, _ = scipy.linalg.sqrtm(m, disp=False)
29 | sqrtm = torch.from_numpy(scipy_res.real).to(input)
30 | ctx.save_for_backward(sqrtm)
31 | return sqrtm
32 |
33 | @staticmethod
34 | def backward(ctx: Any, grad_output: Tensor) -> Tensor:
35 | import scipy
36 |
37 | grad_input = None
38 | if ctx.needs_input_grad[0]:
39 | (sqrtm,) = ctx.saved_tensors
40 | sqrtm = sqrtm.data.cpu().numpy().astype(np.float_)
41 | gm = grad_output.data.cpu().numpy().astype(np.float_)
42 |
43 | # Given a positive semi-definite matrix X,
44 | # since X = X^{1/2}X^{1/2}, we can compute the gradient of the
45 | # matrix square root dX^{1/2} by solving the Sylvester equation:
46 | # dX = (d(X^{1/2})X^{1/2} + X^{1/2}(dX^{1/2}).
47 | grad_sqrtm = scipy.linalg.solve_sylvester(sqrtm, sqrtm, gm)
48 |
49 | grad_input = torch.from_numpy(grad_sqrtm).to(grad_output)
50 | return grad_input
51 |
52 |
53 | sqrtm = MatrixSquareRoot.apply
54 |
55 |
56 | class NoTrainInceptionV3(FeatureExtractorInceptionV3):
57 | def __init__(
58 | self,
59 | name: str,
60 | features_list: List[str],
61 | feature_extractor_weights_path: Optional[str] = None,
62 | ) -> None:
63 | super().__init__(name, features_list, feature_extractor_weights_path)
64 | self.eval()
65 | self.num_out_feat = int(features_list[0])
66 |
67 | def train(self, mode):
68 | return super().train(False)
69 |
70 | def forward(self, x: Tensor) -> Tensor:
71 | x = remap_image_torch(x)
72 | out = super().forward(x)
73 | return out[0].reshape(x.shape[0], -1)
74 |
75 |
76 | def compute_fid(mu_1, mu_2, sigma_1, sigma_2, eps=1e-6):
77 | mean_diff = mu_1 - mu_2
78 |
79 | mean_dist = mean_diff.dot(mean_diff)
80 |
81 | covmean = sqrtm(sigma_1.mm(sigma_2))
82 |
83 | if not torch.isfinite(covmean).all():
84 | offset = torch.eye(sigma_1.size(0), device=mu_1.device, dtype=mu_1.dtype) * eps
85 | covmean = sqrtm((sigma_1 + offset).mm(sigma_2 + offset))
86 | return (
87 | mean_dist
88 | + torch.trace(sigma_1)
89 | + torch.trace(sigma_2)
90 | - 2 * torch.trace(covmean)
91 | )
92 |
93 |
94 | class FID(Metric):
95 | def __init__(
96 | self,
97 | feature_extractor,
98 | real_features_path,
99 | compute_on_step: bool = False,
100 | dist_sync_on_step: bool = False,
101 | process_group: Optional[Any] = None,
102 | dist_sync_fn: Callable = None,
103 | ):
104 |
105 | super().__init__(
106 | compute_on_step=compute_on_step,
107 | dist_sync_on_step=dist_sync_on_step,
108 | process_group=process_group,
109 | dist_sync_fn=dist_sync_fn,
110 | )
111 |
112 | self.feature_extractor = feature_extractor
113 |
114 | mean_real = torch.load(f"{real_features_path}/mean.pt")
115 | sigma_real = torch.load(f"{real_features_path}/sigma.pt")
116 |
117 | self.add_state("mean_real", default=mean_real, dist_reduce_fx="mean")
118 | self.add_state("sigma_real", default=sigma_real, dist_reduce_fx="mean")
119 |
120 | self.add_state(
121 | "generated_features_sum",
122 | torch.zeros(self.feature_extractor.num_out_feat, dtype=torch.double),
123 | dist_reduce_fx="sum",
124 | )
125 | self.add_state(
126 | "generated_features_cov_sum",
127 | torch.zeros(
128 | (
129 | self.feature_extractor.num_out_feat,
130 | self.feature_extractor.num_out_feat,
131 | ),
132 | dtype=torch.double,
133 | ),
134 | dist_reduce_fx="sum",
135 | )
136 | self.add_state(
137 | "generated_features_num_samples",
138 | torch.tensor(0).long(),
139 | dist_reduce_fx="sum",
140 | )
141 |
142 | def update(self, images):
143 | features = self.feature_extractor(images).double()
144 | self.generated_features_sum += features.sum(dim=0)
145 | self.generated_features_cov_sum += features.t().mm(features)
146 | self.generated_features_num_samples += images.shape[0]
147 |
148 | def compute(self):
149 | mean_real = self.mean_real
150 | mean_generated = (
151 | self.generated_features_sum / self.generated_features_num_samples
152 | ).unsqueeze(dim=0)
153 |
154 | sigma_real = self.sigma_real
155 | sigma_generated = (
156 | self.generated_features_cov_sum
157 | - self.generated_features_num_samples
158 | * mean_generated.t().mm(mean_generated)
159 | ) / (self.generated_features_num_samples - 1)
160 | return compute_fid(mean_real, mean_generated[0], sigma_real, sigma_generated)
161 |
162 |
163 | def compute_fid_features(dataloader, save_path, device="cpu"):
164 | feature_extractor = NoTrainInceptionV3(
165 | name="inception-v3-compat", features_list=[str(2048)]
166 | ).to(device)
167 | real_features_sum = torch.zeros(
168 | feature_extractor.num_out_feat, dtype=torch.double, device=device
169 | )
170 | real_features_cov_sum = torch.zeros(
171 | (
172 | feature_extractor.num_out_feat,
173 | feature_extractor.num_out_feat,
174 | ),
175 | dtype=torch.double,
176 | device=device,
177 | )
178 | real_features_num_samples = 0
179 | with torch.no_grad():
180 | for i, batch in enumerate(tqdm(dataloader)):
181 | images, _ = batch
182 | images = images.to(device)
183 | features = feature_extractor(images).double()
184 | real_features_num_samples += features.shape[0]
185 | real_features_sum += features.sum(dim=0)
186 | real_features_cov_sum += features.t().mm(features)
187 | mean = (real_features_sum / real_features_num_samples).unsqueeze(0)
188 | sigma = (real_features_cov_sum - real_features_num_samples * mean.t().mm(mean)) / (
189 | real_features_num_samples - 1
190 | )
191 | mean = mean.squeeze(0)
192 | torch.save(mean.cpu(), Path(save_path) / Path("mean.pt"))
193 | torch.save(sigma.cpu(), Path(save_path) / Path("sigma.pt"))
194 |
--------------------------------------------------------------------------------
/metrics/models/fastreid/__init__.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: liaoxingyu
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 |
8 | __version__ = "0.1.0"
--------------------------------------------------------------------------------
/metrics/models/fastreid/layers/__init__.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: liaoxingyu
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 | from .activation import *
8 | from .arc_softmax import ArcSoftmax
9 | from .circle_softmax import CircleSoftmax
10 | from .am_softmax import AMSoftmax
11 | from .batch_drop import BatchDrop
12 | from .batch_norm import *
13 | from .context_block import ContextBlock
14 | from .frn import FRN, TLU
15 | from .non_local import Non_local
16 | from .pooling import *
17 | from .se_layer import SELayer
18 | from .splat import SplAtConv2d
19 | from .gather_layer import GatherLayer
20 |
--------------------------------------------------------------------------------
/metrics/models/fastreid/layers/activation.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: xingyu liao
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 | import math
8 |
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 |
13 | __all__ = [
14 | 'Mish',
15 | 'Swish',
16 | 'MemoryEfficientSwish',
17 | 'GELU']
18 |
19 |
20 | class Mish(nn.Module):
21 | def __init__(self):
22 | super().__init__()
23 |
24 | def forward(self, x):
25 | # inlining this saves 1 second per epoch (V100 GPU) vs having a temp x and then returning x(!)
26 | return x * (torch.tanh(F.softplus(x)))
27 |
28 |
29 | class Swish(nn.Module):
30 | def forward(self, x):
31 | return x * torch.sigmoid(x)
32 |
33 |
34 | class SwishImplementation(torch.autograd.Function):
35 | @staticmethod
36 | def forward(ctx, i):
37 | result = i * torch.sigmoid(i)
38 | ctx.save_for_backward(i)
39 | return result
40 |
41 | @staticmethod
42 | def backward(ctx, grad_output):
43 | i = ctx.saved_variables[0]
44 | sigmoid_i = torch.sigmoid(i)
45 | return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))
46 |
47 |
48 | class MemoryEfficientSwish(nn.Module):
49 | def forward(self, x):
50 | return SwishImplementation.apply(x)
51 |
52 |
53 | class GELU(nn.Module):
54 | """
55 | Paper Section 3.4, last paragraph notice that BERT used the GELU instead of RELU
56 | """
57 |
58 | def forward(self, x):
59 | return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
60 |
--------------------------------------------------------------------------------
/metrics/models/fastreid/layers/am_softmax.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: xingyu liao
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 | import torch
8 | from torch import nn
9 | import torch.nn.functional as F
10 | from torch.nn import Parameter
11 |
12 |
13 | class AMSoftmax(nn.Module):
14 | r"""Implement of large margin cosine distance:
15 | Args:
16 | in_feat: size of each input sample
17 | num_classes: size of each output sample
18 | """
19 |
20 | def __init__(self, cfg, in_feat, num_classes):
21 | super().__init__()
22 | self.in_features = in_feat
23 | self._num_classes = num_classes
24 | self.s = cfg.MODEL.HEADS.SCALE
25 | self.m = cfg.MODEL.HEADS.MARGIN
26 | self.weight = Parameter(torch.Tensor(num_classes, in_feat))
27 | nn.init.xavier_uniform_(self.weight)
28 |
29 | def forward(self, features, targets):
30 | # --------------------------- cos(theta) & phi(theta) ---------------------------
31 | cosine = F.linear(F.normalize(features), F.normalize(self.weight))
32 | phi = cosine - self.m
33 | # --------------------------- convert label to one-hot ---------------------------
34 | targets = F.one_hot(targets, num_classes=self._num_classes)
35 | output = (targets * phi) + ((1.0 - targets) * cosine)
36 | output *= self.s
37 |
38 | return output
39 |
40 | def extra_repr(self):
41 | return 'in_features={}, num_classes={}, scale={}, margin={}'.format(
42 | self.in_feat, self._num_classes, self.s, self.m
43 | )
44 |
--------------------------------------------------------------------------------
/metrics/models/fastreid/layers/arc_softmax.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: liaoxingyu
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 | import math
8 |
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 | from torch.nn import Parameter
13 |
14 |
15 | class ArcSoftmax(nn.Module):
16 | def __init__(self, cfg, in_feat, num_classes):
17 | super().__init__()
18 | self.in_feat = in_feat
19 | self._num_classes = num_classes
20 | self.s = cfg.MODEL.HEADS.SCALE
21 | self.m = cfg.MODEL.HEADS.MARGIN
22 |
23 | self.cos_m = math.cos(self.m)
24 | self.sin_m = math.sin(self.m)
25 | self.threshold = math.cos(math.pi - self.m)
26 | self.mm = math.sin(math.pi - self.m) * self.m
27 |
28 | self.weight = Parameter(torch.Tensor(num_classes, in_feat))
29 | nn.init.xavier_uniform_(self.weight)
30 | self.register_buffer('t', torch.zeros(1))
31 |
32 | def forward(self, features, targets):
33 | # get cos(theta)
34 | cos_theta = F.linear(F.normalize(features), F.normalize(self.weight))
35 | cos_theta = cos_theta.clamp(-1, 1) # for numerical stability
36 |
37 | target_logit = cos_theta[torch.arange(0, features.size(0)), targets].view(-1, 1)
38 |
39 | sin_theta = torch.sqrt(1.0 - torch.pow(target_logit, 2))
40 | cos_theta_m = target_logit * self.cos_m - sin_theta * self.sin_m # cos(target+margin)
41 | mask = cos_theta > cos_theta_m
42 | final_target_logit = torch.where(target_logit > self.threshold, cos_theta_m, target_logit - self.mm)
43 |
44 | hard_example = cos_theta[mask]
45 | with torch.no_grad():
46 | self.t = target_logit.mean() * 0.01 + (1 - 0.01) * self.t
47 | cos_theta[mask] = hard_example * (self.t + hard_example)
48 | cos_theta.scatter_(1, targets.view(-1, 1).long(), final_target_logit)
49 | pred_class_logits = cos_theta * self.s
50 | return pred_class_logits
51 |
52 | def extra_repr(self):
53 | return 'in_features={}, num_classes={}, scale={}, margin={}'.format(
54 | self.in_feat, self._num_classes, self.s, self.m
55 | )
56 |
--------------------------------------------------------------------------------
/metrics/models/fastreid/layers/batch_drop.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: liaoxingyu
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 | import random
8 |
9 | from torch import nn
10 |
11 |
12 | class BatchDrop(nn.Module):
13 | """ref: https://github.com/daizuozhuo/batch-dropblock-network/blob/master/models/networks.py
14 | batch drop mask
15 | """
16 |
17 | def __init__(self, h_ratio, w_ratio):
18 | super(BatchDrop, self).__init__()
19 | self.h_ratio = h_ratio
20 | self.w_ratio = w_ratio
21 |
22 | def forward(self, x):
23 | if self.training:
24 | h, w = x.size()[-2:]
25 | rh = round(self.h_ratio * h)
26 | rw = round(self.w_ratio * w)
27 | sx = random.randint(0, h - rh)
28 | sy = random.randint(0, w - rw)
29 | mask = x.new_ones(x.size())
30 | mask[:, :, sx:sx + rh, sy:sy + rw] = 0
31 | x = x * mask
32 | return x
33 |
--------------------------------------------------------------------------------
/metrics/models/fastreid/layers/batch_norm.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: liaoxingyu
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 | import logging
8 |
9 | import torch
10 | import torch.nn.functional as F
11 | from torch import nn
12 |
13 | __all__ = [
14 | "BatchNorm",
15 | "IBN",
16 | "GhostBatchNorm",
17 | "FrozenBatchNorm",
18 | "SyncBatchNorm",
19 | "get_norm",
20 | ]
21 |
22 |
23 | class BatchNorm(nn.BatchNorm2d):
24 | def __init__(self, num_features, eps=1e-05, momentum=0.1, weight_freeze=False, bias_freeze=False, weight_init=1.0,
25 | bias_init=0.0, **kwargs):
26 | super().__init__(num_features, eps=eps, momentum=momentum)
27 | if weight_init is not None: nn.init.constant_(self.weight, weight_init)
28 | if bias_init is not None: nn.init.constant_(self.bias, bias_init)
29 | self.weight.requires_grad_(not weight_freeze)
30 | self.bias.requires_grad_(not bias_freeze)
31 |
32 |
33 | class SyncBatchNorm(nn.SyncBatchNorm):
34 | def __init__(self, num_features, eps=1e-05, momentum=0.1, weight_freeze=False, bias_freeze=False, weight_init=1.0,
35 | bias_init=0.0):
36 | super().__init__(num_features, eps=eps, momentum=momentum)
37 | if weight_init is not None: nn.init.constant_(self.weight, weight_init)
38 | if bias_init is not None: nn.init.constant_(self.bias, bias_init)
39 | self.weight.requires_grad_(not weight_freeze)
40 | self.bias.requires_grad_(not bias_freeze)
41 |
42 |
43 | class IBN(nn.Module):
44 | def __init__(self, planes, bn_norm, **kwargs):
45 | super(IBN, self).__init__()
46 | half1 = int(planes / 2)
47 | self.half = half1
48 | half2 = planes - half1
49 | self.IN = nn.InstanceNorm2d(half1, affine=True)
50 | self.BN = get_norm(bn_norm, half2, **kwargs)
51 |
52 | def forward(self, x):
53 | split = torch.split(x, self.half, 1)
54 | out1 = self.IN(split[0].contiguous())
55 | out2 = self.BN(split[1].contiguous())
56 | out = torch.cat((out1, out2), 1)
57 | return out
58 |
59 |
60 | class GhostBatchNorm(BatchNorm):
61 | def __init__(self, num_features, num_splits=1, **kwargs):
62 | super().__init__(num_features, **kwargs)
63 | self.num_splits = num_splits
64 | self.register_buffer('running_mean', torch.zeros(num_features))
65 | self.register_buffer('running_var', torch.ones(num_features))
66 |
67 | def forward(self, input):
68 | N, C, H, W = input.shape
69 | if self.training or not self.track_running_stats:
70 | self.running_mean = self.running_mean.repeat(self.num_splits)
71 | self.running_var = self.running_var.repeat(self.num_splits)
72 | outputs = F.batch_norm(
73 | input.view(-1, C * self.num_splits, H, W), self.running_mean, self.running_var,
74 | self.weight.repeat(self.num_splits), self.bias.repeat(self.num_splits),
75 | True, self.momentum, self.eps).view(N, C, H, W)
76 | self.running_mean = torch.mean(self.running_mean.view(self.num_splits, self.num_features), dim=0)
77 | self.running_var = torch.mean(self.running_var.view(self.num_splits, self.num_features), dim=0)
78 | return outputs
79 | else:
80 | return F.batch_norm(
81 | input, self.running_mean, self.running_var,
82 | self.weight, self.bias, False, self.momentum, self.eps)
83 |
84 |
85 | class FrozenBatchNorm(BatchNorm):
86 | """
87 | BatchNorm2d where the batch statistics and the affine parameters are fixed.
88 | It contains non-trainable buffers called
89 | "weight" and "bias", "running_mean", "running_var",
90 | initialized to perform identity transformation.
91 | The pre-trained backbone models from Caffe2 only contain "weight" and "bias",
92 | which are computed from the original four parameters of BN.
93 | The affine transform `x * weight + bias` will perform the equivalent
94 | computation of `(x - running_mean) / sqrt(running_var) * weight + bias`.
95 | When loading a backbone model from Caffe2, "running_mean" and "running_var"
96 | will be left unchanged as identity transformation.
97 | Other pre-trained backbone models may contain all 4 parameters.
98 | The forward is implemented by `F.batch_norm(..., training=False)`.
99 | """
100 |
101 | _version = 3
102 |
103 | def __init__(self, num_features, eps=1e-5, **kwargs):
104 | super().__init__(num_features, weight_freeze=True, bias_freeze=True, **kwargs)
105 | self.num_features = num_features
106 | self.eps = eps
107 |
108 | def forward(self, x):
109 | if x.requires_grad:
110 | # When gradients are needed, F.batch_norm will use extra memory
111 | # because its backward op computes gradients for weight/bias as well.
112 | scale = self.weight * (self.running_var + self.eps).rsqrt()
113 | bias = self.bias - self.running_mean * scale
114 | scale = scale.reshape(1, -1, 1, 1)
115 | bias = bias.reshape(1, -1, 1, 1)
116 | return x * scale + bias
117 | else:
118 | # When gradients are not needed, F.batch_norm is a single fused op
119 | # and provide more optimization opportunities.
120 | return F.batch_norm(
121 | x,
122 | self.running_mean,
123 | self.running_var,
124 | self.weight,
125 | self.bias,
126 | training=False,
127 | eps=self.eps,
128 | )
129 |
130 | def _load_from_state_dict(
131 | self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
132 | ):
133 | version = local_metadata.get("version", None)
134 |
135 | if version is None or version < 2:
136 | # No running_mean/var in early versions
137 | # This will silent the warnings
138 | if prefix + "running_mean" not in state_dict:
139 | state_dict[prefix + "running_mean"] = torch.zeros_like(self.running_mean)
140 | if prefix + "running_var" not in state_dict:
141 | state_dict[prefix + "running_var"] = torch.ones_like(self.running_var)
142 |
143 | if version is not None and version < 3:
144 | logger = logging.getLogger(__name__)
145 | logger.info("FrozenBatchNorm {} is upgraded to version 3.".format(prefix.rstrip(".")))
146 | # In version < 3, running_var are used without +eps.
147 | state_dict[prefix + "running_var"] -= self.eps
148 |
149 | super()._load_from_state_dict(
150 | state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
151 | )
152 |
153 | def __repr__(self):
154 | return "FrozenBatchNorm2d(num_features={}, eps={})".format(self.num_features, self.eps)
155 |
156 | @classmethod
157 | def convert_frozen_batchnorm(cls, module):
158 | """
159 | Convert BatchNorm/SyncBatchNorm in module into FrozenBatchNorm.
160 | Args:
161 | module (torch.nn.Module):
162 | Returns:
163 | If module is BatchNorm/SyncBatchNorm, returns a new module.
164 | Otherwise, in-place convert module and return it.
165 | Similar to convert_sync_batchnorm in
166 | https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/batchnorm.py
167 | """
168 | bn_module = nn.modules.batchnorm
169 | bn_module = (bn_module.BatchNorm2d, bn_module.SyncBatchNorm)
170 | res = module
171 | if isinstance(module, bn_module):
172 | res = cls(module.num_features)
173 | if module.affine:
174 | res.weight.data = module.weight.data.clone().detach()
175 | res.bias.data = module.bias.data.clone().detach()
176 | res.running_mean.data = module.running_mean.data
177 | res.running_var.data = module.running_var.data
178 | res.eps = module.eps
179 | else:
180 | for name, child in module.named_children():
181 | new_child = cls.convert_frozen_batchnorm(child)
182 | if new_child is not child:
183 | res.add_module(name, new_child)
184 | return res
185 |
186 |
187 | def get_norm(norm, out_channels, **kwargs):
188 | """
189 | Args:
190 | norm (str or callable): either one of BN, GhostBN, FrozenBN, GN or SyncBN;
191 | or a callable that thakes a channel number and returns
192 | the normalization layer as a nn.Module
193 | out_channels: number of channels for normalization layer
194 |
195 | Returns:
196 | nn.Module or None: the normalization layer
197 | """
198 | if isinstance(norm, str):
199 | if len(norm) == 0:
200 | return None
201 | norm = {
202 | "BN": BatchNorm,
203 | "GhostBN": GhostBatchNorm,
204 | "FrozenBN": FrozenBatchNorm,
205 | "GN": lambda channels, **args: nn.GroupNorm(32, channels),
206 | "syncBN": SyncBatchNorm,
207 | }[norm]
208 | return norm(out_channels, **kwargs)
209 |
--------------------------------------------------------------------------------
/metrics/models/fastreid/layers/circle_softmax.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: liaoxingyu
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 | import math
8 |
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 | from torch.nn import Parameter
13 |
14 |
15 | class CircleSoftmax(nn.Module):
16 | def __init__(self, cfg, in_feat, num_classes):
17 | super().__init__()
18 | self.in_feat = in_feat
19 | self._num_classes = num_classes
20 | self.s = cfg.MODEL.HEADS.SCALE
21 | self.m = cfg.MODEL.HEADS.MARGIN
22 |
23 | self.weight = Parameter(torch.Tensor(num_classes, in_feat))
24 | nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
25 |
26 | def forward(self, features, targets):
27 | sim_mat = F.linear(F.normalize(features), F.normalize(self.weight))
28 | alpha_p = torch.clamp_min(-sim_mat.detach() + 1 + self.m, min=0.)
29 | alpha_n = torch.clamp_min(sim_mat.detach() + self.m, min=0.)
30 | delta_p = 1 - self.m
31 | delta_n = self.m
32 |
33 | s_p = self.s * alpha_p * (sim_mat - delta_p)
34 | s_n = self.s * alpha_n * (sim_mat - delta_n)
35 |
36 | targets = F.one_hot(targets, num_classes=self._num_classes)
37 |
38 | pred_class_logits = targets * s_p + (1.0 - targets) * s_n
39 |
40 | return pred_class_logits
41 |
42 | def extra_repr(self):
43 | return 'in_features={}, num_classes={}, scale={}, margin={}'.format(
44 | self.in_feat, self._num_classes, self.s, self.m
45 | )
46 |
--------------------------------------------------------------------------------
/metrics/models/fastreid/layers/context_block.py:
--------------------------------------------------------------------------------
1 | # copy from https://github.com/xvjiarui/GCNet/blob/master/mmdet/ops/gcb/context_block.py
2 |
3 | import torch
4 | from torch import nn
5 |
6 | __all__ = ['ContextBlock']
7 |
8 |
9 | def last_zero_init(m):
10 | if isinstance(m, nn.Sequential):
11 | nn.init.constant_(m[-1].weight, val=0)
12 | if hasattr(m[-1], 'bias') and m[-1].bias is not None:
13 | nn.init.constant_(m[-1].bias, 0)
14 | else:
15 | nn.init.constant_(m.weight, val=0)
16 | if hasattr(m, 'bias') and m.bias is not None:
17 | nn.init.constant_(m.bias, 0)
18 |
19 |
20 | class ContextBlock(nn.Module):
21 |
22 | def __init__(self,
23 | inplanes,
24 | ratio,
25 | pooling_type='att',
26 | fusion_types=('channel_add',)):
27 | super(ContextBlock, self).__init__()
28 | assert pooling_type in ['avg', 'att']
29 | assert isinstance(fusion_types, (list, tuple))
30 | valid_fusion_types = ['channel_add', 'channel_mul']
31 | assert all([f in valid_fusion_types for f in fusion_types])
32 | assert len(fusion_types) > 0, 'at least one fusion should be used'
33 | self.inplanes = inplanes
34 | self.ratio = ratio
35 | self.planes = int(inplanes * ratio)
36 | self.pooling_type = pooling_type
37 | self.fusion_types = fusion_types
38 | if pooling_type == 'att':
39 | self.conv_mask = nn.Conv2d(inplanes, 1, kernel_size=1)
40 | self.softmax = nn.Softmax(dim=2)
41 | else:
42 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
43 | if 'channel_add' in fusion_types:
44 | self.channel_add_conv = nn.Sequential(
45 | nn.Conv2d(self.inplanes, self.planes, kernel_size=1),
46 | nn.LayerNorm([self.planes, 1, 1]),
47 | nn.ReLU(inplace=True), # yapf: disable
48 | nn.Conv2d(self.planes, self.inplanes, kernel_size=1))
49 | else:
50 | self.channel_add_conv = None
51 | if 'channel_mul' in fusion_types:
52 | self.channel_mul_conv = nn.Sequential(
53 | nn.Conv2d(self.inplanes, self.planes, kernel_size=1),
54 | nn.LayerNorm([self.planes, 1, 1]),
55 | nn.ReLU(inplace=True), # yapf: disable
56 | nn.Conv2d(self.planes, self.inplanes, kernel_size=1))
57 | else:
58 | self.channel_mul_conv = None
59 | self.reset_parameters()
60 |
61 | def reset_parameters(self):
62 | if self.pooling_type == 'att':
63 | nn.init.kaiming_normal_(self.conv_mask.weight, a=0, mode='fan_in', nonlinearity='relu')
64 | if hasattr(self.conv_mask, 'bias') and self.conv_mask.bias is not None:
65 | nn.init.constant_(self.conv_mask.bias, 0)
66 | self.conv_mask.inited = True
67 |
68 | if self.channel_add_conv is not None:
69 | last_zero_init(self.channel_add_conv)
70 | if self.channel_mul_conv is not None:
71 | last_zero_init(self.channel_mul_conv)
72 |
73 | def spatial_pool(self, x):
74 | batch, channel, height, width = x.size()
75 | if self.pooling_type == 'att':
76 | input_x = x
77 | # [N, C, H * W]
78 | input_x = input_x.view(batch, channel, height * width)
79 | # [N, 1, C, H * W]
80 | input_x = input_x.unsqueeze(1)
81 | # [N, 1, H, W]
82 | context_mask = self.conv_mask(x)
83 | # [N, 1, H * W]
84 | context_mask = context_mask.view(batch, 1, height * width)
85 | # [N, 1, H * W]
86 | context_mask = self.softmax(context_mask)
87 | # [N, 1, H * W, 1]
88 | context_mask = context_mask.unsqueeze(-1)
89 | # [N, 1, C, 1]
90 | context = torch.matmul(input_x, context_mask)
91 | # [N, C, 1, 1]
92 | context = context.view(batch, channel, 1, 1)
93 | else:
94 | # [N, C, 1, 1]
95 | context = self.avg_pool(x)
96 |
97 | return context
98 |
99 | def forward(self, x):
100 | # [N, C, 1, 1]
101 | context = self.spatial_pool(x)
102 |
103 | out = x
104 | if self.channel_mul_conv is not None:
105 | # [N, C, 1, 1]
106 | channel_mul_term = torch.sigmoid(self.channel_mul_conv(context))
107 | out = out * channel_mul_term
108 | if self.channel_add_conv is not None:
109 | # [N, C, 1, 1]
110 | channel_add_term = self.channel_add_conv(context)
111 | out = out + channel_add_term
112 |
113 | return out
114 |
--------------------------------------------------------------------------------
/metrics/models/fastreid/layers/frn.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: liaoxingyu
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 | import torch
8 | from torch import nn
9 | from torch.nn.modules.batchnorm import BatchNorm2d
10 | from torch.nn import ReLU, LeakyReLU
11 | from torch.nn.parameter import Parameter
12 |
13 |
14 | class TLU(nn.Module):
15 | def __init__(self, num_features):
16 | """max(y, tau) = max(y - tau, 0) + tau = ReLU(y - tau) + tau"""
17 | super(TLU, self).__init__()
18 | self.num_features = num_features
19 | self.tau = Parameter(torch.Tensor(num_features))
20 | self.reset_parameters()
21 |
22 | def reset_parameters(self):
23 | nn.init.zeros_(self.tau)
24 |
25 | def extra_repr(self):
26 | return 'num_features={num_features}'.format(**self.__dict__)
27 |
28 | def forward(self, x):
29 | return torch.max(x, self.tau.view(1, self.num_features, 1, 1))
30 |
31 |
32 | class FRN(nn.Module):
33 | def __init__(self, num_features, eps=1e-6, is_eps_leanable=False):
34 | """
35 | weight = gamma, bias = beta
36 | beta, gamma:
37 | Variables of shape [1, 1, 1, C]. if TensorFlow
38 | Variables of shape [1, C, 1, 1]. if PyTorch
39 | eps: A scalar constant or learnable variable.
40 | """
41 | super(FRN, self).__init__()
42 |
43 | self.num_features = num_features
44 | self.init_eps = eps
45 | self.is_eps_leanable = is_eps_leanable
46 |
47 | self.weight = Parameter(torch.Tensor(num_features))
48 | self.bias = Parameter(torch.Tensor(num_features))
49 | if is_eps_leanable:
50 | self.eps = Parameter(torch.Tensor(1))
51 | else:
52 | self.register_buffer('eps', torch.Tensor([eps]))
53 | self.reset_parameters()
54 |
55 | def reset_parameters(self):
56 | nn.init.ones_(self.weight)
57 | nn.init.zeros_(self.bias)
58 | if self.is_eps_leanable:
59 | nn.init.constant_(self.eps, self.init_eps)
60 |
61 | def extra_repr(self):
62 | return 'num_features={num_features}, eps={init_eps}'.format(**self.__dict__)
63 |
64 | def forward(self, x):
65 | """
66 | 0, 1, 2, 3 -> (B, H, W, C) in TensorFlow
67 | 0, 1, 2, 3 -> (B, C, H, W) in PyTorch
68 | TensorFlow code
69 | nu2 = tf.reduce_mean(tf.square(x), axis=[1, 2], keepdims=True)
70 | x = x * tf.rsqrt(nu2 + tf.abs(eps))
71 | # This Code include TLU function max(y, tau)
72 | return tf.maximum(gamma * x + beta, tau)
73 | """
74 | # Compute the mean norm of activations per channel.
75 | nu2 = x.pow(2).mean(dim=[2, 3], keepdim=True)
76 |
77 | # Perform FRN.
78 | x = x * torch.rsqrt(nu2 + self.eps.abs())
79 |
80 | # Scale and Bias
81 | x = self.weight.view(1, self.num_features, 1, 1) * x + self.bias.view(1, self.num_features, 1, 1)
82 | # x = self.weight * x + self.bias
83 | return x
84 |
85 |
86 | def bnrelu_to_frn(module):
87 | """
88 | Convert 'BatchNorm2d + ReLU' to 'FRN + TLU'
89 | """
90 | mod = module
91 | before_name = None
92 | before_child = None
93 | is_before_bn = False
94 |
95 | for name, child in module.named_children():
96 | if is_before_bn and isinstance(child, (ReLU, LeakyReLU)):
97 | # Convert BN to FRN
98 | if isinstance(before_child, BatchNorm2d):
99 | mod.add_module(
100 | before_name, FRN(num_features=before_child.num_features))
101 | else:
102 | raise NotImplementedError()
103 |
104 | # Convert ReLU to TLU
105 | mod.add_module(name, TLU(num_features=before_child.num_features))
106 | else:
107 | mod.add_module(name, bnrelu_to_frn(child))
108 |
109 | before_name = name
110 | before_child = child
111 | is_before_bn = isinstance(child, BatchNorm2d)
112 | return mod
113 |
114 |
115 | def convert(module, flag_name):
116 | mod = module
117 | before_ch = None
118 | for name, child in module.named_children():
119 | if hasattr(child, flag_name) and getattr(child, flag_name):
120 | if isinstance(child, BatchNorm2d):
121 | before_ch = child.num_features
122 | mod.add_module(name, FRN(num_features=child.num_features))
123 | # TODO bn is no good...
124 | if isinstance(child, (ReLU, LeakyReLU)):
125 | mod.add_module(name, TLU(num_features=before_ch))
126 | else:
127 | mod.add_module(name, convert(child, flag_name))
128 | return mod
129 |
130 |
131 | def remove_flags(module, flag_name):
132 | mod = module
133 | for name, child in module.named_children():
134 | if hasattr(child, 'is_convert_frn'):
135 | delattr(child, flag_name)
136 | mod.add_module(name, remove_flags(child, flag_name))
137 | else:
138 | mod.add_module(name, remove_flags(child, flag_name))
139 | return mod
140 |
141 |
142 | def bnrelu_to_frn2(model, input_size=(3, 128, 128), batch_size=2, flag_name='is_convert_frn'):
143 | forard_hooks = list()
144 | backward_hooks = list()
145 |
146 | is_before_bn = [False]
147 |
148 | def register_forward_hook(module):
149 | def hook(self, input, output):
150 | if isinstance(module, (nn.Sequential, nn.ModuleList)) or (module == model):
151 | is_before_bn.append(False)
152 | return
153 |
154 | # input and output is required in hook def
155 | is_converted = is_before_bn[-1] and isinstance(self, (ReLU, LeakyReLU))
156 | if is_converted:
157 | setattr(self, flag_name, True)
158 | is_before_bn.append(isinstance(self, BatchNorm2d))
159 |
160 | forard_hooks.append(module.register_forward_hook(hook))
161 |
162 | is_before_relu = [False]
163 |
164 | def register_backward_hook(module):
165 | def hook(self, input, output):
166 | if isinstance(module, (nn.Sequential, nn.ModuleList)) or (module == model):
167 | is_before_relu.append(False)
168 | return
169 | is_converted = is_before_relu[-1] and isinstance(self, BatchNorm2d)
170 | if is_converted:
171 | setattr(self, flag_name, True)
172 | is_before_relu.append(isinstance(self, (ReLU, LeakyReLU)))
173 |
174 | backward_hooks.append(module.register_backward_hook(hook))
175 |
176 | # multiple inputs to the network
177 | if isinstance(input_size, tuple):
178 | input_size = [input_size]
179 |
180 | # batch_size of 2 for batchnorm
181 | x = [torch.rand(batch_size, *in_size) for in_size in input_size]
182 |
183 | # register hook
184 | model.apply(register_forward_hook)
185 | model.apply(register_backward_hook)
186 |
187 | # make a forward pass
188 | output = model(*x)
189 | output.sum().backward() # Raw output is not enabled to use backward()
190 |
191 | # remove these hooks
192 | for h in forard_hooks:
193 | h.remove()
194 | for h in backward_hooks:
195 | h.remove()
196 |
197 | model = convert(model, flag_name=flag_name)
198 | model = remove_flags(model, flag_name=flag_name)
199 | return model
200 |
--------------------------------------------------------------------------------
/metrics/models/fastreid/layers/gather_layer.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: xingyu liao
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 | # based on: https://github.com/open-mmlab/OpenSelfSup/blob/master/openselfsup/models/utils/gather_layer.py
8 |
9 | import torch
10 | import torch.distributed as dist
11 |
12 |
13 | class GatherLayer(torch.autograd.Function):
14 | """Gather tensors from all process, supporting backward propagation.
15 | """
16 |
17 | @staticmethod
18 | def forward(ctx, input):
19 | ctx.save_for_backward(input)
20 | output = [torch.zeros_like(input) \
21 | for _ in range(dist.get_world_size())]
22 | dist.all_gather(output, input)
23 | return tuple(output)
24 |
25 | @staticmethod
26 | def backward(ctx, *grads):
27 | input, = ctx.saved_tensors
28 | grad_out = torch.zeros_like(input)
29 | grad_out[:] = grads[dist.get_rank()]
30 | return grad_out
31 |
--------------------------------------------------------------------------------
/metrics/models/fastreid/layers/non_local.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 |
3 |
4 | import torch
5 | from torch import nn
6 | from .batch_norm import get_norm
7 |
8 |
9 | class Non_local(nn.Module):
10 | def __init__(self, in_channels, bn_norm, reduc_ratio=2):
11 | super(Non_local, self).__init__()
12 |
13 | self.in_channels = in_channels
14 | self.inter_channels = in_channels // reduc_ratio
15 |
16 | self.g = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels,
17 | kernel_size=1, stride=1, padding=0)
18 |
19 | self.W = nn.Sequential(
20 | nn.Conv2d(in_channels=self.inter_channels, out_channels=self.in_channels,
21 | kernel_size=1, stride=1, padding=0),
22 | get_norm(bn_norm, self.in_channels),
23 | )
24 | nn.init.constant_(self.W[1].weight, 0.0)
25 | nn.init.constant_(self.W[1].bias, 0.0)
26 |
27 | self.theta = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels,
28 | kernel_size=1, stride=1, padding=0)
29 |
30 | self.phi = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels,
31 | kernel_size=1, stride=1, padding=0)
32 |
33 | def forward(self, x):
34 | """
35 | :param x: (b, t, h, w)
36 | :return x: (b, t, h, w)
37 | """
38 | batch_size = x.size(0)
39 | g_x = self.g(x).view(batch_size, self.inter_channels, -1)
40 | g_x = g_x.permute(0, 2, 1)
41 |
42 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)
43 | theta_x = theta_x.permute(0, 2, 1)
44 | phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)
45 | f = torch.matmul(theta_x, phi_x)
46 | N = f.size(-1)
47 | f_div_C = f / N
48 |
49 | y = torch.matmul(f_div_C, g_x)
50 | y = y.permute(0, 2, 1).contiguous()
51 | y = y.view(batch_size, self.inter_channels, *x.size()[2:])
52 | W_y = self.W(y)
53 | z = W_y + x
54 | return z
55 |
--------------------------------------------------------------------------------
/metrics/models/fastreid/layers/pooling.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: l1aoxingyu
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 | import torch
8 | import torch.nn.functional as F
9 | from torch import nn
10 |
11 | __all__ = ["Flatten",
12 | "GeneralizedMeanPooling",
13 | "GeneralizedMeanPoolingP",
14 | "FastGlobalAvgPool2d",
15 | "AdaptiveAvgMaxPool2d",
16 | "ClipGlobalAvgPool2d",
17 | ]
18 |
19 |
20 | class Flatten(nn.Module):
21 | def forward(self, input):
22 | return input.view(input.size(0), -1)
23 |
24 |
25 | class GeneralizedMeanPooling(nn.Module):
26 | r"""Applies a 2D power-average adaptive pooling over an input signal composed of several input planes.
27 | The function computed is: :math:`f(X) = pow(sum(pow(X, p)), 1/p)`
28 | - At p = infinity, one gets Max Pooling
29 | - At p = 1, one gets Average Pooling
30 | The output is of size H x W, for any input size.
31 | The number of output features is equal to the number of input planes.
32 | Args:
33 | output_size: the target output size of the image of the form H x W.
34 | Can be a tuple (H, W) or a single H for a square image H x H
35 | H and W can be either a ``int``, or ``None`` which means the size will
36 | be the same as that of the input.
37 | """
38 |
39 | def __init__(self, norm=3, output_size=1, eps=1e-6):
40 | super(GeneralizedMeanPooling, self).__init__()
41 | assert norm > 0
42 | self.p = float(norm)
43 | self.output_size = output_size
44 | self.eps = eps
45 |
46 | def forward(self, x):
47 | x = x.clamp(min=self.eps).pow(self.p)
48 | return torch.nn.functional.adaptive_avg_pool2d(x, self.output_size).pow(1. / self.p)
49 |
50 | def __repr__(self):
51 | return self.__class__.__name__ + '(' \
52 | + str(self.p) + ', ' \
53 | + 'output_size=' + str(self.output_size) + ')'
54 |
55 |
56 | class GeneralizedMeanPoolingP(GeneralizedMeanPooling):
57 | """ Same, but norm is trainable
58 | """
59 |
60 | def __init__(self, norm=3, output_size=1, eps=1e-6):
61 | super(GeneralizedMeanPoolingP, self).__init__(norm, output_size, eps)
62 | self.p = nn.Parameter(torch.ones(1) * norm)
63 |
64 |
65 | class AdaptiveAvgMaxPool2d(nn.Module):
66 | def __init__(self):
67 | super(AdaptiveAvgMaxPool2d, self).__init__()
68 | self.gap = FastGlobalAvgPool2d()
69 | self.gmp = nn.AdaptiveMaxPool2d(1)
70 |
71 | def forward(self, x):
72 | avg_feat = self.gap(x)
73 | max_feat = self.gmp(x)
74 | feat = avg_feat + max_feat
75 | return feat
76 |
77 |
78 | class FastGlobalAvgPool2d(nn.Module):
79 | def __init__(self, flatten=False):
80 | super(FastGlobalAvgPool2d, self).__init__()
81 | self.flatten = flatten
82 |
83 | def forward(self, x):
84 | if self.flatten:
85 | in_size = x.size()
86 | return x.view((in_size[0], in_size[1], -1)).mean(dim=2)
87 | else:
88 | return x.view(x.size(0), x.size(1), -1).mean(-1).view(x.size(0), x.size(1), 1, 1)
89 |
90 |
91 | class ClipGlobalAvgPool2d(nn.Module):
92 | def __init__(self):
93 | super().__init__()
94 | self.avgpool = FastGlobalAvgPool2d()
95 |
96 | def forward(self, x):
97 | x = self.avgpool(x)
98 | x = torch.clamp(x, min=0., max=1.)
99 | return x
100 |
--------------------------------------------------------------------------------
/metrics/models/fastreid/layers/se_layer.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: liaoxingyu
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 | from torch import nn
8 |
9 |
10 | class SELayer(nn.Module):
11 | def __init__(self, channel, reduction=16):
12 | super(SELayer, self).__init__()
13 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
14 | self.fc = nn.Sequential(
15 | nn.Linear(channel, int(channel / reduction), bias=False),
16 | nn.ReLU(inplace=True),
17 | nn.Linear(int(channel / reduction), channel, bias=False),
18 | nn.Sigmoid()
19 | )
20 |
21 | def forward(self, x):
22 | b, c, _, _ = x.size()
23 | y = self.avg_pool(x).view(b, c)
24 | y = self.fc(y).view(b, c, 1, 1)
25 | return x * y.expand_as(x)
26 |
--------------------------------------------------------------------------------
/metrics/models/fastreid/layers/splat.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: xingyu liao
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 | import torch
8 | import torch.nn.functional as F
9 | from torch import nn
10 | from torch.nn import Conv2d, ReLU
11 | from torch.nn.modules.utils import _pair
12 | from metrics.models.fastreid.layers import get_norm
13 |
14 |
15 | class SplAtConv2d(nn.Module):
16 | """Split-Attention Conv2d
17 | """
18 |
19 | def __init__(self, in_channels, channels, kernel_size, stride=(1, 1), padding=(0, 0),
20 | dilation=(1, 1), groups=1, bias=True,
21 | radix=2, reduction_factor=4,
22 | rectify=False, rectify_avg=False, norm_layer=None, num_splits=1,
23 | dropblock_prob=0.0, **kwargs):
24 | super(SplAtConv2d, self).__init__()
25 | padding = _pair(padding)
26 | self.rectify = rectify and (padding[0] > 0 or padding[1] > 0)
27 | self.rectify_avg = rectify_avg
28 | inter_channels = max(in_channels * radix // reduction_factor, 32)
29 | self.radix = radix
30 | self.cardinality = groups
31 | self.channels = channels
32 | self.dropblock_prob = dropblock_prob
33 | if self.rectify:
34 | from rfconv import RFConv2d
35 | self.conv = RFConv2d(in_channels, channels * radix, kernel_size, stride, padding, dilation,
36 | groups=groups * radix, bias=bias, average_mode=rectify_avg, **kwargs)
37 | else:
38 | self.conv = Conv2d(in_channels, channels * radix, kernel_size, stride, padding, dilation,
39 | groups=groups * radix, bias=bias, **kwargs)
40 | self.use_bn = norm_layer is not None
41 | if self.use_bn:
42 | self.bn0 = get_norm(norm_layer, channels * radix)
43 | self.relu = ReLU(inplace=True)
44 | self.fc1 = Conv2d(channels, inter_channels, 1, groups=self.cardinality)
45 | if self.use_bn:
46 | self.bn1 = get_norm(norm_layer, inter_channels)
47 | self.fc2 = Conv2d(inter_channels, channels * radix, 1, groups=self.cardinality)
48 |
49 | self.rsoftmax = rSoftMax(radix, groups)
50 |
51 | def forward(self, x):
52 | x = self.conv(x)
53 | if self.use_bn:
54 | x = self.bn0(x)
55 | if self.dropblock_prob > 0.0:
56 | x = self.dropblock(x)
57 | x = self.relu(x)
58 |
59 | batch, rchannel = x.shape[:2]
60 | if self.radix > 1:
61 | splited = torch.split(x, rchannel // self.radix, dim=1)
62 | gap = sum(splited)
63 | else:
64 | gap = x
65 | gap = F.adaptive_avg_pool2d(gap, 1)
66 | gap = self.fc1(gap)
67 |
68 | if self.use_bn:
69 | gap = self.bn1(gap)
70 | gap = self.relu(gap)
71 |
72 | atten = self.fc2(gap)
73 | atten = self.rsoftmax(atten).view(batch, -1, 1, 1)
74 |
75 | if self.radix > 1:
76 | attens = torch.split(atten, rchannel // self.radix, dim=1)
77 | out = sum([att * split for (att, split) in zip(attens, splited)])
78 | else:
79 | out = atten * x
80 | return out.contiguous()
81 |
82 |
83 | class rSoftMax(nn.Module):
84 | def __init__(self, radix, cardinality):
85 | super().__init__()
86 | self.radix = radix
87 | self.cardinality = cardinality
88 |
89 | def forward(self, x):
90 | batch = x.size(0)
91 | if self.radix > 1:
92 | x = x.view(batch, self.cardinality, self.radix, -1).transpose(1, 2)
93 | x = F.softmax(x, dim=1)
94 | x = x.reshape(batch, -1)
95 | else:
96 | x = torch.sigmoid(x)
97 | return x
98 |
--------------------------------------------------------------------------------
/metrics/models/fastreid/model.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: liaoxingyu
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 | import logging
8 | import os
9 | import math
10 |
11 | import torch
12 | from torch import nn
13 |
14 | from metrics.models.fastreid.layers import (
15 | IBN,
16 | SELayer,
17 | Non_local,
18 | get_norm,
19 | )
20 |
21 |
22 | logger = logging.getLogger(__name__)
23 | model_urls = {
24 | '18x': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
25 | '34x': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
26 | '50x': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
27 | '101x': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
28 | '152x': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
29 | 'ibn_18x': 'https://github.com/XingangPan/IBN-Net/releases/download/v1.0/resnet18_ibn_a-2f571257.pth',
30 | 'ibn_34x': 'https://github.com/XingangPan/IBN-Net/releases/download/v1.0/resnet34_ibn_a-94bc1577.pth',
31 | 'ibn_50x': 'https://github.com/XingangPan/IBN-Net/releases/download/v1.0/resnet50_ibn_a-d9d0bb7b.pth',
32 | 'ibn_101x': 'https://github.com/XingangPan/IBN-Net/releases/download/v1.0/resnet101_ibn_a-59ea0ac6.pth',
33 | 'se_ibn_101x': 'https://github.com/XingangPan/IBN-Net/releases/download/v1.0/se_resnet101_ibn_a-fabed4e2.pth',
34 | }
35 |
36 |
37 | class BasicBlock(nn.Module):
38 | expansion = 1
39 |
40 | def __init__(self, inplanes, planes, bn_norm, with_ibn=False, with_se=False,
41 | stride=1, downsample=None, reduction=16):
42 | super(BasicBlock, self).__init__()
43 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
44 | if with_ibn:
45 | self.bn1 = IBN(planes, bn_norm)
46 | else:
47 | self.bn1 = get_norm(bn_norm, planes)
48 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
49 | self.bn2 = get_norm(bn_norm, planes)
50 | self.relu = nn.ReLU(inplace=True)
51 | if with_se:
52 | self.se = SELayer(planes, reduction)
53 | else:
54 | self.se = nn.Identity()
55 | self.downsample = downsample
56 | self.stride = stride
57 |
58 | def forward(self, x):
59 | identity = x
60 |
61 | out = self.conv1(x)
62 | out = self.bn1(out)
63 | out = self.relu(out)
64 |
65 | out = self.conv2(out)
66 | out = self.bn2(out)
67 |
68 | if self.downsample is not None:
69 | identity = self.downsample(x)
70 |
71 | out += identity
72 | out = self.relu(out)
73 |
74 | return out
75 |
76 |
77 | class Bottleneck(nn.Module):
78 | expansion = 4
79 |
80 | def __init__(self, inplanes, planes, bn_norm, with_ibn=False, with_se=False,
81 | stride=1, downsample=None, reduction=16):
82 | super(Bottleneck, self).__init__()
83 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
84 | if with_ibn:
85 | self.bn1 = IBN(planes, bn_norm)
86 | else:
87 | self.bn1 = get_norm(bn_norm, planes)
88 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
89 | padding=1, bias=False)
90 | self.bn2 = get_norm(bn_norm, planes)
91 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
92 | self.bn3 = get_norm(bn_norm, planes * self.expansion)
93 | self.relu = nn.ReLU(inplace=True)
94 | if with_se:
95 | self.se = SELayer(planes * self.expansion, reduction)
96 | else:
97 | self.se = nn.Identity()
98 | self.downsample = downsample
99 | self.stride = stride
100 |
101 | def forward(self, x):
102 | residual = x
103 |
104 | out = self.conv1(x)
105 | out = self.bn1(out)
106 | out = self.relu(out)
107 |
108 | out = self.conv2(out)
109 | out = self.bn2(out)
110 | out = self.relu(out)
111 |
112 | out = self.conv3(out)
113 | out = self.bn3(out)
114 | out = self.se(out)
115 |
116 | if self.downsample is not None:
117 | residual = self.downsample(x)
118 |
119 | out += residual
120 | out = self.relu(out)
121 |
122 | return out
123 |
124 |
125 | class ResNet(nn.Module):
126 | def __init__(self, last_stride, bn_norm, with_ibn, with_se, with_nl, block, layers, non_layers):
127 | self.inplanes = 64
128 | super().__init__()
129 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
130 | bias=False)
131 | self.bn1 = get_norm(bn_norm, 64)
132 | self.relu = nn.ReLU(inplace=True)
133 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
134 | # self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True)
135 | self.layer1 = self._make_layer(block, 64, layers[0], 1, bn_norm, with_ibn, with_se)
136 | self.layer2 = self._make_layer(block, 128, layers[1], 2, bn_norm, with_ibn, with_se)
137 | self.layer3 = self._make_layer(block, 256, layers[2], 2, bn_norm, with_ibn, with_se)
138 | self.layer4 = self._make_layer(block, 512, layers[3], last_stride, bn_norm, with_se=with_se)
139 |
140 | self.random_init()
141 |
142 | # fmt: off
143 | if with_nl: self._build_nonlocal(layers, non_layers, bn_norm)
144 | else: self.NL_1_idx = self.NL_2_idx = self.NL_3_idx = self.NL_4_idx = []
145 | # fmt: on
146 |
147 | def _make_layer(self, block, planes, blocks, stride=1, bn_norm="BN", with_ibn=False, with_se=False):
148 | downsample = None
149 | if stride != 1 or self.inplanes != planes * block.expansion:
150 | downsample = nn.Sequential(
151 | nn.Conv2d(self.inplanes, planes * block.expansion,
152 | kernel_size=1, stride=stride, bias=False),
153 | get_norm(bn_norm, planes * block.expansion),
154 | )
155 |
156 | layers = []
157 | layers.append(block(self.inplanes, planes, bn_norm, with_ibn, with_se, stride, downsample))
158 | self.inplanes = planes * block.expansion
159 | for i in range(1, blocks):
160 | layers.append(block(self.inplanes, planes, bn_norm, with_ibn, with_se))
161 |
162 | return nn.Sequential(*layers)
163 |
164 | def _build_nonlocal(self, layers, non_layers, bn_norm):
165 | self.NL_1 = nn.ModuleList(
166 | [Non_local(256, bn_norm) for _ in range(non_layers[0])])
167 | self.NL_1_idx = sorted([layers[0] - (i + 1) for i in range(non_layers[0])])
168 | self.NL_2 = nn.ModuleList(
169 | [Non_local(512, bn_norm) for _ in range(non_layers[1])])
170 | self.NL_2_idx = sorted([layers[1] - (i + 1) for i in range(non_layers[1])])
171 | self.NL_3 = nn.ModuleList(
172 | [Non_local(1024, bn_norm) for _ in range(non_layers[2])])
173 | self.NL_3_idx = sorted([layers[2] - (i + 1) for i in range(non_layers[2])])
174 | self.NL_4 = nn.ModuleList(
175 | [Non_local(2048, bn_norm) for _ in range(non_layers[3])])
176 | self.NL_4_idx = sorted([layers[3] - (i + 1) for i in range(non_layers[3])])
177 |
178 | def forward(self, x):
179 | x = self.conv1(x)
180 | x = self.bn1(x)
181 | x = self.relu(x)
182 | x = self.maxpool(x)
183 |
184 | NL1_counter = 0
185 | if len(self.NL_1_idx) == 0:
186 | self.NL_1_idx = [-1]
187 | for i in range(len(self.layer1)):
188 | x = self.layer1[i](x)
189 | if i == self.NL_1_idx[NL1_counter]:
190 | _, C, H, W = x.shape
191 | x = self.NL_1[NL1_counter](x)
192 | NL1_counter += 1
193 | # Layer 2
194 | NL2_counter = 0
195 | if len(self.NL_2_idx) == 0:
196 | self.NL_2_idx = [-1]
197 | for i in range(len(self.layer2)):
198 | x = self.layer2[i](x)
199 | if i == self.NL_2_idx[NL2_counter]:
200 | _, C, H, W = x.shape
201 | x = self.NL_2[NL2_counter](x)
202 | NL2_counter += 1
203 | # Layer 3
204 | NL3_counter = 0
205 | if len(self.NL_3_idx) == 0:
206 | self.NL_3_idx = [-1]
207 | for i in range(len(self.layer3)):
208 | x = self.layer3[i](x)
209 | if i == self.NL_3_idx[NL3_counter]:
210 | _, C, H, W = x.shape
211 | x = self.NL_3[NL3_counter](x)
212 | NL3_counter += 1
213 | # Layer 4
214 | NL4_counter = 0
215 | if len(self.NL_4_idx) == 0:
216 | self.NL_4_idx = [-1]
217 | for i in range(len(self.layer4)):
218 | x = self.layer4[i](x)
219 | if i == self.NL_4_idx[NL4_counter]:
220 | _, C, H, W = x.shape
221 | x = self.NL_4[NL4_counter](x)
222 | NL4_counter += 1
223 |
224 | return x
225 |
226 | def random_init(self):
227 | for m in self.modules():
228 | if isinstance(m, nn.Conv2d):
229 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
230 | nn.init.normal_(m.weight, 0, math.sqrt(2. / n))
231 | elif isinstance(m, nn.BatchNorm2d):
232 | nn.init.constant_(m.weight, 1)
233 | nn.init.constant_(m.bias, 0)
234 |
235 |
236 | def init_pretrained_weights(key):
237 | """Initializes model with pretrained weights.
238 | Layers that don't match with pretrained layers in name or size are kept unchanged.
239 | """
240 | import os
241 | import errno
242 | import gdown
243 |
244 | def _get_torch_home():
245 | ENV_TORCH_HOME = 'TORCH_HOME'
246 | ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
247 | DEFAULT_CACHE_DIR = '~/.cache'
248 | torch_home = os.path.expanduser(
249 | os.getenv(
250 | ENV_TORCH_HOME,
251 | os.path.join(
252 | os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'torch'
253 | )
254 | )
255 | )
256 | return torch_home
257 |
258 | torch_home = _get_torch_home()
259 | model_dir = os.path.join(torch_home, 'checkpoints')
260 | try:
261 | os.makedirs(model_dir)
262 | except OSError as e:
263 | if e.errno == errno.EEXIST:
264 | # Directory already exists, ignore.
265 | pass
266 | else:
267 | # Unexpected OSError, re-raise.
268 | raise
269 |
270 | filename = model_urls[key].split('/')[-1]
271 |
272 | cached_file = os.path.join(model_dir, filename)
273 |
274 | if not os.path.exists(cached_file):
275 | if comm.is_main_process():
276 | gdown.download(model_urls[key], cached_file, quiet=False)
277 |
278 | comm.synchronize()
279 |
280 | logger.info(f"Loading pretrained model from {cached_file}")
281 | state_dict = torch.load(cached_file, map_location=torch.device('cpu'))
282 |
283 | return state_dict
284 |
285 |
286 | def build_resnet_backbone():
287 | """
288 | Create a ResNet instance from config.
289 | Returns:
290 | ResNet: a :class:`ResNet` instance.
291 | """
292 |
293 | # fmt: off
294 | pretrain = True
295 | pretrain_path = "metrics/models/weights/lup_moco_r50.pth"
296 | if not os.path.exists(pretrain_path):
297 | raise ValueError("Import model weights first")
298 | last_stride = 1
299 | bn_norm = "BN"
300 | with_ibn = False
301 | with_se = False
302 | with_nl = False
303 | depth = "50x"
304 | # fmt: on
305 |
306 | num_blocks_per_stage = {
307 | '18x': [2, 2, 2, 2],
308 | '34x': [3, 4, 6, 3],
309 | '50x': [3, 4, 6, 3],
310 | '101x': [3, 4, 23, 3],
311 | '152x': [3, 8, 36, 3],
312 | }[depth]
313 |
314 | nl_layers_per_stage = {
315 | '18x': [0, 0, 0, 0],
316 | '34x': [0, 0, 0, 0],
317 | '50x': [0, 2, 3, 0],
318 | '101x': [0, 2, 9, 0],
319 | '152x': [0, 4, 12, 0]
320 | }[depth]
321 |
322 | block = {
323 | '18x': BasicBlock,
324 | '34x': BasicBlock,
325 | '50x': Bottleneck,
326 | '101x': Bottleneck,
327 | '152x': Bottleneck,
328 | }[depth]
329 |
330 | model = ResNet(last_stride, bn_norm, with_ibn, with_se, with_nl, block,
331 | num_blocks_per_stage, nl_layers_per_stage)
332 | if pretrain:
333 | # Load pretrain path if specifically
334 | if pretrain_path:
335 | try:
336 | state_dict = torch.load(pretrain_path, map_location=torch.device('cpu'))
337 | logger.info(f"Loading pretrained model from {pretrain_path}")
338 | except FileNotFoundError as e:
339 | logger.info(f'{pretrain_path} is not found! Please check this path.')
340 | raise e
341 | except KeyError as e:
342 | logger.info("State dict keys error! Please check the state dict.")
343 | raise e
344 | else:
345 | key = depth
346 | if with_ibn: key = 'ibn_' + key
347 | if with_se: key = 'se_' + key
348 |
349 | state_dict = init_pretrained_weights(key)
350 |
351 | incompatible = model.load_state_dict(state_dict, strict=False)
352 | if incompatible.missing_keys:
353 | logger.info(
354 | get_missing_parameters_message(incompatible.missing_keys)
355 | )
356 | if incompatible.unexpected_keys:
357 | logger.info(
358 | get_unexpected_parameters_message(incompatible.unexpected_keys)
359 | )
360 |
361 | return model
--------------------------------------------------------------------------------
/metrics/reid.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Callable, List, Optional, Union
2 | from torchmetrics.metric import Metric
3 | import torch
4 | from tqdm import tqdm
5 | from metrics.models.fastreid.model import build_resnet_backbone
6 | import torch.nn.functional as F
7 |
8 |
9 | class ReidSubjectSuperiority(Metric):
10 | def __init__(
11 | self,
12 | compute_on_step: bool = False,
13 | dist_sync_on_step: bool = False,
14 | process_group: Optional[Any] = None,
15 | dist_sync_fn: Callable = None,
16 | ):
17 |
18 | super().__init__(
19 | compute_on_step=compute_on_step,
20 | dist_sync_on_step=dist_sync_on_step,
21 | process_group=process_group,
22 | dist_sync_fn=dist_sync_fn,
23 | )
24 | self.reid_model = build_resnet_backbone()
25 |
26 | self.add_state("more_similar_to_subject", torch.tensor(0), dist_reduce_fx="sum")
27 |
28 | self.add_state(
29 | "num_samples",
30 | torch.tensor(0).long(),
31 | dist_reduce_fx="sum",
32 | )
33 |
34 | def update(
35 | self,
36 | swap_images,
37 | background_images,
38 | background_segmask,
39 | subject_images,
40 | subject_segmask,
41 | ):
42 | self.num_samples += swap_images.shape[0]
43 | _, _, h, w = swap_images.shape
44 | background_background_segmask = 1 - background_segmask[:, 0, :, :]
45 | h_max = (background_background_segmask.sum(dim=2) > 0).long().argmax(dim=1)
46 | w_max = (background_background_segmask.sum(dim=1) > 0).long().argmax(dim=1)
47 |
48 | h_rev = [i for i in range(background_background_segmask.shape[1])]
49 |
50 | h_min = h - (background_background_segmask.sum(dim=2) > 0)[
51 | :, h_rev[::-1]
52 | ].long().argmax(dim=1)
53 | w_rev = [i for i in range(background_background_segmask.shape[2])]
54 | w_min = w - (background_background_segmask.sum(dim=1) > 0)[
55 | :, w_rev[::-1]
56 | ].long().argmax(dim=1)
57 | background_bbox = torch.stack([h_max, h_min, w_max, w_min], dim=1)
58 |
59 | subject_segmask = 1 - subject_segmask[:, 0, :, :]
60 | h_max = (subject_segmask.sum(dim=2) > 0).long().argmax(dim=1).long()
61 | w_max = (subject_segmask.sum(dim=1) > 0).long().argmax(dim=1).long()
62 | h_min = h - (subject_segmask.sum(dim=2) > 0)[:, h_rev[::-1]].long().argmax(
63 | dim=1
64 | )
65 | w_min = w - (subject_segmask.sum(dim=1) > 0)[:, w_rev[::-1]].long().argmax(
66 | dim=1
67 | )
68 | subject_bbox = torch.stack([h_max, h_min, w_max, w_min], dim=1)
69 |
70 | swap_images_cropped = []
71 |
72 | subject_images_cropped = []
73 |
74 | background_images_cropped = []
75 |
76 | for i in range(swap_images.shape[0]):
77 | swap_image = swap_images[i]
78 | background_image = background_images[i]
79 | subject_image = subject_images[i]
80 | swap_image_bbox = background_bbox[i]
81 | subject_image_bbox = subject_bbox[i]
82 |
83 | swap_images_cropped.append(
84 | F.interpolate(
85 | swap_image[
86 | :,
87 | swap_image_bbox[0] : swap_image_bbox[1],
88 | swap_image_bbox[2] : swap_image_bbox[3],
89 | ].unsqueeze(0),
90 | size=(256, 128),
91 | mode="bilinear",
92 | )
93 | )
94 |
95 | subject_images_cropped.append(
96 | F.interpolate(
97 | subject_image[
98 | :,
99 | subject_image_bbox[0] : subject_image_bbox[1],
100 | subject_image_bbox[2] : subject_image_bbox[3],
101 | ].unsqueeze(0),
102 | size=(256, 128),
103 | mode="bilinear",
104 | )
105 | )
106 | background_images_cropped.append(
107 | F.interpolate(
108 | background_image[
109 | :,
110 | swap_image_bbox[0] : swap_image_bbox[1],
111 | swap_image_bbox[2] : swap_image_bbox[3],
112 | ].unsqueeze(0),
113 | size=(256, 128),
114 | mode="bilinear",
115 | )
116 | )
117 | swap_images_cropped = torch.cat(swap_images_cropped, dim=0)
118 | subject_images_cropped = torch.cat(subject_images_cropped, dim=0)
119 | background_images_cropped = torch.cat(background_images_cropped, dim=0)
120 |
121 | swap_images_features = self.reid_model(swap_images_cropped).mean(dim=(2, 3))
122 | subject_images_features = self.reid_model(subject_images_cropped).mean(
123 | dim=(2, 3)
124 | )
125 | background_images_features = self.reid_model(background_images_cropped).mean(
126 | dim=(2, 3)
127 | )
128 |
129 | background_cos_sim = F.cosine_similarity(
130 | swap_images_features, background_images_features, dim=1
131 | )
132 | subject_cos_sim = F.cosine_similarity(
133 | swap_images_features, subject_images_features, dim=1
134 | )
135 |
136 | self.more_similar_to_subject += (subject_cos_sim > background_cos_sim).sum()
137 |
138 | def compute(self):
139 | return self.more_similar_to_subject / self.num_samples
140 |
141 |
142 | class ReidSubjectSimilarity(Metric):
143 | def __init__(
144 | self,
145 | compute_on_step: bool = False,
146 | dist_sync_on_step: bool = False,
147 | process_group: Optional[Any] = None,
148 | dist_sync_fn: Callable = None,
149 | ):
150 |
151 | super().__init__(
152 | compute_on_step=compute_on_step,
153 | dist_sync_on_step=dist_sync_on_step,
154 | process_group=process_group,
155 | dist_sync_fn=dist_sync_fn,
156 | )
157 | self.reid_model = build_resnet_backbone()
158 |
159 | self.add_state(
160 | "subject_cosine_sim", torch.tensor(0).float(), dist_reduce_fx="sum"
161 | )
162 |
163 | self.add_state(
164 | "num_samples",
165 | torch.tensor(0).long(),
166 | dist_reduce_fx="sum",
167 | )
168 |
169 | def update(
170 | self,
171 | swap_images,
172 | background_images,
173 | background_segmask,
174 | subject_images,
175 | subject_segmask,
176 | ):
177 | self.num_samples += swap_images.shape[0]
178 | _, _, h, w = swap_images.shape
179 | background_background_segmask = 1 - background_segmask[:, 0, :, :]
180 | h_max = (background_background_segmask.sum(dim=2) > 0).long().argmax(dim=1)
181 | w_max = (background_background_segmask.sum(dim=1) > 0).long().argmax(dim=1)
182 |
183 | h_rev = [i for i in range(background_background_segmask.shape[1])]
184 |
185 | h_min = h - (background_background_segmask.sum(dim=2) > 0)[
186 | :, h_rev[::-1]
187 | ].long().argmax(dim=1)
188 | w_rev = [i for i in range(background_background_segmask.shape[2])]
189 | w_min = w - (background_background_segmask.sum(dim=1) > 0)[
190 | :, w_rev[::-1]
191 | ].long().argmax(dim=1)
192 | background_bbox = torch.stack([h_max, h_min, w_max, w_min], dim=1)
193 |
194 | subject_segmask = 1 - subject_segmask[:, 0, :, :]
195 | h_max = (subject_segmask.sum(dim=2) > 0).long().argmax(dim=1)
196 | w_max = (subject_segmask.sum(dim=1) > 0).long().argmax(dim=1)
197 | h_min = h - (subject_segmask.sum(dim=2) > 0)[:, h_rev[::-1]].long().argmax(
198 | dim=1
199 | )
200 | w_min = w - (subject_segmask.sum(dim=1) > 0)[:, w_rev[::-1]].long().argmax(
201 | dim=1
202 | )
203 | subject_bbox = torch.stack([h_max, h_min, w_max, w_min], dim=1)
204 |
205 | swap_images_cropped = []
206 |
207 | subject_images_cropped = []
208 |
209 | for i in range(swap_images.shape[0]):
210 | swap_image = swap_images[i]
211 | subject_image = subject_images[i]
212 | swap_image_bbox = background_bbox[i]
213 | subject_image_bbox = subject_bbox[i]
214 |
215 | swap_images_cropped.append(
216 | F.interpolate(
217 | swap_image[
218 | :,
219 | swap_image_bbox[0] : swap_image_bbox[1],
220 | swap_image_bbox[2] : swap_image_bbox[3],
221 | ].unsqueeze(0),
222 | size=(256, 128),
223 | mode="bilinear",
224 | )
225 | )
226 | subject_images_cropped.append(
227 | F.interpolate(
228 | subject_image[
229 | :,
230 | subject_image_bbox[0] : subject_image_bbox[1],
231 | subject_image_bbox[2] : subject_image_bbox[3],
232 | ].unsqueeze(0),
233 | size=(256, 128),
234 | mode="bilinear",
235 | )
236 | )
237 | swap_images_cropped = torch.cat(swap_images_cropped, dim=0)
238 | subject_images_cropped = torch.cat(subject_images_cropped, dim=0)
239 |
240 | swap_images_features = self.reid_model(swap_images_cropped).mean(dim=(2, 3))
241 | subject_images_features = self.reid_model(subject_images_cropped).mean(
242 | dim=(2, 3)
243 | )
244 |
245 | subject_cos_sim = F.cosine_similarity(
246 | swap_images_features, subject_images_features, dim=1
247 | ).sum()
248 |
249 | self.subject_cosine_sim += subject_cos_sim
250 |
251 | def compute(self):
252 | return self.subject_cosine_sim / self.num_samples
253 |
254 |
255 | class ReidBackgroundSuperiority(Metric):
256 | def __init__(
257 | self,
258 | compute_on_step: bool = False,
259 | dist_sync_on_step: bool = False,
260 | process_group: Optional[Any] = None,
261 | dist_sync_fn: Callable = None,
262 | ):
263 |
264 | super().__init__(
265 | compute_on_step=compute_on_step,
266 | dist_sync_on_step=dist_sync_on_step,
267 | process_group=process_group,
268 | dist_sync_fn=dist_sync_fn,
269 | )
270 | self.reid_model = build_resnet_backbone()
271 |
272 | self.add_state(
273 | "background_cosine_sim", torch.tensor(0).float(), dist_reduce_fx="sum"
274 | )
275 |
276 | self.add_state(
277 | "num_samples",
278 | torch.tensor(0).long(),
279 | dist_reduce_fx="sum",
280 | )
281 |
282 | def update(
283 | self,
284 | swap_images,
285 | background_images,
286 | background_segmask,
287 | subject_images,
288 | subject_segmask,
289 | ):
290 | self.num_samples += swap_images.shape[0]
291 | _, _, h, w = swap_images.shape
292 | background_background_segmask = 1 - background_segmask[:, 0, :, :]
293 | h_max = (background_background_segmask.sum(dim=2) > 0).long().argmax(dim=1)
294 | w_max = (background_background_segmask.sum(dim=1) > 0).long().argmax(dim=1)
295 |
296 | h_rev = [i for i in range(background_background_segmask.shape[1])]
297 |
298 | h_min = h - (background_background_segmask.sum(dim=2) > 0)[
299 | :, h_rev[::-1]
300 | ].long().argmax(dim=1)
301 | w_rev = [i for i in range(background_background_segmask.shape[2])]
302 | w_min = w - (background_background_segmask.sum(dim=1) > 0)[
303 | :, w_rev[::-1]
304 | ].long().argmax(dim=1)
305 | background_bbox = torch.stack([h_max, h_min, w_max, w_min], dim=1)
306 |
307 | swap_images_cropped = []
308 |
309 | background_images_cropped = []
310 |
311 | for i in range(swap_images.shape[0]):
312 | swap_image = swap_images[i]
313 | background_image = background_images[i]
314 | swap_image_bbox = background_bbox[i]
315 |
316 | swap_images_cropped.append(
317 | F.interpolate(
318 | swap_image[
319 | :,
320 | swap_image_bbox[0] : swap_image_bbox[1],
321 | swap_image_bbox[2] : swap_image_bbox[3],
322 | ].unsqueeze(0),
323 | size=(256, 128),
324 | mode="bilinear",
325 | )
326 | )
327 | background_images_cropped.append(
328 | F.interpolate(
329 | background_image[
330 | :,
331 | swap_image_bbox[0] : swap_image_bbox[1],
332 | swap_image_bbox[2] : swap_image_bbox[3],
333 | ].unsqueeze(0),
334 | size=(256, 128),
335 | mode="bilinear",
336 | )
337 | )
338 | swap_images_cropped = torch.cat(swap_images_cropped, dim=0)
339 | background_images_cropped = torch.cat(background_images_cropped, dim=0)
340 |
341 | swap_images_features = self.reid_model(swap_images_cropped).mean(dim=(2, 3))
342 | background_images_features = self.reid_model(background_images_cropped).mean(
343 | dim=(2, 3)
344 | )
345 |
346 | background_cos_sim = F.cosine_similarity(
347 | swap_images_features, background_images_features, dim=1
348 | ).sum()
349 |
350 | self.background_cosine_sim += background_cos_sim
351 |
352 | def compute(self):
353 | return self.background_cosine_sim / self.num_samples
354 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nicolas-dufour/SCAM/848484acd71c277f3ed48456e206870847dd400d/models/__init__.py
--------------------------------------------------------------------------------
/models/diffaugment.py:
--------------------------------------------------------------------------------
1 | import kornia.augmentation as K
2 | import torch.nn as nn
3 |
4 |
5 | class SimpleAugmentation(nn.Module):
6 | def __init__(self):
7 | super().__init__()
8 | self.aug = nn.Sequential(
9 | K.Denormalize(mean=0.5, std=0.5),
10 | K.ColorJitter(p=0.8, brightness=0.2, contrast=0.3, hue=0.2),
11 | K.RandomErasing(p=0.5),
12 | K.Normalize(mean=0.5, std=0.5),
13 | )
14 |
15 | def forward(self, x):
16 | return self.aug(x)
17 |
--------------------------------------------------------------------------------
/models/discriminators/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nicolas-dufour/SCAM/848484acd71c277f3ed48456e206870847dd400d/models/discriminators/__init__.py
--------------------------------------------------------------------------------
/models/discriminators/mcad.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torchtyping import TensorType
5 | from typing import Union, Sequence
6 |
7 | import torch.nn.utils.spectral_norm as spectral_norm
8 | from collections import OrderedDict
9 |
10 | from einops import repeat, rearrange
11 |
12 | from models.utils_blocks.base import BaseNetwork
13 |
14 | from models.utils_blocks.attention import (
15 | MaskedTransformer,
16 | SinusoidalPositionalEmbedding,
17 | )
18 |
19 |
20 | class MultiScaleMCADiscriminator(BaseNetwork):
21 | """
22 | This is the Multi Scale Discriminator. We use multiple discriminators at different scales
23 |
24 | Parameters:
25 | -----------
26 | num_discriminator: int,
27 | How many discriminators do we use
28 | image_num_channels: int,
29 | Number of input images channels
30 | segmap_num_channels: int,
31 | Number of segmentation map channels
32 | num_features_fst_conv: int,
33 | How many kernels at the first convolution
34 | num_layers: int,
35 | How many layers per discriminator
36 | apply_spectral_norm: bool = True,
37 | Wheter or not to apply spectral normalization
38 | keep_intermediate_results: bool = True
39 | Whether or not to keep intermediate discriminators feature maps
40 |
41 | """
42 |
43 | def __init__(
44 | self,
45 | num_discriminator: int,
46 | image_num_channels: int,
47 | segmap_num_channels: int,
48 | positional_embedding_dim: int,
49 | num_labels: int,
50 | num_latent_per_labels: int,
51 | latent_dim: int,
52 | num_blocks: int,
53 | attention_latent_dim: int,
54 | num_cross_heads: int,
55 | num_self_heads: int,
56 | apply_spectral_norm: bool = True,
57 | concat_segmaps: bool = False,
58 | output_type: str = "patchgan",
59 | keep_intermediate_results: bool = True,
60 | ):
61 | super().__init__()
62 |
63 | self.keep_intermediate_results = keep_intermediate_results
64 |
65 | self.discriminators = nn.ModuleDict(OrderedDict())
66 |
67 | for i in range(num_discriminator):
68 | self.discriminators.update(
69 | {
70 | f"discriminator_{i}": MaskedCrossAttentionDiscriminator(
71 | image_num_channels=image_num_channels,
72 | segmap_num_channels=segmap_num_channels,
73 | positional_embedding_dim=positional_embedding_dim,
74 | num_labels=num_labels,
75 | num_latent_per_labels=num_latent_per_labels,
76 | latent_dim=latent_dim,
77 | num_blocks=num_blocks,
78 | attention_latent_dim=attention_latent_dim,
79 | num_cross_heads=num_cross_heads,
80 | num_self_heads=num_self_heads,
81 | apply_spectral_norm=apply_spectral_norm,
82 | concat_segmaps=concat_segmaps,
83 | output_type=output_type,
84 | keep_intermediate_results=keep_intermediate_results,
85 | ),
86 | }
87 | )
88 | self.downsample = nn.AvgPool2d(
89 | kernel_size=3, stride=2, padding=[1, 1], count_include_pad=False
90 | )
91 |
92 | def forward(
93 | self,
94 | input: TensorType["batch_size", "input_channels", "height", "width"],
95 | segmentation_map: TensorType[
96 | "batch_size", "num_labels", "height", "width"
97 | ] = None,
98 | ) -> Union[
99 | Sequence[
100 | Sequence[TensorType["batch_size", 1, "output_height", "output_width"]]
101 | ],
102 | Sequence[TensorType["batch_size", 1, "output_height", "output_width"]],
103 | ]:
104 | results = []
105 | for disc_name in self.discriminators:
106 | result = self.discriminators[disc_name](input, segmentation_map)
107 | if not self.keep_intermediate_results:
108 | result = [result]
109 | results.append(result)
110 | input = self.downsample(input)
111 | segmentation_map = F.interpolate(
112 | segmentation_map, size=input.size()[2:], mode="nearest"
113 | )
114 |
115 | return results
116 |
117 |
118 | class MaskedCrossAttentionDiscriminator(BaseNetwork):
119 | """
120 | Encoder that encode style vectors for each segmentation labels doing regional average pooling
121 |
122 | Parameters:
123 | -----------
124 | num_input_channels: int,
125 | Number of input channels
126 | latent_dim: int,
127 | Number of output channels (style_dim)
128 | num_features_fst_conv: iny,
129 | Number of kernels at first conv
130 |
131 | """
132 |
133 | def __init__(
134 | self,
135 | image_num_channels: int,
136 | segmap_num_channels: int,
137 | positional_embedding_dim: int,
138 | num_labels: int,
139 | num_latent_per_labels: int,
140 | latent_dim: int,
141 | num_blocks: int,
142 | attention_latent_dim: int,
143 | num_cross_heads: int,
144 | num_self_heads: int,
145 | apply_spectral_norm: bool = True,
146 | concat_segmaps: bool = False,
147 | output_type: str = "patchgan",
148 | keep_intermediate_results: bool = True,
149 | ):
150 | super().__init__()
151 |
152 | self.concat_segmaps = concat_segmaps
153 | self.output_type = output_type
154 | self.keep_intermediate_results = keep_intermediate_results
155 | self.image_pos_embs = nn.ModuleList(
156 | [SinusoidalPositionalEmbedding(positional_embedding_dim, emb_type="concat")]
157 | )
158 | self.convs = nn.ModuleList([nn.Identity()])
159 |
160 | num_input_channels = image_num_channels + segmap_num_channels
161 |
162 | image_emb_dim = num_input_channels + positional_embedding_dim
163 |
164 | num_latents = num_labels * num_latent_per_labels
165 | self.num_latent_per_labels = num_latent_per_labels
166 |
167 | latents_mask = torch.block_diag(
168 | *[
169 | torch.FloatTensor(
170 | [
171 | [1.0 for _ in range(num_latent_per_labels)]
172 | for _ in range(num_latent_per_labels)
173 | ]
174 | )
175 | for _ in range(num_labels)
176 | ]
177 | ).unsqueeze(0)
178 | self.register_buffer("latents_mask", latents_mask)
179 |
180 | self.latents = nn.Parameter(torch.Tensor(num_latents, latent_dim))
181 |
182 | self.backbone = nn.ModuleDict(OrderedDict())
183 |
184 | for i in range(num_blocks):
185 | module = nn.ModuleDict(OrderedDict())
186 | module.update(
187 | {
188 | "cross_attention": MaskedTransformer(
189 | latent_dim,
190 | num_latents,
191 | image_emb_dim if i == 0 else 32 * 2 ** i,
192 | attention_latent_dim,
193 | num_cross_heads,
194 | ),
195 | }
196 | )
197 | module.update(
198 | {
199 | "self_attention": MaskedTransformer(
200 | latent_dim,
201 | num_latents,
202 | latent_dim,
203 | attention_latent_dim,
204 | num_self_heads,
205 | ),
206 | }
207 | )
208 | self.backbone.update({f"block_{i}": module})
209 | if i > 0:
210 | conv = nn.Conv2d(
211 | num_input_channels if i == 1 else 32 * 2 ** (i - 1),
212 | 32 * 2 ** i,
213 | kernel_size=3,
214 | padding=1,
215 | stride=2,
216 | )
217 | if apply_spectral_norm:
218 | conv = spectral_norm(conv)
219 | self.convs.append(
220 | nn.Sequential(
221 | conv,
222 | nn.LeakyReLU(0.2),
223 | )
224 | )
225 | self.image_pos_embs.append(
226 | SinusoidalPositionalEmbedding(32 * 2 ** i, emb_type="add")
227 | )
228 | self.conv_to_latent = nn.Sequential(
229 | nn.Linear(32 * 2 ** (num_blocks - 1), latent_dim),
230 | nn.LeakyReLU(0.2),
231 | )
232 | if self.output_type == "attention_pool":
233 | self.cls_token = nn.Parameter(torch.randn(1, 1, latent_dim))
234 | self.attention_pool = MaskedTransformer(
235 | latent_dim,
236 | 1,
237 | latent_dim,
238 | attention_latent_dim,
239 | num_self_heads,
240 | )
241 | self.classification_head = nn.Linear(latent_dim, 1)
242 |
243 | def forward(
244 | self,
245 | input: TensorType["batch_size", "num_input_channels", "height", "width"],
246 | segmentation_map: TensorType["batch_size", "num_labels", "height", "width"],
247 | ) -> TensorType["batch_size", "num_input_channels", "style_dim"]:
248 | batch_size = input.shape[0]
249 |
250 | latents = repeat(self.latents, " l d -> b l d", b=batch_size)
251 |
252 | if self.concat_segmaps:
253 | input = torch.cat([input, segmentation_map], dim=1)
254 |
255 | cross_attention_masks = []
256 | flattened_inputs = []
257 | for i, conv in enumerate(self.convs):
258 | input = conv(input)
259 | cross_attention_masks.append(
260 | rearrange(
261 | torch.repeat_interleave(
262 | F.interpolate(
263 | segmentation_map, size=input.size()[2:], mode="nearest"
264 | ),
265 | self.num_latent_per_labels,
266 | dim=1,
267 | ),
268 | "b n h w -> b n (h w)",
269 | )
270 | )
271 | flattened_inputs.append(
272 | rearrange(self.image_pos_embs[i](input), " b c h w -> b (h w) c")
273 | )
274 |
275 | for i, block_name in enumerate(self.backbone):
276 | latents = self.backbone[block_name]["cross_attention"](
277 | latents, flattened_inputs[i], cross_attention_masks[i]
278 | )
279 | convs_features = self.conv_to_latent(flattened_inputs[-1])
280 | latents_and_convs = torch.cat([latents, convs_features], dim=1)
281 |
282 | if self.output_type == "patchgan":
283 | flattened_inputs.append(self.classification_head(latents_and_convs))
284 | elif self.output_type == "mean_pool":
285 | flattened_inputs.append(
286 | self.classification_head(latents_and_convs.mean(dim=1))
287 | )
288 | elif self.output_type == "attention_pool":
289 | output_token = repeat(self.cls_token, "() n d -> b n d", b=batch_size)
290 | output_token = self.attention_pool(output_token, latents_and_convs)
291 | flattened_inputs.append(self.classification_head(output_token))
292 | else:
293 | raise ValueError("Not a supported disc output type")
294 |
295 | if self.keep_intermediate_results:
296 | return flattened_inputs[1:]
297 | else:
298 | return flattened_inputs[-1]
299 |
--------------------------------------------------------------------------------
/models/discriminators/oasis.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | import torch.nn.utils.spectral_norm as spectral_norm
5 |
6 | from models.utils_blocks.base import BaseNetwork
7 |
8 |
9 | class OASISDiscriminator(BaseNetwork):
10 | """
11 | PatchGAN Discriminator. Perform patch-wise GAN discrimination to avoid ignoring local high dimensional features.
12 |
13 | Parameters:
14 | -----------
15 | image_num_channels: int,
16 | Number of channels in the real/generated image
17 | segmap_num_channels: int,
18 | Number of channels in the segmentation map
19 | num_features_fst_conv: int,
20 | Number of features in the first convolution layer
21 | num_layers: int,
22 | Number of convolution layers
23 | apply_spectral_norm: bool, default True
24 | Whether or not to apply spectral normalization
25 | keep_intermediate_results: bool, default True
26 | Whether or not to keep each feature map output
27 | """
28 |
29 | def __init__(
30 | self,
31 | image_num_channels: int,
32 | segmap_num_channels: int,
33 | apply_spectral_norm: bool = True,
34 | apply_grad_norm: bool = False,
35 | ):
36 | super().__init__()
37 | self.apply_grad_norm = apply_grad_norm
38 | output_channel = segmap_num_channels + 1 # for N+1 loss
39 | self.channels = [image_num_channels, 128, 128, 256, 256, 512, 512]
40 | num_res_blocks = 6
41 | self.body_up = nn.ModuleList([])
42 | self.body_down = nn.ModuleList([])
43 | # encoder part
44 | for i in range(num_res_blocks):
45 | self.body_down.append(
46 | OASISBlock(
47 | self.channels[i],
48 | self.channels[i + 1],
49 | -1,
50 | first=(i == 0),
51 | apply_spectral_norm=apply_spectral_norm,
52 | )
53 | )
54 | # decoder part
55 | self.body_up.append(
56 | OASISBlock(
57 | self.channels[-1],
58 | self.channels[-2],
59 | 1,
60 | apply_spectral_norm=apply_spectral_norm,
61 | )
62 | )
63 | for i in range(1, num_res_blocks - 1):
64 | self.body_up.append(
65 | OASISBlock(
66 | 2 * self.channels[-1 - i],
67 | self.channels[-2 - i],
68 | 1,
69 | apply_spectral_norm=apply_spectral_norm,
70 | )
71 | )
72 | self.body_up.append(
73 | OASISBlock(
74 | 2 * self.channels[1], 64, 1, apply_spectral_norm=apply_spectral_norm
75 | )
76 | )
77 | self.layer_up_last = nn.Conv2d(64, output_channel, 1, 1, 0)
78 |
79 | def forward(self, input, segmap=None):
80 | x = input
81 | # encoder
82 | encoder_res = list()
83 | for i in range(len(self.body_down)):
84 | x = self.body_down[i](x)
85 | encoder_res.append(x)
86 | # decoder
87 | x = self.body_up[0](x)
88 | for i in range(1, len(self.body_down)):
89 | x = self.body_up[i](torch.cat((encoder_res[-i - 1], x), dim=1))
90 | ans = self.layer_up_last(x)
91 | if self.apply_grad_norm:
92 | grad = torch.autograd.grad(
93 | ans,
94 | [input],
95 | torch.ones_like(ans),
96 | create_graph=True,
97 | retain_graph=True,
98 | )[0]
99 | grad_norm = torch.norm(torch.flatten(grad, start_dim=1), p=2, dim=1)
100 | grad_norm = grad_norm.view(-1, *[1 for _ in range(len(ans.shape) - 1)])
101 | ans = ans / (grad_norm + torch.abs(ans))
102 | return ans
103 |
104 |
105 | class OASISBlock(nn.Module):
106 | def __init__(
107 | self,
108 | fin: int,
109 | fout: int,
110 | up_or_down: int,
111 | first: bool = False,
112 | apply_spectral_norm: bool = True,
113 | ):
114 | super().__init__()
115 | # Attributes
116 | self.up_or_down = up_or_down
117 | self.first = first
118 | self.learned_shortcut = fin != fout
119 | fmiddle = fout
120 | if apply_spectral_norm:
121 | norm_layer = spectral_norm
122 | else:
123 | norm_layer = nn.Identity()
124 | if first:
125 | self.conv1 = nn.Sequential(norm_layer(nn.Conv2d(fin, fmiddle, 3, 1, 1)))
126 | else:
127 | if self.up_or_down > 0:
128 | self.conv1 = nn.Sequential(
129 | nn.LeakyReLU(0.2, False),
130 | nn.Upsample(scale_factor=2),
131 | norm_layer(nn.Conv2d(fin, fmiddle, 3, 1, 1)),
132 | )
133 | else:
134 | self.conv1 = nn.Sequential(
135 | nn.LeakyReLU(0.2, False),
136 | norm_layer(nn.Conv2d(fin, fmiddle, 3, 1, 1)),
137 | )
138 | self.conv2 = nn.Sequential(
139 | nn.LeakyReLU(0.2, False), norm_layer(nn.Conv2d(fmiddle, fout, 3, 1, 1))
140 | )
141 | if self.learned_shortcut:
142 | self.conv_s = norm_layer(nn.Conv2d(fin, fout, 1, 1, 0))
143 | if up_or_down > 0:
144 | self.sampling = nn.Upsample(scale_factor=2)
145 | elif up_or_down < 0:
146 | self.sampling = nn.AvgPool2d(2)
147 | else:
148 | self.sampling = nn.Sequential()
149 |
150 | def forward(self, x):
151 | x_s = self.shortcut(x)
152 | dx = self.conv1(x)
153 | dx = self.conv2(dx)
154 | if self.up_or_down < 0:
155 | dx = self.sampling(dx)
156 | out = x_s + dx
157 | return out
158 |
159 | def shortcut(self, x):
160 | if self.first:
161 | if self.up_or_down < 0:
162 | x = self.sampling(x)
163 | if self.learned_shortcut:
164 | x = self.conv_s(x)
165 | x_s = x
166 | else:
167 | if self.up_or_down > 0:
168 | x = self.sampling(x)
169 | if self.learned_shortcut:
170 | x = self.conv_s(x)
171 | if self.up_or_down < 0:
172 | x = self.sampling(x)
173 | x_s = x
174 | return x_s
175 |
--------------------------------------------------------------------------------
/models/discriminators/patchgan.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torchtyping import TensorType
5 | from typing import Union, Sequence
6 |
7 | import torch.nn.utils.spectral_norm as spectral_norm
8 | from collections import OrderedDict
9 | from math import ceil
10 |
11 | from models.utils_blocks.base import BaseNetwork
12 |
13 | from models.utils_blocks.equallr import EqualConv2d
14 | from functools import partial
15 |
16 |
17 | class MultiScalePatchGanDiscriminator(BaseNetwork):
18 | """
19 | This is the Multi Scale Discriminator. We use multiple discriminators at different scales
20 |
21 | Parameters:
22 | -----------
23 | num_discriminator: int,
24 | How many discriminators do we use
25 | image_num_channels: int,
26 | Number of input images channels
27 | segmap_num_channels: int,
28 | Number of segmentation map channels
29 | num_features_fst_conv: int,
30 | How many kernels at the first convolution
31 | num_layers: int,
32 | How many layers per discriminator
33 | apply_spectral_norm: bool = True,
34 | Wheter or not to apply spectral normalization
35 | keep_intermediate_results: bool = True
36 | Whether or not to keep intermediate discriminators feature maps
37 |
38 | """
39 |
40 | def __init__(
41 | self,
42 | num_discriminator: int,
43 | image_num_channels: int,
44 | segmap_num_channels: int,
45 | num_features_fst_conv: int,
46 | num_layers: int,
47 | apply_spectral_norm: bool = True,
48 | apply_grad_norm: bool = False,
49 | keep_intermediate_results: bool = True,
50 | use_equalized_lr: bool = False,
51 | lr_mul: float = 1.0,
52 | ):
53 | super().__init__()
54 |
55 | self.keep_intermediate_results = keep_intermediate_results
56 |
57 | self.image_num_channels = image_num_channels
58 |
59 | self.discriminators = nn.ModuleDict(OrderedDict())
60 |
61 | for i in range(num_discriminator):
62 | self.discriminators.update(
63 | {
64 | f"discriminator_{i}": PatchGANDiscriminator(
65 | image_num_channels=image_num_channels,
66 | segmap_num_channels=segmap_num_channels,
67 | num_features_fst_conv=num_features_fst_conv,
68 | num_layers=num_layers,
69 | apply_spectral_norm=apply_spectral_norm,
70 | apply_grad_norm=apply_grad_norm,
71 | keep_intermediate_results=keep_intermediate_results,
72 | use_equalized_lr=use_equalized_lr,
73 | lr_mul=lr_mul,
74 | ),
75 | }
76 | )
77 | self.downsample = nn.AvgPool2d(
78 | kernel_size=3, stride=2, padding=[1, 1], count_include_pad=False
79 | )
80 |
81 | def forward(
82 | self,
83 | input: TensorType["batch_size", "input_channels", "height", "width"],
84 | ) -> Union[
85 | Sequence[
86 | Sequence[TensorType["batch_size", 1, "output_height", "output_width"]]
87 | ],
88 | Sequence[TensorType["batch_size", 1, "output_height", "output_width"]],
89 | ]:
90 | results = []
91 | for disc_name in self.discriminators:
92 | result = self.discriminators[disc_name](input)
93 | if not self.keep_intermediate_results:
94 | result = [result]
95 | results.append(result)
96 | image = self.downsample(input[:, : self.image_num_channels])
97 | segmentation_map = F.interpolate(
98 | input[:, self.image_num_channels :],
99 | size=image.size()[2:],
100 | mode="nearest",
101 | )
102 | input = torch.cat([image, segmentation_map], dim=1)
103 | return results
104 |
105 |
106 | class PatchGANDiscriminator(BaseNetwork):
107 | """
108 | PatchGAN Discriminator. Perform patch-wise GAN discrimination to avoid ignoring local high dimensional features.
109 |
110 | Parameters:
111 | -----------
112 | image_num_channels: int,
113 | Number of channels in the real/generated image
114 | segmap_num_channels: int,
115 | Number of channels in the segmentation map
116 | num_features_fst_conv: int,
117 | Number of features in the first convolution layer
118 | num_layers: int,
119 | Number of convolution layers
120 | apply_spectral_norm: bool, default True
121 | Whether or not to apply spectral normalization
122 | keep_intermediate_results: bool, default True
123 | Whether or not to keep each feature map output
124 | """
125 |
126 | def __init__(
127 | self,
128 | image_num_channels: int,
129 | segmap_num_channels: int,
130 | num_features_fst_conv: int,
131 | num_layers: int,
132 | apply_spectral_norm: bool = True,
133 | apply_grad_norm: bool = False,
134 | keep_intermediate_results: bool = True,
135 | use_equalized_lr: bool = False,
136 | lr_mul: float = 1.0,
137 | ):
138 | super().__init__()
139 | kernel_size = 4
140 | padding = int(ceil((kernel_size - 1.0) / 2))
141 | nffc = num_features_fst_conv
142 | self.keep_intermediate_results = keep_intermediate_results
143 | self.apply_grad_norm = apply_grad_norm
144 | self.model = nn.ModuleDict(OrderedDict())
145 | ConvLayer = (
146 | partial(EqualConv2d, lr_mul=lr_mul) if use_equalized_lr else nn.Conv2d
147 | )
148 | self.model.update(
149 | {
150 | "conv_0": nn.Sequential(
151 | ConvLayer(
152 | in_channels=image_num_channels + segmap_num_channels,
153 | out_channels=nffc,
154 | kernel_size=kernel_size,
155 | stride=2,
156 | padding=padding,
157 | ),
158 | nn.LeakyReLU(0.2, False),
159 | )
160 | }
161 | )
162 |
163 | for n in range(1, num_layers):
164 | nffc_prev = nffc
165 | nffc = min(2 * nffc_prev, 512)
166 | stride = 1 if n == num_layers - 1 else 2
167 | conv = ConvLayer(
168 | in_channels=nffc_prev,
169 | out_channels=nffc,
170 | kernel_size=kernel_size,
171 | padding=padding,
172 | stride=stride,
173 | )
174 | if apply_spectral_norm:
175 | conv = spectral_norm(conv)
176 | self.model.update(
177 | {
178 | f"conv_{n}": nn.Sequential(
179 | conv,
180 | nn.InstanceNorm2d(nffc),
181 | nn.LeakyReLU(0.2, False),
182 | )
183 | }
184 | )
185 | self.model.update(
186 | {
187 | f"last_conv": nn.Sequential(
188 | ConvLayer(
189 | in_channels=nffc,
190 | out_channels=1,
191 | kernel_size=kernel_size,
192 | stride=1,
193 | padding=padding,
194 | )
195 | )
196 | }
197 | )
198 |
199 | def forward(
200 | self,
201 | input: TensorType["batch_size", "input_channels", "height", "width"],
202 | ) -> Union[
203 | Sequence[TensorType["batch_size", 1, "output_height", "output_width"]],
204 | TensorType["batch_size", 1, "output_height", "output_width"],
205 | ]:
206 | results = [input]
207 | for conv_name in self.model:
208 | results.append(self.model[conv_name](results[-1]))
209 | if self.keep_intermediate_results:
210 | return results[1:]
211 | else:
212 | return results[-1]
213 |
--------------------------------------------------------------------------------
/models/discriminators/stylegan2.py:
--------------------------------------------------------------------------------
1 | from genericpath import exists
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from functools import partial
6 | from torchtyping import TensorType
7 | from typing import Union, Sequence
8 | from collections import OrderedDict
9 | import torch.nn.utils.spectral_norm as spectral_norm
10 |
11 | from models.utils_blocks.equallr import EqualConv2d, EqualLinear
12 |
13 | from math import sqrt
14 |
15 | from einops.layers.torch import Rearrange
16 |
17 | from kornia.filters import filter2d
18 |
19 | from models.utils_blocks.base import BaseNetwork
20 |
21 |
22 | class MultiStyleGan2Discriminator(BaseNetwork):
23 | """
24 | This is the Multi Scale Discriminator. We use multiple discriminators at different scales
25 |
26 | Parameters:
27 | -----------
28 | num_discriminator: int,
29 | How many discriminators do we use
30 | image_num_channels: int,
31 | Number of input images channels
32 | segmap_num_channels: int,
33 | Number of segmentation map channels
34 | num_features_fst_conv: int,
35 | How many kernels at the first convolution
36 | num_layers: int,
37 | How many layers per discriminator
38 | apply_spectral_norm: bool = True,
39 | Wheter or not to apply spectral normalization
40 | keep_intermediate_results: bool = True
41 | Whether or not to keep intermediate discriminators feature maps
42 |
43 | """
44 |
45 | def __init__(
46 | self,
47 | num_discriminator: int,
48 | image_num_channels: int,
49 | segmap_num_channels: int,
50 | num_features_fst_conv: int,
51 | num_layers: int,
52 | fmap_max: int,
53 | apply_spectral_norm: bool = True,
54 | apply_grad_norm: bool = False,
55 | keep_intermediate_results: bool = True,
56 | use_equalized_lr=False,
57 | lr_mul=1,
58 | ):
59 | super().__init__()
60 |
61 | self.keep_intermediate_results = keep_intermediate_results
62 |
63 | self.image_num_channels = image_num_channels
64 |
65 | self.discriminators = nn.ModuleDict(OrderedDict())
66 |
67 | for i in range(num_discriminator):
68 | self.discriminators.update(
69 | {
70 | f"discriminator_{i}": StyleGan2Discriminator(
71 | image_num_channels=image_num_channels,
72 | segmap_num_channels=segmap_num_channels,
73 | num_features_fst_conv=num_features_fst_conv,
74 | num_layers=num_layers - i,
75 | fmap_max=fmap_max,
76 | apply_spectral_norm=apply_spectral_norm,
77 | apply_grad_norm=apply_grad_norm,
78 | keep_intermediate_results=keep_intermediate_results,
79 | use_equalized_lr=use_equalized_lr,
80 | lr_mul=lr_mul,
81 | ),
82 | }
83 | )
84 | self.downsample = nn.AvgPool2d(
85 | kernel_size=3, stride=2, padding=[1, 1], count_include_pad=False
86 | )
87 |
88 | def forward(
89 | self,
90 | input: TensorType["batch_size", "input_channels", "height", "width"],
91 | ) -> Union[
92 | Sequence[
93 | Sequence[TensorType["batch_size", 1, "output_height", "output_width"]]
94 | ],
95 | Sequence[TensorType["batch_size", 1, "output_height", "output_width"]],
96 | ]:
97 | results = []
98 | for disc_name in self.discriminators:
99 | result = self.discriminators[disc_name](input)
100 | if not self.keep_intermediate_results:
101 | result = [result]
102 | results.append(result)
103 | image = self.downsample(input[:, : self.image_num_channels])
104 | segmentation_map = F.interpolate(
105 | input[:, self.image_num_channels :],
106 | size=image.size()[2:],
107 | mode="nearest",
108 | )
109 | input = torch.cat([image, segmentation_map], dim=1)
110 |
111 | return results
112 |
113 |
114 | class StyleGan2Discriminator(BaseNetwork):
115 | """
116 | Taken from lucidrain repo https://github.com/lucidrains/stylegan2-pytorch
117 |
118 | """
119 |
120 | def __init__(
121 | self,
122 | num_layers,
123 | image_num_channels,
124 | segmap_num_channels,
125 | num_features_fst_conv,
126 | fmap_max,
127 | keep_intermediate_results=True,
128 | apply_grad_norm=False,
129 | apply_spectral_norm=False,
130 | use_equalized_lr=False,
131 | lr_mul=1,
132 | ):
133 | super().__init__()
134 | self.apply_grad_norm = apply_grad_norm
135 | self.keep_intermediate_results = keep_intermediate_results
136 | init_channels = image_num_channels + segmap_num_channels
137 | blocks = []
138 | ConvLayer = (
139 | partial(EqualConv2d, lr_mul=lr_mul) if use_equalized_lr else nn.Conv2d
140 | )
141 | LinearLayer = (
142 | partial(EqualLinear, lr_mul=lr_mul) if use_equalized_lr else nn.Linear
143 | )
144 | filters = [init_channels] + [
145 | num_features_fst_conv * (2**i) for i in range(num_layers + 1)
146 | ]
147 | set_fmap_max = partial(min, fmap_max)
148 | filters = list(map(set_fmap_max, filters))
149 | chan_in_out = list(zip(filters[:-1], filters[1:]))
150 | for i, (input_dim, output_dim) in enumerate(chan_in_out):
151 | is_not_last = i != len(chan_in_out) - 1
152 | block = StyleGan2DiscBlock(
153 | input_dim,
154 | output_dim,
155 | downsample=is_not_last,
156 | apply_spectral_norm=apply_spectral_norm,
157 | use_equalized_lr=use_equalized_lr,
158 | lr_mul=lr_mul,
159 | )
160 | blocks.append(block)
161 | self.blocks = nn.ModuleList(blocks)
162 |
163 | chan_last = filters[-1]
164 | latent_dim = 2 * 2 * chan_last
165 |
166 | self.final_block = nn.Sequential(
167 | ConvLayer(chan_last, chan_last, 3, padding=1),
168 | Rearrange("b c h w -> b (c h w)"),
169 | LinearLayer(latent_dim, 1),
170 | )
171 |
172 | def forward(self, input):
173 | if self.apply_grad_norm:
174 | input.requires_grad_(True)
175 | results = [input]
176 | for block in self.blocks:
177 | results.append(block(results[-1]))
178 | results.append(self.final_block(results[-1]))
179 | if self.keep_intermediate_results:
180 | return results[1:]
181 | else:
182 | return results[-1]
183 |
184 |
185 | class StyleGan2DiscBlock(nn.Module):
186 | """
187 | Taken from lucidrain repo https://github.com/lucidrains/stylegan2-pytorch
188 |
189 | """
190 |
191 | def __init__(
192 | self,
193 | in_channels,
194 | out_channels,
195 | downsample=True,
196 | apply_spectral_norm=False,
197 | use_equalized_lr=False,
198 | lr_mul=1,
199 | ):
200 | super().__init__()
201 | spectral_norm_op = spectral_norm if apply_spectral_norm else nn.Identity()
202 | ConvLayer = (
203 | partial(EqualConv2d, lr_mul=lr_mul) if use_equalized_lr else nn.Conv2d
204 | )
205 | self.conv_res = spectral_norm_op(
206 | ConvLayer(in_channels, out_channels, 1, stride=(2 if downsample else 1))
207 | )
208 | self.net = nn.Sequential(
209 | spectral_norm_op(ConvLayer(in_channels, out_channels, 3, padding=1)),
210 | nn.LeakyReLU(0.2),
211 | spectral_norm_op(ConvLayer(out_channels, out_channels, 3, padding=1)),
212 | nn.LeakyReLU(0.2),
213 | )
214 |
215 | self.downsample = (
216 | nn.Sequential(
217 | # Blur(),
218 | spectral_norm_op(
219 | ConvLayer(out_channels, out_channels, 3, padding=1, stride=2)
220 | ),
221 | )
222 | if downsample
223 | else None
224 | )
225 |
226 | def forward(self, x):
227 | res = self.conv_res(x)
228 | x = self.net(x)
229 | if self.downsample is not None:
230 | x = self.downsample(x)
231 | x = (x + res) / sqrt(2)
232 | return x
233 |
234 |
235 | class Blur(nn.Module):
236 | """
237 | Taken from lucidrain repo https://github.com/lucidrains/stylegan2-pytorch
238 |
239 | """
240 |
241 | def __init__(self):
242 | super().__init__()
243 | f = torch.Tensor([1, 2, 1])
244 | self.register_buffer("f", f)
245 |
246 | def forward(self, x):
247 | f = self.f
248 | f = f[None, None, :] * f[None, :, None]
249 | return filter2d(x, f, normalized=True)
250 |
--------------------------------------------------------------------------------
/models/encoders/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nicolas-dufour/SCAM/848484acd71c277f3ed48456e206870847dd400d/models/encoders/__init__.py
--------------------------------------------------------------------------------
/models/encoders/groupdnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import torch.nn.utils.spectral_norm as spectral_norm
5 | from models.utils_blocks.base import BaseNetwork
6 |
7 | class GroupDNetStyleEncoder(BaseNetwork):
8 | def __init__(self, num_labels):
9 | super().__init__()
10 | self.num_labels = num_labels
11 | kw = 3
12 | pw = 1
13 | ndf = 32 * num_labels
14 | self.layer1 = nn.Sequential(
15 | spectral_norm(
16 | nn.Conv2d(
17 | 3 * num_labels, ndf, kw, stride=2, padding=pw, groups=num_labels
18 | )
19 | ),
20 | nn.InstanceNorm2d(ndf, affine=False),
21 | )
22 | self.layer2 = nn.Sequential(
23 | spectral_norm(
24 | nn.Conv2d(ndf * 1, ndf * 2, kw, stride=2, padding=pw, groups=num_labels)
25 | ),
26 | nn.InstanceNorm2d(ndf * 2, affine=False),
27 | )
28 | self.layer3 = nn.Sequential(
29 | spectral_norm(
30 | nn.Conv2d(ndf * 2, ndf * 4, kw, stride=2, padding=pw, groups=num_labels)
31 | ),
32 | nn.InstanceNorm2d(ndf * 4, affine=False),
33 | )
34 | self.layer4 = nn.Sequential(
35 | spectral_norm(
36 | nn.Conv2d(ndf * 4, ndf * 8, kw, stride=2, padding=pw, groups=num_labels)
37 | ),
38 | nn.InstanceNorm2d(ndf * 8, affine=False),
39 | )
40 | self.layer5 = nn.Sequential(
41 | spectral_norm(
42 | nn.Conv2d(ndf * 8, ndf * 8, kw, stride=2, padding=pw, groups=num_labels)
43 | ),
44 | nn.InstanceNorm2d(ndf * 8, affine=False),
45 | )
46 | self.layer6 = nn.Sequential(
47 | spectral_norm(
48 | nn.Conv2d(ndf * 8, ndf * 8, kw, stride=2, padding=pw, groups=num_labels)
49 | ),
50 | nn.InstanceNorm2d(ndf * 8, affine=False),
51 | )
52 |
53 | self.so = s0 = 4
54 | self.fc_mu = nn.Conv2d(
55 | ndf * 8, num_labels * 8, kw, stride=1, padding=pw, groups=num_labels
56 | )
57 | self.fc_var = nn.Conv2d(
58 | ndf * 8, num_labels * 8, kw, stride=1, padding=pw, groups=num_labels
59 | )
60 |
61 | self.actvn = nn.LeakyReLU(0.2, False)
62 |
63 | def trans_img(self, input_semantics, real_image):
64 | images = None
65 | seg_range = input_semantics.size()[1]
66 | for i in range(input_semantics.size(0)):
67 | resize_image = None
68 | for n in range(0, seg_range):
69 | seg_image = real_image[i] * input_semantics[i][n]
70 | # resize seg_image
71 | c_sum = seg_image.sum(dim=0)
72 | y_seg = c_sum.sum(dim=0)
73 | x_seg = c_sum.sum(dim=1)
74 | y_id = y_seg.nonzero()
75 | if y_id.size()[0] == 0:
76 | seg_image = seg_image.unsqueeze(dim=0)
77 | # resize_image = torch.cat((resize_image, seg_image), dim=0)
78 | if resize_image is None:
79 | resize_image = seg_image
80 | else:
81 | resize_image = torch.cat((resize_image, seg_image), dim=1)
82 | continue
83 | # print(y_id)
84 | y_min = y_id[0][0]
85 | y_max = y_id[-1][0]
86 | x_id = x_seg.nonzero()
87 | x_min = x_id[0][0]
88 | x_max = x_id[-1][0]
89 | seg_image = seg_image.unsqueeze(dim=0)
90 | seg_image = F.interpolate(
91 | seg_image[:, :, x_min : x_max + 1, y_min : y_max + 1],
92 | size=[256, 256],
93 | )
94 | if resize_image is None:
95 | resize_image = seg_image
96 | else:
97 | resize_image = torch.cat((resize_image, seg_image), dim=1)
98 | if images is None:
99 | images = resize_image
100 | else:
101 | images = torch.cat((images, resize_image), dim=0)
102 | return images
103 |
104 | def forward(self, image, segmap=None):
105 | image = self.trans_img(segmap, image)
106 | image = self.layer1(image)
107 | image = self.layer2(self.actvn(image))
108 | image = self.layer3(self.actvn(image))
109 | image = self.layer4(self.actvn(image))
110 | image = self.layer5(self.actvn(image))
111 | image = self.layer6(self.actvn(image))
112 |
113 | image = self.actvn(image)
114 |
115 | mu = self.fc_mu(image)
116 | logvar = self.fc_var(image)
117 |
118 | return [mu, logvar], None
--------------------------------------------------------------------------------
/models/encoders/inade.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from models.utils_blocks.base import BaseNetwork
4 | from utils.partial_conv import InstanceAwareConv2d
5 |
6 | class InstanceAdaptiveEncoder(BaseNetwork):
7 | def __init__(self, num_labels, noise_dim, use_vae=True):
8 | super().__init__()
9 | kw = 3
10 | pw = 1
11 | ndf = 64
12 | conv_layer = InstanceAwareConv2d
13 |
14 | self.layer1 = conv_layer(3, ndf, kw, stride=2, padding=pw)
15 | self.norm1 = nn.InstanceNorm2d(ndf)
16 | self.layer2 = conv_layer(ndf * 1, ndf * 2, kw, stride=2, padding=pw)
17 | self.norm2 = nn.InstanceNorm2d(ndf * 2)
18 | self.layer3 = conv_layer(ndf * 2, ndf * 4, kw, stride=2, padding=pw)
19 | self.norm3 = nn.InstanceNorm2d(ndf * 4)
20 | self.layer4 = conv_layer(ndf * 4, ndf * 8, kw, stride=2, padding=pw)
21 | self.norm4 = nn.InstanceNorm2d(ndf * 8)
22 |
23 | self.middle = conv_layer(ndf * 8, ndf * 4, kw, stride=1, padding=pw)
24 | self.norm_middle = nn.InstanceNorm2d(ndf * 4)
25 | self.up1 = conv_layer(ndf * 8, ndf * 2, kw, stride=1, padding=pw)
26 | self.norm_up1 = nn.InstanceNorm2d(ndf * 2)
27 | self.up2 = conv_layer(ndf * 4, ndf * 1, kw, stride=1, padding=pw)
28 | self.norm_up2 = nn.InstanceNorm2d(ndf)
29 | self.up3 = conv_layer(ndf * 2, ndf, kw, stride=1, padding=pw)
30 | self.norm_up3 = nn.InstanceNorm2d(ndf)
31 |
32 | self.up = nn.Upsample(scale_factor=2, mode="bilinear")
33 | self.num_labels = num_labels
34 |
35 | self.scale_conv_mu = conv_layer(ndf, noise_dim, kw, stride=1, padding=pw)
36 | self.scale_conv_var = conv_layer(ndf, noise_dim, kw, stride=1, padding=pw)
37 | self.bias_conv_mu = conv_layer(ndf, noise_dim, kw, stride=1, padding=pw)
38 | self.bias_conv_var = conv_layer(ndf, noise_dim, kw, stride=1, padding=pw)
39 |
40 | self.actvn = nn.LeakyReLU(0.2, False)
41 |
42 | def instAvgPooling(self, x, instances):
43 | inst_num = instances.size()[1]
44 | for i in range(inst_num):
45 | inst_mask = torch.unsqueeze(instances[:, i, :, :], 1) # [n,1,h,w]
46 | pixel_num = torch.sum(
47 | torch.sum(inst_mask, dim=2, keepdim=True), dim=3, keepdim=True
48 | )
49 | pixel_num[pixel_num == 0] = 1
50 | feat = x * inst_mask
51 | feat = (
52 | torch.sum(torch.sum(feat, dim=2, keepdim=True), dim=3, keepdim=True)
53 | / pixel_num
54 | )
55 | if i == 0:
56 | out = torch.unsqueeze(feat[:, :, 0, 0], 1) # [n,1,c]
57 | else:
58 | out = torch.cat([out, torch.unsqueeze(feat[:, :, 0, 0], 1)], 1)
59 | # inst_pool_feats.append(feat[:,:,0,0]) # [n, 64]
60 | return out
61 |
62 | def forward(self, real_image, input_semantics):
63 | # instances [n,1,h,w], input_instances [n,inst_nc,h,w]
64 | instances = torch.argmax(input_semantics, 1, keepdim=True).float()
65 | x1 = self.actvn(self.norm1(self.layer1(real_image, instances)))
66 | x2 = self.actvn(self.norm2(self.layer2(x1, instances)))
67 | x3 = self.actvn(self.norm3(self.layer3(x2, instances)))
68 | x4 = self.actvn(self.norm4(self.layer4(x3, instances)))
69 | y = self.up(self.actvn(self.norm_middle(self.middle(x4, instances))))
70 | y1 = self.up(
71 | self.actvn(self.norm_up1(self.up1(torch.cat([y, x3], 1), instances)))
72 | )
73 | y2 = self.up(
74 | self.actvn(self.norm_up2(self.up2(torch.cat([y1, x2], 1), instances)))
75 | )
76 | y3 = self.up(
77 | self.actvn(self.norm_up3(self.up3(torch.cat([y2, x1], 1), instances)))
78 | )
79 |
80 | scale_mu = self.scale_conv_mu(y3, instances)
81 | scale_var = self.scale_conv_var(y3, instances)
82 | bias_mu = self.bias_conv_mu(y3, instances)
83 | bias_var = self.bias_conv_var(y3, instances)
84 |
85 | scale_mus = self.instAvgPooling(scale_mu, input_semantics)
86 | scale_vars = self.instAvgPooling(scale_var, input_semantics)
87 | bias_mus = self.instAvgPooling(bias_mu, input_semantics)
88 | bias_vars = self.instAvgPooling(bias_var, input_semantics)
89 |
90 | return (scale_mus, scale_vars, bias_mus, bias_vars), None
91 |
--------------------------------------------------------------------------------
/models/encoders/sat.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from einops import repeat, rearrange
6 |
7 | from torchtyping import TensorType
8 | from typing import Tuple
9 | from collections import OrderedDict
10 | from models.utils_blocks.base import BaseNetwork
11 |
12 |
13 | from models.utils_blocks.attention import (
14 | MaskedTransformer,
15 | SinusoidalPositionalEmbedding,
16 | )
17 |
18 | from models.utils_blocks.equallr import EqualConv2d
19 | from functools import partial
20 |
21 |
22 | class SemanticAttentionTransformerEncoder(BaseNetwork):
23 | """
24 | Encoder that encode style vectors for each segmentation labels doing regional average pooling
25 |
26 | Parameters:
27 | -----------
28 | num_input_channels: int,
29 | Number of input channels
30 | latent_dim: int,
31 | Number of output channels (style_dim)
32 | num_features_fst_conv: iny,
33 | Number of kernels at first conv
34 |
35 | """
36 |
37 | def __init__(
38 | self,
39 | num_input_channels: int,
40 | positional_embedding_dim: int,
41 | num_labels: int,
42 | num_latent_per_labels: int,
43 | latent_dim: int,
44 | num_blocks: int,
45 | attention_latent_dim: int,
46 | num_cross_heads: int,
47 | num_self_heads: int,
48 | type_of_initial_latents: str = "learned",
49 | image_conv: bool = False,
50 | reverse_conv: bool = False,
51 | num_latents_bg: int = None,
52 | conv_features_dim_first: int = 16,
53 | content_dim: int = 512,
54 | use_vae: bool = False,
55 | use_self_attention: bool = True,
56 | use_equalized_lr: bool = False,
57 | lr_mul: float = 1.0,
58 | ):
59 | super().__init__()
60 | nf = conv_features_dim_first
61 | if num_latents_bg is None:
62 | num_latents_bg = num_latent_per_labels
63 | if image_conv:
64 | self.image_pos_embs = nn.ModuleList(
65 | [
66 | SinusoidalPositionalEmbedding(
67 | positional_embedding_dim, emb_type="concat"
68 | )
69 | ]
70 | )
71 | self.convs = nn.ModuleList([nn.Identity()])
72 |
73 | else:
74 | self.image_pos_emb = SinusoidalPositionalEmbedding(
75 | positional_embedding_dim, emb_type="concat"
76 | )
77 | ConvLayer = (
78 | partial(EqualConv2d, lr_mul=lr_mul) if use_equalized_lr else nn.Conv2d
79 | )
80 | self.lr_mul = lr_mul
81 | self.reverse_conv = reverse_conv
82 | self.image_conv = image_conv
83 | self.use_self_attention = use_self_attention
84 |
85 | image_emb_dim = num_input_channels + positional_embedding_dim
86 |
87 | num_latents = (num_labels - 1) * num_latent_per_labels + num_latents_bg
88 | self.num_latent_per_labels = num_latent_per_labels
89 | self.num_latents_bg = num_latents_bg
90 |
91 | self.return_attention = False
92 |
93 | latents_mask = [
94 | torch.FloatTensor(
95 | [[1.0 for _ in range(num_latents_bg)] for _ in range(num_latents_bg)]
96 | )
97 | ] + [
98 | torch.FloatTensor(
99 | [
100 | [1.0 for _ in range(num_latent_per_labels)]
101 | for _ in range(num_latent_per_labels)
102 | ]
103 | )
104 | for _ in range(num_labels - 1)
105 | ]
106 |
107 | latents_mask = torch.block_diag(*latents_mask).unsqueeze(0)
108 | self.register_buffer("latents_mask", latents_mask)
109 |
110 | self.type_of_initial_latents = type_of_initial_latents
111 | self.latent_dim = latent_dim
112 | self.num_latents = num_latents
113 | if self.type_of_initial_latents == "learned":
114 | self.latents = nn.Parameter(
115 | torch.randn(num_latents, latent_dim).div_(lr_mul)
116 | )
117 | elif self.type_of_initial_latents == "fixed_random":
118 | self.latents = torch.randn(num_latents, latent_dim)
119 |
120 | self.backbone = nn.ModuleDict(OrderedDict())
121 |
122 | for i in range(num_blocks):
123 | module = nn.ModuleDict(OrderedDict())
124 | if reverse_conv:
125 | module.update(
126 | {
127 | "cross_attention": MaskedTransformer(
128 | latent_dim,
129 | num_latents,
130 | image_emb_dim
131 | if i == num_blocks - 1 or not image_conv
132 | else nf * 2 ** (num_blocks - 1 - i),
133 | attention_latent_dim,
134 | num_cross_heads,
135 | use_equalized_lr=use_equalized_lr,
136 | lr_mul=lr_mul,
137 | ),
138 | }
139 | )
140 | else:
141 | module.update(
142 | {
143 | "cross_attention": MaskedTransformer(
144 | latent_dim,
145 | num_latents,
146 | image_emb_dim if i == 0 or not image_conv else nf * 2**i,
147 | attention_latent_dim,
148 | num_cross_heads,
149 | use_equalized_lr=use_equalized_lr,
150 | lr_mul=lr_mul,
151 | ),
152 | }
153 | )
154 | if use_self_attention:
155 | module.update(
156 | {
157 | "self_attention": MaskedTransformer(
158 | latent_dim,
159 | num_latents,
160 | latent_dim,
161 | attention_latent_dim,
162 | num_self_heads,
163 | use_equalized_lr=use_equalized_lr,
164 | lr_mul=lr_mul,
165 | ),
166 | }
167 | )
168 | self.backbone.update({f"block_{i}": module})
169 | if i > 0 and image_conv:
170 | self.convs.append(
171 | nn.Sequential(
172 | ConvLayer(
173 | num_input_channels if i == 1 else nf * 2 ** (i - 1),
174 | nf * 2**i,
175 | kernel_size=3,
176 | padding=1,
177 | stride=2,
178 | ),
179 | nn.LeakyReLU(0.2),
180 | )
181 | )
182 | self.image_pos_embs.append(
183 | SinusoidalPositionalEmbedding(nf * 2**i, emb_type="add")
184 | )
185 | self.use_vae = use_vae
186 |
187 | if self.use_vae:
188 | self.vae_mapper = nn.ModuleList([nn.Linear(latent_dim, 2*latent_dim) for _ in range(num_latents)])
189 |
190 | def forward(
191 | self,
192 | input: TensorType["batch_size", "num_input_channels", "height", "width"],
193 | segmentation_map: TensorType["batch_size", "num_labels", "height", "width"],
194 | ) -> Tuple[
195 | TensorType["batch_size", "num_segmap_labels", "style_dim"],
196 | TensorType[
197 | "batch_size",
198 | "output_dim",
199 | "output_fmap_heigth",
200 | "output_fmap_width",
201 | ],
202 | ]:
203 | batch_size = input.shape[0]
204 | if self.type_of_initial_latents == "learned":
205 | latents = repeat(self.latents, " l d -> b l d", b=batch_size) * self.lr_mul
206 | elif self.type_of_initial_latents == "fixed_random":
207 | latents = (
208 | repeat(self.latents.to(input.device), " l d -> b l d", b=batch_size)
209 | * self.lr_mul
210 | )
211 | elif self.type_of_initial_latents == "random":
212 | latents = torch.randn(
213 | batch_size, self.num_latents, self.latent_dim, device=input.device
214 | )
215 |
216 | if self.return_attention:
217 | self.image_sizes = []
218 |
219 | if self.image_conv:
220 | cross_attention_masks = []
221 | flattened_inputs = []
222 | for i, conv in enumerate(self.convs):
223 | input = conv(input)
224 | if i == len(self.convs) - 1:
225 | output_fmap = input
226 |
227 | interpolated_segmap = F.interpolate(
228 | segmentation_map, size=input.size()[2:], mode="nearest"
229 | )
230 | if self.return_attention:
231 | self.image_sizes.append(input.size()[2:])
232 | cross_attention_masks.append(
233 | rearrange(
234 | torch.cat(
235 | [
236 | torch.repeat_interleave(
237 | interpolated_segmap[:, 0].unsqueeze(1),
238 | self.num_latents_bg,
239 | dim=1,
240 | ),
241 | torch.repeat_interleave(
242 | interpolated_segmap[:, 1:],
243 | self.num_latent_per_labels,
244 | dim=1,
245 | ),
246 | ],
247 | dim=1,
248 | ),
249 | "b n h w -> b n (h w)",
250 | )
251 | )
252 |
253 | flattened_inputs.append(
254 | rearrange(self.image_pos_embs[i](input), " b c h w -> b (h w) c")
255 | )
256 | if self.reverse_conv:
257 | cross_attention_masks = cross_attention_masks[::-1]
258 | flattened_inputs = flattened_inputs[::-1]
259 | else:
260 | cross_attention_mask = rearrange(
261 | torch.cat(
262 | [
263 | torch.repeat_interleave(
264 | segmentation_map[:, 0].unsqueeze(1),
265 | self.num_latents_bg,
266 | dim=1,
267 | ),
268 | torch.repeat_interleave(
269 | segmentation_map[:, 1:],
270 | self.num_latent_per_labels,
271 | dim=1,
272 | ),
273 | ],
274 | dim=1,
275 | ),
276 | "b n h w -> b n (h w)",
277 | )
278 | flattened_input = rearrange(
279 | self.image_pos_emb(input), " b c h w -> b (h w) c"
280 | )
281 |
282 | for i, block_name in enumerate(self.backbone):
283 | if self.image_conv:
284 | if self.return_attention:
285 | latents, _ = self.backbone[block_name]["cross_attention"](
286 | latents,
287 | flattened_inputs[i],
288 | cross_attention_masks[i],
289 | return_attention=self.return_attention,
290 | )
291 | else:
292 | latents = self.backbone[block_name]["cross_attention"](
293 | latents, flattened_inputs[i], cross_attention_masks[i]
294 | )
295 | else:
296 | if self.return_attention:
297 | latents, _ = self.backbone[block_name]["cross_attention"](
298 | latents,
299 | flattened_input,
300 | cross_attention_mask,
301 | return_attention=self.return_attention,
302 | )
303 | else:
304 | latents = self.backbone[block_name]["cross_attention"](
305 | latents, flattened_input, cross_attention_mask
306 | )
307 | if self.use_self_attention:
308 | latents = self.backbone[block_name]["self_attention"](
309 | latents, latents, self.latents_mask
310 | )
311 | if self.use_vae:
312 | latents_vae = torch.zeros((*latents.shape[:-1], 2*latents.shape[-1]), device=latents.device)
313 | for i, mapper in enumerate(self.vae_mapper):
314 | latents_vae[:, i, :] = mapper(latents[:, i, :])
315 | latents = latents_vae.chunk(2, dim=-1)
316 | return latents, None
317 |
--------------------------------------------------------------------------------
/models/encoders/sean.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | from torchtyping import TensorType
7 | from collections import OrderedDict
8 | from models.utils_blocks.base import BaseNetwork
9 |
10 |
11 | class RegionalAveragePoolingStyleEncoder(BaseNetwork):
12 | """
13 | Encoder that encode style vectors for each segmentation labels doing regional average pooling
14 |
15 | Parameters:
16 | -----------
17 | num_input_channels: int,
18 | Number of input channels
19 | latent_dim: int,
20 | Number of output channels (style_dim)
21 | num_features_fst_conv: iny,
22 | Number of kernels at first conv
23 |
24 | """
25 |
26 | def __init__(
27 | self,
28 | num_input_channels: int,
29 | latent_dim: int,
30 | num_features_fst_conv: int = 32,
31 | use_vae: bool = False,
32 | ):
33 | super().__init__()
34 |
35 | nffc = num_features_fst_conv
36 | self.first_layer = nn.Sequential(
37 | nn.ReflectionPad2d(1),
38 | nn.Conv2d(num_input_channels, nffc, kernel_size=3, padding=0),
39 | nn.InstanceNorm2d(nffc),
40 | nn.LeakyReLU(0.2, False),
41 | )
42 | self.bottleneck = nn.ModuleDict(OrderedDict())
43 |
44 | for i in range(2):
45 | mult = 2**i
46 | module = nn.Sequential(
47 | nn.Conv2d(
48 | mult * nffc,
49 | 2 * mult * nffc,
50 | kernel_size=3,
51 | stride=2,
52 | padding=1,
53 | ),
54 | nn.InstanceNorm2d(2 * mult * nffc),
55 | nn.LeakyReLU(0.2, False),
56 | )
57 | self.bottleneck.update({f"down_{i}": module})
58 |
59 | self.bottleneck.update(
60 | {
61 | f"up_{1}": nn.Sequential(
62 | nn.ConvTranspose2d(
63 | 4 * nffc,
64 | nffc * 8,
65 | kernel_size=3,
66 | stride=2,
67 | padding=1,
68 | output_padding=1,
69 | ),
70 | nn.InstanceNorm2d(4 * nffc),
71 | nn.LeakyReLU(0.2, False),
72 | )
73 | }
74 | )
75 |
76 | self.last_layer = nn.Sequential(
77 | nn.ReflectionPad2d(1),
78 | nn.Conv2d(nffc * 8, latent_dim, kernel_size=3, padding=0),
79 | nn.InstanceNorm2d(nffc),
80 | nn.Tanh(),
81 | )
82 |
83 | self.use_vae = use_vae
84 | if self.use_vae:
85 | self.vae_mapper = nn.Linear(latent_dim, 2*latent_dim)
86 | def forward(
87 | self,
88 | input: TensorType["batch_size", "num_input_channels", "height", "width"],
89 | segmentation_map: TensorType["batch_size", "num_labels", "height", "width"],
90 | ) -> TensorType["batch_size", "num_input_channels", "style_dim"]:
91 | x = self.first_layer(input)
92 | for _, block_name in enumerate(self.bottleneck):
93 | x = self.bottleneck[block_name](x)
94 | x = self.last_layer(x)
95 | segmentation_map = F.interpolate(
96 | segmentation_map, size=x.size()[2:], mode="nearest"
97 | )
98 | (batch_size, style_dim, *_) = x.shape
99 | num_labels = segmentation_map.shape[1]
100 | style_codes = torch.zeros(batch_size, num_labels, style_dim, device=x.device)
101 | for i in range(num_labels):
102 | num_components = segmentation_map[:, i].unsqueeze(1).sum((2, 3))
103 | num_components[num_components == 0] = 1
104 | style_codes[:, i] = (segmentation_map[:, i].unsqueeze(1) * x).sum(
105 | (2, 3)
106 | ) / num_components
107 | if self.use_vae:
108 | style_codes = self.vae_mapper(style_codes)
109 | style_codes = style_codes.chunk(2, dim=-1)
110 | return style_codes, None
111 |
--------------------------------------------------------------------------------
/models/encoders/spade.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 | import torch.nn.utils.spectral_norm as spectral_norm
4 | from models.utils_blocks.base import BaseNetwork
5 |
6 |
7 | class SPADEStyleEncoder(BaseNetwork):
8 | """
9 | Encoder that encode one style vector for the whole image as done in SPADE.
10 |
11 | Parameters:
12 | -----------
13 |
14 |
15 | """
16 | def __init__(self, use_vae=True):
17 | super().__init__()
18 | kw = 3
19 | pw = 1
20 | ndf = 64
21 | self.layer1 = nn.Sequential(
22 | spectral_norm(nn.Conv2d(3, ndf, kw, stride=2, padding=pw)),
23 | nn.InstanceNorm2d(ndf, affine=False),
24 | )
25 | self.layer2 = nn.Sequential(
26 | spectral_norm(nn.Conv2d(ndf * 1, ndf * 2, kw, stride=2, padding=pw)),
27 | nn.InstanceNorm2d(ndf * 2, affine=False),
28 | )
29 | self.layer3 = nn.Sequential(
30 | spectral_norm(nn.Conv2d(ndf * 2, ndf * 4, kw, stride=2, padding=pw)),
31 | nn.InstanceNorm2d(ndf * 4, affine=False),
32 | )
33 | self.layer4 = nn.Sequential(
34 | spectral_norm(nn.Conv2d(ndf * 4, ndf * 8, kw, stride=2, padding=pw)),
35 | nn.InstanceNorm2d(ndf * 8, affine=False),
36 | )
37 | self.layer5 = nn.Sequential(
38 | spectral_norm(nn.Conv2d(ndf * 8, ndf * 8, kw, stride=2, padding=pw)),
39 | nn.InstanceNorm2d(ndf * 8, affine=False),
40 | )
41 | self.layer6 = nn.Sequential(
42 | spectral_norm(nn.Conv2d(ndf * 8, ndf * 8, kw, stride=2, padding=pw)),
43 | nn.InstanceNorm2d(ndf * 8, affine=False),
44 | )
45 |
46 | self.so = s0 = 4
47 | self.fc_mu = nn.Linear(ndf * 8 * s0 * s0, 256)
48 | self.fc_var = nn.Linear(ndf * 8 * s0 * s0, 256)
49 |
50 | self.actvn = nn.LeakyReLU(0.2, False)
51 |
52 | def forward(self, image, segmap=None):
53 | if image.shape[2] != 256 or image.shape[3] != 256:
54 | image = F.interpolate(image, size=(256, 256), mode="bilinear")
55 | image = self.layer1(image)
56 | image = self.layer2(self.actvn(image))
57 | image = self.layer3(self.actvn(image))
58 | image = self.layer4(self.actvn(image))
59 | image = self.layer5(self.actvn(image))
60 | image = self.layer6(self.actvn(image))
61 |
62 | image = self.actvn(image)
63 |
64 | image = image.view(image.size(0), -1)
65 | mu = self.fc_mu(image)
66 | logvar = self.fc_var(image)
67 |
68 | return [mu, logvar], None
69 |
--------------------------------------------------------------------------------
/models/generators/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nicolas-dufour/SCAM/848484acd71c277f3ed48456e206870847dd400d/models/generators/__init__.py
--------------------------------------------------------------------------------
/models/utils_blocks/EMA.py:
--------------------------------------------------------------------------------
1 | class EMA:
2 | def __init__(self, beta):
3 | super().__init__()
4 | self.beta = beta
5 |
6 | def update_average(self, old, new):
7 | if old is None:
8 | return new
9 | return old * self.beta + (1 - self.beta) * new
10 |
--------------------------------------------------------------------------------
/models/utils_blocks/SAM.py:
--------------------------------------------------------------------------------
1 | ### Taken from https://github.com/davda54/sam/blob/main/sam.py
2 | import torch
3 |
4 |
5 | class SAM(torch.optim.Optimizer):
6 | def __init__(self, params, rho=0.05, adaptive=False, **kwargs):
7 | assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}"
8 |
9 | defaults = dict(rho=rho, adaptive=adaptive, **kwargs)
10 | super(SAM, self).__init__(params, defaults)
11 |
12 | self.base_optimizer = torch.optim.AdamW(self.param_groups, **kwargs)
13 | self.param_groups = self.base_optimizer.param_groups
14 |
15 | @torch.no_grad()
16 | def first_step(self, zero_grad=False):
17 | grad_norm = self._grad_norm()
18 | for group in self.param_groups:
19 | scale = group["rho"] / (grad_norm + 1e-12)
20 |
21 | for p in group["params"]:
22 | if p.grad is None:
23 | continue
24 | self.state[p]["old_p"] = p.data.clone()
25 | e_w = (
26 | (torch.pow(p, 2) if group["adaptive"] else 1.0)
27 | * p.grad
28 | * scale.to(p)
29 | )
30 | p.add_(e_w) # climb to the local maximum "w + e(w)"
31 |
32 | if zero_grad:
33 | self.zero_grad()
34 |
35 | @torch.no_grad()
36 | def second_step(self, zero_grad=False):
37 | for group in self.param_groups:
38 | for p in group["params"]:
39 | if p.grad is None:
40 | continue
41 | p.data = self.state[p]["old_p"] # get back to "w" from "w + e(w)"
42 |
43 | self.base_optimizer.step() # do the actual "sharpness-aware" update
44 |
45 | if zero_grad:
46 | self.zero_grad()
47 |
48 | @torch.no_grad()
49 | def step(self, closure=None):
50 | assert (
51 | closure is not None
52 | ), "Sharpness Aware Minimization requires closure, but it was not provided"
53 | closure = torch.enable_grad()(
54 | closure
55 | ) # the closure should do a full forward-backward pass
56 |
57 | self.first_step(zero_grad=True)
58 | closure()
59 | self.second_step()
60 |
61 | def _grad_norm(self):
62 | shared_device = self.param_groups[0]["params"][
63 | 0
64 | ].device # put everything on the same device, in case of model parallelism
65 | norm = torch.norm(
66 | torch.stack(
67 | [
68 | ((torch.abs(p) if group["adaptive"] else 1.0) * p.grad)
69 | .norm(p=2)
70 | .to(shared_device)
71 | for group in self.param_groups
72 | for p in group["params"]
73 | if p.grad is not None
74 | ]
75 | ),
76 | p=2,
77 | )
78 | return norm
79 |
80 | def load_state_dict(self, state_dict):
81 | super().load_state_dict(state_dict)
82 | self.base_optimizer.param_groups = self.param_groups
83 |
--------------------------------------------------------------------------------
/models/utils_blocks/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nicolas-dufour/SCAM/848484acd71c277f3ed48456e206870847dd400d/models/utils_blocks/__init__.py
--------------------------------------------------------------------------------
/models/utils_blocks/attention.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import math
5 |
6 | from torchtyping import TensorType
7 | from einops import repeat, rearrange
8 | from models.utils_blocks.equallr import EqualLinear
9 | from functools import partial
10 |
11 |
12 | class MaskedAttention(nn.Module):
13 | """
14 | Masked Attention Module. Can be used for both self attention and cross attention
15 |
16 | to_dim: int,
17 | The dimension of the query token dim
18 | from_dim: int,
19 | The Dimension of the key token dim
20 | num_heads: int,
21 | Number of attention heads
22 | """
23 |
24 | def __init__(
25 | self,
26 | to_dim: int,
27 | from_dim: int,
28 | latent_dim: int,
29 | num_heads: int,
30 | use_equalized_lr: bool = False,
31 | lr_mul: float = 1.0,
32 | ):
33 | super().__init__()
34 | self.latent_dim_k = latent_dim // num_heads
35 | self.num_heads = num_heads
36 | LinearLayer = (
37 | partial(EqualLinear, lr_mul=lr_mul) if use_equalized_lr else nn.Linear
38 | )
39 |
40 | # Mappings Query, Key Values
41 | self.q_linear = LinearLayer(to_dim, latent_dim)
42 | self.v_linear = LinearLayer(from_dim, latent_dim)
43 | self.k_linear = LinearLayer(from_dim, latent_dim)
44 |
45 | # Final output mapping
46 | self.out = nn.Sequential(LinearLayer(latent_dim, to_dim))
47 |
48 | def forward(
49 | self,
50 | X_to: TensorType["batch_size", "num_to_tokens", "to_dim"],
51 | X_from: TensorType["batch_size", "num_from_tokens", "from_dim"],
52 | mask_from: TensorType["batch_size", "num_to_tokens", "num_from_tokens"] = None,
53 | return_attention: bool = False,
54 | ):
55 | Q = rearrange(self.q_linear(X_to), " b t (h k) -> b h t k ", h=self.num_heads)
56 | K = rearrange(self.v_linear(X_from), " b t (h k) -> b h t k ", h=self.num_heads)
57 | V = rearrange(self.k_linear(X_from), " b t (h k) -> b h t k ", h=self.num_heads)
58 |
59 | attn = torch.einsum("bhtk,bhfk->bhtf", [Q, K]) / math.sqrt(self.latent_dim_k)
60 |
61 | if mask_from is not None:
62 | mask_from = mask_from.unsqueeze(1)
63 | attn = attn.masked_fill(mask_from == 0, -1e4)
64 |
65 | attn = F.softmax(attn, dim=-1)
66 |
67 | output = torch.einsum("bhtf,bhfk->bhtk", [attn, V])
68 | output = rearrange(output, "b h t k -> b t (h k)")
69 | output = self.out(output)
70 |
71 | if return_attention:
72 | return output, attn
73 | else:
74 | return output
75 |
76 |
77 | class MaskedTransformer(nn.Module):
78 | def __init__(
79 | self,
80 | to_dim,
81 | to_len,
82 | from_dim,
83 | latent_dim,
84 | num_heads,
85 | use_equalized_lr=False,
86 | lr_mul=1,
87 | ):
88 | super().__init__()
89 | self.attention = MaskedAttention(
90 | to_dim,
91 | from_dim,
92 | latent_dim,
93 | num_heads,
94 | use_equalized_lr=use_equalized_lr,
95 | lr_mul=lr_mul,
96 | )
97 |
98 | LinearLayer = (
99 | partial(EqualLinear, lr_mul=lr_mul) if use_equalized_lr else nn.Linear
100 | )
101 | self.ln_1 = nn.LayerNorm((to_len, to_dim))
102 | self.fc = nn.Sequential(
103 | LinearLayer(to_dim, to_dim),
104 | nn.LeakyReLU(2e-1),
105 | LinearLayer(to_dim, to_dim),
106 | nn.LeakyReLU(2e-1),
107 | )
108 | self.ln_2 = nn.LayerNorm((to_len, to_dim))
109 |
110 | def forward(self, X_to, X_from, mask=None, return_attention=False):
111 | if return_attention:
112 | X_to_out, attn = self.attention(X_to, X_from, mask, return_attention)
113 | else:
114 | X_to_out = self.attention(X_to, X_from, mask, return_attention)
115 | X_to = self.ln_1(X_to_out + X_to)
116 | X_to_out = self.fc(X_to)
117 | X_to = self.ln_2(X_to_out + X_to)
118 | if return_attention:
119 | return X_to, attn
120 | else:
121 | return X_to
122 |
123 |
124 | class SinusoidalPositionalEmbedding(nn.Module):
125 | """
126 | Sinusoidal Positional Encoder
127 | Adapted from https://github.com/lucidrains/transganformer
128 |
129 | dim: int,
130 | Tokens dimension
131 | """
132 |
133 | def __init__(self, dim: int, emb_type: str = "add"):
134 | super().__init__()
135 | dim //= 2
136 | inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
137 | self.emb_type = emb_type
138 | self.register_buffer("inv_freq", inv_freq)
139 |
140 | def forward(self, x: TensorType["batch_size", "num_token", "dim"]):
141 | h = torch.linspace(-1.0, 1.0, x.shape[-2], device=x.device).type_as(
142 | self.inv_freq
143 | )
144 | w = torch.linspace(-1.0, 1.0, x.shape[-1], device=x.device).type_as(
145 | self.inv_freq
146 | )
147 | sinu_inp_h = torch.einsum("i , j -> i j", h, self.inv_freq)
148 | sinu_inp_w = torch.einsum("i , j -> i j", w, self.inv_freq)
149 | sinu_inp_h = repeat(sinu_inp_h, "h c -> () c h w", w=x.shape[-1])
150 | sinu_inp_w = repeat(sinu_inp_w, "w c -> () c h w", h=x.shape[-2])
151 | sinu_inp = torch.cat((sinu_inp_w, sinu_inp_h), dim=1)
152 | emb = torch.cat((sinu_inp.sin(), sinu_inp.cos()), dim=1)
153 | if self.emb_type == "add":
154 | x_emb = x + emb
155 | elif self.emb_type == "concat":
156 | emb = repeat(emb, "1 ... -> b ...", b=x.shape[0])
157 | x_emb = torch.cat([x, emb], dim=1)
158 | return x_emb
159 |
160 |
161 | class LearnedPositionalEmbedding(nn.Module):
162 | """
163 | Learned Positional Embedding
164 |
165 | Parameters:
166 | -----------
167 | num_tokens_max: int,
168 | Max size of the sequence lenght
169 | dim_tokens: int,
170 | Size of the embedding dim
171 | """
172 |
173 | def __init__(self, num_tokens_max: int, dim_tokens: int):
174 | super().__init__()
175 | self.num_tokens_max = num_tokens_max
176 | self.dim_tokens = dim_tokens
177 | self.weights = nn.Parameter(torch.Tensor(num_tokens_max, dim_tokens))
178 |
179 | def forward(self, x: TensorType["batch_size", "num_tokens", "dim_tokens"]):
180 | _, num_tokens = x.shape[:2]
181 | assert num_tokens <= self.num_tokens_max
182 | return x + self.weights[:num_tokens].view(1, num_tokens, self.dim_tokens)
183 |
--------------------------------------------------------------------------------
/models/utils_blocks/base.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 |
4 | class BaseNetwork(nn.Module):
5 | def __init__(self):
6 | super(BaseNetwork, self).__init__()
7 |
8 | def init_weight(self, do_init: bool = True):
9 | def init_func(m):
10 | classname = m.__class__.__name__
11 | if classname.find("BatchNorm2d") != -1:
12 | if hasattr(m, "weight") and m.weight is not None:
13 | nn.init.normal_(m.weight.data, 1.0, 0.02)
14 | if hasattr(m, "bias") and m.bias is not None:
15 | nn.init.constant_(m.bias.data, 0.0)
16 | elif hasattr(m, "weight") and (
17 | classname.find("Conv") != -1 or classname.find("Linear") != -1
18 | ):
19 | nn.init.xavier_normal_(m.weight.data, gain=0.02)
20 |
21 | if do_init:
22 | self.apply(init_func)
23 |
24 | for m in self.children():
25 | if hasattr(m, "init_weights"):
26 | m.init_weights(do_init)
27 |
--------------------------------------------------------------------------------
/models/utils_blocks/equallr.py:
--------------------------------------------------------------------------------
1 | ## Modified from https://github.com/rosinality/stylegan2-pytorch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from math import sqrt
5 | import torch
6 |
7 |
8 | class EqualConv2d(nn.Module):
9 | def __init__(
10 | self,
11 | in_channels,
12 | out_channels,
13 | kernel_size,
14 | stride=1,
15 | padding=0,
16 | bias=True,
17 | lr_mul=1,
18 | ):
19 | super().__init__()
20 |
21 | self.weight = nn.Parameter(
22 | torch.randn(out_channels, in_channels, kernel_size, kernel_size).div_(
23 | lr_mul
24 | ),
25 | requires_grad=True,
26 | )
27 | # torch.nn.init.kaiming_uniform_(self.weight, a=sqrt(5))
28 | self.scale = (1 / sqrt(in_channels * kernel_size**2)) * lr_mul
29 | self.stride = stride
30 | self.padding = padding
31 | self.lr_mul = lr_mul
32 |
33 | if bias:
34 | self.bias = nn.Parameter(torch.zeros(out_channels), requires_grad=True)
35 |
36 | else:
37 | self.bias = None
38 |
39 | def forward(self, input):
40 | out = F.conv2d(
41 | input,
42 | self.weight * self.scale,
43 | bias=self.bias * self.lr_mul,
44 | stride=self.stride,
45 | padding=self.padding,
46 | )
47 |
48 | return out
49 |
50 | def __repr__(self):
51 | return (
52 | f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},"
53 | f" {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})"
54 | )
55 |
56 |
57 | class EqualLinear(nn.Module):
58 | def __init__(self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1):
59 | super().__init__()
60 |
61 | self.weight = nn.Parameter(
62 | torch.randn(out_dim, in_dim).div_(lr_mul), requires_grad=True
63 | )
64 | # torch.nn.init.kaiming_uniform_(self.weight, a=sqrt(5))
65 |
66 | if bias:
67 | self.bias = nn.Parameter(
68 | torch.zeros(out_dim).fill_(bias_init), requires_grad=True
69 | )
70 |
71 | else:
72 | self.bias = None
73 | self.scale = (1 / sqrt(in_dim)) * lr_mul
74 | self.lr_mul = lr_mul
75 |
76 | def forward(self, input):
77 | out = F.linear(input, self.weight * self.scale, bias=self.bias * self.lr_mul)
78 | return out
79 |
80 | def __repr__(self):
81 | return (
82 | f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})"
83 | )
84 |
--------------------------------------------------------------------------------
/preprocess_fid_features.py:
--------------------------------------------------------------------------------
1 | from metrics.fid import compute_fid_features
2 | from data.datamodule import ImageDataModule
3 | import hydra
4 | from pathlib import Path
5 |
6 |
7 | @hydra.main(config_path="configs", config_name="config")
8 | def compute_fid_for_dataset(cfg):
9 | datamodule = ImageDataModule(cfg.dataset)
10 | datamodule.setup()
11 | train_path = Path(cfg.dataset.path) / Path("train/stats")
12 | train_path.mkdir(parents=True, exist_ok=True)
13 | print(f"Computing FID features for train set and saving to {train_path}")
14 | compute_fid_features(datamodule.train_dataloader(), train_path, device=cfg.compnode.accelerator)
15 |
16 | test_path = Path(cfg.dataset.path) / Path("test/stats")
17 | test_path.mkdir(parents=True, exist_ok=True)
18 | print(f"Computing FID features for test set and saving to {test_path}")
19 | compute_fid_features(datamodule.test_dataloader(), test_path, device=cfg.compnode.accelerator)
20 |
21 | if __name__ == "__main__":
22 | compute_fid_for_dataset()
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | from re import T
2 | from data.datamodule import ImageDataModule
3 | from gan import Geppetto
4 | import hydra
5 | import shutil
6 | import wandb
7 | import os
8 | from utils.callbacks import GANImageLog, StopIfNaN, LogAttention
9 | from pytorch_lightning.utilities.rank_zero import _get_rank
10 |
11 | from pathlib import Path
12 |
13 | from omegaconf import OmegaConf
14 |
15 |
16 | @hydra.main(config_path="configs", config_name="config")
17 | def run(cfg):
18 |
19 | # print(OmegaConf.to_yaml(cfg, resolve=True))
20 |
21 | dict_config = OmegaConf.to_container(cfg, resolve=True)
22 |
23 | Path(cfg.checkpoints.dirpath).mkdir(parents=True, exist_ok=True)
24 |
25 | shutil.copyfile(".hydra/config.yaml", f"{cfg.checkpoints.dirpath}/config.yaml")
26 |
27 | log_dict = {}
28 |
29 | log_dict["model"] = dict_config["model"]
30 |
31 | log_dict["dataset"] = dict_config["dataset"]
32 |
33 | datamodule = ImageDataModule(cfg.dataset)
34 |
35 | # logger.log_hyperparams(dict_config)
36 |
37 | checkpoint_callback = hydra.utils.instantiate(cfg.checkpoints)
38 |
39 | image_log_callback = GANImageLog()
40 |
41 | stop_if_nan_callback = StopIfNaN(
42 | ["train/gen_loss_step", "train/disc_loss_step"]
43 | )
44 |
45 | progress_bar = hydra.utils.instantiate(cfg.progress_bar)
46 |
47 | callbacks = [
48 | checkpoint_callback,
49 | image_log_callback,
50 | stop_if_nan_callback,
51 | progress_bar,
52 | ]
53 | if cfg.model.name == "SCAM":
54 | callbacks.append(LogAttention())
55 |
56 | rank = _get_rank()
57 |
58 | if os.path.isfile(Path(cfg.checkpoints.dirpath) / Path("wandb_id.txt")):
59 | with open(
60 | Path(cfg.checkpoints.dirpath) / Path("wandb_id.txt"), "r"
61 | ) as wandb_id_file:
62 | wandb_id = wandb_id_file.readline()
63 | else:
64 | wandb_id = wandb.util.generate_id()
65 | print(f"generated id{wandb_id}")
66 | if rank == 0:
67 | with open(
68 | Path(cfg.checkpoints.dirpath) / Path("wandb_id.txt"), "w"
69 | ) as wandb_id_file:
70 | wandb_id_file.write(str(wandb_id))
71 |
72 | if (Path(cfg.checkpoints.dirpath) / Path("last.ckpt")).exists():
73 |
74 | print("Loading checkpoints")
75 | checkpoint_path = Path(cfg.checkpoints.dirpath) / Path("last.ckpt")
76 |
77 | logger = hydra.utils.instantiate(cfg.logger, id=wandb_id, resume="allow")
78 | model = Geppetto.load_from_checkpoint(checkpoint_path, cfg=cfg.model)
79 | logger.watch(model)
80 | trainer = hydra.utils.instantiate(
81 | cfg.trainer,
82 | strategy=cfg.trainer.strategy,
83 | logger=logger,
84 | callbacks=callbacks,
85 | resume_from_checkpoint=str(checkpoint_path),
86 | )
87 | else:
88 | logger = hydra.utils.instantiate(cfg.logger, id=wandb_id, resume="allow")
89 | logger._wandb_init.update({"config": log_dict})
90 | model = Geppetto(cfg.model)
91 | logger.watch(model)
92 | trainer = hydra.utils.instantiate(
93 | cfg.trainer,
94 | strategy=cfg.trainer.strategy,
95 | logger=logger,
96 | callbacks=callbacks,
97 | )
98 | # trainer.fit_loop.epoch_loop.batch_loop.connect(optimizer_loop=YieldLoop())
99 |
100 | trainer.fit(model, datamodule)
101 |
102 | trainer.test(model, dataloaders=datamodule, ckpt_path="best")
103 |
104 |
105 | if __name__ == "__main__":
106 | run()
107 |
--------------------------------------------------------------------------------
/utils/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nicolas-dufour/SCAM/848484acd71c277f3ed48456e206870847dd400d/utils/.DS_Store
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nicolas-dufour/SCAM/848484acd71c277f3ed48456e206870847dd400d/utils/__init__.py
--------------------------------------------------------------------------------
/utils/callbacks.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from pytorch_lightning import Callback
3 | from torch.utils.data import DataLoader
4 | import torch
5 | import torch.nn.functional as F
6 | import wandb
7 | import numpy as np
8 | from einops import rearrange
9 | from utils.utils import AttentionVis, get_palette, remap_image_torch
10 |
11 |
12 | log = logging.getLogger(__name__)
13 |
14 |
15 | class ReInitOptimAfterSanity(Callback):
16 | def on_train_start(self, trainer, pl_module):
17 | optimizers, lr_schedulers, optimizer_frequencies = trainer.init_optimizers(
18 | pl_module
19 | )
20 | trainer.optimizers = optimizers
21 | trainer.lr_schedulers = lr_schedulers
22 | trainer.optimizer_frequencies = optimizer_frequencies
23 |
24 |
25 | class LogAttention(Callback):
26 | def __init__(self, num_samples: int = 8):
27 | super().__init__()
28 | self.num_samples = num_samples
29 | self.ready = True
30 |
31 | def on_sanity_check_start(self, trainer, pl_module):
32 | self.ready = False
33 |
34 | def on_sanity_check_end(self, trainer, pl_module):
35 | """Start executing this callback only after all validation sanity checks end."""
36 | self.ready = True
37 |
38 | def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
39 | if trainer.global_step % 25000 == 0:
40 | self.log_attention(trainer, pl_module)
41 |
42 | def on_test_end(self, trainer, pl_module):
43 | self.log_attention(trainer, pl_module)
44 |
45 | def log_attention(self, trainer, pl_module):
46 | if self.ready:
47 | logger = trainer.logger
48 | experiment = logger.experiment
49 | log_encoder = (
50 | pl_module.cfg.encoder._target_
51 | == "models.encoder.SemanticAttentionTransformerEncoder"
52 | )
53 |
54 | attnn_viz = AttentionVis(log_encoder=log_encoder)
55 |
56 | # Create tokens palette
57 | np.random.seed(42)
58 | num_tokens = (
59 | (pl_module.cfg.generator.num_labels - 1)
60 | * pl_module.cfg.generator.num_labels_split
61 | + pl_module.cfg.generator.num_labels_bg
62 | )
63 | palette_tokens = torch.from_numpy(get_palette(num_tokens)).to(
64 | device=pl_module.device
65 | )
66 | # Create labels palette
67 | np.random.seed(10)
68 | palette_segmask = torch.from_numpy(
69 | get_palette(pl_module.cfg.generator.num_labels)
70 | ).to(device=pl_module.device)
71 |
72 | val_dataloader = DataLoader(
73 | trainer.datamodule.test_dataset,
74 | batch_size=self.num_samples,
75 | shuffle=True,
76 | )
77 | logs = dict()
78 |
79 | real_images, segmentation_maps = next(iter(val_dataloader))
80 | real_images = real_images.to(device=pl_module.device)
81 | segmentation_maps = segmentation_maps.to(device=pl_module.device)
82 |
83 | segmask_colorized = palette_segmask[segmentation_maps.argmax(1)]
84 |
85 | segmask_colorized = rearrange(segmask_colorized, "b h w c -> b c h w")
86 | with torch.no_grad():
87 | outputs = attnn_viz.encode_and_generate(
88 | pl_module, real_images, segmentation_maps
89 | )
90 | output_images = outputs["output"]
91 | del outputs["output"]
92 | for attn_key, attn_val in outputs.items():
93 |
94 | attentions = [
95 | F.interpolate(
96 | attention,
97 | size=(output_images.shape[2], output_images.shape[3]),
98 | )
99 | for attention in attn_val
100 | ]
101 |
102 | attentions_colorized = [
103 | palette_tokens[attention.argmax(1)] for attention in attentions
104 | ]
105 |
106 | attentions_colorized = [
107 | rearrange(attention, "b h w c -> b c h w")
108 | for attention in attentions_colorized
109 | ]
110 |
111 | attn_viz = rearrange(
112 | [
113 | remap_image_torch(real_images),
114 | remap_image_torch(output_images),
115 | segmask_colorized,
116 | *attentions_colorized,
117 | ],
118 | "l b c h w -> b c h (l w)",
119 | ).float()
120 | attn_viz = [
121 | wandb.Image(
122 | attn_viz[i],
123 | caption="Reference; Reconstruction; Segmask; Attn matrix argmax from first layers to last",
124 | )
125 | for i in range(attn_viz.shape[0])
126 | ]
127 | logs.update({f"Images/{attn_key}": attn_viz})
128 | experiment.log(logs, step=trainer.global_step)
129 |
130 |
131 | class GANImageLog(Callback):
132 | def __init__(self, num_samples: int = 8):
133 | super().__init__()
134 | self.num_samples = num_samples
135 | self.ready = True
136 |
137 | def on_sanity_check_start(self, trainer, pl_module):
138 | self.ready = False
139 |
140 | def on_sanity_check_end(self, trainer, pl_module):
141 | """Start executing this callback only after all validation sanity checks end."""
142 | self.ready = True
143 |
144 | def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
145 | if trainer.global_step % 25000 == 0:
146 | self.log_images(trainer, pl_module)
147 |
148 | def on_test_end(self, trainer, pl_module):
149 | self.log_images(trainer, pl_module)
150 |
151 | def log_images(self, trainer, pl_module):
152 | if self.ready:
153 | logger = trainer.logger
154 | experiment = logger.experiment
155 |
156 | # get a validation batch from the validation dat loader
157 | val_dataloader = DataLoader(
158 | trainer.datamodule.test_dataset,
159 | batch_size=self.num_samples,
160 | shuffle=True,
161 | )
162 |
163 | real_images, segmentation_maps = next(iter(val_dataloader))
164 | real_images = real_images.to(device=pl_module.device)
165 | segmentation_maps = segmentation_maps.to(device=pl_module.device)
166 |
167 | with torch.no_grad():
168 |
169 | styles_codes, content_code = pl_module.encoder(
170 | real_images, segmentation_maps
171 | )
172 | reco_images = pl_module.generator(
173 | segmentation_maps, styles_codes, content_code
174 | )
175 | reco_images = rearrange(
176 | [real_images, reco_images], "l b c h w -> b c h (l w)"
177 | )
178 | segmask_reco = rearrange(
179 | [segmentation_maps, segmentation_maps], "l b c h w -> b c h (l w)"
180 | )
181 | segmask_reco = segmask_reco.argmax(dim=1).cpu().numpy()
182 | permutation = torch.randperm(self.num_samples)
183 | if pl_module.cfg.losses.lambda_kld > 0:
184 | permuted_style_codes = [
185 | styles_codes[i][permutation] for i in range(len(styles_codes))
186 | ]
187 | else:
188 | permuted_style_codes = styles_codes[permutation]
189 |
190 | swap_images = pl_module.generator(
191 | segmentation_maps, permuted_style_codes, content_code
192 | )
193 | segmask_swap = rearrange(
194 | [
195 | segmentation_maps,
196 | segmentation_maps[permutation],
197 | segmentation_maps,
198 | ],
199 | "l b c h w -> b c h (l w)",
200 | )
201 | segmask_swap = segmask_swap.argmax(dim=1).cpu().numpy()
202 | swap_images = rearrange(
203 | [real_images, real_images[permutation], swap_images],
204 | "l b c h w -> b c h (l w)",
205 | )
206 | reco_images = [
207 | wandb.Image(
208 | reco_images[i],
209 | caption="Left: Real; Right: Reconstruction",
210 | masks={
211 | "Segmentation": {
212 | "mask_data": segmask_reco[i],
213 | }
214 | },
215 | )
216 | for i in range(reco_images.shape[0])
217 | ]
218 | experiment.log(
219 | {"Images/Reconstruction": reco_images}, step=trainer.global_step
220 | )
221 |
222 | swap_images = [
223 | wandb.Image(
224 | swap_images[i],
225 | caption="Left: Semantic ref; Middle: style ref; Right: Swap",
226 | masks={
227 | "Segmentation": {
228 | "mask_data": segmask_swap[i],
229 | }
230 | },
231 | )
232 | for i in range(swap_images.shape[0])
233 | ]
234 | experiment.log({"Images/Swap": swap_images}, step=trainer.global_step)
235 |
236 |
237 | class StopIfNaN(Callback):
238 | def __init__(self, monitor):
239 | super().__init__()
240 | self.monitor = monitor
241 | self.continuous_nan_batchs = 0
242 |
243 | def on_train_batch_end(
244 | self,
245 | trainer,
246 | pl_module,
247 | outputs,
248 | batch,
249 | batch_idx,
250 | ) -> None:
251 | logs = trainer.callback_metrics
252 | i = 0
253 | found_metric = False
254 | while i < len(self.monitor) and not found_metric:
255 | if self.monitor[i] in logs.keys():
256 | current = logs[self.monitor[i]].squeeze()
257 | found_metric = True
258 | else:
259 | i += 1
260 | if not found_metric:
261 | raise ValueError("Asked metric not in logs")
262 |
263 | if not torch.isfinite(current):
264 | self.continuous_nan_batchs += 1
265 | if self.continuous_nan_batchs >= 5:
266 | trainer.should_stop = True
267 | log.info("Training interrupted because of NaN in {self.monitor}")
268 | else:
269 | self.continuous_nan_batchs = 0
270 |
271 | def on_before_optimizer_step(self, trainer, pl_module, optimizer, opt_idx) -> None:
272 | valid_gradients = True
273 | for name, param in pl_module.named_parameters():
274 | if param.grad is not None:
275 | valid_gradients = not (
276 | torch.isnan(param.grad).any() or torch.isinf(param.grad).any()
277 | )
278 | if not valid_gradients:
279 | break
280 |
281 | if not valid_gradients:
282 | log.warning(
283 | f"detected inf or nan values in gradients. not updating model parameters"
284 | )
285 | optimizer.zero_grad()
286 |
--------------------------------------------------------------------------------
/utils/partial_conv.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch import nn
4 | from torch import nn
5 | import math
6 | from torch.nn import init
7 |
8 |
9 | class InstanceAwareConv2d(nn.Module):
10 | def __init__(self, fin, fout, kw, stride=1, padding=1):
11 | super().__init__()
12 | self.kw = kw
13 | self.stride = stride
14 | self.padding = padding
15 | self.fin = fin
16 | self.fout = fout
17 | self.unfold = nn.Unfold(kw, stride=stride, padding=padding)
18 | self.weight = nn.Parameter(torch.Tensor(fout, fin, kw, kw))
19 | self.bias = nn.Parameter(torch.Tensor(fout))
20 | self.reset_parameters()
21 |
22 | def reset_parameters(self):
23 | init.kaiming_uniform_(self.weight, a=math.sqrt(5))
24 | if self.bias is not None:
25 | fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
26 | bound = 1 / math.sqrt(fan_in)
27 | init.uniform_(self.bias, -bound, bound)
28 |
29 | def forward(self, x, instances, check=False):
30 | N, C, H, W = x.size()
31 | # cal the binary mask from instance map
32 | instances = F.interpolate(instances, x.size()[2:], mode="nearest") # [n,1,h,w]
33 | inst_unf = self.unfold(instances)
34 | # substract the center pixel
35 | center = torch.unsqueeze(inst_unf[:, self.kw * self.kw // 2, :], 1)
36 | mask_unf = inst_unf - center
37 | # clip the absolute value to 0~1
38 | mask_unf = torch.abs(mask_unf)
39 | mask_unf = torch.clamp(mask_unf, 0, 1)
40 | mask_unf = 1.0 - mask_unf # [n,k*k,L]
41 | # multiply mask_unf and x
42 | x_unf = self.unfold(x) # [n,c*k*k,L]
43 | x_unf = x_unf.view(N, C, -1, x_unf.size()[-1]) # [n,c,,k*k,L]
44 | mask = torch.unsqueeze(mask_unf, 1) # [n,1,k*k,L]
45 | mask_x = mask * x_unf # [n,c,k*k,L]
46 | mask_x = mask_x.view(N, -1, mask_x.size()[-1]) # [n,c*k*k,L]
47 | # conv operation
48 | weight = self.weight.view(self.fout, -1) # [fout, c*k*k]
49 | out = torch.einsum("cm,nml->ncl", weight, mask_x)
50 | # x_unf = torch.unsqueeze(x_unf, 1) # [n,1,c*k*k,L]
51 | # out = torch.mul(masked_weight, x_unf).sum(dim=2, keepdim=False) # [n,fout,L]
52 | bias = torch.unsqueeze(torch.unsqueeze(self.bias, 0), -1) # [1,fout,1]
53 | out = out + bias
54 | out = out.view(N, self.fout, H // self.stride, W // self.stride)
55 | # print('weight:',self.weight[0,0,...])
56 | # print('bias:',self.bias)
57 |
58 | if check:
59 | out2 = nn.functional.conv2d(
60 | x, self.weight, self.bias, stride=self.stride, padding=self.padding
61 | )
62 | print((out - out2).abs().max())
63 | return out
64 |
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from einops import rearrange
4 |
5 |
6 | def remap_image_numpy(image):
7 |
8 | image_numpy = ((image + 1) / 2.0) * 255.0
9 | return np.clip(image_numpy, 0, 255).astype(int)
10 |
11 |
12 | def remap_image_torch(image):
13 |
14 | image_torch = ((image + 1) / 2.0) * 255.0
15 | return torch.clip(image_torch, 0, 255).type(torch.uint8)
16 |
17 |
18 | def get_palette(num_cls):
19 | """Returns the color map for visualizing the segmentation mask.
20 | Args:
21 | num_cls: Number of classes
22 | Returns:
23 | The color map
24 | """
25 | palette = np.random.randint(0, 255, size=(num_cls, 3))
26 | return palette
27 |
28 |
29 | class AttentionVis:
30 | def __init__(self, log_encoder=True):
31 | self.latent_to_image_gen_masks = []
32 | self.image_to_latent_gen_masks = []
33 | self.log_encoder = log_encoder
34 | if self.log_encoder:
35 | self.encoder_masks = []
36 |
37 | def encoder_hook_fn(self, m, i, o):
38 | self.encoder_masks.append(o[1])
39 |
40 | def latent_to_image_gen_hook_fn(self, m, i, o):
41 | self.latent_to_image_gen_masks.append(o[1])
42 |
43 | def image_to_latent_gen_hook_fn(self, m, i, o):
44 | self.image_to_latent_gen_masks.append(o[1])
45 |
46 | def encode_and_generate(self, model, images, masks):
47 | ### activate attention
48 | num_gen_blocks = model.cfg.generator.num_up_layers
49 | gen_mod_blocks = [
50 | getattr(model.generator.backbone[f"SCAM_block_{i}"], f"mod_{j}")
51 | for i in range(num_gen_blocks)
52 | for j in range(2)
53 | ]
54 |
55 | for gen_mod_block in gen_mod_blocks:
56 | gen_mod_block.return_attention = True
57 | if self.log_encoder:
58 | num_enc_blocks = model.cfg.encoder.num_blocks
59 | enc_blocks = [
60 | model.encoder.backbone[f"block_{i}"]["cross_attention"]
61 | for i in range(num_enc_blocks)
62 | ]
63 | model.encoder.return_attention = True
64 |
65 | handles = []
66 | for mod in gen_mod_blocks:
67 | handle = mod.latent_to_image.register_forward_hook(
68 | self.latent_to_image_gen_hook_fn
69 | )
70 | handles.append(handle)
71 | handle = mod.image_to_latent.register_forward_hook(
72 | self.image_to_latent_gen_hook_fn
73 | )
74 | handles.append(handle)
75 | if self.log_encoder:
76 | for enc_block in enc_blocks:
77 | handle = enc_block.register_forward_hook(self.encoder_hook_fn)
78 | handles.append(handle)
79 |
80 | # ### retrieve attention
81 | output = model.encode_and_generate(images, masks)
82 | gen_mod_dims = [(mod.height, mod.width) for mod in gen_mod_blocks]
83 | if self.log_encoder:
84 | enc_dims = model.encoder.image_sizes
85 | # ### Remove hook
86 | for handle in handles:
87 | handle.remove()
88 | # ### deasctivate attention output
89 | for gen_mod_block in gen_mod_blocks:
90 | gen_mod_block.return_attention = False
91 | if self.log_encoder:
92 | model.encoder.return_attention = False
93 |
94 | latent_to_image_gen_attention_masks = [
95 | rearrange(
96 | attention_mask,
97 | "b heads (h w) c-> b c (heads h) w",
98 | h=h,
99 | w=w,
100 | )
101 | for (h, w), attention_mask in zip(
102 | gen_mod_dims, self.latent_to_image_gen_masks
103 | )
104 | ]
105 | image_to_latent_gen_attention_masks = [
106 | rearrange(
107 | attention_mask,
108 | "b heads c (h w)-> b c (heads h) w",
109 | h=h,
110 | w=w,
111 | )
112 | for (h, w), attention_mask in zip(
113 | gen_mod_dims, self.image_to_latent_gen_masks
114 | )
115 | ]
116 | if self.log_encoder:
117 | encoder_attention_masks = [
118 | rearrange(
119 | attention_mask,
120 | "b heads c (h w)-> b c (heads h) w",
121 | h=h,
122 | w=w,
123 | )
124 | for (h, w), attention_mask in zip(enc_dims, self.encoder_masks)
125 | ]
126 | self.encoder_masks = []
127 | self.latent_to_image_gen_masks = []
128 | self.image_to_latent_gen_masks = []
129 |
130 | if self.log_encoder:
131 | return {
132 | "output": output,
133 | "latent_to_image_gen_attn": latent_to_image_gen_attention_masks,
134 | "image_to_latent_gen_attn": image_to_latent_gen_attention_masks,
135 | "encoder_attn": encoder_attention_masks,
136 | }
137 | else:
138 | return {
139 | "output": output,
140 | "latent_to_image_gen_attn": latent_to_image_gen_attention_masks,
141 | "image_to_latent_gen_attn": image_to_latent_gen_attention_masks,
142 | }
143 |
144 | # def generate(self, model, style_codes, masks):
145 | # ### activate attention
146 | # model.generator.backbone.SCAM_block_5.mod_1.return_attention = True
147 | # output = model.generator(masks, style_codes)
148 | # self.height = model.generator.backbone.SCAM_block_5.mod_1.height
149 | # self.width = model.generator.backbone.SCAM_block_5.mod_1.width
150 | # ### register hook
151 | # handle = model.generator.backbone.SCAM_block_5.mod_1.latent_to_image.register_forward_hook(
152 | # self.hook_fn
153 | # )
154 | # ### retrieve attention
155 | # output = model.generator(masks, style_codes)
156 | # ### Remove hook
157 | # handle.remove()
158 | # ### deasctivate attention output
159 | # model.generator.backbone.SCAM_block_4.mod_1.return_attention = False
160 | # return output, self.attention_masks
161 |
--------------------------------------------------------------------------------