├── .gitattributes ├── .gitignore ├── .vscode └── settings.json ├── README.md ├── docs ├── datasets.md ├── path.md └── simple_process.md ├── examples └── standard │ ├── README.md │ ├── adapter_training │ └── c2f.py │ ├── attacks │ ├── brepmi.py │ ├── c2f.py │ ├── gmi.py │ ├── lokt.py │ ├── lomma_gmi.py │ ├── mirror_black.py │ ├── mirror_white.py │ ├── plgmi.py │ ├── ppa.py │ ├── rlbmi.py │ └── vmi.py │ ├── classifier_training │ ├── celeba112.py │ ├── celeba224.py │ ├── celeba64.py │ ├── celeba64_ir152_bido.py │ ├── celeba64_ir152_ls.py │ ├── celeba64_ir152_tl.py │ ├── celeba64_ir152_vib.py │ ├── distill_celeba64_celeba64.py │ └── lokt_ffhq64_dense121_facescrub64_ir152.py │ ├── dataset_preprocess │ ├── afhqdogs256.py │ ├── celeba.py │ ├── ffhq256.py │ ├── ffhq64.py │ ├── lokt_generation.py │ ├── metfaces256.py │ └── plgmi_top_k_selection.py │ └── gan_training │ ├── gmi.py │ ├── kedmi.py │ ├── lokt.py │ └── plgmi.py ├── requirements.txt ├── requirements_ori.txt └── src └── modelinversion ├── __init__.py ├── attack ├── SecretGen │ ├── .DS_Store │ ├── .gitignore │ ├── LICENSE │ ├── README.md │ ├── dataloader.py │ ├── discri.py │ ├── eval_target.py │ ├── facenet.py │ ├── losses.py │ ├── models.py │ ├── requirements.txt │ ├── stage1.py │ ├── stage2.py │ ├── tgt_models │ │ ├── resnet152.py │ │ ├── vgg16.py │ │ └── vit.py │ ├── train_target.py │ └── utils.py ├── VMI │ ├── __init__.py │ └── vmi_attacker.py ├── __init__.py ├── attacker.py ├── attacker_ori.py ├── losses.py └── optimize │ ├── __init__.py │ ├── base.py │ ├── deepinversion.py │ ├── genetic.py │ └── rlb.py ├── configs ├── attack_config.py ├── classifier_config.py └── gan_config.py ├── datasets ├── __init__.py ├── base.py ├── celeba.py ├── facescrub.py ├── ffhq.py ├── generator.py ├── preprocess.py ├── split_files │ ├── private_test.txt │ ├── private_train.txt │ └── public.txt └── utils.py ├── defense ├── BiDO │ ├── __init__.py │ ├── kernel.py │ └── trainer.py ├── DP │ ├── __init__.py │ └── trainer.py ├── LS │ ├── __init__.py │ └── trainer.py ├── README.md ├── TL │ ├── __init__.py │ └── trainer.py ├── Vib │ ├── __init__.py │ └── trainer.py ├── __init__.py ├── base.py ├── distill │ ├── __init__.py │ └── trainer.py └── no_defense │ ├── __init__.py │ └── trainer.py ├── metrics ├── __init__.py ├── base.py ├── fid │ ├── __init__.py │ ├── fid_utils.py │ └── inceptionv3.py ├── psnr │ └── __init__.py └── ssim │ └── __init__.py ├── models ├── README.md ├── __init__.py ├── adapters │ ├── __init__.py │ ├── base.py │ └── c2f.py ├── base.py ├── classifiers │ ├── __init__.py │ ├── base.py │ ├── classifier112.py │ ├── classifier64.py │ ├── classifier_utils.py │ ├── evolve │ │ ├── __init__.py │ │ └── evolve.py │ ├── inception.py │ └── wrappers.py └── gans │ ├── __init__.py │ ├── base.py │ ├── cgan.py │ ├── simple.py │ └── stylegan2ada.py ├── sampler ├── __init__.py ├── base.py ├── flow │ ├── __init__.py │ ├── ais_utils.py │ ├── flow_utils.py │ ├── likelihood_models.py │ ├── model.py │ ├── modules.py │ ├── spectral_norm_adaptive.py │ └── toy_utils.py └── labelonly.py ├── scores ├── __init__.py ├── functional.py ├── imgscore.py └── latentscore.py ├── train ├── __init__.py ├── classifier │ ├── __init__.py │ ├── base.py │ ├── bido.py │ └── distill.py ├── gan.py └── mapping.py └── utils ├── __init__.py ├── accumulator.py ├── batch.py ├── check.py ├── config.py ├── constraint.py ├── hook.py ├── io.py ├── log.py ├── losses.py ├── outputs.py ├── random.py └── torchutil.py /.gitattributes: -------------------------------------------------------------------------------- 1 | *.py eol=lf 2 | *.txt eol=lf 3 | *.yaml eol=lf -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | .idea/ 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | workspace.xml 81 | # SageMath parsed files 82 | *.sage.py 83 | .idea/workspace.xml 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | cache/ 106 | 107 | *.bak 108 | *.tmp 109 | *.pt 110 | *.pth 111 | *.tar 112 | *.pkl 113 | *.tar.gz 114 | results/* 115 | /results* 116 | # checkpoints/*/* 117 | dataset/*/* 118 | 119 | nohup.txt 120 | 121 | # force into git 122 | !README.md 123 | !/**/*.py 124 | !checkpoints/**/.gitkeep 125 | # !dataset/celeba/README.md 126 | /test.py 127 | /test.ipynb 128 | /test* 129 | .vscode 130 | test*.py 131 | test*.sh 132 | checkpoints_v2/ 133 | /datasets -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "python.analysis.extraPaths": [ 3 | "./src" 4 | ], 5 | "[python]": { 6 | "editor.defaultFormatter": "ms-python.black-formatter", 7 | "editor.formatOnSave": true, 8 | 9 | "editor.formatOnType": true 10 | }, 11 | "black-formatter.args": [ 12 | "-S" 13 | ], 14 | } -------------------------------------------------------------------------------- /docs/datasets.md: -------------------------------------------------------------------------------- 1 | 2 | # Datasets 3 | 4 | Here are the details for preprocessing datasets in 2 steps. We provide the preprocess tools for 5 | + celeba 6 | + facescrub 7 | + ffhq64 8 | + ffhq256 9 | + metfaces256 10 | + afhqdog256 11 | 12 | Note that when using the `celeba64` and `facescrub64` datasets you can directly use the transform `Resize((64,64))` in torchvision on `celeba112` and `facescrub112` datasets respectively. 13 | 14 | 15 | ## Step 1: Download datasets 16 | 17 | ### Celeba 18 | 19 | Download celeba dataset from [here](https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html). 20 | 21 | The structure of the dataset is as follows: 22 | ``` 23 | 24 | ├── img_align_celeba 25 | ├── identity_CelebA.txt 26 | ├── list_attr_celeba.txt 27 | ├── list_bbox_celeba.txt 28 | ├── list_eval_partition.txt 29 | ├── list_landmarks_align_celeba.txt 30 | └── list_landmarks_celeba.txt 31 | ``` 32 | 33 | For `celeba` with low resolution, you can directly use your download file above for step 3. 34 | 35 | For `celeba` with high resolution (e.g. $224\times 224$), you need to follow [HD-CelebA-Cropper](https://github.com/LynnHo/HD-CelebA-Cropper) to increase the resolution of the cropped and aligned samples. Run the script of the cropper and replace all the images in `img_align_celeba`. 36 | ```sh 37 | python align.py --crop_size_h 224 --crop_size_w 224 --order 3 --save_format png --face_factor 0.65 --n_worker 32 38 | ``` 39 | 40 | ### FaceScrub 41 | 42 | Use [this script](https://github.com/faceteam/facescrub) to download facescrub and some links are unavailable. 43 | 44 | The structure of the dataset is as follows: 45 | ``` 46 | 47 | ├── actors 48 | │ └── faces 49 | └── actresses 50 | └── faces 51 | ``` 52 | 53 | ### FFHQ 54 | 55 | For `ffhq64`, download [thumbnails128x128](https://drive.google.com/drive/folders/1tg-Ur7d4vk1T8Bn0pPpUSQPxlPGBlGfv). 56 | 57 | For `ffhq256`, download [images1024x1024](https://drive.google.com/drive/folders/1tZUcXDBeOibC6jcMCtgRRz67pzrAHeHL). 58 | 59 | ### MetFaces 60 | 61 | Download [here](https://drive.google.com/drive/folders/1iChdwdW7mZFUyivKtDwL8ehCNhYKQz6D). 62 | 63 | ### afhqdog 64 | 65 | Follow [StyleGAN2-ada](https://github.com/NVlabs/stylegan2-ada-pytorch) to download afhqdog dataset. 66 | ```sh 67 | python dataset_tool.py --source=~/downloads/afhq/train/dog --dest=~/datasets/afhqdog.zip 68 | ``` 69 | 70 | ## Step 2: Preprocess data 71 | 72 | Fill the relative path for relative scripts in [examples/standard/datasets](../examples/standard/datasets) and run the scripts. Note that **FaceScrub dataset do not need to be preprocessed**. The parameters are as follows: 73 | + src_path: The path for the dataset you download. 74 | + dst_path: The path for the preprocessed dataset. 75 | + split_file_path: Only `celeba` need this parameter. We provide split files to split the dataset into train and test subset for `celeba`. Split files are available at [here](https://drive.google.com/drive/folders/13jGV8bsQnxZRMPSVOLzu3OVGWyQf5kpI). Note that you need to unzip the file. 76 | 77 | The file structure of split files for `celeba` is as follows: 78 | ``` 79 | split_files/ 80 | ├── private_test.txt 81 | ├── private_train.txt 82 | └── public.txt 83 | ``` -------------------------------------------------------------------------------- /docs/path.md: -------------------------------------------------------------------------------- 1 | # Path Tutorial 2 | 3 | ## Experiment Folder 4 | 5 | The experiment folder is the folder where the experiment results are saved. 6 | 7 | ## Dataset 8 | 9 | ### CelebA 10 | 11 | To use CelebA dataset, please follow [here](./datasets.md) to download and preprocess the dataset. 12 | 13 | The preprocessed CelebA dataset is organized into the following structure: 14 | ``` 15 | 16 | ├── public 17 | ├── private_train 18 | └── private_test 19 | ``` 20 | 21 | To access the dataset, the `dataset_path` format should follow `/` where `` corresponds to one of the available subsets: public, private_train, or private_test. 22 | 23 | You can load the dataset using the following code: 24 | 25 | ```python 26 | from modelinversion.datasets import CelebA64 27 | from torchvision.transforms import ToTensor 28 | 29 | dataset_path = "/" 30 | 31 | dataset = CelebA64(dataset_path, output_transform=ToTensor()) 32 | ``` 33 | 34 | In this example, CelebA64 refers to the CelebA dataset with a resolution of $64\times 64$ pixels. We also provide datasets at other resolutions, including: CelebA112, CelebA224 and CelebA299. 35 | 36 | ### FaceScrub 37 | 38 | The downloaded FaceScrub dataset is organized into the following structure: 39 | ``` 40 | 41 | ├── actors 42 | └── actresses 43 | ``` 44 | 45 | The `dataset_path` corresponds to the `facescrub_download_path` where the dataset is downloaded. 46 | 47 | You can load the dataset using the following code: 48 | 49 | ```python 50 | from modelinversion.datasets import FaceScrub64 51 | from torchvision.transforms import ToTensor 52 | 53 | dataset_path = "" 54 | 55 | train_dataset = FaceScrub64(dataset_path, train=True, output_transform=ToTensor()) 56 | test_dataset = FaceScrub64(dataset_path, train=False, output_transform=ToTensor()) 57 | ``` 58 | 59 | In this example, FaceScrub64 refers to the CelebA dataset with a resolution of $64\times 64$ pixels. We also provide datasets at other resolutions, including: FaceScrub112, FaceScrub224 and FaceScrub299. 60 | 61 | ### Labeled Datasets 62 | 63 | For datasets generated by [top_k_selection.py](../examples/standard/dataset_preprocess/plgmi_top_k_selection.py) or [lokt_generation.py](../examples/standard/dataset_preprocess/lokt_generation.py). The dataset is organized into the following structure: 64 | ``` 65 | 66 | ├── 0 67 | ├── 1 68 | ├── 2 69 | ├── 3 70 | ├── ... 71 | ``` 72 | 73 | You can load the dataset using the following code: 74 | ```python 75 | from modelinversion.datasets import LabelImageFolder, Celeba64 76 | from torchvision.transforms import ToTensor 77 | 78 | # if the dataset is extracted by origin celeba dataset, use the series of CelebA64, CelebA112 79 | dataset = LabelImageFolder(dataset_path, output_transform=ToTensor()) 80 | 81 | # otherwise, use LabelImageFolder 82 | dataset = LabelImageFolder(dataset_path, transform=ToTensor()) 83 | ``` 84 | 85 | ### Public Datasets 86 | 87 | Public Datasets contains FFHQ64, FFHQ256, MetFaces256. To use these datasets, just use `ImageFolder` from `torchvision.datasets`. 88 | 89 | ```python 90 | from torchvision.datasets import ImageFolder 91 | from torchvision.transforms import ToTensor 92 | 93 | dataset_path = "" 94 | 95 | dataset = ImageFolder(dataset_path, transform=ToTensor()) 96 | ``` 97 | 98 | ## Checkpoint Path 99 | 100 | 101 | The format of the model ckeckpoint path should be `/.pth`. 102 | 103 | You can load the dataset using the following codes: 104 | 105 | ```python 106 | from modelinversion.models import ( 107 | auto_classifier_from_pretrained, 108 | auto_generator_from_pretrained, 109 | auto_discriminator_from_pretrained, 110 | ) 111 | 112 | target_model_ckpt_path = "" 113 | eval_model_ckpt_path = "" 114 | generator_ckpt_path = "" 115 | discriminator_ckpt_path = "" 116 | 117 | target_model = auto_classifier_from_pretrained(target_model_ckpt_path) 118 | eval_model = auto_classifier_from_pretrained(eval_model_ckpt_path) 119 | generator = auto_generator_from_pretrained(generator_ckpt_path) 120 | discriminator = auto_discriminator_from_pretrained(discriminator_ckpt_path) 121 | 122 | ``` 123 | -------------------------------------------------------------------------------- /docs/simple_process.md: -------------------------------------------------------------------------------- 1 | # Simple Process 2 | 3 | Here we take GMI as an example to introduce the whole process. 4 | 5 | ## Data Preparation 6 | 7 | Follow [dataset.md](./datasets.md) to prepare the CelebA dataset. 8 | 9 | The structure of the dataset should be like this: 10 | ``` 11 | 12 | ├── public 13 | ├── private_train 14 | └── private_test 15 | ``` 16 | 17 | ## Classifier Training 18 | 19 | Here we train IR152 as the target model and FaceNet112 as the eval model. 20 | 21 | ### IR152 22 | 23 | To train the IR152 model with [classifier_training/celeba64.py](../examples/standard/classifier_training/celeba64.py) as an example, you can fill the paths in the script: 24 | ```python 25 | save_name = f'.pth' 26 | train_dataset_path = '/private_train' 27 | test_dataset_path = '/private_test' 28 | experiment_dir = '' 29 | backbone_path = '' 30 | ``` 31 | 32 | The `ir152_backbone_path` is the path to the pre-trained IR152 backbone model. You can download it from [Google Drive](https://drive.google.com/file/d/1qz6Z6X7Q1j7Q6j0VY9Zj1X0Zj1X0Zj1X/view?usp=sharing)/checkpoints_v2.0/classifier/backbones/Backbone_IR_152_Epoch_112_Batch_2547328_Time_2019-07-13-02-59_checkpoint.pth. 33 | 34 | The model will be saved in `/.pth`, denoted as `` in the following text. 35 | 36 | ### FaceNet112 37 | 38 | To train the FaceNet112 model with [classifier_training/celeba112.py](../examples/standard/classifier_training/celeba112.py) as an example, you can fill the paths in the script: 39 | ```python 40 | save_name = f'.pth' 41 | train_dataset_path = '/private_train' 42 | test_dataset_path = '/private_test' 43 | experiment_dir = '' 44 | backbone_path = '' 45 | ``` 46 | 47 | The `facenet112_backbone_path` is the path to the pre-trained IR152 backbone model. You can download it from [Google Drive](https://drive.google.com/file/d/1qz6Z6X7Q1j7Q6j0VY9Zj1X0Zj1X0Zj1X/view?usp=sharing)/checkpoints_v2.0/classifier/backbones/backbone_ir50_ms1m_epoch120.pth. 48 | 49 | 50 | The model will be saved in `/.pth`, denoted as `` in the following text. 51 | 52 | ## GMI GAN training 53 | 54 | To train the GMI GAN with [gan_training/gmi.py](../examples/standard/gan_training/gmi.py) as an example, you can fill the paths in the script: 55 | ```python 56 | dataset_path = '/public' 57 | experiment_dir = '' 58 | ``` 59 | 60 | The generator and discriminator will be saved in `/G.pth` and `/D.pth`, denoted as `` and `` in the following text. 61 | 62 | 63 | ## GMI Attack 64 | 65 | The attack script is [attacks/gmi.py](../examples/standard/attacks/gmi.py). You can fill the paths in the script: 66 | ```python 67 | experiment_dir = '' 68 | device_ids_available = '0' 69 | num_classes = 1000 70 | generator_ckpt_path = '' 71 | discriminator_ckpt_path = '' 72 | target_model_ckpt_path = '' 73 | eval_model_ckpt_path = '' 74 | eval_dataset_path = '/private_train' 75 | ``` 76 | 77 | The attack result will be saved in ``. 78 | 79 | Evaluation results are shown in `/optimized/evaluation.csv` -------------------------------------------------------------------------------- /examples/standard/README.md: -------------------------------------------------------------------------------- 1 | 2 | # examples-standard 3 | 4 | Here are some standard examples about how to use this toolbox. 5 | 6 | ## adapter_training 7 | 8 | 1. run the attack method of C2FMI: 9 | ```bash 10 | cd adapter_training 11 | python c2f.py 12 | ``` 13 | 14 | 2. run the attack method of C2FMI on high-resolution images: 15 | ```bash 16 | cd adapter_training 17 | python c2f_high.py 18 | ``` 19 | 20 | ## attacks 21 | 22 | 1. config the options in each attack script as instructed, such as in `brepmi.py`: 23 | ```python 24 | experiment_dir = '' 25 | generator_ckpt_path = '' 26 | target_model_ckpt_path = '' 27 | eval_model_ckpt_path = '' 28 | eval_dataset_path = '' 29 | ``` 30 | 31 | 2. run the attack method of BREPMI, GMI, LOMMA, MIRROR, PLGMI, PPA, RLBMI, VIM: 32 | ```bash 33 | cd attacks 34 | python .py 35 | ``` 36 | 37 | ## classifier_training 38 | 39 | 1. config the options in each training script as instructed, such as in `celeba64.py`: 40 | ```python 41 | train_dataset_path = '' 42 | test_dataset_path = '' 43 | experiment_dir = '' 44 | backbone_path = '' 45 | ``` 46 | 47 | 2. run the training scripts of classifiers under various resolutions: 48 | ```bash 49 | cd classifier_training 50 | python .py 51 | ``` 52 | 53 | ## dataset_preprocess 54 | 55 | 1. config the options in each preprocess script as instructed, such as in `afhqdogs256.py`: 56 | ```python 57 | src_path = '' 58 | dst_path = '' 59 | ``` 60 | 61 | 2. run the preprocess scripts of datasets under various resolutions: 62 | ```bash 63 | cd dataset_preprocess 64 | python .py 65 | ``` 66 | 67 | ## gan_training 68 | 69 | 1. config the options in each gan training script as instructed, such as in `gmi.py`: 70 | ```python 71 | dataset_path = '' 72 | experiment_dir = '' 73 | ``` 74 | 75 | 2. run the preprocess scripts of datasets under various resolutions: 76 | ```bash 77 | cd gan_training 78 | python .py 79 | ``` 80 | -------------------------------------------------------------------------------- /examples/standard/adapter_training/c2f.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import time 4 | 5 | sys.path.append("../../../src") 6 | 7 | import torch 8 | from torch import nn 9 | from torch.utils.data import DataLoader 10 | from torchvision.transforms import ToTensor, Compose, Resize 11 | 12 | from modelinversion.models import ( 13 | SimpleGenerator64, 14 | GmiDiscriminator64, 15 | auto_classifier_from_pretrained, 16 | ) 17 | from modelinversion.models.adapters.c2f import C2fThreeLayerMlpOutputMapping 18 | from modelinversion.train import GmiGanTrainer, GmiGanTrainConfig, train_mapping_model 19 | from modelinversion.utils import Logger 20 | from modelinversion.datasets import InfiniteSamplerWrapper, CelebA64 21 | 22 | if __name__ == '__main__': 23 | 24 | target_model_ckpt_path = '' 25 | embed_model_ckpt_path = '/casia_incv1.pth' 26 | dataset_path = '' 27 | 28 | dataset_map_name = 'ffhq64_facescrub64' 29 | target_name = 'ir152' 30 | experiment_dir = f'../../../results_mapping/c2f/{dataset_map_name}/{target_name}' 31 | 32 | batch_size = 256 33 | 34 | device_ids_str = '3' 35 | 36 | # prepare logger 37 | 38 | now_time = time.strftime(r'%Y%m%d_%H%M', time.localtime(time.time())) 39 | logger = Logger(experiment_dir, f'train_gan_{now_time}.log') 40 | 41 | # prepare devices 42 | 43 | os.environ["CUDA_VISIBLE_DEVICES"] = device_ids_str 44 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 45 | device = torch.device(device) 46 | gpu_devices = [i for i in range(torch.cuda.device_count())] 47 | 48 | # prepare target models 49 | 50 | target_model = auto_classifier_from_pretrained(target_model_ckpt_path) 51 | target_model = nn.DataParallel(target_model, device_ids=gpu_devices).to(device) 52 | target_model.eval() 53 | 54 | embed_model = auto_classifier_from_pretrained(embed_model_ckpt_path) 55 | embed_model = nn.DataParallel(embed_model, device_ids=gpu_devices).to(device) 56 | embed_model.eval() 57 | # print(target_model.training) 58 | # exit() 59 | 60 | # prepare dataset 61 | 62 | from torchvision.datasets import ImageFolder 63 | 64 | dataset = ImageFolder( 65 | dataset_path, 66 | transform=ToTensor(), 67 | ) 68 | # dataset = CelebA64(dataset_path, ToTensor()) 69 | dataloader = DataLoader( 70 | dataset, 71 | batch_size=batch_size, 72 | shuffle=True, 73 | # sampler=InfiniteSamplerWrapper(dataset), 74 | ) 75 | 76 | mapping = C2fThreeLayerMlpOutputMapping( 77 | target_model.module.num_classes, 4096, embed_model.module.num_classes 78 | ) 79 | mapping = nn.DataParallel(mapping).to(device) 80 | mapping.train() 81 | 82 | optimizer = torch.optim.Adam(mapping.parameters(), lr=0.001) 83 | optim_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.8) 84 | 85 | train_mapping_model( 86 | 40, 87 | mapping, 88 | optimizer, 89 | target_model, 90 | embed_model, 91 | dataloader, 92 | device=device, 93 | save_path=os.path.join(experiment_dir, 'mapping.pth'), 94 | schedular=optim_scheduler, 95 | ) 96 | -------------------------------------------------------------------------------- /examples/standard/attacks/brepmi.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import argparse 4 | import time 5 | 6 | sys.path.append('../../../src') 7 | 8 | import torch 9 | from torch import nn 10 | from torchvision.transforms import ToTensor, Compose, Resize 11 | 12 | from modelinversion.models import ( 13 | SimpleGenerator64, 14 | IR152_64, 15 | FaceNet112, 16 | ) 17 | from modelinversion.sampler import LabelOnlySelectLatentsSampler 18 | from modelinversion.utils import Logger 19 | from modelinversion.attack import ( 20 | BrepOptimizationConfig, 21 | BrepOptimization, 22 | ImageClassifierAttackConfig, 23 | ImageClassifierAttacker, 24 | ) 25 | from modelinversion.scores import ImageClassificationAugmentLabelOnlyScore 26 | from modelinversion.metrics import ( 27 | ImageClassifierAttackAccuracy, 28 | ImageDistanceMetric, 29 | ImageFidPRDCMetric, 30 | ) 31 | from modelinversion.datasets import CelebA112 32 | 33 | if __name__ == '__main__': 34 | 35 | experiment_dir = '' 36 | device_ids_available = '2' 37 | num_classes = 1000 38 | generator_ckpt_path = '' 39 | target_model_ckpt_path = '' 40 | eval_model_ckpt_path = '' 41 | eval_dataset_path = '' 42 | attack_targets = list(range(10)) 43 | 44 | batch_size = 100 45 | 46 | # prepare logger 47 | 48 | now_time = time.strftime(r'%Y%m%d_%H%M', time.localtime(time.time())) 49 | logger = Logger(experiment_dir, f'attack_{now_time}.log') 50 | 51 | # prepare devices 52 | 53 | os.environ["CUDA_VISIBLE_DEVICES"] = device_ids_available 54 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 55 | device = torch.device(device) 56 | gpu_devices = [i for i in range(torch.cuda.device_count())] 57 | 58 | # prepare models 59 | 60 | z_dim = 100 61 | 62 | target_model = IR152_64(num_classes=num_classes) 63 | eval_model = FaceNet112(num_classes, register_last_feature_hook=True) 64 | generator = SimpleGenerator64(in_dim=z_dim) 65 | # discriminator = KedmiDiscriminator64(num_classes=num_classes) 66 | 67 | target_model.load_state_dict( 68 | torch.load(target_model_ckpt_path, map_location='cpu')['state_dict'] 69 | ) 70 | eval_model.load_state_dict( 71 | torch.load(eval_model_ckpt_path, map_location='cpu')['state_dict'] 72 | ) 73 | generator.load_state_dict( 74 | torch.load(generator_ckpt_path, map_location='cpu')['state_dict'] 75 | ) 76 | 77 | target_model = nn.DataParallel(target_model, device_ids=gpu_devices).to(device) 78 | eval_model = nn.DataParallel(eval_model, device_ids=gpu_devices).to(device) 79 | generator = nn.DataParallel(generator, device_ids=gpu_devices).to(device) 80 | 81 | target_model.eval() 82 | eval_model.eval() 83 | generator.eval() 84 | 85 | latents_sampler = LabelOnlySelectLatentsSampler( 86 | z_dim, batch_size, generator, target_model, device=device 87 | ) 88 | 89 | # prepare eval dataset 90 | 91 | eval_dataset = CelebA112( 92 | eval_dataset_path, 93 | output_transform=ToTensor(), 94 | ) 95 | 96 | # prepare optimization 97 | 98 | optimization_config = BrepOptimizationConfig( 99 | experiment_dir=experiment_dir, device=device, iter_times=1000 100 | ) 101 | 102 | image_score_fn = ImageClassificationAugmentLabelOnlyScore( 103 | classifier=target_model, device=device, correct_score=1, wrong_score=-1 104 | ) 105 | 106 | optimization_fn = BrepOptimization( 107 | config=optimization_config, generator=generator, image_score_fn=image_score_fn 108 | ) 109 | 110 | # prepare metrics 111 | 112 | accuracy_metric = ImageClassifierAttackAccuracy( 113 | batch_size, eval_model, device=device, description='evaluation' 114 | ) 115 | 116 | distance_metric = ImageDistanceMetric( 117 | batch_size, 118 | eval_model, 119 | eval_dataset, 120 | device=device, 121 | description='evaluation', 122 | save_individual_res_dir=experiment_dir, 123 | ) 124 | 125 | fid_prdc_metric = ImageFidPRDCMetric( 126 | batch_size, 127 | eval_dataset, 128 | device=device, 129 | save_individual_prdc_dir=experiment_dir, 130 | fid=True, 131 | prdc=True, 132 | ) 133 | 134 | # prepare attack 135 | 136 | attack_config = ImageClassifierAttackConfig( 137 | latents_sampler, 138 | optimize_num=50, 139 | optimize_batch_size=batch_size, 140 | optimize_fn=optimization_fn, 141 | save_dir=experiment_dir, 142 | save_optimized_images=True, 143 | save_final_images=False, 144 | eval_metrics=[accuracy_metric, distance_metric, fid_prdc_metric], 145 | eval_optimized_result=True, 146 | eval_final_result=False, 147 | ) 148 | 149 | attacker = ImageClassifierAttacker(attack_config) 150 | 151 | attacker.attack(attack_targets) 152 | -------------------------------------------------------------------------------- /examples/standard/classifier_training/celeba112.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import time 4 | 5 | sys.path.append('../../../src') 6 | 7 | import torch 8 | from torch import nn 9 | from torch.utils.data import DataLoader 10 | from torchvision.transforms import ( 11 | ToTensor, 12 | Compose, 13 | ColorJitter, 14 | RandomResizedCrop, 15 | RandomHorizontalFlip, 16 | Normalize, 17 | Resize, 18 | ) 19 | 20 | from modelinversion.models import FaceNet112 21 | from modelinversion.train import SimpleTrainer, SimpleTrainConfig 22 | from modelinversion.utils import Logger 23 | from modelinversion.datasets import InfiniteSamplerWrapper, CelebA112 24 | 25 | if __name__ == '__main__': 26 | 27 | # prepare path args 28 | 29 | num_classes = 1000 30 | model_name = 'facenet112' 31 | save_name = f'{model_name}.pth' 32 | train_dataset_path = '' 33 | test_dataset_path = '' 34 | experiment_dir = '' 35 | backbone_path = '' 36 | 37 | batch_size = 128 38 | epoch_num = 100 39 | 40 | device_ids_available = '1' 41 | pin_memory = True 42 | 43 | # prepare logger 44 | 45 | now_time = time.strftime(r'%Y%m%d_%H%M', time.localtime(time.time())) 46 | logger = Logger(experiment_dir, f'train_classifier_{now_time}.log') 47 | 48 | # prepare devices 49 | 50 | os.environ["CUDA_VISIBLE_DEVICES"] = device_ids_available 51 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 52 | device = torch.device(device) 53 | gpu_devices = [i for i in range(torch.cuda.device_count())] 54 | 55 | # prepare target model 56 | 57 | model = FaceNet112(num_classes, backbone_path=backbone_path) 58 | model = nn.DataParallel(model, device_ids=gpu_devices).to(device) 59 | 60 | optimizer = torch.optim.SGD( 61 | model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4 62 | ) 63 | lr_schedular = None 64 | 65 | # prepare dataset 66 | 67 | train_dataset = CelebA112( 68 | train_dataset_path, 69 | output_transform=Compose( 70 | [ 71 | ToTensor(), 72 | RandomHorizontalFlip(p=0.5), 73 | ] 74 | ), 75 | ) 76 | test_dataset = CelebA112( 77 | test_dataset_path, 78 | output_transform=Compose([ToTensor()]), 79 | ) 80 | 81 | train_loader = DataLoader( 82 | train_dataset, batch_size=batch_size, shuffle=True, pin_memory=pin_memory 83 | ) 84 | test_loader = DataLoader( 85 | test_dataset, batch_size=batch_size, shuffle=False, pin_memory=pin_memory 86 | ) 87 | 88 | # prepare train config 89 | 90 | config = SimpleTrainConfig( 91 | experiment_dir=experiment_dir, 92 | save_name=save_name, 93 | # train args 94 | device=device, 95 | model=model, 96 | optimizer=optimizer, 97 | lr_scheduler=lr_schedular, 98 | loss_fn='cross_entropy', 99 | ) 100 | 101 | trainer = SimpleTrainer(config) 102 | 103 | trainer.train(epoch_num, train_loader, test_loader) 104 | -------------------------------------------------------------------------------- /examples/standard/classifier_training/celeba224.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import time 4 | 5 | sys.path.append('../../../src') 6 | 7 | import torch 8 | from torch import nn 9 | from torch.utils.data import DataLoader 10 | from torchvision.transforms import ( 11 | ToTensor, 12 | Compose, 13 | ColorJitter, 14 | RandomResizedCrop, 15 | RandomHorizontalFlip, 16 | Normalize, 17 | ) 18 | 19 | from modelinversion.models import TorchvisionClassifierModel 20 | from modelinversion.train import SimpleTrainer, SimpleTrainConfig 21 | from modelinversion.utils import Logger 22 | from modelinversion.datasets import InfiniteSamplerWrapper, CelebA224 23 | 24 | if __name__ == '__main__': 25 | 26 | num_classes = 1000 27 | torchvison_model_name = 'resnet152' 28 | save_name = f'{torchvison_model_name}.pth' 29 | train_dataset_path = '../../../test/celeba/private_train' 30 | test_dataset_path = '../../../test/celeba/private_test' 31 | experiment_dir = '../../../test/resnet152' 32 | 33 | batch_size = 128 34 | epoch_num = 100 35 | 36 | device_ids_available = '0' 37 | pin_memory = True 38 | 39 | # prepare logger 40 | 41 | now_time = time.strftime(r'%Y%m%d_%H%M', time.localtime(time.time())) 42 | logger = Logger(experiment_dir, f'train_classifier_{now_time}.log') 43 | 44 | # prepare devices 45 | 46 | os.environ["CUDA_VISIBLE_DEVICES"] = device_ids_available 47 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 48 | device = torch.device(device) 49 | gpu_devices = [i for i in range(torch.cuda.device_count())] 50 | 51 | # prepare target model 52 | 53 | model = TorchvisionClassifierModel( 54 | arch_name=torchvison_model_name, num_classes=num_classes, weights='DEFAULT' 55 | ) 56 | model = nn.DataParallel(model, device_ids=gpu_devices).to(device) 57 | 58 | optimizer = torch.optim.Adam(model.parameters(), lr=0.001, betas=[0.9, 0.999]) 59 | lr_schedular = torch.optim.lr_scheduler.MultiStepLR( 60 | optimizer, milestones=[75, 90], gamma=0.1 61 | ) 62 | 63 | # prepare dataset 64 | 65 | train_dataset = CelebA224( 66 | train_dataset_path, 67 | output_transform=Compose( 68 | [ 69 | ToTensor(), 70 | RandomResizedCrop( 71 | size=(224, 224), scale=(0.85, 1), ratio=(1, 1), antialias=True 72 | ), 73 | ColorJitter(brightness=0.2, contrast=0.2, saturation=0.1, hue=0.1), 74 | RandomHorizontalFlip(p=0.5), 75 | Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), 76 | ] 77 | ), 78 | ) 79 | test_dataset = CelebA224( 80 | test_dataset_path, 81 | output_transform=Compose( 82 | [ToTensor(), Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])] 83 | ), 84 | ) 85 | train_loader = DataLoader( 86 | train_dataset, batch_size=batch_size, shuffle=True, pin_memory=pin_memory 87 | ) 88 | test_loader = DataLoader( 89 | test_dataset, batch_size=batch_size, shuffle=False, pin_memory=pin_memory 90 | ) 91 | 92 | # prepare train config 93 | 94 | config = SimpleTrainConfig( 95 | experiment_dir=experiment_dir, 96 | save_name=save_name, 97 | device=device, 98 | model=model, 99 | optimizer=optimizer, 100 | lr_scheduler=lr_schedular, 101 | loss_fn='cross_entropy', 102 | ) 103 | 104 | trainer = SimpleTrainer(config) 105 | 106 | trainer.train(epoch_num, train_loader, test_loader) 107 | -------------------------------------------------------------------------------- /examples/standard/classifier_training/celeba64.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import time 4 | 5 | sys.path.append('../../../src') 6 | 7 | import torch 8 | from torch import nn 9 | from torch.utils.data import DataLoader 10 | from torchvision.transforms import ( 11 | ToTensor, 12 | Compose, 13 | ColorJitter, 14 | RandomResizedCrop, 15 | RandomHorizontalFlip, 16 | Normalize, 17 | Resize, 18 | ) 19 | 20 | from modelinversion.models import IR152_64 21 | from modelinversion.train import SimpleTrainer, SimpleTrainConfig 22 | from modelinversion.utils import Logger 23 | from modelinversion.datasets import InfiniteSamplerWrapper, CelebA64 24 | 25 | if __name__ == '__main__': 26 | 27 | # prepare path args 28 | 29 | num_classes = 1000 30 | model_name = 'ir152' 31 | save_name = f'{model_name}.pth' 32 | train_dataset_path = '' 33 | test_dataset_path = '' 34 | experiment_dir = '' 35 | backbone_path = '' 36 | 37 | batch_size = 128 38 | epoch_num = 100 39 | 40 | device_ids_available = '1' 41 | pin_memory = True 42 | 43 | # prepare logger 44 | 45 | now_time = time.strftime(r'%Y%m%d_%H%M', time.localtime(time.time())) 46 | logger = Logger(experiment_dir, f'train_classifier_{now_time}.log') 47 | 48 | # prepare devices 49 | 50 | os.environ["CUDA_VISIBLE_DEVICES"] = device_ids_available 51 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 52 | device = torch.device(device) 53 | gpu_devices = [i for i in range(torch.cuda.device_count())] 54 | 55 | # prepare target model 56 | 57 | model = IR152_64(num_classes, backbone_path=backbone_path) 58 | model = nn.DataParallel(model, device_ids=gpu_devices).to(device) 59 | 60 | optimizer = torch.optim.SGD( 61 | model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4 62 | ) 63 | lr_schedular = None 64 | 65 | # prepare dataset 66 | 67 | train_dataset = CelebA64( 68 | train_dataset_path, 69 | output_transform=Compose( 70 | [ 71 | ToTensor(), 72 | RandomHorizontalFlip(p=0.5), 73 | ] 74 | ), 75 | ) 76 | test_dataset = CelebA64( 77 | test_dataset_path, 78 | output_transform=Compose([ToTensor()]), 79 | ) 80 | 81 | train_loader = DataLoader( 82 | train_dataset, batch_size=batch_size, shuffle=True, pin_memory=pin_memory 83 | ) 84 | test_loader = DataLoader( 85 | test_dataset, batch_size=batch_size, shuffle=False, pin_memory=pin_memory 86 | ) 87 | 88 | # prepare train config 89 | 90 | config = SimpleTrainConfig( 91 | experiment_dir=experiment_dir, 92 | save_name=save_name, 93 | # train args 94 | device=device, 95 | model=model, 96 | optimizer=optimizer, 97 | lr_scheduler=lr_schedular, 98 | loss_fn='cross_entropy', 99 | ) 100 | 101 | trainer = SimpleTrainer(config) 102 | 103 | trainer.train(epoch_num, train_loader, test_loader) 104 | -------------------------------------------------------------------------------- /examples/standard/classifier_training/celeba64_ir152_bido.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import time 4 | 5 | sys.path.append('../../../src') 6 | 7 | import torch 8 | from torch import nn 9 | from torch.utils.data import DataLoader 10 | from torchvision.datasets import ImageFolder 11 | from torchvision.transforms import ( 12 | ToTensor, 13 | Compose, 14 | ColorJitter, 15 | RandomResizedCrop, 16 | RandomHorizontalFlip, 17 | Normalize, 18 | Resize, 19 | ) 20 | 21 | from modelinversion.models import IR152_64, BiDOWrapper 22 | from modelinversion.train import BiDOTrainConfig, BiDOTrainer 23 | from modelinversion.utils import Logger 24 | from modelinversion.datasets import CelebA 25 | 26 | if __name__ == '__main__': 27 | 28 | num_classes = 1000 29 | model_name = 'ir152' 30 | save_name = f'celeba64_{model_name}_bido_ih0.05_oh2.pth' 31 | train_dataset_path = '' 32 | test_dataset_path = '' 33 | experiment_dir = '' 34 | backbone_path = '../../../checkpoints_v2/classifier/backbones/Backbone_IR_152_Epoch_112_Batch_2547328_Time_2019-07-13-02-59_checkpoint.pth' 35 | 36 | batch_size = 128 37 | epoch_num = 150 38 | 39 | device_ids_str = '2' 40 | pin_memory = False 41 | 42 | # prepare logger 43 | 44 | now_time = time.strftime(r'%Y%m%d_%H%M', time.localtime(time.time())) 45 | logger = Logger(experiment_dir, f'train_classifier_{now_time}.log') 46 | 47 | # prepare devices 48 | 49 | os.environ["CUDA_VISIBLE_DEVICES"] = device_ids_str 50 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 51 | device = torch.device(device) 52 | gpu_devices = [i for i in range(torch.cuda.device_count())] 53 | 54 | # prepare target model 55 | 56 | model = IR152_64( 57 | num_classes, backbone_path=backbone_path, register_last_feature_hook=True 58 | ) 59 | model = BiDOWrapper(model) 60 | model = nn.DataParallel(model, device_ids=gpu_devices).to(device) 61 | 62 | optimizer = torch.optim.SGD( 63 | model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4 64 | ) 65 | lr_schedular = torch.optim.lr_scheduler.MultiStepLR( 66 | optimizer, milestones=[75, 100, 125, 140], gamma=0.3 67 | ) 68 | lr_schedular = None 69 | 70 | # prepare dataset 71 | 72 | train_dataset = CelebA( 73 | train_dataset_path, 74 | crop_center=True, 75 | preprocess_resolution=64, 76 | transform=Compose( 77 | [ 78 | ToTensor(), 79 | RandomHorizontalFlip(p=0.5), 80 | ] 81 | ), 82 | ) 83 | test_dataset = CelebA( 84 | test_dataset_path, 85 | crop_center=True, 86 | preprocess_resolution=64, 87 | transform=Compose([ToTensor()]), 88 | ) 89 | 90 | train_loader = DataLoader( 91 | train_dataset, 92 | batch_size=batch_size, 93 | shuffle=True, 94 | pin_memory=pin_memory, 95 | num_workers=4, 96 | ) 97 | test_loader = DataLoader( 98 | test_dataset, 99 | batch_size=batch_size, 100 | shuffle=False, 101 | pin_memory=pin_memory, 102 | num_workers=4, 103 | ) 104 | 105 | # prepare train config 106 | 107 | config = BiDOTrainConfig( 108 | experiment_dir=experiment_dir, 109 | save_name=save_name, 110 | device=device, 111 | model=model, 112 | optimizer=optimizer, 113 | lr_scheduler=lr_schedular, 114 | loss_fn='cross_entropy', 115 | coef_hidden_input=0.05, 116 | coef_hidden_output=2, 117 | ) 118 | 119 | trainer = BiDOTrainer(config) 120 | 121 | trainer.train(epoch_num, train_loader, test_loader) 122 | -------------------------------------------------------------------------------- /examples/standard/classifier_training/celeba64_ir152_ls.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import time 4 | 5 | sys.path.append('../../../src') 6 | 7 | import torch 8 | from torch import nn 9 | from torch.utils.data import DataLoader 10 | from torchvision.datasets import ImageFolder 11 | from torchvision.transforms import ( 12 | ToTensor, 13 | Compose, 14 | ColorJitter, 15 | RandomResizedCrop, 16 | RandomHorizontalFlip, 17 | Normalize, 18 | Resize, 19 | ) 20 | 21 | from modelinversion.models import IR152_64 22 | from modelinversion.train import SimpleTrainer, SimpleTrainConfig 23 | from modelinversion.utils import Logger, LabelSmoothingCrossEntropyLoss 24 | from modelinversion.datasets import CelebA 25 | 26 | if __name__ == '__main__': 27 | 28 | num_classes = 1000 29 | model_name = 'ir152' 30 | save_name = f'celeba64_{model_name}_ls_-0.05.pth' 31 | train_dataset_path = '' 32 | test_dataset_path = '' 33 | experiment_dir = '' 34 | backbone_path = '../../../checkpoints_v2/classifier/backbones/Backbone_IR_152_Epoch_112_Batch_2547328_Time_2019-07-13-02-59_checkpoint.pth' 35 | 36 | batch_size = 128 37 | epoch_num = 100 38 | 39 | device_ids_str = '1' 40 | pin_memory = False 41 | 42 | # prepare logger 43 | 44 | now_time = time.strftime(r'%Y%m%d_%H%M', time.localtime(time.time())) 45 | logger = Logger(experiment_dir, f'train_classifier_{now_time}.log') 46 | 47 | # prepare devices 48 | 49 | os.environ["CUDA_VISIBLE_DEVICES"] = device_ids_str 50 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 51 | device = torch.device(device) 52 | gpu_devices = [i for i in range(torch.cuda.device_count())] 53 | 54 | # prepare target model 55 | 56 | model = IR152_64(num_classes, backbone_path=backbone_path) 57 | model = nn.DataParallel(model, device_ids=gpu_devices).to(device) 58 | 59 | optimizer = torch.optim.SGD( 60 | model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4 61 | ) 62 | lr_schedular = torch.optim.lr_scheduler.MultiStepLR( 63 | optimizer, milestones=[75, 90], gamma=0.1 64 | ) 65 | lr_schedular = None 66 | 67 | # prepare dataset 68 | 69 | train_dataset = CelebA( 70 | train_dataset_path, 71 | crop_center=True, 72 | preprocess_resolution=64, 73 | transform=Compose( 74 | [ 75 | ToTensor(), 76 | RandomHorizontalFlip(p=0.5), 77 | ] 78 | ), 79 | ) 80 | test_dataset = CelebA( 81 | test_dataset_path, 82 | crop_center=True, 83 | preprocess_resolution=64, 84 | transform=Compose([ToTensor()]), 85 | ) 86 | 87 | train_loader = DataLoader( 88 | train_dataset, batch_size=batch_size, shuffle=True, pin_memory=pin_memory 89 | ) 90 | test_loader = DataLoader( 91 | test_dataset, batch_size=batch_size, shuffle=False, pin_memory=pin_memory 92 | ) 93 | 94 | # prepare train config 95 | 96 | config = SimpleTrainConfig( 97 | experiment_dir=experiment_dir, 98 | save_name=save_name, 99 | device=device, 100 | model=model, 101 | optimizer=optimizer, 102 | lr_scheduler=lr_schedular, 103 | loss_fn=LabelSmoothingCrossEntropyLoss(-0.05), 104 | ) 105 | 106 | trainer = SimpleTrainer(config) 107 | 108 | trainer.train(epoch_num, train_loader, test_loader) 109 | -------------------------------------------------------------------------------- /examples/standard/classifier_training/celeba64_ir152_tl.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import time 4 | 5 | sys.path.append('../../../src') 6 | 7 | import torch 8 | from torch import nn 9 | from torch.utils.data import DataLoader 10 | from torchvision.datasets import ImageFolder 11 | from torchvision.transforms import ( 12 | ToTensor, 13 | Compose, 14 | ColorJitter, 15 | RandomResizedCrop, 16 | RandomHorizontalFlip, 17 | Normalize, 18 | Resize, 19 | ) 20 | 21 | from modelinversion.models import IR152_64 22 | from modelinversion.train import SimpleTrainer, SimpleTrainConfig 23 | from modelinversion.utils import Logger, freeze_front_layers 24 | from modelinversion.datasets import CelebA 25 | 26 | if __name__ == '__main__': 27 | 28 | num_classes = 1000 29 | model_name = 'ir152' 30 | save_name = f'celeba64_{model_name}_tl.pth' 31 | train_dataset_path = '' 32 | test_dataset_path = '' 33 | experiment_dir = '' 34 | backbone_path = '../../../checkpoints_v2/classifier/backbones/Backbone_IR_152_Epoch_112_Batch_2547328_Time_2019-07-13-02-59_checkpoint.pth' 35 | 36 | batch_size = 128 37 | epoch_num = 100 38 | 39 | device_ids_str = '1' 40 | pin_memory = False 41 | 42 | # prepare logger 43 | 44 | now_time = time.strftime(r'%Y%m%d_%H%M', time.localtime(time.time())) 45 | logger = Logger(experiment_dir, f'train_classifier_{now_time}.log') 46 | 47 | # prepare devices 48 | 49 | os.environ["CUDA_VISIBLE_DEVICES"] = device_ids_str 50 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 51 | device = torch.device(device) 52 | gpu_devices = [i for i in range(torch.cuda.device_count())] 53 | 54 | # prepare target model 55 | 56 | model = IR152_64(num_classes, backbone_path=backbone_path) 57 | model = nn.DataParallel(model, device_ids=gpu_devices).to(device) 58 | freeze_front_layers(model) 59 | 60 | optimizer = torch.optim.SGD( 61 | model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4 62 | ) 63 | lr_schedular = torch.optim.lr_scheduler.MultiStepLR( 64 | optimizer, milestones=[75, 90], gamma=0.1 65 | ) 66 | lr_schedular = None 67 | 68 | # prepare dataset 69 | 70 | train_dataset = CelebA( 71 | train_dataset_path, 72 | crop_center=True, 73 | preprocess_resolution=64, 74 | transform=Compose( 75 | [ 76 | ToTensor(), 77 | RandomHorizontalFlip(p=0.5), 78 | ] 79 | ), 80 | ) 81 | test_dataset = CelebA( 82 | test_dataset_path, 83 | crop_center=True, 84 | preprocess_resolution=64, 85 | transform=Compose([ToTensor()]), 86 | ) 87 | 88 | train_loader = DataLoader( 89 | train_dataset, batch_size=batch_size, shuffle=True, pin_memory=pin_memory 90 | ) 91 | test_loader = DataLoader( 92 | test_dataset, batch_size=batch_size, shuffle=False, pin_memory=pin_memory 93 | ) 94 | 95 | # prepare train config 96 | 97 | config = SimpleTrainConfig( 98 | experiment_dir=experiment_dir, 99 | save_name=save_name, 100 | device=device, 101 | model=model, 102 | optimizer=optimizer, 103 | lr_scheduler=lr_schedular, 104 | loss_fn='cross_entropy', 105 | ) 106 | 107 | trainer = SimpleTrainer(config) 108 | 109 | trainer.train(epoch_num, train_loader, test_loader) 110 | -------------------------------------------------------------------------------- /examples/standard/classifier_training/celeba64_ir152_vib.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import time 4 | 5 | sys.path.append('../../../src') 6 | 7 | import torch 8 | from torch import nn 9 | from torch.utils.data import DataLoader 10 | from torchvision.datasets import ImageFolder 11 | from torchvision.transforms import ( 12 | ToTensor, 13 | Compose, 14 | ColorJitter, 15 | RandomResizedCrop, 16 | RandomHorizontalFlip, 17 | Normalize, 18 | Resize, 19 | ) 20 | 21 | from modelinversion.models import IR152_64, VibWrapper 22 | from modelinversion.train import VibTrainConfig, VibTrainer 23 | from modelinversion.utils import Logger 24 | from modelinversion.datasets import CelebA 25 | 26 | if __name__ == '__main__': 27 | 28 | num_classes = 1000 29 | model_name = 'ir152' 30 | save_name = f'celeba64_{model_name}_vib_-0.01.pth' 31 | train_dataset_path = '' 32 | test_dataset_path = '' 33 | experiment_dir = '' 34 | backbone_path = '../../../checkpoints_v2/classifier/backbones/Backbone_IR_152_Epoch_112_Batch_2547328_Time_2019-07-13-02-59_checkpoint.pth' 35 | 36 | batch_size = 128 37 | epoch_num = 100 38 | 39 | device_ids_str = '1' 40 | pin_memory = False 41 | 42 | # prepare logger 43 | 44 | now_time = time.strftime(r'%Y%m%d_%H%M', time.localtime(time.time())) 45 | logger = Logger(experiment_dir, f'train_classifier_{now_time}.log') 46 | 47 | # prepare devices 48 | 49 | os.environ["CUDA_VISIBLE_DEVICES"] = device_ids_str 50 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 51 | device = torch.device(device) 52 | gpu_devices = [i for i in range(torch.cuda.device_count())] 53 | 54 | # prepare target model 55 | 56 | model = IR152_64( 57 | num_classes, backbone_path=backbone_path, register_last_feature_hook=True 58 | ) 59 | model = VibWrapper(model) 60 | model = nn.DataParallel(model, device_ids=gpu_devices).to(device) 61 | 62 | optimizer = torch.optim.SGD( 63 | model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4 64 | ) 65 | lr_schedular = torch.optim.lr_scheduler.MultiStepLR( 66 | optimizer, milestones=[75, 90], gamma=0.1 67 | ) 68 | lr_schedular = None 69 | 70 | # prepare dataset 71 | 72 | train_dataset = CelebA( 73 | train_dataset_path, 74 | crop_center=True, 75 | preprocess_resolution=64, 76 | transform=Compose( 77 | [ 78 | ToTensor(), 79 | RandomHorizontalFlip(p=0.5), 80 | ] 81 | ), 82 | ) 83 | test_dataset = CelebA( 84 | test_dataset_path, 85 | crop_center=True, 86 | preprocess_resolution=64, 87 | transform=Compose([ToTensor()]), 88 | ) 89 | 90 | train_loader = DataLoader( 91 | train_dataset, batch_size=batch_size, shuffle=True, pin_memory=pin_memory 92 | ) 93 | test_loader = DataLoader( 94 | test_dataset, batch_size=batch_size, shuffle=False, pin_memory=pin_memory 95 | ) 96 | 97 | # prepare train config 98 | 99 | config = VibTrainConfig( 100 | experiment_dir=experiment_dir, 101 | save_name=save_name, 102 | device=device, 103 | model=model, 104 | optimizer=optimizer, 105 | lr_scheduler=lr_schedular, 106 | loss_fn='cross_entropy', 107 | ) 108 | 109 | trainer = VibTrainer(config) 110 | 111 | trainer.train(epoch_num, train_loader, test_loader) 112 | -------------------------------------------------------------------------------- /examples/standard/classifier_training/distill_celeba64_celeba64.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import time 4 | 5 | sys.path.append('../../../src') 6 | 7 | import torch 8 | from torch import nn 9 | from torch.utils.data import DataLoader 10 | from torchvision.transforms import ( 11 | ToTensor, 12 | Compose, 13 | ColorJitter, 14 | RandomResizedCrop, 15 | RandomHorizontalFlip, 16 | Normalize, 17 | Resize, 18 | ) 19 | 20 | from modelinversion.models import IR152_64, EfficientNet_b0_64 21 | from modelinversion.train import DistillTrainer, DistillTrainConfig 22 | from modelinversion.utils import Logger 23 | from modelinversion.datasets import CelebA64 24 | 25 | if __name__ == '__main__': 26 | 27 | num_classes = 1000 28 | model_name = 'efficientnet_b0' 29 | teacher_name = 'ir152' 30 | save_name = f'celeba64_{model_name}.pth' 31 | train_dataset_path = '../../../test/celeba/public' 32 | test_dataset_path = '../../../test/celeba/private_test' 33 | experiment_dir = f'../../../results/distill_celeba64_{model_name}_{teacher_name}_v3' 34 | teacher_ckpt_path = ( 35 | '../../../checkpoints_v2/classifier/celeba64/celeba64_ir152_93.71.pth' 36 | ) 37 | 38 | batch_size = 128 39 | epoch_num = 100 40 | 41 | device_ids_available = '1' 42 | pin_memory = False 43 | 44 | # prepare logger 45 | 46 | now_time = time.strftime(r'%Y%m%d_%H%M', time.localtime(time.time())) 47 | logger = Logger(experiment_dir, f'train_classifier_{now_time}.log') 48 | 49 | # prepare devices 50 | 51 | os.environ["CUDA_VISIBLE_DEVICES"] = device_ids_available 52 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 53 | device = torch.device(device) 54 | gpu_devices = [i for i in range(torch.cuda.device_count())] 55 | 56 | # prepare target model 57 | 58 | teacher = IR152_64(num_classes) 59 | teacher.load_state_dict( 60 | torch.load(teacher_ckpt_path, map_location='cpu')['state_dict'] 61 | ) 62 | teacher = teacher.to(device) 63 | 64 | model = EfficientNet_b0_64(num_classes, prtrained=True) 65 | model = nn.DataParallel(model, device_ids=gpu_devices).to(device) 66 | 67 | optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) 68 | lr_schedular = None 69 | 70 | # prepare dataset 71 | 72 | train_dataset = CelebA64( 73 | train_dataset_path, 74 | transform=Compose( 75 | [ 76 | ToTensor(), 77 | RandomHorizontalFlip(p=0.5), 78 | ] 79 | ), 80 | ) 81 | test_dataset = CelebA64( 82 | test_dataset_path, 83 | transform=Compose([ToTensor()]), 84 | ) 85 | 86 | train_loader = DataLoader( 87 | train_dataset, batch_size=batch_size, shuffle=True, pin_memory=pin_memory 88 | ) 89 | test_loader = DataLoader( 90 | test_dataset, batch_size=batch_size, shuffle=False, pin_memory=pin_memory 91 | ) 92 | 93 | # prepare train config 94 | 95 | config = DistillTrainConfig( 96 | experiment_dir=experiment_dir, 97 | save_name=save_name, 98 | device=device, 99 | model=model, 100 | optimizer=optimizer, 101 | lr_scheduler=lr_schedular, 102 | teacher=teacher, 103 | ) 104 | 105 | trainer = DistillTrainer(config) 106 | 107 | trainer.train(epoch_num, train_loader, test_loader) 108 | -------------------------------------------------------------------------------- /examples/standard/dataset_preprocess/afhqdogs256.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | sys.path.append('../../../src') 4 | 5 | from modelinversion.datasets.preprocess import preprocess_afhqdogs256 6 | 7 | if __name__ == '__main__': 8 | 9 | src_path = '' 10 | dst_path = '' 11 | 12 | preprocess_afhqdogs256(src_path, dst_path) 13 | -------------------------------------------------------------------------------- /examples/standard/dataset_preprocess/celeba.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | sys.path.append('../../../src') 4 | 5 | from modelinversion.datasets.preprocess import preprocess_celeba 6 | 7 | if __name__ == '__main__': 8 | 9 | src_path = '' 10 | dst_path = '' 11 | split_files_path = '' 12 | mode = 'copy' 13 | 14 | preprocess_celeba(src_path, dst_path, split_files_path, mode=mode) 15 | -------------------------------------------------------------------------------- /examples/standard/dataset_preprocess/ffhq256.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | sys.path.append('../../../src') 4 | 5 | from modelinversion.datasets.preprocess import preprocess_ffhq256 6 | 7 | if __name__ == '__main__': 8 | 9 | src_path = '' 10 | dst_path = '' 11 | 12 | preprocess_ffhq256(src_path, dst_path) 13 | -------------------------------------------------------------------------------- /examples/standard/dataset_preprocess/ffhq64.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | sys.path.append('../../../src') 4 | 5 | from modelinversion.datasets.preprocess import preprocess_ffhq64 6 | 7 | if __name__ == '__main__': 8 | 9 | src_path = '' 10 | dst_path = '' 11 | 12 | preprocess_ffhq64(src_path, dst_path) 13 | -------------------------------------------------------------------------------- /examples/standard/dataset_preprocess/lokt_generation.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | sys.path.append('../../../src') 5 | 6 | import torch 7 | from torch import nn 8 | import torchvision.transforms as TF 9 | 10 | from modelinversion.models import ( 11 | auto_classifier_from_pretrained, 12 | auto_generator_from_pretrained, 13 | ) 14 | from modelinversion.datasets import ( 15 | generator_generate_datasets, 16 | preprocess_celeba_fn, 17 | GeneratorDataset, 18 | ) 19 | 20 | if __name__ == '__main__': 21 | 22 | num_classes = 1000 23 | generator_ckpt_path = ( 24 | '../../../checkpoints_v2/attacks/lokt/lokt_celeba64_celeba64_ir152_G.pt' 25 | ) 26 | target_model_ckpt_path = ( 27 | '../../../checkpoints_v2/classifier/celeba64/celeba64_ir152_93.71.pth' 28 | ) 29 | dst_dataset_path = '../../../results/lokt_celeba_celeba_ir152_dataset/celeba64_celeba64_ir152_dataset.pt' 30 | 31 | batch_size = 200 32 | device_ids_available = '3' 33 | 34 | # prepare devices 35 | 36 | os.environ["CUDA_VISIBLE_DEVICES"] = device_ids_available 37 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 38 | device = torch.device(device) 39 | gpu_devices = [i for i in range(torch.cuda.device_count())] 40 | 41 | # dataset generator 42 | 43 | z_dim = 128 44 | 45 | generator = auto_generator_from_pretrained(generator_ckpt_path) 46 | generator = generator.to(device) 47 | generator.eval() 48 | 49 | # prepare target models 50 | 51 | target_model = auto_classifier_from_pretrained(target_model_ckpt_path) 52 | target_model = nn.DataParallel(target_model, device_ids=gpu_devices).to(device) 53 | target_model.eval() 54 | 55 | dataset = GeneratorDataset.create( 56 | z_dim, 57 | num_classes=num_classes, 58 | generate_num_per_class=500, 59 | generator=generator, 60 | target_model=target_model, 61 | batch_size=batch_size, 62 | device=device, 63 | ) 64 | 65 | dataset.save(dst_dataset_path) 66 | -------------------------------------------------------------------------------- /examples/standard/dataset_preprocess/metfaces256.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | sys.path.append('../../../src') 4 | 5 | from modelinversion.datasets.preprocess import preprocess_metfaces256 6 | 7 | if __name__ == '__main__': 8 | 9 | src_path = '' 10 | dst_path = '' 11 | 12 | preprocess_metfaces256(src_path, dst_path) 13 | -------------------------------------------------------------------------------- /examples/standard/dataset_preprocess/plgmi_top_k_selection.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | sys.path.append('../../../src') 5 | 6 | import torch 7 | from torch import nn 8 | from torchvision.transforms import functional as TF 9 | 10 | from modelinversion.models import IR152_64 11 | from modelinversion.datasets import top_k_selection 12 | 13 | if __name__ == '__main__': 14 | 15 | top_k = 30 16 | num_classes = 1000 17 | target_model_ckpt_path = '' 18 | src_dataset_path = '' 19 | dst_dataset_path = '' 20 | 21 | batch_size = 50 22 | device_ids_available = '0' 23 | 24 | # prepare devices 25 | 26 | os.environ["CUDA_VISIBLE_DEVICES"] = device_ids_available 27 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 28 | device = torch.device(device) 29 | gpu_devices = [i for i in range(torch.cuda.device_count())] 30 | 31 | # prepare target models 32 | 33 | target_model = IR152_64(num_classes=num_classes) 34 | target_model.load_state_dict( 35 | torch.load(target_model_ckpt_path, map_location='cpu')['state_dict'] 36 | ) 37 | target_model = nn.DataParallel(target_model, device_ids=gpu_devices).to(device) 38 | 39 | # dataset generation 40 | 41 | top_k_selection( 42 | top_k=top_k, 43 | src_dataset_path=src_dataset_path, 44 | dst_dataset_path=dst_dataset_path, 45 | batch_size=batch_size, 46 | target_model=target_model, 47 | num_classes=num_classes, 48 | device=device, 49 | create_aug_images_fn=lambda img: TF.resize(img, (64, 64), antialias=True), 50 | ) 51 | -------------------------------------------------------------------------------- /examples/standard/gan_training/gmi.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import time 4 | 5 | sys.path.append("../../../src") 6 | 7 | import torch 8 | from torch import nn 9 | from torch.utils.data import DataLoader 10 | from torchvision.transforms import ToTensor, Compose, Resize 11 | 12 | from modelinversion.models import SimpleGenerator64, GmiDiscriminator64 13 | from modelinversion.train import GmiGanTrainer, GmiGanTrainConfig 14 | from modelinversion.utils import Logger 15 | from modelinversion.datasets import InfiniteSamplerWrapper, CelebA64 16 | 17 | if __name__ == "__main__": 18 | 19 | dataset_path = '' 20 | experiment_dir = '' 21 | 22 | batch_size = 64 23 | max_iters = 150000 24 | 25 | device_ids_available = "2" 26 | 27 | # prepare logger 28 | 29 | now_time = time.strftime(r"%Y%m%d_%H%M", time.localtime(time.time())) 30 | logger = Logger(experiment_dir, f"train_gan_{now_time}.log") 31 | 32 | # prepare devices 33 | 34 | os.environ["CUDA_VISIBLE_DEVICES"] = device_ids_available 35 | device = "cuda" if torch.cuda.is_available() else "cpu" 36 | device = torch.device(device) 37 | gpu_devices = [i for i in range(torch.cuda.device_count())] 38 | 39 | # prepare dataset 40 | 41 | dataset = CelebA64( 42 | dataset_path, 43 | output_transform=ToTensor(), 44 | ) 45 | dataloader = DataLoader( 46 | dataset, batch_size=batch_size, sampler=InfiniteSamplerWrapper(dataset) 47 | ) 48 | 49 | # prepare GANs 50 | 51 | z_dim = 100 52 | 53 | generator = SimpleGenerator64(in_dim=z_dim) 54 | discriminator = GmiDiscriminator64() 55 | 56 | generator = nn.DataParallel(generator, device_ids=gpu_devices).to(device) 57 | discriminator = nn.DataParallel(discriminator, device_ids=gpu_devices).to(device) 58 | 59 | gen_optimizer = torch.optim.Adam( 60 | generator.parameters(), lr=0.0002, betas=(0.5, 0.999) 61 | ) 62 | dis_optimizer = torch.optim.Adam( 63 | discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999) 64 | ) 65 | 66 | # prepare trainer 67 | 68 | train_config = GmiGanTrainConfig( 69 | experiment_dir=experiment_dir, 70 | # train args 71 | batch_size=batch_size, 72 | input_size=z_dim, 73 | generator=generator, 74 | discriminator=discriminator, 75 | device=device, 76 | gen_optimizer=gen_optimizer, 77 | dis_optimizer=dis_optimizer, 78 | # log args 79 | save_ckpt_iters=1000, 80 | show_images_iters=1000, 81 | show_train_info_iters=100, 82 | ) 83 | 84 | trainer = GmiGanTrainer(train_config) 85 | 86 | # train gan 87 | 88 | trainer.train(dataloader, max_iters) 89 | -------------------------------------------------------------------------------- /examples/standard/gan_training/kedmi.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import time 4 | 5 | sys.path.append('../../../src') 6 | 7 | import kornia 8 | import torch 9 | from torch import nn 10 | from torch.utils.data import DataLoader 11 | from torchvision.transforms import ToTensor, Compose 12 | 13 | from modelinversion.models import ( 14 | auto_classifier_from_pretrained, 15 | KedmiDiscriminator64, 16 | SimpleGenerator64, 17 | ) 18 | from modelinversion.train import KedmiGanTrainer, KedmiGanTrainConfig 19 | from modelinversion.utils import Logger 20 | from modelinversion.datasets import InfiniteSamplerWrapper, CelebA64 21 | 22 | 23 | if __name__ == '__main__': 24 | 25 | num_classes = 1000 26 | target_model_ckpt_path = ( 27 | '../../../checkpoints_v2/classifier/celeba64/celeba64_ir152_93.71.pth' 28 | ) 29 | dataset_path = '../../../dataset/celeba/private_train' 30 | experiment_dir = '../../../results/kedmi_celeba64_celeba64_ir152_gan' 31 | 32 | batch_size = 64 33 | max_iters = 50000 34 | 35 | device_ids_str = '0' 36 | 37 | # prepare logger 38 | 39 | now_time = time.strftime(r'%Y%m%d_%H%M', time.localtime(time.time())) 40 | logger = Logger(experiment_dir, f'train_gan_{now_time}.log') 41 | 42 | # prepare devices 43 | 44 | os.environ["CUDA_VISIBLE_DEVICES"] = device_ids_str 45 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 46 | device = torch.device(device) 47 | gpu_devices = [i for i in range(torch.cuda.device_count())] 48 | 49 | # prepare target models 50 | 51 | target_model = auto_classifier_from_pretrained(target_model_ckpt_path) 52 | target_model = nn.DataParallel(target_model, device_ids=gpu_devices).to(device) 53 | target_model.eval() 54 | 55 | # prepare dataset 56 | 57 | from torchvision.datasets import ImageFolder 58 | 59 | dataset = CelebA64(dataset_path, ToTensor()) 60 | dataloader = iter( 61 | DataLoader( 62 | dataset, 63 | batch_size=batch_size, 64 | # shuffle=True, 65 | sampler=InfiniteSamplerWrapper(dataset), 66 | ) 67 | ) 68 | 69 | # prepare GANs 70 | 71 | z_dim = 100 72 | 73 | generator = SimpleGenerator64(in_dim=z_dim) 74 | discriminator = KedmiDiscriminator64(num_classes) 75 | 76 | generator = nn.DataParallel(generator, device_ids=gpu_devices).to(device) 77 | discriminator = nn.DataParallel(discriminator, device_ids=gpu_devices).to(device) 78 | 79 | gen_optimizer = torch.optim.Adam( 80 | generator.parameters(), lr=0.0002, betas=(0.5, 0.999) 81 | ) 82 | dis_optimizer = torch.optim.Adam( 83 | discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999) 84 | ) 85 | 86 | # prepare trainer 87 | 88 | config = KedmiGanTrainConfig( 89 | experiment_dir=experiment_dir, 90 | batch_size=batch_size, 91 | input_size=z_dim, 92 | generator=generator, 93 | discriminator=discriminator, 94 | target_model=target_model, 95 | device=device, 96 | augment=None, 97 | gen_optimizer=gen_optimizer, 98 | dis_optimizer=dis_optimizer, 99 | save_ckpt_iters=1000, 100 | show_images_iters=1000, 101 | show_train_info_iters=100, 102 | ) 103 | 104 | trainer = KedmiGanTrainer(config) 105 | 106 | # train gan 107 | 108 | trainer.train(dataloader, max_iters) 109 | -------------------------------------------------------------------------------- /examples/standard/gan_training/lokt.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import time 4 | 5 | sys.path.append('../../../src') 6 | 7 | import kornia 8 | import torch 9 | from torch import nn 10 | from torch.utils.data import DataLoader 11 | from torchvision.datasets import ImageFolder 12 | from torchvision.transforms import ToTensor, Compose, Resize 13 | 14 | from modelinversion.models import IR152_64, LoktGenerator64, LoktDiscriminator64 15 | from modelinversion.train import LoktGanTrainer, LoktGanTrainConfig 16 | from modelinversion.utils import Logger, set_random_seed 17 | from modelinversion.datasets import InfiniteSamplerWrapper, CelebA64 18 | 19 | if __name__ == '__main__': 20 | 21 | num_classes = 1000 22 | target_model_ckpt_path = ( 23 | '../../../checkpoints_v2/classifier/celeba64/celeba64_ir152_93.71.pth' 24 | ) 25 | dataset_path = '../../../dataset/celeba_low/public' 26 | experiment_dir = '' 27 | 28 | batch_size = 256 29 | max_iters = 105000 30 | 31 | device_ids_str = '1' 32 | 33 | # prepare logger 34 | 35 | now_time = time.strftime(r'%Y%m%d_%H%M', time.localtime(time.time())) 36 | logger = Logger(experiment_dir, f'train_gan_{now_time}.log') 37 | 38 | # prepare devices 39 | 40 | os.environ["CUDA_VISIBLE_DEVICES"] = device_ids_str 41 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 42 | device = torch.device(device) 43 | gpu_devices = [i for i in range(torch.cuda.device_count())] 44 | 45 | set_random_seed(46) 46 | 47 | # prepare target models 48 | 49 | target_model = IR152_64(num_classes=num_classes) 50 | target_model.load_state_dict( 51 | torch.load(target_model_ckpt_path, map_location='cpu')['state_dict'] 52 | ) 53 | target_model = nn.DataParallel(target_model, device_ids=gpu_devices).to(device) 54 | target_model.eval() 55 | 56 | # prepare dataset 57 | 58 | dataset = CelebA64( 59 | dataset_path, 60 | output_transform=Compose([ToTensor()]), 61 | ) 62 | dataloader = iter( 63 | DataLoader( 64 | dataset, 65 | batch_size=batch_size, 66 | sampler=InfiniteSamplerWrapper(dataset), 67 | num_workers=4, 68 | ) 69 | ) 70 | 71 | # prepare GANs 72 | 73 | z_dim = 128 74 | 75 | generator = LoktGenerator64(num_classes, dim_z=z_dim) 76 | discriminator = LoktDiscriminator64(num_classes) 77 | 78 | generator = nn.DataParallel(generator, device_ids=gpu_devices).to(device) 79 | discriminator = nn.DataParallel(discriminator, device_ids=gpu_devices).to(device) 80 | 81 | gen_optimizer = torch.optim.Adam( 82 | generator.parameters(), lr=0.0002, betas=(0.0, 0.9) 83 | ) 84 | dis_optimizer = torch.optim.Adam( 85 | discriminator.parameters(), lr=0.0002, betas=(0.0, 0.9) 86 | ) 87 | 88 | # prepare trainer 89 | 90 | # data_augment = kornia.augmentation.container.ImageSequential( 91 | # kornia.augmentation.RandomResizedCrop( 92 | # (64, 64), scale=(0.8, 1.0), ratio=(1.0, 1.0) 93 | # ), 94 | # kornia.augmentation.ColorJitter(brightness=0.2, contrast=0.2, p=0.5), 95 | # kornia.augmentation.RandomHorizontalFlip(), 96 | # kornia.augmentation.RandomRotation(5), 97 | # ) 98 | 99 | train_config = LoktGanTrainConfig( 100 | experiment_dir=experiment_dir, 101 | batch_size=batch_size, 102 | input_size=z_dim, 103 | generator=generator, 104 | discriminator=discriminator, 105 | num_classes=num_classes, 106 | target_model=target_model, 107 | classification_loss_fn='cross_entropy', 108 | device=device, 109 | augment=None, 110 | gen_optimizer=gen_optimizer, 111 | dis_optimizer=dis_optimizer, 112 | save_ckpt_iters=2000, 113 | start_class_loss_iters=5000, 114 | show_images_iters=2000, 115 | show_train_info_iters=473, 116 | class_loss_weight=1.5, 117 | ) 118 | 119 | # PlgmiGanTrainer( 120 | # experiment_dir=experiment_dir, 121 | # batch_size=batch_size, 122 | # input_size=z_dim, 123 | # generator=generator, 124 | # discriminator=discriminator, 125 | # num_classes=num_classes, 126 | # target_model=target_model, 127 | # classification_loss_fn='max_margin', 128 | # device=device, 129 | # augment=data_augment, 130 | # gen_optimizer=gen_optimizer, 131 | # dis_optimizer=dis_optimizer, 132 | # save_ckpt_iters=1000, 133 | # show_images_iters=1000, 134 | # show_train_info_iters=100, 135 | # ) 136 | 137 | # train gan 138 | 139 | trainer = LoktGanTrainer(train_config) 140 | 141 | trainer.train(dataloader, max_iters) 142 | -------------------------------------------------------------------------------- /examples/standard/gan_training/plgmi.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import time 4 | 5 | sys.path.append('../../../src') 6 | 7 | import kornia 8 | import torch 9 | from torch import nn 10 | from torch.utils.data import DataLoader 11 | from torchvision.transforms import ToTensor, Compose, Resize 12 | 13 | from modelinversion.models import IR152_64, PlgmiGenerator64, PlgmiDiscriminator64 14 | from modelinversion.train import PlgmiGanTrainer, PlgmiGanTrainConfig 15 | from modelinversion.utils import Logger 16 | from modelinversion.datasets import InfiniteSamplerWrapper, CelebA64 17 | 18 | if __name__ == '__main__': 19 | 20 | top_k = 30 21 | num_classes = 1000 22 | target_model_ckpt_path = '' 23 | dataset_path = '' 24 | experiment_dir = '' 25 | 26 | batch_size = 64 27 | max_iters = 150000 28 | 29 | device_ids_available = '0' 30 | 31 | # prepare logger 32 | 33 | now_time = time.strftime(r'%Y%m%d_%H%M', time.localtime(time.time())) 34 | logger = Logger(experiment_dir, f'train_gan_{now_time}.log') 35 | 36 | # prepare devices 37 | 38 | os.environ["CUDA_VISIBLE_DEVICES"] = device_ids_available 39 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 40 | device = torch.device(device) 41 | gpu_devices = [i for i in range(torch.cuda.device_count())] 42 | 43 | # prepare target models 44 | 45 | target_model = IR152_64(num_classes=num_classes) 46 | target_model.load_state_dict( 47 | torch.load(target_model_ckpt_path, map_location='cpu')['state_dict'] 48 | ) 49 | target_model = nn.DataParallel(target_model, device_ids=gpu_devices).to(device) 50 | 51 | # prepare dataset 52 | 53 | def _noise_adder(img): 54 | return torch.empty_like(img, dtype=img.dtype).uniform_(0.0, 1 / 256.0) + img 55 | 56 | dataset = CelebA64( 57 | dataset_path, 58 | output_transform=Compose([ToTensor(), _noise_adder]), 59 | ) 60 | dataloader = DataLoader( 61 | dataset, batch_size=batch_size, sampler=InfiniteSamplerWrapper(dataset) 62 | ) 63 | 64 | # prepare GANs 65 | 66 | z_dim = 128 67 | 68 | generator = PlgmiGenerator64(num_classes, dim_z=z_dim) 69 | discriminator = PlgmiDiscriminator64(num_classes) 70 | 71 | generator = nn.DataParallel(generator, device_ids=gpu_devices).to(device) 72 | discriminator = nn.DataParallel(discriminator, device_ids=gpu_devices).to(device) 73 | 74 | gen_optimizer = torch.optim.Adam( 75 | generator.parameters(), lr=0.0002, betas=(0.0, 0.9) 76 | ) 77 | dis_optimizer = torch.optim.Adam( 78 | discriminator.parameters(), lr=0.0002, betas=(0.0, 0.9) 79 | ) 80 | 81 | # prepare trainer 82 | 83 | data_augment = kornia.augmentation.container.ImageSequential( 84 | kornia.augmentation.RandomResizedCrop( 85 | (64, 64), scale=(0.8, 1.0), ratio=(1.0, 1.0) 86 | ), 87 | kornia.augmentation.ColorJitter(brightness=0.2, contrast=0.2, p=0.5), 88 | kornia.augmentation.RandomHorizontalFlip(), 89 | kornia.augmentation.RandomRotation(5), 90 | ) 91 | 92 | train_configs = PlgmiGanTrainConfig( 93 | experiment_dir=experiment_dir, 94 | # train args 95 | batch_size=batch_size, 96 | input_size=z_dim, 97 | generator=generator, 98 | discriminator=discriminator, 99 | num_classes=num_classes, 100 | target_model=target_model, 101 | classification_loss_fn='max_margin', 102 | device=device, 103 | augment=data_augment, 104 | gen_optimizer=gen_optimizer, 105 | dis_optimizer=dis_optimizer, 106 | # log args 107 | save_ckpt_iters=1000, 108 | show_images_iters=1000, 109 | show_train_info_iters=100, 110 | ) 111 | 112 | # train gan 113 | 114 | trainer = PlgmiGanTrainer(train_configs) 115 | 116 | trainer.train(dataloader, max_iters) 117 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | adjustText==1.1.1 2 | apex==0.9.10dev 3 | dlib==19.24.4 4 | facenet_pytorch==2.5.2 5 | ignite==1.1.0 6 | kornia==0.7.2 7 | lmdb==1.4.1 8 | matplotlib==3.8.2 9 | ml_collections==0.1.1 10 | monai==1.3.1 11 | numpy==1.26.4 12 | opencv_python==4.9.0.80 13 | pandas==2.2.2 14 | Pillow==10.3.0 15 | pytorch_fid==0.2.1 16 | PyYAML==6.0.1 17 | scipy==1.13.1 18 | seaborn==0.13.2 19 | tensorboardX==2.6.2.2 20 | timm==0.6.13 21 | tqdm==4.66.1 22 | wandb==0.16.1 23 | --find-links https://download.pytorch.org/whl/torch_stable.html 24 | torch==2.0.1+cu118 25 | torchvision==0.15.2+cu118 26 | -------------------------------------------------------------------------------- /requirements_ori.txt: -------------------------------------------------------------------------------- 1 | adjustText==0.8 2 | apex==0.9.10dev 3 | beautifulsoup4==4.12.2 4 | click==8.1.7 5 | conv==0.2 6 | cryptography==41.0.5 7 | dlib==19.24.2 8 | easydict==1.10 9 | facenet_pytorch==2.5.3 10 | h5py==3.10.0 11 | ignite==1.1.0 12 | imageio==2.31.6 13 | ipdb==0.13.13 14 | kornia==0.7.0 15 | lmdb==1.4.1 16 | matplotlib==3.8.1 17 | ml_collections==0.1.1 18 | monai==1.3.0 19 | moviepy==1.0.3 20 | numpy==1.24.1 21 | opencv_contrib_python_headless==4.8.0.76 22 | pandas==2.1.2 23 | Pillow==9.3.0 24 | psutil==5.9.0 25 | pyspng==0.1.1 26 | PyYAML==6.0.1 27 | Requests==2.31.0 28 | rich==13.6.0 29 | scikit_learn==1.3.2 30 | scipy==1.11.3 31 | seaborn==0.13.0 32 | six==1.16.0 33 | scikit-image 34 | tensorboard==2.15.0 35 | tensorboardX==2.6.2.2 36 | timm==0.9.8 37 | torch==2.0.1 38 | torchvision==0.15.2 39 | tqdm==4.66.1 40 | turbojpeg==0.0.2 41 | wandb==0.15.12 42 | facenet_pytorch==2.5.3 -------------------------------------------------------------------------------- /src/modelinversion/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ffhibnese/Model-Inversion-Attack-ToolBox/d9ad48f997aa204e28696b87ce49a25384fcf794/src/modelinversion/__init__.py -------------------------------------------------------------------------------- /src/modelinversion/attack/SecretGen/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ffhibnese/Model-Inversion-Attack-ToolBox/d9ad48f997aa204e28696b87ce49a25384fcf794/src/modelinversion/attack/SecretGen/.DS_Store -------------------------------------------------------------------------------- /src/modelinversion/attack/SecretGen/.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | checkpoint 3 | data 4 | logs 5 | output 6 | result 7 | premodels 8 | vec 9 | tmp 10 | 11 | *.out 12 | *.png 13 | *.tar 14 | *.sh 15 | *.csv -------------------------------------------------------------------------------- /src/modelinversion/attack/SecretGen/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 AI Secure 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/modelinversion/attack/SecretGen/README.md: -------------------------------------------------------------------------------- 1 | # SecretGen: Privacy Recovery on Pre-trained Models via Distribution Discrimination 2 | 3 | 4 | ## Requirements 5 | Python 3.8 or higher 6 | PyTorch 1.8 or higher 7 | ``` 8 | $ pip install requirements.txt 9 | ``` 10 | 11 | 12 | ## Performing Attack 13 | stage1.py: Train the generation backbone on public data. 14 | ``` 15 | $ python stage1.py --name --mask 16 | ``` 17 | Set `bb` to `True` if it's blackbox case, which will use a public model instead of the target model for diversity loss. 18 | 19 | stage2.py: Perform attack. 20 | ``` 21 | $ python stage2.py --name --mask --target 22 | ``` 23 | For the `target` parameter: 24 | - `pii`: PII (whitebox) 25 | - `pii-bb`: PII (blackbox) 26 | - `gmi`: GMI 27 | - `init-bb`: SecretGen (blackbox) 28 | - `full-bb`: SecretGen (blackbox + ground truth label) 29 | - `init-wb`: SecretGen (whitebox) 30 | - `full`: SecretGen (white + ground truth label) 31 | 32 | Set `save` to `True` if you want to run evaluation protocol 2, which requires a completely recovered dataset. 33 | 34 | 35 | ## Pre-trained Checkpoints 36 | We release the checkpoints for our VGG16 target model and the corresponding generation backbones at this link: 37 | 38 | https://drive.google.com/drive/folders/149LMfBEmhcFr1S2y6PLXf3WqqPA8-We0?usp=sharing 39 | -------------------------------------------------------------------------------- /src/modelinversion/attack/SecretGen/discri.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class DGWGAN(nn.Module): 7 | def __init__(self, in_dim=3, dim=64): 8 | super(DGWGAN, self).__init__() 9 | 10 | def conv_ln_lrelu(in_dim, out_dim): 11 | return nn.Sequential( 12 | nn.Conv2d(in_dim, out_dim, 5, 2, 2), 13 | # Since there is no effective implementation of LayerNorm, 14 | # we use InstanceNorm2d instead of LayerNorm here. 15 | nn.InstanceNorm2d(out_dim, affine=True), 16 | nn.LeakyReLU(0.2), 17 | ) 18 | 19 | self.layer1 = nn.Sequential(nn.Conv2d(in_dim, dim, 5, 2, 2), nn.LeakyReLU(0.2)) 20 | self.layer2 = conv_ln_lrelu(dim, dim * 2) 21 | self.layer3 = conv_ln_lrelu(dim * 2, dim * 4) 22 | self.layer4 = conv_ln_lrelu(dim * 4, dim * 8) 23 | self.layer5 = nn.Conv2d(dim * 8, 1, 4) 24 | 25 | def forward(self, x): 26 | feat1 = self.layer1(x) 27 | feat2 = self.layer2(feat1) 28 | feat3 = self.layer3(feat2) 29 | feat4 = self.layer4(feat3) 30 | y = self.layer5(feat4) 31 | y = y.view(-1) 32 | return [feat1, feat2, feat3, feat4], y 33 | 34 | 35 | class DLWGAN(nn.Module): 36 | def __init__(self, in_dim=3, dim=64): 37 | super(DLWGAN, self).__init__() 38 | 39 | def conv_ln_lrelu(in_dim, out_dim): 40 | return nn.Sequential( 41 | nn.Conv2d(in_dim, out_dim, 5, 2, 2), 42 | # Since there is no effective implementation of LayerNorm, 43 | # we use InstanceNorm2d instead of LayerNorm here. 44 | nn.InstanceNorm2d(out_dim, affine=True), 45 | nn.LeakyReLU(0.2), 46 | ) 47 | 48 | self.layer1 = nn.Sequential(nn.Conv2d(in_dim, dim, 5, 2, 2), nn.LeakyReLU(0.2)) 49 | self.layer2 = conv_ln_lrelu(dim, dim * 2) 50 | self.layer3 = conv_ln_lrelu(dim * 2, dim * 4) 51 | self.layer4 = nn.Conv2d(dim * 4, 1, 4) 52 | 53 | def forward(self, x): 54 | feat1 = self.layer1(x) 55 | feat2 = self.layer2(feat1) 56 | feat3 = self.layer3(feat2) 57 | y = self.layer4(feat3) 58 | return [feat1, feat2, feat3], y 59 | -------------------------------------------------------------------------------- /src/modelinversion/attack/SecretGen/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn.modules.loss import _Loss 3 | 4 | from utils import low2high112 5 | 6 | 7 | def completion_network_loss(input, output, mask): 8 | bs = input.size(0) 9 | loss = torch.sum(torch.abs(output * mask - input * mask)) / bs 10 | # return mse_loss(output * mask, input * mask) 11 | return loss 12 | 13 | 14 | def noise_loss(V, img1, img2): 15 | # img1 = low2high(img1) 16 | # img2 = low2high(img2) 17 | feat1, __, ___ = V(img1) 18 | feat2, __, ___ = V(img2) 19 | 20 | loss = torch.mean(torch.abs(feat1 - feat2)) 21 | # return mse_loss(output * mask, input * mask) 22 | return loss 23 | 24 | 25 | class ContextLoss(_Loss): 26 | def forward(self, mask, gen, images): 27 | bs = gen.size(0) 28 | context_loss = ( 29 | torch.sum(torch.abs(torch.mul(mask, gen) - torch.mul(mask, images))) / bs 30 | ) 31 | return context_loss 32 | 33 | 34 | class CrossEntropyLoss(_Loss): 35 | def forward(self, out, gt): 36 | bs = out.size(0) 37 | # print(out.size(), gt.size()) 38 | loss = -torch.mul(gt.float(), torch.log(out.float() + 1e-7)) 39 | loss = torch.sum(loss) / bs 40 | return loss 41 | 42 | 43 | class FeatLoss(_Loss): 44 | def forward(self, fake_feat, real_feat): 45 | num = len(fake_feat) 46 | loss = torch.zeros(1).cuda() 47 | for i in range(num): 48 | loss += torch.mean(torch.abs(fake_feat[i] - real_feat[i])) 49 | 50 | return loss 51 | -------------------------------------------------------------------------------- /src/modelinversion/attack/SecretGen/requirements.txt: -------------------------------------------------------------------------------- 1 | tensorboardX 2 | adjustText 3 | matplotlib 4 | seaborn 5 | scikit-learn 6 | tqdm 7 | ml_collections -------------------------------------------------------------------------------- /src/modelinversion/attack/SecretGen/tgt_models/resnet152.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision 4 | 5 | 6 | class ResNet152(nn.Module): 7 | def __init__(self, num_classes=1000, vis=False): 8 | super(ResNet152, self).__init__() 9 | self.vis = vis 10 | model = torchvision.models.resnet152(pretrained=True) 11 | self.feature = nn.Sequential(*list(model.children())[:-1]) 12 | self.feat_dim = 2048 13 | self.num_of_classes = num_classes 14 | self.fc_layer = nn.Sequential( 15 | nn.Linear(self.feat_dim, self.num_of_classes), nn.Softmax(dim=1) 16 | ) 17 | 18 | def classifier(self, x): 19 | out = self.fc_layer(x) 20 | __, iden = torch.max(out, dim=1) 21 | return out, iden 22 | 23 | def forward(self, x): 24 | if self.vis: 25 | out = [] 26 | for module in self.feature[0]: 27 | x = module(x) 28 | out.append(torch.flatten(x, 1)) 29 | x = x.contiguous().view(x.size(0), -1) 30 | for module in self.fc_layer: 31 | x = module(x) 32 | out.append(torch.flatten(x, 1)) 33 | return out 34 | 35 | feature = self.feature(x) 36 | feature = feature.contiguous().view(feature.size(0), -1) 37 | out, iden = self.classifier(feature) 38 | return feature, out, iden 39 | -------------------------------------------------------------------------------- /src/modelinversion/attack/SecretGen/tgt_models/vgg16.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision 4 | 5 | 6 | class VGG16(nn.Module): 7 | def __init__(self, num_classes=1000, vis=False): 8 | super(VGG16, self).__init__() 9 | self.vis = vis 10 | model = torchvision.models.vgg16_bn(pretrained=True) 11 | self.feature = nn.Sequential(*list(model.children())[:-2]) 12 | self.feat_dim = 512 * 2 * 2 13 | self.num_of_classes = num_classes 14 | self.fc_layer = nn.Sequential( 15 | nn.Linear(self.feat_dim, self.num_of_classes), nn.Softmax(dim=1) 16 | ) 17 | 18 | def classifier(self, x): 19 | out = self.fc_layer(x) 20 | __, iden = torch.max(out, dim=1) 21 | return out, iden 22 | 23 | def forward(self, x): 24 | if self.vis: 25 | out = [] 26 | for module in self.feature[0]: 27 | x = module(x) 28 | out.append(torch.flatten(x, 1)) 29 | x = x.contiguous().view(x.size(0), -1) 30 | for module in self.fc_layer: 31 | x = module(x) 32 | out.append(torch.flatten(x, 1)) 33 | return out 34 | 35 | feature = self.feature(x) 36 | feature = feature.contiguous().view(feature.size(0), -1) 37 | out, iden = self.classifier(feature) 38 | return feature, out, iden 39 | -------------------------------------------------------------------------------- /src/modelinversion/attack/SecretGen/train_target.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | from torch.utils.data import DataLoader 5 | import argparse 6 | from tqdm import tqdm 7 | from dataloader import CelebA 8 | from tgt_models.vgg16 import VGG16 9 | from tgt_models.resnet152 import ResNet152 10 | from tensorboardX import SummaryWriter 11 | import os.path as osp 12 | import os 13 | import numpy as np 14 | import random 15 | 16 | 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument( 19 | '--name', 20 | '-n', 21 | required=True, 22 | type=str, 23 | choices=['vgg16', 'resnet152'], 24 | help='type of model to use', 25 | ) 26 | parser.add_argument('--batch_size', default=64, type=int, help='batch size') 27 | parser.add_argument('--max_epoch', default=300, type=int, help='training epochs') 28 | parser.add_argument('--lr', default=0.01, type=float, help='learning rate') 29 | opt = parser.parse_args() 30 | print(opt) 31 | 32 | 33 | torch.manual_seed(0) 34 | torch.cuda.manual_seed(0) 35 | np.random.seed(0) 36 | random.seed(0) 37 | 38 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 39 | if opt.name == 'vgg16': 40 | net = VGG16(num_classes=1000).to(device) 41 | elif opt.name == 'resnet152': 42 | net = ResNet152(num_classes=1000).to(device) 43 | 44 | 45 | def seed_worker(worker_id): 46 | worker_seed = torch.initial_seed() % 2**32 47 | np.random.seed(worker_seed) 48 | random.seed(worker_seed) 49 | 50 | 51 | g = torch.Generator() 52 | g.manual_seed(0) 53 | 54 | 55 | trainset = CelebA(split='pri') 56 | testset = CelebA(split='pri-dev') 57 | trainloader = DataLoader( 58 | trainset, opt.batch_size, shuffle=True, worker_init_fn=seed_worker, generator=g 59 | ) 60 | testloader = DataLoader( 61 | testset, opt.batch_size, shuffle=False, worker_init_fn=seed_worker, generator=g 62 | ) 63 | 64 | nll_loss = nn.CrossEntropyLoss() 65 | optimizer = optim.SGD(net.parameters(), lr=opt.lr, momentum=0.9) 66 | 67 | max_epoch = opt.max_epoch 68 | scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_epoch) 69 | 70 | writer = SummaryWriter(log_dir=osp.join('logs', f'{opt.name}-train-pri')) 71 | 72 | train_step = 0 73 | test_step = 0 74 | 75 | 76 | def train(epoch): 77 | global train_step 78 | 79 | print('\nEpoch: %d' % epoch) 80 | net.train() 81 | correct = 0 82 | total = 0 83 | progress_bar = tqdm(trainloader) 84 | for inputs, _, targets in progress_bar: 85 | inputs, targets = inputs.to(device), targets.to(device) 86 | 87 | optimizer.zero_grad() 88 | _, logits, _ = net(inputs) 89 | loss = nll_loss(torch.log(logits), targets) 90 | loss.backward() 91 | optimizer.step() 92 | 93 | preds = torch.argmax(torch.softmax(logits, dim=1), dim=1) 94 | total += targets.size(0) 95 | correct += len(preds[preds == targets]) 96 | 97 | progress_bar.set_description(f'train loss: {loss:.4f}') 98 | 99 | writer.add_scalar('train loss', loss, train_step) 100 | train_step += 1 101 | 102 | acc = 100 * correct / total 103 | writer.add_scalar('train acc', acc, epoch) 104 | writer.add_scalar('lr', scheduler.get_lr()[0], epoch) 105 | 106 | 107 | def test(epoch): 108 | global test_step 109 | net.eval() 110 | correct = 0 111 | total = 0 112 | with torch.no_grad(): 113 | progress_bar = tqdm(testloader) 114 | for inputs, _, targets in progress_bar: 115 | inputs, targets = inputs.to(device), targets.to(device) 116 | _, logits, _ = net(inputs) 117 | loss = nll_loss(torch.log(logits), targets) 118 | 119 | preds = torch.argmax(torch.softmax(logits, dim=1), dim=1) 120 | total += targets.size(0) 121 | correct += len(preds[preds == targets]) 122 | 123 | progress_bar.set_description(f'test loss: {loss:.4f}') 124 | 125 | writer.add_scalar('test loss', loss, test_step) 126 | test_step += 1 127 | 128 | acc = 100 * correct / total 129 | writer.add_scalar('test acc', acc, epoch) 130 | 131 | state = { 132 | 'state_dict': net.state_dict(), 133 | 'acc': acc, 134 | } 135 | if not osp.isdir('premodels'): 136 | os.mkdir('premodels') 137 | torch.save(state, f'./premodels/{opt.name}-pri.tar') 138 | 139 | 140 | for epoch in range(max_epoch): 141 | train(epoch) 142 | test(epoch) 143 | scheduler.step() 144 | -------------------------------------------------------------------------------- /src/modelinversion/attack/VMI/__init__.py: -------------------------------------------------------------------------------- 1 | from .vmi_attacker import * -------------------------------------------------------------------------------- /src/modelinversion/attack/__init__.py: -------------------------------------------------------------------------------- 1 | from .attacker import ImageClassifierAttackConfig, ImageClassifierAttacker 2 | 3 | from .optimize import ( 4 | BaseImageOptimizationConfig, 5 | BaseImageOptimization, 6 | SimpleWhiteBoxOptimization, 7 | SimpleWhiteBoxOptimizationConfig, 8 | ImageAugmentWhiteBoxOptimization, 9 | ImageAugmentWhiteBoxOptimizationConfig, 10 | VarienceWhiteboxOptimization, 11 | VarienceWhiteboxOptimizationConfig, 12 | BrepOptimization, 13 | BrepOptimizationConfig, 14 | MinerWhiteBoxOptimization, 15 | MinerWhiteBoxOptimizationConfig, 16 | RlbOptimization, 17 | RlbOptimizationConfig, 18 | GeneticOptimizationConfig, 19 | GeneticOptimization, 20 | C2fGeneticOptimizationConfig, 21 | C2fGeneticOptimization, 22 | IntermediateWhiteboxOptimization, 23 | StyelGANIntermediateWhiteboxOptimization, 24 | IntermediateWhiteboxOptimizationConfig, 25 | ) 26 | 27 | from .losses import ( 28 | ImageAugmentClassificationLoss, 29 | ClassificationWithFeatureDistributionLoss, 30 | ComposeImageLoss, 31 | GmiDiscriminatorLoss, 32 | KedmiDiscriminatorLoss, 33 | VmiLoss, 34 | DeepInversionBatchNormPriorLoss, 35 | ImagePixelPriorLoss, 36 | ImageVariationPriorLoss, 37 | MultiModelOutputKLLoss, 38 | ) 39 | 40 | from .VMI import * 41 | -------------------------------------------------------------------------------- /src/modelinversion/attack/optimize/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import ( 2 | BaseImageOptimizationConfig, 3 | BaseImageOptimization, 4 | SimpleWhiteBoxOptimization, 5 | SimpleWhiteBoxOptimizationConfig, 6 | ImageAugmentWhiteBoxOptimization, 7 | ImageAugmentWhiteBoxOptimizationConfig, 8 | VarienceWhiteboxOptimization, 9 | VarienceWhiteboxOptimizationConfig, 10 | MinerWhiteBoxOptimization, 11 | MinerWhiteBoxOptimizationConfig, 12 | BrepOptimization, 13 | BrepOptimizationConfig, 14 | IntermediateWhiteboxOptimization, 15 | IntermediateWhiteboxOptimizationConfig, 16 | StyelGANIntermediateWhiteboxOptimization, 17 | ) 18 | 19 | from .rlb import RlbOptimization, RlbOptimizationConfig 20 | from .genetic import ( 21 | GeneticOptimization, 22 | GeneticOptimizationConfig, 23 | C2fGeneticOptimization, 24 | C2fGeneticOptimizationConfig, 25 | ) 26 | -------------------------------------------------------------------------------- /src/modelinversion/attack/optimize/deepinversion.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from typing import Callable, Tuple 3 | from torch import LongTensor, Tensor 4 | from ...attack.optimize.base import SimpleWhiteBoxOptimizationConfig 5 | from ...models import BaseImageGenerator 6 | from ...utils import DeepInversionBNFeatureHook 7 | from .base import * 8 | 9 | 10 | class DeepInversionOptimizationConfig(SimpleWhiteBoxOptimizationConfig): 11 | 12 | pass 13 | 14 | 15 | class DeepInversionOptimization(SimpleWhiteBoxOptimization): 16 | 17 | def __init__( 18 | self, 19 | config: SimpleWhiteBoxOptimizationConfig, 20 | image_loss_fn: Callable[ 21 | [Tensor, LongTensor], Tensor | Tuple[Tensor, OrderedDict] 22 | ], 23 | ) -> None: 24 | generator = lambda img, *args, **kwargs: img 25 | super().__init__(config, generator, image_loss_fn) 26 | -------------------------------------------------------------------------------- /src/modelinversion/configs/attack_config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass,field 2 | from abc import ABC,abstractmethod 3 | from models import * 4 | from sampler import * 5 | from attack import * 6 | from datasets import * 7 | from torch.utils.data import Dataset 8 | 9 | @dataclass 10 | class BaseAttackConfig(ABC): 11 | experiment_dir: str 12 | generator_ckpt_path: str 13 | discriminator_ckpt_path: str 14 | target_model_ckpt_path: str 15 | eval_model_ckpt_path: str 16 | eval_dataset_path: str 17 | 18 | batch_size: int 19 | device: torch.device 20 | gpu_devices: list[int] = field(default_factory=list) 21 | attack_targets: list[int] = field(default_factory=list) 22 | 23 | def _parse_models(self): 24 | target_model = auto_classifier_from_pretrained(self.target_model_ckpt_path) 25 | eval_model = auto_classifier_from_pretrained( 26 | self.eval_model_ckpt_path, register_last_feature_hook=True 27 | ) 28 | generator = auto_generator_from_pretrained(self.generator_ckpt_path) 29 | 30 | target_model = nn.DataParallel(target_model, device_ids=self.gpu_devices).to(self.device) 31 | eval_model = nn.DataParallel(eval_model, device_ids=self.gpu_devices).to(self.device) 32 | generator = nn.DataParallel(generator, device_ids=self.gpu_devices).to(self.device) 33 | target_model.eval() 34 | eval_model.eval() 35 | generator.eval() 36 | return target_model,eval_model,generator 37 | 38 | @abstractmethod 39 | def default_params(self): 40 | pass 41 | 42 | @abstractmethod 43 | def get_attacker(self): 44 | pass 45 | 46 | class GmiAttackConfig(BaseAttackConfig): 47 | z_dim: int = 100 48 | optimize_num: int = 50 49 | 50 | def default_params(self): 51 | # prepare models 52 | 53 | self.latents_sampler = SimpleLatentsSampler(self.z_dim, self.batch_size) 54 | self.target_model,self.eval_model,self.generator = self._parse_models() 55 | discriminator = auto_discriminator_from_pretrained(self.discriminator_ckpt_path) 56 | discriminator = nn.DataParallel(discriminator, device_ids=self.gpu_devices).to(self.device) 57 | discriminator.eval() 58 | self.discriminator = discriminator 59 | 60 | # prepare optimization 61 | 62 | optimization_config = SimpleWhiteBoxOptimizationConfig( 63 | experiment_dir=self.experiment_dir, 64 | device=self.device, 65 | optimizer='SGD', 66 | optimizer_kwargs={'lr': 0.02, 'momentum': 0.9}, 67 | iter_times=1500, 68 | ) 69 | 70 | identity_loss_fn = ImageAugmentClassificationLoss( 71 | classifier=self.target_model, loss_fn='ce', create_aug_images_fn=None 72 | ) 73 | 74 | discriminator_loss_fn = GmiDiscriminatorLoss(discriminator) 75 | 76 | loss_fn = ComposeImageLoss( 77 | [identity_loss_fn, discriminator_loss_fn], weights=[100, 1] 78 | ) 79 | 80 | self.optimization_fn = SimpleWhiteBoxOptimization( 81 | optimization_config, generator, loss_fn 82 | ) 83 | 84 | def get_attacker(self, 85 | save_optimized_images: bool=True, 86 | save_final_images: bool=False, 87 | eval_metrics: list=[], 88 | eval_optimized_result:bool=True, 89 | eval_final_result:bool=False): 90 | 91 | # prepare attack 92 | 93 | attack_config = ImageClassifierAttackConfig( 94 | # attack args 95 | self.latents_sampler, 96 | optimize_num=self.optimize_num, 97 | optimize_batch_size=self.batch_size, 98 | optimize_fn=self.optimization_fn, 99 | 100 | # save path args 101 | save_dir=self.experiment_dir, 102 | save_optimized_images=save_optimized_images, 103 | save_final_images=save_final_images, 104 | 105 | # metric args 106 | eval_metrics=eval_metrics, 107 | eval_optimized_result=eval_optimized_result, 108 | eval_final_result=eval_final_result, 109 | ) 110 | 111 | return ImageClassifierAttacker(attack_config) -------------------------------------------------------------------------------- /src/modelinversion/configs/classifier_config.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ffhibnese/Model-Inversion-Attack-ToolBox/d9ad48f997aa204e28696b87ce49a25384fcf794/src/modelinversion/configs/classifier_config.py -------------------------------------------------------------------------------- /src/modelinversion/configs/gan_config.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ffhibnese/Model-Inversion-Attack-ToolBox/d9ad48f997aa204e28696b87ce49a25384fcf794/src/modelinversion/configs/gan_config.py -------------------------------------------------------------------------------- /src/modelinversion/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import ( 2 | InfiniteSamplerWrapper, 3 | ClassSubset, 4 | top_k_selection, 5 | generator_generate_datasets, 6 | ) 7 | from .generator import GeneratorDataset 8 | from .base import LabelImageFolder 9 | from .facescrub import ( 10 | FaceScrub, 11 | preprocess_facescrub_fn, 12 | FaceScrub64, 13 | FaceScrub112, 14 | FaceScrub224, 15 | FaceScrub299, 16 | ) 17 | from .celeba import ( 18 | CelebA, 19 | preprocess_celeba_fn, 20 | CelebA64, 21 | CelebA112, 22 | CelebA224, 23 | CelebA299, 24 | ) 25 | -------------------------------------------------------------------------------- /src/modelinversion/datasets/base.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Dict, List, Tuple, Optional 2 | from torchvision.datasets import DatasetFolder 3 | from torchvision.datasets.folder import default_loader, IMG_EXTENSIONS 4 | import os 5 | 6 | 7 | class LabelDatasetFolder(DatasetFolder): 8 | """A label data loader. 9 | 10 | The subfolder of the root should be named as the label number. 11 | 12 | Args: 13 | root (string): Root directory path. 14 | loader (callable): A function to load a sample given its path. 15 | extensions (tuple[string]): A list of allowed extensions. 16 | both extensions and is_valid_file should not be passed. 17 | transform (callable, optional): A function/transform that takes in 18 | a sample and returns a transformed version. 19 | E.g, ``transforms.RandomCrop`` for images. 20 | target_transform (callable, optional): A function/transform that takes 21 | in the target and transforms it. 22 | is_valid_file (callable, optional): A function that takes path of a file 23 | and check if the file is a valid file (used to check of corrupt files) 24 | both extensions and is_valid_file should not be passed. 25 | 26 | Attributes: 27 | classes (list): List of the class names sorted alphabetically. 28 | class_to_idx (dict): Dict with items (class_name, class_index). 29 | samples (list): List of (sample path, class_index) tuples 30 | targets (list): The class_index value for each image in the dataset 31 | """ 32 | 33 | def __init__( 34 | self, 35 | root: str, 36 | loader: Callable[[str], Any], 37 | extensions: Tuple[str, ...] | None = None, 38 | transform: Callable[..., Any] | None = None, 39 | target_transform: Callable[..., Any] | None = None, 40 | is_valid_file: Callable[[str], bool] | None = None, 41 | ) -> None: 42 | super().__init__( 43 | root, loader, extensions, transform, target_transform, is_valid_file 44 | ) 45 | 46 | classes, class_to_idx = self.find_classes(self.root) 47 | samples = self.make_dataset(self.root, class_to_idx, extensions, is_valid_file) 48 | 49 | self.loader = loader 50 | self.extensions = extensions 51 | 52 | self.classes = classes 53 | self.class_to_idx = class_to_idx 54 | self.samples = samples 55 | self.targets = [s[1] for s in samples] 56 | 57 | def find_classes(self, directory: str) -> Tuple[List[str], Dict[str, int]]: 58 | classes = sorted( 59 | ( 60 | entry.name 61 | for entry in os.scandir(directory) 62 | if entry.is_dir() and entry.name.isalnum() 63 | ), 64 | key=lambda x: int(x), 65 | ) 66 | if not classes: 67 | raise FileNotFoundError(f"Couldn't find any class folder in {directory}.") 68 | 69 | class_to_idx = {cls_name: int(cls_name) for i, cls_name in enumerate(classes)} 70 | return classes, class_to_idx 71 | 72 | 73 | class LabelImageFolder(LabelDatasetFolder): 74 | """A generic data loader where the images are arranged in this way by default: :: 75 | 76 | root/0/xxx.png 77 | root/0/xxy.png 78 | 79 | root/1/123.png 80 | root/1/nsdf3.png 81 | 82 | This class inherits from :class:`LabelDatasetFolder` so 83 | the same methods can be overridden to customize the dataset. 84 | 85 | Args: 86 | root (string): Root directory path. 87 | transform (callable, optional): A function/transform that takes in an PIL image 88 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 89 | target_transform (callable, optional): A function/transform that takes in the 90 | target and transforms it. 91 | loader (callable, optional): A function to load an image given its path. 92 | is_valid_file (callable, optional): A function that takes path of an Image file 93 | and check if the file is a valid file (used to check of corrupt files) 94 | 95 | Attributes: 96 | classes (list): List of the class names sorted alphabetically. 97 | class_to_idx (dict): Dict with items (class_name, class_index). 98 | imgs (list): List of (image path, class_index) tuples 99 | """ 100 | 101 | def __init__( 102 | self, 103 | root: str, 104 | transform: Optional[Callable] = None, 105 | target_transform: Optional[Callable] = None, 106 | loader: Callable[[str], Any] = default_loader, 107 | is_valid_file: Optional[Callable[[str], bool]] = None, 108 | ): 109 | super().__init__( 110 | root, 111 | loader, 112 | IMG_EXTENSIONS if is_valid_file is None else None, 113 | transform=transform, 114 | target_transform=target_transform, 115 | is_valid_file=is_valid_file, 116 | ) 117 | self.imgs = self.samples 118 | -------------------------------------------------------------------------------- /src/modelinversion/datasets/celeba.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | from torch.utils.data import ConcatDataset, Dataset, Subset 5 | from torchvision.datasets import ImageFolder 6 | import torchvision.transforms as TF 7 | from .base import LabelImageFolder 8 | 9 | 10 | def preprocess_celeba_fn(crop_center, output_resolution): 11 | """ 12 | Do transformations to CelebA dataset. 13 | Support: center crop 14 | """ 15 | if crop_center: 16 | crop_size = 108 17 | return TF.Compose( 18 | [ 19 | TF.CenterCrop((crop_size, crop_size)), 20 | TF.Resize((output_resolution, output_resolution), antialias=True), 21 | ] 22 | ) 23 | else: 24 | return TF.Resize((output_resolution, output_resolution)) 25 | 26 | 27 | class CelebA(Dataset): 28 | 29 | def __init__( 30 | self, 31 | root_path, 32 | crop_center=False, 33 | preprocess_resolution=224, 34 | transform=None, 35 | ): 36 | 37 | self.preprocess_transform = preprocess_celeba_fn( 38 | crop_center, preprocess_resolution 39 | ) 40 | 41 | self.dataset = LabelImageFolder( 42 | root=root_path, transform=self.preprocess_transform 43 | ) 44 | self.name = 'CelebA' 45 | 46 | self.transform = transform 47 | self.targets = self.dataset.targets 48 | 49 | def __len__(self): 50 | return len(self.dataset) 51 | 52 | def __getitem__(self, idx): 53 | im, target = self.dataset[idx] 54 | if self.transform: 55 | return self.transform(im), target 56 | else: 57 | return im, target 58 | 59 | 60 | class CelebA64(CelebA): 61 | 62 | def __init__(self, root_path, output_transform=None): 63 | super().__init__(root_path, True, 64, output_transform) 64 | 65 | 66 | class CelebA112(CelebA): 67 | 68 | def __init__(self, root_path, output_transform=None): 69 | super().__init__(root_path, True, 112, output_transform) 70 | 71 | 72 | class CelebA224(CelebA): 73 | 74 | def __init__(self, root_path, output_transform=None): 75 | super().__init__(root_path, False, 224, output_transform) 76 | 77 | 78 | class CelebA299(CelebA): 79 | 80 | def __init__(self, root_path, output_transform=None): 81 | super().__init__(root_path, False, 299, output_transform) 82 | -------------------------------------------------------------------------------- /src/modelinversion/datasets/facescrub.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | from torch.utils.data import ConcatDataset, Dataset, Subset 5 | from torchvision.datasets import ImageFolder 6 | import torchvision.transforms as TF 7 | 8 | 9 | def preprocess_facescrub_fn(crop_center, output_resolution): 10 | if crop_center: 11 | crop_size = int(54 * output_resolution / 64) 12 | return TF.Compose( 13 | [ 14 | TF.Resize((output_resolution, output_resolution), antialias=True), 15 | TF.CenterCrop((crop_size, crop_size)), 16 | TF.Resize((output_resolution, output_resolution), antialias=True), 17 | ] 18 | ) 19 | else: 20 | return TF.Resize((output_resolution, output_resolution)) 21 | 22 | 23 | class FaceScrub(Dataset): 24 | 25 | def __init__( 26 | self, 27 | root_path, 28 | train=False, 29 | crop_center=False, 30 | preprocess_resolution=224, 31 | transform=None, 32 | ): 33 | 34 | split_seed = 42 35 | root_actors = os.path.join(root_path, 'actors/faces') 36 | root_actresses = os.path.join(root_path, 'actresses/faces') 37 | dataset_actors = ImageFolder(root=root_actors, transform=None) 38 | target_transform_actresses = lambda x: x + len(dataset_actors.classes) 39 | dataset_actresses = ImageFolder( 40 | root=root_actresses, 41 | transform=None, 42 | target_transform=target_transform_actresses, 43 | ) 44 | dataset_actresses.class_to_idx = { 45 | key: value + len(dataset_actors.classes) 46 | for key, value in dataset_actresses.class_to_idx.items() 47 | } 48 | self.dataset = ConcatDataset([dataset_actors, dataset_actresses]) 49 | self.classes = dataset_actors.classes + dataset_actresses.classes 50 | self.class_to_idx = { 51 | **dataset_actors.class_to_idx, 52 | **dataset_actresses.class_to_idx, 53 | } 54 | self.targets = dataset_actors.targets + [ 55 | t + len(dataset_actors.classes) for t in dataset_actresses.targets 56 | ] 57 | self.name = 'facescrub_all' 58 | 59 | self.transform = transform 60 | self.preprocess_transform = preprocess_facescrub_fn( 61 | crop_center, preprocess_resolution 62 | ) 63 | 64 | indices = list(range(len(self.dataset))) 65 | np.random.seed(split_seed) 66 | np.random.shuffle(indices) 67 | training_set_size = int(0.9 * len(self.dataset)) 68 | train_idx = indices[:training_set_size] 69 | test_idx = indices[training_set_size:] 70 | 71 | # print(indices.__len__(), len(self.targets)) 72 | 73 | if train: 74 | self.dataset = Subset(self.dataset, train_idx) 75 | self.targets = np.array(self.targets)[train_idx].tolist() 76 | else: 77 | self.dataset = Subset(self.dataset, test_idx) 78 | self.targets = np.array(self.targets)[test_idx].tolist() 79 | 80 | def __len__(self): 81 | return len(self.dataset) 82 | 83 | def __getitem__(self, idx): 84 | im, _ = self.dataset[idx] 85 | im = self.preprocess_transform(im) 86 | if self.transform: 87 | return self.transform(im), self.targets[idx] 88 | else: 89 | return im, self.targets[idx] 90 | 91 | 92 | class FaceScrub64(FaceScrub): 93 | 94 | def __init__( 95 | self, 96 | root_path, 97 | train=True, 98 | output_transform=None, 99 | ): 100 | super().__init__(root_path, train, True, 64, output_transform) 101 | 102 | 103 | class FaceScrub112(FaceScrub): 104 | 105 | def __init__( 106 | self, 107 | root_path, 108 | train=True, 109 | output_transform=None, 110 | ): 111 | super().__init__(root_path, train, True, 112, output_transform) 112 | 113 | 114 | class FaceScrub224(FaceScrub): 115 | 116 | def __init__( 117 | self, 118 | root_path, 119 | train=True, 120 | output_transform=None, 121 | ): 122 | super().__init__(root_path, train, False, 224, output_transform) 123 | 124 | 125 | class FaceScrub299(FaceScrub): 126 | 127 | def __init__( 128 | self, 129 | root_path, 130 | train=True, 131 | output_transform=None, 132 | ): 133 | super().__init__(root_path, train, False, 299, output_transform) 134 | -------------------------------------------------------------------------------- /src/modelinversion/datasets/ffhq.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Optional, Callable, Any 3 | 4 | import numpy as np 5 | from torch.utils.data import ConcatDataset, Dataset, Subset 6 | from torchvision.datasets import ImageFolder 7 | from torchvision.datasets.folder import default_loader 8 | import torchvision.transforms as TF 9 | 10 | 11 | def preprocess_ffhq_fn(crop_center_size, output_resolution): 12 | if crop_center_size is not None: 13 | return TF.Compose( 14 | [ 15 | TF.CenterCrop((crop_center_size, crop_center_size)), 16 | TF.Resize((output_resolution, output_resolution), antialias=True), 17 | ] 18 | ) 19 | else: 20 | return TF.Resize((output_resolution, output_resolution)) 21 | 22 | 23 | class FFHQ(ImageFolder): 24 | 25 | def __init__( 26 | self, 27 | root_path: str, 28 | crop_center_size: Optional[int] = 800, 29 | preprocess_resolution: int = 224, 30 | output_transform: Callable[..., Any] | None = None, 31 | ): 32 | preprocess_transform = preprocess_ffhq_fn( 33 | crop_center_size, preprocess_resolution 34 | ) 35 | transform = ( 36 | preprocess_transform 37 | if output_transform is None 38 | else TF.Compose([preprocess_transform, output_transform]) 39 | ) 40 | super().__init__(root_path, transform) 41 | 42 | 43 | class FFHQ64(FFHQ): 44 | 45 | def __init__( 46 | self, root_path: str, output_transform: Callable[..., Any] | None = None 47 | ): 48 | 49 | super().__init__(root_path, 88, 64, output_transform) 50 | 51 | 52 | class FFHQ256(FFHQ): 53 | 54 | def __init__( 55 | self, root_path: str, output_transform: Callable[..., Any] | None = None 56 | ): 57 | 58 | super().__init__(root_path, 800, 256, output_transform) 59 | -------------------------------------------------------------------------------- /src/modelinversion/datasets/generator.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Sequence, Callable, Optional 3 | 4 | import torch 5 | from torch.utils.data import TensorDataset, DataLoader 6 | 7 | from ..utils import batch_apply 8 | 9 | 10 | class GeneratorDataset(TensorDataset): 11 | 12 | def __init__(self, z, y, pseudo_y, generator, device, transform=None) -> None: 13 | super().__init__(z, y, pseudo_y) 14 | self.generator = generator 15 | self.device = device 16 | self.transform = transform 17 | 18 | def __getitem__(self, index): 19 | return super().__getitem__(index) 20 | 21 | @classmethod 22 | def create( 23 | cls, 24 | input_shape: int | Sequence[int], 25 | num_classes: int, 26 | generate_num_per_class: int, 27 | generator, 28 | target_model, 29 | batch_size, 30 | device: torch.device, 31 | gan_to_target_transform: Optional[Callable] = None, 32 | ): 33 | labels = torch.arange(0, num_classes, dtype=torch.long).repeat_interleave( 34 | generate_num_per_class 35 | ) 36 | 37 | if isinstance(input_shape, int): 38 | input_shape = (input_shape,) 39 | 40 | @torch.no_grad() 41 | def generation(labels): 42 | shape = (len(labels), *input_shape) 43 | pseudo_y = labels.to(device) 44 | z = torch.randn(shape, device=device) 45 | imgs = generator(z, labels=pseudo_y) 46 | if gan_to_target_transform is not None: 47 | imgs = gan_to_target_transform(imgs) 48 | y = target_model(imgs)[0].argmax(dim=-1).detach().cpu() 49 | return z.detach().cpu(), y, pseudo_y.detach().cpu() 50 | 51 | z, y, pseudo_y = batch_apply( 52 | generation, labels, batch_size=batch_size, use_tqdm=True 53 | ) 54 | 55 | return cls(z, y, pseudo_y, generator, device, gan_to_target_transform) 56 | 57 | @classmethod 58 | def from_precreate( 59 | cls, save_path, generator, device, transform=None 60 | ) -> "GeneratorDataset": 61 | tensors = torch.load(save_path) 62 | return cls(*tensors, generator, device, transform) 63 | 64 | def save(self, save_path): 65 | save_dir, _ = os.path.split(save_path) 66 | os.makedirs(save_dir, exist_ok=True) 67 | torch.save(self.tensors, save_path) 68 | 69 | @torch.no_grad() 70 | def collate_fn(self, data): 71 | z, y, pseudo_y = zip(*data) 72 | z = torch.stack(z, dim=0).to(self.device) 73 | y = torch.stack(y, dim=0) 74 | pseudo_y = torch.stack(pseudo_y, dim=0).to(self.device) 75 | images = self.generator(z, labels=pseudo_y).detach().cpu() 76 | if self.transform is not None: 77 | images = self.transform(images) 78 | return images, y 79 | -------------------------------------------------------------------------------- /src/modelinversion/defense/BiDO/__init__.py: -------------------------------------------------------------------------------- 1 | from .trainer import BiDOTrainArgs, BiDOTrainer 2 | -------------------------------------------------------------------------------- /src/modelinversion/defense/BiDO/kernel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def distmat(X): 6 | """distance matrix""" 7 | assert X.ndim == 2 8 | r = torch.sum(X * X, dim=1, keepdim=True) 9 | # r = r.view([-1, 1]) 10 | a = torch.mm(X, torch.transpose(X, 0, 1)) 11 | D = r.expand_as(a) - 2 * a + torch.transpose(r, 0, 1).expand_as(a) 12 | D = torch.abs(D) 13 | return D 14 | 15 | 16 | def sigma_estimation(X, Y): 17 | """sigma from median distance""" 18 | D = distmat(torch.cat([X, Y])) 19 | D = D.detach().cpu().numpy() 20 | Itri = np.tril_indices(D.shape[0], -1) 21 | Tri = D[Itri] 22 | med = np.median(Tri) 23 | if med <= 0: 24 | med = np.mean(Tri) 25 | if med < 1e-2: 26 | med = 1e-2 27 | return med 28 | 29 | 30 | def hisc_kernelmat(X, sigma, ktype='gaussian'): 31 | """kernel matrix baker""" 32 | m = int(X.size()[0]) 33 | H = torch.eye(m) - (1.0 / m) * torch.ones([m, m]) 34 | 35 | if ktype == "gaussian": 36 | Dxx = distmat(X) 37 | 38 | if sigma: 39 | variance = 2.0 * sigma * sigma * X.size()[1] 40 | Kx = torch.exp(-Dxx / variance).type(torch.FloatTensor) # kernel matrices 41 | # print(sigma, torch.mean(Kx), torch.max(Kx), torch.min(Kx)) 42 | else: 43 | try: 44 | sx = sigma_estimation(X, X) 45 | Kx = torch.exp(-Dxx / (2.0 * sx * sx)).type(torch.FloatTensor) 46 | except RuntimeError as e: 47 | raise RuntimeError( 48 | "Unstable sigma {} with maximum/minimum input ({},{})".format( 49 | sx, torch.max(X), torch.min(X) 50 | ) 51 | ) 52 | 53 | elif ktype == "linear": 54 | Kx = torch.mm(X, X.T).type(torch.FloatTensor) 55 | 56 | elif ktype == 'IMQ': 57 | Dxx = distmat(X) 58 | Kx = 1 * torch.rsqrt(Dxx + 1) 59 | 60 | Kxc = torch.mm(Kx, H) 61 | 62 | return Kxc 63 | 64 | 65 | def hsic_normalized_cca(x, y, sigma, ktype='gaussian'): 66 | m = int(x.size()[0]) 67 | Kxc = hisc_kernelmat(x, sigma=sigma) 68 | Kyc = hisc_kernelmat(y, sigma=sigma, ktype=ktype) 69 | 70 | epsilon = 1e-5 71 | K_I = torch.eye(m) 72 | Kxc_i = torch.inverse(Kxc + epsilon * m * K_I) 73 | Kyc_i = torch.inverse(Kyc + epsilon * m * K_I) 74 | Rx = Kxc.mm(Kxc_i) 75 | Ry = Kyc.mm(Kyc_i) 76 | Pxy = torch.sum(torch.mul(Rx, Ry.t())) 77 | 78 | return Pxy 79 | 80 | 81 | def hsic_objective(hidden, h_target, h_data, sigma, ktype='gaussian'): 82 | hsic_hx_val = hsic_normalized_cca(hidden, h_data, sigma=sigma) 83 | hsic_hy_val = hsic_normalized_cca(hidden, h_target, sigma=sigma, ktype=ktype) 84 | 85 | return hsic_hx_val, hsic_hy_val 86 | 87 | 88 | def coco_kernelmat(X, sigma, ktype='gaussian'): 89 | """kernel matrix baker""" 90 | m = int(X.size()[0]) 91 | H = torch.eye(m) - (1.0 / m) * torch.ones([m, m]) 92 | 93 | if ktype == "gaussian": 94 | Dxx = distmat(X) 95 | 96 | if sigma: 97 | variance = 2.0 * sigma * sigma * X.size()[1] 98 | Kx = torch.exp(-Dxx / variance).type(torch.FloatTensor) # kernel matrices 99 | # print(sigma, torch.mean(Kx), torch.max(Kx), torch.min(Kx)) 100 | else: 101 | try: 102 | sx = sigma_estimation(X, X) 103 | Kx = torch.exp(-Dxx / (2.0 * sx * sx)).type(torch.FloatTensor) 104 | except RuntimeError as e: 105 | raise RuntimeError( 106 | "Unstable sigma {} with maximum/minimum input ({},{})".format( 107 | sx, torch.max(X), torch.min(X) 108 | ) 109 | ) 110 | 111 | ## Adding linear kernel 112 | elif ktype == "linear": 113 | Kx = torch.mm(X, X.T).type(torch.FloatTensor) 114 | 115 | elif ktype == 'IMQ': 116 | Dxx = distmat(X) 117 | Kx = 1 * torch.rsqrt(Dxx + 1) 118 | 119 | Kxc = torch.mm(H, torch.mm(Kx, H)) 120 | 121 | return Kxc 122 | 123 | 124 | def coco_normalized_cca(x, y, sigma, ktype='gaussian'): 125 | m = int(x.size()[0]) 126 | K = coco_kernelmat(x, sigma=sigma) 127 | L = coco_kernelmat(y, sigma=sigma, ktype=ktype) 128 | 129 | res = torch.sqrt(torch.norm(torch.mm(K, L))) / m 130 | return res 131 | 132 | 133 | def coco_objective(hidden, h_target, h_data, sigma, ktype='gaussian'): 134 | coco_hx_val = coco_normalized_cca(hidden, h_data, sigma=sigma) 135 | coco_hy_val = coco_normalized_cca(hidden, h_target, sigma=sigma, ktype=ktype) 136 | 137 | return coco_hx_val, coco_hy_val 138 | -------------------------------------------------------------------------------- /src/modelinversion/defense/BiDO/trainer.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | 3 | import torch 4 | from torch import LongTensor 5 | from torch.nn import Module, MaxPool2d, Sequential 6 | from torch.optim import Optimizer 7 | from torch.optim.lr_scheduler import LRScheduler 8 | from torch.nn import functional as F 9 | 10 | from ...models import ModelResult 11 | from ...models.classifiers import BaseTargetModel 12 | from ...utils import traverse_module, OutputHook, BaseHook 13 | from ...foldermanager import FolderManager 14 | from ..base import BaseTrainArgs, BaseTrainer 15 | from ...models.get_models import NUM_CLASSES 16 | from .kernel import hsic_objective, coco_objective 17 | 18 | 19 | @dataclass 20 | class BiDOTrainArgs(BaseTrainArgs): 21 | 22 | kernel_type: str = field( 23 | default='linear', metadata={'help': 'kernel type: linear, gaussian, IMQ'} 24 | ) 25 | 26 | bido_loss_type: str = field( 27 | default='hisc', metadata={'help': 'loss type: hisc, coco'} 28 | ) 29 | 30 | coef_hidden_input: float = field( 31 | default=0.05, metadata={'help': 'coef of loss between hidden and input'} 32 | ) 33 | coef_hidden_output: float = field( 34 | default=0.5, metadata={'help': 'coef of loss between hidden and output'} 35 | ) 36 | 37 | 38 | class BiDOTrainer(BaseTrainer): 39 | 40 | def __init__( 41 | self, 42 | args: BiDOTrainArgs, 43 | folder_manager: FolderManager, 44 | model: BaseTargetModel, 45 | optimizer: Optimizer, 46 | lr_scheduler: LRScheduler = None, 47 | **kwargs, 48 | ) -> None: 49 | super().__init__(args, folder_manager, model, optimizer, lr_scheduler, **kwargs) 50 | 51 | self.hiddens_hooks: list[BaseHook] = [] 52 | 53 | if self.args.bido_loss_type == 'hisc': 54 | self.objective_fn = hsic_objective 55 | elif self.args.bido_loss_type == 'coco': 56 | self.objective_fn = coco_objective 57 | else: 58 | raise RuntimeError( 59 | f'loss type `{self.args.bido_loss_type}` is not supported, valid loss types: `hisc` and `coco`' 60 | ) 61 | 62 | def _to_onehot(self, y, num_classes): 63 | """1-hot encodes a tensor""" 64 | # return torch.squeeze(torch.eye(num_classes)[y.cpu()], dim=1) 65 | return ( 66 | torch.zeros((len(y), num_classes)) 67 | .to(self.args.device) 68 | .scatter_(1, y.reshape(-1, 1), 1.0) 69 | ) 70 | 71 | # def _add_hook(self, module: Module): 72 | # if self.args.model_name == 'vgg16': 73 | # if isinstance(module, MaxPool2d): 74 | # self.hiddens_hooks.append(OutputHook(module)) 75 | # elif self.args.model_name in ['ir152', 'facenet64', 'facenet']: 76 | # if isinstance(module, Sequential): 77 | # self.hiddens_hooks.append(OutputHook(module)) 78 | # else: 79 | # raise RuntimeError(f'model {self.args.model_name} is not support for BiDO') 80 | 81 | def before_train(self): 82 | super().before_train() 83 | self.hiddens_hooks.clear() 84 | # traverse_module(self.model, self._add_hook, call_middle=True) 85 | self.hiddens_hooks.extend(self.model.create_hidden_hooks()) 86 | assert len(self.hiddens_hooks) > 0 87 | 88 | # print(f'hook num: {len(self.hiddens_hooks)}') 89 | 90 | def after_train(self): 91 | super().after_train() 92 | for hook in self.hiddens_hooks: 93 | hook.close() 94 | 95 | def calc_loss(self, inputs: torch.Tensor, result: ModelResult, labels: LongTensor): 96 | res = result.result 97 | bs = len(inputs) 98 | 99 | total_loss = 0 100 | cross_loss = F.cross_entropy(res, labels) 101 | 102 | total_loss += cross_loss 103 | 104 | h_data = inputs.view(bs, -1) 105 | h_label = ( 106 | self._to_onehot(labels, NUM_CLASSES[self.args.dataset_name]) 107 | .to(self.args.device) 108 | .view(bs, -1) 109 | ) 110 | 111 | for hidden_hook in self.hiddens_hooks: 112 | h_hidden = hidden_hook.get_feature().reshape(bs, -1) 113 | 114 | hidden_input_loss, hidden_output_loss = self.objective_fn( 115 | h_hidden, h_label, h_data, 5.0, self.args.kernel_type 116 | ) 117 | 118 | total_loss += self.args.coef_hidden_input * hidden_input_loss 119 | total_loss += -self.args.coef_hidden_output * hidden_output_loss 120 | 121 | return total_loss 122 | -------------------------------------------------------------------------------- /src/modelinversion/defense/DP/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ffhibnese/Model-Inversion-Attack-ToolBox/d9ad48f997aa204e28696b87ce49a25384fcf794/src/modelinversion/defense/DP/__init__.py -------------------------------------------------------------------------------- /src/modelinversion/defense/DP/trainer.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import torch 4 | from torch import LongTensor 5 | from torch.nn import Module 6 | import torch.nn.functional as F 7 | from torch.optim import Optimizer 8 | from torch.optim.lr_scheduler import LRScheduler 9 | 10 | from ..base import BaseTrainer, BaseTrainArgs 11 | from ...models import ModelResult 12 | from ...foldermanager import FolderManager 13 | 14 | 15 | @dataclass 16 | class DPTrainArgs(BaseTrainArgs): 17 | 18 | noise_multiplier: float = 0.01 19 | microbatch_size: int = 1 20 | 21 | 22 | class DPTrainer(BaseTrainer): 23 | 24 | def __init__( 25 | self, 26 | args: DPTrainArgs, 27 | folder_manager: FolderManager, 28 | model: Module, 29 | optimizer: Optimizer, 30 | scheduler: LRScheduler = None, 31 | **kwargs 32 | ) -> None: 33 | super().__init__(args, folder_manager, model, optimizer, scheduler, **kwargs) 34 | self.args: DPTrainArgs 35 | 36 | def calc_loss(self, inputs, result: ModelResult, labels: LongTensor): 37 | pred_res = result.result 38 | # pred_res = F.softmax(pred_res, dim=1) 39 | # return self.criterion(pred_res, labels) 40 | return F.cross_entropy(pred_res, labels, reduction='none') 41 | 42 | def before_train(self): 43 | super().before_train() 44 | # self.avg_norm = 0 45 | 46 | def _update_step(self, loss): 47 | 48 | # loss.backward() 49 | bs = len(loss) 50 | 51 | parameters = [param for param in self.model.parameters() if param.requires_grad] 52 | 53 | grad = [torch.zeros_like(param) for param in parameters] 54 | num_microbatch = (bs - 1) // self.args.microbatch_size + 1 55 | 56 | max_norm = self.args.clip_grad_norm 57 | 58 | # print(len(list(range(0, bs, num_microbatch)))) 59 | # exit() 60 | for j in range(0, bs, self.args.microbatch_size): 61 | self.optimizer.zero_grad() 62 | torch.autograd.backward( 63 | torch.mean(loss[j : min(j + self.args.microbatch_size, bs)]), 64 | retain_graph=True, 65 | ) 66 | 67 | l2norm = 0.0 68 | for param in parameters: 69 | l2norm += (param.grad * param.grad).sum() 70 | l2norm = torch.sqrt(l2norm) 71 | 72 | # self.avg_norm = self.avg_norm * 0.95 + l2norm * 0.05 73 | 74 | coef = 1 if max_norm is None else (max_norm / max(max_norm, l2norm.item())) 75 | grad = [g + param.grad * coef for param, g in zip(parameters, grad)] 76 | 77 | if max_norm is None: 78 | max_norm = 1.0 79 | 80 | for param, g in zip(parameters, grad): 81 | param.grad.data = g 82 | if self.args.noise_multiplier > 0: 83 | param.grad.data += ( 84 | torch.cuda.FloatTensor(g.size()) 85 | .normal_(0, self.args.noise_multiplier * float(max_norm)) 86 | .to(self.args.device) 87 | ) # torch.randn_like(g) * self.args.noise_multiplier * max_norm 88 | param.grad.data /= num_microbatch 89 | 90 | self.optimizer.step() 91 | -------------------------------------------------------------------------------- /src/modelinversion/defense/LS/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ffhibnese/Model-Inversion-Attack-ToolBox/d9ad48f997aa204e28696b87ce49a25384fcf794/src/modelinversion/defense/LS/__init__.py -------------------------------------------------------------------------------- /src/modelinversion/defense/LS/trainer.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | 3 | import torch 4 | from torch import LongTensor 5 | from torch.nn import Module, MaxPool2d, Sequential 6 | from torch.optim import Optimizer 7 | from torch.optim.lr_scheduler import LRScheduler 8 | from torch.nn import functional as F 9 | 10 | from ...models import ModelResult 11 | from ...models.classifiers import BaseTargetModel 12 | from ...utils import traverse_module, OutputHook, BaseHook 13 | from ...foldermanager import FolderManager 14 | from ..base import BaseTrainArgs, BaseTrainer 15 | 16 | 17 | @dataclass 18 | class LSTrainArgs(BaseTrainArgs): 19 | 20 | coef_label_smoothing: float = 0.1 21 | 22 | 23 | class LSTrainer(BaseTrainer): 24 | 25 | def __init__( 26 | self, 27 | args: LSTrainArgs, 28 | folder_manager: FolderManager, 29 | model: BaseTargetModel, 30 | optimizer: Optimizer, 31 | lr_scheduler: LRScheduler = None, 32 | **kwargs 33 | ) -> None: 34 | super().__init__(args, folder_manager, model, optimizer, lr_scheduler, **kwargs) 35 | 36 | def _neg_label_smoothing(self, inputs, labels): 37 | ls = self.args.coef_label_smoothing 38 | confidence = 1.0 - ls 39 | logprobs = F.log_softmax(inputs, dim=-1) 40 | nll_loss = -logprobs.gather(dim=-1, index=labels.unsqueeze(1)) 41 | nll_loss = nll_loss.squeeze(1) 42 | smooth_loss = -logprobs.mean(dim=-1) 43 | loss = confidence * nll_loss + ls * smooth_loss 44 | return torch.mean(loss, dim=0).sum() 45 | 46 | def calc_loss(self, inputs: torch.Tensor, result: ModelResult, labels: LongTensor): 47 | res = result.result 48 | bs = len(inputs) 49 | 50 | return self._neg_label_smoothing(res, labels) 51 | -------------------------------------------------------------------------------- /src/modelinversion/defense/README.md: -------------------------------------------------------------------------------- 1 | # ModelInversionAttackBox 2 | Defense algorithms. 3 | -------------------------------------------------------------------------------- /src/modelinversion/defense/TL/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ffhibnese/Model-Inversion-Attack-ToolBox/d9ad48f997aa204e28696b87ce49a25384fcf794/src/modelinversion/defense/TL/__init__.py -------------------------------------------------------------------------------- /src/modelinversion/defense/TL/trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import dataclass 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from torch import LongTensor 7 | from torch.nn import Module 8 | from torch.optim import Optimizer 9 | from torch.optim.lr_scheduler import LRScheduler 10 | 11 | from ..base import BaseTrainer, BaseTrainArgs 12 | from ..BiDO.trainer import BiDOTrainer, BiDOTrainArgs 13 | from ...models import ModelResult 14 | from ...foldermanager import FolderManager 15 | 16 | 17 | @dataclass 18 | class TLTrainArgs(BiDOTrainArgs): 19 | pass 20 | 21 | 22 | class TLTrainer(BiDOTrainer): 23 | 24 | def __init__( 25 | self, 26 | args: TLTrainArgs, 27 | folder_manager: FolderManager, 28 | model: Module, 29 | optimizer: Optimizer, 30 | scheduler: LRScheduler = None, 31 | **kwargs 32 | ) -> None: 33 | super().__init__(args, folder_manager, model, optimizer, scheduler, **kwargs) 34 | 35 | def before_train_step(self): 36 | super().before_train_step() 37 | self.model.freeze_front_layers() 38 | -------------------------------------------------------------------------------- /src/modelinversion/defense/Vib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ffhibnese/Model-Inversion-Attack-ToolBox/d9ad48f997aa204e28696b87ce49a25384fcf794/src/modelinversion/defense/Vib/__init__.py -------------------------------------------------------------------------------- /src/modelinversion/defense/Vib/trainer.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import torch.nn.functional as F 4 | from torch import LongTensor 5 | from torch.nn import Module 6 | from torch.optim import Optimizer 7 | from torch.optim.lr_scheduler import LRScheduler 8 | 9 | from ..base import BaseTrainer, BaseTrainArgs 10 | from ...models import ModelResult 11 | from ...foldermanager import FolderManager 12 | 13 | 14 | @dataclass 15 | class VibTrainArgs(BaseTrainArgs): 16 | beta: float = 1e-2 17 | 18 | 19 | class VibTrainer(BaseTrainer): 20 | 21 | def __init__( 22 | self, 23 | args: BaseTrainArgs, 24 | folder_manager: FolderManager, 25 | model: Module, 26 | optimizer: Optimizer, 27 | scheduler: LRScheduler = None, 28 | **kwargs 29 | ) -> None: 30 | super().__init__(args, folder_manager, model, optimizer, scheduler, **kwargs) 31 | 32 | def calc_loss(self, inputs, result: ModelResult, labels: LongTensor): 33 | pred_res = result.result 34 | mu = result.addition_info['mu'] 35 | std = result.addition_info['std'] 36 | cross_loss = F.cross_entropy(pred_res, labels) 37 | info_loss = ( 38 | -0.5 * (1 + 2 * std.log() - mu.pow(2) - std.pow(2)).sum(dim=1).mean() 39 | ) 40 | loss = cross_loss + self.args.beta * info_loss 41 | return loss 42 | -------------------------------------------------------------------------------- /src/modelinversion/defense/__init__.py: -------------------------------------------------------------------------------- 1 | from ..trainer import BaseTrainArgs, TqdmStrategy 2 | from .BiDO import BiDOTrainArgs, BiDOTrainer 3 | from .no_defense.trainer import RegTrainer 4 | from .Vib.trainer import VibTrainer, VibTrainArgs 5 | from .TL.trainer import TLTrainArgs, TLTrainer 6 | from .DP.trainer import DPTrainArgs, DPTrainer 7 | from .LS.trainer import LSTrainArgs, LSTrainer 8 | -------------------------------------------------------------------------------- /src/modelinversion/defense/base.py: -------------------------------------------------------------------------------- 1 | from ..trainer import BaseTrainArgs, BaseTrainer 2 | -------------------------------------------------------------------------------- /src/modelinversion/defense/distill/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ffhibnese/Model-Inversion-Attack-ToolBox/d9ad48f997aa204e28696b87ce49a25384fcf794/src/modelinversion/defense/distill/__init__.py -------------------------------------------------------------------------------- /src/modelinversion/defense/distill/trainer.py: -------------------------------------------------------------------------------- 1 | from torch import LongTensor 2 | 3 | from modelinversion.models import ModelResult 4 | from ..base import * 5 | from ...models import ModelResult 6 | import torch.nn.functional as F 7 | from ...foldermanager import FolderManager 8 | from torch.nn import Module 9 | from torch.optim import Optimizer 10 | from torch.optim.lr_scheduler import LRScheduler 11 | 12 | 13 | class DistillTrainer(BaseTrainer): 14 | 15 | def __init__( 16 | self, 17 | args: BaseTrainArgs, 18 | folder_manager: FolderManager, 19 | model: Module, 20 | optimizer: Optimizer, 21 | scheduler: LRScheduler = None, 22 | teacher_model: Module = None, 23 | **kwargs 24 | ) -> None: 25 | super().__init__(args, folder_manager, model, optimizer, scheduler, **kwargs) 26 | assert teacher_model is not None 27 | self.teacher = teacher_model.to(args.device) 28 | self.teacher.eval() 29 | 30 | def calc_loss(self, inputs, result: ModelResult, labels: LongTensor): 31 | return 0 32 | 33 | def _train_step(self, inputs, labels) -> TrainStepResult: 34 | self.before_train_step() 35 | 36 | result = self.model(inputs) 37 | 38 | pred_res = result 39 | teacher_res = self.teacher(inputs) 40 | loss = F.kl_div( 41 | F.log_softmax(pred_res, dim=-1), 42 | F.softmax(teacher_res, dim=-1), 43 | reduction='sum', 44 | ) 45 | 46 | acc = ( 47 | (torch.argmax(pred_res, dim=-1) == torch.argmax(teacher_res, dim=-1)) 48 | .float() 49 | .mean() 50 | ) 51 | 52 | self._update_step(loss) 53 | 54 | return TrainStepResult(loss.mean().item(), acc.item()) 55 | -------------------------------------------------------------------------------- /src/modelinversion/defense/no_defense/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ffhibnese/Model-Inversion-Attack-ToolBox/d9ad48f997aa204e28696b87ce49a25384fcf794/src/modelinversion/defense/no_defense/__init__.py -------------------------------------------------------------------------------- /src/modelinversion/defense/no_defense/trainer.py: -------------------------------------------------------------------------------- 1 | from torch import LongTensor 2 | from ..base import BaseTrainer, BaseTrainArgs 3 | from ...models import ModelResult 4 | import torch.nn.functional as F 5 | from ...foldermanager import FolderManager 6 | from torch.nn import Module 7 | from torch.optim import Optimizer 8 | from torch.optim.lr_scheduler import LRScheduler 9 | 10 | 11 | class RegTrainer(BaseTrainer): 12 | 13 | def __init__( 14 | self, 15 | args: BaseTrainArgs, 16 | folder_manager: FolderManager, 17 | model: Module, 18 | optimizer: Optimizer, 19 | scheduler: LRScheduler = None, 20 | **kwargs 21 | ) -> None: 22 | super().__init__(args, folder_manager, model, optimizer, scheduler, **kwargs) 23 | 24 | def calc_loss(self, inputs, result: ModelResult, labels: LongTensor): 25 | pred_res = result.result 26 | # return self.criterion(pred_res, labels) 27 | return F.cross_entropy(pred_res, labels) 28 | -------------------------------------------------------------------------------- /src/modelinversion/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | # from .knn import generate_private_feats, calc_knn 2 | # from .fid.fid import calc_fid 3 | # from .psnr import calc_psnr 4 | 5 | from .base import * 6 | 7 | # __all__ = ['get_knn_dist', 'calc_fid'] 8 | -------------------------------------------------------------------------------- /src/modelinversion/metrics/fid/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ffhibnese/Model-Inversion-Attack-ToolBox/d9ad48f997aa204e28696b87ce49a25384fcf794/src/modelinversion/metrics/fid/__init__.py -------------------------------------------------------------------------------- /src/modelinversion/metrics/fid/fid_utils.py: -------------------------------------------------------------------------------- 1 | """Derived from https://github.com/mseitzer/pytorch-fid/blob/master/fid_score.py""" # NOQA 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn.functional as F 6 | from scipy import linalg 7 | 8 | import warnings 9 | 10 | 11 | def get_activations(images, model, batch_size=64, dims=2048, device=None): 12 | model.eval() 13 | 14 | d0 = len(images) 15 | if batch_size > d0: 16 | print( 17 | ( 18 | 'Warning: batch size is bigger than the data size. ' 19 | 'Setting batch size to data size' 20 | ) 21 | ) 22 | batch_size = d0 23 | 24 | n_batches = d0 // batch_size 25 | n_used_imgs = n_batches * batch_size 26 | 27 | pred_arr = np.empty((n_used_imgs, dims)) 28 | for i in range(n_batches): 29 | start = i * batch_size 30 | end = start + batch_size 31 | 32 | batch = torch.from_numpy(images[start:end]).type(torch.FloatTensor) 33 | if device is not None: 34 | batch = batch.to(device) 35 | 36 | with torch.no_grad(): 37 | pred = model(batch)[0] 38 | 39 | # If model output is not scalar, apply global spatial average pooling. 40 | # This happens if you choose a dimensionality not equal 2048. 41 | if pred.shape[2] != 1 or pred.shape[3] != 1: 42 | pred = F.adaptive_avg_pool2d(pred, output_size=(1, 1)) 43 | 44 | pred_arr[start:end] = pred.cpu().numpy().reshape(batch_size, -1) 45 | 46 | return pred_arr 47 | 48 | 49 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): 50 | mu1 = np.atleast_1d(mu1) 51 | mu2 = np.atleast_1d(mu2) 52 | 53 | sigma1 = np.atleast_2d(sigma1) 54 | sigma2 = np.atleast_2d(sigma2) 55 | 56 | assert ( 57 | mu1.shape == mu2.shape 58 | ), 'Training and test mean vectors have different lengths' 59 | assert ( 60 | sigma1.shape == sigma2.shape 61 | ), 'Training and test covariances have different dimensions' 62 | 63 | diff = mu1 - mu2 64 | 65 | # Product might be almost singular 66 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) 67 | if not np.isfinite(covmean).all(): 68 | msg = ( 69 | 'fid calculation produces singular product; ' 70 | 'adding %s to diagonal of cov estimates' 71 | ) % eps 72 | print(msg) 73 | offset = np.eye(sigma1.shape[0]) * eps 74 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) 75 | 76 | # Numerical error might give slight imaginary component 77 | if np.iscomplexobj(covmean): 78 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 79 | m = np.max(np.abs(covmean.imag)) 80 | # raise ValueError('Imaginary component {}'.format(m)) 81 | warnings.warn(f'Insufficient image quantity. Return FID=0') 82 | return 0 83 | covmean = covmean.real 84 | 85 | tr_covmean = np.trace(covmean) 86 | 87 | return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean 88 | 89 | 90 | def calculate_activation_statistics( 91 | images, model, batch_size=64, dims=2048, device=None 92 | ): 93 | act = get_activations(images, model, batch_size, dims, device) 94 | mu = np.mean(act, axis=0) 95 | sigma = np.cov(act, rowvar=False) 96 | return mu, sigma 97 | -------------------------------------------------------------------------------- /src/modelinversion/metrics/psnr/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | import torchvision 4 | import numpy as np 5 | from PIL import Image 6 | import os 7 | 8 | 9 | def psnr(fake_imgs, real_imgs, combination=False, factor=1.0, eps=1e-6): 10 | """calculate psnr between fake imgs and real imgs 11 | 12 | Args: 13 | fake_imgs (Tensor): n fake imgs 14 | real_imgs (Tensor): m real imgs 15 | batch_size (int, optional): batch size. Defaults to 60. 16 | combination (bool, optional): 17 | if True: 18 | Combine fake and real each other and calcutate, m * n times 19 | else: 20 | fake imgs and real imgs are in pairs, n times 21 | . Defaults to False. 22 | factor (float, optional): factor for calculate. Defaults to 1.. 23 | """ 24 | 25 | if not combination and len(fake_imgs) != len(real_imgs): 26 | raise RuntimeError( 27 | 'number of fake imgs and real imgs should be the same when combination is False' 28 | ) 29 | 30 | def get_psnr(fake, real): 31 | mse = ((fake - real) ** 2).mean(dim=-1).mean(dim=-1).mean(dim=-1) 32 | return 10 * torch.log10(factor**2 / (mse + eps)) 33 | 34 | if combination: 35 | results = [] 36 | for i in range(len(fake_imgs)): 37 | fake = fake_imgs[i] 38 | ret = get_psnr(fake, real_imgs).max() 39 | results.append(ret.item()) 40 | return torch.Tensor(results).to(fake_imgs) 41 | else: 42 | return get_psnr(fake_imgs, real_imgs) 43 | 44 | 45 | def calc_psnr(recovery_img_dir, private_img_dir): 46 | 47 | trans = torchvision.transforms.ToTensor() 48 | 49 | psnr_all = 0 50 | num = 0 51 | 52 | for label in os.listdir(recovery_img_dir): 53 | recovery_label_dir = os.path.join(recovery_img_dir, label) 54 | private_label_dir = os.path.join(private_img_dir, label) 55 | if not os.path.exists(private_img_dir): 56 | continue 57 | 58 | def read_imgs(dir_name): 59 | 60 | res = [] 61 | for img_name in os.listdir(dir_name): 62 | img_path = os.path.join(dir_name, img_name) 63 | try: 64 | img = Image.open(img_path) 65 | except: 66 | continue 67 | res.append(trans(img)) 68 | res = torch.stack(res, dim=0) 69 | return res 70 | 71 | recovery_imgs = read_imgs(recovery_label_dir) 72 | private_imgs = read_imgs(private_label_dir) 73 | 74 | psnr_res = psnr(recovery_imgs, private_imgs, combination=True) 75 | 76 | psnr_all += psnr_res.sum().item() 77 | num += len(psnr_res) 78 | 79 | if num == 0: 80 | raise RuntimeError('no imgs') 81 | 82 | res = psnr_all / num 83 | 84 | print(f'psnr: {res}') 85 | 86 | return res 87 | -------------------------------------------------------------------------------- /src/modelinversion/metrics/ssim/__init__.py: -------------------------------------------------------------------------------- 1 | """This is code based on https://sudomake.ai/inception-score-explained/.""" 2 | 3 | import torch 4 | import torchvision 5 | import torch.nn.functional as F 6 | from torch.utils.data import DataLoader 7 | from collections import defaultdict 8 | import math 9 | import numpy as np 10 | from torch.autograd import Variable 11 | import os 12 | from PIL import Image 13 | 14 | 15 | def ssim_gaussian(window_size, sigma): 16 | gauss = torch.Tensor( 17 | [ 18 | math.exp(-((x - window_size // 2) ** 2) / float(2 * sigma**2)) 19 | for x in range(window_size) 20 | ] 21 | ) 22 | return gauss / gauss.sum() 23 | 24 | 25 | def ssim_create_window(window_size, channel): 26 | _1D_window = ssim_gaussian(window_size, 1.5).unsqueeze(1) 27 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 28 | window = Variable( 29 | _2D_window.expand(channel, 1, window_size, window_size).contiguous() 30 | ) 31 | return window 32 | 33 | 34 | def _ssim_core(img1, img2, window, window_size, channel, size_average=True): 35 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) 36 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) 37 | 38 | mu1_sq = mu1.pow(2) 39 | mu2_sq = mu2.pow(2) 40 | mu1_mu2 = mu1 * mu2 41 | 42 | sigma1_sq = ( 43 | F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq 44 | ) 45 | sigma2_sq = ( 46 | F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq 47 | ) 48 | sigma12 = ( 49 | F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) 50 | - mu1_mu2 51 | ) 52 | 53 | C1 = 0.01**2 54 | C2 = 0.03**2 55 | 56 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ( 57 | (mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2) 58 | ) 59 | 60 | if size_average: 61 | return ssim_map.mean() 62 | else: 63 | return ssim_map.mean(1).mean(1).mean(1) 64 | 65 | 66 | def ssim(fake_imgs, real_imgs, combination=False, window_size=11): 67 | 68 | (_, channel, _, _) = fake_imgs.size() 69 | window = ssim_create_window(window_size, channel) 70 | 71 | window = window.to(fake_imgs) 72 | 73 | def get_ssim(fake, real): 74 | return _ssim_core(fake, real, window, window_size, channel, size_average=False) 75 | 76 | if combination: 77 | results = [] 78 | for i in range(len(fake_imgs)): 79 | fake = fake_imgs[i] 80 | ret = get_ssim(fake, real_imgs).max() # 改成mean ? 81 | results.append(ret.item()) 82 | return torch.Tensor(results).to(fake_imgs) 83 | else: 84 | return get_ssim(fake_imgs, real_imgs) 85 | 86 | 87 | def calc_ssim(recovery_img_dir, private_img_dir): 88 | 89 | trans = torchvision.transforms.ToTensor() 90 | 91 | ssim_all = 0 92 | num = 0 93 | 94 | for label in os.listdir(recovery_img_dir): 95 | recovery_label_dir = os.path.join(recovery_img_dir, label) 96 | private_label_dir = os.path.join(private_img_dir, label) 97 | if not os.path.exists(private_img_dir): 98 | continue 99 | 100 | def read_imgs(dir_name): 101 | 102 | res = [] 103 | for img_name in os.listdir(dir_name): 104 | img_path = os.path.join(dir_name, img_name) 105 | try: 106 | img = Image.open(img_path) 107 | except: 108 | continue 109 | res.append(trans(img)) 110 | res = torch.stack(res, dim=0) 111 | return res 112 | 113 | recovery_imgs = read_imgs(recovery_label_dir) 114 | private_imgs = read_imgs(private_label_dir) 115 | 116 | ssim_res = ssim(recovery_imgs, private_imgs, combination=True) 117 | 118 | ssim_all += ssim_res.sum().item() 119 | num += len(ssim_res) 120 | 121 | if num == 0: 122 | raise RuntimeError('no imgs') 123 | 124 | res = ssim_all / num 125 | 126 | print(f'psnr: {res}') 127 | 128 | return res 129 | -------------------------------------------------------------------------------- /src/modelinversion/models/README.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /src/modelinversion/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .classifiers import * 2 | from .gans import * 3 | from .adapters import * -------------------------------------------------------------------------------- /src/modelinversion/models/adapters/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import * 2 | from .c2f import * -------------------------------------------------------------------------------- /src/modelinversion/models/adapters/base.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from abc import abstractmethod 3 | from copy import deepcopy 4 | from typing import Callable, Optional, Any 5 | from functools import wraps 6 | 7 | import torch 8 | from torch import nn 9 | import torch.nn.functional as F 10 | import torchvision.models as tvmodel 11 | import torchvision.transforms.functional as TF 12 | from torchvision.models.inception import InceptionOutputs 13 | 14 | from ..base import ModelMixin 15 | from ...utils import traverse_name_module, FirstInputHook, BaseHook 16 | 17 | BUILDIN_ADAPTERS = {} 18 | CLASSNAME_TO_NAME_MAPPING = {} 19 | 20 | 21 | def register_adapter(name: Optional[str] = None): 22 | """Register model for construct. 23 | 24 | Args: 25 | name (Optional[str], optional): The key of the model. Defaults to None. 26 | """ 27 | 28 | def wrapper(c): 29 | key = name if name is not None else c.__name__ 30 | CLASSNAME_TO_NAME_MAPPING[c.__name__] = key 31 | if key in BUILDIN_ADAPTERS: 32 | raise ValueError(f"An entry is already registered under the name '{key}'.") 33 | BUILDIN_ADAPTERS[key] = c 34 | return c 35 | 36 | return wrapper 37 | 38 | 39 | class BaseAdapter(ModelMixin): 40 | 41 | def save_pretrained(self, path, **add_infos): 42 | return super().save_pretrained( 43 | path, 44 | model_name=CLASSNAME_TO_NAME_MAPPING[self.__class__.__name__], 45 | **add_infos, 46 | ) 47 | 48 | 49 | class ModelConstructException(Exception): 50 | pass 51 | 52 | 53 | def construct_adapters_by_name(name: str, **kwargs): 54 | 55 | if name in BUILDIN_ADAPTERS: 56 | return BUILDIN_ADAPTERS[name](**kwargs) 57 | 58 | raise ModelConstructException(f'Module name {name} not found.') 59 | 60 | 61 | def list_adapters(): 62 | """List all valid module names""" 63 | return sorted(BUILDIN_ADAPTERS.keys()) 64 | 65 | 66 | def auto_adapter_from_pretrained(data_or_path, **kwargs): 67 | 68 | if isinstance(data_or_path, str): 69 | data = torch.load(data_or_path, map_location='cpu') 70 | else: 71 | data = data_or_path 72 | if 'model_name' not in data: 73 | raise RuntimeError('model_name is not contained in the data') 74 | 75 | cls: ModelMixin = BUILDIN_ADAPTERS[data['model_name']] 76 | return cls.from_pretrained(data_or_path, **kwargs) 77 | -------------------------------------------------------------------------------- /src/modelinversion/models/adapters/c2f.py: -------------------------------------------------------------------------------- 1 | from torch.nn.modules import Module 2 | from .base import * 3 | 4 | 5 | # @register_model('c2f_mlp2') 6 | class C2fOutputMapping(BaseAdapter): 7 | 8 | # @ModelMixin.register_to_config_init 9 | def __init__(self, input_dim, map: nn.Module, trunc: int = 1): 10 | super(C2fOutputMapping, self).__init__() 11 | 12 | self.trunc = trunc 13 | self.input_dim = input_dim 14 | self.map = map 15 | # 10575 16 | 17 | def forward(self, x): 18 | # input_dim = x.shape[-1] 19 | topk, index = torch.topk(x, self.trunc) 20 | topk = torch.clamp(torch.log(topk), min=-1000) + 50.0 21 | topk_min = topk.min(1, keepdim=True)[0] 22 | topk = topk + F.relu(-topk_min) 23 | x = torch.zeros_like(x).scatter_(1, index, topk) 24 | x = x.view(-1, self.input_dim) 25 | x = F.normalize(x, 2, dim=1) 26 | x = self.map(x) 27 | # x = F.normalize(x, 2, dim=1) 28 | return x 29 | 30 | 31 | @register_adapter('c2f_mlp3') 32 | class C2fThreeLayerMlpOutputMapping(C2fOutputMapping): 33 | 34 | @ModelMixin.register_to_config_init 35 | def __init__(self, input_dim, hidden_dim, output_dim, trunc: int = 1): 36 | map = nn.Sequential( 37 | nn.Linear(input_dim, hidden_dim), 38 | nn.LeakyReLU(0.2, inplace=True), 39 | nn.Linear(hidden_dim, hidden_dim), 40 | nn.LeakyReLU(0.2, inplace=True), 41 | # nn.Linear(4096, 4096), 42 | # nn.LeakyReLU(0.2, inplace=True), 43 | nn.Dropout(0.25), 44 | # nn.Linear(4096, 4096), 45 | # nn.LeakyReLU(0.2, inplace=True), 46 | nn.Linear(hidden_dim, output_dim), 47 | # nn.BatchNorm1d(128, eps=0.0000001, momentum=0.1, affine=True), 48 | ) 49 | 50 | super().__init__(input_dim, map, trunc) 51 | -------------------------------------------------------------------------------- /src/modelinversion/models/base.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | import torch 5 | from torch.nn import Module 6 | from ..utils import ConfigMixin, safe_save 7 | 8 | 9 | class ModelMixin(Module, ConfigMixin): 10 | 11 | # def save_config(self, save_path: str): 12 | # os.makedirs(save_path, exist_ok=True) 13 | # with open(save_path, 'w', encoding='utf8') as f: 14 | # json.dump(f, self._config_mixin_dict) 15 | 16 | # @staticmethod 17 | # def load_config(config_path: str): 18 | # if not os.path.exists(config_path): 19 | # raise RuntimeError(f'config_path {config_path} is not existed.') 20 | 21 | # with open(config_path, 'r', encoding='utf8') as f: 22 | # kwargs = json.load(config_path) 23 | 24 | # return kwargs 25 | 26 | def save_pretrained(self, path, **add_infos): 27 | save_result = { 28 | 'state_dict': self.state_dict(), 29 | 'config': self.preprocess_config_before_save(self._config_mixin_dict), 30 | **add_infos, 31 | } 32 | safe_save(save_result, path) 33 | 34 | @classmethod 35 | def from_pretrained(cls, data_or_path, **config_kwargs): 36 | 37 | if isinstance(data_or_path, str): 38 | data: dict = torch.load(data_or_path, map_location='cpu') 39 | else: 40 | data = data_or_path 41 | 42 | kwargs = cls.postprocess_config_after_load(data['config']) 43 | for k in config_kwargs: 44 | kwargs[k] = config_kwargs[k] 45 | init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")} 46 | model = cls(**init_kwargs) 47 | 48 | if 'state_dict' in data: 49 | state_dict = data['state_dict'] 50 | if state_dict is not None: 51 | # print(f'load state dict') 52 | model.load_state_dict(state_dict) 53 | 54 | return model 55 | -------------------------------------------------------------------------------- /src/modelinversion/models/classifiers/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import ( 2 | TorchvisionClassifierModel, 3 | ResNeSt, 4 | BaseImageClassifier, 5 | BaseImageEncoder, 6 | HOOK_NAME_FEATURE, 7 | HOOK_NAME_HIDDEN, 8 | list_classifiers, 9 | construct_classifiers_by_name, 10 | auto_classifier_from_pretrained, 11 | ) 12 | from .wrappers import ( 13 | VibWrapper, 14 | BiDOWrapper, 15 | get_default_create_hidden_hook_fn, 16 | origin_vgg16_64_hidden_hook_fn, 17 | ConditionPurifierWrapper, 18 | ) 19 | from .classifier64 import ( 20 | VGG16_64, 21 | IR152_64, 22 | FaceNet64, 23 | EfficientNet_b0_64, 24 | EfficientNet_b1_64, 25 | EfficientNet_b2_64, 26 | ) 27 | from .classifier112 import FaceNet112 28 | from .classifier_utils import generate_feature_statics 29 | from .inception import InceptionResnetV1_adaptive 30 | -------------------------------------------------------------------------------- /src/modelinversion/models/classifiers/classifier112.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | from torch import Tensor 4 | import torchvision 5 | 6 | from ..base import ModelMixin 7 | from ...utils import BaseHook 8 | 9 | from .base import * 10 | from .evolve import evolve 11 | 12 | 13 | @register_model(name='facenet112') 14 | class FaceNet112(BaseImageClassifier): 15 | 16 | @ModelMixin.register_to_config_init 17 | def __init__( 18 | self, 19 | num_classes=1000, 20 | register_last_feature_hook=False, 21 | backbone_path: Optional[str] = None, 22 | ): 23 | super(FaceNet112, self).__init__( 24 | 112, 512, num_classes, register_last_feature_hook 25 | ) 26 | self.feature = evolve.IR_50_112((112, 112)) 27 | if backbone_path is not None: 28 | state_dict = torch.load(backbone_path, map_location='cpu') 29 | self.feature.load_state_dict(state_dict) 30 | self.feat_dim = 512 31 | 32 | self.fc_layer = nn.Linear(self.feat_dim, self.num_classes) 33 | 34 | # self.feature_hook = FirstInputHook(self.fc_layer) 35 | 36 | def get_last_feature_hook(self) -> BaseHook: 37 | return self.feature_hook 38 | 39 | def preprocess_config_before_save(self, config): 40 | config = deepcopy(config) 41 | del config['backbone_path'] 42 | return super().preprocess_config_before_save(config) 43 | 44 | def _forward_impl(self, image: Tensor, *args, **kwargs): 45 | feat = self.feature(image) 46 | feat = feat.view(feat.size(0), -1) 47 | out = self.fc_layer(feat) 48 | 49 | return out, {HOOK_NAME_FEATURE: feat} 50 | -------------------------------------------------------------------------------- /src/modelinversion/models/classifiers/classifier_utils.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | from torch import Tensor 3 | from .base import * 4 | 5 | 6 | @torch.no_grad() 7 | def generate_feature_statics(dataloader, sample_num, classifier, device): 8 | 9 | features = [] 10 | for imgs in tqdm(dataloader, leave=False): 11 | if not isinstance(imgs, Tensor): 12 | imgs = imgs[0] 13 | 14 | if sample_num <= 0: 15 | break 16 | if sample_num < len(imgs): 17 | imgs = imgs[sample_num:] 18 | sample_num -= len(imgs) 19 | 20 | imgs = imgs.to(device) 21 | _, addition_info = classifier(imgs) 22 | if not HOOK_NAME_FEATURE in addition_info: 23 | raise RuntimeError( 24 | f'{HOOK_NAME_FEATURE} are not contains in the output of the classifier' 25 | ) 26 | features.append(addition_info[HOOK_NAME_FEATURE].cpu()) 27 | features = torch.cat(features, dim=0) 28 | features_mean = torch.mean(features, dim=0) 29 | features_std = torch.std(features, dim=0) 30 | return features_mean, features_std 31 | -------------------------------------------------------------------------------- /src/modelinversion/models/classifiers/evolve/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ffhibnese/Model-Inversion-Attack-ToolBox/d9ad48f997aa204e28696b87ce49a25384fcf794/src/modelinversion/models/classifiers/evolve/__init__.py -------------------------------------------------------------------------------- /src/modelinversion/models/gans/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import ( 2 | BaseImageGenerator, 3 | BaseIntermediateImageGenerator, 4 | construct_generator_by_name, 5 | construct_discriminator_by_name, 6 | show_generators, 7 | show_discriminators, 8 | list_generators, 9 | list_discriminators, 10 | auto_generator_from_pretrained, 11 | auto_discriminator_from_pretrained, 12 | ) 13 | from .simple import ( 14 | SimpleGenerator64, 15 | SimpleGenerator256, 16 | GmiDiscriminator64, 17 | GmiDiscriminator256, 18 | KedmiDiscriminator64, 19 | KedmiDiscriminator256, 20 | ) 21 | from .cgan import ( 22 | PlgmiGenerator64, 23 | PlgmiGenerator256, 24 | PlgmiDiscriminator64, 25 | PlgmiDiscriminator256, 26 | LoktDiscriminator64, 27 | LoktDiscriminator256, 28 | LoktGenerator64, 29 | LoktGenerator256, 30 | ) 31 | from .stylegan2ada import ( 32 | get_stylegan2ada_generator, 33 | StyleGan2adaMappingWrapper, 34 | StyleGAN2adaSynthesisWrapper, 35 | ) 36 | -------------------------------------------------------------------------------- /src/modelinversion/models/gans/stylegan2ada.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import pickle 4 | from typing import Optional 5 | 6 | import numpy as np 7 | import torch 8 | from torch import nn, Tensor 9 | 10 | from .base import BaseIntermediateImageGenerator 11 | from ...utils import check_shape 12 | 13 | 14 | class StyleGan2adaMappingWrapper(nn.Module): 15 | 16 | def __init__( 17 | self, 18 | mapping, 19 | single_w, 20 | truncation_psi=0.5, 21 | truncation_cutoff=8, 22 | *args, 23 | **kwargs, 24 | ) -> None: 25 | super().__init__(*args, **kwargs) 26 | self.mapping = mapping 27 | self.single_w = single_w 28 | self.truncation_psi = truncation_psi 29 | self.truncation_cutoff = truncation_cutoff 30 | self.w_dim = mapping.w_dim 31 | self.z_dim = mapping.z_dim 32 | self.num_ws = mapping.num_ws 33 | 34 | def forward(self, z): 35 | w = self.mapping( 36 | z, 37 | c=None, 38 | truncation_psi=self.truncation_psi, 39 | truncation_cutoff=self.truncation_cutoff, 40 | ) 41 | if self.single_w: 42 | w = w[:, [0]] 43 | return w 44 | 45 | 46 | class StyleGAN2adaSynthesisWrapper(BaseIntermediateImageGenerator): 47 | 48 | def __init__(self, synthesis, *args, **kwargs) -> None: 49 | block_num = len(synthesis.block_resolutions) 50 | super().__init__( 51 | synthesis.img_resolution, 52 | (synthesis.num_ws, synthesis.w_dim), 53 | block_num, 54 | *args, 55 | **kwargs, 56 | ) 57 | 58 | self.synthesis = synthesis 59 | 60 | def _forward_impl( 61 | self, 62 | ws: Tensor, 63 | intermediate_inputs: Optional[Tensor] = None, 64 | labels: torch.LongTensor | None = None, 65 | start_block: int = None, 66 | end_block: int = None, 67 | noise_mode='const', 68 | force_fp32=True, 69 | **kwargs, 70 | ): 71 | 72 | if 'noise_mode' not in kwargs: 73 | kwargs['noise_mode'] = noise_mode 74 | if 'force_fp32' not in kwargs: 75 | kwargs['force_fp32'] = force_fp32 76 | 77 | block_ws = [] 78 | with torch.autograd.profiler.record_function('split_ws'): 79 | if ws.shape[-2] == 1: 80 | ws = torch.repeat_interleave(ws, self.synthesis.num_ws, dim=-2) 81 | check_shape(ws, [None, self.synthesis.num_ws, self.synthesis.w_dim]) 82 | w_idx = 0 83 | for res in self.synthesis.block_resolutions: 84 | block = getattr(self.synthesis, f'b{res}') 85 | block_ws.append(ws.narrow(1, w_idx, block.num_conv + block.num_torgb)) 86 | w_idx += block.num_conv 87 | 88 | img = None 89 | x = intermediate_inputs 90 | for i in range(start_block, end_block): 91 | res = self.synthesis.block_resolutions[i] 92 | w = block_ws[i] 93 | 94 | block = getattr(self.synthesis, f'b{res}') 95 | x, img = block(x, img, w, **kwargs) 96 | return x if end_block < self.block_num else img 97 | 98 | 99 | def get_stylegan2ada_generator( 100 | stylegan2ada_path: str, 101 | checkpoint_path: str, 102 | single_w=True, 103 | truncation_psi=0.5, 104 | truncation_cutoff=8, 105 | ): 106 | 107 | sys.path.append(stylegan2ada_path) 108 | 109 | with open(checkpoint_path, 'rb') as f: 110 | G = pickle.load(f)['G_ema'] 111 | mapping = StyleGan2adaMappingWrapper( 112 | G.mapping, single_w, truncation_psi, truncation_cutoff 113 | ) 114 | 115 | synthesis = StyleGAN2adaSynthesisWrapper(G.synthesis) 116 | 117 | return mapping, synthesis 118 | -------------------------------------------------------------------------------- /src/modelinversion/sampler/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import ( 2 | BaseLatentsSampler, 3 | SimpleLatentsSampler, 4 | ImageAugmentSelectLatentsSampler, 5 | GaussianMixtureLatentsSampler, 6 | LayeredFlowLatentsSampler, 7 | ) 8 | from .labelonly import LabelOnlySelectLatentsSampler 9 | 10 | from .flow import LayeredFlowMiner, MixtureOfGMM, FlowConfig 11 | -------------------------------------------------------------------------------- /src/modelinversion/sampler/flow/__init__.py: -------------------------------------------------------------------------------- 1 | from .likelihood_models import * 2 | from dataclasses import dataclass, field 3 | 4 | 5 | class LayeredMineGAN(nn.Module): 6 | def __init__(self, miner, Gmapping): 7 | super(LayeredMineGAN, self).__init__() 8 | self.nz = miner.nz0 9 | self.miner = miner 10 | self.Gmapping = Gmapping 11 | 12 | def forward(self, z0): 13 | N, zdim = z0.shape 14 | z = self.miner(z0) # (N, zdim) -> (N, l, zdim) 15 | w = self.Gmapping(z.reshape(-1, zdim)) # (N * l, l, zdim) 16 | w = w[:, 0].reshape(N, -1, zdim) # (N, l, zdim) 17 | return w 18 | 19 | 20 | @dataclass 21 | class FlowConfig: 22 | k: int 23 | l: int 24 | flow_permutation: str 25 | flow_K: int 26 | flow_glow: bool = False 27 | flow_coupling: str = 'additive' 28 | flow_L: int = 1 29 | flow_use_actnorm: bool = True 30 | l_identity: list = range(10) 31 | -------------------------------------------------------------------------------- /src/modelinversion/sampler/flow/flow_utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | 4 | 5 | def compute_same_pad(kernel_size, stride): 6 | if isinstance(kernel_size, int): 7 | kernel_size = [kernel_size] 8 | 9 | if isinstance(stride, int): 10 | stride = [stride] 11 | 12 | assert len(stride) == len(kernel_size),\ 13 | "Pass kernel size and stride both as int, or both as equal length iterable" 14 | 15 | return [((k - 1) * s + 1) // 2 for k, s in zip(kernel_size, stride)] 16 | 17 | 18 | def pixels(tensor): 19 | return int(tensor.size(2) * tensor.size(3)) 20 | 21 | 22 | def uniform_binning_correction(x, n_bits=8): 23 | """Replaces x^i with q^i(x) = U(x, x + 1.0 / 256.0). 24 | 25 | Args: 26 | Args: 27 | x: 4-D Tensor of shape (NCHW) 28 | n_bits: optional. 29 | Returns: 30 | x: x ~ U(x, x + 1.0 / 256) 31 | objective: Equivalent to -q(x)*log(q(x)). 32 | """ 33 | b, c, h, w = x.size() 34 | n_bins = 2**n_bits 35 | chw = c * h * w 36 | # correct for pytorch.to_tensor 37 | x = x * 255. / 256. 38 | x = x + torch.zeros_like(x).uniform_(0, 1.0 / n_bins) 39 | x = torch.clamp(x, min=0,max=1) 40 | objective = -math.log(n_bins) * chw * torch.ones(b, device=x.device) 41 | return x, objective 42 | 43 | 44 | def split_feature(tensor, type="split"): 45 | """ 46 | type = ["split", "cross"] 47 | """ 48 | C = tensor.size(1) 49 | if type == "split": 50 | return tensor[:, :C // 2, ...], tensor[:, C // 2:, ...] 51 | elif type == "cross": 52 | return tensor[:, 0::2, ...], tensor[:, 1::2, ...] 53 | -------------------------------------------------------------------------------- /src/modelinversion/sampler/flow/toy_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from itertools import product 3 | 4 | 5 | def twod_mog_means(n_centers): 6 | 7 | assert np.sqrt(n_centers) == np.floor(np.sqrt(n_centers)) 8 | n_centers = int(np.sqrt(n_centers)) 9 | std = 1. / n_centers 10 | means = np.arange(-3., 3., 6. / n_centers) + 6. / n_centers / 2 11 | means = np.array(list(product(means, means))) 12 | return means 13 | 14 | 15 | def get_grid(low=-4, high=4, npts=20, ret_xy=False): 16 | delta = (high - low) / npts 17 | x, y = np.mgrid[low:high:delta, low:high:delta] 18 | # x, y = np.mgrid[low:high+delta:delta, low:high+delta:delta] 19 | pos = np.empty(x.shape + (2,)) 20 | pos[:, :, 0] = x 21 | pos[:, :, 1] = y 22 | if ret_xy: 23 | return x, y 24 | else: 25 | return pos.reshape(-1, 2) 26 | 27 | 28 | def compute_grid_f(f, npts=100, low=-4., high=4.): 29 | x = get_grid(low=low, high=high, npts=npts) 30 | fx = f(x)[:, None] 31 | return fx.reshape(npts, npts) 32 | 33 | # def compute_density(logdensity, npts=100, low=-4., high=4.): 34 | # x = get_grid(low=low, high=high, npts=npts) 35 | # logpx = logdensity(x)[:, None] 36 | 37 | # px = np.exp(logpx).reshape(npts, npts) 38 | # # px = np.exp(logpx).reshape(npts+1, npts+1) 39 | # return px 40 | 41 | 42 | # def plt_density(logdensity, ax, npts=100, low=-4., high=4., alpha=1, cmap='inferno'): 43 | # px = compute_grid_f(logdensity, npts, low, high) 44 | # ax.imshow(px, alpha=alpha, cmap=cmap) 45 | 46 | 47 | def plt_contourf(f, ax, npts=20, low=-4., high=4., fill=True, **kwargs): 48 | px = compute_grid_f(f, npts, low, high) 49 | x, y = get_grid(low=low, high=high, npts=npts, ret_xy=True) 50 | if fill: 51 | cont = ax.contourf 52 | else: 53 | cont = ax.contour 54 | return cont(x, y, px, **kwargs) 55 | 56 | 57 | def plt_samples(samples, ax, npts=100, low=-4., high=4., alpha=1, cmap='inferno'): 58 | ax.hist2d(samples[:, 0], samples[:, 1], range=[ 59 | [low, high], [low, high]], bins=npts, alpha=alpha, cmap=cmap) 60 | ax.invert_yaxis() 61 | -------------------------------------------------------------------------------- /src/modelinversion/sampler/labelonly.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from abc import ABC, abstractmethod 3 | from collections import defaultdict 4 | from typing import Callable, Optional, Iterable, Sequence 5 | 6 | import torch 7 | from torch import Tensor, LongTensor 8 | from tqdm import tqdm 9 | 10 | from ..models import BaseImageGenerator, BaseImageClassifier 11 | from ..utils import batch_apply, print_split_line, print_as_yaml 12 | from .base import SimpleLatentsSampler 13 | 14 | 15 | class LabelOnlySelectLatentsSampler(SimpleLatentsSampler): 16 | 17 | def __init__( 18 | self, 19 | input_size: int | Sequence[int], 20 | batch_size: int, 21 | generator: BaseImageGenerator, 22 | classifier: BaseImageClassifier, 23 | device: torch.device, 24 | latents_mapping: Optional[Callable] = None, 25 | image_transform: Optional[Callable[[Tensor], Tensor]] = None, 26 | max_iters: int = 100000, 27 | ) -> None: 28 | super().__init__(input_size, batch_size, latents_mapping) 29 | 30 | self.generator = generator 31 | self.classifier = classifier 32 | self.device = device 33 | self.image_transform = image_transform 34 | self.max_iters = max_iters 35 | 36 | def __call__(self, labels: list[int], sample_num: int): 37 | 38 | batch_latent_size = self.get_batch_latent_size(self.batch_size) 39 | 40 | res_labels = set(labels) 41 | 42 | results = defaultdict(list) 43 | 44 | for _ in tqdm(range(self.max_iters)): 45 | batch_latents = torch.randn(batch_latent_size, device=self.device) 46 | if self.latents_mapping: 47 | batch_latents = self.latents_mapping(batch_latents) 48 | batch_images = self.generator(batch_latents) 49 | if self.image_transform is not None: 50 | batch_images = self.image_transform(batch_images) 51 | pred_scores = self.classifier(batch_images)[0] 52 | pred_labels = torch.argmax(pred_scores, dim=-1).detach().tolist() 53 | batch_latents = batch_latents.detach().cpu() 54 | 55 | for i, label in enumerate(pred_labels): 56 | if label in res_labels: 57 | results[label].append(batch_latents[i]) 58 | if len(results[label]) == sample_num: 59 | res_labels.remove(label) 60 | if len(res_labels) == 0: 61 | break 62 | 63 | unfinish_labels = [] 64 | res_labels = list(res_labels) 65 | 66 | for label in results: 67 | results[label] = torch.stack(results[label], dim=0) 68 | if len(results[label]) < sample_num: 69 | unfinish_labels.append(label) 70 | 71 | print_split_line('label only unfinish labels') 72 | print_as_yaml({'no sample labels': res_labels}) 73 | print_as_yaml({'insufficient sample labels': unfinish_labels}) 74 | print_split_line() 75 | 76 | return results 77 | -------------------------------------------------------------------------------- /src/modelinversion/scores/__init__.py: -------------------------------------------------------------------------------- 1 | from .imgscore import ( 2 | BaseImageClassificationScore, 3 | ImageClassificationAugmentConfidence, 4 | ImageClassificationAugmentLabelOnlyScore, 5 | ImageClassificationAugmentLossScore, 6 | ) 7 | from .latentscore import BaseLatentScore, LatentClassificationAugmentConfidence 8 | from .functional import ( 9 | cross_image_augment_scores, 10 | specific_image_augment_scores, 11 | specific_image_augment_loss_score, 12 | specific_image_augment_scores_label_only, 13 | ) 14 | -------------------------------------------------------------------------------- /src/modelinversion/scores/functional.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from abc import ABC, abstractmethod 3 | from typing import Callable, Optional, Iterable 4 | 5 | import torch 6 | from torch import Tensor, LongTensor 7 | 8 | from ..models import BaseImageClassifier 9 | from ..utils import TorchLoss 10 | 11 | 12 | # @torch.no_grad() 13 | def specific_image_augment_scores( 14 | model: BaseImageClassifier, 15 | device: torch.device, 16 | create_aug_images_fn: Optional[Callable[[Tensor], Iterable[Tensor]]], 17 | images: Tensor, 18 | labels: LongTensor, 19 | ): 20 | images = images.to(device) 21 | labels = labels.cpu() 22 | 23 | if create_aug_images_fn is None: 24 | create_aug_images_fn = lambda x: [x] 25 | 26 | scores = torch.zeros_like(labels, dtype=images.dtype, device='cpu') 27 | total_num = 0 28 | for trans in create_aug_images_fn(images): 29 | total_num += 1 30 | conf = model(trans)[0].softmax(dim=-1).detach().cpu() 31 | scores += torch.gather(conf, 1, labels.unsqueeze(1)).squeeze(1) 32 | return scores / total_num 33 | 34 | 35 | def specific_image_augment_loss_score( 36 | model: BaseImageClassifier, 37 | device: torch.device, 38 | create_aug_images_fn: Optional[Callable[[Tensor], Iterable[Tensor]]], 39 | images: Tensor, 40 | labels: LongTensor, 41 | loss_fn: Callable, 42 | ): 43 | images = images.detach().to(device) 44 | labels = labels.to(device) 45 | 46 | if create_aug_images_fn is None: 47 | create_aug_images_fn = lambda x: [x] 48 | 49 | losses = torch.zeros_like(labels, dtype=images.dtype, device='cpu') 50 | total_num = 0 51 | for trans in create_aug_images_fn(images): 52 | total_num += 1 53 | losses += -loss_fn(model(trans)[0], labels).detach().cpu() 54 | return losses / total_num 55 | 56 | 57 | def specific_image_augment_scores_label_only( 58 | model: BaseImageClassifier, 59 | device: torch.device, 60 | create_aug_images_fn: Optional[Callable[[Tensor], Iterable[Tensor]]], 61 | images: Tensor, 62 | labels: LongTensor, 63 | correct_score: float = 1, 64 | wrong_score=-1, 65 | ): 66 | images = images.detach().to(device) 67 | labels = labels.cpu() 68 | 69 | if create_aug_images_fn is not None: 70 | scores = torch.zeros_like(labels, dtype=images.dtype, device='cpu') 71 | total_num = 0 72 | for trans in create_aug_images_fn(images): 73 | total_num += 1 74 | correct = model(trans)[0].argmax(dim=-1).detach().cpu() == labels 75 | scores += torch.where(correct, correct_score, wrong_score) 76 | return scores / total_num 77 | else: 78 | correct = model(images)[0].argmax(dim=-1).detach().cpu() == labels 79 | return torch.where(correct, correct_score, wrong_score).to(images.dtype) 80 | 81 | 82 | # @torch.no_grad() 83 | def cross_image_augment_scores( 84 | model: BaseImageClassifier, 85 | device: torch.device, 86 | create_aug_images_fn: Optional[Callable[[Tensor], Iterable[Tensor]]], 87 | images: Tensor, 88 | ): 89 | images = images.detach().to(device) 90 | 91 | if create_aug_images_fn is not None: 92 | scores = 0 93 | total_num = 0 94 | for trans in create_aug_images_fn(images): 95 | total_num += 1 96 | conf = model(trans)[0].softmax(dim=-1).cpu() 97 | scores += conf 98 | res = scores / total_num 99 | else: 100 | conf = model(images)[0].softmax(dim=-1).cpu() 101 | res = conf 102 | 103 | return res 104 | -------------------------------------------------------------------------------- /src/modelinversion/scores/latentscore.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Callable, Optional 3 | 4 | import torch 5 | from torch import Tensor, LongTensor 6 | 7 | from .imgscore import * 8 | from ..models import BaseImageClassifier, BaseImageGenerator 9 | from .functional import specific_image_augment_scores 10 | 11 | 12 | class BaseLatentScore(ABC): 13 | """This is a class for generating scores for each latent vector with the corresponding label.""" 14 | 15 | def __init__(self) -> None: 16 | super().__init__() 17 | 18 | @abstractmethod 19 | def __call__(self, latents: Tensor, labels: LongTensor | list[int]) -> Tensor: 20 | """The scoring function to score all latent vectors with the corresponding labels. 21 | 22 | Args: 23 | latents (Tensor): Latent vectors. 24 | labels (LongTensor): The corresponding labels for latent vectors. The length of `labels` should keep the same as `images` 25 | 26 | Returns: 27 | Tensor: The score of each latent vectors. 28 | """ 29 | pass 30 | 31 | 32 | class LatentClassificationAugmentConfidence(BaseLatentScore): 33 | """This is a class for generating scores for each latent vector with the corresponding label.. The score is calculated by the conficence of the classifier model. 34 | 35 | Args: 36 | generator (BaseImageGenerator): 37 | The image generator. 38 | model (BaseImageClassifier): 39 | The image classifier to generate scores. 40 | device (device): 41 | The device used for calculation. Please keep the same with the device of `generator` and `model`. 42 | create_aug_images_fn (Callable[[Tensor], Iterable[Tensor]], optional): 43 | The function to create a list of augment images that will be used to calculate the score. Defaults to `None`. 44 | """ 45 | 46 | def __init__( 47 | self, 48 | generator: BaseImageGenerator, 49 | model: BaseImageClassifier, 50 | device: torch.device, 51 | create_aug_images_fn: Optional[Callable[[Tensor], Iterable[Tensor]]] = None, 52 | ) -> None: 53 | self.generator = generator 54 | self.model = model 55 | self.device = device 56 | self.create_aug_images_fn = create_aug_images_fn 57 | self.device = device 58 | 59 | @torch.no_grad() 60 | def __call__(self, latents: Tensor, labels: LongTensor | list[int]) -> Tensor: 61 | latents = latents.to(self.device) 62 | labels = torch.LongTensor(labels).to(self.device) 63 | images = self.generator(latents, labels=labels) 64 | return specific_image_augment_scores( 65 | self.model, self.device, self.create_aug_images_fn, images, labels 66 | ) 67 | -------------------------------------------------------------------------------- /src/modelinversion/train/__init__.py: -------------------------------------------------------------------------------- 1 | from .gan import ( 2 | PlgmiGanTrainer, 3 | GmiGanTrainer, 4 | KedmiGanTrainer, 5 | LoktGanTrainer, 6 | PlgmiGanTrainConfig, 7 | GmiGanTrainConfig, 8 | KedmiGanTrainConfig, 9 | LoktGanTrainConfig, 10 | ) 11 | from .classifier import ( 12 | BaseTrainConfig, 13 | BaseTrainer, 14 | SimpleTrainConfig, 15 | SimpleTrainer, 16 | VibTrainConfig, 17 | VibTrainer, 18 | BiDOTrainConfig, 19 | BiDOTrainer, 20 | DistillTrainer, 21 | DistillTrainConfig, 22 | ) 23 | 24 | from .mapping import train_mapping_model -------------------------------------------------------------------------------- /src/modelinversion/train/classifier/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import * 2 | from .bido import BiDOTrainConfig, BiDOTrainer 3 | from .distill import DistillTrainConfig, DistillTrainer 4 | -------------------------------------------------------------------------------- /src/modelinversion/train/classifier/distill.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | import numpy as np 5 | 6 | from .base import * 7 | from ...models import BaseImageClassifier 8 | 9 | 10 | @dataclass 11 | class DistillTrainConfig(BaseTrainConfig): 12 | 13 | teacher: BaseImageClassifier = None 14 | 15 | 16 | class DistillTrainer(BaseTrainer): 17 | 18 | def __init__(self, config: DistillTrainConfig, *args, **kwargs) -> None: 19 | super().__init__(config, *args, **kwargs) 20 | 21 | self.config: DistillTrainConfig 22 | 23 | if config.teacher is None: 24 | raise RuntimeError(f'Teacher model should not be None') 25 | 26 | def calc_loss(self, inputs, result, labels: LongTensor): 27 | result = result[0] 28 | teacher_result = self.config.teacher(inputs)[0] 29 | 30 | loss = F.kl_div( 31 | F.log_softmax(result, dim=-1), 32 | F.softmax(teacher_result, dim=-1), 33 | reduction='batchmean', 34 | ) 35 | 36 | return loss 37 | 38 | @torch.no_grad() 39 | def calc_train_acc(self, inputs, result, labels: torch.LongTensor): 40 | res = result[0] 41 | if isinstance(res, InceptionOutputs): 42 | res, _ = res 43 | assert res.ndim <= 2 44 | 45 | teacher_result = self.config.teacher(inputs)[0] 46 | 47 | pred = torch.argmax(res, dim=-1) 48 | teacher_pred = torch.argmax(teacher_result, dim=-1) 49 | # print((pred == labels).float()) 50 | return (pred == teacher_pred).float().mean() 51 | -------------------------------------------------------------------------------- /src/modelinversion/train/mapping.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | from typing import Optional 3 | 4 | import torch 5 | from torch import nn, Tensor 6 | from torch.optim import Optimizer 7 | from torch.nn import Module 8 | from torch.utils.data import DataLoader 9 | 10 | from ..models.base import ModelMixin 11 | from ..utils import unwrapped_parallel_module 12 | 13 | def _get_first(data): 14 | if not isinstance(data, Tensor): 15 | return data[0] 16 | return data 17 | 18 | def train_mapping_model( 19 | epoch_num: int, 20 | mapping_module: ModelMixin, 21 | optimizer: Optimizer, 22 | src_model: Module, 23 | dst_model: Module, 24 | dataloader: DataLoader, 25 | device: torch.device, 26 | save_path: str, 27 | schedular: Optional[torch.optim.lr_scheduler.LRScheduler] = None, 28 | show_info_iters: int = 100, 29 | ): 30 | src_model.eval() 31 | dst_model.eval() 32 | 33 | loss_fn = nn.MSELoss() 34 | 35 | for epoch in range(epoch_num): 36 | 37 | bar = tqdm(dataloader, leave=False) 38 | for i, data in enumerate(bar): 39 | data = _get_first(data).to(device) 40 | with torch.no_grad(): 41 | inputs = _get_first(src_model(data)).softmax(dim=-1) 42 | labels = _get_first(dst_model(data)) 43 | 44 | map_result = mapping_module(inputs) 45 | loss = loss_fn(map_result, labels) 46 | 47 | optimizer.zero_grad() 48 | loss.backward() 49 | optimizer.step() 50 | 51 | if i % show_info_iters == 0: 52 | bar.set_description_str(f'epoch: {epoch} loss: {loss.item():.5f}') 53 | 54 | if schedular is not None: 55 | schedular.step() 56 | 57 | unwrapped_parallel_module(mapping_module).save_pretrained(save_path) -------------------------------------------------------------------------------- /src/modelinversion/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .log import Logger 2 | from .random import set_random_seed, get_random_string 3 | from .accumulator import Accumulator, DictAccumulator 4 | from .torchutil import * 5 | from .io import ( 6 | safe_save, 7 | safe_save_csv, 8 | walk_imgs, 9 | print_as_yaml, 10 | print_split_line, 11 | obj_to_yaml, 12 | ) 13 | from .config import ConfigMixin 14 | from .losses import ( 15 | TorchLoss, 16 | LabelSmoothingCrossEntropyLoss, 17 | max_margin_loss, 18 | poincare_loss, 19 | ) 20 | from .check import check_shape 21 | from .batch import batch_apply 22 | from .hook import ( 23 | BaseHook, 24 | OutputHook, 25 | InputHook, 26 | FirstInputHook, 27 | DeepInversionBNFeatureHook, 28 | ) 29 | from .constraint import BaseConstraint, MinMaxConstraint, L1ballConstraint 30 | from .outputs import BaseOutput 31 | 32 | ClassificationLoss = TorchLoss 33 | Tee = Logger 34 | -------------------------------------------------------------------------------- /src/modelinversion/utils/accumulator.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from collections import defaultdict, OrderedDict 3 | 4 | import torch 5 | 6 | 7 | class Accumulator: 8 | """For accumulating sums over `n` variables.""" 9 | 10 | def __init__(self, n): 11 | self.data = [0] * n 12 | self.num = 0 13 | 14 | def add(self, *args, add_num=1, add_type='mean'): 15 | """adding data to the data list""" 16 | assert len(args) == len(self.data) 17 | mul_coef = add_num if add_type == 'mean' else 1 18 | self.num += add_num 19 | for i, add_item in enumerate(args): 20 | if isinstance(add_item, torch.Tensor): 21 | add_item = add_item.item() 22 | self.data[i] += add_item * mul_coef 23 | 24 | def reset(self): 25 | """reset all data to 0""" 26 | self.data = [0] * len(self.data) 27 | self.num = 0 28 | 29 | def __getitem__(self, idx): 30 | return self.data[idx] 31 | 32 | def __len__(self): 33 | return len(self.data) 34 | 35 | def avg(self, idx=None): 36 | """Calculate average of the data specified by `idx`. If idx is None, it will calculate average of all data. 37 | 38 | Args: 39 | idx (int, optional): subscript for the data list. Defaults to None. 40 | 41 | Returns: 42 | int | list: list if idx is None else int 43 | """ 44 | num = 1 if self.num == 0 else self.num 45 | if idx is None: 46 | return [d / num for d in self.data] 47 | else: 48 | return self.data[idx] / num 49 | 50 | 51 | class DictAccumulator: 52 | def __init__(self) -> None: 53 | self.data = OrderedDict() # defaultdict(lambda : 0) 54 | self.num = 0 55 | 56 | def reset(self): 57 | """reset all data to 0""" 58 | self.data = OrderedDict() # defaultdict(lambda : 0) 59 | self.num = 0 60 | 61 | def add(self, add_dic: OrderedDict, add_num=1, add_type='mean'): 62 | mul_coef = add_num if add_type == 'mean' else 1 63 | self.num += add_num 64 | for key, val in add_dic.items(): 65 | if isinstance(val, torch.Tensor): 66 | val = val.item() 67 | if key not in self.data.keys(): 68 | self.data[key] = 0 69 | self.data[key] += val * mul_coef 70 | 71 | def __getitem__(self, key): 72 | return self.data[key] 73 | 74 | def __len__(self): 75 | return len(self.data) 76 | 77 | def avg(self, key=None): 78 | num = 1 if self.num == 0 else self.num 79 | if key is None: 80 | res = copy.deepcopy(self.data) 81 | for k in self.data: 82 | res[k] /= num 83 | return res 84 | else: 85 | return self.data[key] / num 86 | -------------------------------------------------------------------------------- /src/modelinversion/utils/batch.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Optional 2 | from functools import reduce 3 | 4 | import torch 5 | from tqdm import tqdm 6 | 7 | from .io import print_split_line 8 | from .outputs import BaseOutput 9 | 10 | 11 | def _is_namedtuple(obj): 12 | # Check if type was created from collections.namedtuple or a typing.NamedTuple. 13 | return ( 14 | isinstance(obj, tuple) and hasattr(obj, "_asdict") and hasattr(obj, "_fields") 15 | ) 16 | 17 | 18 | def _gather(outputs, dim=0): 19 | """Gather the input data. 20 | 21 | Args: 22 | outputs (_type_): The data to gather. 23 | dim (int, optional): The specified dimension used when the type of input data is torch.Tensor. Defaults to 0. 24 | """ 25 | 26 | def gather_map(outputs): 27 | out = outputs[0] 28 | if isinstance(out, torch.Tensor): 29 | return torch.cat(outputs, dim=dim) 30 | if isinstance(out, (list, tuple)) and isinstance(out[0], str): 31 | return list(reduce(lambda x, y: x + y, outputs)) 32 | if out is None: 33 | return None 34 | if isinstance(out, BaseOutput): 35 | # print((out.keys())) 36 | # exit() 37 | return type(out)(*gather_map([d.to_tuple() for d in outputs])) 38 | 39 | if isinstance(out, dict): 40 | if not all(len(out) == len(d) for d in outputs): 41 | raise ValueError('All dicts must have the same number of keys') 42 | return type(out)((k, gather_map([d[k] for d in outputs])) for k in out) 43 | if _is_namedtuple(out): 44 | return type(out)._make(map(gather_map, zip(*outputs))) 45 | return type(out)(map(gather_map, zip(*outputs))) 46 | 47 | try: 48 | res = gather_map(outputs) 49 | finally: 50 | gather_map = None 51 | return res 52 | 53 | 54 | def batch_apply( 55 | fn: Callable, 56 | *inputs, 57 | batch_size: int, 58 | description: Optional[str] = None, 59 | use_tqdm: bool = False, 60 | **other_input_kwargs, 61 | ): 62 | """Apply the given function to input data by the specified batch size. 63 | 64 | Args: 65 | fn (Callable): The given function. 66 | *inputs: The collected input data. 67 | batch_size (int): The specified batch size. 68 | description (Optional[str], optional): The content to print when processing the input data. Defaults to None. 69 | use_tqdm (bool, optional): Determine whether to use tqdm when printing. Defaults to False. 70 | """ 71 | 72 | def _check_valid(inputs): 73 | if len(inputs) == 0: 74 | return 75 | lens = [] 76 | for i, inp in enumerate(inputs): 77 | try: 78 | lens.append(len(inp)) 79 | except: 80 | raise RuntimeError(f'the {i} inputs have no attr `len`') 81 | valid_len = lens[0] 82 | if not all(map(lambda x: x == valid_len, lens)): 83 | raise RuntimeError('lengths of all inputs are not the same') 84 | 85 | _check_valid(inputs) 86 | 87 | total_len = len(inputs[0]) 88 | 89 | results = [] 90 | starts = list(range(0, total_len, batch_size)) 91 | iter_times = len(starts) 92 | 93 | if use_tqdm: 94 | if description is not None: 95 | print_split_line(description) 96 | starts = tqdm(starts, leave=False) 97 | 98 | for i, start in enumerate(starts, start=1): 99 | 100 | if description is not None and not use_tqdm: 101 | print_split_line(f'{description}: {i} / {iter_times}') 102 | 103 | end = min(total_len, start + batch_size) 104 | res = fn(*[p[start:end] for p in inputs], **other_input_kwargs) 105 | # print(res.device) 106 | results.append(res) 107 | return _gather(results) 108 | -------------------------------------------------------------------------------- /src/modelinversion/utils/check.py: -------------------------------------------------------------------------------- 1 | from typing import Union, Optional 2 | 3 | import torch 4 | 5 | 6 | class ShapeException(Exception): 7 | pass 8 | 9 | 10 | def check_shape( 11 | tensor: torch.Tensor, 12 | expect_shape: Union[list[Optional[int]], list[Optional[int]]], 13 | raise_exception=True, 14 | ) -> bool: 15 | """Check if the shape of the tensor matches expectations. 16 | 17 | Args: 18 | tensor (torch.Tensor): The tensor to check. 19 | expect_shape (Union[list[Optional[int]], list[Optional[int]]]): The expected shape. 20 | raise_exception (bool, optional): Whether to raise an exception. Defaults to True. 21 | 22 | Returns: 23 | bool: The check result. 24 | """ 25 | 26 | tensor_shape = tensor.shape 27 | 28 | if len(tensor_shape) < len(expect_shape): 29 | if raise_exception: 30 | raise ShapeException( 31 | f'expect ndim >= {len(expect_shape)}, but found {len(tensor_shape)}' 32 | ) 33 | return False 34 | # torch.Size(). 35 | tensor_shape_raw = tensor_shape 36 | tensor_shape = tensor_shape[-len(expect_shape) :] 37 | 38 | for i in range(len(expect_shape)): 39 | if expect_shape[i] is None: 40 | continue 41 | 42 | if expect_shape[i] != tensor_shape[i]: 43 | if raise_exception: 44 | for j in range(len(expect_shape)): 45 | if expect_shape[i] is None: 46 | expect_shape[i] = '*' 47 | expect_shape_str = ', '.join(expect_shape) 48 | tensor_shape_str = ', '.join(list(tensor_shape_raw)) 49 | raise ShapeException( 50 | f'expect shape [..., {expect_shape_str}], but found [{tensor_shape_str}]' 51 | ) 52 | 53 | return False 54 | return True 55 | -------------------------------------------------------------------------------- /src/modelinversion/utils/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import inspect 3 | import functools 4 | import torch 5 | 6 | from .io import safe_save 7 | 8 | 9 | class ConfigMixin: 10 | """ 11 | A Mixin to save parameters from `__init__` function. Inherit the `ConfigMixin` class and add the decorator `@register_to_config_init` to the `__init__` function. 12 | 13 | The workflow of the class are as follows. 14 | +------------------------------+ 15 | | | 16 | | Initial Parameters | 17 | | | 18 | +-----------+-----^------------+ 19 | | | 20 | register_to_config_init | | __init__ 21 | | | 22 | +-----------v-----+------------+ 23 | | | 24 | | Loaded Config | 25 | | | 26 | +-----------+-----^------------+ 27 | | | 28 | preprocess_config_before_save | | postprocess_config_after_load 29 | | | 30 | +-----------v-----+------------+ 31 | | | 32 | | Saved Config | 33 | | | 34 | +------------------------------+ 35 | """ 36 | 37 | def preprocess_config_before_save(self, config): 38 | return config 39 | 40 | @staticmethod 41 | def postprocess_config_after_load(config): 42 | return config 43 | 44 | def register_to_config(self, **config_dict): 45 | self._config_mixin_dict = config_dict 46 | 47 | def save_config(self, save_path: str): 48 | # os.makedirs(save_path, exist_ok=True) 49 | safe_save( 50 | self.preprocess_config_before_save(self._config_mixin_dict), save_path 51 | ) 52 | 53 | @staticmethod 54 | def load_config(config_path: str): 55 | if not os.path.exists(config_path): 56 | raise RuntimeError(f'config_path {config_path} is not existed.') 57 | 58 | kwargs = torch.load(config_path, map_location='cpu') 59 | return ConfigMixin.postprocess_config_after_load(kwargs) 60 | 61 | @staticmethod 62 | def register_to_config_init(init): 63 | """Decorator of `__init__` method of classses inherit from `ConfigMixin`. Automatically save the init parameters.""" 64 | 65 | @functools.wraps(init) 66 | def inner_init(self, *args, **kwargs): 67 | 68 | # Ignore private kwargs in the init. 69 | init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")} 70 | config_init_kwargs = {k: v for k, v in kwargs.items() if k.startswith("_")} 71 | if not isinstance(self, ConfigMixin): 72 | raise RuntimeError( 73 | f"`@register_to_config_init` was applied to {self.__class__.__name__} init method, but this class does " 74 | "not inherit from `ConfigMixin`." 75 | ) 76 | 77 | # Get positional arguments aligned with kwargs 78 | new_kwargs = {} 79 | signature = inspect.signature(init) 80 | parameters = { 81 | name: p.default 82 | for i, (name, p) in enumerate(signature.parameters.items()) 83 | if i > 0 84 | } 85 | for arg, name in zip(args, parameters.keys()): 86 | new_kwargs[name] = arg 87 | 88 | # Then add all kwargs 89 | new_kwargs.update( 90 | { 91 | k: init_kwargs.get(k, default) 92 | for k, default in parameters.items() 93 | if k not in new_kwargs 94 | } 95 | ) 96 | 97 | # Take note of the parameters that were not present in the loaded config 98 | # if len(set(new_kwargs.keys()) - set(init_kwargs)) > 0: 99 | # new_kwargs["_use_default_values"] = list( 100 | # set(new_kwargs.keys()) - set(init_kwargs) 101 | # ) 102 | 103 | new_kwargs = {**config_init_kwargs, **new_kwargs} 104 | # getattr(self, "register_to_config")(**new_kwargs) 105 | self.register_to_config(**new_kwargs) 106 | init(self, *args, **init_kwargs) 107 | 108 | return inner_init 109 | -------------------------------------------------------------------------------- /src/modelinversion/utils/constraint.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Optional 2 | import torch 3 | import torch.nn.functional as F 4 | from torch import Tensor 5 | from abc import ABC,abstractmethod 6 | 7 | 8 | def copy_or_set_(dest, source): 9 | """ 10 | A workaround to respect strides of :code:`dest` when copying :code:`source` 11 | (https://github.com/geoopt/geoopt/issues/70) 12 | Parameters 13 | ---------- 14 | dest : torch.Tensor 15 | Destination tensor where to store new data 16 | source : torch.Tensor 17 | Source data to put in the new tensor 18 | Returns 19 | ------- 20 | dest 21 | torch.Tensor, modified inplace 22 | """ 23 | if dest.stride() != source.stride(): 24 | return dest.copy_(source) 25 | else: 26 | # return dest.set_(source) 27 | dest.data = source.data 28 | return dest 29 | 30 | 31 | class BaseConstraint(ABC): 32 | """The limitations for tensors to restrict them in the certain domain.""" 33 | 34 | def __init__(self) -> None: 35 | self.center_tensor = None 36 | 37 | def register_center(self, tensor: Tensor): 38 | self.center_tensor = tensor 39 | 40 | @abstractmethod 41 | def __call__(self, tensor: Tensor, *args: Any, **kwds: Any) -> Any: 42 | return tensor 43 | 44 | 45 | class MinMaxConstraint(BaseConstraint): 46 | """Restrict the input tensor between the minimum tensor and maximum tensor.""" 47 | 48 | def __init__(self, min_tensor, max_tensor) -> None: 49 | super().__init__() 50 | 51 | self.min_tensor = min_tensor 52 | self.max_tensor = max_tensor 53 | 54 | def register_center(self, tensor: Tensor): 55 | pass 56 | 57 | def __call__(self, tensor: Tensor, *args: Any, **kwds: Any) -> Any: 58 | max_tensor = self.max_tensor 59 | min_tensor = self.min_tensor 60 | if isinstance(max_tensor, int): 61 | max_tensor = torch.tensor( 62 | max_tensor, dtype=tensor.dtype, device=tensor.device 63 | ) 64 | 65 | if isinstance(min_tensor, int): 66 | min_tensor = torch.tensor( 67 | min_tensor, dtype=tensor.dtype, device=tensor.device 68 | ) 69 | 70 | res = torch.min(tensor, max_tensor) 71 | res = torch.max(tensor, min_tensor) 72 | tensor.data = res.data 73 | return tensor 74 | # return copy_or_set_(tensor, res.detach().requires_grad_(False)) 75 | 76 | 77 | class L1ballConstraint(BaseConstraint): 78 | """Restrict the input tensor into a L1-ball centered at the specified tensor.""" 79 | 80 | def __init__(self, bias: float) -> None: 81 | super().__init__() 82 | 83 | self.bias = bias 84 | 85 | def register_center(self, tensor: Tensor): 86 | pass 87 | 88 | def __call__(self, tensor: Tensor, *args: Any, **kwds: Any) -> Any: 89 | x = tensor 90 | eps = self.bias 91 | original_shape = x.shape 92 | x = x.view(x.shape[0], -1) 93 | mask = (torch.norm(x, p=1, dim=1) < eps).float().unsqueeze(1) 94 | mu, _ = torch.sort(torch.abs(x), dim=1, descending=True) 95 | cumsum = torch.cumsum(mu, dim=1) 96 | arange = torch.arange(1, x.shape[1] + 1, device=x.device) 97 | rho, _ = torch.max((mu * arange > (cumsum - eps)) * arange, dim=1) 98 | theta = (cumsum[torch.arange(x.shape[0]), rho.cpu() - 1] - eps) / rho 99 | proj = (torch.abs(x) - theta.unsqueeze(1)).clamp(min=0) 100 | x = mask * x + (1 - mask) * proj * torch.sign(x) 101 | return copy_or_set_( 102 | tensor, x.view(original_shape).detach().requires_grad_(False) 103 | ) 104 | -------------------------------------------------------------------------------- /src/modelinversion/utils/hook.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | 3 | import torch 4 | from torch import Tensor 5 | from torch.nn import Module, parallel 6 | 7 | 8 | class BaseHook(metaclass=ABCMeta): 9 | """Monitor the model when forward""" 10 | 11 | def __init__(self, module: Module) -> None: 12 | self.hook = module.register_forward_hook(self._hook_gather_impl) 13 | self.features = None 14 | 15 | def _hook_gather_impl(self, module, input, output): 16 | feature = self.hook_fn(module, input, output) 17 | self.features = feature 18 | 19 | @abstractmethod 20 | def hook_fn(self, module, input, output): 21 | raise NotImplementedError() 22 | 23 | def get_feature(self) -> Tensor: 24 | """ 25 | Returns: 26 | Tensor: the value that the hook monitors. 27 | """ 28 | return self.features 29 | 30 | def close(self): 31 | self.hook.remove() 32 | 33 | 34 | class OutputHook(BaseHook): 35 | """Monitor the output of the model""" 36 | 37 | def __init__(self, module: Module) -> None: 38 | super().__init__(module) 39 | 40 | def hook_fn(self, module, input, output): 41 | return output 42 | 43 | 44 | class InputHook(BaseHook): 45 | """Monitor the input of the model""" 46 | 47 | def __init__(self, module: Module) -> None: 48 | super().__init__(module) 49 | 50 | def hook_fn(self, module, input, output): 51 | return input 52 | 53 | 54 | class FirstInputHook(BaseHook): 55 | """Monitor the input of the model""" 56 | 57 | def __init__(self, module: Module) -> None: 58 | super().__init__(module) 59 | 60 | def hook_fn(self, module, input, output): 61 | return input[0] 62 | 63 | 64 | class DeepInversionBNFeatureHook(BaseHook): 65 | ''' 66 | Implementation of the forward hook to track feature statistics and compute a loss on them. 67 | Will compute mean and variance, and will use l2 as a loss 68 | ''' 69 | 70 | def __init__(self, module): 71 | super().__init__(module) 72 | 73 | def hook_fn(self, module, input, output): 74 | # hook co compute deepinversion's feature distribution regularization 75 | nch = input[0].shape[1] 76 | mean = input[0].mean([0, 2, 3]) 77 | var = ( 78 | input[0] 79 | .permute(1, 0, 2, 3) 80 | .contiguous() 81 | .view([nch, -1]) 82 | .var(1, unbiased=False) 83 | ) 84 | 85 | # forcing mean and variance to match between two distributions 86 | # other ways might work better, i.g. KL divergence 87 | r_feature = torch.norm(module.running_var.data - var, 2) + torch.norm( 88 | module.running_mean.data - mean, 2 89 | ) 90 | 91 | return r_feature 92 | # must have no output 93 | -------------------------------------------------------------------------------- /src/modelinversion/utils/io.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | from typing import Optional 4 | from collections import OrderedDict 5 | import torch 6 | import pandas as pd 7 | 8 | 9 | def safe_save(obj, save_dir: str, save_name: Optional[str] = None): 10 | """Save the obj by using torch.save function. 11 | 12 | Args: 13 | obj (_type_): The objective to save. 14 | save_dir (str): The directory path. 15 | save_name (Optional[str], optional): The file name for the objective to save. Defaults to None. 16 | """ 17 | 18 | if save_name is None: 19 | save_dir, save_name = os.path.split(save_dir) 20 | if save_dir.strip() != '': 21 | os.makedirs(save_dir, exist_ok=True) 22 | torch.save(obj, os.path.join(save_dir, save_name)) 23 | 24 | 25 | def safe_save_csv(df: pd.DataFrame, save_dir: str, save_name: Optional[str] = None): 26 | """Save the data in csv format. 27 | 28 | Args: 29 | df (pd.DataFrame): The data to save. 30 | save_dir (str): The directory path. 31 | save_name (Optional[str], optional): The file name for the data to save. Defaults to None. 32 | """ 33 | 34 | if save_name is None: 35 | save_dir, save_name = os.path.split(save_dir) 36 | if save_dir.strip() != '': 37 | os.makedirs(save_dir, exist_ok=True) 38 | # torch.save(obj, os.path.join(save_dir, save_name)) 39 | df.to_csv(os.path.join(save_dir, save_name), index=None) 40 | 41 | 42 | IMG_EXTENSIONS = ( 43 | ".jpg", 44 | ".jpeg", 45 | ".png", 46 | ".ppm", 47 | ".bmp", 48 | ".pgm", 49 | ".tif", 50 | ".tiff", 51 | ".webp", 52 | ) 53 | 54 | 55 | def walk_imgs(path): 56 | """Traverse all images in the specified path. 57 | 58 | Args: 59 | path (_type_): The specified path. 60 | 61 | Returns: 62 | List: The list that collects the paths for all the images. 63 | """ 64 | 65 | img_paths = [] 66 | for root, dirs, files in os.walk(path): 67 | for file in files: 68 | if file.endswith(IMG_EXTENSIONS): 69 | img_paths.append(os.path.join(root, file)) 70 | return img_paths 71 | 72 | 73 | yaml.add_representer( 74 | OrderedDict, 75 | lambda dumper, data: dumper.represent_mapping( 76 | 'tag:yaml.org,2002:map', data.items() 77 | ), 78 | ) 79 | yaml.add_representer( 80 | tuple, lambda dumper, data: dumper.represent_sequence('tag:yaml.org,2002:seq', data) 81 | ) 82 | 83 | 84 | def obj_to_yaml(obj) -> str: 85 | return yaml.dump(obj) 86 | 87 | 88 | def print_as_yaml(obj, stdout=True, file=None, mode='w'): 89 | """Print the obj in the yaml format and Save the obj if the file path is specified. 90 | 91 | Args: 92 | obj (_type_): The objective to save. 93 | stdout (bool, optional): Whether to print in the stdout. Defaults to True. 94 | file (_type_, optional): The file path for the obj to save. Defaults to None. 95 | mode (str, optional): An optional string that specifies the mode in which the file is opened. Defaults to 'w'. 96 | """ 97 | 98 | s = yaml.dump(obj) 99 | 100 | if stdout: 101 | print(s) 102 | if file: 103 | with open(file, mode) as f: 104 | f.write(s) 105 | 106 | 107 | def print_split_line(content=None, length=60): 108 | """Print the content and surround it with '-' character for alignment. 109 | 110 | Args: 111 | content (_type_, optional): The content to print. Defaults to None. 112 | length (int, optional): The total length of content and '-' characters. Defaults to 60. 113 | """ 114 | 115 | if content is None: 116 | print('-' * length) 117 | return 118 | if len(content) > length - 4: 119 | length = len(content) + 4 120 | 121 | total_num = length - len(content) - 2 122 | left_num = total_num // 2 123 | right_num = total_num - left_num 124 | print('-' * left_num, end=' ') 125 | print(content, end=' ') 126 | print('-' * right_num) 127 | -------------------------------------------------------------------------------- /src/modelinversion/utils/log.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | from typing import Any, List, Tuple, Union 4 | 5 | # class Tee(object): 6 | # """A workaround method to print in console and write to log file 7 | # """ 8 | # def __init__(self, name, mode): 9 | # self.file = open(name, mode) 10 | # self.stdout = sys.stdout 11 | # sys.stdout = self 12 | # def __del__(self): 13 | # sys.stdout = self.stdout 14 | # self.file.close() 15 | # def write(self, data): 16 | # if not '...' in data: 17 | # self.file.write(data) 18 | # self.stdout.write(data) 19 | # def flush(self): 20 | # self.file.flush() 21 | 22 | 23 | class Logger(object): 24 | """Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file.""" 25 | 26 | def __init__( 27 | self, 28 | file_dir: str = None, 29 | file_name: str = None, 30 | file_mode: str = "w", 31 | should_flush: bool = True, 32 | ): 33 | self.file = None 34 | 35 | if file_name is not None: 36 | if file_dir is not None: 37 | os.makedirs(file_dir, exist_ok=True) 38 | file_name = os.path.join(file_dir, file_name) 39 | self.file = open(file_name, file_mode) 40 | 41 | self.should_flush = should_flush 42 | self.stdout = sys.stdout 43 | # self.stderr = sys.stderr 44 | 45 | sys.stdout = self 46 | # sys.stderr = self 47 | 48 | def __enter__(self) -> "Logger": 49 | return self 50 | 51 | def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: 52 | self.close() 53 | 54 | def write(self, text: Union[str, bytes]) -> None: 55 | """Write text to stdout (and a file) and optionally flush.""" 56 | if isinstance(text, bytes): 57 | text = text.decode() 58 | if ( 59 | len(text) == 0 60 | ): # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash 61 | return 62 | 63 | if self.file is not None: 64 | self.file.write(text) 65 | 66 | self.stdout.write(text) 67 | 68 | if self.should_flush: 69 | self.flush() 70 | 71 | def flush(self) -> None: 72 | """Flush written text to both stdout and a file, if open.""" 73 | if self.file is not None: 74 | self.file.flush() 75 | 76 | self.stdout.flush() 77 | 78 | def close(self) -> None: 79 | """Flush, close possible files, and remove stdout/stderr mirroring.""" 80 | self.flush() 81 | 82 | # if using multiple loggers, prevent closing in wrong order 83 | if sys.stdout is self: 84 | sys.stdout = self.stdout 85 | # if sys.stderr is self: 86 | # sys.stderr = self.stderr 87 | 88 | if self.file is not None: 89 | self.file.close() 90 | self.file = None 91 | -------------------------------------------------------------------------------- /src/modelinversion/utils/losses.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from typing import Callable 3 | 4 | import torch 5 | from torch import nn 6 | from torch.nn import functional as F 7 | 8 | 9 | def max_margin_loss(out, iden): 10 | real = out.gather(1, iden.unsqueeze(1)).squeeze(1) 11 | tmp1 = torch.argsort(out, dim=1)[:, -2:] 12 | new_y = torch.where(tmp1[:, -1] == iden, tmp1[:, -2], tmp1[:, -1]) 13 | margin = out.gather(1, new_y.unsqueeze(1)).squeeze(1) 14 | 15 | return (-1 * real).mean() + margin.mean() 16 | 17 | 18 | def poincare_loss(outputs, targets, xi=1e-4): 19 | # Normalize logits 20 | u = outputs / torch.norm(outputs, p=1, dim=-1).unsqueeze(1) 21 | # Create one-hot encoded target vector 22 | v = torch.clip(torch.eye(outputs.shape[-1])[targets.detach().cpu()] - xi, 0, 1) 23 | v = v.to(u.device) 24 | # Compute squared norms 25 | u_norm_squared = torch.norm(u, p=2, dim=1) ** 2 26 | v_norm_squared = torch.norm(v, p=2, dim=1) ** 2 27 | diff_norm_squared = torch.norm(u - v, p=2, dim=1) ** 2 28 | # Compute delta 29 | delta = 2 * diff_norm_squared / ((1 - u_norm_squared) * (1 - v_norm_squared)) 30 | # Compute distance 31 | loss = torch.arccosh(1 + delta) 32 | return loss.mean() 33 | 34 | 35 | _LOSS_MAPPING = { 36 | 'ce': F.cross_entropy, 37 | 'poincare': poincare_loss, 38 | 'max_margin': max_margin_loss, 39 | } 40 | 41 | 42 | class LabelSmoothingCrossEntropyLoss: 43 | """The Cross Entropy Loss with label smoothing technique. Used in the LS defense method.""" 44 | 45 | def __init__(self, label_smoothing: float = 0.0) -> None: 46 | self.label_smoothing = label_smoothing 47 | 48 | def __call__(self, inputs, labels): 49 | ls = self.label_smoothing 50 | confidence = 1.0 - ls 51 | logprobs = F.log_softmax(inputs, dim=-1) 52 | nll_loss = -logprobs.gather(dim=-1, index=labels.unsqueeze(1)) 53 | nll_loss = nll_loss.squeeze(1) 54 | smooth_loss = -logprobs.mean(dim=-1) 55 | loss = confidence * nll_loss + ls * smooth_loss 56 | return torch.mean(loss, dim=0).sum() 57 | 58 | 59 | class TorchLoss: 60 | """Find loss function from 'torch.nn.functional' and 'torch.nn'""" 61 | 62 | def __init__(self, loss_fn: str | Callable, *args, **kwargs) -> None: 63 | # super().__init__() 64 | self.fn = None 65 | if isinstance(loss_fn, str): 66 | if loss_fn.lower() in _LOSS_MAPPING: 67 | self.fn = _LOSS_MAPPING[loss_fn.lower()] 68 | else: 69 | module = importlib.import_module('torch.nn.functional') 70 | fn = getattr(module, loss_fn, None) 71 | if fn is not None: 72 | self.fn = lambda *arg, **kwd: fn(*arg, *args, **kwd, **kwargs) 73 | else: 74 | module = importlib.import_module('torch.nn') 75 | t = getattr(module, loss_fn, None) 76 | if t is not None: 77 | self.fn = t(*args, **kwargs) 78 | if self.fn is None: 79 | raise RuntimeError(f'loss_fn {loss_fn} not found.') 80 | else: 81 | self.fn = loss_fn 82 | 83 | def __call__(self, *args, **kwargs): 84 | return self.fn(*args, **kwargs) 85 | -------------------------------------------------------------------------------- /src/modelinversion/utils/outputs.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from dataclasses import fields, is_dataclass 3 | from typing import Any, Tuple 4 | 5 | import numpy as np 6 | 7 | 8 | class BaseOutput(OrderedDict): 9 | """ 10 | Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/utils/outputs.py. 11 | 12 | Base class for all model outputs as dataclass. Has a `__getitem__` that allows indexing by integer or slice (like a 13 | tuple) or strings (like a dictionary) that will ignore the `None` attributes. Otherwise behaves like a regular 14 | Python dictionary. 15 | """ 16 | 17 | def __init_subclass__(cls) -> None: 18 | """Register subclasses as pytree nodes. 19 | 20 | This is necessary to synchronize gradients when using `torch.nn.parallel.DistributedDataParallel` with 21 | `static_graph=True` with modules that output `ModelOutput` subclasses. 22 | """ 23 | # if is_torch_available(): 24 | import torch.utils._pytree 25 | 26 | # if is_torch_version("<", "2.2"): 27 | torch.utils._pytree._register_pytree_node( 28 | cls, 29 | torch.utils._pytree._dict_flatten, 30 | lambda values, context: cls( 31 | **torch.utils._pytree._dict_unflatten(values, context) 32 | ), 33 | ) 34 | # else: 35 | # torch.utils._pytree.register_pytree_node( 36 | # cls, 37 | # torch.utils._pytree._dict_flatten, 38 | # lambda values, context: cls( 39 | # **torch.utils._pytree._dict_unflatten(values, context) 40 | # ), 41 | # ) 42 | 43 | def __post_init__(self) -> None: 44 | class_fields = fields(self) 45 | 46 | # Safety and consistency checks 47 | if not len(class_fields): 48 | raise ValueError(f"{self.__class__.__name__} has no fields.") 49 | 50 | first_field = getattr(self, class_fields[0].name) 51 | other_fields_are_none = all( 52 | getattr(self, field.name) is None for field in class_fields[1:] 53 | ) 54 | 55 | if other_fields_are_none and isinstance(first_field, dict): 56 | for key, value in first_field.items(): 57 | self[key] = value 58 | else: 59 | for field in class_fields: 60 | v = getattr(self, field.name) 61 | # if v is not None: 62 | self[field.name] = v 63 | 64 | def __delitem__(self, *args, **kwargs): 65 | raise Exception( 66 | f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance." 67 | ) 68 | 69 | def setdefault(self, *args, **kwargs): 70 | raise Exception( 71 | f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance." 72 | ) 73 | 74 | def pop(self, *args, **kwargs): 75 | raise Exception( 76 | f"You cannot use ``pop`` on a {self.__class__.__name__} instance." 77 | ) 78 | 79 | def update(self, *args, **kwargs): 80 | raise Exception( 81 | f"You cannot use ``update`` on a {self.__class__.__name__} instance." 82 | ) 83 | 84 | def __getitem__(self, k: Any) -> Any: 85 | if isinstance(k, str): 86 | inner_dict = dict(self.items()) 87 | return inner_dict[k] 88 | else: 89 | return self.to_tuple()[k] 90 | 91 | def __setattr__(self, name: Any, value: Any) -> None: 92 | if name in self.keys() and value is not None: 93 | # Don't call self.__setitem__ to avoid recursion errors 94 | super().__setitem__(name, value) 95 | super().__setattr__(name, value) 96 | 97 | def __setitem__(self, key, value): 98 | # Will raise a KeyException if needed 99 | super().__setitem__(key, value) 100 | # Don't call self.__setattr__ to avoid recursion errors 101 | super().__setattr__(key, value) 102 | 103 | def __reduce__(self): 104 | if not is_dataclass(self): 105 | return super().__reduce__() 106 | callable, _args, *remaining = super().__reduce__() 107 | args = tuple(getattr(self, field.name) for field in fields(self)) 108 | return callable, args, *remaining 109 | 110 | def to_tuple(self) -> Tuple[Any, ...]: 111 | """ 112 | Convert self to a tuple containing all the attributes/keys that are not `None`. 113 | """ 114 | return tuple(self[k] for k in self.keys()) 115 | -------------------------------------------------------------------------------- /src/modelinversion/utils/random.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import random 4 | import time 5 | 6 | 7 | def set_random_seed(random_seed): 8 | random.seed(random_seed) 9 | np.random.seed(random_seed) 10 | torch.manual_seed(random_seed) 11 | if torch.cuda.is_available(): 12 | torch.cuda.manual_seed(random_seed) 13 | torch.cuda.manual_seed_all(random_seed) 14 | torch.backends.cudnn.deterministic = True 15 | torch.backends.cudnn.benchmark = False 16 | 17 | 18 | _ALL_LOGITS = '0123456789qwertyuiopasdfghjklzxcvbnmQWERTYUIOPASDFGHJKLZXCVBNM' 19 | _ALL_LOGITS_INDICES = np.arange(len(_ALL_LOGITS), dtype=np.int32) 20 | 21 | 22 | def get_random_string(length: int = 6): 23 | """Generate a random string with the specified length. 24 | 25 | Args: 26 | length (int, optional): The string length. Defaults to 6. 27 | 28 | Returns: 29 | str: The randomly generated string. 30 | """ 31 | 32 | seed = int(time.time() * 1000) % (2**30) ^ random.randint(0, 2**30) 33 | # print(seed) 34 | 35 | resindices = np.random.RandomState(seed).choice(_ALL_LOGITS_INDICES, length) 36 | return ''.join(map(lambda x: _ALL_LOGITS[x], resindices)) 37 | -------------------------------------------------------------------------------- /src/modelinversion/utils/torchutil.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | from typing import Optional 3 | 4 | import torch 5 | from torch import nn 6 | from torch.nn.parallel import DataParallel, DistributedDataParallel 7 | 8 | 9 | def traverse_module(module: nn.Module, fn: Callable, call_middle=False): 10 | """Use DFS to traverse the module and visit submodules by function `fn`. 11 | 12 | Args: 13 | module (nn.Module): the module to be traversed 14 | fn (Callable): visit function 15 | call_middle (bool, optional): If true, it will visit both intermediate nodes and leaf nodes, else, it will only visit leaf nodes. Defaults to False. 16 | """ 17 | 18 | children = list(module.children()) 19 | if len(children) == 0: 20 | fn(module) 21 | else: 22 | if call_middle: 23 | fn(module) 24 | for child in children: 25 | traverse_module(child, fn) 26 | 27 | 28 | def _traverse_name_module_impl(module_tuple: list, fn: Callable, call_middle=False): 29 | name, module = module_tuple 30 | children = list(module.named_children()) 31 | if len(children) == 0: 32 | fn(module_tuple) 33 | else: 34 | if call_middle: 35 | fn(module_tuple) 36 | for child in children: 37 | _traverse_name_module_impl(child, fn) 38 | 39 | 40 | def traverse_name_module(module: nn.Module, fn: Callable, call_middle=False): 41 | """Use DFS to traverse the module and visit submodules by function `fn`. 42 | 43 | Args: 44 | module (nn.Module): the module to be traversed 45 | fn (Callable): visit function 46 | call_middle (bool, optional): If true, it will visit both intermediate nodes and leaf nodes, else, it will only visit leaf nodes. Defaults to False. 47 | """ 48 | children = list(module.named_children()) 49 | for child in children: 50 | _traverse_name_module_impl(child, fn, call_middle=call_middle) 51 | 52 | 53 | def freeze(module): 54 | for p in module.parameters(): 55 | p.requires_grad_(False) 56 | 57 | 58 | def unfreeze(module): 59 | for p in module.parameters(): 60 | p.requires_grad_(True) 61 | 62 | 63 | def freeze_front_layers(module, ratio=0.5): 64 | 65 | if ratio < 0 or ratio > 1: 66 | raise RuntimeError('Ratio should be in [0, 1]') 67 | 68 | if ratio == 0: 69 | unfreeze(module) 70 | return 71 | 72 | if ratio == 1: 73 | freeze(module) 74 | return 75 | 76 | all_modules = [] 77 | 78 | def _visit_fn(module): 79 | all_modules.append(module) 80 | 81 | traverse_module(module, _visit_fn) 82 | length = len(all_modules) 83 | if length == 0: 84 | return 85 | 86 | freeze_line = ratio * length 87 | for i, m in enumerate(all_modules): 88 | if i < freeze_line: 89 | m.requires_grad_(False) 90 | else: 91 | m.requires_grad_(True) 92 | 93 | 94 | def unwrapped_parallel_module(module): 95 | 96 | if isinstance(module, (DataParallel, DistributedDataParallel)): 97 | return module.module 98 | return module 99 | 100 | 101 | def reparameterize(mu, std): 102 | """ 103 | Reparameterization trick to sample from N(mu, var) from 104 | N(0,1). 105 | :param mu: (Tensor) Mean of the latent Gaussian [B x D] 106 | :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D] 107 | :return: (Tensor) [B x D] 108 | """ 109 | 110 | # std = torch.exp(0.5 * std) 111 | eps = torch.randn_like(std) 112 | 113 | return eps * std + mu 114 | 115 | 116 | def augment_images_fn_generator( 117 | initial_transform: Optional[Callable] = None, 118 | add_origin_image=True, 119 | augment: Optional[Callable] = None, 120 | augment_times: int = 0, 121 | ): 122 | """Return a function for image augmentation. 123 | 124 | Args: 125 | initial_transform (Optional[Callable], optional): The first transformation to perform. Defaults to None. 126 | add_origin_image (bool, optional): Whether to return the original image. Defaults to True. 127 | augment (Optional[Callable], optional): The augmentation to perform. Defaults to None. 128 | augment_times (int, optional): Times for augmentation to repeat. Defaults to 0. 129 | """ 130 | 131 | def fn(image): 132 | if initial_transform is not None: 133 | image = initial_transform(image) 134 | 135 | if add_origin_image: 136 | yield image 137 | 138 | if augment is not None: 139 | for i in range(augment_times): 140 | yield augment(image) 141 | 142 | return fn 143 | --------------------------------------------------------------------------------