├── README.md ├── __init__.py ├── checkpoints ├── .DS_Store └── .empty ├── configs ├── compnode │ ├── 1-node-cluster.yaml │ ├── 2-node-cluster.yaml │ ├── 4-node-cluster.yaml │ ├── cpu.yaml │ ├── mps.yaml │ ├── single-gpu.yaml │ └── two-gpus.yaml ├── config.yaml ├── dataset │ ├── ADE20K.yaml │ ├── CelebAHQ.yaml │ ├── idesigner.yaml │ ├── train_aug │ │ ├── basic.yaml │ │ ├── basicpad.yaml │ │ ├── crop.yaml │ │ ├── none.yaml │ │ └── nonepad.yaml │ └── val_aug │ │ ├── crop.yaml │ │ ├── none.yaml │ │ └── nonepad.yaml └── model │ ├── clade.yaml │ ├── disc_augments │ └── basic.yaml │ ├── discriminator │ ├── mcad.yaml │ ├── oasis.yaml │ ├── patchgan.yaml │ └── stylegan2.yaml │ ├── encoder │ ├── groupdnet.yaml │ ├── inade.yaml │ ├── sat.yaml │ ├── sean.yaml │ └── spade.yaml │ ├── generator │ ├── clade.yaml │ ├── groupdnet.yaml │ ├── inade.yaml │ ├── scam.yaml │ ├── sean.yaml │ ├── sean_clade.yaml │ └── spade.yaml │ ├── groupdnet.yaml │ ├── inade.yaml │ ├── optim │ ├── adam.yaml │ ├── adamw.yaml │ └── sam.yaml │ ├── scam.yaml │ ├── sean.yaml │ ├── sean_clade.yaml │ └── spade.yaml ├── data ├── __init__.py ├── datamodule.py └── image_dataset.py ├── datasets └── .empty ├── environment.yml ├── gan.py ├── medias ├── .DS_Store └── teaser.png ├── metrics ├── __init__.py ├── fid.py ├── models │ └── fastreid │ │ ├── __init__.py │ │ ├── layers │ │ ├── __init__.py │ │ ├── activation.py │ │ ├── am_softmax.py │ │ ├── arc_softmax.py │ │ ├── batch_drop.py │ │ ├── batch_norm.py │ │ ├── circle_softmax.py │ │ ├── context_block.py │ │ ├── frn.py │ │ ├── gather_layer.py │ │ ├── non_local.py │ │ ├── pooling.py │ │ ├── se_layer.py │ │ └── splat.py │ │ └── model.py └── reid.py ├── models ├── __init__.py ├── diffaugment.py ├── discriminators │ ├── __init__.py │ ├── mcad.py │ ├── oasis.py │ ├── patchgan.py │ └── stylegan2.py ├── encoders │ ├── __init__.py │ ├── groupdnet.py │ ├── inade.py │ ├── sat.py │ ├── sean.py │ └── spade.py ├── generators │ ├── __init__.py │ ├── clade.py │ ├── groupdnet.py │ ├── inade.py │ ├── scam.py │ ├── sean.py │ ├── sean_clade.py │ └── spade.py ├── loss.py └── utils_blocks │ ├── EMA.py │ ├── SAM.py │ ├── __init__.py │ ├── attention.py │ ├── base.py │ └── equallr.py ├── preprocess_fid_features.py ├── train.py └── utils ├── .DS_Store ├── __init__.py ├── callbacks.py ├── partial_conv.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # SCAM! Transferring humans between images with Semantic Cross Attention Modulation 2 | 3 | Official PyTorch implementation of "SCAM! Transferring humans between images with Semantic Cross Attention Modulation", ECCV 2022. 4 | 5 | Arxiv | Website 6 | 7 | 8 |

9 | 10 |

11 | 12 | 13 | 14 | ## Bibtex 15 | 16 | If you happen to find this code or method useful in your research please cite this paper with 17 | 18 | ``` 19 | @article{dufour2022scam, 20 | title={SCAM! Transferring humans between images with Semantic Cross Attention Modulation}, 21 | author={Nicolas Dufour, David Picard, Vicky Kalogeiton}, 22 | booktitle={Proceedings of the European Conference on Computer Vision (ECCV)}, 23 | year={2022} 24 | } 25 | ``` 26 | 27 | ## Installation guide 28 | To install this repo follow the following step: 29 | 30 | 1. Install Anaconda or MiniConda 31 | 2. Run `conda create -n scam python=3.9.12` 32 | 3. Activate scam: `conda activate scam` 33 | 4. Install pytorch 1.12 and torchvision 0.13 that match your device: 34 | - For cpu: 35 | ```bash 36 | conda install pytorch==1.12.0 torchvision==0.13.0 torchaudio==0.12.0 cpuonly -c pytorch 37 | ``` 38 | - For cuda 11.6: 39 | ```bash 40 | conda install pytorch==1.12.0 torchvision==0.13.0 cudatoolkit=11.6 -c pytorch -c conda-forge 41 | ``` 42 | 5. Install dependencies: 43 | ``` 44 | conda env update --file environment.yml 45 | ``` 46 | 47 | ## Prepare the data 48 | 49 | ### Download CelebAHQ 50 | 51 | Download link: Dataset provided and processed by SEAN authors. 52 | Unzip in the datasets folder. 53 | Then run: 54 | 55 | ```bash 56 | python preprocess_fid_features.py dataset=CelebAHQ 57 | ``` 58 | to run on GPU run: 59 | 60 | ```bash 61 | python preprocess_fid_features.py dataset=CelebAHQ compnode=single-gpu 62 | ``` 63 | 64 | ## Add a new dataset 65 | If you want to add a new dataset you need to follow the same directory structure as the other datasets: 66 | 67 | 2 folders: train and test. 68 | Each folders contains 4 sub-folders: images, labels, vis and stats. 69 | - images: contains the rgb images. can be jpg or png. 70 | - labels: the segmentation labels. needs to be png where each pixels takes value between 0 and num_labels. 71 | - vis: rgb visualization of the labels 72 | - stats: Inception statistics of the train dataset. To do so, initialize a dataloader with the train dataloader 73 | 74 | Create a config file for the dataset in configs/dataset (the name of the dataset should be the name of the config file) 75 | 76 | and run: 77 | 78 | ```bash 79 | python preprocess_fid_features.py dataset=dataset_name compnode=(cpu or single-gpu) 80 | ``` 81 | 82 | ## Run SCAM 83 | Our code needs to be logged in wandb. Make sure to be logged in to wandb before starting. 84 | To run our code you just need to do: 85 | 86 | ```bash 87 | python train.py 88 | ``` 89 | 90 | By default the code will be runned on cpu. If you want to run it on gpu you can do: 91 | 92 | ```bash 93 | python train.py compnode=single-gpu 94 | ``` 95 | 96 | Other compute configs exists in `configs/compnode`. If none suit your needs you can easily create one thanks to the modularity of hydra. 97 | 98 | In this paper, the main compute config we use is the 4 gpus (NVIDIA V100 32g with a total batch size of 32, 8 per GPUs) config called `1-node-cluster`. 99 | 100 | By default, our method scam will be runned. If you want to try one of the baselines we compare to, you can run 101 | 102 | ```bash 103 | python train.py model=sean 104 | ``` 105 | 106 | We include implementations for SEAN, SPADE, INADE, CLADE and GroupDNET. 107 | 108 | To swap datasets, you can do: 109 | 110 | ```bash 111 | python train.py dataset=CelebAHQ 112 | ``` 113 | 114 | we support 3 datasets, `CelebAHQ`, `idesigner` and `ADE20K`. 115 | 116 | If you want to run different experiments, make sure to change `experiment_name_comp` or `experiment_name` when running. Otherwise or code will restart the training from the checkpointed weights for the experiment with the same name. 117 | 118 | ## Pretrained models 119 | Coming soon -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nicolas-dufour/SCAM/848484acd71c277f3ed48456e206870847dd400d/__init__.py -------------------------------------------------------------------------------- /checkpoints/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nicolas-dufour/SCAM/848484acd71c277f3ed48456e206870847dd400d/checkpoints/.DS_Store -------------------------------------------------------------------------------- /checkpoints/.empty: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nicolas-dufour/SCAM/848484acd71c277f3ed48456e206870847dd400d/checkpoints/.empty -------------------------------------------------------------------------------- /configs/compnode/1-node-cluster.yaml: -------------------------------------------------------------------------------- 1 | devices: 4 2 | progress_bar_refresh_rate: 2 3 | sync_batchnorm: True 4 | accelerator: gpu 5 | precision: 16 6 | strategy: ddp 7 | batch_size: 8 8 | num_nodes: 1 9 | -------------------------------------------------------------------------------- /configs/compnode/2-node-cluster.yaml: -------------------------------------------------------------------------------- 1 | devices: 4 2 | progress_bar_refresh_rate: 2 3 | sync_batchnorm: True 4 | accelerator: gpu 5 | precision: 16 6 | strategy: ddp 7 | batch_size: 8 8 | num_nodes: 2 9 | -------------------------------------------------------------------------------- /configs/compnode/4-node-cluster.yaml: -------------------------------------------------------------------------------- 1 | devices: 4 2 | progress_bar_refresh_rate: 2 3 | sync_batchnorm: True 4 | accelerator: gpu 5 | precision: 16 6 | strategy: ddp 7 | batch_size: 8 8 | num_nodes: 4 9 | -------------------------------------------------------------------------------- /configs/compnode/cpu.yaml: -------------------------------------------------------------------------------- 1 | devices: 1 2 | progress_bar_refresh_rate: 2 3 | sync_batchnorm: False 4 | accelerator: cpu 5 | precision: 32 6 | strategy: null 7 | batch_size: 4 8 | num_nodes: 1 -------------------------------------------------------------------------------- /configs/compnode/mps.yaml: -------------------------------------------------------------------------------- 1 | devices: 1 2 | progress_bar_refresh_rate: 2 3 | sync_batchnorm: False 4 | accelerator: mps 5 | precision: 32 6 | strategy: null 7 | batch_size: 4 8 | num_nodes: 1 -------------------------------------------------------------------------------- /configs/compnode/single-gpu.yaml: -------------------------------------------------------------------------------- 1 | devices: 1 2 | progress_bar_refresh_rate: 2 3 | sync_batchnorm: False 4 | accelerator: gpu 5 | precision: 16 6 | strategy: null 7 | batch_size: 4 8 | num_nodes: 1 -------------------------------------------------------------------------------- /configs/compnode/two-gpus.yaml: -------------------------------------------------------------------------------- 1 | devices: 2 2 | progress_bar_refresh_rate: 2 3 | sync_batchnorm: False 4 | accelerator: gpu 5 | precision: 16 6 | strategy: ddp 7 | batch_size: 4 8 | num_nodes: 1 -------------------------------------------------------------------------------- /configs/config.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - model: scam 3 | - compnode: cpu 4 | - dataset: CelebAHQ 5 | 6 | trainer: 7 | _target_: pytorch_lightning.Trainer 8 | max_steps: 100000 9 | devices: ${compnode.devices} 10 | accelerator: ${compnode.accelerator} 11 | sync_batchnorm: ${compnode.sync_batchnorm} 12 | strategy: ${compnode.strategy} 13 | log_every_n_steps: 1 14 | num_nodes: ${compnode.num_nodes} 15 | precision: ${compnode.precision} 16 | dataset: 17 | batch_size: ${compnode.batch_size} 18 | 19 | logger: 20 | _target_: pytorch_lightning.loggers.WandbLogger 21 | save_dir: ${root_dir}/wandb 22 | name: ${experiment_name} 23 | project: Pose_Transfer 24 | log_model: False 25 | offline: True 26 | 27 | checkpoints: 28 | _target_: pytorch_lightning.callbacks.ModelCheckpoint 29 | dirpath: ${root_dir}/checkpoints/${experiment_name} 30 | monitor: val/reco_fid 31 | save_last: True 32 | every_n_epochs: 1 33 | 34 | progress_bar: 35 | _target_: pytorch_lightning.callbacks.TQDMProgressBar 36 | refresh_rate: ${compnode.progress_bar_refresh_rate} 37 | 38 | hydra: 39 | run: 40 | dir: outputs/${hydra.job.name}/${now:%Y-%m-%d_%H-%M-%S}/${experiment_name} 41 | 42 | data_dir: ${root_dir}/datasets 43 | root_dir: ${hydra:runtime.cwd} 44 | experiment_name_comp: base 45 | experiment_name: ${dataset.name}_${model.name}_${experiment_name_comp} 46 | -------------------------------------------------------------------------------- /configs/dataset/ADE20K.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - train_aug: crop 3 | - val_aug: crop 4 | 5 | path: ${data_dir}/ADE20K 6 | name: ADE20K 7 | height: 256 8 | width: 256 9 | num_labels_orig: 151 10 | num_labels: ${dataset.num_labels_orig} 11 | num_channels: 3 12 | num_workers: 4 13 | image_extension: jpg 14 | label_merge_strat: none -------------------------------------------------------------------------------- /configs/dataset/CelebAHQ.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - train_aug: none 3 | - val_aug: none 4 | 5 | path: ${data_dir}/CelebA-HQ 6 | name: CelebA-HQ 7 | height: 256 8 | width: 256 9 | num_labels: ${dataset.num_labels_orig} 10 | num_labels_orig: 19 11 | num_channels: 3 12 | num_workers: 10 13 | image_extension: jpg 14 | label_merge_strat: none -------------------------------------------------------------------------------- /configs/dataset/idesigner.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - train_aug: none 3 | - val_aug: none 4 | 5 | path: ${data_dir}/idesigner 6 | name: idesigner 7 | height: 256 8 | width: 170 9 | num_labels_orig: 20 10 | num_labels: 3 11 | num_channels: 3 12 | num_workers: 10 13 | image_extension: png 14 | label_merge_strat: body_face_background -------------------------------------------------------------------------------- /configs/dataset/train_aug/basic.yaml: -------------------------------------------------------------------------------- 1 | _target_: albumentations.Compose 2 | transforms: 3 | - _target_: albumentations.ColorJitter 4 | p: 0.4 5 | - _target_: albumentations.HorizontalFlip 6 | p: 0.5 7 | - _target_: albumentations.OneOf 8 | p: 1 9 | transforms: 10 | - _target_: albumentations.RandomResizedCrop 11 | height: ${dataset.height} 12 | width: ${dataset.width} 13 | scale: [0.5, 1] 14 | ratio: [1,1] 15 | p: 0.7 16 | - _target_: albumentations.Resize 17 | height: ${dataset.height} 18 | width: ${dataset.width} 19 | p: 0.3 20 | 21 | - _target_: albumentations.Normalize 22 | mean: 0.5 23 | std: 0.5 24 | - _target_: albumentations.pytorch.ToTensorV2 -------------------------------------------------------------------------------- /configs/dataset/train_aug/basicpad.yaml: -------------------------------------------------------------------------------- 1 | _target_: albumentations.Compose 2 | transforms: 3 | - _target_: albumentations.ColorJitter 4 | p: 0.4 5 | - _target_: albumentations.HorizontalFlip 6 | p: 0.5 7 | - _target_: albumentations.LongestMaxSize 8 | p: 1 9 | max_size: ${dataset.height} 10 | - _target_: albumentations.PadIfNeeded 11 | p: 1 12 | border_mode: 0 13 | min_height: ${dataset.height} 14 | min_width: ${dataset.width} 15 | value: [0,0,0] 16 | mask_value: 0 17 | 18 | - _target_: albumentations.Normalize 19 | mean: 0.5 20 | std: 0.5 21 | - _target_: albumentations.pytorch.ToTensorV2 -------------------------------------------------------------------------------- /configs/dataset/train_aug/crop.yaml: -------------------------------------------------------------------------------- 1 | _target_: albumentations.Compose 2 | transforms: 3 | - _target_: albumentations.Resize 4 | height: ${dataset.height} 5 | width: ${dataset.width} 6 | - _target_: albumentations.ColorJitter 7 | p: 0.4 8 | - _target_: albumentations.HorizontalFlip 9 | p: 0.5 10 | - _target_: albumentations.OneOf 11 | transforms: 12 | - _target_: albumentations.RandomSizedCrop 13 | p: 0.7 14 | w2h_ratio: 0.664 15 | min_max_height: [400, 512] 16 | height: ${dataset.height} 17 | width: ${dataset.width} 18 | - _target_: albumentations.RandomSizedCrop 19 | p: 0.3 20 | w2h_ratio: 0.664 21 | min_max_height: [128, 400] 22 | height: ${dataset.height} 23 | width: ${dataset.width} 24 | p: 0.6 25 | - _target_: albumentations.Normalize 26 | mean: 0.5 27 | std: 0.5 28 | - _target_: albumentations.pytorch.ToTensorV2 -------------------------------------------------------------------------------- /configs/dataset/train_aug/none.yaml: -------------------------------------------------------------------------------- 1 | _target_: albumentations.Compose 2 | transforms: 3 | - _target_: albumentations.Resize 4 | height: ${dataset.height} 5 | width: ${dataset.width} 6 | - _target_: albumentations.Normalize 7 | mean: 0.5 8 | std: 0.5 9 | - _target_: albumentations.pytorch.ToTensorV2 -------------------------------------------------------------------------------- /configs/dataset/train_aug/nonepad.yaml: -------------------------------------------------------------------------------- 1 | _target_: albumentations.Compose 2 | transforms: 3 | - _target_: albumentations.LongestMaxSize 4 | p: 1 5 | max_size: ${dataset.height} 6 | - _target_: albumentations.PadIfNeeded 7 | p: 1 8 | border_mode: 0 9 | min_height: ${dataset.height} 10 | min_width: ${dataset.width} 11 | value: [0,0,0] 12 | mask_value: 0 13 | - _target_: albumentations.HorizontalFlip 14 | p: 0.5 15 | - _target_: albumentations.Normalize 16 | mean: 0.5 17 | std: 0.5 18 | - _target_: albumentations.pytorch.ToTensorV2 -------------------------------------------------------------------------------- /configs/dataset/val_aug/crop.yaml: -------------------------------------------------------------------------------- 1 | _target_: albumentations.Compose 2 | transforms: 3 | - _target_: albumentations.SmallestMaxSize 4 | p: 1 5 | max_size: ${dataset.height} 6 | - _target_: albumentations.CenterCrop 7 | p: 1 8 | height: ${dataset.height} 9 | width: ${dataset.width} 10 | - _target_: albumentations.HorizontalFlip 11 | p: 0.5 12 | - _target_: albumentations.Normalize 13 | mean: 0.5 14 | std: 0.5 15 | - _target_: albumentations.pytorch.ToTensorV2 -------------------------------------------------------------------------------- /configs/dataset/val_aug/none.yaml: -------------------------------------------------------------------------------- 1 | _target_: albumentations.Compose 2 | transforms: 3 | - _target_: albumentations.Resize 4 | height: ${dataset.height} 5 | width: ${dataset.width} 6 | 7 | - _target_: albumentations.Normalize 8 | mean: 0.5 9 | std: 0.5 10 | - _target_: albumentations.pytorch.ToTensorV2 -------------------------------------------------------------------------------- /configs/dataset/val_aug/nonepad.yaml: -------------------------------------------------------------------------------- 1 | _target_: albumentations.Compose 2 | transforms: 3 | - _target_: albumentations.LongestMaxSize 4 | p: 1 5 | max_size: ${dataset.height} 6 | - _target_: albumentations.PadIfNeeded 7 | p: 1 8 | border_mode: 0 9 | min_height: ${dataset.height} 10 | min_width: ${dataset.width} 11 | value: [0,0,0] 12 | mask_value: 0 13 | - _target_: albumentations.Normalize 14 | mean: 0.5 15 | std: 0.5 16 | - _target_: albumentations.pytorch.ToTensorV2 -------------------------------------------------------------------------------- /configs/model/clade.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - disc_augments: null 3 | - generator: clade 4 | - discriminator: patchgan 5 | - encoder: spade 6 | - optim: adam 7 | 8 | change_init: True 9 | gradient_clip_val: 0 10 | name: CLADE 11 | dataset_path: ${dataset.path} 12 | losses: 13 | lambda_gan: 1 14 | lambda_gan_end: ${model.losses.lambda_gan} 15 | lambda_gan_decay_steps: 1 16 | 17 | lambda_fm: 10.0 18 | lambda_fm_end: ${model.losses.lambda_fm} 19 | lambda_fm_decay_steps: 1 20 | 21 | lambda_label_mix: 0 22 | lambda_label_mix_end: ${model.losses.lambda_label_mix} 23 | lambda_label_mix_decay_steps: 1 24 | 25 | lambda_l1: 0 26 | lambda_l1_end: ${model.losses.lambda_l1} 27 | lambda_l1_decay_steps: 1 28 | 29 | lambda_perceptual: 10.0 30 | lambda_perceptual_end: ${model.losses.lambda_perceptual} 31 | lambda_perceptual_decay_steps: 1 32 | 33 | lambda_r1: 0 34 | lambda_r1_end: ${model.losses.lambda_r1} 35 | lambda_r1_decay_steps: 1 36 | 37 | lambda_kld: 0.05 38 | lambda_kld_end: ${model.losses.lambda_kld} 39 | lambda_kld_decay_steps: 1 40 | 41 | lazy_r1_step: 16 42 | gan_loss_type: hinge 43 | gan_loss_on_swaps: False 44 | use_adaptive_lambda : False -------------------------------------------------------------------------------- /configs/model/disc_augments/basic.yaml: -------------------------------------------------------------------------------- 1 | _target_: torch.nn.Sequential 2 | _args_: 3 | - _target_: kornia.augmentation.augmentation.Denormalize 4 | mean: 0.5 5 | std: 0.5 6 | - _target_: kornia.augmentation.augmentation.ColorJitter 7 | p: 0.8 8 | brightness : 0.2 9 | contrast: 0.3 10 | hue: 0.2 11 | - _target_: kornia.augmentation.augmentation.RandomErasing 12 | p: 0.5 13 | - _target_: kornia.augmentation.augmentation.Normalize 14 | mean: 0.5 15 | std: 0.5 16 | -------------------------------------------------------------------------------- /configs/model/discriminator/mcad.yaml: -------------------------------------------------------------------------------- 1 | _target_: models.discriminators.mcad.MultiScaleMCADiscriminator 2 | num_discriminator: 2 3 | image_num_channels: ${dataset.num_channels} 4 | segmap_num_channels: ${dataset.num_labels} 5 | positional_embedding_dim: 40 6 | num_labels: ${dataset.num_labels} 7 | num_latent_per_labels: 8 8 | latent_dim: 64 9 | num_blocks: 6 10 | attention_latent_dim: 64 11 | num_cross_heads: 1 12 | num_self_heads: 1 13 | apply_spectral_norm: False 14 | concat_segmaps: True 15 | output_type: attention_pool 16 | keep_intermediate_results: True -------------------------------------------------------------------------------- /configs/model/discriminator/oasis.yaml: -------------------------------------------------------------------------------- 1 | _target_: models.discriminators.oasis.OASISDiscriminator 2 | image_num_channels: ${dataset.num_channels} 3 | segmap_num_channels: ${dataset.num_labels} 4 | apply_spectral_norm: True 5 | apply_grad_norm: False -------------------------------------------------------------------------------- /configs/model/discriminator/patchgan.yaml: -------------------------------------------------------------------------------- 1 | _target_: models.discriminators.patchgan.MultiScalePatchGanDiscriminator 2 | num_discriminator: 2 3 | image_num_channels: ${dataset.num_channels} 4 | segmap_num_channels: ${dataset.num_labels} 5 | num_features_fst_conv: 64 6 | num_layers: 4 7 | apply_spectral_norm: True 8 | apply_grad_norm: False 9 | keep_intermediate_results: True 10 | use_equalized_lr: False 11 | lr_mul: 1.0 -------------------------------------------------------------------------------- /configs/model/discriminator/stylegan2.yaml: -------------------------------------------------------------------------------- 1 | _target_: models.discriminators.stylegan2.MultiStyleGan2Discriminator 2 | num_discriminator: 1 3 | image_num_channels: ${dataset.num_channels} 4 | segmap_num_channels: ${dataset.num_labels} 5 | num_features_fst_conv: 64 6 | num_layers: 7 7 | fmap_max: 512 8 | apply_spectral_norm: False 9 | apply_grad_norm: False 10 | keep_intermediate_results: True 11 | use_equalized_lr: False 12 | lr_mul: 1.0 -------------------------------------------------------------------------------- /configs/model/encoder/groupdnet.yaml: -------------------------------------------------------------------------------- 1 | _target_: models.encoders.groupdnet.GroupDNetStyleEncoder 2 | num_labels: ${dataset.num_labels} 3 | -------------------------------------------------------------------------------- /configs/model/encoder/inade.yaml: -------------------------------------------------------------------------------- 1 | _target_: models.encoders.inade.InstanceAdaptiveEncoder 2 | num_labels: ${dataset.num_labels} 3 | noise_dim: ${model.generator.noise_dim} 4 | use_vae: True -------------------------------------------------------------------------------- /configs/model/encoder/sat.yaml: -------------------------------------------------------------------------------- 1 | _target_: models.encoders.sat.SemanticAttentionTransformerEncoder 2 | num_input_channels: 3 3 | positional_embedding_dim: 40 4 | num_labels: ${dataset.num_labels} 5 | num_latent_per_labels: 8 6 | num_latents_bg: 8 7 | latent_dim: 256 8 | type_of_initial_latents: learned 9 | attention_latent_dim: 256 10 | num_blocks: 7 11 | num_cross_heads: 1 12 | num_self_heads: 1 13 | image_conv: True 14 | reverse_conv: False 15 | conv_features_dim_first: 16 16 | use_self_attention: True 17 | use_equalized_lr: False 18 | lr_mul: 1.0 19 | use_vae: False -------------------------------------------------------------------------------- /configs/model/encoder/sean.yaml: -------------------------------------------------------------------------------- 1 | _target_: models.encoders.sean.RegionalAveragePoolingStyleEncoder 2 | num_input_channels: ${dataset.num_channels} 3 | latent_dim: 512 4 | num_features_fst_conv: 32 5 | use_vae: False -------------------------------------------------------------------------------- /configs/model/encoder/spade.yaml: -------------------------------------------------------------------------------- 1 | _target_: models.encoders.spade.SPADEStyleEncoder 2 | use_vae: True -------------------------------------------------------------------------------- /configs/model/generator/clade.yaml: -------------------------------------------------------------------------------- 1 | _target_ : models.generators.clade.CLADEGenerator 2 | num_filters_last_layer: 64 3 | num_up_layers: 7 4 | height: ${dataset.height} 5 | width: ${dataset.width} 6 | num_labels: ${dataset.num_labels} 7 | kernel_size: 3 8 | num_output_channels: ${dataset.num_channels} 9 | apply_spectral_norm: True 10 | use_vae : True 11 | use_dists: False -------------------------------------------------------------------------------- /configs/model/generator/groupdnet.yaml: -------------------------------------------------------------------------------- 1 | _target_ : models.generators.groupdnet.GroupDNetGenerator 2 | num_filters_last_layer: 64 3 | num_up_layers: 7 4 | height: ${dataset.height} 5 | width: ${dataset.width} 6 | num_labels: ${dataset.num_labels} 7 | kernel_size: 3 8 | num_output_channels: ${dataset.num_channels} 9 | apply_spectral_norm: True 10 | use_vae : True -------------------------------------------------------------------------------- /configs/model/generator/inade.yaml: -------------------------------------------------------------------------------- 1 | _target_ : models.generators.inade.INADEGenerator 2 | num_filters_last_layer: 64 3 | num_up_layers: 7 4 | height: ${dataset.height} 5 | width: ${dataset.width} 6 | num_labels: ${dataset.num_labels} 7 | noise_dim: 64 8 | kernel_size: 3 9 | num_output_channels: ${dataset.num_channels} 10 | apply_spectral_norm: True 11 | use_vae: True -------------------------------------------------------------------------------- /configs/model/generator/scam.yaml: -------------------------------------------------------------------------------- 1 | _target_ : models.generators.scam.SCAMGenerator 2 | num_filters_last_layer: 64 3 | num_up_layers: 6 4 | height: ${dataset.height} 5 | width: ${dataset.width} 6 | num_labels: ${dataset.num_labels} 7 | num_labels_split: 8 8 | num_labels_bg: 8 9 | style_dim: ${model.encoder.latent_dim} 10 | kernel_size: 3 11 | attention_latent_dim: 256 12 | num_heads: 1 13 | attention_type: duplex 14 | num_up_layers_with_mask_adain: 0 15 | num_output_channels: ${dataset.num_channels} 16 | latent_pos_emb: none 17 | apply_spectral_norm: False 18 | split_latents: False 19 | norm_type: InstanceNorm 20 | architecture: skip 21 | add_noise: True 22 | modulate: True 23 | use_equalized_lr: False 24 | lr_mul: 1.0 25 | use_vae: False -------------------------------------------------------------------------------- /configs/model/generator/sean.yaml: -------------------------------------------------------------------------------- 1 | _target_ : models.generators.sean.SEANGenerator 2 | num_filters_last_layer: 64 3 | num_up_layers: 7 4 | height: ${dataset.height} 5 | width: ${dataset.width} 6 | num_labels: ${dataset.num_labels} 7 | style_dim: ${model.encoder.latent_dim} 8 | kernel_size: 3 9 | num_output_channels: ${dataset.num_channels} 10 | apply_spectral_norm: True 11 | use_vae: False -------------------------------------------------------------------------------- /configs/model/generator/sean_clade.yaml: -------------------------------------------------------------------------------- 1 | _target_ : models.generators.sean_clade.SEANCLADEGenerator 2 | num_filters_last_layer: 64 3 | num_up_layers: 7 4 | height: ${dataset.height} 5 | width: ${dataset.width} 6 | num_labels: ${dataset.num_labels} 7 | style_dim: ${model.encoder.latent_dim} 8 | kernel_size: 3 9 | num_output_channels: ${dataset.num_channels} 10 | apply_spectral_norm: True 11 | use_dists: False -------------------------------------------------------------------------------- /configs/model/generator/spade.yaml: -------------------------------------------------------------------------------- 1 | _target_ : models.generators.spade.SPADEGenerator 2 | num_filters_last_layer: 64 3 | num_up_layers: 7 4 | height: ${dataset.height} 5 | width: ${dataset.width} 6 | num_labels: ${dataset.num_labels} 7 | kernel_size: 3 8 | num_output_channels: ${dataset.num_channels} 9 | apply_spectral_norm: True 10 | use_vae : True -------------------------------------------------------------------------------- /configs/model/groupdnet.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - disc_augments: null 3 | - generator: groupdnet 4 | - discriminator: patchgan 5 | - encoder: groupdnet 6 | - optim: adam 7 | 8 | change_init: True 9 | gradient_clip_val: 0 10 | name: GroupDNet 11 | dataset_path: ${dataset.path} 12 | losses: 13 | lambda_gan: 1 14 | lambda_gan_end: ${model.losses.lambda_gan} 15 | lambda_gan_decay_steps: 1 16 | 17 | lambda_fm: 10.0 18 | lambda_fm_end: ${model.losses.lambda_fm} 19 | lambda_fm_decay_steps: 1 20 | 21 | lambda_label_mix: 0 22 | lambda_label_mix_end: ${model.losses.lambda_label_mix} 23 | lambda_label_mix_decay_steps: 1 24 | 25 | lambda_l1: 0 26 | lambda_l1_end: ${model.losses.lambda_l1} 27 | lambda_l1_decay_steps: 1 28 | 29 | lambda_perceptual: 10.0 30 | lambda_perceptual_end: ${model.losses.lambda_perceptual} 31 | lambda_perceptual_decay_steps: 1 32 | 33 | lambda_r1: 0 34 | lambda_r1_end: ${model.losses.lambda_r1} 35 | lambda_r1_decay_steps: 1 36 | 37 | lambda_kld: 0.05 38 | lambda_kld_end: ${model.losses.lambda_kld} 39 | lambda_kld_decay_steps: 1 40 | lazy_r1_step: 16 41 | gan_loss_type: hinge 42 | gan_loss_on_swaps: False 43 | use_adaptive_lambda : False -------------------------------------------------------------------------------- /configs/model/inade.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - disc_augments: null 3 | - generator: inade 4 | - discriminator: patchgan 5 | - encoder: inade 6 | - optim: adam 7 | 8 | change_init: True 9 | gradient_clip_val: 0 10 | name: INADE 11 | dataset_path: ${dataset.path} 12 | losses: 13 | lambda_gan: 1 14 | lambda_gan_end: ${model.losses.lambda_gan} 15 | lambda_gan_decay_steps: 1 16 | 17 | lambda_fm: 10.0 18 | lambda_fm_end: ${model.losses.lambda_fm} 19 | lambda_fm_decay_steps: 1 20 | 21 | lambda_label_mix: 0 22 | lambda_label_mix_end: ${model.losses.lambda_label_mix} 23 | lambda_label_mix_decay_steps: 1 24 | 25 | lambda_l1: 0 26 | lambda_l1_end: ${model.losses.lambda_l1} 27 | lambda_l1_decay_steps: 1 28 | 29 | lambda_perceptual: 10.0 30 | lambda_perceptual_end: ${model.losses.lambda_perceptual} 31 | lambda_perceptual_decay_steps: 1 32 | 33 | lambda_r1: 0 34 | lambda_r1_end: ${model.losses.lambda_r1} 35 | lambda_r1_decay_steps: 1 36 | 37 | lambda_kld: 0.05 38 | lambda_kld_end: ${model.losses.lambda_kld} 39 | lambda_kld_decay_steps: 1 40 | lazy_r1_step: 16 41 | gan_loss_type: hinge 42 | gan_loss_on_swaps: False 43 | use_adaptive_lambda : False -------------------------------------------------------------------------------- /configs/model/optim/adam.yaml: -------------------------------------------------------------------------------- 1 | disc_optim: 2 | _target_: torch.optim.Adam 3 | lr: ${model.optim.disc_lr} 4 | betas: ${model.optim.betas} 5 | weight_decay: ${model.optim.weight_decay} 6 | gen_optim: 7 | _target_: torch.optim.Adam 8 | lr: ${model.optim.gen_lr} 9 | betas: ${model.optim.betas} 10 | weight_decay: ${model.optim.weight_decay} 11 | SAM: False 12 | betas : [0, 0.999] 13 | weight_decay: 0 14 | disc_lr: 4e-4 15 | gen_lr: 1e-4 -------------------------------------------------------------------------------- /configs/model/optim/adamw.yaml: -------------------------------------------------------------------------------- 1 | disc_optim: 2 | _target_: torch.optim.AdamW 3 | lr: ${model.optim.disc_lr} 4 | betas: ${model.optim.betas} 5 | weight_decay: ${model.optim.weight_decay} 6 | gen_optim: 7 | _target_: torch.optim.AdamW 8 | lr: ${model.optim.gen_lr} 9 | betas: ${model.optim.betas} 10 | weight_decay: ${model.optim.weight_decay} 11 | SAM: False 12 | betas : [0.9, 0.999] 13 | weight_decay: 0.01 14 | disc_lr: 4e-4 15 | gen_lr: 1e-4 -------------------------------------------------------------------------------- /configs/model/optim/sam.yaml: -------------------------------------------------------------------------------- 1 | disc_optim: 2 | _target_: models.utils_blocks.SAM.SAM 3 | rho: 0.5 4 | adaptive: True 5 | lr: 0.0004 6 | betas: [0.9, 0.999] 7 | weight_decay: 0.01 8 | gen_optim: 9 | _target_: models.utils_blocks.SAM.SAM 10 | rho: 0.5 11 | adaptive: True 12 | lr: 0.0001 13 | betas: [0.9, 0.999] 14 | weight_decay: 0.01 15 | SAM: True 16 | -------------------------------------------------------------------------------- /configs/model/scam.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - disc_augments: null 3 | - generator: scam 4 | - discriminator: patchgan 5 | - encoder: sat 6 | - optim: adamw 7 | 8 | change_init: False 9 | gradient_clip_val: 0 10 | name: SCAM 11 | dataset_path: ${dataset.path} 12 | 13 | 14 | discriminator: 15 | apply_spectral_norm: False 16 | apply_grad_norm: True 17 | 18 | generator: 19 | apply_spectral_norm: False 20 | losses: 21 | lambda_gan: 1 22 | lambda_gan_end: ${model.losses.lambda_gan} 23 | lambda_gan_decay_steps: 1 24 | 25 | lambda_fm: 0 26 | lambda_fm_end: ${model.losses.lambda_fm} 27 | lambda_fm_decay_steps: 1 28 | 29 | lambda_label_mix: 0 30 | lambda_label_mix_end: ${model.losses.lambda_label_mix} 31 | lambda_label_mix_decay_steps: 1 32 | 33 | lambda_l1: 10.0 34 | lambda_l1_end: ${model.losses.lambda_l1} 35 | lambda_l1_decay_steps: 1 36 | 37 | lambda_perceptual: 10.0 38 | lambda_perceptual_end: ${model.losses.lambda_perceptual} 39 | lambda_perceptual_decay_steps: 1 40 | 41 | lambda_r1: 0 42 | lambda_r1_end: ${model.losses.lambda_r1} 43 | lambda_r1_decay_steps: 1 44 | 45 | lambda_kld: 0 46 | lambda_kld_end: ${model.losses.lambda_kld} 47 | lambda_kld_decay_steps: 1 48 | 49 | lazy_r1_step: 16 50 | gan_loss_type: hinge 51 | gan_loss_on_swaps: False 52 | use_adaptive_lambda: False -------------------------------------------------------------------------------- /configs/model/sean.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - disc_augments: null 3 | - generator: sean 4 | - discriminator: patchgan 5 | - encoder: sean 6 | - optim: adam 7 | 8 | change_init: True 9 | gradient_clip_val: 0 10 | name: SEAN 11 | dataset_path: ${dataset.path} 12 | losses: 13 | lambda_gan: 1 14 | lambda_gan_end: ${model.losses.lambda_gan} 15 | lambda_gan_decay_steps: 1 16 | 17 | lambda_fm: 10.0 18 | lambda_fm_end: ${model.losses.lambda_fm} 19 | lambda_fm_decay_steps: 1 20 | 21 | lambda_label_mix: 0 22 | lambda_label_mix_end: ${model.losses.lambda_label_mix} 23 | lambda_label_mix_decay_steps: 1 24 | 25 | lambda_l1: 0 26 | lambda_l1_end: ${model.losses.lambda_l1} 27 | lambda_l1_decay_steps: 1 28 | 29 | lambda_perceptual: 10.0 30 | lambda_perceptual_end: ${model.losses.lambda_perceptual} 31 | lambda_perceptual_decay_steps: 1 32 | 33 | lambda_r1: 0 34 | lambda_r1_end: ${model.losses.lambda_r1} 35 | lambda_r1_decay_steps: 1 36 | 37 | lambda_kld: 0 38 | lambda_kld_end: ${model.losses.lambda_kld} 39 | lambda_kld_decay_steps: 1 40 | 41 | lazy_r1_step: 16 42 | gan_loss_type: hinge 43 | gan_loss_on_swaps: False 44 | use_adaptive_lambda: False -------------------------------------------------------------------------------- /configs/model/sean_clade.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - disc_augments: null 3 | - generator: sean_clade 4 | - discriminator: patchgan 5 | - encoder: sean 6 | - optim: adam 7 | 8 | change_init: True 9 | gradient_clip_val: 0 10 | name: SEAN_CLADE 11 | dataset_path: ${dataset.path} 12 | losses: 13 | lambda_gan: 1 14 | lambda_gan_end: ${model.losses.lambda_gan} 15 | lambda_gan_decay_steps: 1 16 | 17 | lambda_fm: 10.0 18 | lambda_fm_end: ${model.losses.lambda_fm} 19 | lambda_fm_decay_steps: 1 20 | 21 | lambda_label_mix: 0 22 | lambda_label_mix_end: ${model.losses.lambda_label_mix} 23 | lambda_label_mix_decay_steps: 1 24 | 25 | lambda_l1: 0 26 | lambda_l1_end: ${model.losses.lambda_l1} 27 | lambda_l1_decay_steps: 1 28 | 29 | lambda_perceptual: 10.0 30 | lambda_perceptual_end: ${model.losses.lambda_perceptual} 31 | lambda_perceptual_decay_steps: 1 32 | 33 | lambda_r1: 0 34 | lambda_r1_end: ${model.losses.lambda_r1} 35 | lambda_r1_decay_steps: 1 36 | 37 | lambda_kld: 0 38 | lambda_kld_end: ${model.losses.lambda_kld} 39 | lambda_kld_decay_steps: 1 40 | lazy_r1_step: 16 41 | gan_loss_type: hinge 42 | gan_loss_on_swaps: False 43 | use_adaptive_lambda : False -------------------------------------------------------------------------------- /configs/model/spade.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - disc_augments: null 3 | - generator: spade 4 | - discriminator: patchgan 5 | - encoder: spade 6 | - optim: adam 7 | 8 | change_init: True 9 | gradient_clip_val: 0 10 | name: SPADE 11 | dataset_path: ${dataset.path} 12 | losses: 13 | lambda_gan: 1 14 | lambda_gan_end: ${model.losses.lambda_gan} 15 | lambda_gan_decay_steps: 1 16 | 17 | lambda_fm: 10.0 18 | lambda_fm_end: ${model.losses.lambda_fm} 19 | lambda_fm_decay_steps: 1 20 | 21 | lambda_label_mix: 0 22 | lambda_label_mix_end: ${model.losses.lambda_label_mix} 23 | lambda_label_mix_decay_steps: 1 24 | 25 | lambda_l1: 0 26 | lambda_l1_end: ${model.losses.lambda_l1} 27 | lambda_l1_decay_steps: 1 28 | 29 | lambda_perceptual: 10.0 30 | lambda_perceptual_end: ${model.losses.lambda_perceptual} 31 | lambda_perceptual_decay_steps: 1 32 | 33 | lambda_r1: 0 34 | lambda_r1_end: ${model.losses.lambda_r1} 35 | lambda_r1_decay_steps: 1 36 | 37 | lambda_kld: 0.05 38 | lambda_kld_end: ${model.losses.lambda_kld} 39 | lambda_kld_decay_steps: 1 40 | lazy_r1_step: 16 41 | gan_loss_type: hinge 42 | gan_loss_on_swaps: False 43 | use_adaptive_lambda: False -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nicolas-dufour/SCAM/848484acd71c277f3ed48456e206870847dd400d/data/__init__.py -------------------------------------------------------------------------------- /data/datamodule.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | import torchvision.transforms as transforms 3 | from torch.utils.data import DataLoader 4 | from pathlib import Path 5 | from .image_dataset import ImageDataset 6 | import albumentations as A 7 | from hydra.utils import instantiate 8 | 9 | 10 | class ImageDataModule(pl.LightningDataModule): 11 | """ 12 | Module to load image data 13 | """ 14 | 15 | def __init__(self, cfg): 16 | super().__init__() 17 | self.cfg = cfg 18 | 19 | def setup(self, stage=None): 20 | 21 | train_transforms = instantiate(self.cfg.train_aug) 22 | val_transforms = instantiate(self.cfg.val_aug) 23 | train_path = Path(self.cfg.path) / Path("train") 24 | test_path = Path(self.cfg.path) / Path("test") 25 | 26 | self.train_dataset = ImageDataset( 27 | train_path, 28 | self.cfg.num_labels_orig, 29 | train_transforms, 30 | image_extension=self.cfg.image_extension, 31 | label_merge_strat=self.cfg.label_merge_strat, 32 | ) 33 | self.test_dataset = ImageDataset( 34 | test_path, 35 | self.cfg.num_labels_orig, 36 | val_transforms, 37 | image_extension=self.cfg.image_extension, 38 | label_merge_strat=self.cfg.label_merge_strat, 39 | ) 40 | 41 | def train_dataloader(self): 42 | return DataLoader( 43 | self.train_dataset, 44 | batch_size=self.cfg.batch_size, 45 | shuffle=True, 46 | pin_memory=True, 47 | num_workers=self.cfg.num_workers, 48 | ) 49 | 50 | def val_dataloader(self): 51 | return DataLoader( 52 | self.test_dataset, 53 | batch_size=self.cfg.batch_size, 54 | shuffle=False, 55 | pin_memory=True, 56 | num_workers=self.cfg.num_workers, 57 | ) 58 | 59 | def test_dataloader(self): 60 | return DataLoader( 61 | self.test_dataset, 62 | batch_size=self.cfg.batch_size, 63 | shuffle=False, 64 | pin_memory=True, 65 | num_workers=self.cfg.num_workers, 66 | ) 67 | -------------------------------------------------------------------------------- /data/image_dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import torch 3 | import os 4 | from pathlib import Path 5 | import PIL 6 | from PIL import Image 7 | import numpy as np 8 | import torch.nn.functional as F 9 | 10 | 11 | class ImageDataset(Dataset): 12 | """ 13 | Image Dataset in jpg format 14 | """ 15 | 16 | def __init__( 17 | self, 18 | path, 19 | num_labels, 20 | transforms=None, 21 | image_extension="jpg", 22 | label_merge_strat="none", 23 | ): 24 | super().__init__() 25 | self.image_dir = Path(path) / Path("images") 26 | self.labels = Path(path) / Path("labels") 27 | 28 | self.image_extension = image_extension 29 | 30 | images_list = sorted(os.listdir(self.image_dir)) 31 | 32 | self.image_list_filtered = [] 33 | 34 | for image_name in images_list: 35 | if image_name.endswith(self.image_extension): 36 | self.image_list_filtered.append(image_name.split(".")[0]) 37 | 38 | self.transforms = transforms 39 | self.num_labels = num_labels 40 | self.label_merge_strat = label_merge_strat 41 | 42 | def __getitem__(self, index): 43 | image_name = self.image_list_filtered[index] 44 | image = np.array( 45 | Image.open( 46 | self.image_dir / Path(f"{image_name}.{self.image_extension}") 47 | ).convert("RGB") 48 | ) 49 | segmentation_mask = np.array( 50 | Image.open(self.labels / Path(f"{image_name}.png")).resize( 51 | image.shape[:-1][::-1], resample=Image.NEAREST 52 | ) 53 | ) 54 | if self.transforms: 55 | augmented = self.transforms(image=image, mask=segmentation_mask) 56 | image = augmented["image"] 57 | segmentation_mask = augmented["mask"] 58 | segmentation_mask = self.onehot_encode_labels(segmentation_mask) 59 | if self.label_merge_strat == "body_background": 60 | background = segmentation_mask[0] 61 | body = segmentation_mask[1:].max(dim=0).values 62 | segmentation_mask = torch.stack([background, body], dim=0) 63 | 64 | elif self.label_merge_strat == "body_face_background": 65 | background = segmentation_mask[0] 66 | face = ( 67 | torch.stack( 68 | [ 69 | segmentation_mask[1], 70 | segmentation_mask[2], 71 | segmentation_mask[4], 72 | segmentation_mask[13], 73 | ], 74 | dim=0, 75 | ) 76 | .max(dim=0) 77 | .values 78 | ) 79 | segmentation_mask[4] = torch.zeros_like(segmentation_mask[4]) 80 | segmentation_mask[13] = torch.zeros_like(segmentation_mask[13]) 81 | body = segmentation_mask[3:].max(dim=0).values 82 | segmentation_mask = torch.stack([background, face, body], dim=0) 83 | return image, segmentation_mask 84 | 85 | def __len__(self): 86 | return len(self.image_list_filtered) 87 | 88 | def onehot_encode_labels(self, labels): 89 | height, width = labels.shape 90 | labels = labels.unsqueeze(0).long() 91 | input_label = torch.FloatTensor(self.num_labels, height, width).zero_() 92 | input_semantics = input_label.scatter_(0, labels, 1.0) 93 | return input_semantics 94 | -------------------------------------------------------------------------------- /datasets/.empty: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nicolas-dufour/SCAM/848484acd71c277f3ed48456e206870847dd400d/datasets/.empty -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: scam 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - numpy=1.21.2 7 | - pillow=9.0.1 8 | - python=3.9.12 9 | - pytorch=1.12 10 | - pip=21.2.4 11 | - pip: 12 | - torchvision==0.13 13 | - albumentations==1.1.0 14 | - einops==0.4.1 15 | - hydra-core==1.1.1 16 | - kornia==0.6.4 17 | - pytorch-lightning==1.6.0 18 | - torch-fidelity==0.3.0 19 | - torchmetrics==0.7.3 20 | - torchtyping==0.1.4 21 | - tqdm==4.64.0 22 | - wandb==0.12.4 -------------------------------------------------------------------------------- /medias/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nicolas-dufour/SCAM/848484acd71c277f3ed48456e206870847dd400d/medias/.DS_Store -------------------------------------------------------------------------------- /medias/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nicolas-dufour/SCAM/848484acd71c277f3ed48456e206870847dd400d/medias/teaser.png -------------------------------------------------------------------------------- /metrics/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nicolas-dufour/SCAM/848484acd71c277f3ed48456e206870847dd400d/metrics/__init__.py -------------------------------------------------------------------------------- /metrics/fid.py: -------------------------------------------------------------------------------- 1 | from torch_fidelity.feature_extractor_inceptionv3 import FeatureExtractorInceptionV3 2 | from typing import Any, Callable, List, Optional, Union 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | from torch import Tensor 8 | from torch.autograd import Function 9 | from torchvision.transforms import Resize 10 | from pathlib import Path 11 | from torchmetrics.metric import Metric 12 | from tqdm import tqdm 13 | 14 | from utils.utils import remap_image_torch 15 | 16 | 17 | class MatrixSquareRoot(Function): 18 | """Square root of a positive definite matrix. 19 | All credit to: 20 | https://github.com/steveli/pytorch-sqrtm/blob/master/sqrtm.py 21 | """ 22 | 23 | @staticmethod 24 | def forward(ctx: Any, input: Tensor) -> Tensor: 25 | import scipy 26 | 27 | m = input.detach().cpu().numpy().astype(np.float_) 28 | scipy_res, _ = scipy.linalg.sqrtm(m, disp=False) 29 | sqrtm = torch.from_numpy(scipy_res.real).to(input) 30 | ctx.save_for_backward(sqrtm) 31 | return sqrtm 32 | 33 | @staticmethod 34 | def backward(ctx: Any, grad_output: Tensor) -> Tensor: 35 | import scipy 36 | 37 | grad_input = None 38 | if ctx.needs_input_grad[0]: 39 | (sqrtm,) = ctx.saved_tensors 40 | sqrtm = sqrtm.data.cpu().numpy().astype(np.float_) 41 | gm = grad_output.data.cpu().numpy().astype(np.float_) 42 | 43 | # Given a positive semi-definite matrix X, 44 | # since X = X^{1/2}X^{1/2}, we can compute the gradient of the 45 | # matrix square root dX^{1/2} by solving the Sylvester equation: 46 | # dX = (d(X^{1/2})X^{1/2} + X^{1/2}(dX^{1/2}). 47 | grad_sqrtm = scipy.linalg.solve_sylvester(sqrtm, sqrtm, gm) 48 | 49 | grad_input = torch.from_numpy(grad_sqrtm).to(grad_output) 50 | return grad_input 51 | 52 | 53 | sqrtm = MatrixSquareRoot.apply 54 | 55 | 56 | class NoTrainInceptionV3(FeatureExtractorInceptionV3): 57 | def __init__( 58 | self, 59 | name: str, 60 | features_list: List[str], 61 | feature_extractor_weights_path: Optional[str] = None, 62 | ) -> None: 63 | super().__init__(name, features_list, feature_extractor_weights_path) 64 | self.eval() 65 | self.num_out_feat = int(features_list[0]) 66 | 67 | def train(self, mode): 68 | return super().train(False) 69 | 70 | def forward(self, x: Tensor) -> Tensor: 71 | x = remap_image_torch(x) 72 | out = super().forward(x) 73 | return out[0].reshape(x.shape[0], -1) 74 | 75 | 76 | def compute_fid(mu_1, mu_2, sigma_1, sigma_2, eps=1e-6): 77 | mean_diff = mu_1 - mu_2 78 | 79 | mean_dist = mean_diff.dot(mean_diff) 80 | 81 | covmean = sqrtm(sigma_1.mm(sigma_2)) 82 | 83 | if not torch.isfinite(covmean).all(): 84 | offset = torch.eye(sigma_1.size(0), device=mu_1.device, dtype=mu_1.dtype) * eps 85 | covmean = sqrtm((sigma_1 + offset).mm(sigma_2 + offset)) 86 | return ( 87 | mean_dist 88 | + torch.trace(sigma_1) 89 | + torch.trace(sigma_2) 90 | - 2 * torch.trace(covmean) 91 | ) 92 | 93 | 94 | class FID(Metric): 95 | def __init__( 96 | self, 97 | feature_extractor, 98 | real_features_path, 99 | compute_on_step: bool = False, 100 | dist_sync_on_step: bool = False, 101 | process_group: Optional[Any] = None, 102 | dist_sync_fn: Callable = None, 103 | ): 104 | 105 | super().__init__( 106 | compute_on_step=compute_on_step, 107 | dist_sync_on_step=dist_sync_on_step, 108 | process_group=process_group, 109 | dist_sync_fn=dist_sync_fn, 110 | ) 111 | 112 | self.feature_extractor = feature_extractor 113 | 114 | mean_real = torch.load(f"{real_features_path}/mean.pt") 115 | sigma_real = torch.load(f"{real_features_path}/sigma.pt") 116 | 117 | self.add_state("mean_real", default=mean_real, dist_reduce_fx="mean") 118 | self.add_state("sigma_real", default=sigma_real, dist_reduce_fx="mean") 119 | 120 | self.add_state( 121 | "generated_features_sum", 122 | torch.zeros(self.feature_extractor.num_out_feat, dtype=torch.double), 123 | dist_reduce_fx="sum", 124 | ) 125 | self.add_state( 126 | "generated_features_cov_sum", 127 | torch.zeros( 128 | ( 129 | self.feature_extractor.num_out_feat, 130 | self.feature_extractor.num_out_feat, 131 | ), 132 | dtype=torch.double, 133 | ), 134 | dist_reduce_fx="sum", 135 | ) 136 | self.add_state( 137 | "generated_features_num_samples", 138 | torch.tensor(0).long(), 139 | dist_reduce_fx="sum", 140 | ) 141 | 142 | def update(self, images): 143 | features = self.feature_extractor(images).double() 144 | self.generated_features_sum += features.sum(dim=0) 145 | self.generated_features_cov_sum += features.t().mm(features) 146 | self.generated_features_num_samples += images.shape[0] 147 | 148 | def compute(self): 149 | mean_real = self.mean_real 150 | mean_generated = ( 151 | self.generated_features_sum / self.generated_features_num_samples 152 | ).unsqueeze(dim=0) 153 | 154 | sigma_real = self.sigma_real 155 | sigma_generated = ( 156 | self.generated_features_cov_sum 157 | - self.generated_features_num_samples 158 | * mean_generated.t().mm(mean_generated) 159 | ) / (self.generated_features_num_samples - 1) 160 | return compute_fid(mean_real, mean_generated[0], sigma_real, sigma_generated) 161 | 162 | 163 | def compute_fid_features(dataloader, save_path, device="cpu"): 164 | feature_extractor = NoTrainInceptionV3( 165 | name="inception-v3-compat", features_list=[str(2048)] 166 | ).to(device) 167 | real_features_sum = torch.zeros( 168 | feature_extractor.num_out_feat, dtype=torch.double, device=device 169 | ) 170 | real_features_cov_sum = torch.zeros( 171 | ( 172 | feature_extractor.num_out_feat, 173 | feature_extractor.num_out_feat, 174 | ), 175 | dtype=torch.double, 176 | device=device, 177 | ) 178 | real_features_num_samples = 0 179 | with torch.no_grad(): 180 | for i, batch in enumerate(tqdm(dataloader)): 181 | images, _ = batch 182 | images = images.to(device) 183 | features = feature_extractor(images).double() 184 | real_features_num_samples += features.shape[0] 185 | real_features_sum += features.sum(dim=0) 186 | real_features_cov_sum += features.t().mm(features) 187 | mean = (real_features_sum / real_features_num_samples).unsqueeze(0) 188 | sigma = (real_features_cov_sum - real_features_num_samples * mean.t().mm(mean)) / ( 189 | real_features_num_samples - 1 190 | ) 191 | mean = mean.squeeze(0) 192 | torch.save(mean.cpu(), Path(save_path) / Path("mean.pt")) 193 | torch.save(sigma.cpu(), Path(save_path) / Path("sigma.pt")) 194 | -------------------------------------------------------------------------------- /metrics/models/fastreid/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | 8 | __version__ = "0.1.0" -------------------------------------------------------------------------------- /metrics/models/fastreid/layers/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | from .activation import * 8 | from .arc_softmax import ArcSoftmax 9 | from .circle_softmax import CircleSoftmax 10 | from .am_softmax import AMSoftmax 11 | from .batch_drop import BatchDrop 12 | from .batch_norm import * 13 | from .context_block import ContextBlock 14 | from .frn import FRN, TLU 15 | from .non_local import Non_local 16 | from .pooling import * 17 | from .se_layer import SELayer 18 | from .splat import SplAtConv2d 19 | from .gather_layer import GatherLayer 20 | -------------------------------------------------------------------------------- /metrics/models/fastreid/layers/activation.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: xingyu liao 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import math 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | __all__ = [ 14 | 'Mish', 15 | 'Swish', 16 | 'MemoryEfficientSwish', 17 | 'GELU'] 18 | 19 | 20 | class Mish(nn.Module): 21 | def __init__(self): 22 | super().__init__() 23 | 24 | def forward(self, x): 25 | # inlining this saves 1 second per epoch (V100 GPU) vs having a temp x and then returning x(!) 26 | return x * (torch.tanh(F.softplus(x))) 27 | 28 | 29 | class Swish(nn.Module): 30 | def forward(self, x): 31 | return x * torch.sigmoid(x) 32 | 33 | 34 | class SwishImplementation(torch.autograd.Function): 35 | @staticmethod 36 | def forward(ctx, i): 37 | result = i * torch.sigmoid(i) 38 | ctx.save_for_backward(i) 39 | return result 40 | 41 | @staticmethod 42 | def backward(ctx, grad_output): 43 | i = ctx.saved_variables[0] 44 | sigmoid_i = torch.sigmoid(i) 45 | return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i))) 46 | 47 | 48 | class MemoryEfficientSwish(nn.Module): 49 | def forward(self, x): 50 | return SwishImplementation.apply(x) 51 | 52 | 53 | class GELU(nn.Module): 54 | """ 55 | Paper Section 3.4, last paragraph notice that BERT used the GELU instead of RELU 56 | """ 57 | 58 | def forward(self, x): 59 | return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 60 | -------------------------------------------------------------------------------- /metrics/models/fastreid/layers/am_softmax.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: xingyu liao 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import torch 8 | from torch import nn 9 | import torch.nn.functional as F 10 | from torch.nn import Parameter 11 | 12 | 13 | class AMSoftmax(nn.Module): 14 | r"""Implement of large margin cosine distance: 15 | Args: 16 | in_feat: size of each input sample 17 | num_classes: size of each output sample 18 | """ 19 | 20 | def __init__(self, cfg, in_feat, num_classes): 21 | super().__init__() 22 | self.in_features = in_feat 23 | self._num_classes = num_classes 24 | self.s = cfg.MODEL.HEADS.SCALE 25 | self.m = cfg.MODEL.HEADS.MARGIN 26 | self.weight = Parameter(torch.Tensor(num_classes, in_feat)) 27 | nn.init.xavier_uniform_(self.weight) 28 | 29 | def forward(self, features, targets): 30 | # --------------------------- cos(theta) & phi(theta) --------------------------- 31 | cosine = F.linear(F.normalize(features), F.normalize(self.weight)) 32 | phi = cosine - self.m 33 | # --------------------------- convert label to one-hot --------------------------- 34 | targets = F.one_hot(targets, num_classes=self._num_classes) 35 | output = (targets * phi) + ((1.0 - targets) * cosine) 36 | output *= self.s 37 | 38 | return output 39 | 40 | def extra_repr(self): 41 | return 'in_features={}, num_classes={}, scale={}, margin={}'.format( 42 | self.in_feat, self._num_classes, self.s, self.m 43 | ) 44 | -------------------------------------------------------------------------------- /metrics/models/fastreid/layers/arc_softmax.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import math 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from torch.nn import Parameter 13 | 14 | 15 | class ArcSoftmax(nn.Module): 16 | def __init__(self, cfg, in_feat, num_classes): 17 | super().__init__() 18 | self.in_feat = in_feat 19 | self._num_classes = num_classes 20 | self.s = cfg.MODEL.HEADS.SCALE 21 | self.m = cfg.MODEL.HEADS.MARGIN 22 | 23 | self.cos_m = math.cos(self.m) 24 | self.sin_m = math.sin(self.m) 25 | self.threshold = math.cos(math.pi - self.m) 26 | self.mm = math.sin(math.pi - self.m) * self.m 27 | 28 | self.weight = Parameter(torch.Tensor(num_classes, in_feat)) 29 | nn.init.xavier_uniform_(self.weight) 30 | self.register_buffer('t', torch.zeros(1)) 31 | 32 | def forward(self, features, targets): 33 | # get cos(theta) 34 | cos_theta = F.linear(F.normalize(features), F.normalize(self.weight)) 35 | cos_theta = cos_theta.clamp(-1, 1) # for numerical stability 36 | 37 | target_logit = cos_theta[torch.arange(0, features.size(0)), targets].view(-1, 1) 38 | 39 | sin_theta = torch.sqrt(1.0 - torch.pow(target_logit, 2)) 40 | cos_theta_m = target_logit * self.cos_m - sin_theta * self.sin_m # cos(target+margin) 41 | mask = cos_theta > cos_theta_m 42 | final_target_logit = torch.where(target_logit > self.threshold, cos_theta_m, target_logit - self.mm) 43 | 44 | hard_example = cos_theta[mask] 45 | with torch.no_grad(): 46 | self.t = target_logit.mean() * 0.01 + (1 - 0.01) * self.t 47 | cos_theta[mask] = hard_example * (self.t + hard_example) 48 | cos_theta.scatter_(1, targets.view(-1, 1).long(), final_target_logit) 49 | pred_class_logits = cos_theta * self.s 50 | return pred_class_logits 51 | 52 | def extra_repr(self): 53 | return 'in_features={}, num_classes={}, scale={}, margin={}'.format( 54 | self.in_feat, self._num_classes, self.s, self.m 55 | ) 56 | -------------------------------------------------------------------------------- /metrics/models/fastreid/layers/batch_drop.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import random 8 | 9 | from torch import nn 10 | 11 | 12 | class BatchDrop(nn.Module): 13 | """ref: https://github.com/daizuozhuo/batch-dropblock-network/blob/master/models/networks.py 14 | batch drop mask 15 | """ 16 | 17 | def __init__(self, h_ratio, w_ratio): 18 | super(BatchDrop, self).__init__() 19 | self.h_ratio = h_ratio 20 | self.w_ratio = w_ratio 21 | 22 | def forward(self, x): 23 | if self.training: 24 | h, w = x.size()[-2:] 25 | rh = round(self.h_ratio * h) 26 | rw = round(self.w_ratio * w) 27 | sx = random.randint(0, h - rh) 28 | sy = random.randint(0, w - rw) 29 | mask = x.new_ones(x.size()) 30 | mask[:, :, sx:sx + rh, sy:sy + rw] = 0 31 | x = x * mask 32 | return x 33 | -------------------------------------------------------------------------------- /metrics/models/fastreid/layers/batch_norm.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import logging 8 | 9 | import torch 10 | import torch.nn.functional as F 11 | from torch import nn 12 | 13 | __all__ = [ 14 | "BatchNorm", 15 | "IBN", 16 | "GhostBatchNorm", 17 | "FrozenBatchNorm", 18 | "SyncBatchNorm", 19 | "get_norm", 20 | ] 21 | 22 | 23 | class BatchNorm(nn.BatchNorm2d): 24 | def __init__(self, num_features, eps=1e-05, momentum=0.1, weight_freeze=False, bias_freeze=False, weight_init=1.0, 25 | bias_init=0.0, **kwargs): 26 | super().__init__(num_features, eps=eps, momentum=momentum) 27 | if weight_init is not None: nn.init.constant_(self.weight, weight_init) 28 | if bias_init is not None: nn.init.constant_(self.bias, bias_init) 29 | self.weight.requires_grad_(not weight_freeze) 30 | self.bias.requires_grad_(not bias_freeze) 31 | 32 | 33 | class SyncBatchNorm(nn.SyncBatchNorm): 34 | def __init__(self, num_features, eps=1e-05, momentum=0.1, weight_freeze=False, bias_freeze=False, weight_init=1.0, 35 | bias_init=0.0): 36 | super().__init__(num_features, eps=eps, momentum=momentum) 37 | if weight_init is not None: nn.init.constant_(self.weight, weight_init) 38 | if bias_init is not None: nn.init.constant_(self.bias, bias_init) 39 | self.weight.requires_grad_(not weight_freeze) 40 | self.bias.requires_grad_(not bias_freeze) 41 | 42 | 43 | class IBN(nn.Module): 44 | def __init__(self, planes, bn_norm, **kwargs): 45 | super(IBN, self).__init__() 46 | half1 = int(planes / 2) 47 | self.half = half1 48 | half2 = planes - half1 49 | self.IN = nn.InstanceNorm2d(half1, affine=True) 50 | self.BN = get_norm(bn_norm, half2, **kwargs) 51 | 52 | def forward(self, x): 53 | split = torch.split(x, self.half, 1) 54 | out1 = self.IN(split[0].contiguous()) 55 | out2 = self.BN(split[1].contiguous()) 56 | out = torch.cat((out1, out2), 1) 57 | return out 58 | 59 | 60 | class GhostBatchNorm(BatchNorm): 61 | def __init__(self, num_features, num_splits=1, **kwargs): 62 | super().__init__(num_features, **kwargs) 63 | self.num_splits = num_splits 64 | self.register_buffer('running_mean', torch.zeros(num_features)) 65 | self.register_buffer('running_var', torch.ones(num_features)) 66 | 67 | def forward(self, input): 68 | N, C, H, W = input.shape 69 | if self.training or not self.track_running_stats: 70 | self.running_mean = self.running_mean.repeat(self.num_splits) 71 | self.running_var = self.running_var.repeat(self.num_splits) 72 | outputs = F.batch_norm( 73 | input.view(-1, C * self.num_splits, H, W), self.running_mean, self.running_var, 74 | self.weight.repeat(self.num_splits), self.bias.repeat(self.num_splits), 75 | True, self.momentum, self.eps).view(N, C, H, W) 76 | self.running_mean = torch.mean(self.running_mean.view(self.num_splits, self.num_features), dim=0) 77 | self.running_var = torch.mean(self.running_var.view(self.num_splits, self.num_features), dim=0) 78 | return outputs 79 | else: 80 | return F.batch_norm( 81 | input, self.running_mean, self.running_var, 82 | self.weight, self.bias, False, self.momentum, self.eps) 83 | 84 | 85 | class FrozenBatchNorm(BatchNorm): 86 | """ 87 | BatchNorm2d where the batch statistics and the affine parameters are fixed. 88 | It contains non-trainable buffers called 89 | "weight" and "bias", "running_mean", "running_var", 90 | initialized to perform identity transformation. 91 | The pre-trained backbone models from Caffe2 only contain "weight" and "bias", 92 | which are computed from the original four parameters of BN. 93 | The affine transform `x * weight + bias` will perform the equivalent 94 | computation of `(x - running_mean) / sqrt(running_var) * weight + bias`. 95 | When loading a backbone model from Caffe2, "running_mean" and "running_var" 96 | will be left unchanged as identity transformation. 97 | Other pre-trained backbone models may contain all 4 parameters. 98 | The forward is implemented by `F.batch_norm(..., training=False)`. 99 | """ 100 | 101 | _version = 3 102 | 103 | def __init__(self, num_features, eps=1e-5, **kwargs): 104 | super().__init__(num_features, weight_freeze=True, bias_freeze=True, **kwargs) 105 | self.num_features = num_features 106 | self.eps = eps 107 | 108 | def forward(self, x): 109 | if x.requires_grad: 110 | # When gradients are needed, F.batch_norm will use extra memory 111 | # because its backward op computes gradients for weight/bias as well. 112 | scale = self.weight * (self.running_var + self.eps).rsqrt() 113 | bias = self.bias - self.running_mean * scale 114 | scale = scale.reshape(1, -1, 1, 1) 115 | bias = bias.reshape(1, -1, 1, 1) 116 | return x * scale + bias 117 | else: 118 | # When gradients are not needed, F.batch_norm is a single fused op 119 | # and provide more optimization opportunities. 120 | return F.batch_norm( 121 | x, 122 | self.running_mean, 123 | self.running_var, 124 | self.weight, 125 | self.bias, 126 | training=False, 127 | eps=self.eps, 128 | ) 129 | 130 | def _load_from_state_dict( 131 | self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs 132 | ): 133 | version = local_metadata.get("version", None) 134 | 135 | if version is None or version < 2: 136 | # No running_mean/var in early versions 137 | # This will silent the warnings 138 | if prefix + "running_mean" not in state_dict: 139 | state_dict[prefix + "running_mean"] = torch.zeros_like(self.running_mean) 140 | if prefix + "running_var" not in state_dict: 141 | state_dict[prefix + "running_var"] = torch.ones_like(self.running_var) 142 | 143 | if version is not None and version < 3: 144 | logger = logging.getLogger(__name__) 145 | logger.info("FrozenBatchNorm {} is upgraded to version 3.".format(prefix.rstrip("."))) 146 | # In version < 3, running_var are used without +eps. 147 | state_dict[prefix + "running_var"] -= self.eps 148 | 149 | super()._load_from_state_dict( 150 | state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs 151 | ) 152 | 153 | def __repr__(self): 154 | return "FrozenBatchNorm2d(num_features={}, eps={})".format(self.num_features, self.eps) 155 | 156 | @classmethod 157 | def convert_frozen_batchnorm(cls, module): 158 | """ 159 | Convert BatchNorm/SyncBatchNorm in module into FrozenBatchNorm. 160 | Args: 161 | module (torch.nn.Module): 162 | Returns: 163 | If module is BatchNorm/SyncBatchNorm, returns a new module. 164 | Otherwise, in-place convert module and return it. 165 | Similar to convert_sync_batchnorm in 166 | https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/batchnorm.py 167 | """ 168 | bn_module = nn.modules.batchnorm 169 | bn_module = (bn_module.BatchNorm2d, bn_module.SyncBatchNorm) 170 | res = module 171 | if isinstance(module, bn_module): 172 | res = cls(module.num_features) 173 | if module.affine: 174 | res.weight.data = module.weight.data.clone().detach() 175 | res.bias.data = module.bias.data.clone().detach() 176 | res.running_mean.data = module.running_mean.data 177 | res.running_var.data = module.running_var.data 178 | res.eps = module.eps 179 | else: 180 | for name, child in module.named_children(): 181 | new_child = cls.convert_frozen_batchnorm(child) 182 | if new_child is not child: 183 | res.add_module(name, new_child) 184 | return res 185 | 186 | 187 | def get_norm(norm, out_channels, **kwargs): 188 | """ 189 | Args: 190 | norm (str or callable): either one of BN, GhostBN, FrozenBN, GN or SyncBN; 191 | or a callable that thakes a channel number and returns 192 | the normalization layer as a nn.Module 193 | out_channels: number of channels for normalization layer 194 | 195 | Returns: 196 | nn.Module or None: the normalization layer 197 | """ 198 | if isinstance(norm, str): 199 | if len(norm) == 0: 200 | return None 201 | norm = { 202 | "BN": BatchNorm, 203 | "GhostBN": GhostBatchNorm, 204 | "FrozenBN": FrozenBatchNorm, 205 | "GN": lambda channels, **args: nn.GroupNorm(32, channels), 206 | "syncBN": SyncBatchNorm, 207 | }[norm] 208 | return norm(out_channels, **kwargs) 209 | -------------------------------------------------------------------------------- /metrics/models/fastreid/layers/circle_softmax.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import math 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from torch.nn import Parameter 13 | 14 | 15 | class CircleSoftmax(nn.Module): 16 | def __init__(self, cfg, in_feat, num_classes): 17 | super().__init__() 18 | self.in_feat = in_feat 19 | self._num_classes = num_classes 20 | self.s = cfg.MODEL.HEADS.SCALE 21 | self.m = cfg.MODEL.HEADS.MARGIN 22 | 23 | self.weight = Parameter(torch.Tensor(num_classes, in_feat)) 24 | nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) 25 | 26 | def forward(self, features, targets): 27 | sim_mat = F.linear(F.normalize(features), F.normalize(self.weight)) 28 | alpha_p = torch.clamp_min(-sim_mat.detach() + 1 + self.m, min=0.) 29 | alpha_n = torch.clamp_min(sim_mat.detach() + self.m, min=0.) 30 | delta_p = 1 - self.m 31 | delta_n = self.m 32 | 33 | s_p = self.s * alpha_p * (sim_mat - delta_p) 34 | s_n = self.s * alpha_n * (sim_mat - delta_n) 35 | 36 | targets = F.one_hot(targets, num_classes=self._num_classes) 37 | 38 | pred_class_logits = targets * s_p + (1.0 - targets) * s_n 39 | 40 | return pred_class_logits 41 | 42 | def extra_repr(self): 43 | return 'in_features={}, num_classes={}, scale={}, margin={}'.format( 44 | self.in_feat, self._num_classes, self.s, self.m 45 | ) 46 | -------------------------------------------------------------------------------- /metrics/models/fastreid/layers/context_block.py: -------------------------------------------------------------------------------- 1 | # copy from https://github.com/xvjiarui/GCNet/blob/master/mmdet/ops/gcb/context_block.py 2 | 3 | import torch 4 | from torch import nn 5 | 6 | __all__ = ['ContextBlock'] 7 | 8 | 9 | def last_zero_init(m): 10 | if isinstance(m, nn.Sequential): 11 | nn.init.constant_(m[-1].weight, val=0) 12 | if hasattr(m[-1], 'bias') and m[-1].bias is not None: 13 | nn.init.constant_(m[-1].bias, 0) 14 | else: 15 | nn.init.constant_(m.weight, val=0) 16 | if hasattr(m, 'bias') and m.bias is not None: 17 | nn.init.constant_(m.bias, 0) 18 | 19 | 20 | class ContextBlock(nn.Module): 21 | 22 | def __init__(self, 23 | inplanes, 24 | ratio, 25 | pooling_type='att', 26 | fusion_types=('channel_add',)): 27 | super(ContextBlock, self).__init__() 28 | assert pooling_type in ['avg', 'att'] 29 | assert isinstance(fusion_types, (list, tuple)) 30 | valid_fusion_types = ['channel_add', 'channel_mul'] 31 | assert all([f in valid_fusion_types for f in fusion_types]) 32 | assert len(fusion_types) > 0, 'at least one fusion should be used' 33 | self.inplanes = inplanes 34 | self.ratio = ratio 35 | self.planes = int(inplanes * ratio) 36 | self.pooling_type = pooling_type 37 | self.fusion_types = fusion_types 38 | if pooling_type == 'att': 39 | self.conv_mask = nn.Conv2d(inplanes, 1, kernel_size=1) 40 | self.softmax = nn.Softmax(dim=2) 41 | else: 42 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 43 | if 'channel_add' in fusion_types: 44 | self.channel_add_conv = nn.Sequential( 45 | nn.Conv2d(self.inplanes, self.planes, kernel_size=1), 46 | nn.LayerNorm([self.planes, 1, 1]), 47 | nn.ReLU(inplace=True), # yapf: disable 48 | nn.Conv2d(self.planes, self.inplanes, kernel_size=1)) 49 | else: 50 | self.channel_add_conv = None 51 | if 'channel_mul' in fusion_types: 52 | self.channel_mul_conv = nn.Sequential( 53 | nn.Conv2d(self.inplanes, self.planes, kernel_size=1), 54 | nn.LayerNorm([self.planes, 1, 1]), 55 | nn.ReLU(inplace=True), # yapf: disable 56 | nn.Conv2d(self.planes, self.inplanes, kernel_size=1)) 57 | else: 58 | self.channel_mul_conv = None 59 | self.reset_parameters() 60 | 61 | def reset_parameters(self): 62 | if self.pooling_type == 'att': 63 | nn.init.kaiming_normal_(self.conv_mask.weight, a=0, mode='fan_in', nonlinearity='relu') 64 | if hasattr(self.conv_mask, 'bias') and self.conv_mask.bias is not None: 65 | nn.init.constant_(self.conv_mask.bias, 0) 66 | self.conv_mask.inited = True 67 | 68 | if self.channel_add_conv is not None: 69 | last_zero_init(self.channel_add_conv) 70 | if self.channel_mul_conv is not None: 71 | last_zero_init(self.channel_mul_conv) 72 | 73 | def spatial_pool(self, x): 74 | batch, channel, height, width = x.size() 75 | if self.pooling_type == 'att': 76 | input_x = x 77 | # [N, C, H * W] 78 | input_x = input_x.view(batch, channel, height * width) 79 | # [N, 1, C, H * W] 80 | input_x = input_x.unsqueeze(1) 81 | # [N, 1, H, W] 82 | context_mask = self.conv_mask(x) 83 | # [N, 1, H * W] 84 | context_mask = context_mask.view(batch, 1, height * width) 85 | # [N, 1, H * W] 86 | context_mask = self.softmax(context_mask) 87 | # [N, 1, H * W, 1] 88 | context_mask = context_mask.unsqueeze(-1) 89 | # [N, 1, C, 1] 90 | context = torch.matmul(input_x, context_mask) 91 | # [N, C, 1, 1] 92 | context = context.view(batch, channel, 1, 1) 93 | else: 94 | # [N, C, 1, 1] 95 | context = self.avg_pool(x) 96 | 97 | return context 98 | 99 | def forward(self, x): 100 | # [N, C, 1, 1] 101 | context = self.spatial_pool(x) 102 | 103 | out = x 104 | if self.channel_mul_conv is not None: 105 | # [N, C, 1, 1] 106 | channel_mul_term = torch.sigmoid(self.channel_mul_conv(context)) 107 | out = out * channel_mul_term 108 | if self.channel_add_conv is not None: 109 | # [N, C, 1, 1] 110 | channel_add_term = self.channel_add_conv(context) 111 | out = out + channel_add_term 112 | 113 | return out 114 | -------------------------------------------------------------------------------- /metrics/models/fastreid/layers/frn.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import torch 8 | from torch import nn 9 | from torch.nn.modules.batchnorm import BatchNorm2d 10 | from torch.nn import ReLU, LeakyReLU 11 | from torch.nn.parameter import Parameter 12 | 13 | 14 | class TLU(nn.Module): 15 | def __init__(self, num_features): 16 | """max(y, tau) = max(y - tau, 0) + tau = ReLU(y - tau) + tau""" 17 | super(TLU, self).__init__() 18 | self.num_features = num_features 19 | self.tau = Parameter(torch.Tensor(num_features)) 20 | self.reset_parameters() 21 | 22 | def reset_parameters(self): 23 | nn.init.zeros_(self.tau) 24 | 25 | def extra_repr(self): 26 | return 'num_features={num_features}'.format(**self.__dict__) 27 | 28 | def forward(self, x): 29 | return torch.max(x, self.tau.view(1, self.num_features, 1, 1)) 30 | 31 | 32 | class FRN(nn.Module): 33 | def __init__(self, num_features, eps=1e-6, is_eps_leanable=False): 34 | """ 35 | weight = gamma, bias = beta 36 | beta, gamma: 37 | Variables of shape [1, 1, 1, C]. if TensorFlow 38 | Variables of shape [1, C, 1, 1]. if PyTorch 39 | eps: A scalar constant or learnable variable. 40 | """ 41 | super(FRN, self).__init__() 42 | 43 | self.num_features = num_features 44 | self.init_eps = eps 45 | self.is_eps_leanable = is_eps_leanable 46 | 47 | self.weight = Parameter(torch.Tensor(num_features)) 48 | self.bias = Parameter(torch.Tensor(num_features)) 49 | if is_eps_leanable: 50 | self.eps = Parameter(torch.Tensor(1)) 51 | else: 52 | self.register_buffer('eps', torch.Tensor([eps])) 53 | self.reset_parameters() 54 | 55 | def reset_parameters(self): 56 | nn.init.ones_(self.weight) 57 | nn.init.zeros_(self.bias) 58 | if self.is_eps_leanable: 59 | nn.init.constant_(self.eps, self.init_eps) 60 | 61 | def extra_repr(self): 62 | return 'num_features={num_features}, eps={init_eps}'.format(**self.__dict__) 63 | 64 | def forward(self, x): 65 | """ 66 | 0, 1, 2, 3 -> (B, H, W, C) in TensorFlow 67 | 0, 1, 2, 3 -> (B, C, H, W) in PyTorch 68 | TensorFlow code 69 | nu2 = tf.reduce_mean(tf.square(x), axis=[1, 2], keepdims=True) 70 | x = x * tf.rsqrt(nu2 + tf.abs(eps)) 71 | # This Code include TLU function max(y, tau) 72 | return tf.maximum(gamma * x + beta, tau) 73 | """ 74 | # Compute the mean norm of activations per channel. 75 | nu2 = x.pow(2).mean(dim=[2, 3], keepdim=True) 76 | 77 | # Perform FRN. 78 | x = x * torch.rsqrt(nu2 + self.eps.abs()) 79 | 80 | # Scale and Bias 81 | x = self.weight.view(1, self.num_features, 1, 1) * x + self.bias.view(1, self.num_features, 1, 1) 82 | # x = self.weight * x + self.bias 83 | return x 84 | 85 | 86 | def bnrelu_to_frn(module): 87 | """ 88 | Convert 'BatchNorm2d + ReLU' to 'FRN + TLU' 89 | """ 90 | mod = module 91 | before_name = None 92 | before_child = None 93 | is_before_bn = False 94 | 95 | for name, child in module.named_children(): 96 | if is_before_bn and isinstance(child, (ReLU, LeakyReLU)): 97 | # Convert BN to FRN 98 | if isinstance(before_child, BatchNorm2d): 99 | mod.add_module( 100 | before_name, FRN(num_features=before_child.num_features)) 101 | else: 102 | raise NotImplementedError() 103 | 104 | # Convert ReLU to TLU 105 | mod.add_module(name, TLU(num_features=before_child.num_features)) 106 | else: 107 | mod.add_module(name, bnrelu_to_frn(child)) 108 | 109 | before_name = name 110 | before_child = child 111 | is_before_bn = isinstance(child, BatchNorm2d) 112 | return mod 113 | 114 | 115 | def convert(module, flag_name): 116 | mod = module 117 | before_ch = None 118 | for name, child in module.named_children(): 119 | if hasattr(child, flag_name) and getattr(child, flag_name): 120 | if isinstance(child, BatchNorm2d): 121 | before_ch = child.num_features 122 | mod.add_module(name, FRN(num_features=child.num_features)) 123 | # TODO bn is no good... 124 | if isinstance(child, (ReLU, LeakyReLU)): 125 | mod.add_module(name, TLU(num_features=before_ch)) 126 | else: 127 | mod.add_module(name, convert(child, flag_name)) 128 | return mod 129 | 130 | 131 | def remove_flags(module, flag_name): 132 | mod = module 133 | for name, child in module.named_children(): 134 | if hasattr(child, 'is_convert_frn'): 135 | delattr(child, flag_name) 136 | mod.add_module(name, remove_flags(child, flag_name)) 137 | else: 138 | mod.add_module(name, remove_flags(child, flag_name)) 139 | return mod 140 | 141 | 142 | def bnrelu_to_frn2(model, input_size=(3, 128, 128), batch_size=2, flag_name='is_convert_frn'): 143 | forard_hooks = list() 144 | backward_hooks = list() 145 | 146 | is_before_bn = [False] 147 | 148 | def register_forward_hook(module): 149 | def hook(self, input, output): 150 | if isinstance(module, (nn.Sequential, nn.ModuleList)) or (module == model): 151 | is_before_bn.append(False) 152 | return 153 | 154 | # input and output is required in hook def 155 | is_converted = is_before_bn[-1] and isinstance(self, (ReLU, LeakyReLU)) 156 | if is_converted: 157 | setattr(self, flag_name, True) 158 | is_before_bn.append(isinstance(self, BatchNorm2d)) 159 | 160 | forard_hooks.append(module.register_forward_hook(hook)) 161 | 162 | is_before_relu = [False] 163 | 164 | def register_backward_hook(module): 165 | def hook(self, input, output): 166 | if isinstance(module, (nn.Sequential, nn.ModuleList)) or (module == model): 167 | is_before_relu.append(False) 168 | return 169 | is_converted = is_before_relu[-1] and isinstance(self, BatchNorm2d) 170 | if is_converted: 171 | setattr(self, flag_name, True) 172 | is_before_relu.append(isinstance(self, (ReLU, LeakyReLU))) 173 | 174 | backward_hooks.append(module.register_backward_hook(hook)) 175 | 176 | # multiple inputs to the network 177 | if isinstance(input_size, tuple): 178 | input_size = [input_size] 179 | 180 | # batch_size of 2 for batchnorm 181 | x = [torch.rand(batch_size, *in_size) for in_size in input_size] 182 | 183 | # register hook 184 | model.apply(register_forward_hook) 185 | model.apply(register_backward_hook) 186 | 187 | # make a forward pass 188 | output = model(*x) 189 | output.sum().backward() # Raw output is not enabled to use backward() 190 | 191 | # remove these hooks 192 | for h in forard_hooks: 193 | h.remove() 194 | for h in backward_hooks: 195 | h.remove() 196 | 197 | model = convert(model, flag_name=flag_name) 198 | model = remove_flags(model, flag_name=flag_name) 199 | return model 200 | -------------------------------------------------------------------------------- /metrics/models/fastreid/layers/gather_layer.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: xingyu liao 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | # based on: https://github.com/open-mmlab/OpenSelfSup/blob/master/openselfsup/models/utils/gather_layer.py 8 | 9 | import torch 10 | import torch.distributed as dist 11 | 12 | 13 | class GatherLayer(torch.autograd.Function): 14 | """Gather tensors from all process, supporting backward propagation. 15 | """ 16 | 17 | @staticmethod 18 | def forward(ctx, input): 19 | ctx.save_for_backward(input) 20 | output = [torch.zeros_like(input) \ 21 | for _ in range(dist.get_world_size())] 22 | dist.all_gather(output, input) 23 | return tuple(output) 24 | 25 | @staticmethod 26 | def backward(ctx, *grads): 27 | input, = ctx.saved_tensors 28 | grad_out = torch.zeros_like(input) 29 | grad_out[:] = grads[dist.get_rank()] 30 | return grad_out 31 | -------------------------------------------------------------------------------- /metrics/models/fastreid/layers/non_local.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | 4 | import torch 5 | from torch import nn 6 | from .batch_norm import get_norm 7 | 8 | 9 | class Non_local(nn.Module): 10 | def __init__(self, in_channels, bn_norm, reduc_ratio=2): 11 | super(Non_local, self).__init__() 12 | 13 | self.in_channels = in_channels 14 | self.inter_channels = in_channels // reduc_ratio 15 | 16 | self.g = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, 17 | kernel_size=1, stride=1, padding=0) 18 | 19 | self.W = nn.Sequential( 20 | nn.Conv2d(in_channels=self.inter_channels, out_channels=self.in_channels, 21 | kernel_size=1, stride=1, padding=0), 22 | get_norm(bn_norm, self.in_channels), 23 | ) 24 | nn.init.constant_(self.W[1].weight, 0.0) 25 | nn.init.constant_(self.W[1].bias, 0.0) 26 | 27 | self.theta = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, 28 | kernel_size=1, stride=1, padding=0) 29 | 30 | self.phi = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, 31 | kernel_size=1, stride=1, padding=0) 32 | 33 | def forward(self, x): 34 | """ 35 | :param x: (b, t, h, w) 36 | :return x: (b, t, h, w) 37 | """ 38 | batch_size = x.size(0) 39 | g_x = self.g(x).view(batch_size, self.inter_channels, -1) 40 | g_x = g_x.permute(0, 2, 1) 41 | 42 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) 43 | theta_x = theta_x.permute(0, 2, 1) 44 | phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) 45 | f = torch.matmul(theta_x, phi_x) 46 | N = f.size(-1) 47 | f_div_C = f / N 48 | 49 | y = torch.matmul(f_div_C, g_x) 50 | y = y.permute(0, 2, 1).contiguous() 51 | y = y.view(batch_size, self.inter_channels, *x.size()[2:]) 52 | W_y = self.W(y) 53 | z = W_y + x 54 | return z 55 | -------------------------------------------------------------------------------- /metrics/models/fastreid/layers/pooling.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: l1aoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | from torch import nn 10 | 11 | __all__ = ["Flatten", 12 | "GeneralizedMeanPooling", 13 | "GeneralizedMeanPoolingP", 14 | "FastGlobalAvgPool2d", 15 | "AdaptiveAvgMaxPool2d", 16 | "ClipGlobalAvgPool2d", 17 | ] 18 | 19 | 20 | class Flatten(nn.Module): 21 | def forward(self, input): 22 | return input.view(input.size(0), -1) 23 | 24 | 25 | class GeneralizedMeanPooling(nn.Module): 26 | r"""Applies a 2D power-average adaptive pooling over an input signal composed of several input planes. 27 | The function computed is: :math:`f(X) = pow(sum(pow(X, p)), 1/p)` 28 | - At p = infinity, one gets Max Pooling 29 | - At p = 1, one gets Average Pooling 30 | The output is of size H x W, for any input size. 31 | The number of output features is equal to the number of input planes. 32 | Args: 33 | output_size: the target output size of the image of the form H x W. 34 | Can be a tuple (H, W) or a single H for a square image H x H 35 | H and W can be either a ``int``, or ``None`` which means the size will 36 | be the same as that of the input. 37 | """ 38 | 39 | def __init__(self, norm=3, output_size=1, eps=1e-6): 40 | super(GeneralizedMeanPooling, self).__init__() 41 | assert norm > 0 42 | self.p = float(norm) 43 | self.output_size = output_size 44 | self.eps = eps 45 | 46 | def forward(self, x): 47 | x = x.clamp(min=self.eps).pow(self.p) 48 | return torch.nn.functional.adaptive_avg_pool2d(x, self.output_size).pow(1. / self.p) 49 | 50 | def __repr__(self): 51 | return self.__class__.__name__ + '(' \ 52 | + str(self.p) + ', ' \ 53 | + 'output_size=' + str(self.output_size) + ')' 54 | 55 | 56 | class GeneralizedMeanPoolingP(GeneralizedMeanPooling): 57 | """ Same, but norm is trainable 58 | """ 59 | 60 | def __init__(self, norm=3, output_size=1, eps=1e-6): 61 | super(GeneralizedMeanPoolingP, self).__init__(norm, output_size, eps) 62 | self.p = nn.Parameter(torch.ones(1) * norm) 63 | 64 | 65 | class AdaptiveAvgMaxPool2d(nn.Module): 66 | def __init__(self): 67 | super(AdaptiveAvgMaxPool2d, self).__init__() 68 | self.gap = FastGlobalAvgPool2d() 69 | self.gmp = nn.AdaptiveMaxPool2d(1) 70 | 71 | def forward(self, x): 72 | avg_feat = self.gap(x) 73 | max_feat = self.gmp(x) 74 | feat = avg_feat + max_feat 75 | return feat 76 | 77 | 78 | class FastGlobalAvgPool2d(nn.Module): 79 | def __init__(self, flatten=False): 80 | super(FastGlobalAvgPool2d, self).__init__() 81 | self.flatten = flatten 82 | 83 | def forward(self, x): 84 | if self.flatten: 85 | in_size = x.size() 86 | return x.view((in_size[0], in_size[1], -1)).mean(dim=2) 87 | else: 88 | return x.view(x.size(0), x.size(1), -1).mean(-1).view(x.size(0), x.size(1), 1, 1) 89 | 90 | 91 | class ClipGlobalAvgPool2d(nn.Module): 92 | def __init__(self): 93 | super().__init__() 94 | self.avgpool = FastGlobalAvgPool2d() 95 | 96 | def forward(self, x): 97 | x = self.avgpool(x) 98 | x = torch.clamp(x, min=0., max=1.) 99 | return x 100 | -------------------------------------------------------------------------------- /metrics/models/fastreid/layers/se_layer.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | from torch import nn 8 | 9 | 10 | class SELayer(nn.Module): 11 | def __init__(self, channel, reduction=16): 12 | super(SELayer, self).__init__() 13 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 14 | self.fc = nn.Sequential( 15 | nn.Linear(channel, int(channel / reduction), bias=False), 16 | nn.ReLU(inplace=True), 17 | nn.Linear(int(channel / reduction), channel, bias=False), 18 | nn.Sigmoid() 19 | ) 20 | 21 | def forward(self, x): 22 | b, c, _, _ = x.size() 23 | y = self.avg_pool(x).view(b, c) 24 | y = self.fc(y).view(b, c, 1, 1) 25 | return x * y.expand_as(x) 26 | -------------------------------------------------------------------------------- /metrics/models/fastreid/layers/splat.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: xingyu liao 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | from torch import nn 10 | from torch.nn import Conv2d, ReLU 11 | from torch.nn.modules.utils import _pair 12 | from metrics.models.fastreid.layers import get_norm 13 | 14 | 15 | class SplAtConv2d(nn.Module): 16 | """Split-Attention Conv2d 17 | """ 18 | 19 | def __init__(self, in_channels, channels, kernel_size, stride=(1, 1), padding=(0, 0), 20 | dilation=(1, 1), groups=1, bias=True, 21 | radix=2, reduction_factor=4, 22 | rectify=False, rectify_avg=False, norm_layer=None, num_splits=1, 23 | dropblock_prob=0.0, **kwargs): 24 | super(SplAtConv2d, self).__init__() 25 | padding = _pair(padding) 26 | self.rectify = rectify and (padding[0] > 0 or padding[1] > 0) 27 | self.rectify_avg = rectify_avg 28 | inter_channels = max(in_channels * radix // reduction_factor, 32) 29 | self.radix = radix 30 | self.cardinality = groups 31 | self.channels = channels 32 | self.dropblock_prob = dropblock_prob 33 | if self.rectify: 34 | from rfconv import RFConv2d 35 | self.conv = RFConv2d(in_channels, channels * radix, kernel_size, stride, padding, dilation, 36 | groups=groups * radix, bias=bias, average_mode=rectify_avg, **kwargs) 37 | else: 38 | self.conv = Conv2d(in_channels, channels * radix, kernel_size, stride, padding, dilation, 39 | groups=groups * radix, bias=bias, **kwargs) 40 | self.use_bn = norm_layer is not None 41 | if self.use_bn: 42 | self.bn0 = get_norm(norm_layer, channels * radix) 43 | self.relu = ReLU(inplace=True) 44 | self.fc1 = Conv2d(channels, inter_channels, 1, groups=self.cardinality) 45 | if self.use_bn: 46 | self.bn1 = get_norm(norm_layer, inter_channels) 47 | self.fc2 = Conv2d(inter_channels, channels * radix, 1, groups=self.cardinality) 48 | 49 | self.rsoftmax = rSoftMax(radix, groups) 50 | 51 | def forward(self, x): 52 | x = self.conv(x) 53 | if self.use_bn: 54 | x = self.bn0(x) 55 | if self.dropblock_prob > 0.0: 56 | x = self.dropblock(x) 57 | x = self.relu(x) 58 | 59 | batch, rchannel = x.shape[:2] 60 | if self.radix > 1: 61 | splited = torch.split(x, rchannel // self.radix, dim=1) 62 | gap = sum(splited) 63 | else: 64 | gap = x 65 | gap = F.adaptive_avg_pool2d(gap, 1) 66 | gap = self.fc1(gap) 67 | 68 | if self.use_bn: 69 | gap = self.bn1(gap) 70 | gap = self.relu(gap) 71 | 72 | atten = self.fc2(gap) 73 | atten = self.rsoftmax(atten).view(batch, -1, 1, 1) 74 | 75 | if self.radix > 1: 76 | attens = torch.split(atten, rchannel // self.radix, dim=1) 77 | out = sum([att * split for (att, split) in zip(attens, splited)]) 78 | else: 79 | out = atten * x 80 | return out.contiguous() 81 | 82 | 83 | class rSoftMax(nn.Module): 84 | def __init__(self, radix, cardinality): 85 | super().__init__() 86 | self.radix = radix 87 | self.cardinality = cardinality 88 | 89 | def forward(self, x): 90 | batch = x.size(0) 91 | if self.radix > 1: 92 | x = x.view(batch, self.cardinality, self.radix, -1).transpose(1, 2) 93 | x = F.softmax(x, dim=1) 94 | x = x.reshape(batch, -1) 95 | else: 96 | x = torch.sigmoid(x) 97 | return x 98 | -------------------------------------------------------------------------------- /metrics/models/fastreid/model.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import logging 8 | import os 9 | import math 10 | 11 | import torch 12 | from torch import nn 13 | 14 | from metrics.models.fastreid.layers import ( 15 | IBN, 16 | SELayer, 17 | Non_local, 18 | get_norm, 19 | ) 20 | 21 | 22 | logger = logging.getLogger(__name__) 23 | model_urls = { 24 | '18x': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 25 | '34x': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 26 | '50x': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 27 | '101x': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 28 | '152x': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 29 | 'ibn_18x': 'https://github.com/XingangPan/IBN-Net/releases/download/v1.0/resnet18_ibn_a-2f571257.pth', 30 | 'ibn_34x': 'https://github.com/XingangPan/IBN-Net/releases/download/v1.0/resnet34_ibn_a-94bc1577.pth', 31 | 'ibn_50x': 'https://github.com/XingangPan/IBN-Net/releases/download/v1.0/resnet50_ibn_a-d9d0bb7b.pth', 32 | 'ibn_101x': 'https://github.com/XingangPan/IBN-Net/releases/download/v1.0/resnet101_ibn_a-59ea0ac6.pth', 33 | 'se_ibn_101x': 'https://github.com/XingangPan/IBN-Net/releases/download/v1.0/se_resnet101_ibn_a-fabed4e2.pth', 34 | } 35 | 36 | 37 | class BasicBlock(nn.Module): 38 | expansion = 1 39 | 40 | def __init__(self, inplanes, planes, bn_norm, with_ibn=False, with_se=False, 41 | stride=1, downsample=None, reduction=16): 42 | super(BasicBlock, self).__init__() 43 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 44 | if with_ibn: 45 | self.bn1 = IBN(planes, bn_norm) 46 | else: 47 | self.bn1 = get_norm(bn_norm, planes) 48 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 49 | self.bn2 = get_norm(bn_norm, planes) 50 | self.relu = nn.ReLU(inplace=True) 51 | if with_se: 52 | self.se = SELayer(planes, reduction) 53 | else: 54 | self.se = nn.Identity() 55 | self.downsample = downsample 56 | self.stride = stride 57 | 58 | def forward(self, x): 59 | identity = x 60 | 61 | out = self.conv1(x) 62 | out = self.bn1(out) 63 | out = self.relu(out) 64 | 65 | out = self.conv2(out) 66 | out = self.bn2(out) 67 | 68 | if self.downsample is not None: 69 | identity = self.downsample(x) 70 | 71 | out += identity 72 | out = self.relu(out) 73 | 74 | return out 75 | 76 | 77 | class Bottleneck(nn.Module): 78 | expansion = 4 79 | 80 | def __init__(self, inplanes, planes, bn_norm, with_ibn=False, with_se=False, 81 | stride=1, downsample=None, reduction=16): 82 | super(Bottleneck, self).__init__() 83 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 84 | if with_ibn: 85 | self.bn1 = IBN(planes, bn_norm) 86 | else: 87 | self.bn1 = get_norm(bn_norm, planes) 88 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 89 | padding=1, bias=False) 90 | self.bn2 = get_norm(bn_norm, planes) 91 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) 92 | self.bn3 = get_norm(bn_norm, planes * self.expansion) 93 | self.relu = nn.ReLU(inplace=True) 94 | if with_se: 95 | self.se = SELayer(planes * self.expansion, reduction) 96 | else: 97 | self.se = nn.Identity() 98 | self.downsample = downsample 99 | self.stride = stride 100 | 101 | def forward(self, x): 102 | residual = x 103 | 104 | out = self.conv1(x) 105 | out = self.bn1(out) 106 | out = self.relu(out) 107 | 108 | out = self.conv2(out) 109 | out = self.bn2(out) 110 | out = self.relu(out) 111 | 112 | out = self.conv3(out) 113 | out = self.bn3(out) 114 | out = self.se(out) 115 | 116 | if self.downsample is not None: 117 | residual = self.downsample(x) 118 | 119 | out += residual 120 | out = self.relu(out) 121 | 122 | return out 123 | 124 | 125 | class ResNet(nn.Module): 126 | def __init__(self, last_stride, bn_norm, with_ibn, with_se, with_nl, block, layers, non_layers): 127 | self.inplanes = 64 128 | super().__init__() 129 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 130 | bias=False) 131 | self.bn1 = get_norm(bn_norm, 64) 132 | self.relu = nn.ReLU(inplace=True) 133 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 134 | # self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True) 135 | self.layer1 = self._make_layer(block, 64, layers[0], 1, bn_norm, with_ibn, with_se) 136 | self.layer2 = self._make_layer(block, 128, layers[1], 2, bn_norm, with_ibn, with_se) 137 | self.layer3 = self._make_layer(block, 256, layers[2], 2, bn_norm, with_ibn, with_se) 138 | self.layer4 = self._make_layer(block, 512, layers[3], last_stride, bn_norm, with_se=with_se) 139 | 140 | self.random_init() 141 | 142 | # fmt: off 143 | if with_nl: self._build_nonlocal(layers, non_layers, bn_norm) 144 | else: self.NL_1_idx = self.NL_2_idx = self.NL_3_idx = self.NL_4_idx = [] 145 | # fmt: on 146 | 147 | def _make_layer(self, block, planes, blocks, stride=1, bn_norm="BN", with_ibn=False, with_se=False): 148 | downsample = None 149 | if stride != 1 or self.inplanes != planes * block.expansion: 150 | downsample = nn.Sequential( 151 | nn.Conv2d(self.inplanes, planes * block.expansion, 152 | kernel_size=1, stride=stride, bias=False), 153 | get_norm(bn_norm, planes * block.expansion), 154 | ) 155 | 156 | layers = [] 157 | layers.append(block(self.inplanes, planes, bn_norm, with_ibn, with_se, stride, downsample)) 158 | self.inplanes = planes * block.expansion 159 | for i in range(1, blocks): 160 | layers.append(block(self.inplanes, planes, bn_norm, with_ibn, with_se)) 161 | 162 | return nn.Sequential(*layers) 163 | 164 | def _build_nonlocal(self, layers, non_layers, bn_norm): 165 | self.NL_1 = nn.ModuleList( 166 | [Non_local(256, bn_norm) for _ in range(non_layers[0])]) 167 | self.NL_1_idx = sorted([layers[0] - (i + 1) for i in range(non_layers[0])]) 168 | self.NL_2 = nn.ModuleList( 169 | [Non_local(512, bn_norm) for _ in range(non_layers[1])]) 170 | self.NL_2_idx = sorted([layers[1] - (i + 1) for i in range(non_layers[1])]) 171 | self.NL_3 = nn.ModuleList( 172 | [Non_local(1024, bn_norm) for _ in range(non_layers[2])]) 173 | self.NL_3_idx = sorted([layers[2] - (i + 1) for i in range(non_layers[2])]) 174 | self.NL_4 = nn.ModuleList( 175 | [Non_local(2048, bn_norm) for _ in range(non_layers[3])]) 176 | self.NL_4_idx = sorted([layers[3] - (i + 1) for i in range(non_layers[3])]) 177 | 178 | def forward(self, x): 179 | x = self.conv1(x) 180 | x = self.bn1(x) 181 | x = self.relu(x) 182 | x = self.maxpool(x) 183 | 184 | NL1_counter = 0 185 | if len(self.NL_1_idx) == 0: 186 | self.NL_1_idx = [-1] 187 | for i in range(len(self.layer1)): 188 | x = self.layer1[i](x) 189 | if i == self.NL_1_idx[NL1_counter]: 190 | _, C, H, W = x.shape 191 | x = self.NL_1[NL1_counter](x) 192 | NL1_counter += 1 193 | # Layer 2 194 | NL2_counter = 0 195 | if len(self.NL_2_idx) == 0: 196 | self.NL_2_idx = [-1] 197 | for i in range(len(self.layer2)): 198 | x = self.layer2[i](x) 199 | if i == self.NL_2_idx[NL2_counter]: 200 | _, C, H, W = x.shape 201 | x = self.NL_2[NL2_counter](x) 202 | NL2_counter += 1 203 | # Layer 3 204 | NL3_counter = 0 205 | if len(self.NL_3_idx) == 0: 206 | self.NL_3_idx = [-1] 207 | for i in range(len(self.layer3)): 208 | x = self.layer3[i](x) 209 | if i == self.NL_3_idx[NL3_counter]: 210 | _, C, H, W = x.shape 211 | x = self.NL_3[NL3_counter](x) 212 | NL3_counter += 1 213 | # Layer 4 214 | NL4_counter = 0 215 | if len(self.NL_4_idx) == 0: 216 | self.NL_4_idx = [-1] 217 | for i in range(len(self.layer4)): 218 | x = self.layer4[i](x) 219 | if i == self.NL_4_idx[NL4_counter]: 220 | _, C, H, W = x.shape 221 | x = self.NL_4[NL4_counter](x) 222 | NL4_counter += 1 223 | 224 | return x 225 | 226 | def random_init(self): 227 | for m in self.modules(): 228 | if isinstance(m, nn.Conv2d): 229 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 230 | nn.init.normal_(m.weight, 0, math.sqrt(2. / n)) 231 | elif isinstance(m, nn.BatchNorm2d): 232 | nn.init.constant_(m.weight, 1) 233 | nn.init.constant_(m.bias, 0) 234 | 235 | 236 | def init_pretrained_weights(key): 237 | """Initializes model with pretrained weights. 238 | Layers that don't match with pretrained layers in name or size are kept unchanged. 239 | """ 240 | import os 241 | import errno 242 | import gdown 243 | 244 | def _get_torch_home(): 245 | ENV_TORCH_HOME = 'TORCH_HOME' 246 | ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME' 247 | DEFAULT_CACHE_DIR = '~/.cache' 248 | torch_home = os.path.expanduser( 249 | os.getenv( 250 | ENV_TORCH_HOME, 251 | os.path.join( 252 | os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'torch' 253 | ) 254 | ) 255 | ) 256 | return torch_home 257 | 258 | torch_home = _get_torch_home() 259 | model_dir = os.path.join(torch_home, 'checkpoints') 260 | try: 261 | os.makedirs(model_dir) 262 | except OSError as e: 263 | if e.errno == errno.EEXIST: 264 | # Directory already exists, ignore. 265 | pass 266 | else: 267 | # Unexpected OSError, re-raise. 268 | raise 269 | 270 | filename = model_urls[key].split('/')[-1] 271 | 272 | cached_file = os.path.join(model_dir, filename) 273 | 274 | if not os.path.exists(cached_file): 275 | if comm.is_main_process(): 276 | gdown.download(model_urls[key], cached_file, quiet=False) 277 | 278 | comm.synchronize() 279 | 280 | logger.info(f"Loading pretrained model from {cached_file}") 281 | state_dict = torch.load(cached_file, map_location=torch.device('cpu')) 282 | 283 | return state_dict 284 | 285 | 286 | def build_resnet_backbone(): 287 | """ 288 | Create a ResNet instance from config. 289 | Returns: 290 | ResNet: a :class:`ResNet` instance. 291 | """ 292 | 293 | # fmt: off 294 | pretrain = True 295 | pretrain_path = "metrics/models/weights/lup_moco_r50.pth" 296 | if not os.path.exists(pretrain_path): 297 | raise ValueError("Import model weights first") 298 | last_stride = 1 299 | bn_norm = "BN" 300 | with_ibn = False 301 | with_se = False 302 | with_nl = False 303 | depth = "50x" 304 | # fmt: on 305 | 306 | num_blocks_per_stage = { 307 | '18x': [2, 2, 2, 2], 308 | '34x': [3, 4, 6, 3], 309 | '50x': [3, 4, 6, 3], 310 | '101x': [3, 4, 23, 3], 311 | '152x': [3, 8, 36, 3], 312 | }[depth] 313 | 314 | nl_layers_per_stage = { 315 | '18x': [0, 0, 0, 0], 316 | '34x': [0, 0, 0, 0], 317 | '50x': [0, 2, 3, 0], 318 | '101x': [0, 2, 9, 0], 319 | '152x': [0, 4, 12, 0] 320 | }[depth] 321 | 322 | block = { 323 | '18x': BasicBlock, 324 | '34x': BasicBlock, 325 | '50x': Bottleneck, 326 | '101x': Bottleneck, 327 | '152x': Bottleneck, 328 | }[depth] 329 | 330 | model = ResNet(last_stride, bn_norm, with_ibn, with_se, with_nl, block, 331 | num_blocks_per_stage, nl_layers_per_stage) 332 | if pretrain: 333 | # Load pretrain path if specifically 334 | if pretrain_path: 335 | try: 336 | state_dict = torch.load(pretrain_path, map_location=torch.device('cpu')) 337 | logger.info(f"Loading pretrained model from {pretrain_path}") 338 | except FileNotFoundError as e: 339 | logger.info(f'{pretrain_path} is not found! Please check this path.') 340 | raise e 341 | except KeyError as e: 342 | logger.info("State dict keys error! Please check the state dict.") 343 | raise e 344 | else: 345 | key = depth 346 | if with_ibn: key = 'ibn_' + key 347 | if with_se: key = 'se_' + key 348 | 349 | state_dict = init_pretrained_weights(key) 350 | 351 | incompatible = model.load_state_dict(state_dict, strict=False) 352 | if incompatible.missing_keys: 353 | logger.info( 354 | get_missing_parameters_message(incompatible.missing_keys) 355 | ) 356 | if incompatible.unexpected_keys: 357 | logger.info( 358 | get_unexpected_parameters_message(incompatible.unexpected_keys) 359 | ) 360 | 361 | return model -------------------------------------------------------------------------------- /metrics/reid.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, List, Optional, Union 2 | from torchmetrics.metric import Metric 3 | import torch 4 | from tqdm import tqdm 5 | from metrics.models.fastreid.model import build_resnet_backbone 6 | import torch.nn.functional as F 7 | 8 | 9 | class ReidSubjectSuperiority(Metric): 10 | def __init__( 11 | self, 12 | compute_on_step: bool = False, 13 | dist_sync_on_step: bool = False, 14 | process_group: Optional[Any] = None, 15 | dist_sync_fn: Callable = None, 16 | ): 17 | 18 | super().__init__( 19 | compute_on_step=compute_on_step, 20 | dist_sync_on_step=dist_sync_on_step, 21 | process_group=process_group, 22 | dist_sync_fn=dist_sync_fn, 23 | ) 24 | self.reid_model = build_resnet_backbone() 25 | 26 | self.add_state("more_similar_to_subject", torch.tensor(0), dist_reduce_fx="sum") 27 | 28 | self.add_state( 29 | "num_samples", 30 | torch.tensor(0).long(), 31 | dist_reduce_fx="sum", 32 | ) 33 | 34 | def update( 35 | self, 36 | swap_images, 37 | background_images, 38 | background_segmask, 39 | subject_images, 40 | subject_segmask, 41 | ): 42 | self.num_samples += swap_images.shape[0] 43 | _, _, h, w = swap_images.shape 44 | background_background_segmask = 1 - background_segmask[:, 0, :, :] 45 | h_max = (background_background_segmask.sum(dim=2) > 0).long().argmax(dim=1) 46 | w_max = (background_background_segmask.sum(dim=1) > 0).long().argmax(dim=1) 47 | 48 | h_rev = [i for i in range(background_background_segmask.shape[1])] 49 | 50 | h_min = h - (background_background_segmask.sum(dim=2) > 0)[ 51 | :, h_rev[::-1] 52 | ].long().argmax(dim=1) 53 | w_rev = [i for i in range(background_background_segmask.shape[2])] 54 | w_min = w - (background_background_segmask.sum(dim=1) > 0)[ 55 | :, w_rev[::-1] 56 | ].long().argmax(dim=1) 57 | background_bbox = torch.stack([h_max, h_min, w_max, w_min], dim=1) 58 | 59 | subject_segmask = 1 - subject_segmask[:, 0, :, :] 60 | h_max = (subject_segmask.sum(dim=2) > 0).long().argmax(dim=1).long() 61 | w_max = (subject_segmask.sum(dim=1) > 0).long().argmax(dim=1).long() 62 | h_min = h - (subject_segmask.sum(dim=2) > 0)[:, h_rev[::-1]].long().argmax( 63 | dim=1 64 | ) 65 | w_min = w - (subject_segmask.sum(dim=1) > 0)[:, w_rev[::-1]].long().argmax( 66 | dim=1 67 | ) 68 | subject_bbox = torch.stack([h_max, h_min, w_max, w_min], dim=1) 69 | 70 | swap_images_cropped = [] 71 | 72 | subject_images_cropped = [] 73 | 74 | background_images_cropped = [] 75 | 76 | for i in range(swap_images.shape[0]): 77 | swap_image = swap_images[i] 78 | background_image = background_images[i] 79 | subject_image = subject_images[i] 80 | swap_image_bbox = background_bbox[i] 81 | subject_image_bbox = subject_bbox[i] 82 | 83 | swap_images_cropped.append( 84 | F.interpolate( 85 | swap_image[ 86 | :, 87 | swap_image_bbox[0] : swap_image_bbox[1], 88 | swap_image_bbox[2] : swap_image_bbox[3], 89 | ].unsqueeze(0), 90 | size=(256, 128), 91 | mode="bilinear", 92 | ) 93 | ) 94 | 95 | subject_images_cropped.append( 96 | F.interpolate( 97 | subject_image[ 98 | :, 99 | subject_image_bbox[0] : subject_image_bbox[1], 100 | subject_image_bbox[2] : subject_image_bbox[3], 101 | ].unsqueeze(0), 102 | size=(256, 128), 103 | mode="bilinear", 104 | ) 105 | ) 106 | background_images_cropped.append( 107 | F.interpolate( 108 | background_image[ 109 | :, 110 | swap_image_bbox[0] : swap_image_bbox[1], 111 | swap_image_bbox[2] : swap_image_bbox[3], 112 | ].unsqueeze(0), 113 | size=(256, 128), 114 | mode="bilinear", 115 | ) 116 | ) 117 | swap_images_cropped = torch.cat(swap_images_cropped, dim=0) 118 | subject_images_cropped = torch.cat(subject_images_cropped, dim=0) 119 | background_images_cropped = torch.cat(background_images_cropped, dim=0) 120 | 121 | swap_images_features = self.reid_model(swap_images_cropped).mean(dim=(2, 3)) 122 | subject_images_features = self.reid_model(subject_images_cropped).mean( 123 | dim=(2, 3) 124 | ) 125 | background_images_features = self.reid_model(background_images_cropped).mean( 126 | dim=(2, 3) 127 | ) 128 | 129 | background_cos_sim = F.cosine_similarity( 130 | swap_images_features, background_images_features, dim=1 131 | ) 132 | subject_cos_sim = F.cosine_similarity( 133 | swap_images_features, subject_images_features, dim=1 134 | ) 135 | 136 | self.more_similar_to_subject += (subject_cos_sim > background_cos_sim).sum() 137 | 138 | def compute(self): 139 | return self.more_similar_to_subject / self.num_samples 140 | 141 | 142 | class ReidSubjectSimilarity(Metric): 143 | def __init__( 144 | self, 145 | compute_on_step: bool = False, 146 | dist_sync_on_step: bool = False, 147 | process_group: Optional[Any] = None, 148 | dist_sync_fn: Callable = None, 149 | ): 150 | 151 | super().__init__( 152 | compute_on_step=compute_on_step, 153 | dist_sync_on_step=dist_sync_on_step, 154 | process_group=process_group, 155 | dist_sync_fn=dist_sync_fn, 156 | ) 157 | self.reid_model = build_resnet_backbone() 158 | 159 | self.add_state( 160 | "subject_cosine_sim", torch.tensor(0).float(), dist_reduce_fx="sum" 161 | ) 162 | 163 | self.add_state( 164 | "num_samples", 165 | torch.tensor(0).long(), 166 | dist_reduce_fx="sum", 167 | ) 168 | 169 | def update( 170 | self, 171 | swap_images, 172 | background_images, 173 | background_segmask, 174 | subject_images, 175 | subject_segmask, 176 | ): 177 | self.num_samples += swap_images.shape[0] 178 | _, _, h, w = swap_images.shape 179 | background_background_segmask = 1 - background_segmask[:, 0, :, :] 180 | h_max = (background_background_segmask.sum(dim=2) > 0).long().argmax(dim=1) 181 | w_max = (background_background_segmask.sum(dim=1) > 0).long().argmax(dim=1) 182 | 183 | h_rev = [i for i in range(background_background_segmask.shape[1])] 184 | 185 | h_min = h - (background_background_segmask.sum(dim=2) > 0)[ 186 | :, h_rev[::-1] 187 | ].long().argmax(dim=1) 188 | w_rev = [i for i in range(background_background_segmask.shape[2])] 189 | w_min = w - (background_background_segmask.sum(dim=1) > 0)[ 190 | :, w_rev[::-1] 191 | ].long().argmax(dim=1) 192 | background_bbox = torch.stack([h_max, h_min, w_max, w_min], dim=1) 193 | 194 | subject_segmask = 1 - subject_segmask[:, 0, :, :] 195 | h_max = (subject_segmask.sum(dim=2) > 0).long().argmax(dim=1) 196 | w_max = (subject_segmask.sum(dim=1) > 0).long().argmax(dim=1) 197 | h_min = h - (subject_segmask.sum(dim=2) > 0)[:, h_rev[::-1]].long().argmax( 198 | dim=1 199 | ) 200 | w_min = w - (subject_segmask.sum(dim=1) > 0)[:, w_rev[::-1]].long().argmax( 201 | dim=1 202 | ) 203 | subject_bbox = torch.stack([h_max, h_min, w_max, w_min], dim=1) 204 | 205 | swap_images_cropped = [] 206 | 207 | subject_images_cropped = [] 208 | 209 | for i in range(swap_images.shape[0]): 210 | swap_image = swap_images[i] 211 | subject_image = subject_images[i] 212 | swap_image_bbox = background_bbox[i] 213 | subject_image_bbox = subject_bbox[i] 214 | 215 | swap_images_cropped.append( 216 | F.interpolate( 217 | swap_image[ 218 | :, 219 | swap_image_bbox[0] : swap_image_bbox[1], 220 | swap_image_bbox[2] : swap_image_bbox[3], 221 | ].unsqueeze(0), 222 | size=(256, 128), 223 | mode="bilinear", 224 | ) 225 | ) 226 | subject_images_cropped.append( 227 | F.interpolate( 228 | subject_image[ 229 | :, 230 | subject_image_bbox[0] : subject_image_bbox[1], 231 | subject_image_bbox[2] : subject_image_bbox[3], 232 | ].unsqueeze(0), 233 | size=(256, 128), 234 | mode="bilinear", 235 | ) 236 | ) 237 | swap_images_cropped = torch.cat(swap_images_cropped, dim=0) 238 | subject_images_cropped = torch.cat(subject_images_cropped, dim=0) 239 | 240 | swap_images_features = self.reid_model(swap_images_cropped).mean(dim=(2, 3)) 241 | subject_images_features = self.reid_model(subject_images_cropped).mean( 242 | dim=(2, 3) 243 | ) 244 | 245 | subject_cos_sim = F.cosine_similarity( 246 | swap_images_features, subject_images_features, dim=1 247 | ).sum() 248 | 249 | self.subject_cosine_sim += subject_cos_sim 250 | 251 | def compute(self): 252 | return self.subject_cosine_sim / self.num_samples 253 | 254 | 255 | class ReidBackgroundSuperiority(Metric): 256 | def __init__( 257 | self, 258 | compute_on_step: bool = False, 259 | dist_sync_on_step: bool = False, 260 | process_group: Optional[Any] = None, 261 | dist_sync_fn: Callable = None, 262 | ): 263 | 264 | super().__init__( 265 | compute_on_step=compute_on_step, 266 | dist_sync_on_step=dist_sync_on_step, 267 | process_group=process_group, 268 | dist_sync_fn=dist_sync_fn, 269 | ) 270 | self.reid_model = build_resnet_backbone() 271 | 272 | self.add_state( 273 | "background_cosine_sim", torch.tensor(0).float(), dist_reduce_fx="sum" 274 | ) 275 | 276 | self.add_state( 277 | "num_samples", 278 | torch.tensor(0).long(), 279 | dist_reduce_fx="sum", 280 | ) 281 | 282 | def update( 283 | self, 284 | swap_images, 285 | background_images, 286 | background_segmask, 287 | subject_images, 288 | subject_segmask, 289 | ): 290 | self.num_samples += swap_images.shape[0] 291 | _, _, h, w = swap_images.shape 292 | background_background_segmask = 1 - background_segmask[:, 0, :, :] 293 | h_max = (background_background_segmask.sum(dim=2) > 0).long().argmax(dim=1) 294 | w_max = (background_background_segmask.sum(dim=1) > 0).long().argmax(dim=1) 295 | 296 | h_rev = [i for i in range(background_background_segmask.shape[1])] 297 | 298 | h_min = h - (background_background_segmask.sum(dim=2) > 0)[ 299 | :, h_rev[::-1] 300 | ].long().argmax(dim=1) 301 | w_rev = [i for i in range(background_background_segmask.shape[2])] 302 | w_min = w - (background_background_segmask.sum(dim=1) > 0)[ 303 | :, w_rev[::-1] 304 | ].long().argmax(dim=1) 305 | background_bbox = torch.stack([h_max, h_min, w_max, w_min], dim=1) 306 | 307 | swap_images_cropped = [] 308 | 309 | background_images_cropped = [] 310 | 311 | for i in range(swap_images.shape[0]): 312 | swap_image = swap_images[i] 313 | background_image = background_images[i] 314 | swap_image_bbox = background_bbox[i] 315 | 316 | swap_images_cropped.append( 317 | F.interpolate( 318 | swap_image[ 319 | :, 320 | swap_image_bbox[0] : swap_image_bbox[1], 321 | swap_image_bbox[2] : swap_image_bbox[3], 322 | ].unsqueeze(0), 323 | size=(256, 128), 324 | mode="bilinear", 325 | ) 326 | ) 327 | background_images_cropped.append( 328 | F.interpolate( 329 | background_image[ 330 | :, 331 | swap_image_bbox[0] : swap_image_bbox[1], 332 | swap_image_bbox[2] : swap_image_bbox[3], 333 | ].unsqueeze(0), 334 | size=(256, 128), 335 | mode="bilinear", 336 | ) 337 | ) 338 | swap_images_cropped = torch.cat(swap_images_cropped, dim=0) 339 | background_images_cropped = torch.cat(background_images_cropped, dim=0) 340 | 341 | swap_images_features = self.reid_model(swap_images_cropped).mean(dim=(2, 3)) 342 | background_images_features = self.reid_model(background_images_cropped).mean( 343 | dim=(2, 3) 344 | ) 345 | 346 | background_cos_sim = F.cosine_similarity( 347 | swap_images_features, background_images_features, dim=1 348 | ).sum() 349 | 350 | self.background_cosine_sim += background_cos_sim 351 | 352 | def compute(self): 353 | return self.background_cosine_sim / self.num_samples 354 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nicolas-dufour/SCAM/848484acd71c277f3ed48456e206870847dd400d/models/__init__.py -------------------------------------------------------------------------------- /models/diffaugment.py: -------------------------------------------------------------------------------- 1 | import kornia.augmentation as K 2 | import torch.nn as nn 3 | 4 | 5 | class SimpleAugmentation(nn.Module): 6 | def __init__(self): 7 | super().__init__() 8 | self.aug = nn.Sequential( 9 | K.Denormalize(mean=0.5, std=0.5), 10 | K.ColorJitter(p=0.8, brightness=0.2, contrast=0.3, hue=0.2), 11 | K.RandomErasing(p=0.5), 12 | K.Normalize(mean=0.5, std=0.5), 13 | ) 14 | 15 | def forward(self, x): 16 | return self.aug(x) 17 | -------------------------------------------------------------------------------- /models/discriminators/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nicolas-dufour/SCAM/848484acd71c277f3ed48456e206870847dd400d/models/discriminators/__init__.py -------------------------------------------------------------------------------- /models/discriminators/mcad.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torchtyping import TensorType 5 | from typing import Union, Sequence 6 | 7 | import torch.nn.utils.spectral_norm as spectral_norm 8 | from collections import OrderedDict 9 | 10 | from einops import repeat, rearrange 11 | 12 | from models.utils_blocks.base import BaseNetwork 13 | 14 | from models.utils_blocks.attention import ( 15 | MaskedTransformer, 16 | SinusoidalPositionalEmbedding, 17 | ) 18 | 19 | 20 | class MultiScaleMCADiscriminator(BaseNetwork): 21 | """ 22 | This is the Multi Scale Discriminator. We use multiple discriminators at different scales 23 | 24 | Parameters: 25 | ----------- 26 | num_discriminator: int, 27 | How many discriminators do we use 28 | image_num_channels: int, 29 | Number of input images channels 30 | segmap_num_channels: int, 31 | Number of segmentation map channels 32 | num_features_fst_conv: int, 33 | How many kernels at the first convolution 34 | num_layers: int, 35 | How many layers per discriminator 36 | apply_spectral_norm: bool = True, 37 | Wheter or not to apply spectral normalization 38 | keep_intermediate_results: bool = True 39 | Whether or not to keep intermediate discriminators feature maps 40 | 41 | """ 42 | 43 | def __init__( 44 | self, 45 | num_discriminator: int, 46 | image_num_channels: int, 47 | segmap_num_channels: int, 48 | positional_embedding_dim: int, 49 | num_labels: int, 50 | num_latent_per_labels: int, 51 | latent_dim: int, 52 | num_blocks: int, 53 | attention_latent_dim: int, 54 | num_cross_heads: int, 55 | num_self_heads: int, 56 | apply_spectral_norm: bool = True, 57 | concat_segmaps: bool = False, 58 | output_type: str = "patchgan", 59 | keep_intermediate_results: bool = True, 60 | ): 61 | super().__init__() 62 | 63 | self.keep_intermediate_results = keep_intermediate_results 64 | 65 | self.discriminators = nn.ModuleDict(OrderedDict()) 66 | 67 | for i in range(num_discriminator): 68 | self.discriminators.update( 69 | { 70 | f"discriminator_{i}": MaskedCrossAttentionDiscriminator( 71 | image_num_channels=image_num_channels, 72 | segmap_num_channels=segmap_num_channels, 73 | positional_embedding_dim=positional_embedding_dim, 74 | num_labels=num_labels, 75 | num_latent_per_labels=num_latent_per_labels, 76 | latent_dim=latent_dim, 77 | num_blocks=num_blocks, 78 | attention_latent_dim=attention_latent_dim, 79 | num_cross_heads=num_cross_heads, 80 | num_self_heads=num_self_heads, 81 | apply_spectral_norm=apply_spectral_norm, 82 | concat_segmaps=concat_segmaps, 83 | output_type=output_type, 84 | keep_intermediate_results=keep_intermediate_results, 85 | ), 86 | } 87 | ) 88 | self.downsample = nn.AvgPool2d( 89 | kernel_size=3, stride=2, padding=[1, 1], count_include_pad=False 90 | ) 91 | 92 | def forward( 93 | self, 94 | input: TensorType["batch_size", "input_channels", "height", "width"], 95 | segmentation_map: TensorType[ 96 | "batch_size", "num_labels", "height", "width" 97 | ] = None, 98 | ) -> Union[ 99 | Sequence[ 100 | Sequence[TensorType["batch_size", 1, "output_height", "output_width"]] 101 | ], 102 | Sequence[TensorType["batch_size", 1, "output_height", "output_width"]], 103 | ]: 104 | results = [] 105 | for disc_name in self.discriminators: 106 | result = self.discriminators[disc_name](input, segmentation_map) 107 | if not self.keep_intermediate_results: 108 | result = [result] 109 | results.append(result) 110 | input = self.downsample(input) 111 | segmentation_map = F.interpolate( 112 | segmentation_map, size=input.size()[2:], mode="nearest" 113 | ) 114 | 115 | return results 116 | 117 | 118 | class MaskedCrossAttentionDiscriminator(BaseNetwork): 119 | """ 120 | Encoder that encode style vectors for each segmentation labels doing regional average pooling 121 | 122 | Parameters: 123 | ----------- 124 | num_input_channels: int, 125 | Number of input channels 126 | latent_dim: int, 127 | Number of output channels (style_dim) 128 | num_features_fst_conv: iny, 129 | Number of kernels at first conv 130 | 131 | """ 132 | 133 | def __init__( 134 | self, 135 | image_num_channels: int, 136 | segmap_num_channels: int, 137 | positional_embedding_dim: int, 138 | num_labels: int, 139 | num_latent_per_labels: int, 140 | latent_dim: int, 141 | num_blocks: int, 142 | attention_latent_dim: int, 143 | num_cross_heads: int, 144 | num_self_heads: int, 145 | apply_spectral_norm: bool = True, 146 | concat_segmaps: bool = False, 147 | output_type: str = "patchgan", 148 | keep_intermediate_results: bool = True, 149 | ): 150 | super().__init__() 151 | 152 | self.concat_segmaps = concat_segmaps 153 | self.output_type = output_type 154 | self.keep_intermediate_results = keep_intermediate_results 155 | self.image_pos_embs = nn.ModuleList( 156 | [SinusoidalPositionalEmbedding(positional_embedding_dim, emb_type="concat")] 157 | ) 158 | self.convs = nn.ModuleList([nn.Identity()]) 159 | 160 | num_input_channels = image_num_channels + segmap_num_channels 161 | 162 | image_emb_dim = num_input_channels + positional_embedding_dim 163 | 164 | num_latents = num_labels * num_latent_per_labels 165 | self.num_latent_per_labels = num_latent_per_labels 166 | 167 | latents_mask = torch.block_diag( 168 | *[ 169 | torch.FloatTensor( 170 | [ 171 | [1.0 for _ in range(num_latent_per_labels)] 172 | for _ in range(num_latent_per_labels) 173 | ] 174 | ) 175 | for _ in range(num_labels) 176 | ] 177 | ).unsqueeze(0) 178 | self.register_buffer("latents_mask", latents_mask) 179 | 180 | self.latents = nn.Parameter(torch.Tensor(num_latents, latent_dim)) 181 | 182 | self.backbone = nn.ModuleDict(OrderedDict()) 183 | 184 | for i in range(num_blocks): 185 | module = nn.ModuleDict(OrderedDict()) 186 | module.update( 187 | { 188 | "cross_attention": MaskedTransformer( 189 | latent_dim, 190 | num_latents, 191 | image_emb_dim if i == 0 else 32 * 2 ** i, 192 | attention_latent_dim, 193 | num_cross_heads, 194 | ), 195 | } 196 | ) 197 | module.update( 198 | { 199 | "self_attention": MaskedTransformer( 200 | latent_dim, 201 | num_latents, 202 | latent_dim, 203 | attention_latent_dim, 204 | num_self_heads, 205 | ), 206 | } 207 | ) 208 | self.backbone.update({f"block_{i}": module}) 209 | if i > 0: 210 | conv = nn.Conv2d( 211 | num_input_channels if i == 1 else 32 * 2 ** (i - 1), 212 | 32 * 2 ** i, 213 | kernel_size=3, 214 | padding=1, 215 | stride=2, 216 | ) 217 | if apply_spectral_norm: 218 | conv = spectral_norm(conv) 219 | self.convs.append( 220 | nn.Sequential( 221 | conv, 222 | nn.LeakyReLU(0.2), 223 | ) 224 | ) 225 | self.image_pos_embs.append( 226 | SinusoidalPositionalEmbedding(32 * 2 ** i, emb_type="add") 227 | ) 228 | self.conv_to_latent = nn.Sequential( 229 | nn.Linear(32 * 2 ** (num_blocks - 1), latent_dim), 230 | nn.LeakyReLU(0.2), 231 | ) 232 | if self.output_type == "attention_pool": 233 | self.cls_token = nn.Parameter(torch.randn(1, 1, latent_dim)) 234 | self.attention_pool = MaskedTransformer( 235 | latent_dim, 236 | 1, 237 | latent_dim, 238 | attention_latent_dim, 239 | num_self_heads, 240 | ) 241 | self.classification_head = nn.Linear(latent_dim, 1) 242 | 243 | def forward( 244 | self, 245 | input: TensorType["batch_size", "num_input_channels", "height", "width"], 246 | segmentation_map: TensorType["batch_size", "num_labels", "height", "width"], 247 | ) -> TensorType["batch_size", "num_input_channels", "style_dim"]: 248 | batch_size = input.shape[0] 249 | 250 | latents = repeat(self.latents, " l d -> b l d", b=batch_size) 251 | 252 | if self.concat_segmaps: 253 | input = torch.cat([input, segmentation_map], dim=1) 254 | 255 | cross_attention_masks = [] 256 | flattened_inputs = [] 257 | for i, conv in enumerate(self.convs): 258 | input = conv(input) 259 | cross_attention_masks.append( 260 | rearrange( 261 | torch.repeat_interleave( 262 | F.interpolate( 263 | segmentation_map, size=input.size()[2:], mode="nearest" 264 | ), 265 | self.num_latent_per_labels, 266 | dim=1, 267 | ), 268 | "b n h w -> b n (h w)", 269 | ) 270 | ) 271 | flattened_inputs.append( 272 | rearrange(self.image_pos_embs[i](input), " b c h w -> b (h w) c") 273 | ) 274 | 275 | for i, block_name in enumerate(self.backbone): 276 | latents = self.backbone[block_name]["cross_attention"]( 277 | latents, flattened_inputs[i], cross_attention_masks[i] 278 | ) 279 | convs_features = self.conv_to_latent(flattened_inputs[-1]) 280 | latents_and_convs = torch.cat([latents, convs_features], dim=1) 281 | 282 | if self.output_type == "patchgan": 283 | flattened_inputs.append(self.classification_head(latents_and_convs)) 284 | elif self.output_type == "mean_pool": 285 | flattened_inputs.append( 286 | self.classification_head(latents_and_convs.mean(dim=1)) 287 | ) 288 | elif self.output_type == "attention_pool": 289 | output_token = repeat(self.cls_token, "() n d -> b n d", b=batch_size) 290 | output_token = self.attention_pool(output_token, latents_and_convs) 291 | flattened_inputs.append(self.classification_head(output_token)) 292 | else: 293 | raise ValueError("Not a supported disc output type") 294 | 295 | if self.keep_intermediate_results: 296 | return flattened_inputs[1:] 297 | else: 298 | return flattened_inputs[-1] 299 | -------------------------------------------------------------------------------- /models/discriminators/oasis.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import torch.nn.utils.spectral_norm as spectral_norm 5 | 6 | from models.utils_blocks.base import BaseNetwork 7 | 8 | 9 | class OASISDiscriminator(BaseNetwork): 10 | """ 11 | PatchGAN Discriminator. Perform patch-wise GAN discrimination to avoid ignoring local high dimensional features. 12 | 13 | Parameters: 14 | ----------- 15 | image_num_channels: int, 16 | Number of channels in the real/generated image 17 | segmap_num_channels: int, 18 | Number of channels in the segmentation map 19 | num_features_fst_conv: int, 20 | Number of features in the first convolution layer 21 | num_layers: int, 22 | Number of convolution layers 23 | apply_spectral_norm: bool, default True 24 | Whether or not to apply spectral normalization 25 | keep_intermediate_results: bool, default True 26 | Whether or not to keep each feature map output 27 | """ 28 | 29 | def __init__( 30 | self, 31 | image_num_channels: int, 32 | segmap_num_channels: int, 33 | apply_spectral_norm: bool = True, 34 | apply_grad_norm: bool = False, 35 | ): 36 | super().__init__() 37 | self.apply_grad_norm = apply_grad_norm 38 | output_channel = segmap_num_channels + 1 # for N+1 loss 39 | self.channels = [image_num_channels, 128, 128, 256, 256, 512, 512] 40 | num_res_blocks = 6 41 | self.body_up = nn.ModuleList([]) 42 | self.body_down = nn.ModuleList([]) 43 | # encoder part 44 | for i in range(num_res_blocks): 45 | self.body_down.append( 46 | OASISBlock( 47 | self.channels[i], 48 | self.channels[i + 1], 49 | -1, 50 | first=(i == 0), 51 | apply_spectral_norm=apply_spectral_norm, 52 | ) 53 | ) 54 | # decoder part 55 | self.body_up.append( 56 | OASISBlock( 57 | self.channels[-1], 58 | self.channels[-2], 59 | 1, 60 | apply_spectral_norm=apply_spectral_norm, 61 | ) 62 | ) 63 | for i in range(1, num_res_blocks - 1): 64 | self.body_up.append( 65 | OASISBlock( 66 | 2 * self.channels[-1 - i], 67 | self.channels[-2 - i], 68 | 1, 69 | apply_spectral_norm=apply_spectral_norm, 70 | ) 71 | ) 72 | self.body_up.append( 73 | OASISBlock( 74 | 2 * self.channels[1], 64, 1, apply_spectral_norm=apply_spectral_norm 75 | ) 76 | ) 77 | self.layer_up_last = nn.Conv2d(64, output_channel, 1, 1, 0) 78 | 79 | def forward(self, input, segmap=None): 80 | x = input 81 | # encoder 82 | encoder_res = list() 83 | for i in range(len(self.body_down)): 84 | x = self.body_down[i](x) 85 | encoder_res.append(x) 86 | # decoder 87 | x = self.body_up[0](x) 88 | for i in range(1, len(self.body_down)): 89 | x = self.body_up[i](torch.cat((encoder_res[-i - 1], x), dim=1)) 90 | ans = self.layer_up_last(x) 91 | if self.apply_grad_norm: 92 | grad = torch.autograd.grad( 93 | ans, 94 | [input], 95 | torch.ones_like(ans), 96 | create_graph=True, 97 | retain_graph=True, 98 | )[0] 99 | grad_norm = torch.norm(torch.flatten(grad, start_dim=1), p=2, dim=1) 100 | grad_norm = grad_norm.view(-1, *[1 for _ in range(len(ans.shape) - 1)]) 101 | ans = ans / (grad_norm + torch.abs(ans)) 102 | return ans 103 | 104 | 105 | class OASISBlock(nn.Module): 106 | def __init__( 107 | self, 108 | fin: int, 109 | fout: int, 110 | up_or_down: int, 111 | first: bool = False, 112 | apply_spectral_norm: bool = True, 113 | ): 114 | super().__init__() 115 | # Attributes 116 | self.up_or_down = up_or_down 117 | self.first = first 118 | self.learned_shortcut = fin != fout 119 | fmiddle = fout 120 | if apply_spectral_norm: 121 | norm_layer = spectral_norm 122 | else: 123 | norm_layer = nn.Identity() 124 | if first: 125 | self.conv1 = nn.Sequential(norm_layer(nn.Conv2d(fin, fmiddle, 3, 1, 1))) 126 | else: 127 | if self.up_or_down > 0: 128 | self.conv1 = nn.Sequential( 129 | nn.LeakyReLU(0.2, False), 130 | nn.Upsample(scale_factor=2), 131 | norm_layer(nn.Conv2d(fin, fmiddle, 3, 1, 1)), 132 | ) 133 | else: 134 | self.conv1 = nn.Sequential( 135 | nn.LeakyReLU(0.2, False), 136 | norm_layer(nn.Conv2d(fin, fmiddle, 3, 1, 1)), 137 | ) 138 | self.conv2 = nn.Sequential( 139 | nn.LeakyReLU(0.2, False), norm_layer(nn.Conv2d(fmiddle, fout, 3, 1, 1)) 140 | ) 141 | if self.learned_shortcut: 142 | self.conv_s = norm_layer(nn.Conv2d(fin, fout, 1, 1, 0)) 143 | if up_or_down > 0: 144 | self.sampling = nn.Upsample(scale_factor=2) 145 | elif up_or_down < 0: 146 | self.sampling = nn.AvgPool2d(2) 147 | else: 148 | self.sampling = nn.Sequential() 149 | 150 | def forward(self, x): 151 | x_s = self.shortcut(x) 152 | dx = self.conv1(x) 153 | dx = self.conv2(dx) 154 | if self.up_or_down < 0: 155 | dx = self.sampling(dx) 156 | out = x_s + dx 157 | return out 158 | 159 | def shortcut(self, x): 160 | if self.first: 161 | if self.up_or_down < 0: 162 | x = self.sampling(x) 163 | if self.learned_shortcut: 164 | x = self.conv_s(x) 165 | x_s = x 166 | else: 167 | if self.up_or_down > 0: 168 | x = self.sampling(x) 169 | if self.learned_shortcut: 170 | x = self.conv_s(x) 171 | if self.up_or_down < 0: 172 | x = self.sampling(x) 173 | x_s = x 174 | return x_s 175 | -------------------------------------------------------------------------------- /models/discriminators/patchgan.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torchtyping import TensorType 5 | from typing import Union, Sequence 6 | 7 | import torch.nn.utils.spectral_norm as spectral_norm 8 | from collections import OrderedDict 9 | from math import ceil 10 | 11 | from models.utils_blocks.base import BaseNetwork 12 | 13 | from models.utils_blocks.equallr import EqualConv2d 14 | from functools import partial 15 | 16 | 17 | class MultiScalePatchGanDiscriminator(BaseNetwork): 18 | """ 19 | This is the Multi Scale Discriminator. We use multiple discriminators at different scales 20 | 21 | Parameters: 22 | ----------- 23 | num_discriminator: int, 24 | How many discriminators do we use 25 | image_num_channels: int, 26 | Number of input images channels 27 | segmap_num_channels: int, 28 | Number of segmentation map channels 29 | num_features_fst_conv: int, 30 | How many kernels at the first convolution 31 | num_layers: int, 32 | How many layers per discriminator 33 | apply_spectral_norm: bool = True, 34 | Wheter or not to apply spectral normalization 35 | keep_intermediate_results: bool = True 36 | Whether or not to keep intermediate discriminators feature maps 37 | 38 | """ 39 | 40 | def __init__( 41 | self, 42 | num_discriminator: int, 43 | image_num_channels: int, 44 | segmap_num_channels: int, 45 | num_features_fst_conv: int, 46 | num_layers: int, 47 | apply_spectral_norm: bool = True, 48 | apply_grad_norm: bool = False, 49 | keep_intermediate_results: bool = True, 50 | use_equalized_lr: bool = False, 51 | lr_mul: float = 1.0, 52 | ): 53 | super().__init__() 54 | 55 | self.keep_intermediate_results = keep_intermediate_results 56 | 57 | self.image_num_channels = image_num_channels 58 | 59 | self.discriminators = nn.ModuleDict(OrderedDict()) 60 | 61 | for i in range(num_discriminator): 62 | self.discriminators.update( 63 | { 64 | f"discriminator_{i}": PatchGANDiscriminator( 65 | image_num_channels=image_num_channels, 66 | segmap_num_channels=segmap_num_channels, 67 | num_features_fst_conv=num_features_fst_conv, 68 | num_layers=num_layers, 69 | apply_spectral_norm=apply_spectral_norm, 70 | apply_grad_norm=apply_grad_norm, 71 | keep_intermediate_results=keep_intermediate_results, 72 | use_equalized_lr=use_equalized_lr, 73 | lr_mul=lr_mul, 74 | ), 75 | } 76 | ) 77 | self.downsample = nn.AvgPool2d( 78 | kernel_size=3, stride=2, padding=[1, 1], count_include_pad=False 79 | ) 80 | 81 | def forward( 82 | self, 83 | input: TensorType["batch_size", "input_channels", "height", "width"], 84 | ) -> Union[ 85 | Sequence[ 86 | Sequence[TensorType["batch_size", 1, "output_height", "output_width"]] 87 | ], 88 | Sequence[TensorType["batch_size", 1, "output_height", "output_width"]], 89 | ]: 90 | results = [] 91 | for disc_name in self.discriminators: 92 | result = self.discriminators[disc_name](input) 93 | if not self.keep_intermediate_results: 94 | result = [result] 95 | results.append(result) 96 | image = self.downsample(input[:, : self.image_num_channels]) 97 | segmentation_map = F.interpolate( 98 | input[:, self.image_num_channels :], 99 | size=image.size()[2:], 100 | mode="nearest", 101 | ) 102 | input = torch.cat([image, segmentation_map], dim=1) 103 | return results 104 | 105 | 106 | class PatchGANDiscriminator(BaseNetwork): 107 | """ 108 | PatchGAN Discriminator. Perform patch-wise GAN discrimination to avoid ignoring local high dimensional features. 109 | 110 | Parameters: 111 | ----------- 112 | image_num_channels: int, 113 | Number of channels in the real/generated image 114 | segmap_num_channels: int, 115 | Number of channels in the segmentation map 116 | num_features_fst_conv: int, 117 | Number of features in the first convolution layer 118 | num_layers: int, 119 | Number of convolution layers 120 | apply_spectral_norm: bool, default True 121 | Whether or not to apply spectral normalization 122 | keep_intermediate_results: bool, default True 123 | Whether or not to keep each feature map output 124 | """ 125 | 126 | def __init__( 127 | self, 128 | image_num_channels: int, 129 | segmap_num_channels: int, 130 | num_features_fst_conv: int, 131 | num_layers: int, 132 | apply_spectral_norm: bool = True, 133 | apply_grad_norm: bool = False, 134 | keep_intermediate_results: bool = True, 135 | use_equalized_lr: bool = False, 136 | lr_mul: float = 1.0, 137 | ): 138 | super().__init__() 139 | kernel_size = 4 140 | padding = int(ceil((kernel_size - 1.0) / 2)) 141 | nffc = num_features_fst_conv 142 | self.keep_intermediate_results = keep_intermediate_results 143 | self.apply_grad_norm = apply_grad_norm 144 | self.model = nn.ModuleDict(OrderedDict()) 145 | ConvLayer = ( 146 | partial(EqualConv2d, lr_mul=lr_mul) if use_equalized_lr else nn.Conv2d 147 | ) 148 | self.model.update( 149 | { 150 | "conv_0": nn.Sequential( 151 | ConvLayer( 152 | in_channels=image_num_channels + segmap_num_channels, 153 | out_channels=nffc, 154 | kernel_size=kernel_size, 155 | stride=2, 156 | padding=padding, 157 | ), 158 | nn.LeakyReLU(0.2, False), 159 | ) 160 | } 161 | ) 162 | 163 | for n in range(1, num_layers): 164 | nffc_prev = nffc 165 | nffc = min(2 * nffc_prev, 512) 166 | stride = 1 if n == num_layers - 1 else 2 167 | conv = ConvLayer( 168 | in_channels=nffc_prev, 169 | out_channels=nffc, 170 | kernel_size=kernel_size, 171 | padding=padding, 172 | stride=stride, 173 | ) 174 | if apply_spectral_norm: 175 | conv = spectral_norm(conv) 176 | self.model.update( 177 | { 178 | f"conv_{n}": nn.Sequential( 179 | conv, 180 | nn.InstanceNorm2d(nffc), 181 | nn.LeakyReLU(0.2, False), 182 | ) 183 | } 184 | ) 185 | self.model.update( 186 | { 187 | f"last_conv": nn.Sequential( 188 | ConvLayer( 189 | in_channels=nffc, 190 | out_channels=1, 191 | kernel_size=kernel_size, 192 | stride=1, 193 | padding=padding, 194 | ) 195 | ) 196 | } 197 | ) 198 | 199 | def forward( 200 | self, 201 | input: TensorType["batch_size", "input_channels", "height", "width"], 202 | ) -> Union[ 203 | Sequence[TensorType["batch_size", 1, "output_height", "output_width"]], 204 | TensorType["batch_size", 1, "output_height", "output_width"], 205 | ]: 206 | results = [input] 207 | for conv_name in self.model: 208 | results.append(self.model[conv_name](results[-1])) 209 | if self.keep_intermediate_results: 210 | return results[1:] 211 | else: 212 | return results[-1] 213 | -------------------------------------------------------------------------------- /models/discriminators/stylegan2.py: -------------------------------------------------------------------------------- 1 | from genericpath import exists 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from functools import partial 6 | from torchtyping import TensorType 7 | from typing import Union, Sequence 8 | from collections import OrderedDict 9 | import torch.nn.utils.spectral_norm as spectral_norm 10 | 11 | from models.utils_blocks.equallr import EqualConv2d, EqualLinear 12 | 13 | from math import sqrt 14 | 15 | from einops.layers.torch import Rearrange 16 | 17 | from kornia.filters import filter2d 18 | 19 | from models.utils_blocks.base import BaseNetwork 20 | 21 | 22 | class MultiStyleGan2Discriminator(BaseNetwork): 23 | """ 24 | This is the Multi Scale Discriminator. We use multiple discriminators at different scales 25 | 26 | Parameters: 27 | ----------- 28 | num_discriminator: int, 29 | How many discriminators do we use 30 | image_num_channels: int, 31 | Number of input images channels 32 | segmap_num_channels: int, 33 | Number of segmentation map channels 34 | num_features_fst_conv: int, 35 | How many kernels at the first convolution 36 | num_layers: int, 37 | How many layers per discriminator 38 | apply_spectral_norm: bool = True, 39 | Wheter or not to apply spectral normalization 40 | keep_intermediate_results: bool = True 41 | Whether or not to keep intermediate discriminators feature maps 42 | 43 | """ 44 | 45 | def __init__( 46 | self, 47 | num_discriminator: int, 48 | image_num_channels: int, 49 | segmap_num_channels: int, 50 | num_features_fst_conv: int, 51 | num_layers: int, 52 | fmap_max: int, 53 | apply_spectral_norm: bool = True, 54 | apply_grad_norm: bool = False, 55 | keep_intermediate_results: bool = True, 56 | use_equalized_lr=False, 57 | lr_mul=1, 58 | ): 59 | super().__init__() 60 | 61 | self.keep_intermediate_results = keep_intermediate_results 62 | 63 | self.image_num_channels = image_num_channels 64 | 65 | self.discriminators = nn.ModuleDict(OrderedDict()) 66 | 67 | for i in range(num_discriminator): 68 | self.discriminators.update( 69 | { 70 | f"discriminator_{i}": StyleGan2Discriminator( 71 | image_num_channels=image_num_channels, 72 | segmap_num_channels=segmap_num_channels, 73 | num_features_fst_conv=num_features_fst_conv, 74 | num_layers=num_layers - i, 75 | fmap_max=fmap_max, 76 | apply_spectral_norm=apply_spectral_norm, 77 | apply_grad_norm=apply_grad_norm, 78 | keep_intermediate_results=keep_intermediate_results, 79 | use_equalized_lr=use_equalized_lr, 80 | lr_mul=lr_mul, 81 | ), 82 | } 83 | ) 84 | self.downsample = nn.AvgPool2d( 85 | kernel_size=3, stride=2, padding=[1, 1], count_include_pad=False 86 | ) 87 | 88 | def forward( 89 | self, 90 | input: TensorType["batch_size", "input_channels", "height", "width"], 91 | ) -> Union[ 92 | Sequence[ 93 | Sequence[TensorType["batch_size", 1, "output_height", "output_width"]] 94 | ], 95 | Sequence[TensorType["batch_size", 1, "output_height", "output_width"]], 96 | ]: 97 | results = [] 98 | for disc_name in self.discriminators: 99 | result = self.discriminators[disc_name](input) 100 | if not self.keep_intermediate_results: 101 | result = [result] 102 | results.append(result) 103 | image = self.downsample(input[:, : self.image_num_channels]) 104 | segmentation_map = F.interpolate( 105 | input[:, self.image_num_channels :], 106 | size=image.size()[2:], 107 | mode="nearest", 108 | ) 109 | input = torch.cat([image, segmentation_map], dim=1) 110 | 111 | return results 112 | 113 | 114 | class StyleGan2Discriminator(BaseNetwork): 115 | """ 116 | Taken from lucidrain repo https://github.com/lucidrains/stylegan2-pytorch 117 | 118 | """ 119 | 120 | def __init__( 121 | self, 122 | num_layers, 123 | image_num_channels, 124 | segmap_num_channels, 125 | num_features_fst_conv, 126 | fmap_max, 127 | keep_intermediate_results=True, 128 | apply_grad_norm=False, 129 | apply_spectral_norm=False, 130 | use_equalized_lr=False, 131 | lr_mul=1, 132 | ): 133 | super().__init__() 134 | self.apply_grad_norm = apply_grad_norm 135 | self.keep_intermediate_results = keep_intermediate_results 136 | init_channels = image_num_channels + segmap_num_channels 137 | blocks = [] 138 | ConvLayer = ( 139 | partial(EqualConv2d, lr_mul=lr_mul) if use_equalized_lr else nn.Conv2d 140 | ) 141 | LinearLayer = ( 142 | partial(EqualLinear, lr_mul=lr_mul) if use_equalized_lr else nn.Linear 143 | ) 144 | filters = [init_channels] + [ 145 | num_features_fst_conv * (2**i) for i in range(num_layers + 1) 146 | ] 147 | set_fmap_max = partial(min, fmap_max) 148 | filters = list(map(set_fmap_max, filters)) 149 | chan_in_out = list(zip(filters[:-1], filters[1:])) 150 | for i, (input_dim, output_dim) in enumerate(chan_in_out): 151 | is_not_last = i != len(chan_in_out) - 1 152 | block = StyleGan2DiscBlock( 153 | input_dim, 154 | output_dim, 155 | downsample=is_not_last, 156 | apply_spectral_norm=apply_spectral_norm, 157 | use_equalized_lr=use_equalized_lr, 158 | lr_mul=lr_mul, 159 | ) 160 | blocks.append(block) 161 | self.blocks = nn.ModuleList(blocks) 162 | 163 | chan_last = filters[-1] 164 | latent_dim = 2 * 2 * chan_last 165 | 166 | self.final_block = nn.Sequential( 167 | ConvLayer(chan_last, chan_last, 3, padding=1), 168 | Rearrange("b c h w -> b (c h w)"), 169 | LinearLayer(latent_dim, 1), 170 | ) 171 | 172 | def forward(self, input): 173 | if self.apply_grad_norm: 174 | input.requires_grad_(True) 175 | results = [input] 176 | for block in self.blocks: 177 | results.append(block(results[-1])) 178 | results.append(self.final_block(results[-1])) 179 | if self.keep_intermediate_results: 180 | return results[1:] 181 | else: 182 | return results[-1] 183 | 184 | 185 | class StyleGan2DiscBlock(nn.Module): 186 | """ 187 | Taken from lucidrain repo https://github.com/lucidrains/stylegan2-pytorch 188 | 189 | """ 190 | 191 | def __init__( 192 | self, 193 | in_channels, 194 | out_channels, 195 | downsample=True, 196 | apply_spectral_norm=False, 197 | use_equalized_lr=False, 198 | lr_mul=1, 199 | ): 200 | super().__init__() 201 | spectral_norm_op = spectral_norm if apply_spectral_norm else nn.Identity() 202 | ConvLayer = ( 203 | partial(EqualConv2d, lr_mul=lr_mul) if use_equalized_lr else nn.Conv2d 204 | ) 205 | self.conv_res = spectral_norm_op( 206 | ConvLayer(in_channels, out_channels, 1, stride=(2 if downsample else 1)) 207 | ) 208 | self.net = nn.Sequential( 209 | spectral_norm_op(ConvLayer(in_channels, out_channels, 3, padding=1)), 210 | nn.LeakyReLU(0.2), 211 | spectral_norm_op(ConvLayer(out_channels, out_channels, 3, padding=1)), 212 | nn.LeakyReLU(0.2), 213 | ) 214 | 215 | self.downsample = ( 216 | nn.Sequential( 217 | # Blur(), 218 | spectral_norm_op( 219 | ConvLayer(out_channels, out_channels, 3, padding=1, stride=2) 220 | ), 221 | ) 222 | if downsample 223 | else None 224 | ) 225 | 226 | def forward(self, x): 227 | res = self.conv_res(x) 228 | x = self.net(x) 229 | if self.downsample is not None: 230 | x = self.downsample(x) 231 | x = (x + res) / sqrt(2) 232 | return x 233 | 234 | 235 | class Blur(nn.Module): 236 | """ 237 | Taken from lucidrain repo https://github.com/lucidrains/stylegan2-pytorch 238 | 239 | """ 240 | 241 | def __init__(self): 242 | super().__init__() 243 | f = torch.Tensor([1, 2, 1]) 244 | self.register_buffer("f", f) 245 | 246 | def forward(self, x): 247 | f = self.f 248 | f = f[None, None, :] * f[None, :, None] 249 | return filter2d(x, f, normalized=True) 250 | -------------------------------------------------------------------------------- /models/encoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nicolas-dufour/SCAM/848484acd71c277f3ed48456e206870847dd400d/models/encoders/__init__.py -------------------------------------------------------------------------------- /models/encoders/groupdnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.nn.utils.spectral_norm as spectral_norm 5 | from models.utils_blocks.base import BaseNetwork 6 | 7 | class GroupDNetStyleEncoder(BaseNetwork): 8 | def __init__(self, num_labels): 9 | super().__init__() 10 | self.num_labels = num_labels 11 | kw = 3 12 | pw = 1 13 | ndf = 32 * num_labels 14 | self.layer1 = nn.Sequential( 15 | spectral_norm( 16 | nn.Conv2d( 17 | 3 * num_labels, ndf, kw, stride=2, padding=pw, groups=num_labels 18 | ) 19 | ), 20 | nn.InstanceNorm2d(ndf, affine=False), 21 | ) 22 | self.layer2 = nn.Sequential( 23 | spectral_norm( 24 | nn.Conv2d(ndf * 1, ndf * 2, kw, stride=2, padding=pw, groups=num_labels) 25 | ), 26 | nn.InstanceNorm2d(ndf * 2, affine=False), 27 | ) 28 | self.layer3 = nn.Sequential( 29 | spectral_norm( 30 | nn.Conv2d(ndf * 2, ndf * 4, kw, stride=2, padding=pw, groups=num_labels) 31 | ), 32 | nn.InstanceNorm2d(ndf * 4, affine=False), 33 | ) 34 | self.layer4 = nn.Sequential( 35 | spectral_norm( 36 | nn.Conv2d(ndf * 4, ndf * 8, kw, stride=2, padding=pw, groups=num_labels) 37 | ), 38 | nn.InstanceNorm2d(ndf * 8, affine=False), 39 | ) 40 | self.layer5 = nn.Sequential( 41 | spectral_norm( 42 | nn.Conv2d(ndf * 8, ndf * 8, kw, stride=2, padding=pw, groups=num_labels) 43 | ), 44 | nn.InstanceNorm2d(ndf * 8, affine=False), 45 | ) 46 | self.layer6 = nn.Sequential( 47 | spectral_norm( 48 | nn.Conv2d(ndf * 8, ndf * 8, kw, stride=2, padding=pw, groups=num_labels) 49 | ), 50 | nn.InstanceNorm2d(ndf * 8, affine=False), 51 | ) 52 | 53 | self.so = s0 = 4 54 | self.fc_mu = nn.Conv2d( 55 | ndf * 8, num_labels * 8, kw, stride=1, padding=pw, groups=num_labels 56 | ) 57 | self.fc_var = nn.Conv2d( 58 | ndf * 8, num_labels * 8, kw, stride=1, padding=pw, groups=num_labels 59 | ) 60 | 61 | self.actvn = nn.LeakyReLU(0.2, False) 62 | 63 | def trans_img(self, input_semantics, real_image): 64 | images = None 65 | seg_range = input_semantics.size()[1] 66 | for i in range(input_semantics.size(0)): 67 | resize_image = None 68 | for n in range(0, seg_range): 69 | seg_image = real_image[i] * input_semantics[i][n] 70 | # resize seg_image 71 | c_sum = seg_image.sum(dim=0) 72 | y_seg = c_sum.sum(dim=0) 73 | x_seg = c_sum.sum(dim=1) 74 | y_id = y_seg.nonzero() 75 | if y_id.size()[0] == 0: 76 | seg_image = seg_image.unsqueeze(dim=0) 77 | # resize_image = torch.cat((resize_image, seg_image), dim=0) 78 | if resize_image is None: 79 | resize_image = seg_image 80 | else: 81 | resize_image = torch.cat((resize_image, seg_image), dim=1) 82 | continue 83 | # print(y_id) 84 | y_min = y_id[0][0] 85 | y_max = y_id[-1][0] 86 | x_id = x_seg.nonzero() 87 | x_min = x_id[0][0] 88 | x_max = x_id[-1][0] 89 | seg_image = seg_image.unsqueeze(dim=0) 90 | seg_image = F.interpolate( 91 | seg_image[:, :, x_min : x_max + 1, y_min : y_max + 1], 92 | size=[256, 256], 93 | ) 94 | if resize_image is None: 95 | resize_image = seg_image 96 | else: 97 | resize_image = torch.cat((resize_image, seg_image), dim=1) 98 | if images is None: 99 | images = resize_image 100 | else: 101 | images = torch.cat((images, resize_image), dim=0) 102 | return images 103 | 104 | def forward(self, image, segmap=None): 105 | image = self.trans_img(segmap, image) 106 | image = self.layer1(image) 107 | image = self.layer2(self.actvn(image)) 108 | image = self.layer3(self.actvn(image)) 109 | image = self.layer4(self.actvn(image)) 110 | image = self.layer5(self.actvn(image)) 111 | image = self.layer6(self.actvn(image)) 112 | 113 | image = self.actvn(image) 114 | 115 | mu = self.fc_mu(image) 116 | logvar = self.fc_var(image) 117 | 118 | return [mu, logvar], None -------------------------------------------------------------------------------- /models/encoders/inade.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from models.utils_blocks.base import BaseNetwork 4 | from utils.partial_conv import InstanceAwareConv2d 5 | 6 | class InstanceAdaptiveEncoder(BaseNetwork): 7 | def __init__(self, num_labels, noise_dim, use_vae=True): 8 | super().__init__() 9 | kw = 3 10 | pw = 1 11 | ndf = 64 12 | conv_layer = InstanceAwareConv2d 13 | 14 | self.layer1 = conv_layer(3, ndf, kw, stride=2, padding=pw) 15 | self.norm1 = nn.InstanceNorm2d(ndf) 16 | self.layer2 = conv_layer(ndf * 1, ndf * 2, kw, stride=2, padding=pw) 17 | self.norm2 = nn.InstanceNorm2d(ndf * 2) 18 | self.layer3 = conv_layer(ndf * 2, ndf * 4, kw, stride=2, padding=pw) 19 | self.norm3 = nn.InstanceNorm2d(ndf * 4) 20 | self.layer4 = conv_layer(ndf * 4, ndf * 8, kw, stride=2, padding=pw) 21 | self.norm4 = nn.InstanceNorm2d(ndf * 8) 22 | 23 | self.middle = conv_layer(ndf * 8, ndf * 4, kw, stride=1, padding=pw) 24 | self.norm_middle = nn.InstanceNorm2d(ndf * 4) 25 | self.up1 = conv_layer(ndf * 8, ndf * 2, kw, stride=1, padding=pw) 26 | self.norm_up1 = nn.InstanceNorm2d(ndf * 2) 27 | self.up2 = conv_layer(ndf * 4, ndf * 1, kw, stride=1, padding=pw) 28 | self.norm_up2 = nn.InstanceNorm2d(ndf) 29 | self.up3 = conv_layer(ndf * 2, ndf, kw, stride=1, padding=pw) 30 | self.norm_up3 = nn.InstanceNorm2d(ndf) 31 | 32 | self.up = nn.Upsample(scale_factor=2, mode="bilinear") 33 | self.num_labels = num_labels 34 | 35 | self.scale_conv_mu = conv_layer(ndf, noise_dim, kw, stride=1, padding=pw) 36 | self.scale_conv_var = conv_layer(ndf, noise_dim, kw, stride=1, padding=pw) 37 | self.bias_conv_mu = conv_layer(ndf, noise_dim, kw, stride=1, padding=pw) 38 | self.bias_conv_var = conv_layer(ndf, noise_dim, kw, stride=1, padding=pw) 39 | 40 | self.actvn = nn.LeakyReLU(0.2, False) 41 | 42 | def instAvgPooling(self, x, instances): 43 | inst_num = instances.size()[1] 44 | for i in range(inst_num): 45 | inst_mask = torch.unsqueeze(instances[:, i, :, :], 1) # [n,1,h,w] 46 | pixel_num = torch.sum( 47 | torch.sum(inst_mask, dim=2, keepdim=True), dim=3, keepdim=True 48 | ) 49 | pixel_num[pixel_num == 0] = 1 50 | feat = x * inst_mask 51 | feat = ( 52 | torch.sum(torch.sum(feat, dim=2, keepdim=True), dim=3, keepdim=True) 53 | / pixel_num 54 | ) 55 | if i == 0: 56 | out = torch.unsqueeze(feat[:, :, 0, 0], 1) # [n,1,c] 57 | else: 58 | out = torch.cat([out, torch.unsqueeze(feat[:, :, 0, 0], 1)], 1) 59 | # inst_pool_feats.append(feat[:,:,0,0]) # [n, 64] 60 | return out 61 | 62 | def forward(self, real_image, input_semantics): 63 | # instances [n,1,h,w], input_instances [n,inst_nc,h,w] 64 | instances = torch.argmax(input_semantics, 1, keepdim=True).float() 65 | x1 = self.actvn(self.norm1(self.layer1(real_image, instances))) 66 | x2 = self.actvn(self.norm2(self.layer2(x1, instances))) 67 | x3 = self.actvn(self.norm3(self.layer3(x2, instances))) 68 | x4 = self.actvn(self.norm4(self.layer4(x3, instances))) 69 | y = self.up(self.actvn(self.norm_middle(self.middle(x4, instances)))) 70 | y1 = self.up( 71 | self.actvn(self.norm_up1(self.up1(torch.cat([y, x3], 1), instances))) 72 | ) 73 | y2 = self.up( 74 | self.actvn(self.norm_up2(self.up2(torch.cat([y1, x2], 1), instances))) 75 | ) 76 | y3 = self.up( 77 | self.actvn(self.norm_up3(self.up3(torch.cat([y2, x1], 1), instances))) 78 | ) 79 | 80 | scale_mu = self.scale_conv_mu(y3, instances) 81 | scale_var = self.scale_conv_var(y3, instances) 82 | bias_mu = self.bias_conv_mu(y3, instances) 83 | bias_var = self.bias_conv_var(y3, instances) 84 | 85 | scale_mus = self.instAvgPooling(scale_mu, input_semantics) 86 | scale_vars = self.instAvgPooling(scale_var, input_semantics) 87 | bias_mus = self.instAvgPooling(bias_mu, input_semantics) 88 | bias_vars = self.instAvgPooling(bias_var, input_semantics) 89 | 90 | return (scale_mus, scale_vars, bias_mus, bias_vars), None 91 | -------------------------------------------------------------------------------- /models/encoders/sat.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from einops import repeat, rearrange 6 | 7 | from torchtyping import TensorType 8 | from typing import Tuple 9 | from collections import OrderedDict 10 | from models.utils_blocks.base import BaseNetwork 11 | 12 | 13 | from models.utils_blocks.attention import ( 14 | MaskedTransformer, 15 | SinusoidalPositionalEmbedding, 16 | ) 17 | 18 | from models.utils_blocks.equallr import EqualConv2d 19 | from functools import partial 20 | 21 | 22 | class SemanticAttentionTransformerEncoder(BaseNetwork): 23 | """ 24 | Encoder that encode style vectors for each segmentation labels doing regional average pooling 25 | 26 | Parameters: 27 | ----------- 28 | num_input_channels: int, 29 | Number of input channels 30 | latent_dim: int, 31 | Number of output channels (style_dim) 32 | num_features_fst_conv: iny, 33 | Number of kernels at first conv 34 | 35 | """ 36 | 37 | def __init__( 38 | self, 39 | num_input_channels: int, 40 | positional_embedding_dim: int, 41 | num_labels: int, 42 | num_latent_per_labels: int, 43 | latent_dim: int, 44 | num_blocks: int, 45 | attention_latent_dim: int, 46 | num_cross_heads: int, 47 | num_self_heads: int, 48 | type_of_initial_latents: str = "learned", 49 | image_conv: bool = False, 50 | reverse_conv: bool = False, 51 | num_latents_bg: int = None, 52 | conv_features_dim_first: int = 16, 53 | content_dim: int = 512, 54 | use_vae: bool = False, 55 | use_self_attention: bool = True, 56 | use_equalized_lr: bool = False, 57 | lr_mul: float = 1.0, 58 | ): 59 | super().__init__() 60 | nf = conv_features_dim_first 61 | if num_latents_bg is None: 62 | num_latents_bg = num_latent_per_labels 63 | if image_conv: 64 | self.image_pos_embs = nn.ModuleList( 65 | [ 66 | SinusoidalPositionalEmbedding( 67 | positional_embedding_dim, emb_type="concat" 68 | ) 69 | ] 70 | ) 71 | self.convs = nn.ModuleList([nn.Identity()]) 72 | 73 | else: 74 | self.image_pos_emb = SinusoidalPositionalEmbedding( 75 | positional_embedding_dim, emb_type="concat" 76 | ) 77 | ConvLayer = ( 78 | partial(EqualConv2d, lr_mul=lr_mul) if use_equalized_lr else nn.Conv2d 79 | ) 80 | self.lr_mul = lr_mul 81 | self.reverse_conv = reverse_conv 82 | self.image_conv = image_conv 83 | self.use_self_attention = use_self_attention 84 | 85 | image_emb_dim = num_input_channels + positional_embedding_dim 86 | 87 | num_latents = (num_labels - 1) * num_latent_per_labels + num_latents_bg 88 | self.num_latent_per_labels = num_latent_per_labels 89 | self.num_latents_bg = num_latents_bg 90 | 91 | self.return_attention = False 92 | 93 | latents_mask = [ 94 | torch.FloatTensor( 95 | [[1.0 for _ in range(num_latents_bg)] for _ in range(num_latents_bg)] 96 | ) 97 | ] + [ 98 | torch.FloatTensor( 99 | [ 100 | [1.0 for _ in range(num_latent_per_labels)] 101 | for _ in range(num_latent_per_labels) 102 | ] 103 | ) 104 | for _ in range(num_labels - 1) 105 | ] 106 | 107 | latents_mask = torch.block_diag(*latents_mask).unsqueeze(0) 108 | self.register_buffer("latents_mask", latents_mask) 109 | 110 | self.type_of_initial_latents = type_of_initial_latents 111 | self.latent_dim = latent_dim 112 | self.num_latents = num_latents 113 | if self.type_of_initial_latents == "learned": 114 | self.latents = nn.Parameter( 115 | torch.randn(num_latents, latent_dim).div_(lr_mul) 116 | ) 117 | elif self.type_of_initial_latents == "fixed_random": 118 | self.latents = torch.randn(num_latents, latent_dim) 119 | 120 | self.backbone = nn.ModuleDict(OrderedDict()) 121 | 122 | for i in range(num_blocks): 123 | module = nn.ModuleDict(OrderedDict()) 124 | if reverse_conv: 125 | module.update( 126 | { 127 | "cross_attention": MaskedTransformer( 128 | latent_dim, 129 | num_latents, 130 | image_emb_dim 131 | if i == num_blocks - 1 or not image_conv 132 | else nf * 2 ** (num_blocks - 1 - i), 133 | attention_latent_dim, 134 | num_cross_heads, 135 | use_equalized_lr=use_equalized_lr, 136 | lr_mul=lr_mul, 137 | ), 138 | } 139 | ) 140 | else: 141 | module.update( 142 | { 143 | "cross_attention": MaskedTransformer( 144 | latent_dim, 145 | num_latents, 146 | image_emb_dim if i == 0 or not image_conv else nf * 2**i, 147 | attention_latent_dim, 148 | num_cross_heads, 149 | use_equalized_lr=use_equalized_lr, 150 | lr_mul=lr_mul, 151 | ), 152 | } 153 | ) 154 | if use_self_attention: 155 | module.update( 156 | { 157 | "self_attention": MaskedTransformer( 158 | latent_dim, 159 | num_latents, 160 | latent_dim, 161 | attention_latent_dim, 162 | num_self_heads, 163 | use_equalized_lr=use_equalized_lr, 164 | lr_mul=lr_mul, 165 | ), 166 | } 167 | ) 168 | self.backbone.update({f"block_{i}": module}) 169 | if i > 0 and image_conv: 170 | self.convs.append( 171 | nn.Sequential( 172 | ConvLayer( 173 | num_input_channels if i == 1 else nf * 2 ** (i - 1), 174 | nf * 2**i, 175 | kernel_size=3, 176 | padding=1, 177 | stride=2, 178 | ), 179 | nn.LeakyReLU(0.2), 180 | ) 181 | ) 182 | self.image_pos_embs.append( 183 | SinusoidalPositionalEmbedding(nf * 2**i, emb_type="add") 184 | ) 185 | self.use_vae = use_vae 186 | 187 | if self.use_vae: 188 | self.vae_mapper = nn.ModuleList([nn.Linear(latent_dim, 2*latent_dim) for _ in range(num_latents)]) 189 | 190 | def forward( 191 | self, 192 | input: TensorType["batch_size", "num_input_channels", "height", "width"], 193 | segmentation_map: TensorType["batch_size", "num_labels", "height", "width"], 194 | ) -> Tuple[ 195 | TensorType["batch_size", "num_segmap_labels", "style_dim"], 196 | TensorType[ 197 | "batch_size", 198 | "output_dim", 199 | "output_fmap_heigth", 200 | "output_fmap_width", 201 | ], 202 | ]: 203 | batch_size = input.shape[0] 204 | if self.type_of_initial_latents == "learned": 205 | latents = repeat(self.latents, " l d -> b l d", b=batch_size) * self.lr_mul 206 | elif self.type_of_initial_latents == "fixed_random": 207 | latents = ( 208 | repeat(self.latents.to(input.device), " l d -> b l d", b=batch_size) 209 | * self.lr_mul 210 | ) 211 | elif self.type_of_initial_latents == "random": 212 | latents = torch.randn( 213 | batch_size, self.num_latents, self.latent_dim, device=input.device 214 | ) 215 | 216 | if self.return_attention: 217 | self.image_sizes = [] 218 | 219 | if self.image_conv: 220 | cross_attention_masks = [] 221 | flattened_inputs = [] 222 | for i, conv in enumerate(self.convs): 223 | input = conv(input) 224 | if i == len(self.convs) - 1: 225 | output_fmap = input 226 | 227 | interpolated_segmap = F.interpolate( 228 | segmentation_map, size=input.size()[2:], mode="nearest" 229 | ) 230 | if self.return_attention: 231 | self.image_sizes.append(input.size()[2:]) 232 | cross_attention_masks.append( 233 | rearrange( 234 | torch.cat( 235 | [ 236 | torch.repeat_interleave( 237 | interpolated_segmap[:, 0].unsqueeze(1), 238 | self.num_latents_bg, 239 | dim=1, 240 | ), 241 | torch.repeat_interleave( 242 | interpolated_segmap[:, 1:], 243 | self.num_latent_per_labels, 244 | dim=1, 245 | ), 246 | ], 247 | dim=1, 248 | ), 249 | "b n h w -> b n (h w)", 250 | ) 251 | ) 252 | 253 | flattened_inputs.append( 254 | rearrange(self.image_pos_embs[i](input), " b c h w -> b (h w) c") 255 | ) 256 | if self.reverse_conv: 257 | cross_attention_masks = cross_attention_masks[::-1] 258 | flattened_inputs = flattened_inputs[::-1] 259 | else: 260 | cross_attention_mask = rearrange( 261 | torch.cat( 262 | [ 263 | torch.repeat_interleave( 264 | segmentation_map[:, 0].unsqueeze(1), 265 | self.num_latents_bg, 266 | dim=1, 267 | ), 268 | torch.repeat_interleave( 269 | segmentation_map[:, 1:], 270 | self.num_latent_per_labels, 271 | dim=1, 272 | ), 273 | ], 274 | dim=1, 275 | ), 276 | "b n h w -> b n (h w)", 277 | ) 278 | flattened_input = rearrange( 279 | self.image_pos_emb(input), " b c h w -> b (h w) c" 280 | ) 281 | 282 | for i, block_name in enumerate(self.backbone): 283 | if self.image_conv: 284 | if self.return_attention: 285 | latents, _ = self.backbone[block_name]["cross_attention"]( 286 | latents, 287 | flattened_inputs[i], 288 | cross_attention_masks[i], 289 | return_attention=self.return_attention, 290 | ) 291 | else: 292 | latents = self.backbone[block_name]["cross_attention"]( 293 | latents, flattened_inputs[i], cross_attention_masks[i] 294 | ) 295 | else: 296 | if self.return_attention: 297 | latents, _ = self.backbone[block_name]["cross_attention"]( 298 | latents, 299 | flattened_input, 300 | cross_attention_mask, 301 | return_attention=self.return_attention, 302 | ) 303 | else: 304 | latents = self.backbone[block_name]["cross_attention"]( 305 | latents, flattened_input, cross_attention_mask 306 | ) 307 | if self.use_self_attention: 308 | latents = self.backbone[block_name]["self_attention"]( 309 | latents, latents, self.latents_mask 310 | ) 311 | if self.use_vae: 312 | latents_vae = torch.zeros((*latents.shape[:-1], 2*latents.shape[-1]), device=latents.device) 313 | for i, mapper in enumerate(self.vae_mapper): 314 | latents_vae[:, i, :] = mapper(latents[:, i, :]) 315 | latents = latents_vae.chunk(2, dim=-1) 316 | return latents, None 317 | -------------------------------------------------------------------------------- /models/encoders/sean.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | from torchtyping import TensorType 7 | from collections import OrderedDict 8 | from models.utils_blocks.base import BaseNetwork 9 | 10 | 11 | class RegionalAveragePoolingStyleEncoder(BaseNetwork): 12 | """ 13 | Encoder that encode style vectors for each segmentation labels doing regional average pooling 14 | 15 | Parameters: 16 | ----------- 17 | num_input_channels: int, 18 | Number of input channels 19 | latent_dim: int, 20 | Number of output channels (style_dim) 21 | num_features_fst_conv: iny, 22 | Number of kernels at first conv 23 | 24 | """ 25 | 26 | def __init__( 27 | self, 28 | num_input_channels: int, 29 | latent_dim: int, 30 | num_features_fst_conv: int = 32, 31 | use_vae: bool = False, 32 | ): 33 | super().__init__() 34 | 35 | nffc = num_features_fst_conv 36 | self.first_layer = nn.Sequential( 37 | nn.ReflectionPad2d(1), 38 | nn.Conv2d(num_input_channels, nffc, kernel_size=3, padding=0), 39 | nn.InstanceNorm2d(nffc), 40 | nn.LeakyReLU(0.2, False), 41 | ) 42 | self.bottleneck = nn.ModuleDict(OrderedDict()) 43 | 44 | for i in range(2): 45 | mult = 2**i 46 | module = nn.Sequential( 47 | nn.Conv2d( 48 | mult * nffc, 49 | 2 * mult * nffc, 50 | kernel_size=3, 51 | stride=2, 52 | padding=1, 53 | ), 54 | nn.InstanceNorm2d(2 * mult * nffc), 55 | nn.LeakyReLU(0.2, False), 56 | ) 57 | self.bottleneck.update({f"down_{i}": module}) 58 | 59 | self.bottleneck.update( 60 | { 61 | f"up_{1}": nn.Sequential( 62 | nn.ConvTranspose2d( 63 | 4 * nffc, 64 | nffc * 8, 65 | kernel_size=3, 66 | stride=2, 67 | padding=1, 68 | output_padding=1, 69 | ), 70 | nn.InstanceNorm2d(4 * nffc), 71 | nn.LeakyReLU(0.2, False), 72 | ) 73 | } 74 | ) 75 | 76 | self.last_layer = nn.Sequential( 77 | nn.ReflectionPad2d(1), 78 | nn.Conv2d(nffc * 8, latent_dim, kernel_size=3, padding=0), 79 | nn.InstanceNorm2d(nffc), 80 | nn.Tanh(), 81 | ) 82 | 83 | self.use_vae = use_vae 84 | if self.use_vae: 85 | self.vae_mapper = nn.Linear(latent_dim, 2*latent_dim) 86 | def forward( 87 | self, 88 | input: TensorType["batch_size", "num_input_channels", "height", "width"], 89 | segmentation_map: TensorType["batch_size", "num_labels", "height", "width"], 90 | ) -> TensorType["batch_size", "num_input_channels", "style_dim"]: 91 | x = self.first_layer(input) 92 | for _, block_name in enumerate(self.bottleneck): 93 | x = self.bottleneck[block_name](x) 94 | x = self.last_layer(x) 95 | segmentation_map = F.interpolate( 96 | segmentation_map, size=x.size()[2:], mode="nearest" 97 | ) 98 | (batch_size, style_dim, *_) = x.shape 99 | num_labels = segmentation_map.shape[1] 100 | style_codes = torch.zeros(batch_size, num_labels, style_dim, device=x.device) 101 | for i in range(num_labels): 102 | num_components = segmentation_map[:, i].unsqueeze(1).sum((2, 3)) 103 | num_components[num_components == 0] = 1 104 | style_codes[:, i] = (segmentation_map[:, i].unsqueeze(1) * x).sum( 105 | (2, 3) 106 | ) / num_components 107 | if self.use_vae: 108 | style_codes = self.vae_mapper(style_codes) 109 | style_codes = style_codes.chunk(2, dim=-1) 110 | return style_codes, None 111 | -------------------------------------------------------------------------------- /models/encoders/spade.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import torch.nn.utils.spectral_norm as spectral_norm 4 | from models.utils_blocks.base import BaseNetwork 5 | 6 | 7 | class SPADEStyleEncoder(BaseNetwork): 8 | """ 9 | Encoder that encode one style vector for the whole image as done in SPADE. 10 | 11 | Parameters: 12 | ----------- 13 | 14 | 15 | """ 16 | def __init__(self, use_vae=True): 17 | super().__init__() 18 | kw = 3 19 | pw = 1 20 | ndf = 64 21 | self.layer1 = nn.Sequential( 22 | spectral_norm(nn.Conv2d(3, ndf, kw, stride=2, padding=pw)), 23 | nn.InstanceNorm2d(ndf, affine=False), 24 | ) 25 | self.layer2 = nn.Sequential( 26 | spectral_norm(nn.Conv2d(ndf * 1, ndf * 2, kw, stride=2, padding=pw)), 27 | nn.InstanceNorm2d(ndf * 2, affine=False), 28 | ) 29 | self.layer3 = nn.Sequential( 30 | spectral_norm(nn.Conv2d(ndf * 2, ndf * 4, kw, stride=2, padding=pw)), 31 | nn.InstanceNorm2d(ndf * 4, affine=False), 32 | ) 33 | self.layer4 = nn.Sequential( 34 | spectral_norm(nn.Conv2d(ndf * 4, ndf * 8, kw, stride=2, padding=pw)), 35 | nn.InstanceNorm2d(ndf * 8, affine=False), 36 | ) 37 | self.layer5 = nn.Sequential( 38 | spectral_norm(nn.Conv2d(ndf * 8, ndf * 8, kw, stride=2, padding=pw)), 39 | nn.InstanceNorm2d(ndf * 8, affine=False), 40 | ) 41 | self.layer6 = nn.Sequential( 42 | spectral_norm(nn.Conv2d(ndf * 8, ndf * 8, kw, stride=2, padding=pw)), 43 | nn.InstanceNorm2d(ndf * 8, affine=False), 44 | ) 45 | 46 | self.so = s0 = 4 47 | self.fc_mu = nn.Linear(ndf * 8 * s0 * s0, 256) 48 | self.fc_var = nn.Linear(ndf * 8 * s0 * s0, 256) 49 | 50 | self.actvn = nn.LeakyReLU(0.2, False) 51 | 52 | def forward(self, image, segmap=None): 53 | if image.shape[2] != 256 or image.shape[3] != 256: 54 | image = F.interpolate(image, size=(256, 256), mode="bilinear") 55 | image = self.layer1(image) 56 | image = self.layer2(self.actvn(image)) 57 | image = self.layer3(self.actvn(image)) 58 | image = self.layer4(self.actvn(image)) 59 | image = self.layer5(self.actvn(image)) 60 | image = self.layer6(self.actvn(image)) 61 | 62 | image = self.actvn(image) 63 | 64 | image = image.view(image.size(0), -1) 65 | mu = self.fc_mu(image) 66 | logvar = self.fc_var(image) 67 | 68 | return [mu, logvar], None 69 | -------------------------------------------------------------------------------- /models/generators/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nicolas-dufour/SCAM/848484acd71c277f3ed48456e206870847dd400d/models/generators/__init__.py -------------------------------------------------------------------------------- /models/utils_blocks/EMA.py: -------------------------------------------------------------------------------- 1 | class EMA: 2 | def __init__(self, beta): 3 | super().__init__() 4 | self.beta = beta 5 | 6 | def update_average(self, old, new): 7 | if old is None: 8 | return new 9 | return old * self.beta + (1 - self.beta) * new 10 | -------------------------------------------------------------------------------- /models/utils_blocks/SAM.py: -------------------------------------------------------------------------------- 1 | ### Taken from https://github.com/davda54/sam/blob/main/sam.py 2 | import torch 3 | 4 | 5 | class SAM(torch.optim.Optimizer): 6 | def __init__(self, params, rho=0.05, adaptive=False, **kwargs): 7 | assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}" 8 | 9 | defaults = dict(rho=rho, adaptive=adaptive, **kwargs) 10 | super(SAM, self).__init__(params, defaults) 11 | 12 | self.base_optimizer = torch.optim.AdamW(self.param_groups, **kwargs) 13 | self.param_groups = self.base_optimizer.param_groups 14 | 15 | @torch.no_grad() 16 | def first_step(self, zero_grad=False): 17 | grad_norm = self._grad_norm() 18 | for group in self.param_groups: 19 | scale = group["rho"] / (grad_norm + 1e-12) 20 | 21 | for p in group["params"]: 22 | if p.grad is None: 23 | continue 24 | self.state[p]["old_p"] = p.data.clone() 25 | e_w = ( 26 | (torch.pow(p, 2) if group["adaptive"] else 1.0) 27 | * p.grad 28 | * scale.to(p) 29 | ) 30 | p.add_(e_w) # climb to the local maximum "w + e(w)" 31 | 32 | if zero_grad: 33 | self.zero_grad() 34 | 35 | @torch.no_grad() 36 | def second_step(self, zero_grad=False): 37 | for group in self.param_groups: 38 | for p in group["params"]: 39 | if p.grad is None: 40 | continue 41 | p.data = self.state[p]["old_p"] # get back to "w" from "w + e(w)" 42 | 43 | self.base_optimizer.step() # do the actual "sharpness-aware" update 44 | 45 | if zero_grad: 46 | self.zero_grad() 47 | 48 | @torch.no_grad() 49 | def step(self, closure=None): 50 | assert ( 51 | closure is not None 52 | ), "Sharpness Aware Minimization requires closure, but it was not provided" 53 | closure = torch.enable_grad()( 54 | closure 55 | ) # the closure should do a full forward-backward pass 56 | 57 | self.first_step(zero_grad=True) 58 | closure() 59 | self.second_step() 60 | 61 | def _grad_norm(self): 62 | shared_device = self.param_groups[0]["params"][ 63 | 0 64 | ].device # put everything on the same device, in case of model parallelism 65 | norm = torch.norm( 66 | torch.stack( 67 | [ 68 | ((torch.abs(p) if group["adaptive"] else 1.0) * p.grad) 69 | .norm(p=2) 70 | .to(shared_device) 71 | for group in self.param_groups 72 | for p in group["params"] 73 | if p.grad is not None 74 | ] 75 | ), 76 | p=2, 77 | ) 78 | return norm 79 | 80 | def load_state_dict(self, state_dict): 81 | super().load_state_dict(state_dict) 82 | self.base_optimizer.param_groups = self.param_groups 83 | -------------------------------------------------------------------------------- /models/utils_blocks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nicolas-dufour/SCAM/848484acd71c277f3ed48456e206870847dd400d/models/utils_blocks/__init__.py -------------------------------------------------------------------------------- /models/utils_blocks/attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | 6 | from torchtyping import TensorType 7 | from einops import repeat, rearrange 8 | from models.utils_blocks.equallr import EqualLinear 9 | from functools import partial 10 | 11 | 12 | class MaskedAttention(nn.Module): 13 | """ 14 | Masked Attention Module. Can be used for both self attention and cross attention 15 | 16 | to_dim: int, 17 | The dimension of the query token dim 18 | from_dim: int, 19 | The Dimension of the key token dim 20 | num_heads: int, 21 | Number of attention heads 22 | """ 23 | 24 | def __init__( 25 | self, 26 | to_dim: int, 27 | from_dim: int, 28 | latent_dim: int, 29 | num_heads: int, 30 | use_equalized_lr: bool = False, 31 | lr_mul: float = 1.0, 32 | ): 33 | super().__init__() 34 | self.latent_dim_k = latent_dim // num_heads 35 | self.num_heads = num_heads 36 | LinearLayer = ( 37 | partial(EqualLinear, lr_mul=lr_mul) if use_equalized_lr else nn.Linear 38 | ) 39 | 40 | # Mappings Query, Key Values 41 | self.q_linear = LinearLayer(to_dim, latent_dim) 42 | self.v_linear = LinearLayer(from_dim, latent_dim) 43 | self.k_linear = LinearLayer(from_dim, latent_dim) 44 | 45 | # Final output mapping 46 | self.out = nn.Sequential(LinearLayer(latent_dim, to_dim)) 47 | 48 | def forward( 49 | self, 50 | X_to: TensorType["batch_size", "num_to_tokens", "to_dim"], 51 | X_from: TensorType["batch_size", "num_from_tokens", "from_dim"], 52 | mask_from: TensorType["batch_size", "num_to_tokens", "num_from_tokens"] = None, 53 | return_attention: bool = False, 54 | ): 55 | Q = rearrange(self.q_linear(X_to), " b t (h k) -> b h t k ", h=self.num_heads) 56 | K = rearrange(self.v_linear(X_from), " b t (h k) -> b h t k ", h=self.num_heads) 57 | V = rearrange(self.k_linear(X_from), " b t (h k) -> b h t k ", h=self.num_heads) 58 | 59 | attn = torch.einsum("bhtk,bhfk->bhtf", [Q, K]) / math.sqrt(self.latent_dim_k) 60 | 61 | if mask_from is not None: 62 | mask_from = mask_from.unsqueeze(1) 63 | attn = attn.masked_fill(mask_from == 0, -1e4) 64 | 65 | attn = F.softmax(attn, dim=-1) 66 | 67 | output = torch.einsum("bhtf,bhfk->bhtk", [attn, V]) 68 | output = rearrange(output, "b h t k -> b t (h k)") 69 | output = self.out(output) 70 | 71 | if return_attention: 72 | return output, attn 73 | else: 74 | return output 75 | 76 | 77 | class MaskedTransformer(nn.Module): 78 | def __init__( 79 | self, 80 | to_dim, 81 | to_len, 82 | from_dim, 83 | latent_dim, 84 | num_heads, 85 | use_equalized_lr=False, 86 | lr_mul=1, 87 | ): 88 | super().__init__() 89 | self.attention = MaskedAttention( 90 | to_dim, 91 | from_dim, 92 | latent_dim, 93 | num_heads, 94 | use_equalized_lr=use_equalized_lr, 95 | lr_mul=lr_mul, 96 | ) 97 | 98 | LinearLayer = ( 99 | partial(EqualLinear, lr_mul=lr_mul) if use_equalized_lr else nn.Linear 100 | ) 101 | self.ln_1 = nn.LayerNorm((to_len, to_dim)) 102 | self.fc = nn.Sequential( 103 | LinearLayer(to_dim, to_dim), 104 | nn.LeakyReLU(2e-1), 105 | LinearLayer(to_dim, to_dim), 106 | nn.LeakyReLU(2e-1), 107 | ) 108 | self.ln_2 = nn.LayerNorm((to_len, to_dim)) 109 | 110 | def forward(self, X_to, X_from, mask=None, return_attention=False): 111 | if return_attention: 112 | X_to_out, attn = self.attention(X_to, X_from, mask, return_attention) 113 | else: 114 | X_to_out = self.attention(X_to, X_from, mask, return_attention) 115 | X_to = self.ln_1(X_to_out + X_to) 116 | X_to_out = self.fc(X_to) 117 | X_to = self.ln_2(X_to_out + X_to) 118 | if return_attention: 119 | return X_to, attn 120 | else: 121 | return X_to 122 | 123 | 124 | class SinusoidalPositionalEmbedding(nn.Module): 125 | """ 126 | Sinusoidal Positional Encoder 127 | Adapted from https://github.com/lucidrains/transganformer 128 | 129 | dim: int, 130 | Tokens dimension 131 | """ 132 | 133 | def __init__(self, dim: int, emb_type: str = "add"): 134 | super().__init__() 135 | dim //= 2 136 | inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) 137 | self.emb_type = emb_type 138 | self.register_buffer("inv_freq", inv_freq) 139 | 140 | def forward(self, x: TensorType["batch_size", "num_token", "dim"]): 141 | h = torch.linspace(-1.0, 1.0, x.shape[-2], device=x.device).type_as( 142 | self.inv_freq 143 | ) 144 | w = torch.linspace(-1.0, 1.0, x.shape[-1], device=x.device).type_as( 145 | self.inv_freq 146 | ) 147 | sinu_inp_h = torch.einsum("i , j -> i j", h, self.inv_freq) 148 | sinu_inp_w = torch.einsum("i , j -> i j", w, self.inv_freq) 149 | sinu_inp_h = repeat(sinu_inp_h, "h c -> () c h w", w=x.shape[-1]) 150 | sinu_inp_w = repeat(sinu_inp_w, "w c -> () c h w", h=x.shape[-2]) 151 | sinu_inp = torch.cat((sinu_inp_w, sinu_inp_h), dim=1) 152 | emb = torch.cat((sinu_inp.sin(), sinu_inp.cos()), dim=1) 153 | if self.emb_type == "add": 154 | x_emb = x + emb 155 | elif self.emb_type == "concat": 156 | emb = repeat(emb, "1 ... -> b ...", b=x.shape[0]) 157 | x_emb = torch.cat([x, emb], dim=1) 158 | return x_emb 159 | 160 | 161 | class LearnedPositionalEmbedding(nn.Module): 162 | """ 163 | Learned Positional Embedding 164 | 165 | Parameters: 166 | ----------- 167 | num_tokens_max: int, 168 | Max size of the sequence lenght 169 | dim_tokens: int, 170 | Size of the embedding dim 171 | """ 172 | 173 | def __init__(self, num_tokens_max: int, dim_tokens: int): 174 | super().__init__() 175 | self.num_tokens_max = num_tokens_max 176 | self.dim_tokens = dim_tokens 177 | self.weights = nn.Parameter(torch.Tensor(num_tokens_max, dim_tokens)) 178 | 179 | def forward(self, x: TensorType["batch_size", "num_tokens", "dim_tokens"]): 180 | _, num_tokens = x.shape[:2] 181 | assert num_tokens <= self.num_tokens_max 182 | return x + self.weights[:num_tokens].view(1, num_tokens, self.dim_tokens) 183 | -------------------------------------------------------------------------------- /models/utils_blocks/base.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class BaseNetwork(nn.Module): 5 | def __init__(self): 6 | super(BaseNetwork, self).__init__() 7 | 8 | def init_weight(self, do_init: bool = True): 9 | def init_func(m): 10 | classname = m.__class__.__name__ 11 | if classname.find("BatchNorm2d") != -1: 12 | if hasattr(m, "weight") and m.weight is not None: 13 | nn.init.normal_(m.weight.data, 1.0, 0.02) 14 | if hasattr(m, "bias") and m.bias is not None: 15 | nn.init.constant_(m.bias.data, 0.0) 16 | elif hasattr(m, "weight") and ( 17 | classname.find("Conv") != -1 or classname.find("Linear") != -1 18 | ): 19 | nn.init.xavier_normal_(m.weight.data, gain=0.02) 20 | 21 | if do_init: 22 | self.apply(init_func) 23 | 24 | for m in self.children(): 25 | if hasattr(m, "init_weights"): 26 | m.init_weights(do_init) 27 | -------------------------------------------------------------------------------- /models/utils_blocks/equallr.py: -------------------------------------------------------------------------------- 1 | ## Modified from https://github.com/rosinality/stylegan2-pytorch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from math import sqrt 5 | import torch 6 | 7 | 8 | class EqualConv2d(nn.Module): 9 | def __init__( 10 | self, 11 | in_channels, 12 | out_channels, 13 | kernel_size, 14 | stride=1, 15 | padding=0, 16 | bias=True, 17 | lr_mul=1, 18 | ): 19 | super().__init__() 20 | 21 | self.weight = nn.Parameter( 22 | torch.randn(out_channels, in_channels, kernel_size, kernel_size).div_( 23 | lr_mul 24 | ), 25 | requires_grad=True, 26 | ) 27 | # torch.nn.init.kaiming_uniform_(self.weight, a=sqrt(5)) 28 | self.scale = (1 / sqrt(in_channels * kernel_size**2)) * lr_mul 29 | self.stride = stride 30 | self.padding = padding 31 | self.lr_mul = lr_mul 32 | 33 | if bias: 34 | self.bias = nn.Parameter(torch.zeros(out_channels), requires_grad=True) 35 | 36 | else: 37 | self.bias = None 38 | 39 | def forward(self, input): 40 | out = F.conv2d( 41 | input, 42 | self.weight * self.scale, 43 | bias=self.bias * self.lr_mul, 44 | stride=self.stride, 45 | padding=self.padding, 46 | ) 47 | 48 | return out 49 | 50 | def __repr__(self): 51 | return ( 52 | f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]}," 53 | f" {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})" 54 | ) 55 | 56 | 57 | class EqualLinear(nn.Module): 58 | def __init__(self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1): 59 | super().__init__() 60 | 61 | self.weight = nn.Parameter( 62 | torch.randn(out_dim, in_dim).div_(lr_mul), requires_grad=True 63 | ) 64 | # torch.nn.init.kaiming_uniform_(self.weight, a=sqrt(5)) 65 | 66 | if bias: 67 | self.bias = nn.Parameter( 68 | torch.zeros(out_dim).fill_(bias_init), requires_grad=True 69 | ) 70 | 71 | else: 72 | self.bias = None 73 | self.scale = (1 / sqrt(in_dim)) * lr_mul 74 | self.lr_mul = lr_mul 75 | 76 | def forward(self, input): 77 | out = F.linear(input, self.weight * self.scale, bias=self.bias * self.lr_mul) 78 | return out 79 | 80 | def __repr__(self): 81 | return ( 82 | f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})" 83 | ) 84 | -------------------------------------------------------------------------------- /preprocess_fid_features.py: -------------------------------------------------------------------------------- 1 | from metrics.fid import compute_fid_features 2 | from data.datamodule import ImageDataModule 3 | import hydra 4 | from pathlib import Path 5 | 6 | 7 | @hydra.main(config_path="configs", config_name="config") 8 | def compute_fid_for_dataset(cfg): 9 | datamodule = ImageDataModule(cfg.dataset) 10 | datamodule.setup() 11 | train_path = Path(cfg.dataset.path) / Path("train/stats") 12 | train_path.mkdir(parents=True, exist_ok=True) 13 | print(f"Computing FID features for train set and saving to {train_path}") 14 | compute_fid_features(datamodule.train_dataloader(), train_path, device=cfg.compnode.accelerator) 15 | 16 | test_path = Path(cfg.dataset.path) / Path("test/stats") 17 | test_path.mkdir(parents=True, exist_ok=True) 18 | print(f"Computing FID features for test set and saving to {test_path}") 19 | compute_fid_features(datamodule.test_dataloader(), test_path, device=cfg.compnode.accelerator) 20 | 21 | if __name__ == "__main__": 22 | compute_fid_for_dataset() -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from re import T 2 | from data.datamodule import ImageDataModule 3 | from gan import Geppetto 4 | import hydra 5 | import shutil 6 | import wandb 7 | import os 8 | from utils.callbacks import GANImageLog, StopIfNaN, LogAttention 9 | from pytorch_lightning.utilities.rank_zero import _get_rank 10 | 11 | from pathlib import Path 12 | 13 | from omegaconf import OmegaConf 14 | 15 | 16 | @hydra.main(config_path="configs", config_name="config") 17 | def run(cfg): 18 | 19 | # print(OmegaConf.to_yaml(cfg, resolve=True)) 20 | 21 | dict_config = OmegaConf.to_container(cfg, resolve=True) 22 | 23 | Path(cfg.checkpoints.dirpath).mkdir(parents=True, exist_ok=True) 24 | 25 | shutil.copyfile(".hydra/config.yaml", f"{cfg.checkpoints.dirpath}/config.yaml") 26 | 27 | log_dict = {} 28 | 29 | log_dict["model"] = dict_config["model"] 30 | 31 | log_dict["dataset"] = dict_config["dataset"] 32 | 33 | datamodule = ImageDataModule(cfg.dataset) 34 | 35 | # logger.log_hyperparams(dict_config) 36 | 37 | checkpoint_callback = hydra.utils.instantiate(cfg.checkpoints) 38 | 39 | image_log_callback = GANImageLog() 40 | 41 | stop_if_nan_callback = StopIfNaN( 42 | ["train/gen_loss_step", "train/disc_loss_step"] 43 | ) 44 | 45 | progress_bar = hydra.utils.instantiate(cfg.progress_bar) 46 | 47 | callbacks = [ 48 | checkpoint_callback, 49 | image_log_callback, 50 | stop_if_nan_callback, 51 | progress_bar, 52 | ] 53 | if cfg.model.name == "SCAM": 54 | callbacks.append(LogAttention()) 55 | 56 | rank = _get_rank() 57 | 58 | if os.path.isfile(Path(cfg.checkpoints.dirpath) / Path("wandb_id.txt")): 59 | with open( 60 | Path(cfg.checkpoints.dirpath) / Path("wandb_id.txt"), "r" 61 | ) as wandb_id_file: 62 | wandb_id = wandb_id_file.readline() 63 | else: 64 | wandb_id = wandb.util.generate_id() 65 | print(f"generated id{wandb_id}") 66 | if rank == 0: 67 | with open( 68 | Path(cfg.checkpoints.dirpath) / Path("wandb_id.txt"), "w" 69 | ) as wandb_id_file: 70 | wandb_id_file.write(str(wandb_id)) 71 | 72 | if (Path(cfg.checkpoints.dirpath) / Path("last.ckpt")).exists(): 73 | 74 | print("Loading checkpoints") 75 | checkpoint_path = Path(cfg.checkpoints.dirpath) / Path("last.ckpt") 76 | 77 | logger = hydra.utils.instantiate(cfg.logger, id=wandb_id, resume="allow") 78 | model = Geppetto.load_from_checkpoint(checkpoint_path, cfg=cfg.model) 79 | logger.watch(model) 80 | trainer = hydra.utils.instantiate( 81 | cfg.trainer, 82 | strategy=cfg.trainer.strategy, 83 | logger=logger, 84 | callbacks=callbacks, 85 | resume_from_checkpoint=str(checkpoint_path), 86 | ) 87 | else: 88 | logger = hydra.utils.instantiate(cfg.logger, id=wandb_id, resume="allow") 89 | logger._wandb_init.update({"config": log_dict}) 90 | model = Geppetto(cfg.model) 91 | logger.watch(model) 92 | trainer = hydra.utils.instantiate( 93 | cfg.trainer, 94 | strategy=cfg.trainer.strategy, 95 | logger=logger, 96 | callbacks=callbacks, 97 | ) 98 | # trainer.fit_loop.epoch_loop.batch_loop.connect(optimizer_loop=YieldLoop()) 99 | 100 | trainer.fit(model, datamodule) 101 | 102 | trainer.test(model, dataloaders=datamodule, ckpt_path="best") 103 | 104 | 105 | if __name__ == "__main__": 106 | run() 107 | -------------------------------------------------------------------------------- /utils/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nicolas-dufour/SCAM/848484acd71c277f3ed48456e206870847dd400d/utils/.DS_Store -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nicolas-dufour/SCAM/848484acd71c277f3ed48456e206870847dd400d/utils/__init__.py -------------------------------------------------------------------------------- /utils/callbacks.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from pytorch_lightning import Callback 3 | from torch.utils.data import DataLoader 4 | import torch 5 | import torch.nn.functional as F 6 | import wandb 7 | import numpy as np 8 | from einops import rearrange 9 | from utils.utils import AttentionVis, get_palette, remap_image_torch 10 | 11 | 12 | log = logging.getLogger(__name__) 13 | 14 | 15 | class ReInitOptimAfterSanity(Callback): 16 | def on_train_start(self, trainer, pl_module): 17 | optimizers, lr_schedulers, optimizer_frequencies = trainer.init_optimizers( 18 | pl_module 19 | ) 20 | trainer.optimizers = optimizers 21 | trainer.lr_schedulers = lr_schedulers 22 | trainer.optimizer_frequencies = optimizer_frequencies 23 | 24 | 25 | class LogAttention(Callback): 26 | def __init__(self, num_samples: int = 8): 27 | super().__init__() 28 | self.num_samples = num_samples 29 | self.ready = True 30 | 31 | def on_sanity_check_start(self, trainer, pl_module): 32 | self.ready = False 33 | 34 | def on_sanity_check_end(self, trainer, pl_module): 35 | """Start executing this callback only after all validation sanity checks end.""" 36 | self.ready = True 37 | 38 | def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): 39 | if trainer.global_step % 25000 == 0: 40 | self.log_attention(trainer, pl_module) 41 | 42 | def on_test_end(self, trainer, pl_module): 43 | self.log_attention(trainer, pl_module) 44 | 45 | def log_attention(self, trainer, pl_module): 46 | if self.ready: 47 | logger = trainer.logger 48 | experiment = logger.experiment 49 | log_encoder = ( 50 | pl_module.cfg.encoder._target_ 51 | == "models.encoder.SemanticAttentionTransformerEncoder" 52 | ) 53 | 54 | attnn_viz = AttentionVis(log_encoder=log_encoder) 55 | 56 | # Create tokens palette 57 | np.random.seed(42) 58 | num_tokens = ( 59 | (pl_module.cfg.generator.num_labels - 1) 60 | * pl_module.cfg.generator.num_labels_split 61 | + pl_module.cfg.generator.num_labels_bg 62 | ) 63 | palette_tokens = torch.from_numpy(get_palette(num_tokens)).to( 64 | device=pl_module.device 65 | ) 66 | # Create labels palette 67 | np.random.seed(10) 68 | palette_segmask = torch.from_numpy( 69 | get_palette(pl_module.cfg.generator.num_labels) 70 | ).to(device=pl_module.device) 71 | 72 | val_dataloader = DataLoader( 73 | trainer.datamodule.test_dataset, 74 | batch_size=self.num_samples, 75 | shuffle=True, 76 | ) 77 | logs = dict() 78 | 79 | real_images, segmentation_maps = next(iter(val_dataloader)) 80 | real_images = real_images.to(device=pl_module.device) 81 | segmentation_maps = segmentation_maps.to(device=pl_module.device) 82 | 83 | segmask_colorized = palette_segmask[segmentation_maps.argmax(1)] 84 | 85 | segmask_colorized = rearrange(segmask_colorized, "b h w c -> b c h w") 86 | with torch.no_grad(): 87 | outputs = attnn_viz.encode_and_generate( 88 | pl_module, real_images, segmentation_maps 89 | ) 90 | output_images = outputs["output"] 91 | del outputs["output"] 92 | for attn_key, attn_val in outputs.items(): 93 | 94 | attentions = [ 95 | F.interpolate( 96 | attention, 97 | size=(output_images.shape[2], output_images.shape[3]), 98 | ) 99 | for attention in attn_val 100 | ] 101 | 102 | attentions_colorized = [ 103 | palette_tokens[attention.argmax(1)] for attention in attentions 104 | ] 105 | 106 | attentions_colorized = [ 107 | rearrange(attention, "b h w c -> b c h w") 108 | for attention in attentions_colorized 109 | ] 110 | 111 | attn_viz = rearrange( 112 | [ 113 | remap_image_torch(real_images), 114 | remap_image_torch(output_images), 115 | segmask_colorized, 116 | *attentions_colorized, 117 | ], 118 | "l b c h w -> b c h (l w)", 119 | ).float() 120 | attn_viz = [ 121 | wandb.Image( 122 | attn_viz[i], 123 | caption="Reference; Reconstruction; Segmask; Attn matrix argmax from first layers to last", 124 | ) 125 | for i in range(attn_viz.shape[0]) 126 | ] 127 | logs.update({f"Images/{attn_key}": attn_viz}) 128 | experiment.log(logs, step=trainer.global_step) 129 | 130 | 131 | class GANImageLog(Callback): 132 | def __init__(self, num_samples: int = 8): 133 | super().__init__() 134 | self.num_samples = num_samples 135 | self.ready = True 136 | 137 | def on_sanity_check_start(self, trainer, pl_module): 138 | self.ready = False 139 | 140 | def on_sanity_check_end(self, trainer, pl_module): 141 | """Start executing this callback only after all validation sanity checks end.""" 142 | self.ready = True 143 | 144 | def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): 145 | if trainer.global_step % 25000 == 0: 146 | self.log_images(trainer, pl_module) 147 | 148 | def on_test_end(self, trainer, pl_module): 149 | self.log_images(trainer, pl_module) 150 | 151 | def log_images(self, trainer, pl_module): 152 | if self.ready: 153 | logger = trainer.logger 154 | experiment = logger.experiment 155 | 156 | # get a validation batch from the validation dat loader 157 | val_dataloader = DataLoader( 158 | trainer.datamodule.test_dataset, 159 | batch_size=self.num_samples, 160 | shuffle=True, 161 | ) 162 | 163 | real_images, segmentation_maps = next(iter(val_dataloader)) 164 | real_images = real_images.to(device=pl_module.device) 165 | segmentation_maps = segmentation_maps.to(device=pl_module.device) 166 | 167 | with torch.no_grad(): 168 | 169 | styles_codes, content_code = pl_module.encoder( 170 | real_images, segmentation_maps 171 | ) 172 | reco_images = pl_module.generator( 173 | segmentation_maps, styles_codes, content_code 174 | ) 175 | reco_images = rearrange( 176 | [real_images, reco_images], "l b c h w -> b c h (l w)" 177 | ) 178 | segmask_reco = rearrange( 179 | [segmentation_maps, segmentation_maps], "l b c h w -> b c h (l w)" 180 | ) 181 | segmask_reco = segmask_reco.argmax(dim=1).cpu().numpy() 182 | permutation = torch.randperm(self.num_samples) 183 | if pl_module.cfg.losses.lambda_kld > 0: 184 | permuted_style_codes = [ 185 | styles_codes[i][permutation] for i in range(len(styles_codes)) 186 | ] 187 | else: 188 | permuted_style_codes = styles_codes[permutation] 189 | 190 | swap_images = pl_module.generator( 191 | segmentation_maps, permuted_style_codes, content_code 192 | ) 193 | segmask_swap = rearrange( 194 | [ 195 | segmentation_maps, 196 | segmentation_maps[permutation], 197 | segmentation_maps, 198 | ], 199 | "l b c h w -> b c h (l w)", 200 | ) 201 | segmask_swap = segmask_swap.argmax(dim=1).cpu().numpy() 202 | swap_images = rearrange( 203 | [real_images, real_images[permutation], swap_images], 204 | "l b c h w -> b c h (l w)", 205 | ) 206 | reco_images = [ 207 | wandb.Image( 208 | reco_images[i], 209 | caption="Left: Real; Right: Reconstruction", 210 | masks={ 211 | "Segmentation": { 212 | "mask_data": segmask_reco[i], 213 | } 214 | }, 215 | ) 216 | for i in range(reco_images.shape[0]) 217 | ] 218 | experiment.log( 219 | {"Images/Reconstruction": reco_images}, step=trainer.global_step 220 | ) 221 | 222 | swap_images = [ 223 | wandb.Image( 224 | swap_images[i], 225 | caption="Left: Semantic ref; Middle: style ref; Right: Swap", 226 | masks={ 227 | "Segmentation": { 228 | "mask_data": segmask_swap[i], 229 | } 230 | }, 231 | ) 232 | for i in range(swap_images.shape[0]) 233 | ] 234 | experiment.log({"Images/Swap": swap_images}, step=trainer.global_step) 235 | 236 | 237 | class StopIfNaN(Callback): 238 | def __init__(self, monitor): 239 | super().__init__() 240 | self.monitor = monitor 241 | self.continuous_nan_batchs = 0 242 | 243 | def on_train_batch_end( 244 | self, 245 | trainer, 246 | pl_module, 247 | outputs, 248 | batch, 249 | batch_idx, 250 | ) -> None: 251 | logs = trainer.callback_metrics 252 | i = 0 253 | found_metric = False 254 | while i < len(self.monitor) and not found_metric: 255 | if self.monitor[i] in logs.keys(): 256 | current = logs[self.monitor[i]].squeeze() 257 | found_metric = True 258 | else: 259 | i += 1 260 | if not found_metric: 261 | raise ValueError("Asked metric not in logs") 262 | 263 | if not torch.isfinite(current): 264 | self.continuous_nan_batchs += 1 265 | if self.continuous_nan_batchs >= 5: 266 | trainer.should_stop = True 267 | log.info("Training interrupted because of NaN in {self.monitor}") 268 | else: 269 | self.continuous_nan_batchs = 0 270 | 271 | def on_before_optimizer_step(self, trainer, pl_module, optimizer, opt_idx) -> None: 272 | valid_gradients = True 273 | for name, param in pl_module.named_parameters(): 274 | if param.grad is not None: 275 | valid_gradients = not ( 276 | torch.isnan(param.grad).any() or torch.isinf(param.grad).any() 277 | ) 278 | if not valid_gradients: 279 | break 280 | 281 | if not valid_gradients: 282 | log.warning( 283 | f"detected inf or nan values in gradients. not updating model parameters" 284 | ) 285 | optimizer.zero_grad() 286 | -------------------------------------------------------------------------------- /utils/partial_conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | from torch import nn 5 | import math 6 | from torch.nn import init 7 | 8 | 9 | class InstanceAwareConv2d(nn.Module): 10 | def __init__(self, fin, fout, kw, stride=1, padding=1): 11 | super().__init__() 12 | self.kw = kw 13 | self.stride = stride 14 | self.padding = padding 15 | self.fin = fin 16 | self.fout = fout 17 | self.unfold = nn.Unfold(kw, stride=stride, padding=padding) 18 | self.weight = nn.Parameter(torch.Tensor(fout, fin, kw, kw)) 19 | self.bias = nn.Parameter(torch.Tensor(fout)) 20 | self.reset_parameters() 21 | 22 | def reset_parameters(self): 23 | init.kaiming_uniform_(self.weight, a=math.sqrt(5)) 24 | if self.bias is not None: 25 | fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight) 26 | bound = 1 / math.sqrt(fan_in) 27 | init.uniform_(self.bias, -bound, bound) 28 | 29 | def forward(self, x, instances, check=False): 30 | N, C, H, W = x.size() 31 | # cal the binary mask from instance map 32 | instances = F.interpolate(instances, x.size()[2:], mode="nearest") # [n,1,h,w] 33 | inst_unf = self.unfold(instances) 34 | # substract the center pixel 35 | center = torch.unsqueeze(inst_unf[:, self.kw * self.kw // 2, :], 1) 36 | mask_unf = inst_unf - center 37 | # clip the absolute value to 0~1 38 | mask_unf = torch.abs(mask_unf) 39 | mask_unf = torch.clamp(mask_unf, 0, 1) 40 | mask_unf = 1.0 - mask_unf # [n,k*k,L] 41 | # multiply mask_unf and x 42 | x_unf = self.unfold(x) # [n,c*k*k,L] 43 | x_unf = x_unf.view(N, C, -1, x_unf.size()[-1]) # [n,c,,k*k,L] 44 | mask = torch.unsqueeze(mask_unf, 1) # [n,1,k*k,L] 45 | mask_x = mask * x_unf # [n,c,k*k,L] 46 | mask_x = mask_x.view(N, -1, mask_x.size()[-1]) # [n,c*k*k,L] 47 | # conv operation 48 | weight = self.weight.view(self.fout, -1) # [fout, c*k*k] 49 | out = torch.einsum("cm,nml->ncl", weight, mask_x) 50 | # x_unf = torch.unsqueeze(x_unf, 1) # [n,1,c*k*k,L] 51 | # out = torch.mul(masked_weight, x_unf).sum(dim=2, keepdim=False) # [n,fout,L] 52 | bias = torch.unsqueeze(torch.unsqueeze(self.bias, 0), -1) # [1,fout,1] 53 | out = out + bias 54 | out = out.view(N, self.fout, H // self.stride, W // self.stride) 55 | # print('weight:',self.weight[0,0,...]) 56 | # print('bias:',self.bias) 57 | 58 | if check: 59 | out2 = nn.functional.conv2d( 60 | x, self.weight, self.bias, stride=self.stride, padding=self.padding 61 | ) 62 | print((out - out2).abs().max()) 63 | return out 64 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from einops import rearrange 4 | 5 | 6 | def remap_image_numpy(image): 7 | 8 | image_numpy = ((image + 1) / 2.0) * 255.0 9 | return np.clip(image_numpy, 0, 255).astype(int) 10 | 11 | 12 | def remap_image_torch(image): 13 | 14 | image_torch = ((image + 1) / 2.0) * 255.0 15 | return torch.clip(image_torch, 0, 255).type(torch.uint8) 16 | 17 | 18 | def get_palette(num_cls): 19 | """Returns the color map for visualizing the segmentation mask. 20 | Args: 21 | num_cls: Number of classes 22 | Returns: 23 | The color map 24 | """ 25 | palette = np.random.randint(0, 255, size=(num_cls, 3)) 26 | return palette 27 | 28 | 29 | class AttentionVis: 30 | def __init__(self, log_encoder=True): 31 | self.latent_to_image_gen_masks = [] 32 | self.image_to_latent_gen_masks = [] 33 | self.log_encoder = log_encoder 34 | if self.log_encoder: 35 | self.encoder_masks = [] 36 | 37 | def encoder_hook_fn(self, m, i, o): 38 | self.encoder_masks.append(o[1]) 39 | 40 | def latent_to_image_gen_hook_fn(self, m, i, o): 41 | self.latent_to_image_gen_masks.append(o[1]) 42 | 43 | def image_to_latent_gen_hook_fn(self, m, i, o): 44 | self.image_to_latent_gen_masks.append(o[1]) 45 | 46 | def encode_and_generate(self, model, images, masks): 47 | ### activate attention 48 | num_gen_blocks = model.cfg.generator.num_up_layers 49 | gen_mod_blocks = [ 50 | getattr(model.generator.backbone[f"SCAM_block_{i}"], f"mod_{j}") 51 | for i in range(num_gen_blocks) 52 | for j in range(2) 53 | ] 54 | 55 | for gen_mod_block in gen_mod_blocks: 56 | gen_mod_block.return_attention = True 57 | if self.log_encoder: 58 | num_enc_blocks = model.cfg.encoder.num_blocks 59 | enc_blocks = [ 60 | model.encoder.backbone[f"block_{i}"]["cross_attention"] 61 | for i in range(num_enc_blocks) 62 | ] 63 | model.encoder.return_attention = True 64 | 65 | handles = [] 66 | for mod in gen_mod_blocks: 67 | handle = mod.latent_to_image.register_forward_hook( 68 | self.latent_to_image_gen_hook_fn 69 | ) 70 | handles.append(handle) 71 | handle = mod.image_to_latent.register_forward_hook( 72 | self.image_to_latent_gen_hook_fn 73 | ) 74 | handles.append(handle) 75 | if self.log_encoder: 76 | for enc_block in enc_blocks: 77 | handle = enc_block.register_forward_hook(self.encoder_hook_fn) 78 | handles.append(handle) 79 | 80 | # ### retrieve attention 81 | output = model.encode_and_generate(images, masks) 82 | gen_mod_dims = [(mod.height, mod.width) for mod in gen_mod_blocks] 83 | if self.log_encoder: 84 | enc_dims = model.encoder.image_sizes 85 | # ### Remove hook 86 | for handle in handles: 87 | handle.remove() 88 | # ### deasctivate attention output 89 | for gen_mod_block in gen_mod_blocks: 90 | gen_mod_block.return_attention = False 91 | if self.log_encoder: 92 | model.encoder.return_attention = False 93 | 94 | latent_to_image_gen_attention_masks = [ 95 | rearrange( 96 | attention_mask, 97 | "b heads (h w) c-> b c (heads h) w", 98 | h=h, 99 | w=w, 100 | ) 101 | for (h, w), attention_mask in zip( 102 | gen_mod_dims, self.latent_to_image_gen_masks 103 | ) 104 | ] 105 | image_to_latent_gen_attention_masks = [ 106 | rearrange( 107 | attention_mask, 108 | "b heads c (h w)-> b c (heads h) w", 109 | h=h, 110 | w=w, 111 | ) 112 | for (h, w), attention_mask in zip( 113 | gen_mod_dims, self.image_to_latent_gen_masks 114 | ) 115 | ] 116 | if self.log_encoder: 117 | encoder_attention_masks = [ 118 | rearrange( 119 | attention_mask, 120 | "b heads c (h w)-> b c (heads h) w", 121 | h=h, 122 | w=w, 123 | ) 124 | for (h, w), attention_mask in zip(enc_dims, self.encoder_masks) 125 | ] 126 | self.encoder_masks = [] 127 | self.latent_to_image_gen_masks = [] 128 | self.image_to_latent_gen_masks = [] 129 | 130 | if self.log_encoder: 131 | return { 132 | "output": output, 133 | "latent_to_image_gen_attn": latent_to_image_gen_attention_masks, 134 | "image_to_latent_gen_attn": image_to_latent_gen_attention_masks, 135 | "encoder_attn": encoder_attention_masks, 136 | } 137 | else: 138 | return { 139 | "output": output, 140 | "latent_to_image_gen_attn": latent_to_image_gen_attention_masks, 141 | "image_to_latent_gen_attn": image_to_latent_gen_attention_masks, 142 | } 143 | 144 | # def generate(self, model, style_codes, masks): 145 | # ### activate attention 146 | # model.generator.backbone.SCAM_block_5.mod_1.return_attention = True 147 | # output = model.generator(masks, style_codes) 148 | # self.height = model.generator.backbone.SCAM_block_5.mod_1.height 149 | # self.width = model.generator.backbone.SCAM_block_5.mod_1.width 150 | # ### register hook 151 | # handle = model.generator.backbone.SCAM_block_5.mod_1.latent_to_image.register_forward_hook( 152 | # self.hook_fn 153 | # ) 154 | # ### retrieve attention 155 | # output = model.generator(masks, style_codes) 156 | # ### Remove hook 157 | # handle.remove() 158 | # ### deasctivate attention output 159 | # model.generator.backbone.SCAM_block_4.mod_1.return_attention = False 160 | # return output, self.attention_masks 161 | --------------------------------------------------------------------------------