├── genforce ├── __init__.py ├── metrics │ ├── __init__.py │ ├── README.md │ └── fid.py ├── utils │ ├── __init__.py │ ├── logger_test.py │ └── misc.py ├── runners │ ├── __init__.py │ ├── losses │ │ ├── __init__.py │ │ └── logistic_gan_loss.py │ ├── controllers │ │ ├── __init__.py │ │ ├── cache_cleaner.py │ │ ├── snapshoter.py │ │ ├── checkpointer.py │ │ ├── timer.py │ │ ├── fid_evaluator.py │ │ ├── running_logger.py │ │ └── progress_scheduler.py │ ├── misc.py │ ├── stylegan_runner.py │ └── running_stats.py ├── datasets │ ├── __init__.py │ ├── README.md │ ├── dataloaders.py │ ├── distributed_sampler.py │ └── transforms.py ├── .gitignore ├── scripts │ ├── dist_train.sh │ ├── dist_test.sh │ ├── stylegan_training_demo.sh │ ├── slurm_train.sh │ └── slurm_test.sh ├── models │ ├── sync_op.py │ └── __init__.py ├── configs │ ├── stylegan_ffhq256_val.py │ ├── stylegan_ffhq1024_val.py │ ├── stylegan_demo.py │ ├── stylegan_ffhq256.py │ └── stylegan_ffhq1024.py ├── LICENSE ├── train.py ├── test.py ├── README.md ├── synthesize.py └── my_get_GD.py ├── .gitattributes ├── run_SMILE.sh ├── utils_RLBMI.py ├── my_merge_all_tensors.py ├── run_baselines.sh ├── LICENSE ├── test.sh ├── train_classification_models ├── README.md ├── self-train_VGGFace2 │ ├── kernel.py │ ├── train_efficientnet_b0.py │ ├── train_mobilenet_v2.py │ ├── train_vision_transformer.py │ └── train_inception_v3.py └── self-train_CASIA │ └── train_efficientnet_b0_casia.py ├── my_generate_blackbox_attack_dataset.py ├── README.md ├── my_sample_z_w_space.py ├── vgg_m_face_bn_dag.py ├── environment.yml ├── swintransformer_4finetune.py ├── vgg_face_dag.py ├── vitb16_4finetune.py ├── inceptionv3_4finetune.py └── mobilenetv2_4finetune.py /genforce/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /genforce/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /genforce/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /genforce/runners/__init__.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Collects all runners.""" 3 | 4 | from .stylegan_runner import StyleGANRunner 5 | 6 | __all__ = ['StyleGANRunner'] 7 | -------------------------------------------------------------------------------- /genforce/runners/losses/__init__.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Collects all loss functions.""" 3 | 4 | from .logistic_gan_loss import LogisticGANLoss 5 | 6 | __all__ = ['LogisticGANLoss'] 7 | -------------------------------------------------------------------------------- /genforce/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Collects datasets and data loaders.""" 3 | 4 | from .datasets import BaseDataset 5 | from .dataloaders import IterDataLoader 6 | 7 | __all__ = ['BaseDataset', 'IterDataLoader'] 8 | -------------------------------------------------------------------------------- /genforce/.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.py[cod] 3 | 4 | /.vscode/ 5 | /.idea/ 6 | *.sw[pon] 7 | 8 | /data/ 9 | /work_dirs/ 10 | *.jpg 11 | *.png 12 | *.jpeg 13 | *.gif 14 | *.avi 15 | *.mp4 16 | 17 | *.npy 18 | *.txt 19 | *.json 20 | *.log 21 | *.html 22 | *.tar 23 | *.zip 24 | events.* 25 | 26 | *.pth 27 | *.pkl 28 | *.h5 29 | *.dat 30 | -------------------------------------------------------------------------------- /genforce/scripts/dist_train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | GPUS=$1 4 | CONFIG=$2 5 | WORK_DIR=$3 6 | PORT=${PORT:-29500} 7 | 8 | python3 -m torch.distributed.launch \ 9 | --nproc_per_node=${GPUS} \ 10 | --master_port=${PORT} \ 11 | ./train.py ${CONFIG} \ 12 | --work_dir ${WORK_DIR} \ 13 | --launcher="pytorch" \ 14 | ${@:4} 15 | -------------------------------------------------------------------------------- /genforce/scripts/dist_test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | GPUS=$1 4 | CONFIG=$2 5 | WORK_DIR=$3 6 | CHECKPOINT=$4 7 | PORT=${PORT:-29500} 8 | 9 | python -m torch.distributed.launch \ 10 | --nproc_per_node=${GPUS} \ 11 | --master_port=${PORT} \ 12 | ./test.py ${CONFIG} \ 13 | --work_dir ${WORK_DIR} \ 14 | --checkpoint ${CHECKPOINT} \ 15 | --launcher="pytorch" \ 16 | ${@:5} 17 | -------------------------------------------------------------------------------- /genforce/models/sync_op.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Contains the synchronizing operator.""" 3 | 4 | import torch 5 | import torch.distributed as dist 6 | 7 | __all__ = ['all_gather'] 8 | 9 | 10 | def all_gather(tensor): 11 | """Gathers tensor from all devices and does averaging.""" 12 | if not dist.is_initialized(): 13 | return tensor 14 | 15 | world_size = dist.get_world_size() 16 | tensor_list = [torch.ones_like(tensor) for _ in range(world_size)] 17 | dist.all_gather(tensor_list, tensor, async_op=False) 18 | return torch.mean(torch.stack(tensor_list, dim=0), dim=0) 19 | -------------------------------------------------------------------------------- /genforce/runners/controllers/__init__.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Collects all controllers.""" 3 | 4 | from .cache_cleaner import CacheCleaner 5 | from .checkpointer import Checkpointer 6 | from .fid_evaluator import FIDEvaluator 7 | from .lr_scheduler import LRScheduler 8 | from .progress_scheduler import ProgressScheduler 9 | from .running_logger import RunningLogger 10 | from .snapshoter import Snapshoter 11 | from .timer import Timer 12 | 13 | __all__ = [ 14 | 'CacheCleaner', 'Checkpointer', 'FIDEvaluator', 'LRScheduler', 15 | 'ProgressScheduler', 'RunningLogger', 'Snapshoter', 'Timer' 16 | ] 17 | -------------------------------------------------------------------------------- /genforce/scripts/stylegan_training_demo.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | echo "==================================================" 4 | echo "Please ensure you have installed the requirements!" 5 | 6 | # Download data. 7 | echo "Downloading data ..." 8 | mkdir -p data/ 9 | wget -nv https://www.dropbox.com/s/vvtcqcujdjeq3zs/mini_animeface.zip?dl=1 \ 10 | -O data/demo.zip --quiet 11 | 12 | # Launch training. 13 | echo "Launch training job with 1 GPU." 14 | echo "==================================================" 15 | PORT=6666 ./scripts/dist_train.sh 2 \ 16 | configs/stylegan_demo.py \ 17 | work_dirs/stylegan_demo 18 | -------------------------------------------------------------------------------- /run_SMILE.sh: -------------------------------------------------------------------------------- 1 | for target in {1..50} 2 | 3 | do 4 | python my_whitebox_attacks.py --attack_mode ours-w --target_dataset vggface2 --dataset celeba_partial256 --arch_name_target inception_resnetv1_vggface2 --target $target --epochs 200 --arch_name_finetune inception_resnetv1_casia --finetune_mode 'vggface2->CASIA' --num_experts 3 --EorOG SMILE --population_size 2500 5 | python my_blackbox_attacks.py --attack_mode ours-surrogate_model --target_dataset vggface2 --dataset celeba_partial256 --arch_name_target inception_resnetv1_vggface2 --target $target --budget 1000 --population_size 2500 --epochs 200 --finetune_mode 'vggface2->CASIA' --arch_name_finetune inception_resnetv1_casia --EorOG SMILE --lr 0.2 --x 1.7 6 | done 7 | -------------------------------------------------------------------------------- /genforce/scripts/slurm_train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -x 4 | 5 | PARTITION=$1 6 | JOB_NAME=$2 7 | CONFIG=$3 8 | WORK_DIR=$4 9 | GPUS=${GPUS:-8} 10 | GPUS_PER_NODE=${GPUS_PER_NODE:-8} 11 | CPUS_PER_NODE=${CPUS_PER_NODE:-8} 12 | 13 | SRUN_ARGS=${SRUN_ARGS:-""} 14 | PY_ARGS=${@:5} 15 | 16 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 17 | srun -p ${PARTITION} \ 18 | --job-name=${JOB_NAME} \ 19 | --gres=gpu:${GPUS_PER_NODE} \ 20 | --ntasks=${GPUS} \ 21 | --ntasks-per-node=${GPUS_PER_NODE} \ 22 | --cpus-per-task=${CPUS_PER_NODE} \ 23 | --kill-on-bad-exit=1 \ 24 | ${SRUN_ARGS} \ 25 | python -u ./train.py ${CONFIG} \ 26 | --work_dir=${WORK_DIR} \ 27 | --launcher="slurm" \ 28 | ${PY_ARGS} 29 | -------------------------------------------------------------------------------- /genforce/scripts/slurm_test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -x 4 | 5 | PARTITION=$1 6 | JOB_NAME=$2 7 | CONFIG=$3 8 | WORK_DIR=$4 9 | CHECKPOINT=$5 10 | GPUS=${GPUS:-8} 11 | GPUS_PER_NODE=${GPUS_PER_NODE:-8} 12 | CPUS_PER_NODE=${CPUS_PER_NODE:-8} 13 | 14 | SRUN_ARGS=${SRUN_ARGS:-""} 15 | PY_ARGS=${@:6} 16 | 17 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 18 | srun -p ${PARTITION} \ 19 | --job-name=${JOB_NAME} \ 20 | --gres=gpu:${GPUS_PER_NODE} \ 21 | --ntasks=${GPUS} \ 22 | --ntasks-per-node=${GPUS_PER_NODE} \ 23 | --cpus-per-task=${CPUS_PER_TASK} \ 24 | --kill-on-bad-exit=1 \ 25 | ${SRUN_ARGS} \ 26 | python -u ./test.py ${CONFIG} \ 27 | --work_dir=${WORK_DIR} \ 28 | --checkpoint ${CHECKPOINT} \ 29 | --launcher="slurm" \ 30 | ${PY_ARGS} 31 | -------------------------------------------------------------------------------- /genforce/configs/stylegan_ffhq256_val.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Configuration for testing StyleGAN on FF-HQ (256) dataset. 3 | 4 | All settings are particularly used for one replica (GPU), such as `batch_size` 5 | and `num_workers`. 6 | """ 7 | 8 | runner_type = 'StyleGANRunner' 9 | gan_type = 'stylegan' 10 | resolution = 256 11 | batch_size = 64 12 | 13 | data = dict( 14 | num_workers=4, 15 | # val=dict(root_dir='data/ffhq', resolution=resolution), 16 | val=dict(root_dir='data/ffhq.zip', data_format='zip', 17 | resolution=resolution), 18 | ) 19 | 20 | modules = dict( 21 | discriminator=dict( 22 | model=dict(gan_type=gan_type, resolution=resolution), 23 | kwargs_val=dict(), 24 | ), 25 | generator=dict( 26 | model=dict(gan_type=gan_type, resolution=resolution), 27 | kwargs_val=dict(trunc_psi=0.7, trunc_layers=8, randomize_noise=False), 28 | ) 29 | ) 30 | -------------------------------------------------------------------------------- /genforce/configs/stylegan_ffhq1024_val.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Configuration for testing StyleGAN on FF-HQ (1024) dataset. 3 | 4 | All settings are particularly used for one replica (GPU), such as `batch_size` 5 | and `num_workers`. 6 | """ 7 | 8 | runner_type = 'StyleGANRunner' 9 | gan_type = 'stylegan' 10 | resolution = 1024 11 | batch_size = 16 12 | 13 | data = dict( 14 | num_workers=4, 15 | # val=dict(root_dir='data/ffhq', resolution=resolution), 16 | val=dict(root_dir='data/ffhq.zip', data_format='zip', 17 | resolution=resolution), 18 | ) 19 | 20 | modules = dict( 21 | discriminator=dict( 22 | model=dict(gan_type=gan_type, resolution=resolution), 23 | kwargs_val=dict(), 24 | ), 25 | generator=dict( 26 | model=dict(gan_type=gan_type, resolution=resolution), 27 | kwargs_val=dict(trunc_psi=0.7, trunc_layers=8, randomize_noise=False), 28 | ) 29 | ) 30 | -------------------------------------------------------------------------------- /utils_RLBMI.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision import transforms 3 | 4 | def load_my_state_dict(self, state_dict): 5 | own_state = self.state_dict() 6 | for name, param in state_dict.items(): 7 | if name not in own_state: 8 | print(name) 9 | continue 10 | own_state[name].copy_(param.data) 11 | 12 | def get_deprocessor(): 13 | proc = [] 14 | proc.append(transforms.Resize((112, 112))) 15 | proc.append(transforms.ToTensor()) 16 | return transforms.Compose(proc) 17 | 18 | def low2high(img): 19 | bs = img.size(0) 20 | proc = get_deprocessor() 21 | img_tensor = img.detach().cpu().float() 22 | img = torch.zeros(bs, 3, 112, 112) 23 | for i in range(bs): 24 | img_i = transforms.ToPILImage()(img_tensor[i, :, :, :]).convert('RGB') 25 | img_i = proc(img_i) 26 | img[i, :, :, :] = img_i[:, :, :] 27 | 28 | img = img.cuda() 29 | return img 30 | -------------------------------------------------------------------------------- /genforce/runners/controllers/cache_cleaner.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Contains the running controller to clean cache.""" 3 | 4 | import torch 5 | 6 | from .base_controller import BaseController 7 | 8 | __all__ = ['CacheCleaner'] 9 | 10 | 11 | class CacheCleaner(BaseController): 12 | """Defines the running controller to clean cache. 13 | 14 | This controller is used to empty the GPU cache after each iteration. 15 | 16 | NOTE: The controller is set to `LAST` priority by default. 17 | """ 18 | 19 | def __init__(self, config=None): 20 | config = config or dict() 21 | config.setdefault('priority', 'LAST') 22 | config.setdefault('every_n_iters', 1) 23 | super().__init__(config) 24 | 25 | def setup(self, runner): 26 | torch.cuda.empty_cache() 27 | 28 | def close(self, runner): 29 | torch.cuda.empty_cache() 30 | 31 | def execute_after_iteration(self, runner): 32 | torch.cuda.empty_cache() 33 | -------------------------------------------------------------------------------- /my_merge_all_tensors.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # coding=utf-8 3 | import argparse 4 | import os 5 | import glob 6 | 7 | import torch 8 | 9 | 10 | def merge(root_dir, remove=False): 11 | files = sorted(glob.glob(os.path.join(root_dir, 'sample_*_img_logits.pt'))) 12 | print('#files,', len(files)) 13 | all_res = [] 14 | for f in files: 15 | all_res.append(torch.load(f, map_location=torch.device('cpu'))) 16 | all_res = torch.cat(all_res, dim=0) 17 | print('all_res.shape', all_res.shape) 18 | torch.save(all_res, os.path.join(root_dir, 'all_logits.pt')) 19 | 20 | if remove: 21 | for f in files: 22 | os.remove(f) 23 | 24 | 25 | if __name__ == '__main__': 26 | parser = argparse.ArgumentParser() 27 | parser.add_argument('--remove', action='store_true', help='remove files used to merge') 28 | parser.add_argument('root_dir') 29 | args = parser.parse_args() 30 | merge(args.root_dir, remove=args.remove) 31 | -------------------------------------------------------------------------------- /run_baselines.sh: -------------------------------------------------------------------------------- 1 | for target in {1..50} 2 | 3 | do 4 | # Mirror-w 5 | python my_whitebox_attacks.py --attack_mode Mirror-w --target_dataset vggface2 --dataset celeba_partial256 --arch_name_target inception_resnetv1_vggface2 --target $target --epochs 20000 --population_size 100000 6 | 7 | # PPA 8 | python my_whitebox_attacks.py --attack_mode PPA --target_dataset vggface2 --dataset celeba_partial256 --arch_name_target inception_resnetv1_vggface2 --target $target --population_size 5000 --epochs 70 --candidate 200 --final_selection 50 --iterations 100 --lr 0.005 9 | 10 | # Mirror-b 11 | python my_blackbox_attacks.py --attack_mode Mirror-b --target_dataset vggface2 --dataset celeba_partial256 --arch_name_target inception_resnetv1_vggface2 --target $target --population_size 1000 --generations 20 12 | 13 | # RLBMI 14 | python my_blackbox_attacks.py --attack_mode RLB-MI --target_dataset vggface2 --dataset celeba_RLBMI --arch_name_target inception_resnetv1_vggface2 --target $target --max_episodes 40000 15 | 16 | done 17 | -------------------------------------------------------------------------------- /genforce/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2020 GenForce 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | this software and associated documentation files (the "Software"), to deal in 5 | the Software without restriction, including without limitation the rights to 6 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies 7 | of the Software, and to permit persons to whom the Software is furnished to do 8 | so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all 11 | copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 15 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 16 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 17 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 18 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 19 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 L1ziang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /genforce/runners/controllers/snapshoter.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Contains the running controller for saving snapshot.""" 3 | 4 | from .base_controller import BaseController 5 | 6 | __all__ = ['Snapshoter'] 7 | 8 | 9 | class Snapshoter(BaseController): 10 | """Defines the running controller for evaluation. 11 | 12 | NOTE: The controller is set to `LAST` priority by default. 13 | """ 14 | 15 | def __init__(self, config): 16 | config.setdefault('priority', 'LAST') 17 | super().__init__(config) 18 | 19 | self.num = config.get('num', 100) 20 | 21 | def setup(self, runner): 22 | assert hasattr(runner, 'synthesize') 23 | 24 | def execute_after_iteration(self, runner): 25 | mode = runner.mode # save runner mode. 26 | runner.synthesize(self.num, 27 | html_name=f'snapshot_{runner.iter:06d}.html', 28 | save_raw_synthesis=False) 29 | runner.logger.info(f'Saving snapshot at iter {runner.iter:06d} ' 30 | f'({runner.seen_img / 1000:.1f} kimg).') 31 | runner.set_mode(mode) # restore runner mode. 32 | -------------------------------------------------------------------------------- /genforce/runners/misc.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Misc utility functions used for model running.""" 3 | 4 | __all__ = ['format_time'] 5 | 6 | 7 | def format_time(seconds): 8 | """Formats seconds to readable time string. 9 | 10 | Args: 11 | seconds: Number of seconds to format. 12 | 13 | Returns: 14 | The formatted time string. 15 | 16 | Raises: 17 | ValueError: If the input `seconds` is less than 0. 18 | """ 19 | if seconds < 0: 20 | raise ValueError(f'Input `seconds` should be greater than or equal to ' 21 | f'0, but `{seconds}` is received!') 22 | 23 | # Returns seconds as float if less than 1 minute. 24 | if seconds < 10: 25 | return f'{seconds:5.3f}s' 26 | if seconds < 60: 27 | return f'{seconds:5.2f}s' 28 | 29 | seconds = int(seconds + 0.5) 30 | days, seconds = divmod(seconds, 86400) 31 | hours, seconds = divmod(seconds, 3600) 32 | minutes, seconds = divmod(seconds, 60) 33 | if days: 34 | return f'{days:2d}d{hours:02d}h' 35 | if hours: 36 | return f'{hours:2d}h{minutes:02d}m' 37 | return f'{minutes:2d}m{seconds:02d}s' 38 | -------------------------------------------------------------------------------- /genforce/datasets/README.md: -------------------------------------------------------------------------------- 1 | # Data Preparation 2 | 3 | ## Data Format 4 | 5 | Currently, our dataloader is able to load data from 6 | 7 | - a directory that is full of images 8 | - a `lmdb` file 9 | - an image list 10 | - a compressed file (i.e., `zip` package) 11 | 12 | by modifying `data_format` in the configuration. 13 | 14 | **NOTE:** For some computing clusters whose I/O speed may be slow, we recommend the `zip` format for two reasons. First, `zip` file is easy to create. Second, this can load a large file at one time instead of loading small files repeatedly. 15 | 16 | ## Data Sampling 17 | 18 | Considering that most generative models are trained in the unit of iterations instead of epochs, we change the default data loader to an *iter-based* one. Besides, the original distributed data sampler is also modified to make the shuffling correspond to iteration instead of epoch. 19 | 20 | **NOTE:** In order to reduce the data re-loading cost between epochs, we manually extend the length of sampled indices to make it much more efficient. 21 | 22 | ## Data Augmentation 23 | 24 | To better align with the original implementation of PGGAN and StyleGAN (i.e., models that require progressive training), we support progressive resize in `transforms.py`, which downsamples images with the maximum resize factor of 2 at each time. 25 | -------------------------------------------------------------------------------- /genforce/metrics/README.md: -------------------------------------------------------------------------------- 1 | # Evaluation Metrics 2 | 3 | Frechet Inception Distance (FID) is commonly used to evaluate generative model. It employs an [Inception Model](https://arxiv.org/abs/1512.00567) (pretrained on ImageNet) to extract features from both real and synthesized images. 4 | 5 | ## Inception Model 6 | 7 | For [PGGAN](https://github.com/tkarras/progressive_growing_of_gans), [StyleGAN](https://github.com/NVlabs/stylegan), etc, they use inception model from the [TensorFlow Models](https://github.com/tensorflow/models) repository, whose implementation is slightly different from that of `torchvision`. Hence, to make the evaluation metric comparable between different training frameworks (i.e., PyTorch and TensorFlow), we modify `torchvision/models/inception.py` as `inception.py`. The ported pre-trained weight is borrowed from [this repo](https://github.com/mseitzer/pytorch-fid). 8 | 9 | **NOTE:** We also support using the model from `torchvision` to compute the FID. However, please be aware that the FID value from `torchvision` is usually ~1.5 smaller than that from the TensorFlow model. 10 | 11 | Please use the following code to choose which model to use. 12 | 13 | ```python 14 | from metrics.inception import build_inception_model 15 | 16 | inception_model_tf = build_inception_model(align_tf=True) 17 | inception_model_pth = build_inception_model(align_tf=False) 18 | ``` 19 | -------------------------------------------------------------------------------- /test.sh: -------------------------------------------------------------------------------- 1 | targets=$(echo {1..50} | tr ' ' ',') 2 | 3 | # Mirror-w 4 | python my_whitebox_attacks.py --attack_mode Mirror-w --target_dataset vggface2 --dataset celeba_partial256 --arch_name_target inception_resnetv1_vggface2 --test_target=$targets --epochs 20000 --population_size 100000 --test_only 5 | 6 | # PPA 7 | python my_whitebox_attacks.py --attack_mode PPA --target_dataset vggface2 --dataset celeba_partial256 --arch_name_target inception_resnetv1_vggface2 --test_target=$targets --population_size 5000 --epochs 70 --candidate 200 --final_selection 50 --iterations 100 --lr 0.005 --test_only 8 | 9 | # Mirror-b 10 | python my_blackbox_attacks.py --attack_mode Mirror-b --target_dataset vggface2 --dataset celeba_partial256 --arch_name_target inception_resnetv1_vggface2 --test_target=$targets --population_size 1000 --generations 20 --test_only 11 | 12 | # RLBMI 13 | python my_blackbox_attacks.py --attack_mode RLB-MI --target_dataset vggface2 --dataset celeba_RLBMI --arch_name_target inception_resnetv1_vggface2 --test_target=$targets --max_episodes 40000 --test_only 14 | 15 | # SMILE 16 | python my_blackbox_attacks.py --attack_mode ours-surrogate_model --target_dataset vggface2 --dataset celeba_partial256 --arch_name_target inception_resnetv1_vggface2 --test_target=$targets --budget 1000 --population_size 2500 --finetune_mode 'vggface2->CASIA' --arch_name_finetune inception_resnetv1_casia --EorOG SMILE --epochs 200 --lr 0.2 --test_only 17 | 18 | 19 | 20 | -------------------------------------------------------------------------------- /genforce/runners/controllers/checkpointer.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Contains the running controller to handle checkpoints.""" 3 | 4 | import os.path 5 | 6 | from .base_controller import BaseController 7 | 8 | __all__ = ['Checkpointer'] 9 | 10 | class Checkpointer(BaseController): 11 | """Defines the running controller to handle checkpoints. 12 | 13 | This controller is used to save and load checkpoints. 14 | 15 | NOTE: This controller is set to `LAST` priority by default and will only be 16 | executed on the master worker. 17 | """ 18 | 19 | def __init__(self, config): 20 | assert isinstance(config, dict) 21 | config.setdefault('priority', 'LAST') 22 | config.setdefault('master_only', True) 23 | super().__init__(config) 24 | 25 | self._save_dir = config.get('checkpoint_dir', None) 26 | self._save_running_metadata = config.get('save_running_metadata', True) 27 | self._save_learning_rate = config.get('save_learning_rate', True) 28 | self._save_optimizer = config.get('save_optimizer', True) 29 | self._save_running_stats = config.get('save_running_stats', False) 30 | 31 | def execute_after_iteration(self, runner): 32 | save_dir = self._save_dir or runner.work_dir 33 | save_filename = f'checkpoint_iter{runner.iter:06d}.pth' 34 | runner.save(filepath=os.path.join(save_dir, save_filename), 35 | running_metadata=self._save_running_metadata, 36 | learning_rate=self._save_learning_rate, 37 | optimizer=self._save_optimizer, 38 | running_stats=self._save_running_stats) 39 | -------------------------------------------------------------------------------- /train_classification_models/README.md: -------------------------------------------------------------------------------- 1 | # Datasets and Model Training 2 | 3 | ## Datasets Download 4 | 5 | - **VGGFace2**: Download from [Kaggle](https://www.kaggle.com/datasets/dimarodionov/vggface2) 6 | 7 | - **CASIA**: Download from [Drive](https://drive.google.com/file/d/1A9tijVZYYt5bbIXwXK7Ud-LOnGfvXF50/view?usp=sharing) 8 | 9 | ## Models Download 10 | - Download from [Drive](https://drive.google.com/drive/folders/1xtYJXiWTcX6cpZiRU8wTOubYWae-iJeu?usp=drive_link) 11 | 12 | --- 13 | 14 | ## Model Training 15 | 16 | ### Self-Trained Classification Models for VGGFace2 17 | 18 | **Training Command**: 19 | ```bash 20 | python ./self-train_VGGFace2/train_efficientnet_b0.py 21 | ``` 22 | 23 | **Defense Training Command**: 24 | ```bash 25 | python ./self-train_VGGFace2/train_model_defense.py 26 | ``` 27 | 28 | ### Parameter Description 29 | | Parameter | Type | Description | 30 | |------|------|-------| 31 | | `--defense_method` | str | Optional:`BiDO`/`MID`/`LS`/`TL` | 32 | 33 | #### Hyperparameters of the defense 34 | | Defenses | Parameter | type | 35 | |---------|------|------| 36 | | **BiDO** | `--coef_hidden_input` | float | 37 | | | `--coef_hidden_output` | float | 38 | | **MID** | `--beta` | float | 39 | | **LS** | `--coef_label_smoothing` | float | 40 | | **TL** | `--layer_name` | str | 41 | --- 42 | 43 | ## Acknowledge 44 | The defenses are implemented based on the following repositories. We extend our gratitude to the authors for open-sourcing their code. 45 | 46 | [BiDO](https://github.com/AlanPeng0897/Defend_MI), [MID](https://github.com/Jiachen-T-Wang/mi-defense), [LS](https://github.com/LukasStruppek/Plug-and-Play-Attacks), [TL](https://github.com/hosytuyen/TL-DMI), [MIA-ToolBox](https://github.com/ffhibnese/Model-Inversion-Attack-ToolBox) 47 | 48 | -------------------------------------------------------------------------------- /genforce/metrics/fid.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Contains the functions to compute Frechet Inception Distance (FID). 3 | 4 | FID metric is introduced in paper 5 | 6 | GANs Trained by a Two Time-Scale Update Rule Converge to a Local Nash 7 | Equilibrium. Heusel et al. NeurIPS 2017. 8 | 9 | See details at https://arxiv.org/pdf/1706.08500.pdf 10 | """ 11 | 12 | import numpy as np 13 | import scipy.linalg 14 | 15 | __all__ = ['extract_feature', 'compute_fid'] 16 | 17 | 18 | def extract_feature(inception_model, images): 19 | """Extracts feature from input images with given model. 20 | 21 | NOTE: The input images are assumed to be with pixel range [-1, 1]. 22 | 23 | Args: 24 | inception_model: The model used to extract features. 25 | images: The input image tensor to extract features from. 26 | 27 | Returns: 28 | A `numpy.ndarray`, containing the extracted features. 29 | """ 30 | features = inception_model(images, output_logits=False) 31 | features = features.detach().cpu().numpy() 32 | assert features.ndim == 2 and features.shape[1] == 2048 33 | return features 34 | 35 | 36 | def compute_fid(fake_features, real_features): 37 | """Computes FID based on the features extracted from fake and real data. 38 | 39 | Given the mean and covariance (m_f, C_f) of fake data and (m_r, C_r) of real 40 | data, the FID metric can be computed by 41 | 42 | d^2 = ||m_f - m_r||_2^2 + Tr(C_f + C_r - 2(C_f C_r)^0.5) 43 | 44 | Args: 45 | fake_features: The features extracted from fake data. 46 | real_features: The features extracted from real data. 47 | 48 | Returns: 49 | A real number, suggesting the FID value. 50 | """ 51 | 52 | m_f = np.mean(fake_features, axis=0) 53 | C_f = np.cov(fake_features, rowvar=False) 54 | m_r = np.mean(real_features, axis=0) 55 | C_r = np.cov(real_features, rowvar=False) 56 | 57 | fid = np.sum((m_f - m_r) ** 2) + np.trace( 58 | C_f + C_r - 2 * scipy.linalg.sqrtm(np.dot(C_f, C_r))) 59 | return np.real(fid) 60 | -------------------------------------------------------------------------------- /genforce/runners/controllers/timer.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Contains the running controller to record time.""" 3 | 4 | import time 5 | 6 | from .base_controller import BaseController 7 | 8 | __all__ = ['Timer'] 9 | 10 | 11 | class Timer(BaseController): 12 | """Defines the running controller to record running time. 13 | 14 | This controller will be executed every iteration (both before and after) to 15 | summarize the data preparation time as well as the model running time. 16 | Besides, this controller will also mark the start and end time of the 17 | running process. 18 | 19 | NOTE: This controller is set to `LOW` priority by default and will only be 20 | executed on the master worker. 21 | """ 22 | 23 | def __init__(self, config=None): 24 | config = config or dict() 25 | config.setdefault('priority', 'LOW') 26 | config.setdefault('every_n_iters', 1) 27 | config.setdefault('master_only', True) 28 | super().__init__(config) 29 | 30 | self.time = time.time() 31 | 32 | def setup(self, runner): 33 | runner.running_stats.add( 34 | 'data_time', log_format='time', log_name='data time') 35 | runner.running_stats.add( 36 | 'iter_time', log_format='time', log_name='iter time') 37 | runner.running_stats.add( 38 | 'run_time', log_format='time', log_name='run time', 39 | log_strategy='CURRENT') 40 | self.time = time.time() 41 | runner.start_time = self.time 42 | 43 | def close(self, runner): 44 | runner.end_time = time.time() 45 | 46 | def execute_before_iteration(self, runner): 47 | start_time = time.time() 48 | runner.running_stats.update({'data_time': start_time - self.time}) 49 | 50 | def execute_after_iteration(self, runner): 51 | end_time = time.time() 52 | runner.running_stats.update({'iter_time': end_time - self.time}) 53 | runner.running_stats.update({'run_time': end_time - runner.start_time}) 54 | self.time = end_time 55 | -------------------------------------------------------------------------------- /genforce/runners/controllers/fid_evaluator.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Contains the running controller for evaluation.""" 3 | 4 | import os.path 5 | import time 6 | 7 | from .base_controller import BaseController 8 | from ..misc import format_time 9 | 10 | __all__ = ['FIDEvaluator'] 11 | 12 | 13 | class FIDEvaluator(BaseController): 14 | """Defines the running controller for evaluation. 15 | 16 | This controller is used to evalute the GAN model using FID metric. 17 | 18 | NOTE: The controller is set to `LAST` priority by default. 19 | """ 20 | 21 | def __init__(self, config): 22 | assert isinstance(config, dict) 23 | config.setdefault('priority', 'LAST') 24 | super().__init__(config) 25 | 26 | self.num = config.get('num', 50000) 27 | self.ignore_cache = config.get('ignore_cache', False) 28 | self.align_tf = config.get('align_tf', True) 29 | self.file = None 30 | 31 | def setup(self, runner): 32 | assert hasattr(runner, 'fid') 33 | file_path = os.path.join(runner.work_dir, f'metric_fid{self.num}.txt') 34 | if runner.rank == 0: 35 | self.file = open(file_path, 'w') 36 | 37 | def close(self, runner): 38 | if runner.rank == 0: 39 | self.file.close() 40 | 41 | def execute_after_iteration(self, runner): 42 | mode = runner.mode # save runner mode. 43 | start_time = time.time() 44 | fid_value = runner.fid(self.num, 45 | ignore_cache=self.ignore_cache, 46 | align_tf=self.align_tf) 47 | duration_str = format_time(time.time() - start_time) 48 | log_str = (f'FID: {fid_value:.5f} at iter {runner.iter:06d} ' 49 | f'({runner.seen_img / 1000:.1f} kimg). ({duration_str})') 50 | runner.logger.info(log_str) 51 | if runner.rank == 0: 52 | date = time.strftime("%Y-%m-%d %H:%M:%S") 53 | self.file.write(f'[{date}] {log_str}\n') 54 | self.file.flush() 55 | runner.set_mode(mode) # restore runner mode. 56 | -------------------------------------------------------------------------------- /genforce/configs/stylegan_demo.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Configuration for StyleGAN training demo. 3 | 4 | All settings are particularly used for one replica (GPU), such as `batch_size` 5 | and `num_workers`. 6 | """ 7 | 8 | runner_type = 'StyleGANRunner' 9 | gan_type = 'stylegan' 10 | resolution = 64 11 | batch_size = 4 12 | val_batch_size = 32 13 | total_img = 100_000 14 | 15 | # Training dataset is repeated at the beginning to avoid loading dataset 16 | # repeatedly at the end of each epoch. This can save some I/O time. 17 | data = dict( 18 | num_workers=4, 19 | repeat=500, 20 | train=dict(root_dir='data/demo.zip', data_format='zip', 21 | resolution=resolution, mirror=0.5), 22 | val=dict(root_dir='data/demo.zip', data_format='zip', 23 | resolution=resolution), 24 | ) 25 | 26 | controllers = dict( 27 | RunningLogger=dict(every_n_iters=10), 28 | ProgressScheduler=dict( 29 | every_n_iters=1, init_res=8, minibatch_repeats=4, 30 | lod_training_img=5_000, lod_transition_img=5_000, 31 | batch_size_schedule=dict(res4=64, res8=32, res16=16, res32=8), 32 | ), 33 | Snapshoter=dict(every_n_iters=500, first_iter=True, num=200), 34 | FIDEvaluator=dict(every_n_iters=5000, first_iter=True, num=50000), 35 | Checkpointer=dict(every_n_iters=5000, first_iter=True), 36 | ) 37 | 38 | modules = dict( 39 | discriminator=dict( 40 | model=dict(gan_type=gan_type, resolution=resolution), 41 | lr=dict(lr_type='FIXED'), 42 | opt=dict(opt_type='Adam', base_lr=1e-3, betas=(0.0, 0.99)), 43 | kwargs_train=dict(), 44 | kwargs_val=dict(), 45 | ), 46 | generator=dict( 47 | model=dict(gan_type=gan_type, resolution=resolution), 48 | lr=dict(lr_type='FIXED'), 49 | opt=dict(opt_type='Adam', base_lr=1e-3, betas=(0.0, 0.99)), 50 | kwargs_train=dict(w_moving_decay=0.995, style_mixing_prob=0.9, 51 | trunc_psi=1.0, trunc_layers=0, randomize_noise=True), 52 | kwargs_val=dict(trunc_psi=1.0, trunc_layers=0, randomize_noise=False), 53 | g_smooth_img=10000, 54 | ) 55 | ) 56 | 57 | loss = dict( 58 | type='LogisticGANLoss', 59 | d_loss_kwargs=dict(r1_gamma=10.0), 60 | g_loss_kwargs=dict(), 61 | ) 62 | -------------------------------------------------------------------------------- /genforce/utils/logger_test.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Unit test for logger.""" 3 | 4 | import time 5 | 6 | from .logger import build_logger 7 | 8 | 9 | def test_logger(): 10 | """Test function.""" 11 | logger = build_logger('normal', logger_name='normal', logfile_name='') 12 | rich_logger = build_logger('rich', logger_name='rich', logfile_name='') 13 | dumb_logger = build_logger('dumb', logger_name='dumb', logfile_name='') 14 | 15 | print('-------------------------------') 16 | print('| Test `utils.logger.Logger`. |') 17 | print('-------------------------------') 18 | logger.print('log') 19 | logger.debug('log') 20 | logger.info('log') 21 | logger.warning('log') 22 | logger.init_pbar() 23 | task1 = logger.add_pbar_task('Task 1', 500) 24 | task2 = logger.add_pbar_task('Task 2', 1000) 25 | for _ in range(1000): 26 | logger.update_pbar(task1, 1) 27 | logger.update_pbar(task2, 1) 28 | time.sleep(0.005) 29 | logger.close_pbar() 30 | print('Success!') 31 | 32 | print('-----------------------------------') 33 | print('| Test `utils.logger.RichLogger`. |') 34 | print('-----------------------------------') 35 | rich_logger.print('rich_log') 36 | rich_logger.debug('rich_log') 37 | rich_logger.info('rich_log') 38 | rich_logger.warning('rich_log') 39 | rich_logger.init_pbar() 40 | task1 = rich_logger.add_pbar_task('Rich Task 1', 500) 41 | task2 = rich_logger.add_pbar_task('Rich Task 2', 1000) 42 | for _ in range(1000): 43 | rich_logger.update_pbar(task1, 1) 44 | rich_logger.update_pbar(task2, 1) 45 | time.sleep(0.005) 46 | rich_logger.close_pbar() 47 | print('Success!') 48 | 49 | print('-----------------------------------') 50 | print('| Test `utils.logger.DumbLogger`. |') 51 | print('-----------------------------------') 52 | dumb_logger.print('dumb_log') 53 | dumb_logger.debug('dumb_log') 54 | dumb_logger.info('dumb_log') 55 | dumb_logger.warning('dumb_log') 56 | dumb_logger.init_pbar() 57 | task1 = dumb_logger.add_pbar_task('Dumb Task 1', 500) 58 | task2 = dumb_logger.add_pbar_task('Dumb Task 2', 1000) 59 | for _ in range(1000): 60 | dumb_logger.update_pbar(task1, 1) 61 | dumb_logger.update_pbar(task2, 1) 62 | time.sleep(0.005) 63 | dumb_logger.close_pbar() 64 | print('Success!') 65 | -------------------------------------------------------------------------------- /genforce/configs/stylegan_ffhq256.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Configuration for training StyleGAN on FF-HQ (256) dataset. 3 | 4 | All settings are particularly used for one replica (GPU), such as `batch_size` 5 | and `num_workers`. 6 | """ 7 | 8 | runner_type = 'StyleGANRunner' 9 | gan_type = 'stylegan' 10 | resolution = 256 11 | batch_size = 4 12 | val_batch_size = 64 13 | total_img = 25000_000 14 | 15 | # Training dataset is repeated at the beginning to avoid loading dataset 16 | # repeatedly at the end of each epoch. This can save some I/O time. 17 | data = dict( 18 | num_workers=4, 19 | repeat=500, 20 | # train=dict(root_dir='data/ffhq', resolution=resolution, mirror=0.5), 21 | # val=dict(root_dir='data/ffhq', resolution=resolution), 22 | train=dict(root_dir='data/ffhq.zip', data_format='zip', 23 | resolution=resolution, mirror=0.5), 24 | val=dict(root_dir='data/ffhq.zip', data_format='zip', 25 | resolution=resolution), 26 | ) 27 | 28 | controllers = dict( 29 | RunningLogger=dict(every_n_iters=10), 30 | ProgressScheduler=dict( 31 | every_n_iters=1, init_res=8, minibatch_repeats=4, 32 | lod_training_img=600_000, lod_transition_img=600_000, 33 | batch_size_schedule=dict(res4=64, res8=32, res16=16, res32=8), 34 | ), 35 | Snapshoter=dict(every_n_iters=500, first_iter=True, num=200), 36 | FIDEvaluator=dict(every_n_iters=5000, first_iter=True, num=50000), 37 | Checkpointer=dict(every_n_iters=5000, first_iter=True), 38 | ) 39 | 40 | modules = dict( 41 | discriminator=dict( 42 | model=dict(gan_type=gan_type, resolution=resolution), 43 | lr=dict(lr_type='FIXED'), 44 | opt=dict(opt_type='Adam', base_lr=1e-3, betas=(0.0, 0.99)), 45 | kwargs_train=dict(), 46 | kwargs_val=dict(), 47 | ), 48 | generator=dict( 49 | model=dict(gan_type=gan_type, resolution=resolution), 50 | lr=dict(lr_type='FIXED'), 51 | opt=dict(opt_type='Adam', base_lr=1e-3, betas=(0.0, 0.99)), 52 | kwargs_train=dict(w_moving_decay=0.995, style_mixing_prob=0.9, 53 | trunc_psi=1.0, trunc_layers=0, randomize_noise=True), 54 | kwargs_val=dict(trunc_psi=1.0, trunc_layers=0, randomize_noise=False), 55 | g_smooth_img=10_000, 56 | ) 57 | ) 58 | 59 | loss = dict( 60 | type='LogisticGANLoss', 61 | d_loss_kwargs=dict(r1_gamma=10.0), 62 | g_loss_kwargs=dict(), 63 | ) 64 | -------------------------------------------------------------------------------- /genforce/configs/stylegan_ffhq1024.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Configuration for training StyleGAN on FF-HQ (1024) dataset. 3 | 4 | All settings are particularly used for one replica (GPU), such as `batch_size` 5 | and `num_workers`. 6 | """ 7 | 8 | runner_type = 'StyleGANRunner' 9 | gan_type = 'stylegan' 10 | resolution = 1024 11 | batch_size = 4 12 | val_batch_size = 16 13 | total_img = 25000_000 14 | 15 | # Training dataset is repeated at the beginning to avoid loading dataset 16 | # repeatedly at the end of each epoch. This can save some I/O time. 17 | data = dict( 18 | num_workers=4, 19 | repeat=500, 20 | # train=dict(root_dir='data/ffhq', resolution=resolution, mirror=0.5), 21 | # val=dict(root_dir='data/ffhq', resolution=resolution), 22 | train=dict(root_dir='data/ffhq.zip', data_format='zip', 23 | resolution=resolution, mirror=0.5), 24 | val=dict(root_dir='data/ffhq.zip', data_format='zip', 25 | resolution=resolution), 26 | ) 27 | 28 | controllers = dict( 29 | RunningLogger=dict(every_n_iters=10), 30 | ProgressScheduler=dict( 31 | every_n_iters=1, init_res=8, minibatch_repeats=4, 32 | lod_training_img=600_000, lod_transition_img=600_000, 33 | batch_size_schedule=dict(res4=64, res8=32, res16=16, res32=8), 34 | ), 35 | Snapshoter=dict(every_n_iters=500, first_iter=True, num=200), 36 | FIDEvaluator=dict(every_n_iters=5000, first_iter=True, num=50000), 37 | Checkpointer=dict(every_n_iters=5000, first_iter=True), 38 | ) 39 | 40 | modules = dict( 41 | discriminator=dict( 42 | model=dict(gan_type=gan_type, resolution=resolution), 43 | lr=dict(lr_type='FIXED'), 44 | opt=dict(opt_type='Adam', base_lr=1e-3, betas=(0.0, 0.99)), 45 | kwargs_train=dict(), 46 | kwargs_val=dict(), 47 | ), 48 | generator=dict( 49 | model=dict(gan_type=gan_type, resolution=resolution), 50 | lr=dict(lr_type='FIXED'), 51 | opt=dict(opt_type='Adam', base_lr=1e-3, betas=(0.0, 0.99)), 52 | kwargs_train=dict(w_moving_decay=0.995, style_mixing_prob=0.9, 53 | trunc_psi=1.0, trunc_layers=0, randomize_noise=True), 54 | kwargs_val=dict(trunc_psi=1.0, trunc_layers=0, randomize_noise=False), 55 | g_smooth_img=10_000, 56 | ) 57 | ) 58 | 59 | loss = dict( 60 | type='LogisticGANLoss', 61 | d_loss_kwargs=dict(r1_gamma=10.0), 62 | g_loss_kwargs=dict(), 63 | ) 64 | -------------------------------------------------------------------------------- /genforce/runners/stylegan_runner.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Contains the runner for StyleGAN.""" 3 | 4 | from copy import deepcopy 5 | 6 | from .base_gan_runner import BaseGANRunner 7 | 8 | __all__ = ['StyleGANRunner'] 9 | 10 | 11 | class StyleGANRunner(BaseGANRunner): 12 | """Defines the runner for StyleGAN.""" 13 | 14 | def __init__(self, config, logger): 15 | super().__init__(config, logger) 16 | self.lod = getattr(self, 'lod', None) 17 | 18 | def build_models(self): 19 | super().build_models() 20 | self.g_smooth_img = self.config.modules['generator'].get( 21 | 'g_smooth_img', 10000) 22 | self.models['generator_smooth'] = deepcopy(self.models['generator']) 23 | 24 | def build_loss(self): 25 | super().build_loss() 26 | self.running_stats.add( 27 | f'Gs_beta', log_format='.4f', log_strategy='CURRENT') 28 | 29 | def train_step(self, data, **train_kwargs): 30 | # Set level-of-details. 31 | G = self.get_module(self.models['generator']) 32 | D = self.get_module(self.models['discriminator']) 33 | Gs = self.get_module(self.models['generator_smooth']) 34 | G.synthesis.lod.data.fill_(self.lod) 35 | D.lod.data.fill_(self.lod) 36 | Gs.synthesis.lod.data.fill_(self.lod) 37 | 38 | # Update discriminator. 39 | self.set_model_requires_grad('discriminator', True) 40 | self.set_model_requires_grad('generator', False) 41 | 42 | d_loss = self.loss.d_loss(self, data) 43 | self.optimizers['discriminator'].zero_grad() 44 | d_loss.backward() 45 | self.optimizers['discriminator'].step() 46 | 47 | # Life-long update for generator. 48 | beta = 0.5 ** (self.batch_size * self.world_size / self.g_smooth_img) 49 | self.running_stats.update({'Gs_beta': beta}) 50 | self.moving_average_model(model=self.models['generator'], 51 | avg_model=self.models['generator_smooth'], 52 | beta=beta) 53 | 54 | # Update generator. 55 | if self._iter % self.config.get('D_repeats', 1) == 0: 56 | self.set_model_requires_grad('discriminator', False) 57 | self.set_model_requires_grad('generator', True) 58 | g_loss = self.loss.g_loss(self, data) 59 | self.optimizers['generator'].zero_grad() 60 | g_loss.backward() 61 | self.optimizers['generator'].step() 62 | 63 | def load(self, **kwargs): 64 | super().load(**kwargs) 65 | G = self.get_module(self.models['generator']) 66 | D = self.get_module(self.models['discriminator']) 67 | Gs = self.get_module(self.models['generator_smooth']) 68 | if kwargs['running_metadata']: 69 | lod = G.synthesis.lod.cpu().tolist() 70 | assert lod == D.lod.cpu().tolist() 71 | assert lod == Gs.synthesis.lod.cpu().tolist() 72 | self.lod = lod 73 | -------------------------------------------------------------------------------- /my_generate_blackbox_attack_dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # coding=utf-8 3 | import argparse 4 | import glob 5 | import os 6 | import random 7 | import sys 8 | from tqdm import tqdm 9 | import torch 10 | import torch.nn.functional as F 11 | import torch.optim as optim 12 | from torch.utils.data import Dataset, DataLoader 13 | import torchvision.utils as vutils 14 | from my_utils import crop_img, resize_img, normalize, create_folder, Tee 15 | from my_target_models import get_model, get_input_resolution 16 | 17 | random.seed(0) 18 | 19 | @torch.no_grad() 20 | def main(): 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument('--seed', type=int, default=666, help='set the seed') 23 | parser.add_argument('--no-cuda', action='store_true') 24 | parser.add_argument('--arch_name', default='resnet50', type=str, help='model name from torchvision or resnet50v15') 25 | parser.add_argument('--use_dropout', action='store_true', help='use dropout to mitigate overfitting') 26 | parser.add_argument('target_dataset', choices=['vggface', 'vggface2', 'CASIA'], help='use which target dataset') 27 | parser.add_argument('dataset', choices=['ffhq', 'celeba_partial256'], help='use which dataset') 28 | args = parser.parse_args() 29 | 30 | torch.backends.cudnn.benchmark = True 31 | 32 | exp_name = os.path.join('blackbox_attack_data', args.target_dataset, args.arch_name, args.dataset) 33 | create_folder(exp_name) 34 | Tee(os.path.join(exp_name, 'output.log'), 'w') 35 | print(args) 36 | 37 | torch.manual_seed(args.seed) 38 | device = torch.device('cuda' if torch.cuda.is_available() and not args.no_cuda else 'cpu') 39 | 40 | net = get_model(args.arch_name, device, args.use_dropout) 41 | 42 | if args.arch_name == 'resnet50': 43 | resolution = get_input_resolution(args.arch_name) 44 | elif args.arch_name == 'inception_resnetv1_vggface2': 45 | resolution = get_input_resolution(args.arch_name) 46 | elif args.arch_name == 'inception_resnetv1_casia': 47 | resolution = get_input_resolution(args.arch_name) 48 | elif args.arch_name == 'mobilenet_v2': 49 | resolution = 224 50 | elif args.arch_name == 'efficientnet_b0': 51 | resolution = 256 52 | elif args.arch_name == 'efficientnet_b0_casia': 53 | resolution = 256 54 | elif args.arch_name == 'inception_v3': 55 | resolution = 342 56 | elif args.arch_name == 'swin_transformer': 57 | resolution = 260 58 | elif args.arch_name == 'vision_transformer': 59 | resolution = 224 60 | elif args.arch_name == 'vgg16': 61 | resolution = 224 62 | elif args.arch_name == 'vgg16bn': 63 | resolution = 224 64 | elif args.arch_name == 'sphere20a': 65 | resolution = get_input_resolution(args.arch_name) 66 | 67 | 68 | arch_name = args.arch_name 69 | 70 | if args.dataset == 'celeba_partial256': 71 | img_dir = './stylegan_sample_z_stylegan_celeba_partial256_0.7_8_25' 72 | elif args.dataset == 'ffhq': 73 | img_dir = './stylegan_sample_z_stylegan_ffhq256_0.7_8_25' 74 | imgs_files = sorted(glob.glob(os.path.join(img_dir, 'sample_*_img.pt'))) 75 | 76 | assert len(imgs_files) > 0 77 | 78 | for img_gen_file in tqdm(imgs_files): 79 | save_filename = os.path.join(exp_name, os.path.basename(img_gen_file)[:-3]+'_logits.pt') 80 | fake = torch.load(img_gen_file).to(device) 81 | fake = crop_img(fake, arch_name) 82 | fake = normalize(resize_img(fake*255., resolution), args.arch_name) 83 | prediction = net(fake) 84 | if arch_name == 'sphere20a': 85 | prediction = prediction[0] 86 | torch.save(prediction, save_filename) 87 | 88 | 89 | if __name__ == '__main__': 90 | main() 91 | -------------------------------------------------------------------------------- /genforce/models/__init__.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Collects all available models together.""" 3 | 4 | from .model_zoo import MODEL_ZOO 5 | from .pggan_generator import PGGANGenerator 6 | from .pggan_discriminator import PGGANDiscriminator 7 | from .stylegan_generator import StyleGANGenerator 8 | from .stylegan_discriminator import StyleGANDiscriminator 9 | from .stylegan2_generator import StyleGAN2Generator 10 | from .stylegan2_discriminator import StyleGAN2Discriminator 11 | 12 | __all__ = [ 13 | 'MODEL_ZOO', 'PGGANGenerator', 'PGGANDiscriminator', 'StyleGANGenerator', 14 | 'StyleGANDiscriminator', 'StyleGAN2Generator', 'StyleGAN2Discriminator', 15 | 'build_generator', 'build_discriminator', 'build_model' 16 | ] 17 | 18 | _GAN_TYPES_ALLOWED = ['pggan', 'stylegan', 'stylegan2'] 19 | _MODULES_ALLOWED = ['generator', 'discriminator'] 20 | 21 | 22 | def build_generator(gan_type, resolution, **kwargs): 23 | """Builds generator by GAN type. 24 | 25 | Args: 26 | gan_type: GAN type to which the generator belong. 27 | resolution: Synthesis resolution. 28 | **kwargs: Additional arguments to build the generator. 29 | 30 | Raises: 31 | ValueError: If the `gan_type` is not supported. 32 | NotImplementedError: If the `gan_type` is not implemented. 33 | """ 34 | if gan_type not in _GAN_TYPES_ALLOWED: 35 | raise ValueError(f'Invalid GAN type: `{gan_type}`!\n' 36 | f'Types allowed: {_GAN_TYPES_ALLOWED}.') 37 | 38 | if gan_type == 'pggan': 39 | return PGGANGenerator(resolution, **kwargs) 40 | if gan_type == 'stylegan': 41 | return StyleGANGenerator(resolution, **kwargs) 42 | if gan_type == 'stylegan2': 43 | return StyleGAN2Generator(resolution, **kwargs) 44 | raise NotImplementedError(f'Unsupported GAN type `{gan_type}`!') 45 | 46 | 47 | def build_discriminator(gan_type, resolution, **kwargs): 48 | """Builds discriminator by GAN type. 49 | 50 | Args: 51 | gan_type: GAN type to which the discriminator belong. 52 | resolution: Synthesis resolution. 53 | **kwargs: Additional arguments to build the discriminator. 54 | 55 | Raises: 56 | ValueError: If the `gan_type` is not supported. 57 | NotImplementedError: If the `gan_type` is not implemented. 58 | """ 59 | if gan_type not in _GAN_TYPES_ALLOWED: 60 | raise ValueError(f'Invalid GAN type: `{gan_type}`!\n' 61 | f'Types allowed: {_GAN_TYPES_ALLOWED}.') 62 | 63 | if gan_type == 'pggan': 64 | return PGGANDiscriminator(resolution, **kwargs) 65 | if gan_type == 'stylegan': 66 | return StyleGANDiscriminator(resolution, **kwargs) 67 | if gan_type == 'stylegan2': 68 | return StyleGAN2Discriminator(resolution, **kwargs) 69 | raise NotImplementedError(f'Unsupported GAN type `{gan_type}`!') 70 | 71 | 72 | def build_model(gan_type, module, resolution, **kwargs): 73 | """Builds a GAN module (generator/discriminator/etc). 74 | 75 | Args: 76 | gan_type: GAN type to which the model belong. 77 | module: GAN module to build, such as generator or discrimiantor. 78 | resolution: Synthesis resolution. 79 | **kwargs: Additional arguments to build the discriminator. 80 | 81 | Raises: 82 | ValueError: If the `module` is not supported. 83 | NotImplementedError: If the `module` is not implemented. 84 | """ 85 | if module not in _MODULES_ALLOWED: 86 | raise ValueError(f'Invalid module: `{module}`!\n' 87 | f'Modules allowed: {_MODULES_ALLOWED}.') 88 | 89 | if module == 'generator': 90 | return build_generator(gan_type, resolution, **kwargs) 91 | if module == 'discriminator': 92 | return build_discriminator(gan_type, resolution, **kwargs) 93 | raise NotImplementedError(f'Unsupported module `{module}`!') 94 | -------------------------------------------------------------------------------- /genforce/runners/controllers/running_logger.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Contains the running controller to save the running log.""" 3 | 4 | import os 5 | import json 6 | 7 | import warnings 8 | warnings.filterwarnings('ignore', category=FutureWarning) # Ignore TF warning. 9 | 10 | # pylint: disable=wrong-import-position 11 | import torch 12 | from torch.utils.tensorboard import SummaryWriter 13 | 14 | from ..misc import format_time 15 | from .base_controller import BaseController 16 | # pylint: enable=wrong-import-position 17 | 18 | __all__ = ['RunningLogger'] 19 | 20 | 21 | class RunningLogger(BaseController): 22 | """Defines the running controller to save the running log. 23 | 24 | This controller is able to save the log message in different formats: 25 | 26 | (1) Text format, which will be printed on screen and saved to the log file. 27 | (2) JSON format, which will be saved to `{runner.work_dir}/log.json`. 28 | (3) Tensorboard format. 29 | 30 | NOTE: The controller is set to `90` priority by default and will only be 31 | executed on the master worker. 32 | """ 33 | 34 | def __init__(self, config=None): 35 | config = config or dict() 36 | config.setdefault('priority', 90) 37 | config.setdefault('every_n_iters', 1) 38 | config.setdefault('master_only', True) 39 | super().__init__(config) 40 | 41 | self._text_format = config.get('text_format', True) 42 | self._log_order = config.get('log_order', None) 43 | self._json_format = config.get('json_format', True) 44 | self._json_logpath = self._json_filename = 'log.json' 45 | self._tensorboard_format = config.get('tensorboard_format', True) 46 | self.tensorboard_writer = None 47 | 48 | def setup(self, runner): 49 | if self._text_format: 50 | runner.running_stats.log_order = self._log_order 51 | if self._json_format: 52 | self._json_logpath = os.path.join( 53 | runner.work_dir, self._json_filename) 54 | if self._tensorboard_format: 55 | event_dir = os.path.join(runner.work_dir, 'events') 56 | os.makedirs(event_dir, exist_ok=True) 57 | self.tensorboard_writer = SummaryWriter(log_dir=event_dir) 58 | 59 | def close(self, runner): 60 | if self._tensorboard_format: 61 | self.tensorboard_writer.close() 62 | 63 | def execute_after_iteration(self, runner): 64 | # Prepare log data. 65 | log_data = {name: stats.get_log_value() 66 | for name, stats in runner.running_stats.stats_pool.items()} 67 | 68 | # Save in text format. 69 | msg = f'Iter {runner.iter:6d}/{runner.total_iters:6d}' 70 | msg += f', {runner.running_stats}' 71 | memory = torch.cuda.max_memory_allocated() / (1024 ** 3) 72 | msg += f' (memory: {memory:.1f}G)' 73 | if 'iter_time' in log_data: 74 | eta = log_data['iter_time'] * (runner.total_iters - runner.iter) 75 | msg += f' (ETA: {format_time(eta)})' 76 | runner.logger.info(msg) 77 | 78 | # Save in JSON format. 79 | if self._json_format: 80 | with open(self._json_logpath, 'a+') as f: 81 | json.dump(log_data, f) 82 | f.write('\n') 83 | 84 | # Save in Tensorboard format. 85 | if self._tensorboard_format: 86 | for name, value in log_data.items(): 87 | if name in ['data_time', 'iter_time', 'run_time']: 88 | continue 89 | if name.startswith('loss_'): 90 | self.tensorboard_writer.add_scalar( 91 | name.replace('loss_', 'loss/'), value) 92 | elif name.startswith('lr_'): 93 | self.tensorboard_writer.add_scalar( 94 | name.replace('lr_', 'learning_rate/'), value) 95 | else: 96 | self.tensorboard_writer.add_scalar(name, value) 97 | 98 | # Clear running stats. 99 | runner.running_stats.clear() 100 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # _From Head to Tail: Efficient Black-box Model Inversion Attack via Long-tailed Learning_ - CVPR 2025 2 | 3 | ## 📄 [Paper](https://arxiv.org/abs/2503.16266) 4 | 5 | ## 📝 Abstract 6 | _Model Inversion Attacks (MIAs) aim to reconstruct private training data from models, leading to privacy leakage, particularly in facial recognition systems. Although many studies have enhanced the effectiveness of white-box MIAs, less attention has been paid to improving efficiency and utility under limited attacker capabilities. Existing black-box MIAs necessitate an impractical number of queries, incurring significant overhead. Therefore, we analyze the limitations of existing MIAs and introduce **S**urrogate **M**odel-based **I**nversion with **L**ong-tailed **E**nhancement (**SMILE**), a high-resolution oriented and query-efficient MIA for the black-box setting. We begin by analyzing the initialization of MIAs from a data distribution perspective and propose a long-tailed surrogate training method to obtain high-quality initial points. We then enhance the attack's effectiveness by employing the gradient-free black-box optimization algorithm selected by NGOpt. Our experiments show that **SMILE** outperforms existing state-of-the-art black-box MIAs while requiring only about 5% of the query overhead._ 7 | 8 | ## 📦 Environment Installation 9 | ```bash 10 | conda env create -f environment.yml 11 | ``` 12 | 13 | ## 🔍 Datasets and Models Download 14 | 15 | - **Datasets** : refer to ./train_classification_models/README.md 16 | - **Models** : refer to [./checkpoints](https://drive.google.com/drive/folders/1Ka5s0e8UdXKNUOFdIDBxfJAQ2TfiJG_r?usp=drive_link) & [./classification_models](https://drive.google.com/drive/folders/14I9n1pPuHWJiBbdhDTsaoFajSyoXMmvA?usp=drive_link) 17 | - **conf_mask.pt** : refer to [./conf_mask.pt](https://drive.google.com/file/d/19QQE0DZffsdBFQv0lOad4U9T3a9O8XHF/view?usp=drive_link) 18 | 19 | - **Pre-trained target models** : The GANs and classification model we use follows [MIRROR](https://github.com/njuaplusplus/mirror): 20 | 21 | ## 😃 SMILE 22 | 23 | 1.Initial sampling 2.5K synthetic images 24 | ```bash 25 | python my_sample_z_w_space.py 26 | ``` 27 | 28 | 2.Query the black-box target model to obtain the output 29 | ```bash 30 | python my_generate_blackbox_attack_dataset.py --arch_name inception_resnetv1_vggface2 vggface2 celeba_partial256 31 | ``` 32 | 33 | 3.Merge all tensors 34 | ```bash 35 | python my_merge_all_tensors.py blackbox_attack_data/vggface2/inception_resnetv1_vggface2/celeba_partial256/ 36 | ``` 37 | 38 | 4.Long-tailed surrogate training 39 | ```bash 40 | python long-tailed_surrogate_training.py --target_dataset vggface2 --dataset celeba_partial256 --arch_name_target inception_resnetv1_vggface2 --arch_name_finetune inception_resnetv1_casia --finetune_mode 'vggface2->CASIA' --epoch 200 --batch_size 128 --query_num 2500 41 | ``` 42 | 43 | 5.Local White-box attacks & Gradient-free Black-Box attacks 44 | ```bash 45 | run_SMILE.sh 46 | ``` 47 | 48 | Baselines: 49 | ```bash 50 | run.sh 51 | ``` 52 | 53 | ## 🔨 Evaluation for Attacks 54 | Data generation for evaluation : gen_eval_data.py 55 | ```bash 56 | test.sh 57 | ``` 58 | 59 | ## 📚 Evaluation for Models 60 | Datasets for evaluate the accuracy of surrogate models & Self-Trained Classification Models 61 | 62 | refer to ./train_classification_models/README.md 63 | 64 | 65 | ## 🔥 Acknowledgement 66 | 67 | The codebase is based on [MIRROR](https://github.com/njuaplusplus/mirror). 68 | 69 | The StyleGAN models are based on [genforce/genforce](https://github.com/genforce/genforce). 70 | 71 | VGG16/VGG16BN/Resnet50 models are from [their official websites](https://www.robots.ox.ac.uk/~albanie/pytorch-models.html). 72 | 73 | InceptionResnetV1 is from [timesler/facenet-pytorch](https://github.com/timesler/facenet-pytorch). 74 | 75 | SphereFace is from [clcarwin/sphereface_pytorch](https://github.com/clcarwin/sphereface_pytorch). 76 | 77 | Our baselines are implemented based on the following repositories. We extend our gratitude to the authors for open-sourcing their code. 78 | [MIRROR](https://github.com/njuaplusplus/mirror), [PPA](https://github.com/LukasStruppek/Plug-and-Play-Attacks), [RLBMI](https://github.com/HanGyojin/RLB-MI) 79 | 80 | ## 📜 Citation 81 | 82 | ``` 83 | @article{li2025head, 84 | title={From Head to Tail: Efficient Black-box Model Inversion Attack via Long-tailed Learning}, 85 | author={Li, Ziang and Zhang, Hongguang and Wang, Juan and Chen, Meihui and Hu, Hongxin and Yi, Wenzhe and Xu, Xiaoyang and Yang, Mengda and Ma, Chenjun}, 86 | journal={arXiv preprint arXiv:2503.16266}, 87 | year={2025} 88 | } 89 | ``` 90 | -------------------------------------------------------------------------------- /my_sample_z_w_space.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # coding=utf-8 3 | import argparse 4 | import os 5 | import glob 6 | import math 7 | 8 | from PIL import Image 9 | import numpy as np 10 | import torch 11 | import torch.nn as nn 12 | import torch.optim as optim 13 | import torchvision.transforms.functional as F 14 | from torchvision.utils import save_image 15 | from tqdm import tqdm 16 | 17 | use_w_space = False 18 | repeat_w = True # if False, opt w+ space 19 | num_layers = 14 # 14 for stylegan w+ space with stylegan_celeba_partial256 20 | # num_layers = 18 # 14 for stylegan w+ space with stylegan_celebahq1024 21 | use_z_plus_space = False # to use z+ space, set this and use_w_space to be true and repeat_w to be false 22 | trunc_psi = 0.7 23 | trunc_layers = 8 24 | 25 | # genforce_model = 'pggan_celebahq1024' 26 | genforce_model = 'stylegan_celeba_partial256' 27 | # genforce_model = 'stylegan_celebahq1024' 28 | # genforce_model = 'stylegan2_ffhq1024' 29 | # genforce_model = 'stylegan_ffhq256' 30 | # genforce_model = 'stylegan_cat256' 31 | # genforce_model = 'stylegan_animeportrait512' 32 | # genforce_model = 'stylegan_animeface512' 33 | # genforce_model = 'stylegan_artface512' 34 | # genforce_model = 'stylegan_car512' 35 | # genforce_model = 'stylegan2_car512' 36 | 37 | 38 | if use_z_plus_space: 39 | use_w_space = True 40 | repeat_w = False 41 | else: 42 | use_w_space = False 43 | repeat_w = True 44 | 45 | 46 | def get_generator(batch_size, device): 47 | from genforce import my_get_GD 48 | # global use_w_space 49 | # if genforce_model.startswith('stylegan'): 50 | # use_w_space = False 51 | use_discri = False 52 | generator, discri = my_get_GD.main(device, genforce_model, batch_size, batch_size, use_w_space=use_w_space, use_discri=use_discri, repeat_w=repeat_w, use_z_plus_space=use_z_plus_space, trunc_psi=trunc_psi, trunc_layers=trunc_layers) 53 | return generator 54 | 55 | 56 | @torch.no_grad() 57 | def sample(): 58 | device = 'cuda' 59 | latent_dim = 512 60 | batch_size = 100 61 | generator = get_generator(batch_size, device) 62 | RESOLUTION = 256 63 | 64 | SIZE = 25 # 20K sampling 65 | iter_times = SIZE * (100 // batch_size) 66 | 67 | for i in tqdm(range(1, iter_times+1)): 68 | if use_z_plus_space: 69 | signal_file = './my_sample_zplus_w_space_{}.signal'.format(SIZE) 70 | else: 71 | signal_file = './my_sample_z_w_space_{}.signal'.format(SIZE) 72 | if not os.path.isfile(signal_file): 73 | with open(signal_file, 'w') as out_file: 74 | out_file.write('0') 75 | 76 | with open(signal_file) as in_file: 77 | line = in_file.readline().strip() 78 | if line and int(line) == 1: 79 | print('Stop iteration now') 80 | break 81 | 82 | if use_z_plus_space: 83 | exit() 84 | latent_in = torch.randn(batch_size*num_layers, latent_dim, device=device) 85 | dirname = f'/home/a402-3070/storage/LZA/MI_data/stylegan_sample_zplus_{genforce_model}_{trunc_psi}_{trunc_layers}_{SIZE}' 86 | filename = f'{dirname}/sample_{i}' 87 | else: 88 | latent_in = torch.randn(batch_size, latent_dim, device=device) 89 | dirname = f'./stylegan_sample_z_{genforce_model}_{trunc_psi}_{trunc_layers}_{SIZE}' 90 | filename = f'{dirname}/sample_{i}' 91 | 92 | if not os.path.isdir(dirname): 93 | os.mkdir(dirname) 94 | 95 | 96 | img_gen = generator(latent_in) 97 | torch.save(img_gen, f'{filename}_img.pt') 98 | 99 | if use_z_plus_space: 100 | latent_in = latent_in.view(batch_size, num_layers, latent_dim) 101 | torch.save(latent_in, f'{filename}_latent.pt') 102 | 103 | img_gen = F.resize(img_gen, (RESOLUTION, RESOLUTION)) 104 | 105 | # collect all_ws.pt file 106 | all_ws = [] 107 | all_latent_files = sorted(glob.glob(f'./{dirname}/sample_*_latent.pt*')) 108 | for i in tqdm(range(0, len(all_latent_files), batch_size)): 109 | latent_files = all_latent_files[i:i+batch_size] 110 | latent_in = [torch.load(f) for f in latent_files] 111 | latent_in = torch.cat(latent_in, dim=0) 112 | w = generator.G.mapping(latent_in.to(device))['w'] 113 | all_ws.append(w) 114 | 115 | all_ws = torch.cat(all_ws, dim=0).cpu() 116 | torch.save(all_ws, f'{dirname}/{genforce_model}_all_ws.pt') 117 | 118 | 119 | if __name__ == '__main__': 120 | sample() 121 | -------------------------------------------------------------------------------- /genforce/utils/misc.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Misc utility functions.""" 3 | 4 | import os 5 | import sys 6 | import subprocess 7 | from importlib import import_module 8 | import argparse 9 | from easydict import EasyDict 10 | 11 | import torch 12 | import torch.distributed as dist 13 | import torch.multiprocessing as mp 14 | 15 | __all__ = [ 16 | 'init_dist', 'bool_parser', 'DictAction', 'parse_config', 'update_config' 17 | ] 18 | 19 | 20 | def init_dist(launcher, backend='nccl', **kwargs): 21 | """Initializes distributed environment.""" 22 | if mp.get_start_method(allow_none=True) is None: 23 | mp.set_start_method('spawn') 24 | if launcher == 'pytorch': 25 | rank = int(os.environ['RANK']) 26 | num_gpus = torch.cuda.device_count() 27 | torch.cuda.set_device(rank % num_gpus) 28 | dist.init_process_group(backend=backend, **kwargs) 29 | elif launcher == 'slurm': 30 | proc_id = int(os.environ['SLURM_PROCID']) 31 | ntasks = int(os.environ['SLURM_NTASKS']) 32 | node_list = os.environ['SLURM_NODELIST'] 33 | num_gpus = torch.cuda.device_count() 34 | torch.cuda.set_device(proc_id % num_gpus) 35 | addr = subprocess.getoutput( 36 | f'scontrol show hostname {node_list} | head -n1') 37 | port = os.environ.get('PORT', 29500) 38 | os.environ['MASTER_PORT'] = str(port) 39 | os.environ['MASTER_ADDR'] = addr 40 | os.environ['WORLD_SIZE'] = str(ntasks) 41 | os.environ['RANK'] = str(proc_id) 42 | dist.init_process_group(backend=backend) 43 | else: 44 | raise NotImplementedError(f'Not implemented launcher type: ' 45 | f'`{launcher}`!') 46 | 47 | def bool_parser(arg): 48 | """Parses an argument to boolean.""" 49 | if isinstance(arg, bool): 50 | return arg 51 | if arg.lower() in ['1', 'true', 't', 'yes', 'y']: 52 | return True 53 | if arg.lower() in ['0', 'false', 'f', 'no', 'n']: 54 | return False 55 | raise argparse.ArgumentTypeError(f'`{arg}` cannot be converted to boolean!') 56 | 57 | 58 | class DictAction(argparse.Action): 59 | """Argparse action to split an argument into key-value. 60 | 61 | NOTE: This class is borrowed from 62 | https://github.com/open-mmlab/mmcv/blob/master/mmcv/utils/config.py 63 | """ 64 | 65 | @staticmethod 66 | def _parse_int_float_bool(val): 67 | try: 68 | return int(val) 69 | except ValueError: 70 | pass 71 | try: 72 | return float(val) 73 | except ValueError: 74 | pass 75 | if val.lower() in ['true', 'false']: 76 | return val.lower() == 'true' 77 | return val 78 | 79 | def __call__(self, parser, namespace, values, option_string=None): 80 | options = {} 81 | for kv in values: 82 | key, val = kv.split('=', maxsplit=1) 83 | val = [self._parse_int_float_bool(v) for v in val.split(',')] 84 | if len(val) == 1: 85 | val = val[0] 86 | options[key] = val 87 | setattr(namespace, self.dest, options) 88 | 89 | 90 | def parse_config(config_file): 91 | """Parses configuration from python file.""" 92 | assert os.path.isfile(config_file) 93 | directory = os.path.dirname(config_file) 94 | filename = os.path.basename(config_file) 95 | module_name, extension = os.path.splitext(filename) 96 | assert extension == '.py' 97 | sys.path.insert(0, directory) 98 | module = import_module(module_name) 99 | sys.path.pop(0) 100 | config = EasyDict() 101 | for key, value in module.__dict__.items(): 102 | if key.startswith('__'): 103 | continue 104 | config[key] = value 105 | del sys.modules[module_name] 106 | return config 107 | 108 | 109 | def update_config(config, new_config): 110 | """Updates configuration in a hierarchical level. 111 | 112 | For key-value pair {'a.b.c.d': v} in `new_config`, the `config` will be 113 | updated by 114 | 115 | config['a']['b']['c']['d'] = v 116 | """ 117 | if new_config is None: 118 | return config 119 | 120 | assert isinstance(config, dict) 121 | assert isinstance(new_config, dict) 122 | 123 | for key, val in new_config.items(): 124 | hierarchical_keys = key.split('.') 125 | temp = config 126 | for sub_key in hierarchical_keys[:-1]: 127 | temp = temp[sub_key] 128 | temp[hierarchical_keys[-1]] = val 129 | 130 | return config 131 | -------------------------------------------------------------------------------- /train_classification_models/self-train_VGGFace2/kernel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def distmat(X): 6 | """distance matrix""" 7 | assert X.ndim == 2 8 | r = torch.sum(X * X, dim=1, keepdim=True) 9 | # r = r.view([-1, 1]) 10 | a = torch.mm(X, torch.transpose(X, 0, 1)) 11 | D = r.expand_as(a) - 2 * a + torch.transpose(r, 0, 1).expand_as(a) 12 | D = torch.abs(D) 13 | return D 14 | 15 | 16 | def sigma_estimation(X, Y): 17 | """sigma from median distance""" 18 | D = distmat(torch.cat([X, Y])) 19 | D = D.detach().cpu().numpy() 20 | Itri = np.tril_indices(D.shape[0], -1) 21 | Tri = D[Itri] 22 | med = np.median(Tri) 23 | if med <= 0: 24 | med = np.mean(Tri) 25 | if med < 1e-2: 26 | med = 1e-2 27 | return med 28 | 29 | 30 | def hisc_kernelmat(X, sigma, ktype='gaussian'): 31 | """kernel matrix baker""" 32 | m = int(X.size()[0]) 33 | H = torch.eye(m) - (1.0 / m) * torch.ones([m, m]) 34 | 35 | if ktype == "gaussian": 36 | Dxx = distmat(X) 37 | 38 | if sigma: 39 | variance = 2.0 * sigma * sigma * X.size()[1] 40 | Kx = torch.exp(-Dxx / variance).type(torch.FloatTensor) # kernel matrices 41 | # print(sigma, torch.mean(Kx), torch.max(Kx), torch.min(Kx)) 42 | else: 43 | try: 44 | sx = sigma_estimation(X, X) 45 | Kx = torch.exp(-Dxx / (2.0 * sx * sx)).type(torch.FloatTensor) 46 | except RuntimeError as e: 47 | raise RuntimeError( 48 | "Unstable sigma {} with maximum/minimum input ({},{})".format( 49 | sx, torch.max(X), torch.min(X) 50 | ) 51 | ) 52 | 53 | elif ktype == "linear": 54 | Kx = torch.mm(X, X.T).type(torch.FloatTensor) 55 | 56 | elif ktype == 'IMQ': 57 | Dxx = distmat(X) 58 | Kx = 1 * torch.rsqrt(Dxx + 1) 59 | 60 | Kxc = torch.mm(Kx, H) 61 | 62 | return Kxc 63 | 64 | 65 | def hsic_normalized_cca(x, y, sigma, ktype='gaussian'): 66 | m = int(x.size()[0]) 67 | Kxc = hisc_kernelmat(x, sigma=sigma) 68 | Kyc = hisc_kernelmat(y, sigma=sigma, ktype=ktype) 69 | 70 | epsilon = 1e-5 71 | K_I = torch.eye(m) 72 | Kxc_i = torch.inverse(Kxc + epsilon * m * K_I) 73 | Kyc_i = torch.inverse(Kyc + epsilon * m * K_I) 74 | Rx = Kxc.mm(Kxc_i) 75 | Ry = Kyc.mm(Kyc_i) 76 | Pxy = torch.sum(torch.mul(Rx, Ry.t())) 77 | 78 | return Pxy 79 | 80 | 81 | def hsic_objective(hidden, h_target, h_data, sigma, ktype='gaussian'): 82 | hsic_hx_val = hsic_normalized_cca(hidden, h_data, sigma=sigma) 83 | hsic_hy_val = hsic_normalized_cca(hidden, h_target, sigma=sigma, ktype=ktype) 84 | 85 | return hsic_hx_val, hsic_hy_val 86 | 87 | 88 | def coco_kernelmat(X, sigma, ktype='gaussian'): 89 | """kernel matrix baker""" 90 | m = int(X.size()[0]) 91 | H = torch.eye(m) - (1.0 / m) * torch.ones([m, m]) 92 | 93 | if ktype == "gaussian": 94 | Dxx = distmat(X) 95 | 96 | if sigma: 97 | variance = 2.0 * sigma * sigma * X.size()[1] 98 | Kx = torch.exp(-Dxx / variance).type(torch.FloatTensor) # kernel matrices 99 | # print(sigma, torch.mean(Kx), torch.max(Kx), torch.min(Kx)) 100 | else: 101 | try: 102 | sx = sigma_estimation(X, X) 103 | Kx = torch.exp(-Dxx / (2.0 * sx * sx)).type(torch.FloatTensor) 104 | except RuntimeError as e: 105 | raise RuntimeError( 106 | "Unstable sigma {} with maximum/minimum input ({},{})".format( 107 | sx, torch.max(X), torch.min(X) 108 | ) 109 | ) 110 | 111 | ## Adding linear kernel 112 | elif ktype == "linear": 113 | Kx = torch.mm(X, X.T).type(torch.FloatTensor) 114 | 115 | elif ktype == 'IMQ': 116 | Dxx = distmat(X) 117 | Kx = 1 * torch.rsqrt(Dxx + 1) 118 | 119 | Kxc = torch.mm(H, torch.mm(Kx, H)) 120 | 121 | return Kxc 122 | 123 | 124 | def coco_normalized_cca(x, y, sigma, ktype='gaussian'): 125 | m = int(x.size()[0]) 126 | K = coco_kernelmat(x, sigma=sigma) 127 | L = coco_kernelmat(y, sigma=sigma, ktype=ktype) 128 | 129 | res = torch.sqrt(torch.norm(torch.mm(K, L))) / m 130 | return res 131 | 132 | 133 | def coco_objective(hidden, h_target, h_data, sigma, ktype='gaussian'): 134 | coco_hx_val = coco_normalized_cca(hidden, h_data, sigma=sigma) 135 | coco_hy_val = coco_normalized_cca(hidden, h_target, sigma=sigma, ktype=ktype) 136 | 137 | return coco_hx_val, coco_hy_val 138 | -------------------------------------------------------------------------------- /genforce/datasets/dataloaders.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Contains the class of data loader.""" 3 | 4 | import argparse 5 | 6 | from torch.utils.data import DataLoader 7 | from .distributed_sampler import DistributedSampler 8 | from .datasets import BaseDataset 9 | 10 | 11 | __all__ = ['IterDataLoader'] 12 | 13 | 14 | class IterDataLoader(object): 15 | """Iteration-based data loader.""" 16 | 17 | def __init__(self, 18 | dataset, 19 | batch_size, 20 | shuffle=True, 21 | num_workers=1, 22 | current_iter=0, 23 | repeat=1): 24 | """Initializes the data loader. 25 | 26 | Args: 27 | dataset: The dataset to load data from. 28 | batch_size: The batch size on each GPU. 29 | shuffle: Whether to shuffle the data. (default: True) 30 | num_workers: Number of data workers for each GPU. (default: 1) 31 | current_iter: The current number of iterations. (default: 0) 32 | repeat: The repeating number of the whole dataloader. (default: 1) 33 | """ 34 | self._dataset = dataset 35 | self.batch_size = batch_size 36 | self.shuffle = shuffle 37 | self.num_workers = num_workers 38 | self._dataloader = None 39 | self.iter_loader = None 40 | self._iter = current_iter 41 | self.repeat = repeat 42 | self.build_dataloader() 43 | 44 | def build_dataloader(self): 45 | """Builds data loader.""" 46 | dist_sampler = DistributedSampler(self._dataset, 47 | shuffle=self.shuffle, 48 | current_iter=self._iter, 49 | repeat=self.repeat) 50 | 51 | self._dataloader = DataLoader(self._dataset, 52 | batch_size=self.batch_size, 53 | shuffle=(dist_sampler is None), 54 | num_workers=self.num_workers, 55 | drop_last=self.shuffle, 56 | pin_memory=True, 57 | sampler=dist_sampler) 58 | self.iter_loader = iter(self._dataloader) 59 | 60 | 61 | def overwrite_param(self, batch_size=None, resolution=None): 62 | """Overwrites some parameters for progressive training.""" 63 | if (not batch_size) and (not resolution): 64 | return 65 | if (batch_size == self.batch_size) and ( 66 | resolution == self.dataset.resolution): 67 | return 68 | if batch_size: 69 | self.batch_size = batch_size 70 | if resolution: 71 | self._dataset.resolution = resolution 72 | self.build_dataloader() 73 | 74 | @property 75 | def iter(self): 76 | """Returns the current iteration.""" 77 | return self._iter 78 | 79 | @property 80 | def dataset(self): 81 | """Returns the dataset.""" 82 | return self._dataset 83 | 84 | @property 85 | def dataloader(self): 86 | """Returns the data loader.""" 87 | return self._dataloader 88 | 89 | def __next__(self): 90 | try: 91 | data = next(self.iter_loader) 92 | self._iter += 1 93 | except StopIteration: 94 | self._dataloader.sampler.__reset__(self._iter) 95 | self.iter_loader = iter(self._dataloader) 96 | data = next(self.iter_loader) 97 | self._iter += 1 98 | return data 99 | 100 | def __len__(self): 101 | return len(self._dataloader) 102 | 103 | 104 | def dataloader_test(root_dir, test_num=10): 105 | """Tests data loader.""" 106 | res = 2 107 | bs = 2 108 | dataset = BaseDataset(root_dir=root_dir, resolution=res) 109 | dataloader = IterDataLoader(dataset=dataset, 110 | batch_size=bs, 111 | shuffle=False) 112 | for _ in range(test_num): 113 | data_batch = next(dataloader) 114 | image = data_batch['image'] 115 | assert image.shape == (bs, 3, res, res) 116 | res *= 2 117 | bs += 1 118 | dataloader.overwrite_param(batch_size=bs, resolution=res) 119 | 120 | 121 | if __name__ == '__main__': 122 | parser = argparse.ArgumentParser(description='Test Data Loader.') 123 | parser.add_argument('root_dir', type=str, 124 | help='Root directory of the dataset.') 125 | parser.add_argument('--test_num', type=int, default=10, 126 | help='Number of tests. (default: %(default)s)') 127 | args = parser.parse_args() 128 | dataloader_test(args.root_dir, args.test_num) 129 | -------------------------------------------------------------------------------- /genforce/runners/losses/logistic_gan_loss.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Defines loss functions.""" 3 | 4 | import torch 5 | import numpy as np 6 | import torch.nn.functional as F 7 | 8 | __all__ = ['LogisticGANLoss'] 9 | 10 | apply_loss_scaling = lambda x: x * torch.exp(x * np.log(2.0)) 11 | undo_loss_scaling = lambda x: x * torch.exp(-x * np.log(2.0)) 12 | 13 | 14 | class LogisticGANLoss(object): 15 | """Contains the class to compute logistic GAN loss.""" 16 | 17 | def __init__(self, runner, d_loss_kwargs=None, g_loss_kwargs=None): 18 | """Initializes with models and arguments for computing losses.""" 19 | self.d_loss_kwargs = d_loss_kwargs or dict() 20 | self.g_loss_kwargs = g_loss_kwargs or dict() 21 | self.r1_gamma = self.d_loss_kwargs.get('r1_gamma', 10.0) 22 | self.r2_gamma = self.d_loss_kwargs.get('r2_gamma', 0.0) 23 | 24 | runner.running_stats.add( 25 | f'g_loss', log_format='.3f', log_strategy='AVERAGE') 26 | runner.running_stats.add( 27 | f'd_loss', log_format='.3f', log_strategy='AVERAGE') 28 | if self.r1_gamma != 0: 29 | runner.running_stats.add( 30 | f'real_grad_penalty', log_format='.3f', log_strategy='AVERAGE') 31 | if self.r2_gamma != 0: 32 | runner.running_stats.add( 33 | f'fake_grad_penalty', log_format='.3f', log_strategy='AVERAGE') 34 | 35 | @staticmethod 36 | def preprocess_image(images, lod=0, **_unused_kwargs): 37 | """Pre-process images.""" 38 | if lod != int(lod): 39 | downsampled_images = F.avg_pool2d( 40 | images, kernel_size=2, stride=2, padding=0) 41 | upsampled_images = F.interpolate( 42 | downsampled_images, scale_factor=2, mode='nearest') 43 | alpha = lod - int(lod) 44 | images = images * (1 - alpha) + upsampled_images * alpha 45 | if int(lod) == 0: 46 | return images 47 | return F.interpolate( 48 | images, scale_factor=(2 ** int(lod)), mode='nearest') 49 | 50 | @staticmethod 51 | def compute_grad_penalty(images, scores): 52 | """Computes gradient penalty.""" 53 | image_grad = torch.autograd.grad( 54 | outputs=scores.sum(), 55 | inputs=images, 56 | create_graph=True, 57 | retain_graph=True)[0].view(images.shape[0], -1) 58 | penalty = image_grad.pow(2).sum(dim=1).mean() 59 | return penalty 60 | 61 | def d_loss(self, runner, data): 62 | """Computes loss for discriminator.""" 63 | G = runner.models['generator'] 64 | D = runner.models['discriminator'] 65 | 66 | reals = self.preprocess_image(data['image'], lod=runner.lod) 67 | reals.requires_grad = True 68 | labels = data.get('label', None) 69 | 70 | latents = torch.randn(reals.shape[0], runner.z_space_dim).cuda() 71 | latents.requires_grad = True 72 | # TODO: Use random labels. 73 | fakes = G(latents, label=labels, **runner.G_kwargs_train)['image'] 74 | real_scores = D(reals, label=labels, **runner.D_kwargs_train) 75 | fake_scores = D(fakes, label=labels, **runner.D_kwargs_train) 76 | 77 | d_loss = F.softplus(fake_scores).mean() 78 | d_loss += F.softplus(-real_scores).mean() 79 | runner.running_stats.update({'d_loss': d_loss.item()}) 80 | 81 | real_grad_penalty = torch.zeros_like(d_loss) 82 | fake_grad_penalty = torch.zeros_like(d_loss) 83 | if self.r1_gamma: 84 | real_grad_penalty = self.compute_grad_penalty(reals, real_scores) 85 | runner.running_stats.update( 86 | {'real_grad_penalty': real_grad_penalty.item()}) 87 | if self.r2_gamma: 88 | fake_grad_penalty = self.compute_grad_penalty(fakes, fake_scores) 89 | runner.running_stats.update( 90 | {'fake_grad_penalty': fake_grad_penalty.item()}) 91 | 92 | return (d_loss + 93 | real_grad_penalty * (self.r1_gamma * 0.5) + 94 | fake_grad_penalty * (self.r2_gamma * 0.5)) 95 | 96 | def g_loss(self, runner, data): # pylint: disable=no-self-use 97 | """Computes loss for generator.""" 98 | # TODO: Use random labels. 99 | G = runner.models['generator'] 100 | D = runner.models['discriminator'] 101 | batch_size = data['image'].shape[0] 102 | labels = data.get('label', None) 103 | 104 | latents = torch.randn(batch_size, runner.z_space_dim).cuda() 105 | fakes = G(latents, label=labels, **runner.G_kwargs_train)['image'] 106 | fake_scores = D(fakes, label=labels, **runner.D_kwargs_train) 107 | 108 | g_loss = F.softplus(-fake_scores).mean() 109 | runner.running_stats.update({'g_loss': g_loss.item()}) 110 | 111 | return g_loss 112 | -------------------------------------------------------------------------------- /genforce/train.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Main function for model training.""" 3 | 4 | import os.path 5 | import shutil 6 | import warnings 7 | import random 8 | import argparse 9 | import numpy as np 10 | 11 | import torch 12 | import torch.distributed as dist 13 | 14 | import runners 15 | from utils.logger import build_logger 16 | from utils.misc import init_dist 17 | from utils.misc import DictAction, parse_config, update_config 18 | 19 | 20 | def parse_args(): 21 | """Parses arguments.""" 22 | parser = argparse.ArgumentParser(description='Run model training.') 23 | parser.add_argument('config', type=str, 24 | help='Path to the training configuration.') 25 | parser.add_argument('--work_dir', type=str, required=True, 26 | help='The work directory to save logs and checkpoints.') 27 | parser.add_argument('--resume_path', type=str, default=None, 28 | help='Path to the checkpoint to resume training.') 29 | parser.add_argument('--weight_path', type=str, default=None, 30 | help='Path to the checkpoint to load model weights, ' 31 | 'but not resume other states.') 32 | parser.add_argument('--seed', type=int, default=None, 33 | help='Random seed. (default: %(default)s)') 34 | parser.add_argument('--launcher', type=str, default='pytorch', 35 | choices=['pytorch', 'slurm'], 36 | help='Launcher type. (default: %(default)s)') 37 | parser.add_argument('--backend', type=str, default='nccl', 38 | help='Backend for distributed launcher. (default: ' 39 | '%(default)s)') 40 | parser.add_argument('--rank', type=int, default=-1, 41 | help='Node rank for distributed running. (default: ' 42 | '%(default)s)') 43 | parser.add_argument('--local_rank', type=int, default=0, 44 | help='Rank of the current node. (default: %(default)s)') 45 | parser.add_argument('--options', nargs='+', action=DictAction, 46 | help='arguments in dict') 47 | return parser.parse_args() 48 | 49 | 50 | def main(): 51 | """Main function.""" 52 | # Parse arguments. 53 | args = parse_args() 54 | 55 | # Parse configurations. 56 | config = parse_config(args.config) 57 | config = update_config(config, args.options) 58 | config.work_dir = args.work_dir 59 | config.resume_path = args.resume_path 60 | config.weight_path = args.weight_path 61 | config.seed = args.seed 62 | config.launcher = args.launcher 63 | config.backend = args.backend 64 | 65 | # Set CUDNN. 66 | config.cudnn_benchmark = config.get('cudnn_benchmark', True) 67 | config.cudnn_deterministic = config.get('cudnn_deterministic', False) 68 | torch.backends.cudnn.benchmark = config.cudnn_benchmark 69 | torch.backends.cudnn.deterministic = config.cudnn_deterministic 70 | 71 | # Set random seed. 72 | if config.seed is not None: 73 | random.seed(config.seed) 74 | np.random.seed(config.seed) 75 | torch.manual_seed(config.seed) 76 | config.cudnn_deterministic = True 77 | torch.backends.cudnn.deterministic = True 78 | warnings.warn('Random seed is set for training! ' 79 | 'This will turn on the CUDNN deterministic setting, ' 80 | 'which may slow down the training considerably! ' 81 | 'Unexpected behavior can be observed when resuming from ' 82 | 'checkpoints.') 83 | 84 | # Set launcher. 85 | config.is_distributed = True 86 | init_dist(config.launcher, backend=config.backend) 87 | config.num_gpus = dist.get_world_size() 88 | 89 | # Setup logger. 90 | if dist.get_rank() == 0: 91 | logger_type = config.get('logger_type', 'normal') 92 | logger = build_logger(logger_type, work_dir=config.work_dir) 93 | shutil.copy(args.config, os.path.join(config.work_dir, 'config.py')) 94 | commit_id = os.popen('git rev-parse HEAD').readline() 95 | logger.info(f'Commit ID: {commit_id}') 96 | else: 97 | logger = build_logger('dumb', work_dir=config.work_dir) 98 | 99 | # Start training. 100 | runner = getattr(runners, config.runner_type)(config, logger) 101 | if config.resume_path: 102 | runner.load(filepath=config.resume_path, 103 | running_metadata=True, 104 | learning_rate=True, 105 | optimizer=True, 106 | running_stats=False) 107 | if config.weight_path: 108 | runner.load(filepath=config.weight_path, 109 | running_metadata=False, 110 | learning_rate=False, 111 | optimizer=False, 112 | running_stats=False) 113 | runner.train() 114 | 115 | 116 | if __name__ == '__main__': 117 | main() 118 | -------------------------------------------------------------------------------- /genforce/test.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Main function for model inference.""" 3 | 4 | import os.path 5 | import shutil 6 | import argparse 7 | 8 | import torch 9 | import torch.distributed as dist 10 | 11 | import runners 12 | from utils.logger import build_logger 13 | from utils.misc import init_dist 14 | from utils.misc import DictAction, parse_config, update_config 15 | 16 | 17 | def parse_args(): 18 | """Parses arguments.""" 19 | parser = argparse.ArgumentParser(description='Run model inference.') 20 | parser.add_argument('config', type=str, 21 | help='Path to the inference configuration.') 22 | parser.add_argument('--work_dir', type=str, required=True, 23 | help='The work directory to save logs and checkpoints.') 24 | parser.add_argument('--checkpoint', type=str, required=True, 25 | help='Path to the checkpoint to load. (default: ' 26 | '%(default)s)') 27 | parser.add_argument('--synthesis_num', type=int, default=1000, 28 | help='Number of samples to synthesize. Set as 0 to ' 29 | 'disable synthesis. (default: %(default)s)') 30 | parser.add_argument('--fid_num', type=int, default=50000, 31 | help='Number of samples to compute FID. Set as 0 to ' 32 | 'disable FID test. (default: %(default)s)') 33 | parser.add_argument('--use_torchvision', action='store_true', 34 | help='Wether to use the Inception model from ' 35 | '`torchvision` to compute FID. (default: False)') 36 | parser.add_argument('--launcher', type=str, default='pytorch', 37 | choices=['pytorch', 'slurm'], 38 | help='Launcher type. (default: %(default)s)') 39 | parser.add_argument('--backend', type=str, default='nccl', 40 | help='Backend for distributed launcher. (default: ' 41 | '%(default)s)') 42 | parser.add_argument('--rank', type=int, default=-1, 43 | help='Node rank for distributed running. (default: ' 44 | '%(default)s)') 45 | parser.add_argument('--local_rank', type=int, default=0, 46 | help='Rank of the current node. (default: %(default)s)') 47 | parser.add_argument('--options', nargs='+', action=DictAction, 48 | help='arguments in dict') 49 | return parser.parse_args() 50 | 51 | 52 | def main(): 53 | """Main function.""" 54 | # Parse arguments. 55 | args = parse_args() 56 | 57 | # Parse configurations. 58 | config = parse_config(args.config) 59 | config = update_config(config, args.options) 60 | config.work_dir = args.work_dir 61 | config.checkpoint = args.checkpoint 62 | config.launcher = args.launcher 63 | config.backend = args.backend 64 | if not os.path.isfile(config.checkpoint): 65 | raise FileNotFoundError(f'Checkpoint file `{config.checkpoint}` is ' 66 | f'missing!') 67 | 68 | # Set CUDNN. 69 | config.cudnn_benchmark = config.get('cudnn_benchmark', True) 70 | config.cudnn_deterministic = config.get('cudnn_deterministic', False) 71 | torch.backends.cudnn.benchmark = config.cudnn_benchmark 72 | torch.backends.cudnn.deterministic = config.cudnn_deterministic 73 | 74 | # Setting for launcher. 75 | config.is_distributed = True 76 | init_dist(config.launcher, backend=config.backend) 77 | config.num_gpus = dist.get_world_size() 78 | 79 | # Setup logger. 80 | if dist.get_rank() == 0: 81 | logger_type = config.get('logger_type', 'normal') 82 | logger = build_logger(logger_type, work_dir=config.work_dir) 83 | shutil.copy(args.config, os.path.join(config.work_dir, 'config.py')) 84 | commit_id = os.popen('git rev-parse HEAD').readline() 85 | logger.info(f'Commit ID: {commit_id}') 86 | else: 87 | logger = build_logger('dumb', work_dir=config.work_dir) 88 | 89 | # Start inference. 90 | runner = getattr(runners, config.runner_type)(config, logger) 91 | runner.load(filepath=config.checkpoint, 92 | running_metadata=False, 93 | learning_rate=False, 94 | optimizer=False, 95 | running_stats=False) 96 | 97 | if args.synthesis_num > 0: 98 | num = args.synthesis_num 99 | logger.print() 100 | logger.info(f'Synthesizing images ...') 101 | runner.synthesize(num, html_name=f'synthesis_{num}.html') 102 | logger.info(f'Finish synthesizing {num} images.') 103 | 104 | if args.fid_num > 0: 105 | num = args.fid_num 106 | logger.print() 107 | logger.info(f'Testing FID ...') 108 | fid_value = runner.fid(num, align_tf=not args.use_torchvision) 109 | logger.info(f'Finish testing FID on {num} samples. ' 110 | f'The result is {fid_value:.6f}.') 111 | 112 | 113 | if __name__ == '__main__': 114 | main() 115 | -------------------------------------------------------------------------------- /genforce/README.md: -------------------------------------------------------------------------------- 1 | **NOTE: The code is copied from [original GenForce repo](https://github.com/genforce/genforce/tree/e7e82e47027f8653636b379a2d314d8e4ca91ef6). The major modifications are in the `models/stylegan2_generator.py` and `models/stylegan_generator.py` to support w vectors inputs. We also added a file `my_get_GD.py` for easier access to the generator and discriminator.** 2 | 3 | # GenForce Lib for Generative Modeling 4 | 5 | An efficient PyTorch library for deep generative modeling. May the Generative Force (GenForce) be with You. 6 | 7 | ![image](./teaser.gif) 8 | 9 | ## Highlights 10 | 11 | - **Distributed** training framework. 12 | - **Fast** training speed. 13 | - **Modular** design for prototyping new models. 14 | - **Highly** reproducing the training of StyleGAN compared to [the official TensorFlow version](https://github.com/NVlabs/stylegan). 15 | - **Model zoo** containing a rich set of pretrained GAN models, with [Colab live demo](https://colab.research.google.com/github/genforce/genforce/blob/master/docs/synthesize_demo.ipynb) to play. 16 | 17 | We will also support following functions *in the very near future*. Please **STAY TUNED**. 18 | 19 | - Training of PGGAN and StyleGAN2 (and likely BigGAN too). 20 | - Benchmark on model training. 21 | - Training of GAN encoder from [In-Domain GAN Inversion](https://genforce.github.io/idinvert). 22 | - Other recent work from our [GenForce](http://genforce.github.io/). 23 | 24 | ## Installation 25 | 26 | 1. Create a virtual environment via `conda`. 27 | 28 | ```shell 29 | conda create -n genforce python=3.7 30 | conda activate genforce 31 | ``` 32 | 33 | 2. Install `torch` and `torchvision`. 34 | 35 | ```shell 36 | conda install pytorch cudatoolkit=10.1 torchvision -c pytorch 37 | ``` 38 | 39 | 3. Install requirements. 40 | 41 | ```shell 42 | pip install -r requirements.txt 43 | ``` 44 | 45 | ## Quick Demo 46 | 47 | We provide a quick training demo, `scripts/stylegan_training_demo.py`, which allows to train StyleGAN on a toy dataset (500 animeface images with 64 x 64 resolution). Try it via 48 | 49 | ```shell 50 | ./scripts/stylegan_training_demo.sh 51 | ``` 52 | 53 | We also provide an inference demo, `synthesize.py`, which allows to synthesize images with pre-trained models. Generated images can be found at `work_dirs/synthesis_results/`. Try it via 54 | 55 | ```shell 56 | python synthesize.py stylegan_ffhq1024 57 | ``` 58 | 59 | You can also play the demo at [Colab](https://colab.research.google.com/github/genforce/genforce/blob/master/docs/synthesize_demo.ipynb). 60 | 61 | ## Get Started 62 | 63 | ### Test 64 | 65 | Pre-trained models can be found at [model zoo](MODEL_ZOO.md). 66 | 67 | - On local machine: 68 | 69 | ```shell 70 | GPUS=8 71 | CONFIG=configs/stylegan_ffhq256_val.py 72 | WORK_DIR=work_dirs/stylegan_ffhq256_val 73 | CHECKPOINT=checkpoints/stylegan_ffhq256.pth 74 | ./scripts/dist_test.sh ${GPUS} ${CONFIG} ${WORK_DIR} ${CHECKPOINT} 75 | ``` 76 | 77 | - Using `slurm`: 78 | 79 | ```shell 80 | CONFIG=configs/stylegan_ffhq256_val.py 81 | WORK_DIR=work_dirs/stylegan_ffhq256_val 82 | CHECKPOINT=checkpoints/stylegan_ffhq256.pth 83 | GPUS=8 ./scripts/slurm_test.sh ${PARTITION} ${JOB_NAME} \ 84 | ${CONFIG} ${WORK_DIR} ${CHECKPOINT} 85 | ``` 86 | 87 | ### Train 88 | 89 | All log files in the training process, such as log message, checkpoints, synthesis snapshots, etc, will be saved to the work directory. 90 | 91 | - On local machine: 92 | 93 | ```shell 94 | GPUS=8 95 | CONFIG=configs/stylegan_ffhq256.py 96 | WORK_DIR=work_dirs/stylegan_ffhq256_train 97 | ./scripts/dist_train.sh ${GPUS} ${CONFIG} ${WORK_DIR} \ 98 | [--options additional_arguments] 99 | ``` 100 | 101 | - Using `slurm`: 102 | 103 | ```shell 104 | CONFIG=configs/stylegan_ffhq256.py 105 | WORK_DIR=work_dirs/stylegan_ffhq256_train 106 | GPUS=8 ./scripts/slurm_train.sh ${PARTITION} ${JOB_NAME} \ 107 | ${CONFIG} ${WORK_DIR} \ 108 | [--options additional_arguments] 109 | ``` 110 | 111 | ## Contributors 112 | 113 | | Member | Module | 114 | | :-- | :-- | 115 | |[Yujun Shen](http://shenyujun.github.io/) | models and running controllers 116 | |[Yinghao Xu](https://justimyhxu.github.io/) | runner and loss functions 117 | |[Ceyuan Yang](http://ceyuan.me/) | data loader 118 | |[Jiapeng Zhu](https://zhujiapeng.github.io/) | evaluation metrics 119 | |[Bolei Zhou](http://bzhou.ie.cuhk.edu.hk/) | cheerleader 120 | 121 | **NOTE:** The above form only lists the person in charge for each module. We help each other a lot and develop as a **TEAM**. 122 | 123 | *We welcome external contributors to join us for improving this library.* 124 | 125 | ## License 126 | 127 | The project is under the [MIT License](./LICENSE). 128 | 129 | ## Acknowledgement 130 | 131 | We thank [PGGAN](https://github.com/tkarras/progressive_growing_of_gans), [StyleGAN](https://github.com/NVlabs/stylegan), [StyleGAN2](https://github.com/NVlabs/stylegan2) for their work on high-quality image synthesis. We also thank [MMCV](https://github.com/open-mmlab/mmcv) for the inspiration on the design of controllers. 132 | 133 | ## BibTex 134 | 135 | We open source this library to the community to facilitate the research of generative modeling. If you do like our work and use the codebase or models for your research, please cite our work as follows. 136 | 137 | ```bibtex 138 | @misc{genforce2020, 139 | title = {GenForce}, 140 | author = {Shen, Yujun and Xu, Yinghao and Yang, Ceyuan and Zhu, Jiapeng and Zhou, Bolei}, 141 | howpublished = {\url{https://github.com/genforce/genforce}}, 142 | year = {2020} 143 | } 144 | ``` 145 | -------------------------------------------------------------------------------- /vgg_m_face_bn_dag.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import random 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | class Vgg_m_face_bn_dag(nn.Module): 9 | 10 | def __init__(self, use_dropout=False): 11 | super(Vgg_m_face_bn_dag, self).__init__() 12 | self.meta = {'mean': [131.45376586914062, 103.98748016357422, 91.46234893798828], 13 | 'std': [1, 1, 1], 14 | 'imageSize': [224, 224, 3]} 15 | self.conv1 = nn.Conv2d(3, 96, kernel_size=[7, 7], stride=(2, 2)) 16 | self.bn49 = nn.BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) 17 | self.relu1 = nn.ReLU(inplace=True) 18 | self.pool1 = nn.MaxPool2d(kernel_size=[3, 3], stride=[2, 2], padding=0, dilation=1, ceil_mode=False) 19 | self.conv2 = nn.Conv2d(96, 256, kernel_size=[5, 5], stride=(2, 2), padding=(1, 1)) 20 | self.bn50 = nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) 21 | self.relu2 = nn.ReLU(inplace=True) 22 | self.pool2 = nn.MaxPool2d(kernel_size=[3, 3], stride=[2, 2], padding=(0, 0), dilation=1, ceil_mode=True) 23 | self.conv3 = nn.Conv2d(256, 512, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1)) 24 | self.bn51 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) 25 | self.relu3 = nn.ReLU(inplace=True) 26 | self.conv4 = nn.Conv2d(512, 512, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1)) 27 | self.bn52 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) 28 | self.relu4 = nn.ReLU(inplace=True) 29 | self.conv5 = nn.Conv2d(512, 512, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1)) 30 | self.bn53 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) 31 | self.relu5 = nn.ReLU(inplace=True) 32 | self.pool5 = nn.MaxPool2d(kernel_size=[3, 3], stride=[2, 2], padding=0, dilation=1, ceil_mode=False) 33 | self.fc6 = nn.Conv2d(512, 4096, kernel_size=[6, 6], stride=(1, 1)) 34 | self.bn54 = nn.BatchNorm2d(4096, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) 35 | self.relu6 = nn.ReLU(inplace=True) 36 | self.fc7 = nn.Conv2d(4096, 4096, kernel_size=[1, 1], stride=(1, 1)) 37 | self.bn55 = nn.BatchNorm2d(4096, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) 38 | self.relu7 = nn.ReLU(inplace=True) 39 | self.fc8 = nn.Linear(in_features=4096, out_features=2622, bias=True) 40 | 41 | self.use_dropout = use_dropout 42 | 43 | if self.use_dropout: 44 | # fc_dropout_probs = {1: 0.6, 2: 0.5, 3: 0.3} 45 | conv_dropout_probs = {1: 0.5, 2: 0.2, 3: 0.2, 4: 0.1, 5: 0.1} 46 | # self.fc_dropouts = {k: partial(nn.functional.dropout, p=v) for k, v in fc_dropout_probs.items()} 47 | self.conv_dropouts = {k: nn.Dropout2d(v) for k, v in conv_dropout_probs.items()} 48 | 49 | print(f'conv_dropout_probs: {conv_dropout_probs}') # f'fc_dropout_probs: {fc_dropout_probs}\n' 50 | 51 | def forward(self, x0): 52 | if self.use_dropout: 53 | for x in self.conv_dropouts.values(): 54 | x.training = True 55 | 56 | k = random.randint(1, 5) 57 | conv_dropout_layers = set(random.choices(range(1, 7), k=k)) # 6 means no dropout 58 | # k = random.randint(1, 3) 59 | # fc_dropout_layers = set(random.choices(range(1, 5), k=k)) # 4 means no dropout 60 | 61 | conv_dropout = self.conv_dropouts[len(conv_dropout_layers)] 62 | 63 | # fc_dropout = self.fc_dropouts[len(fc_dropout_layers)] 64 | else: 65 | conv_dropout_layers = set() 66 | # fc_dropout_layers = set() 67 | conv_dropout = None 68 | # fc_dropout = None 69 | 70 | # print('conv_dropout_layers', conv_dropout_layers) 71 | # print('fc_dropout_layers', fc_dropout_layers) 72 | 73 | x1 = self.conv1(x0) 74 | x2 = self.bn49(x1) 75 | x3 = self.relu1(x2) 76 | x4 = self.pool1(x3) 77 | 78 | if 1 in conv_dropout_layers: 79 | x4 = conv_dropout(x4) 80 | 81 | x5 = self.conv2(x4) 82 | x6 = self.bn50(x5) 83 | x7 = self.relu2(x6) 84 | x8 = self.pool2(x7) 85 | 86 | if 2 in conv_dropout_layers: 87 | x8 = conv_dropout(x8) 88 | 89 | x9 = self.conv3(x8) 90 | x10 = self.bn51(x9) 91 | x11 = self.relu3(x10) 92 | x12 = self.conv4(x11) 93 | x13 = self.bn52(x12) 94 | x14 = self.relu4(x13) 95 | x15 = self.conv5(x14) 96 | x16 = self.bn53(x15) 97 | x17 = self.relu5(x16) 98 | x18 = self.pool5(x17) 99 | 100 | if 3 in conv_dropout_layers: 101 | x18 = conv_dropout(x18) 102 | 103 | x19 = self.fc6(x18) 104 | x20 = self.bn54(x19) 105 | x21 = self.relu6(x20) 106 | 107 | if 4 in conv_dropout_layers: 108 | x21 = conv_dropout(x21) 109 | 110 | x22 = self.fc7(x21) 111 | x23 = self.bn55(x22) 112 | x24_preflatten = self.relu7(x23) 113 | 114 | if 5 in conv_dropout_layers: 115 | x24_preflatten = conv_dropout(x24_preflatten) 116 | 117 | x24 = x24_preflatten.view(x24_preflatten.size(0), -1) 118 | x25 = self.fc8(x24) 119 | return x25 120 | 121 | 122 | def vgg_m_face_bn_dag(weights_path=None, **kwargs): 123 | """ 124 | load imported model instance 125 | 126 | Args: 127 | weights_path (str): If set, loads model weights from the given path 128 | """ 129 | model = Vgg_m_face_bn_dag(**kwargs) 130 | if weights_path: 131 | state_dict = torch.load(weights_path) 132 | model.load_state_dict(state_dict) 133 | return model 134 | -------------------------------------------------------------------------------- /train_classification_models/self-train_CASIA/train_efficientnet_b0_casia.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import torchvision 4 | import torch.nn.functional as F 5 | from torchvision import transforms 6 | from torch.utils.data.dataloader import DataLoader 7 | from torch.utils.data.dataset import Subset 8 | import numpy as np 9 | from PIL import Image 10 | from sklearn.model_selection import train_test_split 11 | import torchvision.models as models 12 | from typing import Any, Callable, List, Optional, Tuple 13 | from torch import nn, Tensor 14 | from tqdm import tqdm 15 | 16 | import logging 17 | import os 18 | 19 | def setup_logger(log_path): 20 | if not os.path.exists(log_path): 21 | os.makedirs(log_path) 22 | log_file = os.path.join(log_path, 'training.log') 23 | 24 | logger = logging.getLogger('TrainingLogger') 25 | logger.setLevel(logging.DEBUG) 26 | file_handler = logging.FileHandler(log_file) 27 | formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') 28 | file_handler.setFormatter(formatter) 29 | logger.addHandler(file_handler) 30 | 31 | return logger 32 | 33 | 34 | def train_one_epoch(model, train_loader, optimizer, device, logger): 35 | model.train() 36 | running_loss = 0.0 37 | for i, (inputs, labels) in enumerate(tqdm(train_loader, desc="Training", leave=False)): 38 | inputs, labels = inputs.to(device), labels.to(device) 39 | optimizer.zero_grad() 40 | outputs = model(inputs) 41 | loss = F.cross_entropy(outputs, labels) 42 | loss.backward() 43 | optimizer.step() 44 | running_loss += loss.item() 45 | # Log every 100 batches 46 | if (i + 1) % 20 == 0: 47 | logger.info(f"Batch {i+1}/{len(train_loader)}, Loss: {loss.item():.4f}") 48 | return running_loss / len(train_loader) 49 | 50 | def test(model, test_loader, device, logger): 51 | model.eval() 52 | correct = 0 53 | total = 0 54 | with torch.no_grad(): 55 | for inputs, labels in tqdm(test_loader, desc="Testing", leave=False): 56 | inputs, labels = inputs.to(device), labels.to(device) 57 | outputs = model(inputs) 58 | _, predicted = outputs.max(1) 59 | total += labels.size(0) 60 | correct += predicted.eq(labels.view_as(predicted)).sum().item() 61 | 62 | accuracy = 100 * correct / total 63 | logger.info(f"Finished Testing, Final Accuracy: {accuracy:.5f}%") 64 | return accuracy 65 | 66 | 67 | class BasicConv2d(nn.Module): 68 | def __init__(self, in_channels: int, out_channels: int, **kwargs: Any) -> None: 69 | super().__init__() 70 | self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) 71 | self.bn = nn.BatchNorm2d(out_channels, eps=0.001) 72 | 73 | def forward(self, x: Tensor) -> Tensor: 74 | x = self.conv(x) 75 | x = self.bn(x) 76 | return F.relu(x, inplace=True) 77 | 78 | class Normalize(torch.nn.Module): 79 | 80 | def __init__(self, mean, std): 81 | super().__init__() 82 | self.mean = mean 83 | self.std = std 84 | 85 | def forward(self, image_tensor): 86 | image_tensor = (image_tensor-torch.tensor(self.mean, device=image_tensor.device)[:, None, None])/torch.tensor(self.std, device=image_tensor.device)[:, None, None] 87 | return image_tensor 88 | 89 | 90 | 91 | def main(args): 92 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 93 | args.device = device 94 | torch.manual_seed(666) 95 | 96 | log_path = args.LOG 97 | logger = setup_logger(log_path) 98 | 99 | args.resolution = 224 100 | Mean = [127.5, 127.5, 127.5] 101 | Std = [1., 1., 1.] 102 | 103 | T_resize = 360 104 | RESIZE = 256 105 | transform = transforms.Compose([ 106 | transforms.PILToTensor(), 107 | transforms.Resize(RESIZE), 108 | Normalize(Mean, Std) 109 | ]) 110 | 111 | totalset = torchvision.datasets.ImageFolder("/root/autodl-tmp/CASIA-WebFace", transform=transform) 112 | 113 | trainset_list, testset_list = train_test_split(list(range(len(totalset.samples))), test_size=0.05, random_state=666) 114 | trainsete= Subset(totalset, trainset_list) 115 | testset= Subset(totalset, testset_list) 116 | train_loader = DataLoader(trainsete, batch_size=128, shuffle=True, num_workers=8, pin_memory=True) 117 | test_loader = DataLoader(testset, batch_size=128, shuffle=False, num_workers=8, pin_memory=True) 118 | 119 | efficientnet_b0 = models.efficientnet.efficientnet_b0(pretrained=True) 120 | 121 | efficientnet_b0.classifier = nn.Sequential( 122 | nn.Dropout(p=0.2, inplace=True), 123 | nn.Linear(1280, 10575), 124 | ) 125 | 126 | efficientnet_b0 = efficientnet_b0.to(device) 127 | 128 | optimizer = torch.optim.Adam(efficientnet_b0.parameters(), lr=0.001) 129 | 130 | best_acc = 0.0 131 | for epoch in range(args.epochs): 132 | train_loss = train_one_epoch(efficientnet_b0, train_loader, optimizer, device, logger) 133 | test_acc = test(efficientnet_b0, test_loader, device, logger) 134 | print(f"Epoch {epoch+1}, Train Loss: {train_loss:.4f}, Test Accuracy: {test_acc:.4f}") 135 | if test_acc > best_acc: 136 | best_acc = test_acc 137 | torch.save(efficientnet_b0.state_dict(), './models_CASIA/efficientnet_b0_best_model.pth') 138 | print("Saved best model") 139 | 140 | 141 | if __name__ == '__main__': 142 | parser = argparse.ArgumentParser() 143 | parser.add_argument('--batch_size', default=128, type=int, help='batch size') 144 | parser.add_argument('--epochs', default=200, type=int, help='number of epochs to train') 145 | parser.add_argument('--LOG', default='./LOG_CASIA/logs_efficientnet_b0', type=str) 146 | args = parser.parse_args() 147 | 148 | main(args) -------------------------------------------------------------------------------- /train_classification_models/self-train_VGGFace2/train_efficientnet_b0.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import torchvision 4 | import torch.nn.functional as F 5 | from torchvision import transforms 6 | from torch.utils.data.dataloader import DataLoader 7 | from torch.utils.data.dataset import Subset 8 | import numpy as np 9 | from PIL import Image 10 | from sklearn.model_selection import train_test_split 11 | import torchvision.models as models 12 | from typing import Any, Callable, List, Optional, Tuple 13 | from torch import nn, Tensor 14 | from tqdm import tqdm 15 | 16 | import logging 17 | import os 18 | 19 | def setup_logger(log_path): 20 | if not os.path.exists(log_path): 21 | os.makedirs(log_path) 22 | log_file = os.path.join(log_path, 'training.log') 23 | 24 | logger = logging.getLogger('TrainingLogger') 25 | logger.setLevel(logging.DEBUG) 26 | file_handler = logging.FileHandler(log_file) 27 | formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') 28 | file_handler.setFormatter(formatter) 29 | logger.addHandler(file_handler) 30 | 31 | return logger 32 | 33 | 34 | def train_one_epoch(model, train_loader, optimizer, device, logger): 35 | model.train() 36 | running_loss = 0.0 37 | for i, (inputs, labels) in enumerate(tqdm(train_loader, desc="Training", leave=False)): 38 | inputs, labels = inputs.to(device), labels.to(device) 39 | optimizer.zero_grad() 40 | outputs = model(inputs) 41 | loss = F.cross_entropy(outputs, labels) 42 | loss.backward() 43 | optimizer.step() 44 | running_loss += loss.item() 45 | if (i + 1) % 20 == 0: 46 | logger.info(f"Batch {i+1}/{len(train_loader)}, Loss: {loss.item():.4f}") 47 | return running_loss / len(train_loader) 48 | 49 | def test(model, test_loader, device, logger): 50 | model.eval() 51 | correct = 0 52 | total = 0 53 | with torch.no_grad(): 54 | for inputs, labels in tqdm(test_loader, desc="Testing", leave=False): 55 | inputs, labels = inputs.to(device), labels.to(device) 56 | outputs = model(inputs) 57 | _, predicted = outputs.max(1) 58 | total += labels.size(0) 59 | correct += predicted.eq(labels.view_as(predicted)).sum().item() 60 | 61 | accuracy = 100 * correct / total 62 | logger.info(f"Finished Testing, Final Accuracy: {accuracy:.5f}%") 63 | return accuracy 64 | 65 | 66 | class BasicConv2d(nn.Module): 67 | def __init__(self, in_channels: int, out_channels: int, **kwargs: Any) -> None: 68 | super().__init__() 69 | self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) 70 | self.bn = nn.BatchNorm2d(out_channels, eps=0.001) 71 | 72 | def forward(self, x: Tensor) -> Tensor: 73 | x = self.conv(x) 74 | x = self.bn(x) 75 | return F.relu(x, inplace=True) 76 | 77 | class Normalize(torch.nn.Module): 78 | 79 | def __init__(self, mean, std): 80 | super().__init__() 81 | self.mean = mean 82 | self.std = std 83 | 84 | def forward(self, image_tensor): 85 | image_tensor = (image_tensor-torch.tensor(self.mean, device=image_tensor.device)[:, None, None])/torch.tensor(self.std, device=image_tensor.device)[:, None, None] 86 | return image_tensor 87 | 88 | def main(args): 89 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 90 | args.device = device 91 | torch.manual_seed(666) 92 | 93 | log_path = args.LOG 94 | logger = setup_logger(log_path) 95 | 96 | args.resolution = 224 97 | Mean = [131.0912, 103.8827, 91.4953] 98 | Std = [1., 1., 1.] 99 | 100 | T_resize = 360 101 | RESIZE = 256 102 | transform = transforms.Compose([ 103 | transforms.PILToTensor(), 104 | transforms.Resize(T_resize), 105 | transforms.CenterCrop(args.resolution), 106 | transforms.Resize(RESIZE), 107 | Normalize(Mean, Std) 108 | ]) 109 | 110 | totalset = torchvision.datasets.ImageFolder("/root/autodl-tmp/vggface2/train", transform=transform) 111 | 112 | trainset_list, testset_list = train_test_split(list(range(len(totalset.samples))), test_size=0.1, random_state=666) 113 | trainsete= Subset(totalset, trainset_list) 114 | testset= Subset(totalset, testset_list) 115 | train_loader = DataLoader(trainsete, batch_size=128, shuffle=True, num_workers=8, pin_memory=True) 116 | test_loader = DataLoader(testset, batch_size=128, shuffle=False, num_workers=8, pin_memory=True) 117 | 118 | efficientnet_b0 = models.efficientnet.efficientnet_b0(pretrained=True) 119 | 120 | efficientnet_b0.classifier = nn.Sequential( 121 | nn.Dropout(p=0.2, inplace=True), 122 | nn.Linear(1280, 8631), 123 | ) 124 | 125 | efficientnet_b0 = efficientnet_b0.to(device) 126 | 127 | optimizer = torch.optim.Adam(efficientnet_b0.parameters(), lr=0.001) 128 | 129 | best_acc = 0.0 130 | for epoch in range(args.epochs): 131 | train_loss = train_one_epoch(efficientnet_b0, train_loader, optimizer, device, logger) 132 | test_acc = test(efficientnet_b0, test_loader, device, logger) 133 | print(f"Epoch {epoch+1}, Train Loss: {train_loss:.4f}, Test Accuracy: {test_acc:.4f}") 134 | if test_acc > best_acc: 135 | best_acc = test_acc 136 | torch.save(efficientnet_b0.state_dict(), './models/efficientnet_b0_best_model.pth') 137 | print("Saved best model") 138 | 139 | 140 | 141 | if __name__ == '__main__': 142 | parser = argparse.ArgumentParser() 143 | parser.add_argument('--batch_size', default=128, type=int, help='batch size') 144 | parser.add_argument('--epochs', default=20, type=int, help='number of epochs to train') 145 | parser.add_argument('--LOG', default='./LOG/logs_efficientnet_b0', type=str) 146 | args = parser.parse_args() 147 | 148 | main(args) -------------------------------------------------------------------------------- /train_classification_models/self-train_VGGFace2/train_mobilenet_v2.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import torchvision 4 | import torch.nn.functional as F 5 | from torchvision import transforms 6 | from torch.utils.data.dataloader import DataLoader 7 | from torch.utils.data.dataset import Subset 8 | import numpy as np 9 | from PIL import Image 10 | from sklearn.model_selection import train_test_split 11 | import torchvision.models as models 12 | from typing import Any, Callable, List, Optional, Tuple 13 | from torch import nn, Tensor 14 | from tqdm import tqdm 15 | 16 | import logging 17 | import os 18 | 19 | def setup_logger(log_path): 20 | if not os.path.exists(log_path): 21 | os.makedirs(log_path) 22 | log_file = os.path.join(log_path, 'training.log') 23 | 24 | logger = logging.getLogger('TrainingLogger') 25 | logger.setLevel(logging.DEBUG) 26 | file_handler = logging.FileHandler(log_file) 27 | formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') 28 | file_handler.setFormatter(formatter) 29 | logger.addHandler(file_handler) 30 | 31 | return logger 32 | 33 | 34 | def train_one_epoch(model, train_loader, optimizer, device, logger): 35 | model.train() 36 | running_loss = 0.0 37 | for i, (inputs, labels) in enumerate(tqdm(train_loader, desc="Training", leave=False)): 38 | inputs, labels = inputs.to(device), labels.to(device) 39 | optimizer.zero_grad() 40 | outputs = model(inputs) 41 | loss = F.cross_entropy(outputs, labels) 42 | loss.backward() 43 | optimizer.step() 44 | running_loss += loss.item() 45 | if (i + 1) % 20 == 0: 46 | logger.info(f"Batch {i+1}/{len(train_loader)}, Loss: {loss.item():.4f}") 47 | return running_loss / len(train_loader) 48 | 49 | def test(model, test_loader, device, logger): 50 | model.eval() 51 | correct = 0 52 | total = 0 53 | with torch.no_grad(): 54 | for inputs, labels in tqdm(test_loader, desc="Testing", leave=False): 55 | inputs, labels = inputs.to(device), labels.to(device) 56 | outputs = model(inputs) 57 | _, predicted = outputs.max(1) 58 | total += labels.size(0) 59 | correct += predicted.eq(labels.view_as(predicted)).sum().item() 60 | 61 | accuracy = 100 * correct / total 62 | logger.info(f"Finished Testing, Final Accuracy: {accuracy:.5f}%") 63 | return accuracy 64 | 65 | class BasicConv2d(nn.Module): 66 | def __init__(self, in_channels: int, out_channels: int, **kwargs: Any) -> None: 67 | super().__init__() 68 | self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) 69 | self.bn = nn.BatchNorm2d(out_channels, eps=0.001) 70 | 71 | def forward(self, x: Tensor) -> Tensor: 72 | x = self.conv(x) 73 | x = self.bn(x) 74 | return F.relu(x, inplace=True) 75 | 76 | class Normalize(torch.nn.Module): 77 | 78 | def __init__(self, mean, std): 79 | super().__init__() 80 | self.mean = mean 81 | self.std = std 82 | 83 | def forward(self, image_tensor): 84 | image_tensor = (image_tensor-torch.tensor(self.mean, device=image_tensor.device)[:, None, None])/torch.tensor(self.std, device=image_tensor.device)[:, None, None] 85 | return image_tensor 86 | 87 | def main(args): 88 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 89 | args.device = device 90 | torch.manual_seed(666) 91 | 92 | log_path = args.LOG 93 | logger = setup_logger(log_path) 94 | 95 | args.resolution = 224 96 | Mean = [131.0912, 103.8827, 91.4953] 97 | Std = [1., 1., 1.] 98 | 99 | T_resize = 360 100 | RESIZE = 224 101 | transform = transforms.Compose([ 102 | transforms.PILToTensor(), 103 | transforms.Resize(T_resize), 104 | transforms.CenterCrop(args.resolution), 105 | transforms.Resize(RESIZE), 106 | Normalize(Mean, Std) 107 | ]) 108 | 109 | totalset = torchvision.datasets.ImageFolder("/root/autodl-tmp/vggface2/train", transform=transform) 110 | 111 | trainset_list, testset_list = train_test_split(list(range(len(totalset.samples))), test_size=0.1, random_state=666) 112 | trainsete= Subset(totalset, trainset_list) 113 | testset= Subset(totalset, testset_list) 114 | train_loader = DataLoader(trainsete, batch_size=128, shuffle=True, num_workers=8, pin_memory=True) 115 | test_loader = DataLoader(testset, batch_size=128, shuffle=False, num_workers=8, pin_memory=True) 116 | 117 | mobilenet_v2 = models.mobilenet.mobilenet_v2(pretrained=True) 118 | 119 | mobilenet_v2.classifier = nn.Sequential( 120 | nn.Dropout(p=0.2), 121 | nn.Linear(mobilenet_v2.last_channel, 8631), 122 | ) 123 | 124 | print(mobilenet_v2) 125 | 126 | mobilenet_v2 = mobilenet_v2.to(device) 127 | 128 | optimizer = torch.optim.Adam(mobilenet_v2.parameters(), lr=0.001) 129 | 130 | best_acc = 0.0 131 | for epoch in range(args.epochs): 132 | train_loss = train_one_epoch(mobilenet_v2, train_loader, optimizer, device, logger) 133 | test_acc = test(mobilenet_v2, test_loader, device, logger) 134 | print(f"Epoch {epoch+1}, Train Loss: {train_loss:.4f}, Test Accuracy: {test_acc:.4f}") 135 | if test_acc > best_acc: 136 | best_acc = test_acc 137 | torch.save(mobilenet_v2.state_dict(), './models/mobilenet_v2_best_model.pth') 138 | print("Saved best model") 139 | 140 | 141 | 142 | if __name__ == '__main__': 143 | parser = argparse.ArgumentParser() 144 | parser.add_argument('--batch_size', default=128, type=int, help='batch size') 145 | parser.add_argument('--epochs', default=20, type=int, help='number of epochs to train') 146 | parser.add_argument('--LOG', default='./LOG/logs_mobilenet_v2', type=str) 147 | args = parser.parse_args() 148 | 149 | main(args) -------------------------------------------------------------------------------- /train_classification_models/self-train_VGGFace2/train_vision_transformer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import torchvision 4 | import torch.nn.functional as F 5 | from torchvision import transforms 6 | from torch.utils.data.dataloader import DataLoader 7 | from torch.utils.data.dataset import Subset 8 | import numpy as np 9 | from PIL import Image 10 | from sklearn.model_selection import train_test_split 11 | import torchvision.models as models 12 | from typing import Any, Callable, List, Optional, Tuple 13 | from torch import nn, Tensor 14 | from tqdm import tqdm 15 | 16 | import logging 17 | import os 18 | from collections import OrderedDict 19 | 20 | def setup_logger(log_path): 21 | if not os.path.exists(log_path): 22 | os.makedirs(log_path) 23 | log_file = os.path.join(log_path, 'training.log') 24 | 25 | logger = logging.getLogger('TrainingLogger') 26 | logger.setLevel(logging.DEBUG) 27 | file_handler = logging.FileHandler(log_file) 28 | formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') 29 | file_handler.setFormatter(formatter) 30 | logger.addHandler(file_handler) 31 | 32 | return logger 33 | 34 | 35 | def train_one_epoch(model, train_loader, optimizer, device, logger): 36 | model.train() 37 | running_loss = 0.0 38 | for i, (inputs, labels) in enumerate(tqdm(train_loader, desc="Training", leave=False)): 39 | inputs, labels = inputs.to(device), labels.to(device) 40 | optimizer.zero_grad() 41 | outputs = model(inputs) 42 | loss = F.cross_entropy(outputs, labels) 43 | loss.backward() 44 | optimizer.step() 45 | running_loss += loss.item() 46 | if (i + 1) % 20 == 0: 47 | logger.info(f"Batch {i+1}/{len(train_loader)}, Loss: {loss.item():.4f}") 48 | return running_loss / len(train_loader) 49 | 50 | def test(model, test_loader, device, logger): 51 | model.eval() 52 | correct = 0 53 | total = 0 54 | with torch.no_grad(): 55 | for inputs, labels in tqdm(test_loader, desc="Testing", leave=False): 56 | inputs, labels = inputs.to(device), labels.to(device) 57 | outputs = model(inputs) 58 | _, predicted = outputs.max(1) 59 | total += labels.size(0) 60 | correct += predicted.eq(labels.view_as(predicted)).sum().item() 61 | 62 | accuracy = 100 * correct / total 63 | logger.info(f"Finished Testing, Final Accuracy: {accuracy:.5f}%") 64 | return accuracy 65 | 66 | class BasicConv2d(nn.Module): 67 | def __init__(self, in_channels: int, out_channels: int, **kwargs: Any) -> None: 68 | super().__init__() 69 | self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) 70 | self.bn = nn.BatchNorm2d(out_channels, eps=0.001) 71 | 72 | def forward(self, x: Tensor) -> Tensor: 73 | x = self.conv(x) 74 | x = self.bn(x) 75 | return F.relu(x, inplace=True) 76 | 77 | class Normalize(torch.nn.Module): 78 | 79 | def __init__(self, mean, std): 80 | super().__init__() 81 | self.mean = mean 82 | self.std = std 83 | 84 | def forward(self, image_tensor): 85 | image_tensor = (image_tensor-torch.tensor(self.mean, device=image_tensor.device)[:, None, None])/torch.tensor(self.std, device=image_tensor.device)[:, None, None] 86 | return image_tensor 87 | 88 | def main(args): 89 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 90 | args.device = device 91 | torch.manual_seed(666) 92 | 93 | log_path = args.LOG 94 | logger = setup_logger(log_path) 95 | 96 | args.resolution = 224 97 | Mean = [131.0912, 103.8827, 91.4953] 98 | Std = [1., 1., 1.] 99 | 100 | T_resize = 360 101 | RESIZE = 224 102 | transform = transforms.Compose([ 103 | transforms.PILToTensor(), 104 | transforms.Resize(T_resize), 105 | transforms.CenterCrop(args.resolution), 106 | transforms.RandomHorizontalFlip(), # This line adds horizontal flipping 107 | transforms.Resize(RESIZE), 108 | Normalize(Mean, Std) 109 | ]) 110 | 111 | totalset = torchvision.datasets.ImageFolder("/root/autodl-tmp/vggface2/train", transform=transform) 112 | 113 | trainset_list, testset_list = train_test_split(list(range(len(totalset.samples))), test_size=0.1, random_state=666) 114 | trainsete= Subset(totalset, trainset_list) 115 | testset= Subset(totalset, testset_list) 116 | train_loader = DataLoader(trainsete, batch_size=128, shuffle=True, num_workers=8, pin_memory=True) 117 | test_loader = DataLoader(testset, batch_size=128, shuffle=False, num_workers=8, pin_memory=True) 118 | 119 | vision_transformer = models.vision_transformer.vit_b_16(pretrained=True) 120 | 121 | heads_layers: OrderedDict[str, nn.Module] = OrderedDict() 122 | heads_layers["head"] = nn.Linear(768, 8631) 123 | vision_transformer.heads = nn.Sequential(heads_layers) 124 | 125 | vision_transformer = vision_transformer.to(device) 126 | 127 | optimizer = torch.optim.Adam(vision_transformer.parameters(), lr=0.001) 128 | 129 | best_acc = 0.0 130 | for epoch in range(args.epochs): 131 | train_loss = train_one_epoch(vision_transformer, train_loader, optimizer, device, logger) 132 | test_acc = test(vision_transformer, test_loader, device, logger) 133 | print(f"Epoch {epoch+1}, Train Loss: {train_loss:.4f}, Test Accuracy: {test_acc:.4f}") 134 | if test_acc > best_acc: 135 | best_acc = test_acc 136 | torch.save(vision_transformer.state_dict(), './models/vision_transformer_2_best_model.pth') 137 | print("Saved best model") 138 | 139 | 140 | 141 | if __name__ == '__main__': 142 | parser = argparse.ArgumentParser() 143 | parser.add_argument('--batch_size', default=20, type=int, help='batch size') 144 | parser.add_argument('--epochs', default=6, type=int, help='number of epochs to train') 145 | parser.add_argument('--LOG', default='./LOG/vision_transformer_2', type=str) 146 | args = parser.parse_args() 147 | 148 | main(args) -------------------------------------------------------------------------------- /genforce/datasets/distributed_sampler.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Contains the distributed data sampler. 3 | 4 | This file is mostly borrowed from `torch/utils/data/distributed.py`. 5 | 6 | However, sometimes, initialize the data loader and data sampler can be time 7 | consuming (since it will load a large amount of data at one time). To avoid 8 | re-initializing the data loader again and again, we modified the sampler to 9 | support loading the data for only one time and then repeating the data loader. 10 | Please use the class member `repeat` to control how many times you want the 11 | data load to repeat. After `repeat` times, the data will be re-loaded. 12 | 13 | NOTE: The number of repeat times should not be very large, especially when there 14 | are too many samples in the dataset. We recommend to set `repeat = 500` for 15 | datasets with ~50K samples. 16 | """ 17 | 18 | # pylint: disable=line-too-long 19 | 20 | import math 21 | from typing import TypeVar, Optional, Iterator 22 | 23 | import torch 24 | from torch.utils.data import Sampler, Dataset 25 | import torch.distributed as dist 26 | 27 | 28 | T_co = TypeVar('T_co', covariant=True) 29 | 30 | 31 | class DistributedSampler(Sampler): 32 | r"""Sampler that restricts data loading to a subset of the dataset. 33 | 34 | It is especially useful in conjunction with 35 | :class:`torch.nn.parallel.DistributedDataParallel`. In such a case, each 36 | process can pass a :class:`~torch.utils.data.DistributedSampler` instance as a 37 | :class:`~torch.utils.data.DataLoader` sampler, and load a subset of the 38 | original dataset that is exclusive to it. 39 | 40 | .. note:: 41 | Dataset is assumed to be of constant size. 42 | 43 | Arguments: 44 | dataset: Dataset used for sampling. 45 | num_replicas (int, optional): Number of processes participating in 46 | distributed training. By default, :attr:`rank` is retrieved from the 47 | current distributed group. 48 | rank (int, optional): Rank of the current process within :attr:`num_replicas`. 49 | By default, :attr:`rank` is retrieved from the current distributed 50 | group. 51 | shuffle (bool, optional): If ``True`` (default), sampler will shuffle the 52 | indices. 53 | seed (int, optional): random seed used to shuffle the sampler if 54 | :attr:`shuffle=True`. This number should be identical across all 55 | processes in the distributed group. Default: ``0``. 56 | drop_last (bool, optional): if ``True``, then the sampler will drop the 57 | tail of the data to make it evenly divisible across the number of 58 | replicas. If ``False``, the sampler will add extra indices to make 59 | the data evenly divisible across the replicas. Default: ``False``. 60 | current_iter (int, optional): Number of current iteration. Default: ``0``. 61 | repeat (int, optional): Repeating number of the whole dataloader. Default: ``1000``. 62 | 63 | .. warning:: 64 | In distributed mode, calling the :meth:`set_epoch` method at 65 | the beginning of each epoch **before** creating the :class:`DataLoader` iterator 66 | is necessary to make shuffling work properly across multiple epochs. Otherwise, 67 | the same ordering will be always used. 68 | 69 | """ 70 | 71 | def __init__(self, dataset: Dataset, num_replicas: Optional[int] = None, 72 | rank: Optional[int] = None, shuffle: bool = True, 73 | seed: int = 0, drop_last: bool = False, current_iter: int = 0, 74 | repeat: int = 1000) -> None: 75 | super().__init__(None) 76 | if num_replicas is None: 77 | if not dist.is_available(): 78 | raise RuntimeError("Requires distributed package to be available") 79 | num_replicas = dist.get_world_size() 80 | if rank is None: 81 | if not dist.is_available(): 82 | raise RuntimeError("Requires distributed package to be available") 83 | rank = dist.get_rank() 84 | self.dataset = dataset 85 | self.num_replicas = num_replicas 86 | self.rank = rank 87 | self.iter = current_iter 88 | self.drop_last = drop_last 89 | 90 | # NOTE: self.dataset_length is `repeat X len(self.dataset)` 91 | self.repeat = repeat 92 | self.dataset_length = len(self.dataset) * self.repeat 93 | 94 | if self.drop_last and self.dataset_length % self.num_replicas != 0: 95 | # Split to nearest available length that is evenly divisible. 96 | # This is to ensure each rank receives the same amount of data when 97 | # using this Sampler. 98 | self.num_samples = math.ceil( 99 | (self.dataset_length - self.num_replicas) / self.num_replicas 100 | ) 101 | else: 102 | self.num_samples = math.ceil(self.dataset_length / self.num_replicas) 103 | 104 | 105 | self.total_size = self.num_samples * self.num_replicas 106 | self.shuffle = shuffle 107 | self.seed = seed 108 | self.__generate_indices__() 109 | 110 | def __generate_indices__(self) -> None: 111 | g = torch.Generator() 112 | indices_bank = [] 113 | for iter_ in range(self.iter, self.iter + self.repeat): 114 | g.manual_seed(self.seed + iter_) 115 | indices = torch.randperm(len(self.dataset), generator=g).tolist() 116 | indices_bank.extend(indices) 117 | self.indices = indices_bank 118 | 119 | def __iter__(self) -> Iterator[T_co]: 120 | if self.shuffle: 121 | # deterministically shuffle based on iter and seed 122 | indices = self.indices 123 | else: 124 | indices = list(range(self.dataset_length)) 125 | 126 | if not self.drop_last: 127 | # add extra samples to make it evenly divisible 128 | indices += indices[:(self.total_size - len(indices))] 129 | else: 130 | # remove tail of data to make it evenly divisible. 131 | indices = indices[:self.total_size] 132 | 133 | # subsample 134 | indices = indices[self.rank:self.total_size:self.num_replicas] 135 | return iter(indices) 136 | 137 | def __len__(self) -> int: 138 | return self.num_samples 139 | 140 | def __reset__(self, iteration: int) -> None: 141 | self.iter = iteration 142 | self.__generate_indices__() 143 | 144 | # pylint: enable=line-too-long 145 | -------------------------------------------------------------------------------- /genforce/synthesize.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """A simple tool to synthesize images with pre-trained models.""" 3 | 4 | import os 5 | import argparse 6 | import subprocess 7 | from tqdm import tqdm 8 | import numpy as np 9 | 10 | import torch 11 | 12 | from models import MODEL_ZOO 13 | from models import build_generator 14 | from utils.misc import bool_parser 15 | from utils.visualizer import HtmlPageVisualizer 16 | from utils.visualizer import save_image 17 | 18 | 19 | def postprocess(images): 20 | """Post-processes images from `torch.Tensor` to `numpy.ndarray`.""" 21 | images = images.detach().cpu().numpy() 22 | images = (images + 1) * 255 / 2 23 | images = np.clip(images + 0.5, 0, 255).astype(np.uint8) 24 | images = images.transpose(0, 2, 3, 1) 25 | return images 26 | 27 | 28 | def parse_args(): 29 | """Parses arguments.""" 30 | parser = argparse.ArgumentParser( 31 | description='Synthesize images with pre-trained models.') 32 | parser.add_argument('model_name', type=str, 33 | help='Name to the pre-trained model.') 34 | parser.add_argument('--save_dir', type=str, default=None, 35 | help='Directory to save the results. If not specified, ' 36 | 'the results will be saved to ' 37 | '`work_dirs/synthesis/` by default. ' 38 | '(default: %(default)s)') 39 | parser.add_argument('--num', type=int, default=100, 40 | help='Number of samples to synthesize. ' 41 | '(default: %(default)s)') 42 | parser.add_argument('--batch_size', type=int, default=1, 43 | help='Batch size. (default: %(default)s)') 44 | parser.add_argument('--generate_html', type=bool_parser, default=True, 45 | help='Whether to use HTML page to visualize the ' 46 | 'synthesized results. (default: %(default)s)') 47 | parser.add_argument('--save_raw_synthesis', type=bool_parser, default=False, 48 | help='Whether to save raw synthesis. ' 49 | '(default: %(default)s)') 50 | parser.add_argument('--seed', type=int, default=0, 51 | help='Seed for sampling. (default: %(default)s)') 52 | parser.add_argument('--trunc_psi', type=float, default=0.7, 53 | help='Psi factor used for truncation. This is ' 54 | 'particularly applicable to StyleGAN (v1/v2). ' 55 | '(default: %(default)s)') 56 | parser.add_argument('--trunc_layers', type=int, default=8, 57 | help='Number of layers to perform truncation. This is ' 58 | 'particularly applicable to StyleGAN (v1/v2). ' 59 | '(default: %(default)s)') 60 | parser.add_argument('--randomize_noise', type=bool_parser, default=False, 61 | help='Whether to randomize the layer-wise noise. This ' 62 | 'is particularly applicable to StyleGAN (v1/v2). ' 63 | '(default: %(default)s)') 64 | return parser.parse_args() 65 | 66 | 67 | def main(): 68 | """Main function.""" 69 | args = parse_args() 70 | if args.num <= 0: 71 | return 72 | if not args.save_raw_synthesis and not args.generate_html: 73 | return 74 | 75 | # Parse model configuration. 76 | if args.model_name not in MODEL_ZOO: 77 | raise SystemExit(f'Model `{args.model_name}` is not registered in ' 78 | f'`models/model_zoo.py`!') 79 | model_config = MODEL_ZOO[args.model_name].copy() 80 | url = model_config.pop('url') # URL to download model if needed. 81 | 82 | # Get work directory and job name. 83 | if args.save_dir: 84 | work_dir = args.save_dir 85 | else: 86 | work_dir = os.path.join('work_dirs', 'synthesis') 87 | os.makedirs(work_dir, exist_ok=True) 88 | job_name = f'{args.model_name}_{args.num}' 89 | if args.save_raw_synthesis: 90 | os.makedirs(os.path.join(work_dir, job_name), exist_ok=True) 91 | 92 | # Build generation and get synthesis kwargs. 93 | print(f'Building generator for model `{args.model_name}` ...') 94 | generator = build_generator(**model_config) 95 | synthesis_kwargs = dict(trunc_psi=args.trunc_psi, 96 | trunc_layers=args.trunc_layers, 97 | randomize_noise=args.randomize_noise) 98 | print(f'Finish building generator.') 99 | 100 | # Load pre-trained weights. 101 | os.makedirs('checkpoints', exist_ok=True) 102 | checkpoint_path = os.path.join('checkpoints', args.model_name + '.pth') 103 | print(f'Loading checkpoint from `{checkpoint_path}` ...') 104 | if not os.path.exists(checkpoint_path): 105 | print(f' Downloading checkpoint from `{url}` ...') 106 | subprocess.call(['wget', '--quiet', '-O', checkpoint_path, url]) 107 | print(f' Finish downloading checkpoint.') 108 | checkpoint = torch.load(checkpoint_path, map_location='cpu') 109 | if 'generator_smooth' in checkpoint: 110 | generator.load_state_dict(checkpoint['generator_smooth']) 111 | else: 112 | generator.load_state_dict(checkpoint['generator']) 113 | generator = generator.cuda() 114 | generator.eval() 115 | print(f'Finish loading checkpoint.') 116 | 117 | # Set random seed. 118 | np.random.seed(args.seed) 119 | torch.manual_seed(args.seed) 120 | 121 | # Sample and synthesize. 122 | print(f'Synthesizing {args.num} samples ...') 123 | indices = list(range(args.num)) 124 | if args.generate_html: 125 | html = HtmlPageVisualizer(grid_size=args.num) 126 | for batch_idx in tqdm(range(0, args.num, args.batch_size)): 127 | sub_indices = indices[batch_idx:batch_idx + args.batch_size] 128 | code = torch.randn(len(sub_indices), generator.z_space_dim).cuda() 129 | with torch.no_grad(): 130 | images = generator(code, **synthesis_kwargs)['image'] 131 | images = postprocess(images) 132 | for sub_idx, image in zip(sub_indices, images): 133 | if args.save_raw_synthesis: 134 | save_path = os.path.join( 135 | work_dir, job_name, f'{sub_idx:06d}.jpg') 136 | save_image(save_path, image) 137 | if args.generate_html: 138 | row_idx, col_idx = divmod(sub_idx, html.num_cols) 139 | html.set_cell(row_idx, col_idx, image=image, 140 | text=f'Sample {sub_idx:06d}') 141 | if args.generate_html: 142 | html.save(os.path.join(work_dir, f'{job_name}.html')) 143 | print(f'Finish synthesizing {args.num} samples.') 144 | 145 | 146 | if __name__ == '__main__': 147 | main() 148 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: SMILE 2 | channels: 3 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/ 4 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/ 5 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/pytorch/ 6 | - defaults 7 | dependencies: 8 | - _libgcc_mutex=0.1 9 | - _openmp_mutex=5.1 10 | - anaconda-anon-usage=0.4.4 11 | - archspec=0.2.3 12 | - boltons=23.0.0 13 | - brotli-python=1.0.9 14 | - bzip2=1.0.8 15 | - c-ares=1.19.1 16 | - ca-certificates=2024.3.11 17 | - certifi=2024.2.2 18 | - cffi=1.16.0 19 | - charset-normalizer=2.0.4 20 | - conda=24.4.0 21 | - conda-content-trust=0.2.0 22 | - conda-libmamba-solver=24.1.0 23 | - conda-package-handling=2.2.0 24 | - conda-package-streaming=0.9.0 25 | - cryptography=42.0.5 26 | - distro=1.9.0 27 | - expat=2.6.2 28 | - fmt=9.1.0 29 | - icu=73.1 30 | - idna=3.7 31 | - jsonpatch=1.33 32 | - jsonpointer=2.1 33 | - krb5=1.20.1 34 | - ld_impl_linux-64=2.38 35 | - libarchive=3.6.2 36 | - libcurl=8.7.1 37 | - libedit=3.1.20230828 38 | - libev=4.33 39 | - libffi=3.4.4 40 | - libgcc-ng=11.2.0 41 | - libgomp=11.2.0 42 | - libmamba=1.5.8 43 | - libmambapy=1.5.8 44 | - libnghttp2=1.57.0 45 | - libsolv=0.7.24 46 | - libssh2=1.11.0 47 | - libstdcxx-ng=11.2.0 48 | - libuuid=1.41.5 49 | - libxml2=2.10.4 50 | - lz4-c=1.9.4 51 | - menuinst=2.0.2 52 | - ncurses=6.4 53 | - openssl=3.0.13 54 | - packaging=23.2 55 | - pcre2=10.42 56 | - pip=24.0 57 | - platformdirs=3.10.0 58 | - pluggy=1.0.0 59 | - pybind11-abi=5 60 | - pycosat=0.6.6 61 | - pycparser=2.21 62 | - pysocks=1.7.1 63 | - python=3.12.3 64 | - readline=8.2 65 | - reproc=14.2.4 66 | - reproc-cpp=14.2.4 67 | - requests=2.31.0 68 | - ruamel.yaml=0.17.21 69 | - setuptools=69.5.1 70 | - sqlite=3.45.3 71 | - tk=8.6.14 72 | - tqdm=4.66.2 73 | - truststore=0.8.0 74 | - tzdata=2024a 75 | - urllib3=2.1.0 76 | - wheel=0.43.0 77 | - xz=5.4.6 78 | - yaml-cpp=0.8.0 79 | - zlib=1.2.13 80 | - zstandard=0.22.0 81 | - zstd=1.5.5 82 | - pip: 83 | - absl-py==2.1.0 84 | - anyio==4.4.0 85 | - argon2-cffi==23.1.0 86 | - argon2-cffi-bindings==21.2.0 87 | - arrow==1.3.0 88 | - asttokens==2.4.1 89 | - async-lru==2.0.4 90 | - attrs==23.2.0 91 | - babel==2.15.0 92 | - bayesian-optimization==1.5.1 93 | - beautifulsoup4==4.12.3 94 | - bleach==6.1.0 95 | - cma==4.0.0 96 | - colorama==0.4.6 97 | - comm==0.2.2 98 | - contourpy==1.2.1 99 | - cycler==0.12.1 100 | - debugpy==1.8.1 101 | - decorator==5.1.1 102 | - defusedxml==0.7.1 103 | - easydict==1.13 104 | - einops==0.8.0 105 | - executing==2.0.1 106 | - facenet-pytorch==2.5.3 107 | - fastjsonschema==2.19.1 108 | - filelock==3.14.0 109 | - fonttools==4.53.0 110 | - fqdn==1.5.1 111 | - fsspec==2024.5.0 112 | - grpcio==1.64.0 113 | - h11==0.14.0 114 | - h5py==3.10.0 115 | - httpcore==1.0.5 116 | - httpx==0.27.0 117 | - huggingface-hub==0.25.1 118 | - imageio==2.25.1 119 | - importlib-metadata==4.11.2 120 | - ipykernel==6.29.4 121 | - ipython==8.25.0 122 | - ipywidgets==8.1.3 123 | - isoduration==20.11.0 124 | - jedi==0.19.1 125 | - jinja2==3.1.4 126 | - joblib==1.4.2 127 | - json5==0.9.25 128 | - jsonschema==4.22.0 129 | - jsonschema-specifications==2023.12.1 130 | - jupyter-client==8.6.2 131 | - jupyter-core==5.7.2 132 | - jupyter-events==0.10.0 133 | - jupyter-lsp==2.2.5 134 | - jupyter-server==2.14.1 135 | - jupyter-server-terminals==0.5.3 136 | - jupyterlab==4.2.1 137 | - jupyterlab-language-pack-zh-cn==4.2.post1 138 | - jupyterlab-pygments==0.3.0 139 | - jupyterlab-server==2.27.2 140 | - jupyterlab-widgets==3.0.11 141 | - keras==2.13.1 142 | - kiwisolver==1.4.5 143 | - llvmlite==0.43.0 144 | - markdown==3.6 145 | - markupsafe==2.1.5 146 | - matplotlib==3.5.1 147 | - matplotlib-inline==0.1.7 148 | - mistune==3.0.2 149 | - mpmath==1.3.0 150 | - mtcnn==0.1.1 151 | - mxnet==1.9.1 152 | - nbclient==0.10.0 153 | - nbconvert==7.16.4 154 | - nbformat==5.10.4 155 | - nest-asyncio==1.6.0 156 | - networkx==3.3 157 | - nevergrad==0.5.0 158 | - notebook-shim==0.2.4 159 | - numba==0.60.0 160 | - numpy==1.26.4 161 | - nvidia-cublas-cu12==12.1.3.1 162 | - nvidia-cuda-cupti-cu12==12.1.105 163 | - nvidia-cuda-nvrtc-cu12==12.1.105 164 | - nvidia-cuda-runtime-cu12==12.1.105 165 | - nvidia-cudnn-cu12==8.9.2.26 166 | - nvidia-cufft-cu12==11.0.2.54 167 | - nvidia-curand-cu12==10.3.2.106 168 | - nvidia-cusolver-cu12==11.4.5.107 169 | - nvidia-cusparse-cu12==12.1.0.106 170 | - nvidia-nccl-cu12==2.20.5 171 | - nvidia-nvjitlink-cu12==12.5.40 172 | - nvidia-nvtx-cu12==12.1.105 173 | - onnx==1.17.0 174 | - onnx2torch==1.5.13 175 | - opencv-python==4.5.5.62 176 | - overrides==7.7.0 177 | - pandocfilters==1.5.1 178 | - parso==0.8.4 179 | - pexpect==4.9.0 180 | - pillow==10.3.0 181 | - prometheus-client==0.20.0 182 | - prompt-toolkit==3.0.45 183 | - protobuf==5.27.0 184 | - psutil==5.9.8 185 | - ptyprocess==0.7.0 186 | - pure-eval==0.2.2 187 | - pygments==2.18.0 188 | - pynndescent==0.5.13 189 | - pyparsing==3.1.2 190 | - python-dateutil==2.9.0.post0 191 | - python-graphviz==0.8.4 192 | - python-json-logger==2.0.7 193 | - pyyaml==6.0.1 194 | - pyzmq==26.0.3 195 | - referencing==0.35.1 196 | - rfc3339-validator==0.1.4 197 | - rfc3986-validator==0.1.1 198 | - rpds-py==0.18.1 199 | - scikit-learn==1.5.2 200 | - scipy==1.14.1 201 | - send2trash==1.8.3 202 | - six==1.16.0 203 | - sniffio==1.3.1 204 | - soupsieve==2.5 205 | - stack-data==0.6.3 206 | - supervisor==4.2.5 207 | - sympy==1.12.1 208 | - tensorboard==2.16.2 209 | - tensorboard-data-server==0.7.2 210 | - tensorboard-logger==0.1.0 211 | - terminado==0.18.1 212 | - threadpoolctl==3.5.0 213 | - tinycss2==1.3.0 214 | - torch==2.3.0 215 | - torchvision==0.18.0+cu121 216 | - tornado==6.4 217 | - traitlets==5.14.3 218 | - types-python-dateutil==2.9.0.20240316 219 | - typing-extensions==4.12.1 220 | - umap==0.1.1 221 | - umap-learn==0.5.7 222 | - uri-template==1.3.0 223 | - vit-pytorch==1.7.12 224 | - wcwidth==0.2.13 225 | - webcolors==1.13 226 | - webencodings==0.5.1 227 | - websocket-client==1.8.0 228 | - werkzeug==3.0.3 229 | - widgetsnbextension==4.0.11 230 | - zipp==3.20.2 231 | prefix: /root/miniconda3 232 | -------------------------------------------------------------------------------- /swintransformer_4finetune.py: -------------------------------------------------------------------------------- 1 | import math 2 | from functools import partial 3 | from typing import Any, Callable, List, Optional 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | from torch import nn, Tensor 8 | 9 | from torchvision.models.swin_transformer import _ovewrite_named_param, handle_legacy_interface 10 | from torchvision.models.swin_transformer import _patch_merging_pad, SwinTransformerBlock 11 | 12 | from torchvision.utils import _log_api_usage_once 13 | from torchvision.ops.misc import MLP, Permute 14 | 15 | class PatchMerging(nn.Module): 16 | """Patch Merging Layer. 17 | Args: 18 | dim (int): Number of input channels. 19 | norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. 20 | """ 21 | 22 | def __init__(self, dim: int, norm_layer: Callable[..., nn.Module] = nn.LayerNorm): 23 | super().__init__() 24 | _log_api_usage_once(self) 25 | self.dim = dim 26 | self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) 27 | self.norm = norm_layer(4 * dim) 28 | 29 | def forward(self, x: Tensor): 30 | """ 31 | Args: 32 | x (Tensor): input tensor with expected layout of [..., H, W, C] 33 | Returns: 34 | Tensor with layout of [..., H/2, W/2, 2*C] 35 | """ 36 | x = _patch_merging_pad(x) 37 | x = self.norm(x) 38 | x = self.reduction(x) # ... H/2 W/2 2*C 39 | return x 40 | 41 | 42 | class SwinTransformer_E(nn.Module): 43 | """ 44 | Implements Swin Transformer from the `"Swin Transformer: Hierarchical Vision Transformer using 45 | Shifted Windows" `_ paper. 46 | Args: 47 | patch_size (List[int]): Patch size. 48 | embed_dim (int): Patch embedding dimension. 49 | depths (List(int)): Depth of each Swin Transformer layer. 50 | num_heads (List(int)): Number of attention heads in different layers. 51 | window_size (List[int]): Window size. 52 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0. 53 | dropout (float): Dropout rate. Default: 0.0. 54 | attention_dropout (float): Attention dropout rate. Default: 0.0. 55 | stochastic_depth_prob (float): Stochastic depth rate. Default: 0.1. 56 | num_classes (int): Number of classes for classification head. Default: 1000. 57 | block (nn.Module, optional): SwinTransformer Block. Default: None. 58 | norm_layer (nn.Module, optional): Normalization layer. Default: None. 59 | downsample_layer (nn.Module): Downsample layer (patch merging). Default: PatchMerging. 60 | """ 61 | 62 | def __init__( 63 | self, 64 | patch_size: List[int], 65 | embed_dim: int, 66 | depths: List[int], 67 | num_heads: List[int], 68 | window_size: List[int], 69 | mlp_ratio: float = 4.0, 70 | dropout: float = 0.0, 71 | attention_dropout: float = 0.0, 72 | stochastic_depth_prob: float = 0.1, 73 | num_classes: int = 1000, 74 | norm_layer: Optional[Callable[..., nn.Module]] = None, 75 | block: Optional[Callable[..., nn.Module]] = None, 76 | downsample_layer: Callable[..., nn.Module] = PatchMerging, 77 | num_experts: int = 8, 78 | ): 79 | super().__init__() 80 | _log_api_usage_once(self) 81 | self.num_classes = num_classes 82 | self.num_experts = num_experts 83 | 84 | if block is None: 85 | block = SwinTransformerBlock 86 | if norm_layer is None: 87 | norm_layer = partial(nn.LayerNorm, eps=1e-5) 88 | 89 | layers: List[nn.Module] = [] 90 | # split image into non-overlapping patches 91 | layers.append( 92 | nn.Sequential( 93 | nn.Conv2d( 94 | 3, embed_dim, kernel_size=(patch_size[0], patch_size[1]), stride=(patch_size[0], patch_size[1]) 95 | ), 96 | Permute([0, 2, 3, 1]), 97 | norm_layer(embed_dim), 98 | ) 99 | ) 100 | 101 | total_stage_blocks = sum(depths) 102 | stage_block_id = 0 103 | # build SwinTransformer blocks 104 | for i_stage in range(len(depths)): 105 | stage: List[nn.Module] = [] 106 | dim = embed_dim * 2**i_stage 107 | for i_layer in range(depths[i_stage]): 108 | # adjust stochastic depth probability based on the depth of the stage block 109 | sd_prob = stochastic_depth_prob * float(stage_block_id) / (total_stage_blocks - 1) 110 | stage.append( 111 | block( 112 | dim, 113 | num_heads[i_stage], 114 | window_size=window_size, 115 | shift_size=[0 if i_layer % 2 == 0 else w // 2 for w in window_size], 116 | mlp_ratio=mlp_ratio, 117 | dropout=dropout, 118 | attention_dropout=attention_dropout, 119 | stochastic_depth_prob=sd_prob, 120 | norm_layer=norm_layer, 121 | ) 122 | ) 123 | stage_block_id += 1 124 | layers.append(nn.Sequential(*stage)) 125 | # add patch merging layer 126 | if i_stage < (len(depths) - 1): 127 | layers.append(downsample_layer(dim, norm_layer)) 128 | self.features = nn.Sequential(*layers) 129 | 130 | num_features = embed_dim * 2 ** (len(depths) - 1) 131 | self.norm = norm_layer(num_features) 132 | self.permute = Permute([0, 3, 1, 2]) # B H W C -> B C H W 133 | self.avgpool = nn.AdaptiveAvgPool2d(1) 134 | self.flatten = nn.Flatten(1) 135 | self.head = nn.Linear(num_features, num_classes) 136 | 137 | for m in self.modules(): 138 | if isinstance(m, nn.Linear): 139 | nn.init.trunc_normal_(m.weight, std=0.02) 140 | if m.bias is not None: 141 | nn.init.zeros_(m.bias) 142 | 143 | def forward(self, x): 144 | for i, layer in enumerate(self.features): 145 | if i == len(self.features) - 1: 146 | break 147 | x = layer(x) 148 | for ii, layer_ in enumerate(self.features[-1]): 149 | if ii == len(self.features[-1]) - 1: 150 | break 151 | x = layer_(x) 152 | outs = [] 153 | for ind in range(self.num_experts): 154 | tmp = self.features[-1][-1].mlp[ind](x) 155 | tmp = self.norm(tmp) 156 | tmp = self.permute(tmp) 157 | tmp = self.avgpool(tmp) 158 | tmp = self.flatten(tmp) 159 | tmp = self.head[ind](tmp) 160 | outs.append(tmp) 161 | x = torch.stack(outs, dim=1).mean(dim=1) 162 | return x, outs 163 | -------------------------------------------------------------------------------- /genforce/my_get_GD.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """This file is modifed from synthesize.py. The goal is to return a generator which output an image in range [0., 1.]""" 3 | 4 | import os 5 | import argparse 6 | import subprocess 7 | from tqdm import tqdm 8 | import numpy as np 9 | 10 | import torch 11 | from torchvision.utils import save_image 12 | 13 | from .models import MODEL_ZOO 14 | from .models import build_generator, build_discriminator 15 | from .utils.misc import bool_parser 16 | from .utils.visualizer import HtmlPageVisualizer 17 | 18 | 19 | def postprocess(images): 20 | """change the range from [-1, 1] to [0., 1.]""" 21 | images = torch.clamp((images + 1.) / 2., 0., 1.) 22 | return images 23 | 24 | 25 | def parse_args(model_name, num, batch_size, trunc_psi=0.7, trunc_layers=8): 26 | """Parses arguments.""" 27 | parser = argparse.ArgumentParser( 28 | description='Synthesize images with pre-trained models.') 29 | parser.add_argument('model_name', type=str, 30 | help='Name to the pre-trained model.') 31 | parser.add_argument('--save_dir', type=str, default=None, 32 | help='Directory to save the results. If not specified, ' 33 | 'the results will be saved to ' 34 | '`work_dirs/synthesis/` by default. ' 35 | '(default: %(default)s)') 36 | parser.add_argument('--num', type=int, default=num, 37 | help='Number of samples to synthesize. ' 38 | '(default: %(default)s)') 39 | parser.add_argument('--batch_size', type=int, default=batch_size, 40 | help='Batch size. (default: %(default)s)') 41 | parser.add_argument('--seed', type=int, default=0, 42 | help='Seed for sampling. (default: %(default)s)') 43 | parser.add_argument('--trunc_psi', type=float, default=trunc_psi, 44 | help='Psi factor used for truncation. This is ' 45 | 'particularly applicable to StyleGAN (v1/v2). ' 46 | '(default: %(default)s)') 47 | parser.add_argument('--trunc_layers', type=int, default=trunc_layers, 48 | help='Number of layers to perform truncation. This is ' 49 | 'particularly applicable to StyleGAN (v1/v2). ' 50 | '(default: %(default)s)') 51 | parser.add_argument('--randomize_noise', type=bool_parser, default=False, 52 | help='Whether to randomize the layer-wise noise. This ' 53 | 'is particularly applicable to StyleGAN (v1/v2). ' 54 | '(default: %(default)s)') 55 | # return parser.parse_args([model_name, f'--num={num}', f'--batch_size={batch_size}', ]) 56 | return parser.parse_args([model_name, ]) 57 | 58 | 59 | def main(device, model_name, num, batch_size, use_w_space=True, use_discri=True, repeat_w=True, use_z_plus_space=False, trunc_psi=0.7, trunc_layers=8): 60 | """Main function.""" 61 | args = parse_args(model_name, num, batch_size, trunc_psi, trunc_layers) 62 | print(args) 63 | if args.num <= 0: 64 | return 65 | 66 | # Parse model configuration. 67 | if args.model_name not in MODEL_ZOO: 68 | raise SystemExit(f'Model `{args.model_name}` is not registered in ' 69 | f'`models/model_zoo.py`!') 70 | model_config = MODEL_ZOO[args.model_name].copy() 71 | url = model_config.pop('url') # URL to download model if needed. 72 | 73 | # Get work directory and job name. 74 | if args.save_dir: 75 | work_dir = args.save_dir 76 | else: 77 | work_dir = os.path.join('work_dirs', 'synthesis') 78 | os.makedirs(work_dir, exist_ok=True) 79 | 80 | # Build generation and get synthesis kwargs. 81 | print(f'Building generator for model `{args.model_name}` ...') 82 | if model_name.startswith('stylegan'): 83 | generator = build_generator(**model_config, repeat_w=repeat_w) 84 | else: 85 | generator = build_generator(**model_config) 86 | synthesis_kwargs = dict(trunc_psi=args.trunc_psi, 87 | trunc_layers=args.trunc_layers, 88 | randomize_noise=args.randomize_noise) 89 | print('Finish building generator.') 90 | 91 | # Build discriminator 92 | if use_discri: 93 | print(f'Building discriminator for model `{args.model_name}` ...') 94 | discriminator = build_discriminator(**model_config) 95 | print('Finish building discriminator.') 96 | else: 97 | discriminator = None 98 | 99 | # Load pre-trained weights. 100 | os.makedirs('checkpoints', exist_ok=True) 101 | checkpoint_path = os.path.join('checkpoints', args.model_name + '.pth') 102 | # print(checkpoint_path) 103 | print(f'Loading checkpoint from `{checkpoint_path}` ...') 104 | if not os.path.exists(checkpoint_path): 105 | print(f' Downloading checkpoint from `{url}` ...') 106 | subprocess.call(['wget', '--quiet', '-O', checkpoint_path, url]) 107 | print(' Finish downloading checkpoint.') 108 | checkpoint = torch.load(checkpoint_path, map_location='cpu') 109 | 110 | if 'generator_smooth' in checkpoint: 111 | generator.load_state_dict(checkpoint['generator_smooth']) 112 | else: 113 | generator.load_state_dict(checkpoint['generator']) 114 | generator = generator.to(device) 115 | generator.eval() 116 | if use_discri: 117 | discriminator.load_state_dict(checkpoint['discriminator']) 118 | discriminator = discriminator.to(device) 119 | discriminator.eval() 120 | print('Finish loading checkpoint.') 121 | 122 | # Set random seed. 123 | np.random.seed(args.seed) 124 | torch.manual_seed(args.seed) 125 | 126 | def fake_generator(code): 127 | # Sample and synthesize. 128 | # print(f'Synthesizing {args.num} samples ...') 129 | # code = torch.randn(args.batch_size, generator.z_space_dim).cuda() 130 | if use_z_plus_space: 131 | code = generator.mapping(code)['w'] 132 | code = code.view(args.batch_size, generator.num_layers, generator.w_space_dim) 133 | images = generator(code, **synthesis_kwargs, use_w_space=use_w_space)['image'] 134 | images = postprocess(images) 135 | # save_image(images, os.path.join(work_dir, 'tmp.png'), nrow=5) 136 | # print(f'Finish synthesizing {args.num} samples.') 137 | return images 138 | 139 | return Fake_G(generator, fake_generator), discriminator 140 | 141 | 142 | class Fake_G: 143 | 144 | def __init__(self, G, g_function): 145 | self.G = G 146 | self.g_function = g_function 147 | 148 | def __call__(self, code): 149 | # print(f'code.shape {code.shape}') 150 | return self.g_function(code) 151 | 152 | def zero_grad(self): 153 | self.G.zero_grad() 154 | 155 | 156 | if __name__ == '__main__': 157 | # main('stylegan_ffhq1024', 7, 7) 158 | # main('stylegan_ffhq256', 35, 35) 159 | main('stylegan_celeba_partial256', 35, 35) 160 | -------------------------------------------------------------------------------- /vgg_face_dag.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import random 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | class Vgg_face_dag(nn.Module): 9 | 10 | def __init__(self, use_dropout=False): 11 | super(Vgg_face_dag, self).__init__() 12 | self.meta = {'mean': [129.186279296875, 104.76238250732422, 93.59396362304688], 13 | 'std': [1, 1, 1], 14 | 'imageSize': [224, 224, 3]} 15 | self.conv1_1 = nn.Conv2d(3, 64, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1)) 16 | self.relu1_1 = nn.ReLU(inplace=True) 17 | self.conv1_2 = nn.Conv2d(64, 64, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1)) 18 | self.relu1_2 = nn.ReLU(inplace=True) 19 | self.pool1 = nn.MaxPool2d(kernel_size=[2, 2], stride=[2, 2], padding=0, dilation=1, ceil_mode=False) 20 | self.conv2_1 = nn.Conv2d(64, 128, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1)) 21 | self.relu2_1 = nn.ReLU(inplace=True) 22 | self.conv2_2 = nn.Conv2d(128, 128, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1)) 23 | self.relu2_2 = nn.ReLU(inplace=True) 24 | self.pool2 = nn.MaxPool2d(kernel_size=[2, 2], stride=[2, 2], padding=0, dilation=1, ceil_mode=False) 25 | self.conv3_1 = nn.Conv2d(128, 256, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1)) 26 | self.relu3_1 = nn.ReLU(inplace=True) 27 | self.conv3_2 = nn.Conv2d(256, 256, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1)) 28 | self.relu3_2 = nn.ReLU(inplace=True) 29 | self.conv3_3 = nn.Conv2d(256, 256, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1)) 30 | self.relu3_3 = nn.ReLU(inplace=True) 31 | self.pool3 = nn.MaxPool2d(kernel_size=[2, 2], stride=[2, 2], padding=0, dilation=1, ceil_mode=False) 32 | self.conv4_1 = nn.Conv2d(256, 512, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1)) 33 | self.relu4_1 = nn.ReLU(inplace=True) 34 | self.conv4_2 = nn.Conv2d(512, 512, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1)) 35 | self.relu4_2 = nn.ReLU(inplace=True) 36 | self.conv4_3 = nn.Conv2d(512, 512, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1)) 37 | self.relu4_3 = nn.ReLU(inplace=True) 38 | self.pool4 = nn.MaxPool2d(kernel_size=[2, 2], stride=[2, 2], padding=0, dilation=1, ceil_mode=False) 39 | self.conv5_1 = nn.Conv2d(512, 512, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1)) 40 | self.relu5_1 = nn.ReLU(inplace=True) 41 | self.conv5_2 = nn.Conv2d(512, 512, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1)) 42 | self.relu5_2 = nn.ReLU(inplace=True) 43 | self.conv5_3 = nn.Conv2d(512, 512, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1)) 44 | self.relu5_3 = nn.ReLU(inplace=True) 45 | self.pool5 = nn.MaxPool2d(kernel_size=[2, 2], stride=[2, 2], padding=0, dilation=1, ceil_mode=False) 46 | self.fc6 = nn.Linear(in_features=25088, out_features=4096, bias=True) 47 | self.relu6 = nn.ReLU(inplace=True) 48 | self.dropout6 = nn.Dropout(p=0.5) 49 | self.fc7 = nn.Linear(in_features=4096, out_features=4096, bias=True) 50 | self.relu7 = nn.ReLU(inplace=True) 51 | self.dropout7 = nn.Dropout(p=0.5) 52 | self.fc8 = nn.Linear(in_features=4096, out_features=2622, bias=True) 53 | 54 | self.use_dropout = use_dropout 55 | 56 | if self.use_dropout: 57 | fc_dropout_probs = {1: 0.6, 2: 0.5} 58 | conv_dropout_probs = {1: 0.5, 2: 0.2, 3: 0.2, 4: 0.1, 5: 0.1} 59 | self.fc_dropouts = {k: partial(nn.functional.dropout, p=v) for k, v in fc_dropout_probs.items()} 60 | self.conv_dropouts = {k: nn.Dropout2d(v) for k, v in conv_dropout_probs.items()} 61 | 62 | print(f'fc_dropout_probs: {fc_dropout_probs}\n' 63 | f'conv_dropout_probs: {conv_dropout_probs}' 64 | ) 65 | 66 | def forward(self, x0): 67 | if self.use_dropout: 68 | for x in self.conv_dropouts.values(): 69 | x.training = True 70 | 71 | k = random.randint(1, 5) 72 | conv_dropout_layers = set(random.choices(range(1, 7), k=k)) # 6 means no dropout 73 | k = 2 # random.randint(1, 2) 74 | fc_dropout_layers = set(random.choices(range(1, 4), k=k)) # 3 means no dropout 75 | 76 | conv_dropout = self.conv_dropouts[len(conv_dropout_layers)] 77 | 78 | fc_dropout = self.fc_dropouts[len(fc_dropout_layers)] 79 | else: 80 | conv_dropout_layers = set() 81 | fc_dropout_layers = set() 82 | conv_dropout = None 83 | fc_dropout = None 84 | 85 | x1 = self.conv1_1(x0) 86 | x2 = self.relu1_1(x1) 87 | x3 = self.conv1_2(x2) 88 | x4 = self.relu1_2(x3) 89 | x5 = self.pool1(x4) 90 | 91 | if 1 in conv_dropout_layers: 92 | x5 = conv_dropout(x5) 93 | 94 | x6 = self.conv2_1(x5) 95 | x7 = self.relu2_1(x6) 96 | x8 = self.conv2_2(x7) 97 | x9 = self.relu2_2(x8) 98 | x10 = self.pool2(x9) 99 | 100 | if 2 in conv_dropout_layers: 101 | x10 = conv_dropout(x10) 102 | 103 | x11 = self.conv3_1(x10) 104 | x12 = self.relu3_1(x11) 105 | x13 = self.conv3_2(x12) 106 | x14 = self.relu3_2(x13) 107 | x15 = self.conv3_3(x14) 108 | x16 = self.relu3_3(x15) 109 | x17 = self.pool3(x16) 110 | 111 | if 3 in conv_dropout_layers: 112 | x17 = conv_dropout(x17) 113 | 114 | x18 = self.conv4_1(x17) 115 | x19 = self.relu4_1(x18) 116 | x20 = self.conv4_2(x19) 117 | x21 = self.relu4_2(x20) 118 | x22 = self.conv4_3(x21) 119 | x23 = self.relu4_3(x22) 120 | x24 = self.pool4(x23) 121 | 122 | if 4 in conv_dropout_layers: 123 | x24 = conv_dropout(x24) 124 | 125 | x25 = self.conv5_1(x24) 126 | x26 = self.relu5_1(x25) 127 | x27 = self.conv5_2(x26) 128 | x28 = self.relu5_2(x27) 129 | x29 = self.conv5_3(x28) 130 | x30 = self.relu5_3(x29) 131 | x31_preflatten = self.pool5(x30) 132 | 133 | if 5 in conv_dropout_layers: 134 | x31_preflatten = conv_dropout(x31_preflatten) 135 | 136 | x31 = x31_preflatten.view(x31_preflatten.size(0), -1) 137 | x32 = self.fc6(x31) 138 | x33 = self.relu6(x32) 139 | x34 = self.dropout6(x33) 140 | 141 | if 1 in fc_dropout_layers: 142 | x34 = fc_dropout(x34) 143 | 144 | x35 = self.fc7(x34) 145 | x36 = self.relu7(x35) 146 | x37 = self.dropout7(x36) 147 | 148 | if 2 in fc_dropout_layers: 149 | x37 = fc_dropout(x37) 150 | 151 | x38 = self.fc8(x37) 152 | return x38 153 | 154 | 155 | def vgg_face_dag(weights_path=None, **kwargs): 156 | """ 157 | load imported model instance 158 | 159 | Args: 160 | weights_path (str): If set, loads model weights from the given path 161 | """ 162 | model = Vgg_face_dag(**kwargs) 163 | if weights_path: 164 | state_dict = torch.load(weights_path) 165 | model.load_state_dict(state_dict) 166 | return model 167 | -------------------------------------------------------------------------------- /train_classification_models/self-train_VGGFace2/train_inception_v3.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import torchvision 4 | import torch.nn.functional as F 5 | from torchvision import transforms 6 | from torch.utils.data.dataloader import DataLoader 7 | from torch.utils.data.dataset import Subset 8 | import numpy as np 9 | from PIL import Image 10 | from sklearn.model_selection import train_test_split 11 | import torchvision.models as models 12 | from typing import Any, Callable, List, Optional, Tuple 13 | from torch import nn, Tensor 14 | from tqdm import tqdm 15 | import logging 16 | import os 17 | 18 | 19 | def setup_logger(log_path): 20 | if not os.path.exists(log_path): 21 | os.makedirs(log_path) 22 | log_file = os.path.join(log_path, 'training.log') 23 | 24 | logger = logging.getLogger('TrainingLogger') 25 | logger.setLevel(logging.DEBUG) 26 | file_handler = logging.FileHandler(log_file) 27 | formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') 28 | file_handler.setFormatter(formatter) 29 | logger.addHandler(file_handler) 30 | 31 | return logger 32 | 33 | def train_one_epoch(model, train_loader, optimizer, device, logger): 34 | model.train() 35 | running_loss = 0.0 36 | for i, (inputs, labels) in enumerate(tqdm(train_loader, desc="Training", leave=False)): 37 | inputs, labels = inputs.to(device), labels.to(device) 38 | optimizer.zero_grad() 39 | outputs, aux_outputs = model(inputs) 40 | loss1 = F.cross_entropy(outputs, labels) 41 | loss2 = F.cross_entropy(aux_outputs, labels) 42 | loss = loss1 + 0.3 * loss2 43 | loss.backward() 44 | optimizer.step() 45 | running_loss += loss.item() 46 | if (i + 1) % 20 == 0: 47 | logger.info(f"Batch {i+1}/{len(train_loader)}, Loss: {loss.item():.4f}") 48 | return running_loss / len(train_loader) 49 | 50 | def test(model, test_loader, device, logger): 51 | model.eval() 52 | correct = 0 53 | total = 0 54 | with torch.no_grad(): 55 | for inputs, labels in tqdm(test_loader, desc="Testing", leave=False): 56 | inputs, labels = inputs.to(device), labels.to(device) 57 | outputs = model(inputs) 58 | _, predicted = torch.max(outputs.data, 1) 59 | total += labels.size(0) 60 | correct += (predicted == labels).sum().item() 61 | accuracy = 100 * correct / total 62 | logger.info(f"Finished Testing, Final Accuracy: {accuracy:.5f}%") 63 | return accuracy 64 | 65 | class InceptionAux(nn.Module): 66 | def __init__( 67 | self, in_channels: int, num_classes: int, conv_block: Optional[Callable[..., nn.Module]] = None 68 | ) -> None: 69 | super().__init__() 70 | if conv_block is None: 71 | conv_block = BasicConv2d 72 | self.conv0 = conv_block(in_channels, 128, kernel_size=1) 73 | self.conv1 = conv_block(128, 768, kernel_size=5) 74 | self.conv1.stddev = 0.01 # type: ignore[assignment] 75 | self.fc = nn.Linear(768, num_classes) 76 | self.fc.stddev = 0.001 # type: ignore[assignment] 77 | 78 | def forward(self, x: Tensor) -> Tensor: 79 | # N x 768 x 17 x 17 80 | x = F.avg_pool2d(x, kernel_size=5, stride=3) 81 | # N x 768 x 5 x 5 82 | x = self.conv0(x) 83 | # N x 128 x 5 x 5 84 | x = self.conv1(x) 85 | # N x 768 x 1 x 1 86 | # Adaptive average pooling 87 | x = F.adaptive_avg_pool2d(x, (1, 1)) 88 | # N x 768 x 1 x 1 89 | x = torch.flatten(x, 1) 90 | # N x 768 91 | x = self.fc(x) 92 | # N x 1000 93 | return x 94 | 95 | class BasicConv2d(nn.Module): 96 | def __init__(self, in_channels: int, out_channels: int, **kwargs: Any) -> None: 97 | super().__init__() 98 | self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) 99 | self.bn = nn.BatchNorm2d(out_channels, eps=0.001) 100 | 101 | def forward(self, x: Tensor) -> Tensor: 102 | x = self.conv(x) 103 | x = self.bn(x) 104 | return F.relu(x, inplace=True) 105 | 106 | class Normalize(torch.nn.Module): 107 | 108 | def __init__(self, mean, std): 109 | super().__init__() 110 | self.mean = mean 111 | self.std = std 112 | 113 | def forward(self, image_tensor): 114 | image_tensor = (image_tensor-torch.tensor(self.mean, device=image_tensor.device)[:, None, None])/torch.tensor(self.std, device=image_tensor.device)[:, None, None] 115 | return image_tensor 116 | 117 | def main(args): 118 | device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu') 119 | args.device = device 120 | torch.manual_seed(666) 121 | 122 | log_path = args.LOG 123 | logger = setup_logger(log_path) 124 | 125 | args.resolution = 224 126 | Mean = [131.0912, 103.8827, 91.4953] 127 | Std = [1., 1., 1.] 128 | 129 | T_resize = 360 130 | RESIZE = 342 131 | transform = transforms.Compose([ 132 | transforms.PILToTensor(), 133 | transforms.Resize(T_resize), 134 | transforms.CenterCrop(args.resolution), 135 | transforms.Resize(RESIZE), 136 | Normalize(Mean, Std) 137 | ]) 138 | 139 | totalset = torchvision.datasets.ImageFolder("/root/autodl-tmp/vggface2/train", transform=transform) 140 | 141 | trainset_list, testset_list = train_test_split(list(range(len(totalset.samples))), test_size=0.1, random_state=666) 142 | trainset= Subset(totalset, trainset_list) 143 | testset= Subset(totalset, testset_list) 144 | 145 | train_loader = DataLoader(trainset, batch_size=128, shuffle=True, num_workers=8, pin_memory=True) 146 | test_loader = DataLoader(testset, batch_size=128, shuffle=False, num_workers=8, pin_memory=True) 147 | 148 | inception_v3 = models.inception_v3(pretrained=True) 149 | 150 | # self.AuxLogits = inception_aux(768, num_classes) 151 | # self.fc = nn.Linear(2048, num_classes) 152 | 153 | if hasattr(inception_v3, 'AuxLogits'): 154 | inception_v3.AuxLogits = InceptionAux(768, 8631) 155 | 156 | inception_v3.fc = nn.Linear(2048, 8631) 157 | 158 | inception_v3 = inception_v3.to(device) 159 | 160 | optimizer = torch.optim.Adam(inception_v3.parameters(), lr=0.001) 161 | 162 | best_acc = 0.0 163 | for epoch in range(args.epochs): 164 | train_loss = train_one_epoch(inception_v3, train_loader, optimizer, device, logger) 165 | test_acc = test(inception_v3, test_loader, device, logger) 166 | print(f"Epoch {epoch+1}, Train Loss: {train_loss:.4f}, Test Accuracy: {test_acc:.4f}") 167 | if test_acc > best_acc: 168 | best_acc = test_acc 169 | torch.save(inception_v3.state_dict(), './models/inception_v3_best_model.pth') 170 | print("Saved best model") 171 | 172 | 173 | 174 | if __name__ == '__main__': 175 | parser = argparse.ArgumentParser() 176 | parser.add_argument('--batch_size', default=64, type=int, help='batch size') 177 | parser.add_argument('--epochs', default=20, type=int, help='number of epochs to train') 178 | parser.add_argument('--LOG', default='./LOG/logs_inception_v3', type=str) 179 | args = parser.parse_args() 180 | 181 | main(args) -------------------------------------------------------------------------------- /vitb16_4finetune.py: -------------------------------------------------------------------------------- 1 | import math 2 | from collections import OrderedDict 3 | from functools import partial 4 | from typing import Any, Callable, Dict, List, NamedTuple, Optional 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | from torchvision.utils import _log_api_usage_once 10 | from torchvision.models.vision_transformer import Conv2dNormActivation, ConvStemConfig, Encoder 11 | 12 | class VisionTransformer_E(nn.Module): 13 | """Vision Transformer as per https://arxiv.org/abs/2010.11929.""" 14 | 15 | def __init__( 16 | self, 17 | image_size: int, 18 | patch_size: int, 19 | num_layers: int, 20 | num_heads: int, 21 | hidden_dim: int, 22 | mlp_dim: int, 23 | dropout: float = 0.0, 24 | attention_dropout: float = 0.0, 25 | num_classes: int = 1000, 26 | representation_size: Optional[int] = None, 27 | norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), 28 | conv_stem_configs: Optional[List[ConvStemConfig]] = None, 29 | num_experts: int = 8, 30 | ): 31 | super().__init__() 32 | _log_api_usage_once(self) 33 | torch._assert(image_size % patch_size == 0, "Input shape indivisible by patch size!") 34 | self.image_size = image_size 35 | self.patch_size = patch_size 36 | self.hidden_dim = hidden_dim 37 | self.mlp_dim = mlp_dim 38 | self.attention_dropout = attention_dropout 39 | self.dropout = dropout 40 | self.num_classes = num_classes 41 | self.representation_size = representation_size 42 | self.norm_layer = norm_layer 43 | 44 | self.num_experts = num_experts 45 | 46 | if conv_stem_configs is not None: 47 | # As per https://arxiv.org/abs/2106.14881 48 | seq_proj = nn.Sequential() 49 | prev_channels = 3 50 | for i, conv_stem_layer_config in enumerate(conv_stem_configs): 51 | seq_proj.add_module( 52 | f"conv_bn_relu_{i}", 53 | Conv2dNormActivation( 54 | in_channels=prev_channels, 55 | out_channels=conv_stem_layer_config.out_channels, 56 | kernel_size=conv_stem_layer_config.kernel_size, 57 | stride=conv_stem_layer_config.stride, 58 | norm_layer=conv_stem_layer_config.norm_layer, 59 | activation_layer=conv_stem_layer_config.activation_layer, 60 | ), 61 | ) 62 | prev_channels = conv_stem_layer_config.out_channels 63 | seq_proj.add_module( 64 | "conv_last", nn.Conv2d(in_channels=prev_channels, out_channels=hidden_dim, kernel_size=1) 65 | ) 66 | self.conv_proj: nn.Module = seq_proj 67 | else: 68 | self.conv_proj = nn.Conv2d( 69 | in_channels=3, out_channels=hidden_dim, kernel_size=patch_size, stride=patch_size 70 | ) 71 | 72 | seq_length = (image_size // patch_size) ** 2 73 | 74 | # Add a class token 75 | self.class_token = nn.Parameter(torch.zeros(1, 1, hidden_dim)) 76 | seq_length += 1 77 | 78 | self.encoder = Encoder( 79 | seq_length, 80 | num_layers, 81 | num_heads, 82 | hidden_dim, 83 | mlp_dim, 84 | dropout, 85 | attention_dropout, 86 | norm_layer, 87 | ) 88 | self.seq_length = seq_length 89 | 90 | heads_layers: OrderedDict[str, nn.Module] = OrderedDict() 91 | if representation_size is None: 92 | heads_layers["head"] = nn.Linear(hidden_dim, num_classes) 93 | else: 94 | heads_layers["pre_logits"] = nn.Linear(hidden_dim, representation_size) 95 | heads_layers["act"] = nn.Tanh() 96 | heads_layers["head"] = nn.Linear(representation_size, num_classes) 97 | 98 | self.heads = nn.Sequential(heads_layers) 99 | 100 | if isinstance(self.conv_proj, nn.Conv2d): 101 | # Init the patchify stem 102 | fan_in = self.conv_proj.in_channels * self.conv_proj.kernel_size[0] * self.conv_proj.kernel_size[1] 103 | nn.init.trunc_normal_(self.conv_proj.weight, std=math.sqrt(1 / fan_in)) 104 | if self.conv_proj.bias is not None: 105 | nn.init.zeros_(self.conv_proj.bias) 106 | elif self.conv_proj.conv_last is not None and isinstance(self.conv_proj.conv_last, nn.Conv2d): 107 | # Init the last 1x1 conv of the conv stem 108 | nn.init.normal_( 109 | self.conv_proj.conv_last.weight, mean=0.0, std=math.sqrt(2.0 / self.conv_proj.conv_last.out_channels) 110 | ) 111 | if self.conv_proj.conv_last.bias is not None: 112 | nn.init.zeros_(self.conv_proj.conv_last.bias) 113 | 114 | if hasattr(self.heads, "pre_logits") and isinstance(self.heads.pre_logits, nn.Linear): 115 | fan_in = self.heads.pre_logits.in_features 116 | nn.init.trunc_normal_(self.heads.pre_logits.weight, std=math.sqrt(1 / fan_in)) 117 | nn.init.zeros_(self.heads.pre_logits.bias) 118 | 119 | if isinstance(self.heads.head, nn.Linear): 120 | nn.init.zeros_(self.heads.head.weight) 121 | nn.init.zeros_(self.heads.head.bias) 122 | 123 | def _process_input(self, x: torch.Tensor) -> torch.Tensor: 124 | n, c, h, w = x.shape 125 | p = self.patch_size 126 | torch._assert(h == self.image_size, f"Wrong image height! Expected {self.image_size} but got {h}!") 127 | torch._assert(w == self.image_size, f"Wrong image width! Expected {self.image_size} but got {w}!") 128 | n_h = h // p 129 | n_w = w // p 130 | 131 | # (n, c, h, w) -> (n, hidden_dim, n_h, n_w) 132 | x = self.conv_proj(x) 133 | # (n, hidden_dim, n_h, n_w) -> (n, hidden_dim, (n_h * n_w)) 134 | x = x.reshape(n, self.hidden_dim, n_h * n_w) 135 | 136 | # (n, hidden_dim, (n_h * n_w)) -> (n, (n_h * n_w), hidden_dim) 137 | # The self attention layer expects inputs in the format (N, S, E) 138 | # where S is the source sequence length, N is the batch size, E is the 139 | # embedding dimension 140 | x = x.permute(0, 2, 1) 141 | 142 | return x 143 | 144 | def forward(self, x: torch.Tensor): 145 | # Reshape and permute the input tensor 146 | x = self._process_input(x) 147 | n = x.shape[0] 148 | 149 | # Expand the class token to the full batch 150 | batch_class_token = self.class_token.expand(n, -1, -1) 151 | x = torch.cat([batch_class_token, x], dim=1) 152 | 153 | x = self.encoder(x) 154 | outs = [] 155 | for ind in range(self.num_experts): 156 | tmp = x[:, 0] 157 | tmp = self.heads[ind](tmp) 158 | outs.append(tmp) 159 | x = torch.stack(outs, dim=1).mean(dim=1) 160 | return x, outs 161 | 162 | 163 | if __name__ == '__main__': 164 | model_4finetune = VisionTransformer_E(num_classes=8631, num_experts = 1, patch_size=16, num_layers=12, num_heads=12, hidden_dim=768, mlp_dim=3072, image_size=224) 165 | print(model_4finetune.conv_proj) 166 | print(len(model_4finetune.encoder)) 167 | -------------------------------------------------------------------------------- /genforce/runners/controllers/progress_scheduler.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Contains the running controller to control progressive training. 3 | 4 | This controller is applicable to the models that need to progressively change 5 | the batch size, learning rate, etc. 6 | """ 7 | 8 | import numpy as np 9 | 10 | from .base_controller import BaseController 11 | 12 | __all__ = ['ProgressScheduler'] 13 | 14 | _BATCH_SIZE_SCHEDULE_DICT = { 15 | 4: 16, 8: 8, 16: 4, 32: 2, 64: 1, 128: 1, 256: 1, 512: 1, 1024: 1, 16 | } 17 | _MAX_BATCH_SIZE = 64 18 | 19 | _LEARNING_RATE_SCHEDULE_DICT = { 20 | 4: 1, 8: 1, 16: 1, 32: 1, 64: 1, 128: 1.5, 256: 2, 512: 3, 1024: 3, 21 | } 22 | 23 | 24 | class ProgressScheduler(BaseController): 25 | """Defines the running controller to control progressive training. 26 | 27 | NOTE: The controller is set to `HIGH` priority by default. 28 | """ 29 | 30 | def __init__(self, config): 31 | assert isinstance(config, dict) 32 | config.setdefault('priority', 'HIGH') 33 | config.setdefault('every_n_iters', 1) 34 | super().__init__(config) 35 | 36 | self.base_batch_size = 0 37 | self.base_lrs = dict() 38 | 39 | self.total_img = 0 40 | self.init_res = config.get('init_res', 4) 41 | self.final_res = self.init_res 42 | self.init_lod = 0 43 | self.batch_size_schedule = config.get('batch_size_schedule', dict()) 44 | self.lr_schedule = config.get('lr_schedule', dict()) 45 | self.minibatch_repeats = config.get('minibatch_repeats', 4) 46 | 47 | self.lod_training_img = config.get('lod_training_img', 600_000) 48 | self.lod_transition_img = config.get('lod_transition_img', 600_000) 49 | self.lod_duration = (self.lod_training_img + self.lod_transition_img) 50 | 51 | # Whether to reset the optimizer state at the beginning of each phase. 52 | self.reset_optimizer = config.get('reset_optimizer', True) 53 | 54 | def get_batch_size(self, resolution): 55 | """Gets batch size for a particular resolution.""" 56 | if self.batch_size_schedule: 57 | return self.batch_size_schedule.get( 58 | f'res{resolution}', self.base_batch_size) 59 | batch_size_scale = _BATCH_SIZE_SCHEDULE_DICT[resolution] 60 | return min(_MAX_BATCH_SIZE, self.base_batch_size * batch_size_scale) 61 | 62 | def get_lr_scale(self, resolution): 63 | """Gets learning rate scale for a particular resolution.""" 64 | if self.lr_schedule: 65 | return self.lr_schedule.get(f'res{resolution}', 1) 66 | return _LEARNING_RATE_SCHEDULE_DICT[resolution] 67 | 68 | def setup(self, runner): 69 | # Set level of detail (lod). 70 | self.final_res = runner.resolution 71 | self.init_lod = np.log2(self.final_res // self.init_res) 72 | runner.lod = -1.0 73 | 74 | # Save default batch size and learning rate. 75 | self.base_batch_size = runner.batch_size 76 | for lr_name, lr_scheduler in runner.lr_schedulers.items(): 77 | self.base_lrs[lr_name] = lr_scheduler.base_lrs 78 | 79 | # Add running stats for logging. 80 | runner.running_stats.add( 81 | 'kimg', log_format='7.1f', log_name='kimg', log_strategy='CURRENT') 82 | runner.running_stats.add( 83 | 'lod', log_format='4.2f', log_name='lod', log_strategy='CURRENT') 84 | runner.running_stats.add( 85 | 'minibatch', log_format='4d', log_name='minibatch', 86 | log_strategy='CURRENT') 87 | 88 | # Log progressive schedule. 89 | runner.logger.info(f'Progressive Schedule:') 90 | res = self.init_res 91 | lod = int(self.init_lod) 92 | while res <= self.final_res: 93 | batch_size = self.get_batch_size(res) 94 | lr_scale = self.get_lr_scale(res) 95 | runner.logger.info(f' Resolution {res:4d} (lod {lod}): ' 96 | f'batch size ' 97 | f'{batch_size:3d} * {runner.world_size:2d}, ' 98 | f'learning rate scale {lr_scale:.1f}') 99 | res *= 2 100 | lod -= 1 101 | assert lod == -1 and res == self.final_res * 2 102 | 103 | # Compute total running iterations. 104 | assert hasattr(runner.config, 'total_img') 105 | self.total_img = runner.config.total_img 106 | current_img = 0 107 | num_iters = 0 108 | while current_img < self.total_img: 109 | phase = (current_img + self.lod_transition_img) // self.lod_duration 110 | phase = np.clip(phase, 0, self.init_lod) 111 | if num_iters % self.minibatch_repeats == 0: 112 | resolution = self.init_res * (2 ** int(phase)) 113 | current_img += self.get_batch_size(resolution) * runner.world_size 114 | num_iters += 1 115 | runner.total_iters = num_iters 116 | 117 | def execute_before_iteration(self, runner): 118 | is_first_iter = (runner.iter - runner.start_iter == 1) 119 | 120 | # Adjust hyper-parameters only at some particular iteration. 121 | if (not is_first_iter) and (runner.iter % self.minibatch_repeats != 1): 122 | return 123 | 124 | # Compute level-of-details. 125 | phase, subphase = divmod(runner.seen_img, self.lod_duration) 126 | lod = self.init_lod - phase 127 | if self.lod_transition_img: 128 | transition_img = max(subphase - self.lod_training_img, 0) 129 | lod = lod - transition_img / self.lod_transition_img 130 | lod = max(lod, 0.0) 131 | resolution = self.init_res * (2 ** int(np.ceil(self.init_lod - lod))) 132 | batch_size = self.get_batch_size(resolution) 133 | lr_scale = self.get_lr_scale(resolution) 134 | 135 | pre_lod = runner.lod 136 | pre_resolution = runner.train_loader.dataset.resolution 137 | runner.lod = lod 138 | 139 | # Reset optimizer state if needed. 140 | if self.reset_optimizer: 141 | if int(lod) != int(pre_lod) or np.ceil(lod) != np.ceil(pre_lod): 142 | runner.logger.info(f'Reset the optimizer state at ' 143 | f'iter {runner.iter:06d} (lod {lod:.6f}).') 144 | for name in runner.optimizers: 145 | runner.optimizers[name].state.clear() 146 | 147 | # Rebuild the dataset and adjust the learing rate if needed. 148 | if is_first_iter or resolution != pre_resolution: 149 | runner.logger.info(f'Rebuild the dataset at ' 150 | f'iter {runner.iter:06d} (lod {lod:.6f}).') 151 | runner.train_loader.overwrite_param( 152 | batch_size=batch_size, resolution=resolution) 153 | runner.batch_size = batch_size 154 | for lr_name, base_lrs in self.base_lrs.items(): 155 | runner.lr_schedulers[lr_name].base_lrs = [ 156 | lr * lr_scale for lr in base_lrs] 157 | 158 | def execute_after_iteration(self, runner): 159 | minibatch = runner.batch_size * runner.world_size 160 | runner.running_stats.update({'kimg': runner.seen_img / 1000}) 161 | runner.running_stats.update({'lod': runner.lod}) 162 | runner.running_stats.update({'minibatch': minibatch}) 163 | -------------------------------------------------------------------------------- /inceptionv3_4finetune.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from collections import namedtuple 3 | from functools import partial 4 | from typing import Any, Callable, List, Optional, Tuple 5 | import torch 6 | import torch.nn.functional as F 7 | from torch import nn, Tensor 8 | from torchvision.utils import _log_api_usage_once 9 | from torchvision.models.inception import BasicConv2d, InceptionA, InceptionB, InceptionC, InceptionD, InceptionE, InceptionAux 10 | 11 | 12 | class Inception3_E(nn.Module): 13 | def __init__( 14 | self, 15 | num_classes: int = 1000, 16 | aux_logits: bool = True, 17 | # transform_input: bool = False, 18 | transform_input: bool = True, 19 | inception_blocks: Optional[List[Callable[..., nn.Module]]] = None, 20 | init_weights: Optional[bool] = False,#None, 21 | dropout: float = 0.5, 22 | num_experts: int = None, 23 | ) -> None: 24 | super().__init__() 25 | self.num_experts = num_experts 26 | _log_api_usage_once(self) 27 | if inception_blocks is None: 28 | inception_blocks = [BasicConv2d, InceptionA, InceptionB, InceptionC, InceptionD, InceptionE, InceptionAux] 29 | if init_weights is None: 30 | warnings.warn( 31 | "The default weight initialization of inception_v3 will be changed in future releases of " 32 | "torchvision. If you wish to keep the old behavior (which leads to long initialization times" 33 | " due to scipy/scipy#11299), please set init_weights=True.", 34 | FutureWarning, 35 | ) 36 | init_weights = True 37 | if len(inception_blocks) != 7: 38 | raise ValueError(f"length of inception_blocks should be 7 instead of {len(inception_blocks)}") 39 | conv_block = inception_blocks[0] 40 | inception_a = inception_blocks[1] 41 | inception_b = inception_blocks[2] 42 | inception_c = inception_blocks[3] 43 | inception_d = inception_blocks[4] 44 | inception_e = inception_blocks[5] 45 | inception_aux = inception_blocks[6] 46 | 47 | self.aux_logits = aux_logits 48 | self.transform_input = transform_input 49 | self.Conv2d_1a_3x3 = conv_block(3, 32, kernel_size=3, stride=2) 50 | self.Conv2d_2a_3x3 = conv_block(32, 32, kernel_size=3) 51 | self.Conv2d_2b_3x3 = conv_block(32, 64, kernel_size=3, padding=1) 52 | self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2) 53 | self.Conv2d_3b_1x1 = conv_block(64, 80, kernel_size=1) 54 | self.Conv2d_4a_3x3 = conv_block(80, 192, kernel_size=3) 55 | self.maxpool2 = nn.MaxPool2d(kernel_size=3, stride=2) 56 | self.Mixed_5b = inception_a(192, pool_features=32) 57 | self.Mixed_5c = inception_a(256, pool_features=64) 58 | self.Mixed_5d = inception_a(288, pool_features=64) 59 | self.Mixed_6a = inception_b(288) 60 | self.Mixed_6b = inception_c(768, channels_7x7=128) 61 | self.Mixed_6c = inception_c(768, channels_7x7=160) 62 | self.Mixed_6d = inception_c(768, channels_7x7=160) 63 | self.Mixed_6e = inception_c(768, channels_7x7=192) 64 | self.AuxLogits: Optional[nn.Module] = None 65 | if aux_logits: 66 | self.AuxLogits = inception_aux(768, num_classes) 67 | self.Mixed_7a = inception_d(768) 68 | self.Mixed_7b = inception_e(1280) 69 | self.Mixed_7c = inception_e(2048) 70 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 71 | self.dropout = nn.Dropout(p=dropout) 72 | self.fc = nn.Linear(2048, num_classes) 73 | if init_weights: 74 | for m in self.modules(): 75 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): 76 | stddev = float(m.stddev) if hasattr(m, "stddev") else 0.1 # type: ignore 77 | torch.nn.init.trunc_normal_(m.weight, mean=0.0, std=stddev, a=-2, b=2) 78 | elif isinstance(m, nn.BatchNorm2d): 79 | nn.init.constant_(m.weight, 1) 80 | nn.init.constant_(m.bias, 0) 81 | 82 | def _transform_input(self, x: Tensor) -> Tensor: 83 | if self.transform_input: 84 | x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5 85 | x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5 86 | x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5 87 | x = torch.cat((x_ch0, x_ch1, x_ch2), 1) 88 | return x 89 | 90 | # def _forward(self, x: Tensor) -> Tuple[Tensor, Optional[Tensor]]: 91 | def forward(self, x: Tensor, train_flag: int) -> Tuple[Tensor, Optional[Tensor]]: 92 | # N x 3 x 299 x 299 93 | x = self.Conv2d_1a_3x3(x) 94 | # N x 32 x 149 x 149 95 | x = self.Conv2d_2a_3x3(x) 96 | # N x 32 x 147 x 147 97 | x = self.Conv2d_2b_3x3(x) 98 | # N x 64 x 147 x 147 99 | x = self.maxpool1(x) 100 | # N x 64 x 73 x 73 101 | x = self.Conv2d_3b_1x1(x) 102 | # N x 80 x 73 x 73 103 | x = self.Conv2d_4a_3x3(x) 104 | # N x 192 x 71 x 71 105 | x = self.maxpool2(x) 106 | # N x 192 x 35 x 35 107 | x = self.Mixed_5b(x) 108 | # N x 256 x 35 x 35 109 | x = self.Mixed_5c(x) 110 | # N x 288 x 35 x 35 111 | x = self.Mixed_5d(x) 112 | # N x 288 x 35 x 35 113 | x = self.Mixed_6a(x) 114 | # N x 768 x 17 x 17 115 | x = self.Mixed_6b(x) 116 | # N x 768 x 17 x 17 117 | x = self.Mixed_6c(x) 118 | # N x 768 x 17 x 17 119 | x = self.Mixed_6d(x) 120 | # N x 768 x 17 x 17 121 | x = self.Mixed_6e(x) 122 | # N x 768 x 17 x 17 123 | aux: Optional[Tensor] = None 124 | aux_outs: Optional[list] = None 125 | if self.AuxLogits is not None: 126 | # if self.training: 127 | if train_flag == 1: 128 | aux_outs = [] 129 | for ind in range(self.num_experts): 130 | tmp = self.AuxLogits[ind](x) 131 | aux_outs.append(tmp) 132 | aux = torch.stack(aux_outs, dim=1).mean(dim=1) 133 | 134 | # aux = self.AuxLogits(x) 135 | # N x 768 x 17 x 17 136 | x = self.Mixed_7a(x) 137 | # N x 1280 x 8 x 8 138 | x = self.Mixed_7b(x) 139 | # N x 2048 x 8 x 8 140 | 141 | x_outs = [] 142 | for ind in range(self.num_experts): 143 | tmp = self.Mixed_7c[ind](x) 144 | tmp = self.avgpool(tmp) 145 | tmp = self.dropout(tmp) 146 | tmp = torch.flatten(tmp, 1) 147 | tmp = self.fc[ind](tmp) 148 | x_outs.append(tmp) 149 | x = torch.stack(x_outs, dim=1).mean(dim=1) 150 | return x, x_outs, aux, aux_outs 151 | 152 | 153 | 154 | x = self.Mixed_7c(x) 155 | # N x 2048 x 8 x 8 156 | # Adaptive average pooling 157 | x = self.avgpool(x) 158 | # N x 2048 x 1 x 1 159 | x = self.dropout(x) 160 | # N x 2048 x 1 x 1 161 | x = torch.flatten(x, 1) 162 | # N x 2048 163 | x = self.fc(x) 164 | # N x 1000 (num_classes) 165 | return x, aux 166 | 167 | # if __name__ == '__main__': 168 | # model_4finetune = Inception3_E(num_classes = 8631, num_experts = 1) 169 | # print(model_4finetune.features[0:3]) 170 | # print(len(model_4finetune.features)) -------------------------------------------------------------------------------- /genforce/datasets/transforms.py: -------------------------------------------------------------------------------- 1 | """Contains transform functions.""" 2 | 3 | import cv2 4 | import numpy as np 5 | import PIL.Image 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | 12 | __all__ = [ 13 | 'crop_resize_image', 'progressive_resize_image', 'resize_image', 14 | 'normalize_image', 'normalize_latent_code', 'ImageResizing', 15 | 'ImageNormalization', 'LatentCodeNormalization', 16 | ] 17 | 18 | 19 | def crop_resize_image(image, size): 20 | """Crops a square patch and then resizes it to the given size. 21 | 22 | Args: 23 | image: The input image to crop and resize. 24 | size: An integer, indicating the target size. 25 | 26 | Returns: 27 | An image with target size. 28 | 29 | Raises: 30 | TypeError: If the input `image` is not with type `numpy.ndarray`. 31 | ValueError: If the input `image` is not with shape [H, W, C]. 32 | """ 33 | if not isinstance(image, np.ndarray): 34 | raise TypeError(f'Input image should be with type `numpy.ndarray`, ' 35 | f'but `{type(image)}` is received!') 36 | if image.ndim != 3: 37 | raise ValueError(f'Input image should be with shape [H, W, C], ' 38 | f'but `{image.shape}` is received!') 39 | 40 | height, width, channel = image.shape 41 | short_side = min(height, width) 42 | image = image[(height - short_side) // 2:(height + short_side) // 2, 43 | (width - short_side) // 2:(width + short_side) // 2] 44 | pil_image = PIL.Image.fromarray(image) 45 | pil_image = pil_image.resize((size, size), PIL.Image.ANTIALIAS) 46 | image = np.asarray(pil_image) 47 | assert image.shape == (size, size, channel) 48 | return image 49 | 50 | 51 | def progressive_resize_image(image, size): 52 | """Resizes image to target size progressively. 53 | 54 | Different from normal resize, this function will reduce the image size 55 | progressively. In each step, the maximum reduce factor is 2. 56 | 57 | NOTE: This function can only handle square images, and can only be used for 58 | downsampling. 59 | 60 | Args: 61 | image: The input (square) image to resize. 62 | size: An integer, indicating the target size. 63 | 64 | Returns: 65 | An image with target size. 66 | 67 | Raises: 68 | TypeError: If the input `image` is not with type `numpy.ndarray`. 69 | ValueError: If the input `image` is not with shape [H, W, C]. 70 | """ 71 | if not isinstance(image, np.ndarray): 72 | raise TypeError(f'Input image should be with type `numpy.ndarray`, ' 73 | f'but `{type(image)}` is received!') 74 | if image.ndim != 3: 75 | raise ValueError(f'Input image should be with shape [H, W, C], ' 76 | f'but `{image.shape}` is received!') 77 | 78 | height, width, channel = image.shape 79 | assert height == width 80 | assert height >= size 81 | num_iters = int(np.log2(height) - np.log2(size)) 82 | for _ in range(num_iters): 83 | height = max(height // 2, size) 84 | image = cv2.resize(image, (height, height), 85 | interpolation=cv2.INTER_LINEAR) 86 | assert image.shape == (size, size, channel) 87 | return image 88 | 89 | 90 | def resize_image(image, size): 91 | """Resizes image to target size. 92 | 93 | NOTE: We use adaptive average pooing for image resizing. Instead of bilinear 94 | interpolation, average pooling is able to acquire information from more 95 | pixels, such that the resized results can be with higher quality. 96 | 97 | Args: 98 | image: The input image tensor, with shape [C, H, W], to resize. 99 | size: An integer or a tuple of integer, indicating the target size. 100 | 101 | Returns: 102 | An image tensor with target size. 103 | 104 | Raises: 105 | TypeError: If the input `image` is not with type `torch.Tensor`. 106 | ValueError: If the input `image` is not with shape [C, H, W]. 107 | """ 108 | if not isinstance(image, torch.Tensor): 109 | raise TypeError(f'Input image should be with type `torch.Tensor`, ' 110 | f'but `{type(image)}` is received!') 111 | if image.ndim != 3: 112 | raise ValueError(f'Input image should be with shape [C, H, W], ' 113 | f'but `{image.shape}` is received!') 114 | 115 | image = F.adaptive_avg_pool2d(image.unsqueeze(0), size).squeeze(0) 116 | return image 117 | 118 | 119 | def normalize_image(image, mean=127.5, std=127.5): 120 | """Normalizes image by subtracting mean and dividing std. 121 | 122 | Args: 123 | image: The input image tensor to normalize. 124 | mean: The mean value to subtract from the input tensor. (default: 127.5) 125 | std: The standard deviation to normalize the input tensor. (default: 126 | 127.5) 127 | 128 | Returns: 129 | A normalized image tensor. 130 | 131 | Raises: 132 | TypeError: If the input `image` is not with type `torch.Tensor`. 133 | """ 134 | if not isinstance(image, torch.Tensor): 135 | raise TypeError(f'Input image should be with type `torch.Tensor`, ' 136 | f'but `{type(image)}` is received!') 137 | out = (image - mean) / std 138 | return out 139 | 140 | 141 | def normalize_latent_code(latent_code, adjust_norm=True): 142 | """Normalizes latent code. 143 | 144 | NOTE: The latent code will always be normalized along the last axis. 145 | Meanwhile, if `adjust_norm` is set as `True`, the norm of the result will be 146 | adjusted to `sqrt(latent_code.shape[-1])` in order to avoid too small value. 147 | 148 | Args: 149 | latent_code: The input latent code tensor to normalize. 150 | adjust_norm: Whether to adjust the norm of the output. (default: True) 151 | 152 | Returns: 153 | A normalized latent code tensor. 154 | 155 | Raises: 156 | TypeError: If the input `latent_code` is not with type `torch.Tensor`. 157 | """ 158 | if not isinstance(latent_code, torch.Tensor): 159 | raise TypeError(f'Input latent code should be with type ' 160 | f'`torch.Tensor`, but `{type(latent_code)}` is ' 161 | f'received!') 162 | dim = latent_code.shape[-1] 163 | norm = latent_code.pow(2).sum(-1, keepdim=True).pow(0.5) 164 | out = latent_code / norm 165 | if adjust_norm: 166 | out = out * (dim ** 0.5) 167 | return out 168 | 169 | 170 | class ImageResizing(nn.Module): 171 | """Implements the image resizing layer.""" 172 | 173 | def __init__(self, size): 174 | super().__init__() 175 | self.size = size 176 | 177 | def forward(self, image): 178 | return resize_image(image, self.size) 179 | 180 | 181 | class ImageNormalization(nn.Module): 182 | """Implements the image normalization layer.""" 183 | 184 | def __init__(self, mean=127.5, std=127.5): 185 | super().__init__() 186 | self.mean = mean 187 | self.std = std 188 | 189 | def forward(self, image): 190 | return normalize_image(image, self.mean, self.std) 191 | 192 | 193 | class LatentCodeNormalization(nn.Module): 194 | """Implements the latent code normalization layer.""" 195 | 196 | def __init__(self, adjust_norm=True): 197 | super().__init__() 198 | self.adjust_norm = adjust_norm 199 | 200 | def forward(self, latent_code): 201 | return normalize_latent_code(latent_code, self.adjust_norm) 202 | -------------------------------------------------------------------------------- /genforce/runners/running_stats.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Contains the class for recording the running stats. 3 | 4 | Here, running stats refers to the statictical information in the running 5 | process, such as loss values, learning rates, running time, etc. 6 | """ 7 | 8 | from .misc import format_time 9 | 10 | __all__ = ['SingleStats', 'RunningStats'] 11 | 12 | 13 | class SingleStats(object): 14 | """A class to record the stats corresponding to a particular variable. 15 | 16 | This class is log-friendly and supports customized log format, including: 17 | 18 | (1) Numerical log format, such as `.3f`, `.1e`, `05d`, and `>10s`. 19 | (2) Customized log name (name of the stats to show in the log). 20 | (3) Additional string (e.g., measure unit) as the tail of log message. 21 | 22 | Furthermore, this class also supports logging the stats with different 23 | strategies, including: 24 | 25 | (1) CURRENT: The current value will be logged. 26 | (2) AVERAGE: The averaged value (from the beginning) will be logged. 27 | (3) SUM: The cumulative value (from the beginning) will be logged. 28 | """ 29 | 30 | def __init__(self, 31 | name, 32 | log_format='.3f', 33 | log_name=None, 34 | log_tail=None, 35 | log_strategy='AVERAGE'): 36 | """Initializes the stats with log format. 37 | 38 | Args: 39 | name: Name of the stats. Should be a string without spaces. 40 | log_format: The numerical log format. Use `time` to log time 41 | duration. (default: `.3f`) 42 | log_name: The name shown in the log. `None` means to directly use 43 | the stats name. (default: None) 44 | log_tail: The tailing log message. (default: None) 45 | log_strategy: Strategy to log this stats. `CURRENT`, `AVERAGE`, and 46 | `SUM` are supported. (default: `AVERAGE`) 47 | 48 | Raises: 49 | ValueError: If the input `log_strategy` is not supported. 50 | """ 51 | log_strategy = log_strategy.upper() 52 | if log_strategy not in ['CURRENT', 'AVERAGE', 'SUM']: 53 | raise ValueError(f'Invalid log strategy `{self.log_strategy}`!') 54 | 55 | self._name = name 56 | self._log_format = log_format 57 | self._log_name = log_name or name 58 | self._log_tail = log_tail or '' 59 | self._log_strategy = log_strategy 60 | 61 | # Stats Data. 62 | self.val = 0 # Current value. 63 | self.sum = 0 # Cumulative value. 64 | self.avg = 0 # Averaged value. 65 | self.cnt = 0 # Count number. 66 | 67 | @property 68 | def name(self): 69 | """Gets the name of the stats.""" 70 | return self._name 71 | 72 | @property 73 | def log_format(self): 74 | """Gets tne numerical log format of the stats.""" 75 | return self._log_format 76 | 77 | @property 78 | def log_name(self): 79 | """Gets the log name of the stats.""" 80 | return self._log_name 81 | 82 | @property 83 | def log_tail(self): 84 | """Gets the tailing log message of the stats.""" 85 | return self._log_tail 86 | 87 | @property 88 | def log_strategy(self): 89 | """Gets the log strategy of the stats.""" 90 | return self._log_strategy 91 | 92 | def clear(self): 93 | """Clears the stats data.""" 94 | self.val = 0 95 | self.sum = 0 96 | self.avg = 0 97 | self.cnt = 0 98 | 99 | def update(self, value): 100 | """Updates the stats data.""" 101 | self.val = value 102 | self.cnt = self.cnt + 1 103 | self.sum = self.sum + value 104 | self.avg = self.sum / self.cnt 105 | 106 | def get_log_value(self): 107 | """Gets value for logging according to the log strategy.""" 108 | if self.log_strategy == 'CURRENT': 109 | return self.val 110 | if self.log_strategy == 'AVERAGE': 111 | return self.avg 112 | if self.log_strategy == 'SUM': 113 | return self.sum 114 | raise NotImplementedError(f'Log strategy `{self.log_strategy}` is not ' 115 | f'implemented!') 116 | 117 | def __str__(self): 118 | """Gets log message.""" 119 | if self.log_format == 'time': 120 | value_str = f'{format_time(self.get_log_value())}' 121 | else: 122 | value_str = f'{self.get_log_value():{self.log_format}}' 123 | return f'{self.log_name}: {value_str}{self.log_tail}' 124 | 125 | 126 | class RunningStats(object): 127 | """A class to record all the running stats. 128 | 129 | Basically, this class contains a dictionary of SingleStats. 130 | 131 | Example: 132 | 133 | running_stats = RunningStats() 134 | running_stats.add('loss', log_format='.3f', log_strategy='AVERAGE') 135 | running_stats.add('time', log_format='time', log_name='Iter Time', 136 | log_strategy='CURRENT') 137 | running_stats.log_order = ['time', 'loss'] 138 | running_stats.update({'loss': 0.46, 'time': 12}) 139 | running_stats.update({'time': 14.5, 'loss': 0.33}) 140 | print(running_stats) 141 | """ 142 | 143 | def __init__(self, log_delimiter=', '): 144 | """Initializes the running stats with the log delimiter. 145 | 146 | Args: 147 | log_delimiter: This delimiter is used to connect the log messages 148 | from different stats. (default: `, `) 149 | """ 150 | self._log_delimiter = log_delimiter 151 | self.stats_pool = dict() # The stats pool. 152 | self.log_order = None # Order of the stats to log. 153 | 154 | @property 155 | def log_delimiter(self): 156 | """Gets the log delimiter between different stats.""" 157 | return self._log_delimiter 158 | 159 | def add(self, name, **kwargs): 160 | """Adds a new SingleStats to the dictionary. 161 | 162 | Additional arguments include: 163 | 164 | log_format: The numerical log format. Use `time` to log time duration. 165 | (default: `.3f`) 166 | log_name: The name shown in the log. `None` means to directly use the 167 | stats name. (default: None) 168 | log_tail: The tailing log message. (default: None) 169 | log_strategy: Strategy to log this stats. `CURRENT`, `AVERAGE`, and 170 | `SUM` are supported. (default: `AVERAGE`) 171 | """ 172 | if name in self.stats_pool: 173 | return 174 | self.stats_pool[name] = SingleStats(name, **kwargs) 175 | 176 | def clear(self, exclude_list=None): 177 | """Clears the stats data (if needed). 178 | 179 | Args: 180 | exclude_list: A list of stats names whose data will not be cleared. 181 | """ 182 | exclude_list = set(exclude_list or []) 183 | for name, stats in self.stats_pool.items(): 184 | if name not in exclude_list: 185 | stats.clear() 186 | 187 | def update(self, kwargs): 188 | """Updates the stats data by name.""" 189 | for name, value in kwargs.items(): 190 | if name not in self.stats_pool: 191 | self.add(name) 192 | self.stats_pool[name].update(value) 193 | 194 | def __getattr__(self, name): 195 | """Gets a particular SingleStats by name.""" 196 | if name in self.stats_pool: 197 | return self.stats_pool[name] 198 | if name in self.__dict__: 199 | return self.__dict__[name] 200 | raise AttributeError(f'`{self.__class__.__name__}` object has no ' 201 | f'attribute `{name}`!') 202 | 203 | def __str__(self): 204 | """Gets log message.""" 205 | self.log_order = self.log_order or list(self.stats_pool) 206 | log_strings = [str(self.stats_pool[name]) for name in self.log_order] 207 | return self.log_delimiter.join(log_strings) 208 | -------------------------------------------------------------------------------- /mobilenetv2_4finetune.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | from functools import partial 5 | from typing import Any, Callable, List, Optional 6 | from torchvision.utils import _log_api_usage_once 7 | from torch import nn, Tensor 8 | from torchvision.ops.misc import Conv2dNormActivation 9 | 10 | def _make_divisible(v: float, divisor: int, min_value: Optional[int] = None) -> int: 11 | """ 12 | This function is taken from the original tf repo. 13 | It ensures that all layers have a channel number that is divisible by 8 14 | It can be seen here: 15 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py 16 | """ 17 | if min_value is None: 18 | min_value = divisor 19 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 20 | # Make sure that round down does not go down by more than 10%. 21 | if new_v < 0.9 * v: 22 | new_v += divisor 23 | return new_v 24 | 25 | # necessary for backwards compatibility 26 | class InvertedResidual(nn.Module): 27 | def __init__( 28 | self, inp: int, oup: int, stride: int, expand_ratio: int, norm_layer: Optional[Callable[..., nn.Module]] = None 29 | ) -> None: 30 | super().__init__() 31 | self.stride = stride 32 | if stride not in [1, 2]: 33 | raise ValueError(f"stride should be 1 or 2 instead of {stride}") 34 | 35 | if norm_layer is None: 36 | norm_layer = nn.BatchNorm2d 37 | 38 | hidden_dim = int(round(inp * expand_ratio)) 39 | self.use_res_connect = self.stride == 1 and inp == oup 40 | 41 | layers: List[nn.Module] = [] 42 | if expand_ratio != 1: 43 | # pw 44 | layers.append( 45 | Conv2dNormActivation(inp, hidden_dim, kernel_size=1, norm_layer=norm_layer, activation_layer=nn.ReLU6) 46 | ) 47 | layers.extend( 48 | [ 49 | # dw 50 | Conv2dNormActivation( 51 | hidden_dim, 52 | hidden_dim, 53 | stride=stride, 54 | groups=hidden_dim, 55 | norm_layer=norm_layer, 56 | activation_layer=nn.ReLU6, 57 | ), 58 | # pw-linear 59 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 60 | norm_layer(oup), 61 | ] 62 | ) 63 | self.conv = nn.Sequential(*layers) 64 | self.out_channels = oup 65 | self._is_cn = stride > 1 66 | 67 | def forward(self, x: Tensor) -> Tensor: 68 | if self.use_res_connect: 69 | return x + self.conv(x) 70 | else: 71 | return self.conv(x) 72 | 73 | class MobileNetV2_E(nn.Module): 74 | def __init__( 75 | self, 76 | num_classes: int = 1000, 77 | width_mult: float = 1.0, 78 | inverted_residual_setting: Optional[List[List[int]]] = None, 79 | round_nearest: int = 8, 80 | block: Optional[Callable[..., nn.Module]] = None, 81 | norm_layer: Optional[Callable[..., nn.Module]] = None, 82 | dropout: float = 0.2, 83 | num_experts: int = 8, 84 | ) -> None: 85 | """ 86 | MobileNet V2 main class 87 | 88 | Args: 89 | num_classes (int): Number of classes 90 | width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount 91 | inverted_residual_setting: Network structure 92 | round_nearest (int): Round the number of channels in each layer to be a multiple of this number 93 | Set to 1 to turn off rounding 94 | block: Module specifying inverted residual building block for mobilenet 95 | norm_layer: Module specifying the normalization layer to use 96 | dropout (float): The droupout probability 97 | 98 | """ 99 | super().__init__() 100 | _log_api_usage_once(self) 101 | 102 | self.num_experts = num_experts 103 | 104 | self.last_new_layer = None 105 | 106 | if block is None: 107 | block = InvertedResidual 108 | 109 | if norm_layer is None: 110 | norm_layer = nn.BatchNorm2d 111 | 112 | input_channel = 32 113 | last_channel = 1280 114 | 115 | if inverted_residual_setting is None: 116 | inverted_residual_setting = [ 117 | # t, c, n, s 118 | [1, 16, 1, 1], 119 | [6, 24, 2, 2], 120 | [6, 32, 3, 2], 121 | [6, 64, 4, 2], 122 | [6, 96, 3, 1], 123 | [6, 160, 3, 2], 124 | [6, 320, 1, 1], 125 | ] 126 | 127 | # only check the first element, assuming user knows t,c,n,s are required 128 | if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4: 129 | raise ValueError( 130 | f"inverted_residual_setting should be non-empty or a 4-element list, got {inverted_residual_setting}" 131 | ) 132 | 133 | # building first layer 134 | input_channel = _make_divisible(input_channel * width_mult, round_nearest) 135 | self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest) 136 | features: List[nn.Module] = [ 137 | Conv2dNormActivation(3, input_channel, stride=2, norm_layer=norm_layer, activation_layer=nn.ReLU6) 138 | ] 139 | # building inverted residual blocks 140 | for t, c, n, s in inverted_residual_setting: 141 | output_channel = _make_divisible(c * width_mult, round_nearest) 142 | for i in range(n): 143 | stride = s if i == 0 else 1 144 | features.append(block(input_channel, output_channel, stride, expand_ratio=t, norm_layer=norm_layer)) 145 | input_channel = output_channel 146 | # building last several layers 147 | features.append( 148 | Conv2dNormActivation( 149 | input_channel, self.last_channel, kernel_size=1, norm_layer=norm_layer, activation_layer=nn.ReLU6 150 | ) 151 | ) 152 | # make it nn.Sequential 153 | self.features = nn.Sequential(*features) 154 | 155 | # building classifier 156 | self.classifier = nn.Sequential( 157 | nn.Dropout(p=dropout), 158 | nn.Linear(self.last_channel, num_classes), 159 | ) 160 | 161 | # weight initialization 162 | for m in self.modules(): 163 | if isinstance(m, nn.Conv2d): 164 | nn.init.kaiming_normal_(m.weight, mode="fan_out") 165 | if m.bias is not None: 166 | nn.init.zeros_(m.bias) 167 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 168 | nn.init.ones_(m.weight) 169 | nn.init.zeros_(m.bias) 170 | elif isinstance(m, nn.Linear): 171 | nn.init.normal_(m.weight, 0, 0.01) 172 | nn.init.zeros_(m.bias) 173 | 174 | def _forward_impl(self, x: Tensor, num_experts: int) -> Tensor: 175 | # This exists since TorchScript doesn't support inheritance, so the superclass method 176 | # (this one) needs to have a name other than `forward` that can be accessed in a subclass 177 | # x = self.features(x) 178 | for i in range(len(self.features) - num_experts): 179 | x = self.features[i](x) 180 | 181 | outs = [] 182 | for ind in range(num_experts): 183 | tmp = self.features[-1*(ind+1)](x) 184 | tmp = nn.functional.adaptive_avg_pool2d(tmp, (1, 1)) 185 | tmp = torch.flatten(tmp, 1) 186 | tmp = self.classifier[ind](tmp) 187 | outs.append(tmp) 188 | x = torch.stack(outs, dim=1).mean(dim=1) 189 | return x, outs 190 | 191 | def forward(self, x: Tensor) -> Tensor: 192 | return self._forward_impl(x, self.num_experts) 193 | 194 | if __name__ == '__main__': 195 | model_4finetune = MobileNetV2_E(num_classes = 8631, num_experts = 1) 196 | print(model_4finetune.features[0:3]) 197 | print(len(model_4finetune.features)) --------------------------------------------------------------------------------