├── .gitignore ├── LICENSE ├── README.md ├── assets └── teaser.png ├── config ├── EGSDE │ ├── cat2dog-img256-lr-linear.yaml │ ├── cat2dog-img256.yaml │ ├── edge_map_cat2dog-img256-high-real.yaml │ ├── edge_map_cat2dog-img256.yaml │ └── multi-afhq-img256.yaml ├── ILVR │ ├── cat2dog-img256-respace100-down32-Rt20.yaml │ ├── cat2dog-img256-respace250-down32-Rt20.yaml │ ├── ffhq-p2-img256-respace100-down32-Rt20.yaml │ ├── wild2dog-img256-respace100-down32-Rt20.yaml │ └── wild2dog-img256-respace250-down32-Rt20.yaml ├── SDEdit │ ├── cat2dog-img256.yaml │ └── iter-cat2dog-img256-p400-k33-dN32.yaml ├── clfguided │ ├── imagenet-img128-respace250.yaml │ ├── imagenet-img256-respace1000.yaml │ ├── imagenet-img256-respace250.yaml │ ├── imagenet-img256-uncond-respace250.yaml │ └── imagenet-img64-respace250.yaml └── inversion │ ├── afhq-cat2dog-ADM.yaml │ └── afhq-wild2dog-ADM.yaml ├── data └── afhq_demo │ ├── cat │ └── flickr_cat_000033.jpg │ └── dog_sketch │ └── flickr_dog_000005_out.png ├── dataset └── README.md ├── dpm_nn ├── __init__.py ├── dpm_solver │ ├── __init__.py │ └── dpm_solver_pp.py ├── guided_dpm │ ├── ADMs.py │ ├── __init__.py │ ├── beta_schedule.py │ ├── build_ADMs.py │ ├── clf_guided_sampler.py │ ├── clf_guided_trainer.py │ ├── distribution.py │ ├── gaussian_diffusion.py │ ├── loss_resample.py │ └── spaced_diffusion.py ├── inversion │ ├── ILVR.py │ ├── SDEdit.py │ ├── __init__.py │ ├── egsde.py │ └── egsde_model.py └── utils.py ├── libs ├── __init__.py ├── engine │ ├── __init__.py │ └── model_state.py ├── metric │ ├── __init__.py │ ├── accuracy.py │ ├── lpips_origin │ │ ├── __init__.py │ │ ├── lpips.py │ │ ├── pretrained_networks.py │ │ └── weights │ │ │ └── v0.1 │ │ │ ├── alex.pth │ │ │ ├── squeeze.pth │ │ │ └── vgg.pth │ ├── piq │ │ ├── __init__.py │ │ ├── functional │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ ├── colour_conversion.py │ │ │ ├── filters.py │ │ │ ├── layers.py │ │ │ └── resize.py │ │ ├── perceptual.py │ │ └── utils │ │ │ ├── __init__.py │ │ │ └── common.py │ └── pytorch_fid │ │ ├── __init__.py │ │ ├── fid_score.py │ │ └── inception.py ├── modules │ ├── __init__.py │ ├── ema.py │ ├── resizer │ │ ├── __init__.py │ │ ├── interp_methods.py │ │ └── resizer.py │ └── vision │ │ ├── __init__.py │ │ ├── inception.py │ │ └── vgg.py ├── solver │ ├── __init__.py │ └── lr_scheduler.py └── utils │ ├── __init__.py │ ├── argparse.py │ ├── imshow.py │ ├── lazy.py │ ├── logging.py │ ├── meter.py │ ├── misc.py │ ├── model_summary.py │ └── tqdm.py ├── pipelines ├── __init__.py └── inversion │ ├── ILVR.py │ ├── ILVR_mixup.py │ ├── SDEdit_iter_pipeline.py │ ├── SDEdit_pipeline.py │ ├── __init__.py │ ├── egsde_pipeline.py │ └── invbyinv_pipeline.py ├── requirements.txt ├── run ├── run_ADM.py ├── run_ILVR.py ├── run_SDEdit.py ├── run_egsde.py └── run_invbyinv.py ├── setup.py ├── sketch_nn ├── __init__.py ├── augment │ ├── __init__.py │ ├── mixup.py │ └── resizer.py ├── dataset │ ├── __init__.py │ ├── base_dataset.py │ ├── build.py │ ├── imagenet.py │ └── utils.py ├── edge_map │ ├── DoG │ │ ├── XDoG.py │ │ └── __init__.py │ ├── __init__.py │ ├── canny │ │ └── __init__.py │ └── image_grads │ │ ├── __init__.py │ │ └── laplacian.py ├── methods │ ├── __init__.py │ └── inversion │ │ ├── ILVR_mixup.py │ │ ├── SDEdit_iter.py │ │ ├── __init__.py │ │ ├── diffsketching.py │ │ └── invbyinv.py ├── model │ ├── __init__.py │ └── ddpm_model.py ├── photo2sketch │ ├── InformativeDrawings │ │ ├── __init__.py │ │ ├── default_config.yaml │ │ └── model.py │ ├── PhotoSketching │ │ ├── __init__.py │ │ ├── base_model.py │ │ ├── base_options.py │ │ ├── default_config.yaml │ │ ├── image_pool.py │ │ ├── networks.py │ │ ├── pix2pix_model.py │ │ ├── test_options.py │ │ └── util.py │ └── __init__.py └── rasterize │ ├── __init__.py │ ├── bresenham.py │ └── rasterize.py └── style_transfer ├── AdaIN ├── README.md ├── __init__.py ├── function.py ├── net.py └── test.py ├── STROTSS ├── README.md ├── __init__.py ├── loss_utils.py ├── style_transfer.py ├── test.py ├── utils.py └── vgg_pt.py └── __init__.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # .idea 132 | .idea/ 133 | /idea/ 134 | *.ipr 135 | *.iml 136 | *.iws 137 | 138 | # system 139 | .DS_Store 140 | 141 | # pytorch-lighting logs 142 | lightning_logs/* 143 | 144 | # Edit settings 145 | .editorconfig 146 | 147 | # dataset 148 | /dataset/*.png 149 | 150 | # dataset uningore 151 | !/dataset/README.md 152 | !/dataset/afhq 153 | 154 | # model chexkpoint 155 | !checkpoint/README.md 156 | checkpoint/*.pt 157 | checkpoint/* 158 | 159 | # local results 160 | ./workdir/ 161 | ./workdir/* -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 XiMing Xing 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Inversion-by-Inversion: Exemplar-based Sketch-to-Photo Synthesis via Stochastic Differential Equations without Training 2 | 3 | [![arXiv](https://img.shields.io/badge/arXiv-2308.07665-b31b1b.svg)](https://arxiv.org/abs/2308.07665) 4 | 5 | #### [Project Link](https://ximinng.github.io/inversion-by-inversion-project/) 6 | 7 | Our Inversion-by-Inversion method for exemplar-based sketch-to-photo synthesis addresses the challenge of generating 8 | photo-realistic images from mostly white-space sketches. It includes shape-enhancing and full-control inversion, which 9 | generate an uncolored photo for shape control and add color and texture using an appearance-energy function to create 10 | the final RGB photo. Our pipeline works for different exemplars and does not require task-specific training or trainable 11 | hyper-network, making it a versatile solution. 12 |
13 |
14 | 15 | ![VCT examples](assets/teaser.png?raw=true) 16 | 17 | ## Setup 18 | 19 | To set up the environment, please run 20 | 21 | ```bash 22 | conda create -n inv-by-inv python=3.10 23 | conda activate inv-by-inv 24 | pip install -r requirements.txt 25 | ``` 26 | 27 | We test our method on both Nvidia RTX3090 and V100 GPU. However, it should work in any GPU with 4G memory ( 28 | when `valid_batch_size=1`). 29 | 30 | ## Dataset 31 | 32 | Please download the AFHQ dataset and put them in `dataset/`. 33 | 34 | > Download Link: [afhq dataset](https://drive.google.com/file/d/18b0cz38KugVrqFgEe0lZvZgO5ACDdhHZ/view?usp=sharing) 35 | 36 | We also provide some demo images in `data/afhq_demo/` for quick start. 37 | 38 | ## Pretrained Models 39 | 40 | To synthesize the image, a pre-trained diffusion model is required. 41 | 42 | > Download 43 | > Link: [pretrained models](https://drive.google.com/drive/folders/1zt2YzcJUPTxWNKAm2wX4GBK4BB3gQD0C?usp=sharing) 44 | 45 | - In contrast, you need to download the models pretrained on other datasets in the table and put it 46 | in `./checkpoint/InvSDE/` folder. 47 | - You can manually revise the checkpoint paths and names in `./config/inversion/afhq-cat2dog-ADM.yaml` file. 48 | 49 | ## Usage 50 | 51 | After downloading the dataset, to use the inv-by-inv for **cat-to-dog** tasks, please run 52 | 53 | ```bash 54 | python run/run_invbyinv.py -c inversion/afhq-cat2dog-ADM.yaml -respath ./workdir/invbyinv/ -vbz 8 55 | ``` 56 | 57 | The `-vbz` indicates `--valid_batch_size`. 58 | 59 | **Note: This version includes more detailed content and optimizes image quality, so the sampling time will be longer.** 60 | 61 | Specify the data path to run, 62 | 63 | ```bash 64 | python run/run_invbyinv.py -c inversion/afhq-cat2dog-ADM.yaml \ 65 | -dpath ./data/afhq_demo/cat \ # examplar 66 | -rdpath ./data/afhq_demo/dog_sketch \ # sketch 67 | -respath ./workdir/invbyinv/ \ 68 | -vbz 8 69 | ``` 70 | 71 | Please put your exemplar image into `-dpath`, and sketch images into `-rdpath`. 72 | The translated images will be saved in `-respath`. 73 | 74 | If you need to speed up sampling, dpm-solver can be called as follows, 75 | 76 | ```bash 77 | python run/run_invbyinv.py -c inversion/afhq-cat2dog-ADM.yaml -respath ./workdir/invbyinv/ -vbz 8 -uds 78 | ``` 79 | 80 | The `-uds` indicates `--use_dpm_solver`. 81 | 82 | ```bash 83 | python run/run_invbyinv.py -c inversion/afhq-cat2dog-ADM.yaml -respath ./workdir/invbyinv/ -vbz 8 -ts 30000 -final 84 | ``` 85 | 86 | The `-ts` is the total number of samples (eg. 30000) and `-final` indicates that intermediate results are skipped. 87 | 88 | To use the inv-by-inv for **wild-to-dog** tasks, please run, 89 | 90 | ```bash 91 | python run/run_invbyinv.py -c inversion/afhq-wild2dog-ADM.yaml -dpath ./dataset/afhq/train/wild -respath ./workdir/invbyinv/ -vbz 8 92 | ``` 93 | 94 | ## Citation 95 | 96 | If this code is useful for your work, please cite our paper: 97 | 98 | ``` 99 | @article{xing2023inversion, 100 | title={Inversion-by-Inversion: Exemplar-based Sketch-to-Photo Synthesis via Stochastic Differential Equations without Training}, 101 | author={Xing, Ximing and Wang, Chuang and Zhou, Haitao and Hu, Zhihao and Li, Chongxuan and Xu, Dong and Yu, Qian}, 102 | journal={arXiv preprint arXiv:2308.07665}, 103 | year={2023} 104 | } 105 | ``` -------------------------------------------------------------------------------- /assets/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ximinng/inversion-by-inversion/0d74d60567edb8f96c8c15341ec8446f170c9a62/assets/teaser.png -------------------------------------------------------------------------------- /config/EGSDE/cat2dog-img256-lr-linear.yaml: -------------------------------------------------------------------------------- 1 | seed: 1234 2 | # data 3 | dataset: "cat2dog" 4 | num_classes: 2 5 | image_size: 256 6 | channels: 3 7 | # optimizer 8 | lr: 3e-4 9 | adam_betas: [ 0.9, 0.999 ] 10 | weight_decay: 0.05 11 | lr_scheduler: "linear" # "linear", "cosine" ,"cosine_with_restarts", "constant_with_warmup" 12 | lr_warmup_steps: 100 13 | # train: 14 | train_num_steps: 700000 15 | train_batch_size: 32 16 | noised: True 17 | schedule_sampler: "uniform" 18 | # sample 19 | valid_batch_size: 8 20 | # methods 21 | perturb_step: 500 22 | repeat_step: 1 # K 23 | 24 | expert: 25 | lam_s: 500 26 | lam_i: 2 27 | s1: "cosine" 28 | s2: "neg_l2" 29 | down_N: 32 30 | 31 | model: 32 | type: "ADM" 33 | use_fp16: False 34 | in_channels: 3 35 | num_channels: 128 36 | out_channels: 3 37 | num_res_blocks: 1 38 | class_cond: False 39 | use_checkpoint: False 40 | attention_resolutions: "16" 41 | num_heads: 4 42 | num_head_channels: 64 43 | num_heads_upsample: -1 44 | use_scale_shift_norm: True 45 | dropout: 0 46 | resblock_updown: True 47 | use_new_attention_order: False 48 | 49 | diffusion: 50 | beta_schedule: "linear" 51 | beta_start: 0.0001 52 | beta_end: 0.02 53 | timesteps: 1000 54 | var_type: fixedsmall # here 55 | clip_denoised: True 56 | learn_sigma: True # out_channels * 2 57 | sigma_small: False 58 | use_kl: False 59 | predict_xstart: False 60 | rescale_timesteps: False 61 | rescale_learned_sigmas: False 62 | use_ddim: False 63 | timestep_respacing: "" 64 | 65 | dse: 66 | load_share_weights: True 67 | use_fp16: False 68 | model_channels: 128 # width 69 | num_res_blocks: 2 # depth 70 | attention_resolutions: '32,16,8' 71 | use_scale_shift_norm: True 72 | resblock_updown: True 73 | pool: 'attention' 74 | -------------------------------------------------------------------------------- /config/EGSDE/cat2dog-img256.yaml: -------------------------------------------------------------------------------- 1 | seed: 1234 2 | # data 3 | dataset: "cat2dog" 4 | num_classes: 2 5 | image_size: 256 6 | channels: 3 7 | # optimizer 8 | lr: 3e-4 9 | adam_betas: [ 0.9, 0.999 ] 10 | weight_decay: 0.05 11 | lr_scheduler: "cosine" # "linear", "cosine" ,"cosine_with_restarts", "constant_with_warmup" 12 | lr_warmup_steps: 100 13 | # train: 14 | train_num_steps: 700000 15 | train_batch_size: 32 16 | noised: True 17 | schedule_sampler: "uniform" 18 | # sample 19 | valid_batch_size: 8 20 | # methods 21 | perturb_step: 500 22 | repeat_step: 1 # K 23 | 24 | expert: 25 | lam_s: 500 26 | lam_i: 2 27 | s1: "cosine" 28 | s2: "neg_l2" 29 | down_N: 32 30 | 31 | model: 32 | type: "ADM" 33 | use_fp16: False 34 | in_channels: 3 35 | num_channels: 128 36 | out_channels: 3 37 | num_res_blocks: 1 38 | class_cond: False 39 | use_checkpoint: False 40 | attention_resolutions: "16" 41 | num_heads: 4 42 | num_head_channels: 64 43 | num_heads_upsample: -1 44 | use_scale_shift_norm: True 45 | dropout: 0 46 | resblock_updown: True 47 | use_new_attention_order: False 48 | 49 | diffusion: 50 | beta_schedule: "linear" 51 | beta_start: 0.0001 52 | beta_end: 0.02 53 | timesteps: 1000 54 | var_type: fixedsmall # here 55 | clip_denoised: True 56 | learn_sigma: True # out_channels * 2 57 | sigma_small: False 58 | use_kl: False 59 | predict_xstart: False 60 | rescale_timesteps: False 61 | rescale_learned_sigmas: False 62 | use_ddim: False 63 | timestep_respacing: "" 64 | 65 | dse: 66 | load_share_weights: True 67 | use_fp16: False 68 | model_channels: 128 # width 69 | num_res_blocks: 2 # depth 70 | attention_resolutions: '32,16,8' 71 | use_scale_shift_norm: True 72 | resblock_updown: True 73 | pool: 'attention' 74 | -------------------------------------------------------------------------------- /config/EGSDE/edge_map_cat2dog-img256-high-real.yaml: -------------------------------------------------------------------------------- 1 | seed: 1234 2 | # data 3 | dataset: "cat2dog" 4 | num_classes: 2 5 | image_size: 256 6 | channels: 3 7 | # optimizer 8 | lr: 3e-4 9 | adam_betas: [ 0.9, 0.999 ] 10 | weight_decay: 0.05 11 | lr_scheduler: "cosine" # "linear", "cosine" ,"cosine_with_restarts", "constant_with_warmup" 12 | lr_warmup_steps: 100 13 | # train: 14 | train_num_steps: 700000 15 | train_batch_size: 32 16 | noised: True 17 | schedule_sampler: "uniform" 18 | # sample 19 | valid_batch_size: 8 20 | # methods 21 | perturb_step: 500 22 | repeat_step: 1 # K 23 | 24 | expert: 25 | lam_s: 1 26 | lam_i: 0.001 27 | s1: "cosine" 28 | s2: "neg_l2" 29 | down_N: 48 30 | 31 | model: 32 | type: "ADM" 33 | use_fp16: False 34 | in_channels: 3 35 | num_channels: 128 36 | out_channels: 3 37 | num_res_blocks: 1 38 | class_cond: False 39 | use_checkpoint: False 40 | attention_resolutions: "16" 41 | num_heads: 4 42 | num_head_channels: 64 43 | num_heads_upsample: -1 44 | use_scale_shift_norm: True 45 | dropout: 0 46 | resblock_updown: True 47 | use_new_attention_order: False 48 | 49 | diffusion: 50 | beta_schedule: "linear" 51 | beta_start: 0.0001 52 | beta_end: 0.02 53 | timesteps: 1000 54 | var_type: fixedsmall # here 55 | clip_denoised: True 56 | learn_sigma: True # out_channels * 2 57 | sigma_small: False 58 | use_kl: False 59 | predict_xstart: False 60 | rescale_timesteps: False 61 | rescale_learned_sigmas: False 62 | use_ddim: False 63 | timestep_respacing: "" 64 | 65 | dse: 66 | load_share_weights: True 67 | use_fp16: False 68 | model_channels: 128 # width 69 | num_res_blocks: 2 # depth 70 | attention_resolutions: '32,16,8' 71 | use_scale_shift_norm: True 72 | resblock_updown: True 73 | pool: 'attention' 74 | -------------------------------------------------------------------------------- /config/EGSDE/edge_map_cat2dog-img256.yaml: -------------------------------------------------------------------------------- 1 | seed: 1234 2 | # data 3 | dataset: "cat2dog" 4 | num_classes: 2 5 | image_size: 256 6 | channels: 3 7 | # optimizer 8 | lr: 3e-4 9 | adam_betas: [ 0.9, 0.999 ] 10 | weight_decay: 0.05 11 | lr_scheduler: "cosine" # "linear", "cosine" ,"cosine_with_restarts", "constant_with_warmup" 12 | lr_warmup_steps: 100 13 | # train: 14 | train_num_steps: 700000 15 | train_batch_size: 32 16 | noised: True 17 | schedule_sampler: "uniform" 18 | # sample 19 | valid_batch_size: 8 20 | # methods 21 | perturb_step: 500 22 | repeat_step: 1 # K 23 | 24 | expert: 25 | lam_s: 100 26 | lam_i: 0.01 27 | s1: "cosine" 28 | s2: "neg_l2" 29 | down_N: 32 30 | 31 | model: 32 | type: "ADM" 33 | use_fp16: False 34 | in_channels: 3 35 | num_channels: 128 36 | out_channels: 3 37 | num_res_blocks: 1 38 | class_cond: False 39 | use_checkpoint: False 40 | attention_resolutions: "16" 41 | num_heads: 4 42 | num_head_channels: 64 43 | num_heads_upsample: -1 44 | use_scale_shift_norm: True 45 | dropout: 0 46 | resblock_updown: True 47 | use_new_attention_order: False 48 | 49 | diffusion: 50 | beta_schedule: "linear" 51 | beta_start: 0.0001 52 | beta_end: 0.02 53 | timesteps: 1000 54 | var_type: fixedsmall # here 55 | clip_denoised: True 56 | learn_sigma: True # out_channels * 2 57 | sigma_small: False 58 | use_kl: False 59 | predict_xstart: False 60 | rescale_timesteps: False 61 | rescale_learned_sigmas: False 62 | use_ddim: False 63 | timestep_respacing: "" 64 | 65 | dse: 66 | load_share_weights: True 67 | use_fp16: False 68 | model_channels: 128 # width 69 | num_res_blocks: 2 # depth 70 | attention_resolutions: '32,16,8' 71 | use_scale_shift_norm: True 72 | resblock_updown: True 73 | pool: 'attention' 74 | -------------------------------------------------------------------------------- /config/EGSDE/multi-afhq-img256.yaml: -------------------------------------------------------------------------------- 1 | seed: 1234 2 | # data 3 | dataset: "afhq_multi2dog" 4 | num_classes: 3 # three domain 5 | image_size: 256 6 | channels: 3 7 | # optimizer 8 | lr: 3e-4 9 | adam_betas: [ 0.9, 0.999 ] 10 | weight_decay: 0.05 11 | lr_scheduler: "cosine" # "linear", "cosine" ,"cosine_with_restarts", "constant_with_warmup" 12 | lr_warmup_steps: 100 13 | # train: 14 | train_num_steps: 700000 15 | train_batch_size: 32 16 | noised: True 17 | schedule_sampler: "uniform" 18 | # sample 19 | valid_batch_size: 8 20 | # methods 21 | perturb_step: 500 22 | repeat_step: 1 # K 23 | 24 | expert: 25 | lam_s: 500 26 | lam_i: 2 27 | s1: "cosine" 28 | s2: "neg_l2" 29 | down_N: 32 30 | 31 | model: 32 | type: "ADM" 33 | use_fp16: False 34 | in_channels: 3 35 | num_channels: 128 36 | out_channels: 3 37 | num_res_blocks: 1 38 | class_cond: False 39 | use_checkpoint: False 40 | attention_resolutions: "16" 41 | num_heads: 4 42 | num_head_channels: 64 43 | num_heads_upsample: -1 44 | use_scale_shift_norm: True 45 | dropout: 0 46 | resblock_updown: True 47 | use_new_attention_order: False 48 | 49 | diffusion: 50 | beta_schedule: "linear" 51 | beta_start: 0.0001 52 | beta_end: 0.02 53 | timesteps: 1000 54 | var_type: fixedsmall # here 55 | clip_denoised: True 56 | learn_sigma: True # out_channels * 2 57 | sigma_small: False 58 | use_kl: False 59 | predict_xstart: False 60 | rescale_timesteps: False 61 | rescale_learned_sigmas: False 62 | use_ddim: False 63 | timestep_respacing: "" 64 | 65 | dse: 66 | load_share_weights: True 67 | use_fp16: False 68 | model_channels: 128 69 | num_res_blocks: 2 70 | attention_resolutions: '32,16,8' 71 | use_scale_shift_norm: True 72 | resblock_updown: True 73 | pool: 'attention' 74 | -------------------------------------------------------------------------------- /config/ILVR/cat2dog-img256-respace100-down32-Rt20.yaml: -------------------------------------------------------------------------------- 1 | seed: 1234 2 | # data 3 | dataset: "cat2dog" 4 | num_classes: 2 5 | image_size: 256 6 | channels: 3 7 | # optimizer 8 | lr: 3e-4 9 | adam_betas: [ 0.9, 0.999 ] 10 | weight_decay: 0.05 11 | lr_scheduler: "cosine" # "linear", "cosine" ,"cosine_with_restarts", "constant_with_warmup" 12 | lr_warmup_steps: 100 13 | # train: 14 | train_num_steps: 700000 15 | train_batch_size: 32 16 | noised: True 17 | schedule_sampler: "uniform" 18 | # sample 19 | valid_batch_size: 8 20 | # methods 21 | down_N: 32 22 | range_t: 20 23 | 24 | model: 25 | type: "ADM" 26 | use_fp16: False 27 | in_channels: 3 28 | num_channels: 128 29 | out_channels: 3 30 | num_res_blocks: 1 31 | class_cond: False 32 | use_checkpoint: False 33 | attention_resolutions: "16" 34 | num_heads: 4 35 | num_head_channels: 64 36 | num_heads_upsample: -1 37 | use_scale_shift_norm: True 38 | dropout: 0 39 | resblock_updown: True 40 | use_new_attention_order: False 41 | 42 | diffusion: 43 | beta_schedule: "linear" 44 | beta_start: 0.0001 45 | beta_end: 0.02 46 | timesteps: 1000 47 | var_type: fixedsmall # here 48 | clip_denoised: True 49 | learn_sigma: True # out_channels * 2 50 | sigma_small: False 51 | use_kl: False 52 | predict_xstart: False 53 | rescale_timesteps: False 54 | rescale_learned_sigmas: False 55 | use_ddim: False 56 | timestep_respacing: "100" -------------------------------------------------------------------------------- /config/ILVR/cat2dog-img256-respace250-down32-Rt20.yaml: -------------------------------------------------------------------------------- 1 | seed: 1234 2 | # data 3 | dataset: "cat2dog" 4 | num_classes: 2 5 | image_size: 256 6 | channels: 3 7 | # optimizer 8 | lr: 3e-4 9 | adam_betas: [ 0.9, 0.999 ] 10 | weight_decay: 0.05 11 | lr_scheduler: "cosine" # "linear", "cosine" ,"cosine_with_restarts", "constant_with_warmup" 12 | lr_warmup_steps: 100 13 | # train: 14 | train_num_steps: 700000 15 | train_batch_size: 32 16 | noised: True 17 | schedule_sampler: "uniform" 18 | # sample 19 | valid_batch_size: 8 20 | # methods 21 | down_N: 32 22 | range_t: 20 23 | 24 | model: 25 | type: "ADM" 26 | use_fp16: False 27 | in_channels: 3 28 | num_channels: 128 29 | out_channels: 3 30 | num_res_blocks: 1 31 | class_cond: False 32 | use_checkpoint: False 33 | attention_resolutions: "16" 34 | num_heads: 4 35 | num_head_channels: 64 36 | num_heads_upsample: -1 37 | use_scale_shift_norm: True 38 | dropout: 0 39 | resblock_updown: True 40 | use_new_attention_order: False 41 | 42 | diffusion: 43 | beta_schedule: "linear" 44 | beta_start: 0.0001 45 | beta_end: 0.02 46 | timesteps: 1000 47 | var_type: fixedsmall # here 48 | clip_denoised: True 49 | learn_sigma: True # out_channels * 2 50 | sigma_small: False 51 | use_kl: False 52 | predict_xstart: False 53 | rescale_timesteps: False 54 | rescale_learned_sigmas: False 55 | use_ddim: False 56 | timestep_respacing: "250" -------------------------------------------------------------------------------- /config/ILVR/ffhq-p2-img256-respace100-down32-Rt20.yaml: -------------------------------------------------------------------------------- 1 | seed: 1234 2 | # data 3 | dataset: "cat2dog" 4 | num_classes: 2 5 | image_size: 256 6 | channels: 3 7 | # optimizer 8 | lr: 3e-4 9 | adam_betas: [ 0.9, 0.999 ] 10 | weight_decay: 0.05 11 | lr_scheduler: "cosine" # "linear", "cosine" ,"cosine_with_restarts", "constant_with_warmup" 12 | lr_warmup_steps: 100 13 | # train: 14 | train_num_steps: 700000 15 | train_batch_size: 32 16 | noised: True 17 | schedule_sampler: "uniform" 18 | # sample 19 | valid_batch_size: 8 20 | # methods 21 | down_N: 32 22 | range_t: 20 23 | 24 | model: 25 | type: "ADM" 26 | use_fp16: False 27 | in_channels: 3 28 | num_channels: 128 29 | out_channels: 3 30 | num_res_blocks: 1 31 | class_cond: False 32 | use_checkpoint: False 33 | attention_resolutions: "16" 34 | num_heads: 4 35 | num_head_channels: 64 36 | num_heads_upsample: -1 37 | use_scale_shift_norm: True 38 | dropout: 0 39 | resblock_updown: True 40 | use_new_attention_order: False 41 | 42 | diffusion: 43 | beta_schedule: "linear" 44 | beta_start: 0.0001 45 | beta_end: 0.02 46 | timesteps: 1000 47 | var_type: fixedsmall # here 48 | clip_denoised: True 49 | learn_sigma: True # out_channels * 2 50 | sigma_small: False 51 | use_kl: False 52 | predict_xstart: False 53 | rescale_timesteps: False 54 | rescale_learned_sigmas: False 55 | use_ddim: False 56 | timestep_respacing: "100" -------------------------------------------------------------------------------- /config/ILVR/wild2dog-img256-respace100-down32-Rt20.yaml: -------------------------------------------------------------------------------- 1 | seed: 1234 2 | # data 3 | dataset: "cat2dog" 4 | num_classes: 2 5 | image_size: 256 6 | channels: 3 7 | # optimizer 8 | lr: 3e-4 9 | adam_betas: [ 0.9, 0.999 ] 10 | weight_decay: 0.05 11 | lr_scheduler: "cosine" # "linear", "cosine" ,"cosine_with_restarts", "constant_with_warmup" 12 | lr_warmup_steps: 100 13 | # train: 14 | train_num_steps: 700000 15 | train_batch_size: 32 16 | noised: True 17 | schedule_sampler: "uniform" 18 | # sample 19 | valid_batch_size: 8 20 | # methods 21 | down_N: 32 22 | range_t: 20 23 | 24 | model: 25 | type: "ADM" 26 | use_fp16: False 27 | in_channels: 3 28 | num_channels: 128 29 | out_channels: 3 30 | num_res_blocks: 1 31 | class_cond: False 32 | use_checkpoint: False 33 | attention_resolutions: "16" 34 | num_heads: 4 35 | num_head_channels: 64 36 | num_heads_upsample: -1 37 | use_scale_shift_norm: True 38 | dropout: 0 39 | resblock_updown: True 40 | use_new_attention_order: False 41 | 42 | diffusion: 43 | beta_schedule: "linear" 44 | beta_start: 0.0001 45 | beta_end: 0.02 46 | timesteps: 1000 47 | var_type: fixedsmall # here 48 | clip_denoised: True 49 | learn_sigma: True # out_channels * 2 50 | sigma_small: False 51 | use_kl: False 52 | predict_xstart: False 53 | rescale_timesteps: False 54 | rescale_learned_sigmas: False 55 | use_ddim: False 56 | timestep_respacing: "100" -------------------------------------------------------------------------------- /config/ILVR/wild2dog-img256-respace250-down32-Rt20.yaml: -------------------------------------------------------------------------------- 1 | seed: 1234 2 | # data 3 | dataset: "cat2dog" 4 | num_classes: 2 5 | image_size: 256 6 | channels: 3 7 | # optimizer 8 | lr: 3e-4 9 | adam_betas: [ 0.9, 0.999 ] 10 | weight_decay: 0.05 11 | lr_scheduler: "cosine" # "linear", "cosine" ,"cosine_with_restarts", "constant_with_warmup" 12 | lr_warmup_steps: 100 13 | # train: 14 | train_num_steps: 700000 15 | train_batch_size: 32 16 | noised: True 17 | schedule_sampler: "uniform" 18 | # sample 19 | valid_batch_size: 8 20 | # methods 21 | down_N: 32 22 | range_t: 20 23 | 24 | model: 25 | type: "ADM" 26 | use_fp16: False 27 | in_channels: 3 28 | num_channels: 128 29 | out_channels: 3 30 | num_res_blocks: 1 31 | class_cond: False 32 | use_checkpoint: False 33 | attention_resolutions: "16" 34 | num_heads: 4 35 | num_head_channels: 64 36 | num_heads_upsample: -1 37 | use_scale_shift_norm: True 38 | dropout: 0 39 | resblock_updown: True 40 | use_new_attention_order: False 41 | 42 | diffusion: 43 | beta_schedule: "linear" 44 | beta_start: 0.0001 45 | beta_end: 0.02 46 | timesteps: 1000 47 | var_type: fixedsmall # here 48 | clip_denoised: True 49 | learn_sigma: True # out_channels * 2 50 | sigma_small: False 51 | use_kl: False 52 | predict_xstart: False 53 | rescale_timesteps: False 54 | rescale_learned_sigmas: False 55 | use_ddim: False 56 | timestep_respacing: "250" -------------------------------------------------------------------------------- /config/SDEdit/cat2dog-img256.yaml: -------------------------------------------------------------------------------- 1 | seed: 1234 2 | # data 3 | dataset: "cat2dog" 4 | num_classes: 2 5 | image_size: 256 6 | channels: 3 7 | # optimizer 8 | lr: 3e-4 9 | adam_betas: [ 0.9, 0.999 ] 10 | weight_decay: 0.05 11 | lr_scheduler: "linear" # "linear", "cosine" ,"cosine_with_restarts", "constant_with_warmup" 12 | lr_warmup_steps: 100 13 | # train: 14 | train_num_steps: 700000 15 | train_batch_size: 32 16 | noised: True 17 | schedule_sampler: "uniform" 18 | # sample 19 | valid_batch_size: 8 20 | # methods 21 | perturb_step: 400 22 | repeat_step: 1 # K 23 | 24 | model: 25 | type: "ADM" 26 | use_fp16: False 27 | in_channels: 3 28 | num_channels: 128 29 | out_channels: 3 30 | num_res_blocks: 1 31 | class_cond: False 32 | use_checkpoint: False 33 | attention_resolutions: "16" 34 | num_heads: 4 35 | num_head_channels: 64 36 | num_heads_upsample: -1 37 | use_scale_shift_norm: True 38 | dropout: 0 39 | resblock_updown: True 40 | use_new_attention_order: False 41 | 42 | diffusion: 43 | beta_schedule: "linear" 44 | beta_start: 0.0001 45 | beta_end: 0.02 46 | timesteps: 1000 47 | var_type: fixedsmall # here 48 | clip_denoised: True 49 | learn_sigma: True # out_channels * 2 50 | sigma_small: False 51 | use_kl: False 52 | predict_xstart: False 53 | rescale_timesteps: False 54 | rescale_learned_sigmas: False 55 | use_ddim: False 56 | timestep_respacing: "" 57 | -------------------------------------------------------------------------------- /config/SDEdit/iter-cat2dog-img256-p400-k33-dN32.yaml: -------------------------------------------------------------------------------- 1 | seed: 1234 2 | # data 3 | dataset: "cat2dog" 4 | num_classes: 2 5 | image_size: 256 6 | channels: 3 7 | # optimizer 8 | lr: 3e-4 9 | adam_betas: [ 0.9, 0.999 ] 10 | weight_decay: 0.05 11 | lr_scheduler: "linear" # "linear", "cosine" ,"cosine_with_restarts", "constant_with_warmup" 12 | lr_warmup_steps: 100 13 | # train: 14 | train_num_steps: 700000 15 | train_batch_size: 32 16 | noised: True 17 | schedule_sampler: "uniform" 18 | # sample 19 | valid_batch_size: 8 20 | # methods 21 | perturb_step: [ 400, 400 ] 22 | repeat_step: [ 3, 3 ] # K 23 | iter_step: 2 24 | src_down_N: 32 25 | fusion_scale: 0.5 26 | 27 | model: 28 | type: "ADM" 29 | use_fp16: False 30 | in_channels: 3 31 | num_channels: 128 32 | out_channels: 3 33 | num_res_blocks: 1 34 | class_cond: False 35 | use_checkpoint: False 36 | attention_resolutions: "16" 37 | num_heads: 4 38 | num_head_channels: 64 39 | num_heads_upsample: -1 40 | use_scale_shift_norm: True 41 | dropout: 0 42 | resblock_updown: True 43 | use_new_attention_order: False 44 | 45 | diffusion: 46 | beta_schedule: "linear" 47 | beta_start: 0.0001 48 | beta_end: 0.02 49 | timesteps: 1000 50 | var_type: fixedsmall # here 51 | clip_denoised: True 52 | learn_sigma: True # out_channels * 2 53 | sigma_small: False 54 | use_kl: False 55 | predict_xstart: False 56 | rescale_timesteps: False 57 | rescale_learned_sigmas: False 58 | use_ddim: False 59 | timestep_respacing: "" 60 | -------------------------------------------------------------------------------- /config/clfguided/imagenet-img128-respace250.yaml: -------------------------------------------------------------------------------- 1 | seed: 42 2 | num_samples: 25 3 | # data 4 | dataset: "imagenet-1k" 5 | num_classes: 1000 6 | image_size: 128 7 | channels: 3 8 | # optimizer 9 | lr: 2e-5 10 | adam_betas: [ 0.9, 0.999 ] 11 | weight_decay: 0 12 | # train: 13 | train_num_steps: 700000 14 | train_batch_size: 16 15 | noised: True 16 | schedule_sampler: "uniform" 17 | # sample 18 | sample_batch_size: 5 19 | # params 20 | classifier_scale: 0.5 21 | 22 | model: 23 | type: "ADM" 24 | use_fp16: True 25 | load_share_weights: False # use openai ckpt as init 26 | in_channels: 3 27 | num_channels: 256 28 | out_channels: 3 29 | num_res_blocks: 2 30 | class_cond: True # involved label embedding, as well as `y` 31 | use_checkpoint: False 32 | attention_resolutions: "32,16,8" 33 | num_heads: 4 34 | num_head_channels: 64 35 | num_heads_upsample: -1 36 | use_scale_shift_norm: True 37 | dropout: 0 38 | resblock_updown: True 39 | use_new_attention_order: False 40 | 41 | diffusion: 42 | beta_schedule: "linear" 43 | beta_start: 0.0001 44 | beta_end: 0.02 45 | timesteps: 1000 46 | clip_denoised: True 47 | learn_sigma: True # out_channels * 2 48 | sigma_small: False 49 | use_kl: False 50 | predict_xstart: False 51 | rescale_timesteps: False 52 | rescale_learned_sigmas: False 53 | use_ddim: False 54 | timestep_respacing: "250" 55 | 56 | classifier: 57 | use_fp16: False 58 | load_share_weights: False # use openai ckpt as init 59 | in_channels: ${channels} 60 | out_channels: ${num_classes} 61 | model_channels: 128 # width 62 | num_res_blocks: 2 # depth 63 | attention_resolutions: "32,16,8" 64 | dropout: 0 65 | num_head_channels: 64 66 | use_scale_shift_norm: True 67 | resblock_updown: True 68 | use_new_attention_order: False 69 | pool: "attention" -------------------------------------------------------------------------------- /config/clfguided/imagenet-img256-respace1000.yaml: -------------------------------------------------------------------------------- 1 | seed: 42 2 | num_samples: 25 3 | # data 4 | dataset: "imagenet-1k" 5 | num_classes: 1000 6 | image_size: 256 7 | channels: 3 8 | # optimizer 9 | lr: 2e-5 10 | adam_betas: [ 0.9, 0.999 ] 11 | weight_decay: 0 12 | # train: 13 | train_num_steps: 700000 14 | train_batch_size: 16 15 | noised: True 16 | schedule_sampler: "uniform" 17 | # sample 18 | sample_batch_size: 5 19 | # params 20 | classifier_scale: 1.0 21 | 22 | model: 23 | type: "ADM" 24 | use_fp16: True 25 | load_share_weights: False # use openai ckpt as init 26 | in_channels: 3 27 | num_channels: 256 28 | out_channels: 3 29 | num_res_blocks: 2 30 | class_cond: True # involved label embedding, as well as `y` 31 | use_checkpoint: False 32 | attention_resolutions: "32,16,8" 33 | num_heads: 4 34 | num_head_channels: 64 35 | num_heads_upsample: -1 36 | use_scale_shift_norm: True 37 | dropout: 0 38 | resblock_updown: True 39 | use_new_attention_order: False 40 | 41 | diffusion: 42 | beta_schedule: "linear" 43 | beta_start: 0.0001 44 | beta_end: 0.02 45 | timesteps: 1000 46 | clip_denoised: True 47 | learn_sigma: True # out_channels * 2 48 | sigma_small: False 49 | use_kl: False 50 | predict_xstart: False 51 | rescale_timesteps: False 52 | rescale_learned_sigmas: False 53 | use_ddim: False 54 | timestep_respacing: "1000" 55 | 56 | classifier: 57 | use_fp16: False 58 | load_share_weights: False # use openai ckpt as init 59 | in_channels: ${channels} 60 | out_channels: ${num_classes} 61 | model_channels: 128 # width 62 | num_res_blocks: 2 # depth 63 | attention_resolutions: "32,16,8" 64 | dropout: 0 65 | num_head_channels: 64 66 | use_scale_shift_norm: True 67 | resblock_updown: True 68 | use_new_attention_order: False 69 | pool: "attention" -------------------------------------------------------------------------------- /config/clfguided/imagenet-img256-respace250.yaml: -------------------------------------------------------------------------------- 1 | seed: 42 2 | num_samples: 25 3 | # data 4 | dataset: "imagenet-1k" 5 | num_classes: 1000 6 | image_size: 256 7 | channels: 3 8 | # optimizer 9 | lr: 2e-5 10 | adam_betas: [ 0.9, 0.999 ] 11 | weight_decay: 0 12 | # train: 13 | train_num_steps: 700000 14 | train_batch_size: 16 15 | noised: True 16 | schedule_sampler: "uniform" 17 | # sample 18 | sample_batch_size: 5 19 | # params 20 | classifier_scale: 1.0 21 | 22 | model: 23 | type: "ADM" 24 | use_fp16: True 25 | load_share_weights: False # use openai ckpt as init 26 | in_channels: 3 27 | num_channels: 256 28 | out_channels: 3 29 | num_res_blocks: 2 30 | class_cond: True # involved label embedding, as well as `y` 31 | use_checkpoint: False 32 | attention_resolutions: "32,16,8" 33 | num_heads: 4 34 | num_head_channels: 64 35 | num_heads_upsample: -1 36 | use_scale_shift_norm: True 37 | dropout: 0 38 | resblock_updown: True 39 | use_new_attention_order: False 40 | 41 | diffusion: 42 | beta_schedule: "linear" 43 | beta_start: 0.0001 44 | beta_end: 0.02 45 | timesteps: 1000 46 | clip_denoised: True 47 | learn_sigma: True # out_channels * 2 48 | sigma_small: False 49 | use_kl: False 50 | predict_xstart: False 51 | rescale_timesteps: False 52 | rescale_learned_sigmas: False 53 | use_ddim: True 54 | timestep_respacing: "ddim25" #"250" 55 | 56 | classifier: 57 | use_fp16: False 58 | load_share_weights: False # use openai ckpt as init 59 | in_channels: ${channels} 60 | out_channels: ${num_classes} 61 | model_channels: 128 # width 62 | num_res_blocks: 2 # depth 63 | attention_resolutions: "32,16,8" 64 | dropout: 0 65 | num_head_channels: 64 66 | use_scale_shift_norm: True 67 | resblock_updown: True 68 | use_new_attention_order: False 69 | pool: "attention" -------------------------------------------------------------------------------- /config/clfguided/imagenet-img256-uncond-respace250.yaml: -------------------------------------------------------------------------------- 1 | seed: 42 2 | num_samples: 25 3 | # data 4 | dataset: "imagenet-1k" 5 | num_classes: ~ 6 | image_size: 256 7 | channels: 3 8 | # optimizer 9 | lr: 2e-5 10 | adam_betas: [ 0.9, 0.999 ] 11 | weight_decay: 0 12 | # train: 13 | train_num_steps: 700000 14 | train_batch_size: 16 15 | noised: True 16 | schedule_sampler: "uniform" 17 | # sample 18 | sample_batch_size: 5 19 | # params 20 | classifier_scale: 1.0 21 | 22 | model: 23 | type: "ADM" 24 | use_fp16: True 25 | load_share_weights: False # use openai ckpt as init 26 | in_channels: 3 27 | num_channels: 256 28 | out_channels: 3 29 | num_res_blocks: 2 30 | class_cond: False # involved label embedding, as well as `y` 31 | use_checkpoint: False 32 | attention_resolutions: "32,16,8" 33 | num_heads: 4 34 | num_head_channels: 64 35 | num_heads_upsample: -1 36 | use_scale_shift_norm: True 37 | dropout: 0 38 | resblock_updown: True 39 | use_new_attention_order: False 40 | 41 | diffusion: 42 | beta_schedule: "linear" 43 | beta_start: 0.0001 44 | beta_end: 0.02 45 | timesteps: 1000 46 | clip_denoised: True 47 | learn_sigma: True # out_channels * 2 48 | sigma_small: False 49 | use_kl: False 50 | predict_xstart: False 51 | rescale_timesteps: False 52 | rescale_learned_sigmas: False 53 | use_ddim: False 54 | timestep_respacing: "250" 55 | 56 | classifier: 57 | use_fp16: True 58 | load_share_weights: False # use openai ckpt as init 59 | in_channels: ${channels} 60 | out_channels: ${num_classes} 61 | model_channels: 128 # width 62 | num_res_blocks: 2 # depth 63 | attention_resolutions: "32,16,8" 64 | dropout: 0 65 | num_head_channels: 64 66 | use_scale_shift_norm: True 67 | resblock_updown: True 68 | use_new_attention_order: False 69 | pool: "attention" -------------------------------------------------------------------------------- /config/clfguided/imagenet-img64-respace250.yaml: -------------------------------------------------------------------------------- 1 | seed: 42 2 | num_samples: 25 3 | # data 4 | dataset: "imagenet-1k" 5 | num_classes: 1000 6 | image_size: 64 7 | channels: 3 8 | # optimizer 9 | lr: 2e-5 10 | adam_betas: [ 0.9, 0.999 ] 11 | weight_decay: 0 12 | # train: 13 | train_num_steps: 700000 14 | train_batch_size: 16 15 | noised: True 16 | schedule_sampler: "uniform" 17 | # sample 18 | sample_batch_size: 5 19 | # params 20 | classifier_scale: 1.0 21 | 22 | model: 23 | type: "ADM" 24 | use_fp16: True 25 | load_share_weights: False # use openai ckpt as init 26 | in_channels: 3 27 | num_channels: 192 28 | out_channels: 3 29 | num_res_blocks: 3 30 | class_cond: True # involved label embedding, as well as `y` 31 | use_checkpoint: False 32 | attention_resolutions: "32,16,8" 33 | num_heads: 4 34 | num_head_channels: 64 35 | num_heads_upsample: -1 36 | use_scale_shift_norm: True 37 | dropout: 0.1 38 | resblock_updown: True 39 | use_new_attention_order: True 40 | 41 | diffusion: 42 | beta_schedule: "cosine" 43 | beta_start: 0.0001 44 | beta_end: 0.02 45 | timesteps: 1000 46 | clip_denoised: True 47 | learn_sigma: True # out_channels * 2 48 | sigma_small: False 49 | use_kl: False 50 | predict_xstart: False 51 | rescale_timesteps: False 52 | rescale_learned_sigmas: False 53 | use_ddim: False 54 | timestep_respacing: "250" 55 | 56 | classifier: 57 | use_fp16: False 58 | load_share_weights: False # use openai ckpt as init 59 | in_channels: ${channels} 60 | out_channels: ${num_classes} 61 | model_channels: 128 # width 62 | num_res_blocks: 4 # depth 63 | attention_resolutions: "32,16,8" 64 | dropout: 0 65 | num_head_channels: 64 66 | use_scale_shift_norm: True 67 | resblock_updown: True 68 | use_new_attention_order: False 69 | pool: "attention" -------------------------------------------------------------------------------- /config/inversion/afhq-cat2dog-ADM.yaml: -------------------------------------------------------------------------------- 1 | seed: 1234 2 | # data 3 | dataset: "cat2dog" 4 | num_classes: 2 5 | image_size: 256 6 | channels: 3 7 | # optimizer 8 | lr: 3e-4 9 | adam_betas: [ 0.9, 0.999 ] 10 | weight_decay: 0.05 11 | lr_scheduler: "linear" # "linear", "cosine" ,"cosine_with_restarts", "constant_with_warmup" 12 | lr_warmup_steps: 100 13 | # train: 14 | train_num_steps: 700000 15 | train_batch_size: 32 16 | noised: True 17 | schedule_sampler: "uniform" 18 | # sample 19 | valid_batch_size: 8 20 | # methods 21 | perturb_step: [ 400, 400, 400 ] 22 | repeat_step: [ 3, 2, 1 ] # K 23 | n_stage: 3 24 | src_down_N: 64 25 | sdepath: './checkpoint/InvSDE/afhq_dog_4m.pt' 26 | # shape control 27 | use_shape: [ True, True, True ] 28 | shape_metric: "l2" # "l1", "lpips", "l2" , "cosine" 29 | lam_shape: [ 0.5,0.6,0.5 ] 30 | ys_rescale: True 31 | shape_expert_root: './checkpoint/InvSDE/info-drawing/' 32 | shape_style: 'anime_style' 33 | # texture control 34 | pixel_texture: [ False, True, False ] 35 | lam_pixel_texture: [ 0, 2, 0 ] 36 | blur_y: True 37 | blur_xt: True 38 | feature_texture: [ False, True, False ] 39 | feature_texture_model: "CLIP" # "VGG", "CLIP", "inceptionV3" 40 | feature_texture_metric: "l2" # "l1", "cosine" 41 | lam_feature_texture: [ 0, 0.5, 0 ] 42 | # style control 43 | use_style: [ False, False, False ] 44 | preserve_color: False 45 | style_decoder: './checkpoint/style/decoder.pth' 46 | style_vgg: './checkpoint/style/vgg_normalised.pth' 47 | # domain-specific features 48 | use_dse: [ False, False, False ] 49 | dse_metric: "neg_l2" # "cosine" 50 | lam_dse: [ 0,0,0 ] 51 | dsepath: '' # './checkpoint/InvSDE/cat2dog_dse.pt' 52 | 53 | dpm_solver: 54 | t_guided: 55 | 1: [ 300, 200, 100, 0 ] 56 | 2: [ 300, 200, 100, 0 ] 57 | 3: [ 300, 200, 100, 0 ] 58 | t_dpm_solver_dense: 59 | 1: [ 399, 350 ] # range 60 | 2: [ 399, 300 ] # range 61 | 3: [ 399, 350 ] # range 62 | t_dpm_solver_spare: 63 | 1: [ 299, 298, 297, 296, 199, 198, 197, 196, 99, 98, 97, 96 ] 64 | 2: [ 299, 298, 297, 296, 199, 198, 197, 196, 99, 98, 97, 96 ] 65 | 3: [ 299, 298, 297, 296, 199, 198, 197, 196, 99, 98, 97, 96 ] 66 | 67 | model: 68 | type: "ADM" 69 | use_fp16: False 70 | in_channels: 3 71 | num_channels: 128 72 | out_channels: 3 73 | num_res_blocks: 1 74 | class_cond: False 75 | use_checkpoint: False 76 | attention_resolutions: "16" 77 | num_heads: 4 78 | num_head_channels: 64 79 | num_heads_upsample: -1 80 | use_scale_shift_norm: True 81 | dropout: 0 82 | resblock_updown: True 83 | use_new_attention_order: False 84 | 85 | diffusion: 86 | beta_schedule: "linear" 87 | beta_start: 0.0001 88 | beta_end: 0.02 89 | timesteps: 1000 90 | var_type: fixedsmall # here 91 | clip_denoised: True 92 | learn_sigma: True # out_channels * 2 93 | sigma_small: False 94 | use_kl: False 95 | predict_xstart: False 96 | rescale_timesteps: False 97 | rescale_learned_sigmas: False 98 | use_ddim: False 99 | timestep_respacing: "" 100 | -------------------------------------------------------------------------------- /config/inversion/afhq-wild2dog-ADM.yaml: -------------------------------------------------------------------------------- 1 | seed: 1234 2 | # data 3 | dataset: "wild2dog" 4 | num_classes: 2 5 | image_size: 256 6 | channels: 3 7 | # optimizer 8 | lr: 3e-4 9 | adam_betas: [ 0.9, 0.999 ] 10 | weight_decay: 0.05 11 | lr_scheduler: "linear" # "linear", "cosine" ,"cosine_with_restarts", "constant_with_warmup" 12 | lr_warmup_steps: 100 13 | # train: 14 | train_num_steps: 700000 15 | train_batch_size: 32 16 | noised: True 17 | schedule_sampler: "uniform" 18 | # sample 19 | valid_batch_size: 8 20 | # methods 21 | perturb_step: [ 400, 400, 400 ] 22 | repeat_step: [ 3, 3, 1 ] # K 23 | n_stage: 3 24 | src_down_N: 64 25 | sdepath: './checkpoint/InvSDE/afhq_dog_4m.pt' 26 | # shape control 27 | use_shape: [ True, True, True ] 28 | shape_metric: "l2" # "l1", "lpips", "l2" , "cosine" 29 | lam_shape: [ 0.1, 0.3, 0.1 ] 30 | ys_rescale: True 31 | shape_expert_root: './checkpoint/InvSDE/info-drawing/' 32 | shape_style: 'anime_style' 33 | # texture control 34 | pixel_texture: [ False, True, False ] 35 | lam_pixel_texture: [ 0, 2, 0 ] 36 | blur_y: True 37 | blur_xt: True 38 | feature_texture: [ False, False, False ] 39 | feature_texture_model: "inceptionV3" # "VGG", "CLIP" 40 | feature_texture_metric: "l2" 41 | lam_feature_texture: [ 0, 0, 0 ] 42 | # style control 43 | use_style: [ False, False, False ] 44 | preserve_color: False 45 | style_decoder: './checkpoint/style/decoder.pth' 46 | style_vgg: './checkpoint/style/vgg_normalised.pth' 47 | # domain-specific features 48 | use_dse: [ False, False, False ] 49 | dse_metric: "neg_l2" 50 | lam_dse: [ 0,0,0 ] 51 | dsepath: '' # './checkpoint/InvSDE/cat2dog_dse.pt' 52 | 53 | dpm_solver: 54 | t_guided: 55 | 1: [ 300, 200, 100, 0 ] 56 | 2: [ 300, 200, 100, 0 ] 57 | 3: [ 300, 200, 100, 0 ] 58 | t_dpm_solver_dense: 59 | 1: [ 399, 350 ] # range 60 | 2: [ 399, 300 ] # range 61 | 3: [ 399, 350 ] # range 62 | t_dpm_solver_spare: 63 | 1: [ 299, 298, 297, 296, 199, 198, 197, 196, 99, 98, 97, 96 ] 64 | 2: [ 299, 298, 297, 296, 199, 198, 197, 196, 99, 98, 97, 96 ] 65 | 3: [ 299, 298, 297, 296, 199, 198, 197, 196, 99, 98, 97, 96 ] 66 | 67 | model: 68 | type: "ADM" 69 | use_fp16: False 70 | in_channels: 3 71 | num_channels: 128 72 | out_channels: 3 73 | num_res_blocks: 1 74 | class_cond: False 75 | use_checkpoint: False 76 | attention_resolutions: "16" 77 | num_heads: 4 78 | num_head_channels: 64 79 | num_heads_upsample: -1 80 | use_scale_shift_norm: True 81 | dropout: 0 82 | resblock_updown: True 83 | use_new_attention_order: False 84 | 85 | diffusion: 86 | beta_schedule: "linear" 87 | beta_start: 0.0001 88 | beta_end: 0.02 89 | timesteps: 1000 90 | var_type: fixedsmall # here 91 | clip_denoised: True 92 | learn_sigma: True # out_channels * 2 93 | sigma_small: False 94 | use_kl: False 95 | predict_xstart: False 96 | rescale_timesteps: False 97 | rescale_learned_sigmas: False 98 | use_ddim: False 99 | timestep_respacing: "" 100 | -------------------------------------------------------------------------------- /data/afhq_demo/cat/flickr_cat_000033.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ximinng/inversion-by-inversion/0d74d60567edb8f96c8c15341ec8446f170c9a62/data/afhq_demo/cat/flickr_cat_000033.jpg -------------------------------------------------------------------------------- /data/afhq_demo/dog_sketch/flickr_dog_000005_out.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ximinng/inversion-by-inversion/0d74d60567edb8f96c8c15341ec8446f170c9a62/data/afhq_demo/dog_sketch/flickr_dog_000005_out.png -------------------------------------------------------------------------------- /dataset/README.md: -------------------------------------------------------------------------------- 1 | Place the dataset in this folder -------------------------------------------------------------------------------- /dpm_nn/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) XiMing Xing. All rights reserved. 3 | # Author: XiMing Xing 4 | # Description: 5 | 6 | from . import dpm_solver 7 | from . import guided_dpm # openAI guided-diffusion 8 | -------------------------------------------------------------------------------- /dpm_nn/dpm_solver/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) XiMing Xing. All rights reserved. 3 | # Author: XiMing Xing 4 | # Description: 5 | from .dpm_solver_pp import NoiseScheduleVP, DPM_Solver, model_wrapper 6 | -------------------------------------------------------------------------------- /dpm_nn/guided_dpm/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) XiMing Xing. All rights reserved. 3 | # Author: XiMing Xing 4 | # Description: 5 | 6 | from .build_ADMs import ADMs_build_util 7 | from .spaced_diffusion import build_spaced_gaussian_diffusion 8 | from .beta_schedule import get_named_beta_schedule 9 | from .gaussian_diffusion import GaussianDiffusion -------------------------------------------------------------------------------- /dpm_nn/guided_dpm/beta_schedule.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Description: 3 | 4 | import math 5 | import numpy as np 6 | 7 | 8 | def get_named_beta_schedule( 9 | schedule_name: str, 10 | timesteps: int, 11 | beta_start: float = 0.0001, 12 | beta_end: float = 0.02, 13 | ): 14 | """ 15 | Get a pre-defined beta schedule for the given name. 16 | 17 | The beta schedule library consists of beta schedules which remain similar 18 | in the limit of num_diffusion_timesteps. 19 | Beta schedules may be added, but should not be removed or changed once 20 | they are committed to maintain backwards compatibility. 21 | """ 22 | if schedule_name == "scaled_linear": 23 | # Linear schedule from Ho et al, extended to work for any number of 24 | # diffusion steps. 25 | scale = 1000 / timesteps 26 | beta_start = scale * beta_start 27 | beta_end = scale * beta_end 28 | betas = np.linspace( 29 | beta_start, beta_end, timesteps, dtype=np.float64 30 | ) 31 | elif schedule_name == "linear": 32 | betas = np.linspace( 33 | beta_start, beta_end, timesteps, dtype=np.float64 34 | ) 35 | elif schedule_name == "cosine": 36 | betas = _betas_for_alpha_bar( 37 | timesteps, 38 | lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, 39 | ) 40 | elif schedule_name == "warmup10": 41 | betas = _warmup_beta(beta_start, beta_end, timesteps, 0.1) 42 | elif schedule_name == "warmup50": 43 | betas = _warmup_beta(beta_start, beta_end, timesteps, 0.5) 44 | elif schedule_name == "quad": 45 | betas = ( 46 | np.linspace( 47 | beta_start ** 0.5, beta_end ** 0.5, timesteps, 48 | dtype=np.float64 49 | ) ** 2 50 | ) 51 | elif schedule_name == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1 52 | betas = 1.0 / np.linspace( 53 | timesteps, 1, timesteps, dtype=np.float64 54 | ) 55 | elif schedule_name == "sigmoid": 56 | def sigmoid(x): 57 | s = 1 / (1 + np.exp(-x)) 58 | return s 59 | 60 | betas = np.linspace(-6, 6, timesteps) 61 | betas = sigmoid(betas) * (beta_end - beta_start) + beta_start 62 | else: 63 | raise NotImplementedError(f"unknown beta schedule: {schedule_name}") 64 | 65 | assert betas.shape == (timesteps,) 66 | return betas 67 | 68 | 69 | def _betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): 70 | """ 71 | Create a beta schedule that discretizes the given alpha_t_bar function, 72 | which defines the cumulative product of (1-beta) over time from t = [0,1]. 73 | 74 | :param num_diffusion_timesteps: the number of betas to produce. 75 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and 76 | produces the cumulative product of (1-beta) up to that 77 | part of the diffusion process. 78 | :param max_beta: the maximum beta to use; use values lower than 1 to 79 | prevent singularities. 80 | """ 81 | betas = [] 82 | for i in range(num_diffusion_timesteps): 83 | t1 = i / num_diffusion_timesteps 84 | t2 = (i + 1) / num_diffusion_timesteps 85 | betas.append( 86 | min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta) 87 | ) 88 | return np.array(betas) 89 | 90 | 91 | def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac): 92 | betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) 93 | warmup_time = int(num_diffusion_timesteps * warmup_frac) 94 | betas[:warmup_time] = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64) 95 | return betas 96 | -------------------------------------------------------------------------------- /dpm_nn/guided_dpm/build_ADMs.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) XiMing Xing. All rights reserved. 3 | # Author: XiMing Xing 4 | # Description: 5 | 6 | from omegaconf import DictConfig 7 | 8 | from .ADMs import UNetModel as ClfGuidedUnet, EncoderUNetModel 9 | 10 | __all__ = ['ADMs_build_util'] 11 | 12 | 13 | def ADMs_build_util( 14 | image_size, 15 | num_classes, # class conditional 16 | model_cfg: DictConfig, 17 | dpm_cfg: DictConfig, 18 | build_clf: bool = False, 19 | clf_cfg: DictConfig = None 20 | ): 21 | if image_size == 512: 22 | channel_mult = (0.5, 1, 1, 2, 2, 4, 4) 23 | elif image_size == 256: 24 | channel_mult = (1, 1, 2, 2, 4, 4) 25 | elif image_size in [128, 96]: 26 | channel_mult = (1, 1, 2, 3, 4) 27 | elif image_size == 64: 28 | channel_mult = (1, 2, 3, 4) 29 | else: 30 | raise ValueError(f"unsupported image size: {image_size}") 31 | model_cfg.channel_mult = channel_mult # add key 32 | 33 | _attention_resolutions = model_cfg.attention_resolutions # example: [32, 16, 8] 34 | attention_ds = [] 35 | for res in model_cfg.attention_resolutions.split(","): 36 | attention_ds.append(image_size // int(res)) 37 | model_cfg.attention_ds = attention_ds # add key 38 | 39 | model_cfg.out_channels = (3 if not dpm_cfg.learn_sigma else 6) # update key 40 | 41 | eps_model = ClfGuidedUnet( 42 | image_size=image_size, 43 | in_channels=model_cfg.in_channels, 44 | num_res_blocks=model_cfg.num_res_blocks, # depth 45 | model_channels=model_cfg.num_channels, # width 46 | out_channels=model_cfg.out_channels, 47 | attention_resolutions=tuple(model_cfg.attention_ds), 48 | dropout=model_cfg.dropout, 49 | channel_mult=model_cfg.channel_mult, 50 | num_classes=(num_classes if model_cfg.class_cond else None), 51 | num_heads=model_cfg.num_heads, 52 | num_head_channels=model_cfg.num_head_channels, 53 | use_scale_shift_norm=model_cfg.use_scale_shift_norm, 54 | resblock_updown=model_cfg.resblock_updown, 55 | use_new_attention_order=model_cfg.use_new_attention_order, 56 | use_fp16=model_cfg.use_fp16 57 | ) 58 | 59 | if (clf_cfg is not None) and build_clf: 60 | clf_cfg.channel_mult = channel_mult # add key 61 | 62 | clf_attention_ds = [] 63 | for res in clf_cfg.attention_resolutions.split(","): 64 | clf_attention_ds.append(image_size // int(res)) 65 | clf_cfg.attention_ds = attention_ds # add key 66 | 67 | clf_model = EncoderUNetModel( 68 | image_size=image_size, 69 | in_channels=clf_cfg.in_channels, 70 | model_channels=clf_cfg.model_channels, # width 71 | num_res_blocks=clf_cfg.num_res_blocks, # depth 72 | out_channels=clf_cfg.out_channels, 73 | attention_resolutions=tuple(clf_cfg.attention_ds), 74 | dropout=clf_cfg.dropout, 75 | num_head_channels=clf_cfg.num_head_channels, 76 | channel_mult=clf_cfg.channel_mult, 77 | use_scale_shift_norm=clf_cfg.use_scale_shift_norm, 78 | resblock_updown=clf_cfg.resblock_updown, 79 | pool=clf_cfg.pool, 80 | use_fp16=clf_cfg.use_fp16 81 | ) 82 | return eps_model, clf_model 83 | else: 84 | return eps_model, None 85 | -------------------------------------------------------------------------------- /dpm_nn/guided_dpm/distribution.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers for various likelihood-based losses. These are ported from the original 3 | Ho et al. diffusion models codebase: 4 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/utils.py 5 | """ 6 | 7 | import numpy as np 8 | import torch as th 9 | 10 | 11 | def normal_kl(mean1, logvar1, mean2, logvar2): 12 | """ 13 | Compute the KL divergence between two gaussians. 14 | 15 | Shapes are automatically broadcasted, so batches can be compared to 16 | scalars, among other use cases. 17 | """ 18 | tensor = None 19 | for obj in (mean1, logvar1, mean2, logvar2): 20 | if isinstance(obj, th.Tensor): 21 | tensor = obj 22 | break 23 | assert tensor is not None, "at least one argument must be a Tensor" 24 | 25 | # Force variances to be Tensors. Broadcasting helps convert scalars to 26 | # Tensors, but it does not work for th.exp(). 27 | logvar1, logvar2 = [ 28 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) 29 | for x in (logvar1, logvar2) 30 | ] 31 | 32 | # KL( N(mu1, sigma1), N(mu2, sigma2) ) = 33 | # \frac{1}{2} * [ log(sigma2 / sigma1) - d + tr(sigma2 * sigma1) + (mu2- mu1)^T * (mu2 - mu1) * sigma2 ] 34 | return 0.5 * ( 35 | -1.0 36 | + logvar2 - logvar1 37 | + th.exp(logvar1 - logvar2) 38 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2) 39 | ) 40 | 41 | 42 | def approx_standard_normal_cdf(x): 43 | """ 44 | A fast approximation of the cumulative distribution function of the 45 | standard normal. 46 | """ 47 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) 48 | 49 | 50 | def discretized_gaussian_log_likelihood(x, *, means, log_scales): 51 | """ 52 | Compute the log-likelihood of a Gaussian distribution discretizing to a 53 | given image. 54 | 55 | :param x: the target images. It is assumed that this was uint8 values, 56 | rescaled to the range [-1, 1]. 57 | :param means: the Gaussian mean Tensor. 58 | :param log_scales: the Gaussian log stddev Tensor. 59 | :return: a tensor like x of log probabilities (in nats). 60 | """ 61 | assert x.shape == means.shape == log_scales.shape 62 | centered_x = x - means 63 | inv_stdv = th.exp(-log_scales) 64 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0) 65 | cdf_plus = approx_standard_normal_cdf(plus_in) 66 | min_in = inv_stdv * (centered_x - 1.0 / 255.0) 67 | cdf_min = approx_standard_normal_cdf(min_in) 68 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) 69 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) 70 | cdf_delta = cdf_plus - cdf_min 71 | log_probs = th.where( 72 | x < -0.999, 73 | log_cdf_plus, 74 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), 75 | ) 76 | assert log_probs.shape == x.shape 77 | return log_probs 78 | -------------------------------------------------------------------------------- /dpm_nn/guided_dpm/loss_resample.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) XiMing Xing. All rights reserved. 3 | # Description: 4 | 5 | from abc import ABC, abstractmethod 6 | import numpy as np 7 | 8 | import torch as th 9 | import torch.distributed as dist 10 | 11 | 12 | def create_named_schedule_sampler(name, diffusion): 13 | """ 14 | Create a ScheduleSampler from a library of pre-defined samplers. 15 | 16 | :param name: the name of the sampler. 17 | :param diffusion: the diffusion object to sample for. 18 | """ 19 | if name == "uniform": 20 | return UniformSampler(diffusion) 21 | elif name == "loss-second-moment": 22 | return LossSecondMomentResampler(diffusion) 23 | else: 24 | raise NotImplementedError(f"unknown schedule sampler: {name}") 25 | 26 | 27 | class ScheduleSampler(ABC): 28 | """ 29 | A distribution over timesteps in the diffusion process, intended to reduce 30 | variance of the objective. 31 | 32 | By default, samplers perform unbiased importance sampling, in which the 33 | objective's mean is unchanged. 34 | However, subclasses may override sample() to change how the resampled 35 | terms are reweighted, allowing for actual changes in the objective. 36 | """ 37 | 38 | @abstractmethod 39 | def weights(self): 40 | """ 41 | Get a numpy array of weights, one per diffusion step. 42 | 43 | The weights needn't be normalized, but must be positive. 44 | """ 45 | 46 | def sample(self, batch_size, device): 47 | """ 48 | Importance-sample timesteps for a batch. 49 | 50 | :param batch_size: the number of timesteps. 51 | :param device: the torch device to save to. 52 | :return: a tuple (timesteps, weights): 53 | - timesteps: a tensor of timestep indices. 54 | - weights: a tensor of weights to scale the resulting losses. 55 | """ 56 | w = self.weights() 57 | p = w / np.sum(w) 58 | indices_np = np.random.choice(len(p), size=(batch_size,), p=p) 59 | indices = th.from_numpy(indices_np).long().to(device) 60 | weights_np = 1 / (len(p) * p[indices_np]) 61 | weights = th.from_numpy(weights_np).float().to(device) 62 | return indices, weights 63 | 64 | 65 | class UniformSampler(ScheduleSampler): 66 | def __init__(self, diffusion): 67 | self.diffusion = diffusion 68 | self._weights = np.ones([diffusion.num_timesteps]) 69 | 70 | def weights(self): 71 | return self._weights 72 | 73 | 74 | class LossAwareSampler(ScheduleSampler): 75 | def update_with_local_losses(self, local_ts, local_losses): 76 | """ 77 | Update the reweighting using losses from a model. 78 | 79 | Call this methods from each rank with a batch of timesteps and the 80 | corresponding losses for each of those timesteps. 81 | This methods will perform synchronization to make sure all of the ranks 82 | maintain the exact same reweighting. 83 | 84 | :param local_ts: an integer Tensor of timesteps. 85 | :param local_losses: a 1D Tensor of losses. 86 | """ 87 | batch_sizes = [ 88 | th.tensor([0], dtype=th.int32, device=local_ts.device) 89 | for _ in range(dist.get_world_size()) 90 | ] 91 | dist.all_gather( 92 | batch_sizes, 93 | th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device), 94 | ) 95 | 96 | # Pad all_gather batches to be the maximum batch size. 97 | batch_sizes = [x.item() for x in batch_sizes] 98 | max_bs = max(batch_sizes) 99 | 100 | timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes] 101 | loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes] 102 | dist.all_gather(timestep_batches, local_ts) 103 | dist.all_gather(loss_batches, local_losses) 104 | timesteps = [ 105 | x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs] 106 | ] 107 | losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]] 108 | self.update_with_all_losses(timesteps, losses) 109 | 110 | @abstractmethod 111 | def update_with_all_losses(self, ts, losses): 112 | """ 113 | Update the reweighting using losses from a model. 114 | 115 | Sub-classes should override this methods to update the reweighting 116 | using losses from the model. 117 | 118 | This methods directly updates the reweighting without synchronizing 119 | between workers. It is called by update_with_local_losses from all 120 | ranks with identical arguments. Thus, it should have deterministic 121 | behavior to maintain state across workers. 122 | 123 | :param ts: a list of int timesteps. 124 | :param losses: a list of float losses, one per timestep. 125 | """ 126 | 127 | 128 | class LossSecondMomentResampler(LossAwareSampler): 129 | def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001): 130 | self.diffusion = diffusion 131 | self.history_per_term = history_per_term 132 | self.uniform_prob = uniform_prob 133 | self._loss_history = np.zeros( 134 | [diffusion.num_timesteps, history_per_term], dtype=np.float64 135 | ) 136 | self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int) 137 | 138 | def weights(self): 139 | if not self._warmed_up(): 140 | return np.ones([self.diffusion.num_timesteps], dtype=np.float64) 141 | weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1)) 142 | weights /= np.sum(weights) 143 | weights *= 1 - self.uniform_prob 144 | weights += self.uniform_prob / len(weights) 145 | return weights 146 | 147 | def update_with_all_losses(self, ts, losses): 148 | for t, loss in zip(ts, losses): 149 | if self._loss_counts[t] == self.history_per_term: 150 | # Shift out the oldest loss term. 151 | self._loss_history[t, :-1] = self._loss_history[t, 1:] 152 | self._loss_history[t, -1] = loss 153 | else: 154 | self._loss_history[t, self._loss_counts[t]] = loss 155 | self._loss_counts[t] += 1 156 | 157 | def _warmed_up(self): 158 | return (self._loss_counts == self.history_per_term).all() 159 | -------------------------------------------------------------------------------- /dpm_nn/inversion/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) XiMing Xing. All rights reserved. 3 | # Author: XiMing Xing 4 | # Description: 5 | -------------------------------------------------------------------------------- /dpm_nn/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) XiMing Xing. All rights reserved. 3 | # Author: XiMing Xing 4 | # Description: 5 | 6 | import torch 7 | 8 | 9 | def extract(a, t, x_shape): 10 | b, *_ = t.shape 11 | assert x_shape[0] == b 12 | out = a.gather(-1, t) # 1-D tensor, shape: (b,) 13 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) # shape: [b, 1, 1, 1] 14 | 15 | 16 | def unnormalize(x): 17 | """unnormalize_to_zero_to_one""" 18 | x = (x + 1) * 0.5 # Map the data interval to [0, 1] 19 | return torch.clamp(x, 0.0, 1.0) 20 | 21 | 22 | def normalize(x): 23 | """normalize_to_neg_one_to_one""" 24 | x = x * 2 - 1 # Map the data interval to [-1, 1] 25 | return torch.clamp(x, -1.0, 1.0) 26 | -------------------------------------------------------------------------------- /libs/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) XiMing Xing. All rights reserved. 3 | # Author: XiMing Xing 4 | # Description: a self consistent system, 5 | # including runner, trainer, loss function, EMA, optimizer, lr scheduler , and common utils. 6 | 7 | from .utils import lazy 8 | 9 | __getattr__, __dir__, __all__ = lazy.attach( 10 | __name__, 11 | submodules={'engine', 'metric', 'modules', 'solver', 'utils'}, 12 | submod_attrs={} 13 | ) 14 | 15 | __version__ = '0.0.1' 16 | -------------------------------------------------------------------------------- /libs/engine/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) XiMing Xing. All rights reserved. 3 | # Author: XiMing Xing 4 | # Description: 5 | 6 | from .model_state import ModelState 7 | 8 | __all__ = [ 9 | 'ModelState' 10 | ] 11 | -------------------------------------------------------------------------------- /libs/metric/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) XiMing Xing. All rights reserved. 3 | # Author: XiMing Xing 4 | # Description: 5 | -------------------------------------------------------------------------------- /libs/metric/accuracy.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) XiMing Xing. All rights reserved. 3 | # Author: XiMing Xing 4 | # Description: 5 | 6 | 7 | def accuracy(output, target, topk=(1,)): 8 | """ 9 | Computes the accuracy over the k top predictions for the specified values of k. 10 | 11 | Args 12 | output: logits or probs (num of batch, num of classes) 13 | target: (num of batch, 1) or (num of batch, ) 14 | topk: list of returned k 15 | 16 | refer: https://github.com/pytorch/examples/blob/master/imagenet/main.py 17 | """ 18 | maxK = max(topk) # get k in top-k 19 | batch_size = target.size(0) 20 | 21 | _, pred = output.topk(k=maxK, dim=1, largest=True, sorted=True) # pred: [num of batch, k] 22 | pred = pred.t() # pred: [k, num of batch] 23 | 24 | # [1, num of batch] -> [k, num_of_batch] : bool 25 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 26 | 27 | res = [] 28 | for k in topk: 29 | correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True) 30 | res.append(correct_k.mul_(100.0 / batch_size)) 31 | return res # np.shape(res): [k, 1] 32 | -------------------------------------------------------------------------------- /libs/metric/lpips_origin/__init__.py: -------------------------------------------------------------------------------- 1 | from .lpips import LPIPS 2 | 3 | __all__ = ['LPIPS'] 4 | -------------------------------------------------------------------------------- /libs/metric/lpips_origin/weights/v0.1/alex.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ximinng/inversion-by-inversion/0d74d60567edb8f96c8c15341ec8446f170c9a62/libs/metric/lpips_origin/weights/v0.1/alex.pth -------------------------------------------------------------------------------- /libs/metric/lpips_origin/weights/v0.1/squeeze.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ximinng/inversion-by-inversion/0d74d60567edb8f96c8c15341ec8446f170c9a62/libs/metric/lpips_origin/weights/v0.1/squeeze.pth -------------------------------------------------------------------------------- /libs/metric/lpips_origin/weights/v0.1/vgg.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ximinng/inversion-by-inversion/0d74d60567edb8f96c8c15341ec8446f170c9a62/libs/metric/lpips_origin/weights/v0.1/vgg.pth -------------------------------------------------------------------------------- /libs/metric/piq/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) XiMing Xing. All rights reserved. 3 | # Author: XiMing Xing 4 | # Description: 5 | 6 | # install: pip install piq 7 | # repo: https://github.com/photosynthesis-team/piq 8 | -------------------------------------------------------------------------------- /libs/metric/piq/functional/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import ifftshift, get_meshgrid, similarity_map, gradient_map, pow_for_complex, crop_patches 2 | from .colour_conversion import rgb2lmn, rgb2xyz, xyz2lab, rgb2lab, rgb2yiq, rgb2lhm 3 | from .filters import haar_filter, hann_filter, scharr_filter, prewitt_filter, gaussian_filter 4 | from .filters import binomial_filter1d, average_filter2d 5 | from .layers import L2Pool2d 6 | from .resize import imresize 7 | 8 | __all__ = [ 9 | 'ifftshift', 'get_meshgrid', 'similarity_map', 'gradient_map', 'pow_for_complex', 'crop_patches', 10 | 'rgb2lmn', 'rgb2xyz', 'xyz2lab', 'rgb2lab', 'rgb2yiq', 'rgb2lhm', 11 | 'haar_filter', 'hann_filter', 'scharr_filter', 'prewitt_filter', 'gaussian_filter', 12 | 'binomial_filter1d', 'average_filter2d', 13 | 'L2Pool2d', 14 | 'imresize', 15 | ] 16 | -------------------------------------------------------------------------------- /libs/metric/piq/functional/base.py: -------------------------------------------------------------------------------- 1 | r"""General purpose functions""" 2 | from typing import Tuple, Union, Optional 3 | import torch 4 | from ..utils import _parse_version 5 | 6 | 7 | def ifftshift(x: torch.Tensor) -> torch.Tensor: 8 | r""" Similar to np.fft.ifftshift but applies to PyTorch Tensors""" 9 | shift = [-(ax // 2) for ax in x.size()] 10 | return torch.roll(x, shift, tuple(range(len(shift)))) 11 | 12 | 13 | def get_meshgrid(size: Tuple[int, int], device: Optional[str] = None, dtype: Optional[type] = None) -> torch.Tensor: 14 | r"""Return coordinate grid matrices centered at zero point. 15 | Args: 16 | size: Shape of meshgrid to create 17 | device: device to use for creation 18 | dtype: dtype to use for creation 19 | Returns: 20 | Meshgrid of size on device with dtype values. 21 | """ 22 | if size[0] % 2: 23 | # Odd 24 | x = torch.arange(-(size[0] - 1) / 2, size[0] / 2, device=device, dtype=dtype) / (size[0] - 1) 25 | else: 26 | # Even 27 | x = torch.arange(- size[0] / 2, size[0] / 2, device=device, dtype=dtype) / size[0] 28 | 29 | if size[1] % 2: 30 | # Odd 31 | y = torch.arange(-(size[1] - 1) / 2, size[1] / 2, device=device, dtype=dtype) / (size[1] - 1) 32 | else: 33 | # Even 34 | y = torch.arange(- size[1] / 2, size[1] / 2, device=device, dtype=dtype) / size[1] 35 | # Use indexing param depending on torch version 36 | recommended_torch_version = _parse_version("1.10.0") 37 | torch_version = _parse_version(torch.__version__) 38 | if len(torch_version) > 0 and torch_version >= recommended_torch_version: 39 | return torch.meshgrid(x, y, indexing='ij') 40 | return torch.meshgrid(x, y) 41 | 42 | 43 | def similarity_map(map_x: torch.Tensor, map_y: torch.Tensor, constant: float, alpha: float = 0.0) -> torch.Tensor: 44 | r""" Compute similarity_map between two tensors using Dice-like equation. 45 | 46 | Args: 47 | map_x: Tensor with map to be compared 48 | map_y: Tensor with map to be compared 49 | constant: Used for numerical stability 50 | alpha: Masking coefficient. Subtracts - `alpha` * map_x * map_y from denominator and nominator 51 | """ 52 | return (2.0 * map_x * map_y - alpha * map_x * map_y + constant) / \ 53 | (map_x ** 2 + map_y ** 2 - alpha * map_x * map_y + constant) 54 | 55 | 56 | def gradient_map(x: torch.Tensor, kernels: torch.Tensor) -> torch.Tensor: 57 | r""" Compute gradient map for a given tensor and stack of kernels. 58 | 59 | Args: 60 | x: Tensor with shape (N, C, H, W). 61 | kernels: Stack of tensors for gradient computation with shape (k_N, k_H, k_W) 62 | Returns: 63 | Gradients of x per-channel with shape (N, C, H, W) 64 | """ 65 | padding = kernels.size(-1) // 2 66 | grads = torch.nn.functional.conv2d(x, kernels, padding=padding) 67 | 68 | return torch.sqrt(torch.sum(grads ** 2, dim=-3, keepdim=True)) 69 | 70 | 71 | def pow_for_complex(base: torch.Tensor, exp: Union[int, float]) -> torch.Tensor: 72 | r""" Takes the power of each element in a 4D tensor with negative values or 5D tensor with complex values. 73 | Complex numbers are represented by modulus and argument: r * \exp(i * \phi). 74 | 75 | It will likely to be redundant with introduction of torch.ComplexTensor. 76 | 77 | Args: 78 | base: Tensor with shape (N, C, H, W) or (N, C, H, W, 2). 79 | exp: Exponent 80 | Returns: 81 | Complex tensor with shape (N, C, H, W, 2). 82 | """ 83 | if base.dim() == 4: 84 | x_complex_r = base.abs() 85 | x_complex_phi = torch.atan2(torch.zeros_like(base), base) 86 | elif base.dim() == 5 and base.size(-1) == 2: 87 | x_complex_r = base.pow(2).sum(dim=-1).sqrt() 88 | x_complex_phi = torch.atan2(base[..., 1], base[..., 0]) 89 | else: 90 | raise ValueError(f'Expected real or complex tensor, got {base.size()}') 91 | 92 | x_complex_pow_r = x_complex_r ** exp 93 | x_complex_pow_phi = x_complex_phi * exp 94 | x_real_pow = x_complex_pow_r * torch.cos(x_complex_pow_phi) 95 | x_imag_pow = x_complex_pow_r * torch.sin(x_complex_pow_phi) 96 | return torch.stack((x_real_pow, x_imag_pow), dim=-1) 97 | 98 | 99 | def crop_patches(x: torch.Tensor, size=64, stride=32) -> torch.Tensor: 100 | r"""Crop tensor with images into small patches 101 | Args: 102 | x: Tensor with shape (N, C, H, W), expected to be images-like entities 103 | size: Size of a square patch 104 | stride: Step between patches 105 | """ 106 | assert (x.shape[2] >= size) and (x.shape[3] >= size), \ 107 | f"Images must be bigger than patch size. Got ({x.shape[2], x.shape[3]}) and ({size}, {size})" 108 | channels = x.shape[1] 109 | patches = x.unfold(1, channels, channels).unfold(2, size, stride).unfold(3, size, stride) 110 | patches = patches.reshape(-1, channels, size, size) 111 | return patches 112 | -------------------------------------------------------------------------------- /libs/metric/piq/functional/colour_conversion.py: -------------------------------------------------------------------------------- 1 | r"""Colour space conversion functions""" 2 | from typing import Union, Dict 3 | import torch 4 | 5 | 6 | def rgb2lmn(x: torch.Tensor) -> torch.Tensor: 7 | r"""Convert a batch of RGB images to a batch of LMN images 8 | 9 | Args: 10 | x: Batch of images with shape (N, 3, H, W). RGB colour space. 11 | 12 | Returns: 13 | Batch of images with shape (N, 3, H, W). LMN colour space. 14 | """ 15 | weights_rgb_to_lmn = torch.tensor([[0.06, 0.63, 0.27], 16 | [0.30, 0.04, -0.35], 17 | [0.34, -0.6, 0.17]], dtype=x.dtype, device=x.device).t() 18 | x_lmn = torch.matmul(x.permute(0, 2, 3, 1), weights_rgb_to_lmn).permute(0, 3, 1, 2) 19 | return x_lmn 20 | 21 | 22 | def rgb2xyz(x: torch.Tensor) -> torch.Tensor: 23 | r"""Convert a batch of RGB images to a batch of XYZ images 24 | 25 | Args: 26 | x: Batch of images with shape (N, 3, H, W). RGB colour space. 27 | 28 | Returns: 29 | Batch of images with shape (N, 3, H, W). XYZ colour space. 30 | """ 31 | mask_below = (x <= 0.04045).type(x.dtype) 32 | mask_above = (x > 0.04045).type(x.dtype) 33 | 34 | tmp = x / 12.92 * mask_below + torch.pow((x + 0.055) / 1.055, 2.4) * mask_above 35 | 36 | weights_rgb_to_xyz = torch.tensor([[0.4124564, 0.3575761, 0.1804375], 37 | [0.2126729, 0.7151522, 0.0721750], 38 | [0.0193339, 0.1191920, 0.9503041]], dtype=x.dtype, device=x.device) 39 | 40 | x_xyz = torch.matmul(tmp.permute(0, 2, 3, 1), weights_rgb_to_xyz.t()).permute(0, 3, 1, 2) 41 | return x_xyz 42 | 43 | 44 | def xyz2lab(x: torch.Tensor, illuminant: str = 'D50', observer: str = '2') -> torch.Tensor: 45 | r"""Convert a batch of XYZ images to a batch of LAB images 46 | 47 | Args: 48 | x: Batch of images with shape (N, 3, H, W). XYZ colour space. 49 | illuminant: {“A”, “D50”, “D55”, “D65”, “D75”, “E”}, optional. The name of the illuminant. 50 | observer: {“2”, “10”}, optional. The aperture angle of the observer. 51 | 52 | Returns: 53 | Batch of images with shape (N, 3, H, W). LAB colour space. 54 | """ 55 | epsilon = 0.008856 56 | kappa = 903.3 57 | illuminants: Dict[str, Dict] = \ 58 | {"A": {'2': (1.098466069456375, 1, 0.3558228003436005), 59 | '10': (1.111420406956693, 1, 0.3519978321919493)}, 60 | "D50": {'2': (0.9642119944211994, 1, 0.8251882845188288), 61 | '10': (0.9672062750333777, 1, 0.8142801513128616)}, 62 | "D55": {'2': (0.956797052643698, 1, 0.9214805860173273), 63 | '10': (0.9579665682254781, 1, 0.9092525159847462)}, 64 | "D65": {'2': (0.95047, 1., 1.08883), # This was: `lab_ref_white` 65 | '10': (0.94809667673716, 1, 1.0730513595166162)}, 66 | "D75": {'2': (0.9497220898840717, 1, 1.226393520724154), 67 | '10': (0.9441713925645873, 1, 1.2064272211720228)}, 68 | "E": {'2': (1.0, 1.0, 1.0), 69 | '10': (1.0, 1.0, 1.0)}} 70 | 71 | illuminants_to_use = torch.tensor(illuminants[illuminant][observer], 72 | dtype=x.dtype, device=x.device).view(1, 3, 1, 1) 73 | 74 | tmp = x / illuminants_to_use 75 | 76 | mask_below = (tmp <= epsilon).type(x.dtype) 77 | mask_above = (tmp > epsilon).type(x.dtype) 78 | tmp = torch.pow(tmp, 1. / 3.) * mask_above + (kappa * tmp + 16.) / 116. * mask_below 79 | 80 | weights_xyz_to_lab = torch.tensor([[0, 116., 0], 81 | [500., -500., 0], 82 | [0, 200., -200.]], dtype=x.dtype, device=x.device) 83 | bias_xyz_to_lab = torch.tensor([-16., 0., 0.], dtype=x.dtype, device=x.device).view(1, 3, 1, 1) 84 | 85 | x_lab = torch.matmul(tmp.permute(0, 2, 3, 1), weights_xyz_to_lab.t()).permute(0, 3, 1, 2) + bias_xyz_to_lab 86 | return x_lab 87 | 88 | 89 | def rgb2lab(x: torch.Tensor, data_range: Union[int, float] = 255) -> torch.Tensor: 90 | r"""Convert a batch of RGB images to a batch of LAB images 91 | 92 | Args: 93 | x: Batch of images with shape (N, 3, H, W). RGB colour space. 94 | data_range: dynamic range of the input image. 95 | 96 | Returns: 97 | Batch of images with shape (N, 3, H, W). LAB colour space. 98 | """ 99 | return xyz2lab(rgb2xyz(x / float(data_range))) 100 | 101 | 102 | def rgb2yiq(x: torch.Tensor) -> torch.Tensor: 103 | r"""Convert a batch of RGB images to a batch of YIQ images 104 | 105 | Args: 106 | x: Batch of images with shape (N, 3, H, W). RGB colour space. 107 | 108 | Returns: 109 | Batch of images with shape (N, 3, H, W). YIQ colour space. 110 | """ 111 | yiq_weights = torch.tensor([ 112 | [0.299, 0.587, 0.114], 113 | [0.5959, -0.2746, -0.3213], 114 | [0.2115, -0.5227, 0.3112]], dtype=x.dtype, device=x.device).t() 115 | x_yiq = torch.matmul(x.permute(0, 2, 3, 1), yiq_weights).permute(0, 3, 1, 2) 116 | return x_yiq 117 | 118 | 119 | def rgb2lhm(x: torch.Tensor) -> torch.Tensor: 120 | r"""Convert a batch of RGB images to a batch of LHM images 121 | 122 | Args: 123 | x: Batch of images with shape (N, 3, H, W). RGB colour space. 124 | 125 | Returns: 126 | Batch of images with shape (N, 3, H, W). LHM colour space. 127 | 128 | Reference: 129 | https://arxiv.org/pdf/1608.07433.pdf 130 | """ 131 | lhm_weights = torch.tensor([ 132 | [0.2989, 0.587, 0.114], 133 | [0.3, 0.04, -0.35], 134 | [0.34, -0.6, 0.17]], dtype=x.dtype, device=x.device).t() 135 | x_lhm = torch.matmul(x.permute(0, 2, 3, 1), lhm_weights).permute(0, 3, 1, 2) 136 | return x_lhm 137 | -------------------------------------------------------------------------------- /libs/metric/piq/functional/filters.py: -------------------------------------------------------------------------------- 1 | r"""Filters for gradient computation, bluring, etc.""" 2 | import torch 3 | import numpy as np 4 | from typing import Optional 5 | 6 | 7 | def haar_filter(kernel_size: int, device: Optional[str] = None, dtype: Optional[type] = None) -> torch.Tensor: 8 | r"""Creates Haar kernel 9 | 10 | Args: 11 | kernel_size: size of the kernel 12 | device: target device for kernel generation 13 | dtype: target data type for kernel generation 14 | Returns: 15 | kernel: Tensor with shape (1, kernel_size, kernel_size) 16 | """ 17 | kernel = torch.ones((kernel_size, kernel_size), device=device, dtype=dtype) / kernel_size 18 | kernel[kernel_size // 2:, :] = - kernel[kernel_size // 2:, :] 19 | return kernel.unsqueeze(0) 20 | 21 | 22 | def hann_filter(kernel_size: int, device: Optional[str] = None, dtype: Optional[type] = None) -> torch.Tensor: 23 | r"""Creates Hann kernel 24 | Args: 25 | kernel_size: size of the kernel 26 | device: target device for kernel generation 27 | dtype: target data type for kernel generation 28 | Returns: 29 | kernel: Tensor with shape (1, kernel_size, kernel_size) 30 | """ 31 | # Take bigger window and drop borders 32 | window = torch.hann_window(kernel_size + 2, periodic=False, device=device, dtype=dtype)[1:-1] 33 | kernel = window[:, None] * window[None, :] 34 | # Normalize and reshape kernel 35 | return kernel.view(1, kernel_size, kernel_size) / kernel.sum() 36 | 37 | 38 | def gaussian_filter(kernel_size: int, sigma: float, device: Optional[str] = None, 39 | dtype: Optional[type] = None) -> torch.Tensor: 40 | r"""Returns 2D Gaussian kernel N(0,`sigma`^2) 41 | Args: 42 | size: Size of the kernel 43 | sigma: Std of the distribution 44 | device: target device for kernel generation 45 | dtype: target data type for kernel generation 46 | Returns: 47 | gaussian_kernel: Tensor with shape (1, kernel_size, kernel_size) 48 | """ 49 | coords = torch.arange(kernel_size, dtype=dtype, device=device) 50 | coords -= (kernel_size - 1) / 2. 51 | 52 | g = coords ** 2 53 | g = (- (g.unsqueeze(0) + g.unsqueeze(1)) / (2 * sigma ** 2)).exp() 54 | 55 | g /= g.sum() 56 | return g.unsqueeze(0) 57 | 58 | 59 | # Gradient operator kernels 60 | def scharr_filter(device: Optional[str] = None, dtype: Optional[type] = None) -> torch.Tensor: 61 | r"""Utility function that returns a normalized 3x3 Scharr kernel in X direction 62 | 63 | Args: 64 | device: target device for kernel generation 65 | dtype: target data type for kernel generation 66 | Returns: 67 | kernel: Tensor with shape (1, 3, 3) 68 | """ 69 | return torch.tensor([[[-3., 0., 3.], [-10., 0., 10.], [-3., 0., 3.]]], device=device, dtype=dtype) / 16 70 | 71 | 72 | def prewitt_filter(device: Optional[str] = None, dtype: Optional[type] = None) -> torch.Tensor: 73 | r"""Utility function that returns a normalized 3x3 Prewitt kernel in X direction 74 | 75 | Args: 76 | device: target device for kernel generation 77 | dtype: target data type for kernel generation 78 | Returns: 79 | kernel: Tensor with shape (1, 3, 3)""" 80 | return torch.tensor([[[-1., 0., 1.], [-1., 0., 1.], [-1., 0., 1.]]], device=device, dtype=dtype) / 3 81 | 82 | 83 | def binomial_filter1d(kernel_size: int, device: Optional[str] = None, dtype: Optional[type] = None) -> torch.Tensor: 84 | r"""Creates 1D normalized binomial filter 85 | 86 | Args: 87 | kernel_size (int): kernel size 88 | device: target device for kernel generation 89 | dtype: target data type for kernel generation 90 | 91 | Returns: 92 | Binomial kernel with shape (1, 1, kernel_size) 93 | """ 94 | kernel = np.poly1d([0.5, 0.5]) ** (kernel_size - 1) 95 | return torch.tensor(kernel.c, dtype=dtype, device=device).view(1, 1, kernel_size) 96 | 97 | 98 | def average_filter2d(kernel_size: int, device: Optional[str] = None, dtype: Optional[type] = None) -> torch.Tensor: 99 | r"""Creates 2D normalized average filter 100 | 101 | Args: 102 | kernel_size (int): kernel size 103 | device: target device for kernel generation 104 | dtype: target data type for kernel generation 105 | 106 | Returns: 107 | kernel: Tensor with shape (1, kernel_size, kernel_size) 108 | """ 109 | window = torch.ones(kernel_size, dtype=dtype, device=device) / kernel_size 110 | kernel = window[:, None] * window[None, :] 111 | return kernel.unsqueeze(0) 112 | -------------------------------------------------------------------------------- /libs/metric/piq/functional/layers.py: -------------------------------------------------------------------------------- 1 | r"""Custom layers used in metrics computations""" 2 | import torch 3 | from typing import Optional 4 | 5 | from .filters import hann_filter 6 | 7 | 8 | class L2Pool2d(torch.nn.Module): 9 | r"""Applies L2 pooling with Hann window of size 3x3 10 | Args: 11 | x: Tensor with shape (N, C, H, W)""" 12 | EPS = 1e-12 13 | 14 | def __init__(self, kernel_size: int = 3, stride: int = 2, padding=1) -> None: 15 | super().__init__() 16 | self.kernel_size = kernel_size 17 | self.stride = stride 18 | self.padding = padding 19 | 20 | self.kernel: Optional[torch.Tensor] = None 21 | 22 | def forward(self, x: torch.Tensor) -> torch.Tensor: 23 | if self.kernel is None: 24 | C = x.size(1) 25 | self.kernel = hann_filter(self.kernel_size).repeat((C, 1, 1, 1)).to(x) 26 | 27 | out = torch.nn.functional.conv2d( 28 | x ** 2, self.kernel, 29 | stride=self.stride, 30 | padding=self.padding, 31 | groups=x.shape[1] 32 | ) 33 | return (out + self.EPS).sqrt() 34 | -------------------------------------------------------------------------------- /libs/metric/piq/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .common import _validate_input, _reduce, _parse_version 2 | 3 | __all__ = [ 4 | "_validate_input", 5 | "_reduce", 6 | '_parse_version' 7 | ] 8 | -------------------------------------------------------------------------------- /libs/metric/piq/utils/common.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import re 3 | import warnings 4 | 5 | from typing import Tuple, List, Optional, Union, Dict, Any 6 | 7 | SEMVER_VERSION_PATTERN = re.compile( 8 | r""" 9 | ^ 10 | (?P0|[1-9]\d*) 11 | \. 12 | (?P0|[1-9]\d*) 13 | \. 14 | (?P0|[1-9]\d*) 15 | (?:-(?P 16 | (?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*) 17 | (?:\.(?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*))* 18 | ))? 19 | (?:\+(?P 20 | [0-9a-zA-Z-]+ 21 | (?:\.[0-9a-zA-Z-]+)* 22 | ))? 23 | $ 24 | """, 25 | re.VERBOSE, 26 | ) 27 | 28 | 29 | PEP_440_VERSION_PATTERN = r""" 30 | v? 31 | (?: 32 | (?:(?P[0-9]+)!)? # epoch 33 | (?P[0-9]+(?:\.[0-9]+)*) # release segment 34 | (?P
                                          # pre-release
 35 |             [-_\.]?
 36 |             (?P(a|b|c|rc|alpha|beta|pre|preview))
 37 |             [-_\.]?
 38 |             (?P[0-9]+)?
 39 |         )?
 40 |         (?P                                         # post release
 41 |             (?:-(?P[0-9]+))
 42 |             |
 43 |             (?:
 44 |                 [-_\.]?
 45 |                 (?Ppost|rev|r)
 46 |                 [-_\.]?
 47 |                 (?P[0-9]+)?
 48 |             )
 49 |         )?
 50 |         (?P                                          # dev release
 51 |             [-_\.]?
 52 |             (?Pdev)
 53 |             [-_\.]?
 54 |             (?P[0-9]+)?
 55 |         )?
 56 |     )
 57 |     (?:\+(?P[a-z0-9]+(?:[-_\.][a-z0-9]+)*))?       # local version
 58 | """
 59 | 
 60 | 
 61 | def _validate_input(
 62 |         tensors: List[torch.Tensor],
 63 |         dim_range: Tuple[int, int] = (0, -1),
 64 |         data_range: Tuple[float, float] = (0., -1.),
 65 |         # size_dim_range: Tuple[float, float] = (0., -1.),
 66 |         size_range: Optional[Tuple[int, int]] = None,
 67 | ) -> None:
 68 |     r"""Check that input(-s)  satisfies the requirements
 69 |     Args:
 70 |         tensors: Tensors to check
 71 |         dim_range: Allowed number of dimensions. (min, max)
 72 |         data_range: Allowed range of values in tensors. (min, max)
 73 |         size_range: Dimensions to include in size comparison. (start_dim, end_dim + 1)
 74 |     """
 75 | 
 76 |     if not __debug__:
 77 |         return
 78 | 
 79 |     x = tensors[0]
 80 | 
 81 |     for t in tensors:
 82 |         assert torch.is_tensor(t), f'Expected torch.Tensor, got {type(t)}'
 83 |         assert t.device == x.device, f'Expected tensors to be on {x.device}, got {t.device}'
 84 | 
 85 |         if size_range is None:
 86 |             assert t.size() == x.size(), f'Expected tensors with same size, got {t.size()} and {x.size()}'
 87 |         else:
 88 |             assert t.size()[size_range[0]: size_range[1]] == x.size()[size_range[0]: size_range[1]], \
 89 |                 f'Expected tensors with same size at given dimensions, got {t.size()} and {x.size()}'
 90 | 
 91 |         if dim_range[0] == dim_range[1]:
 92 |             assert t.dim() == dim_range[0], f'Expected number of dimensions to be {dim_range[0]}, got {t.dim()}'
 93 |         elif dim_range[0] < dim_range[1]:
 94 |             assert dim_range[0] <= t.dim() <= dim_range[1], \
 95 |                 f'Expected number of dimensions to be between {dim_range[0]} and {dim_range[1]}, got {t.dim()}'
 96 | 
 97 |         if data_range[0] < data_range[1]:
 98 |             assert data_range[0] <= t.min(), \
 99 |                 f'Expected values to be greater or equal to {data_range[0]}, got {t.min()}'
100 |             assert t.max() <= data_range[1], \
101 |                 f'Expected values to be lower or equal to {data_range[1]}, got {t.max()}'
102 | 
103 | 
104 | def _reduce(x: torch.Tensor, reduction: str = 'mean') -> torch.Tensor:
105 |     r"""Reduce input in batch dimension if needed.
106 | 
107 |     Args:
108 |         x: Tensor with shape (N, *).
109 |         reduction: Specifies the reduction type:
110 |             ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'mean'``
111 |     """
112 |     if reduction == 'none':
113 |         return x
114 |     elif reduction == 'mean':
115 |         return x.mean(dim=0)
116 |     elif reduction == 'sum':
117 |         return x.sum(dim=0)
118 |     else:
119 |         raise ValueError("Unknown reduction. Expected one of {'none', 'mean', 'sum'}")
120 | 
121 | 
122 | def _parse_version(version: Union[str, bytes]) -> Tuple[int, ...]:
123 |     """ Parses valid Python versions according to Semver and PEP 440 specifications.
124 |     For more on Semver check: https://semver.org/
125 |     For more on PEP 440 check: https://www.python.org/dev/peps/pep-0440/.
126 | 
127 |     Implementation is inspired by:
128 |     - https://github.com/python-semver
129 |     - https://github.com/pypa/packaging
130 | 
131 |     Args:
132 |         version: unparsed information about the library of interest.
133 | 
134 |     Returns:
135 |         parsed information about the library of interest.
136 |     """
137 |     if isinstance(version, bytes):
138 |         version = version.decode("UTF-8")
139 |     elif not isinstance(version, str) and not isinstance(version, bytes):
140 |         raise TypeError(f"not expecting type {type(version)}")
141 | 
142 |     # Semver processing
143 |     match = SEMVER_VERSION_PATTERN.match(version)
144 |     if match:
145 |         matched_version_parts: Dict[str, Any] = match.groupdict()
146 |         release = tuple([int(matched_version_parts[k]) for k in ['major', 'minor', 'patch']])
147 |         return release
148 | 
149 |     # PEP 440 processing
150 |     regex = re.compile(r"^\s*" + PEP_440_VERSION_PATTERN + r"\s*$", re.VERBOSE | re.IGNORECASE)
151 |     match = regex.search(version)
152 | 
153 |     if match is None:
154 |         warnings.warn(f"{version} is not a valid SemVer or PEP 440 string")
155 |         return tuple()
156 | 
157 |     release = tuple(int(i) for i in match.group("release").split("."))
158 |     return release
159 | 


--------------------------------------------------------------------------------
/libs/metric/pytorch_fid/__init__.py:
--------------------------------------------------------------------------------
 1 | __version__ = '0.3.0'
 2 | 
 3 | import torch
 4 | from einops import rearrange, repeat
 5 | 
 6 | from .inception import InceptionV3
 7 | from .fid_score import calculate_frechet_distance
 8 | 
 9 | 
10 | class PytorchFIDFactory(torch.nn.Module):
11 |     """
12 | 
13 |    Args:
14 |        channels:
15 |        inception_block_idx:
16 | 
17 |     Examples:
18 |     >>> fid_factory =  PytorchFIDFactory()
19 |     >>> fid_score = fid_factory.score(real_samples=data, fake_samples=all_images)
20 |     >>> print(fid_score)
21 |    """
22 | 
23 |     def __init__(self, channels: int = 3, inception_block_idx: int = 2048):
24 |         super().__init__()
25 |         self.channels = channels
26 | 
27 |         # load models
28 |         assert inception_block_idx in InceptionV3.BLOCK_INDEX_BY_DIM
29 |         block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[inception_block_idx]
30 |         self.inception_v3 = InceptionV3([block_idx])
31 | 
32 |     @torch.no_grad()
33 |     def calculate_activation_statistics(self, samples):
34 |         features = self.inception_v3(samples)[0]
35 |         features = rearrange(features, '... 1 1 -> ...')
36 | 
37 |         mu = torch.mean(features, dim=0).cpu()
38 |         sigma = torch.cov(features).cpu()
39 |         return mu, sigma
40 | 
41 |     def score(self, real_samples, fake_samples):
42 |         if self.channels == 1:
43 |             real_samples, fake_samples = map(
44 |                 lambda t: repeat(t, 'b 1 ... -> b c ...', c=3), (real_samples, fake_samples)
45 |             )
46 | 
47 |         min_batch = min(real_samples.shape[0], fake_samples.shape[0])
48 |         real_samples, fake_samples = map(lambda t: t[:min_batch], (real_samples, fake_samples))
49 | 
50 |         m1, s1 = self.calculate_activation_statistics(real_samples)
51 |         m2, s2 = self.calculate_activation_statistics(fake_samples)
52 | 
53 |         fid_value = calculate_frechet_distance(m1, s1, m2, s2)
54 |         return fid_value
55 | 


--------------------------------------------------------------------------------
/libs/modules/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright (c) XiMing Xing. All rights reserved.
3 | # Author: XiMing Xing
4 | # Description:
5 | 


--------------------------------------------------------------------------------
/libs/modules/resizer/__init__.py:
--------------------------------------------------------------------------------
 1 | # -*- coding: utf-8 -*-
 2 | # Copyright (c) XiMing Xing. All rights reserved.
 3 | # Author: XiMing Xing
 4 | # Description:
 5 | 
 6 | from .resizer import resize
 7 | from . import interp_methods
 8 | 
 9 | __all__ = ['resize', 'interp_methods']
10 | 


--------------------------------------------------------------------------------
/libs/modules/resizer/interp_methods.py:
--------------------------------------------------------------------------------
 1 | # -*- coding: utf-8 -*-
 2 | # Description:
 3 | 
 4 | from math import pi
 5 | 
 6 | try:
 7 |     import torch
 8 | except ImportError:
 9 |     torch = None
10 | 
11 | try:
12 |     import numpy
13 | except ImportError:
14 |     numpy = None
15 | 
16 | if numpy is None and torch is None:
17 |     raise ImportError("Must have either Numpy or PyTorch but both not found")
18 | 
19 | 
20 | def set_framework_dependencies(x):
21 |     if type(x) is numpy.ndarray:
22 |         to_dtype = lambda a: a
23 |         fw = numpy
24 |     else:
25 |         to_dtype = lambda a: a.to(x.dtype)
26 |         fw = torch
27 |     eps = fw.finfo(fw.float32).eps
28 |     return fw, to_dtype, eps
29 | 
30 | 
31 | def support_sz(sz):
32 |     def wrapper(f):
33 |         f.support_sz = sz
34 |         return f
35 | 
36 |     return wrapper
37 | 
38 | 
39 | @support_sz(4)
40 | def cubic(x):
41 |     fw, to_dtype, eps = set_framework_dependencies(x)
42 |     absx = fw.abs(x)
43 |     absx2 = absx ** 2
44 |     absx3 = absx ** 3
45 |     return ((1.5 * absx3 - 2.5 * absx2 + 1.) * to_dtype(absx <= 1.) +
46 |             (-0.5 * absx3 + 2.5 * absx2 - 4. * absx + 2.) *
47 |             to_dtype((1. < absx) & (absx <= 2.)))
48 | 
49 | 
50 | @support_sz(4)
51 | def lanczos2(x):
52 |     fw, to_dtype, eps = set_framework_dependencies(x)
53 |     return (((fw.sin(pi * x) * fw.sin(pi * x / 2) + eps) /
54 |              ((pi ** 2 * x ** 2 / 2) + eps)) * to_dtype(abs(x) < 2))
55 | 
56 | 
57 | @support_sz(6)
58 | def lanczos3(x):
59 |     fw, to_dtype, eps = set_framework_dependencies(x)
60 |     return (((fw.sin(pi * x) * fw.sin(pi * x / 3) + eps) /
61 |              ((pi ** 2 * x ** 2 / 3) + eps)) * to_dtype(abs(x) < 3))
62 | 
63 | 
64 | @support_sz(2)
65 | def linear(x):
66 |     fw, to_dtype, eps = set_framework_dependencies(x)
67 |     return ((x + 1) * to_dtype((-1 <= x) & (x < 0)) + (1 - x) *
68 |             to_dtype((0 <= x) & (x <= 1)))
69 | 
70 | 
71 | @support_sz(1)
72 | def box(x):
73 |     fw, to_dtype, eps = set_framework_dependencies(x)
74 |     return to_dtype((-1 <= x) & (x < 0)) + to_dtype((0 <= x) & (x <= 1))
75 | 


--------------------------------------------------------------------------------
/libs/modules/vision/__init__.py:
--------------------------------------------------------------------------------
 1 | # -*- coding: utf-8 -*-
 2 | # Copyright (c) XiMing Xing. All rights reserved.
 3 | # Author: XiMing Xing
 4 | # Description:
 5 | 
 6 | from .inception import inception_v3
 7 | from .vgg import vgg16, vgg19
 8 | 
 9 | __all__ = [
10 |     'inception_v3',
11 |     'vgg16',
12 |     'vgg19'
13 | ]
14 | 


--------------------------------------------------------------------------------
/libs/solver/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright (c) XiMing Xing. All rights reserved.
3 | # Author: XiMing Xing
4 | # Description:
5 | 


--------------------------------------------------------------------------------
/libs/utils/__init__.py:
--------------------------------------------------------------------------------
 1 | # -*- coding: utf-8 -*-
 2 | # Copyright (c) XiMing Xing. All rights reserved.
 3 | # Author: XiMing Xing
 4 | # Description:
 5 | from . import lazy
 6 | 
 7 | # __getattr__, __dir__, __all__ = lazy.attach(
 8 | #     __name__,
 9 | #     submodules={},
10 | #     submod_attrs={
11 | #         'misc': ['identity', 'exists', 'default', 'has_int_squareroot', 'sum_params', 'cycle', 'num_to_groups',
12 | #                  'extract', 'normalize', 'unnormalize'],
13 | #         'tqdm': ['tqdm_decorator'],
14 | #         'lazy': ['load']
15 | #     }
16 | # )
17 | 
18 | from .misc import (
19 |     identity,
20 |     exists,
21 |     default,
22 |     has_int_squareroot,
23 |     sum_params,
24 |     cycle,
25 |     num_to_groups,
26 |     extract,
27 |     normalize,
28 |     unnormalize
29 | )
30 | from .tqdm import tqdm_decorator
31 | 


--------------------------------------------------------------------------------
/libs/utils/imshow.py:
--------------------------------------------------------------------------------
  1 | # -*- coding: utf-8 -*-
  2 | # Copyright (c) XiMing Xing. All rights reserved.
  3 | # Author: XiMing Xing
  4 | # Description:
  5 | 
  6 | import pathlib
  7 | from pathlib import Path
  8 | from typing import Union, List, Text, BinaryIO
  9 | 
 10 | import matplotlib.pyplot as plt
 11 | import torch
 12 | import torchvision.transforms as transforms
 13 | 
 14 | __all__ = [
 15 |     'show_tensor_image',
 16 |     'show_images',
 17 |     'simulate_forward_diffusion',
 18 |     'save_grid_images_and_labels',
 19 |     'save_grid_images_and_captions'
 20 | ]
 21 | 
 22 | reverse_transforms = transforms.Compose([
 23 |     # unnormalizing to [0,1]
 24 |     transforms.Lambda(lambda t: torch.clamp((t + 1) / 2, min=0.0, max=1.0)),
 25 |     # Add 0.5 after unnormalizing to [0, 255]
 26 |     transforms.Lambda(lambda t: torch.clamp(t * 255. + 0.5, min=0, max=255)),
 27 |     # CHW to HWC
 28 |     transforms.Lambda(lambda t: t.permute(1, 2, 0)),
 29 |     # to numpy ndarray, dtype int8
 30 |     transforms.Lambda(lambda t: t.to('cpu', torch.uint8).numpy()),
 31 |     # Converts a numpy ndarray of shape H x W x C to a PIL Image
 32 |     transforms.ToPILImage(),
 33 | ])
 34 | 
 35 | 
 36 | def show_tensor_image(image, title="", f_name=None):
 37 |     # Take first image of batch
 38 |     if len(image.shape) == 4:
 39 |         image = image[0, :, :, :]
 40 |     plt.imshow(reverse_transforms(image))
 41 |     plt.title(title)
 42 | 
 43 |     if f_name is not None and Path(f_name).is_file():
 44 |         plt.savefig(f_name)
 45 |     plt.close()
 46 | 
 47 | 
 48 | def show_images(dataset, num_samples=20, cols=4):
 49 |     """ Plots some samples from the dataset """
 50 |     plt.figure(figsize=(15, 15))
 51 |     for i, img in enumerate(dataset):
 52 |         if i == num_samples:
 53 |             break
 54 |         plt.subplot(num_samples / cols + 1, cols, i + 1)
 55 |         plt.imshow(img[0])
 56 |     plt.close()
 57 | 
 58 | 
 59 | def simulate_forward_diffusion(
 60 |         image,
 61 |         dataloader: torch.utils.data.DataLoader,
 62 |         T: int,
 63 |         ddpm: torch.nn.Module,
 64 |         num_images: int,
 65 | ):
 66 |     """ Simulate forward diffusion
 67 |     Args:
 68 |         image: add noise to this image
 69 |                image = next(iter(dataloader))[0]
 70 |         dataloader:
 71 |         T:
 72 |         ddpm:
 73 |         num_images:
 74 |     """
 75 |     plt.figure(figsize=(15, 15))
 76 |     plt.axis('off')
 77 | 
 78 |     stepsize = int(T / num_images)
 79 | 
 80 |     for idx in range(0, T, stepsize):
 81 |         t = torch.Tensor([idx]).type(torch.int64)
 82 |         plt.subplot(1, num_images + 1, (idx / stepsize) + 1)
 83 |         image, noise = ddpm.q_sample(image, t)
 84 |         show_tensor_image(image)
 85 | 
 86 |     plt.savefig(f"forward-step-{stepsize}.png")
 87 |     plt.close()
 88 | 
 89 | 
 90 | @torch.no_grad()
 91 | def save_grid_images_and_labels(
 92 |         images: Union[torch.Tensor, List[torch.Tensor]],
 93 |         probs: Union[torch.Tensor, List[torch.Tensor]],
 94 |         labels: Union[torch.Tensor, List[torch.Tensor]],
 95 |         classes: Union[torch.Tensor, List[torch.Tensor]],
 96 |         fp: Union[Text, pathlib.Path, BinaryIO],
 97 |         nrow: int = 4,
 98 |         normalize: bool = True
 99 | ) -> None:
100 |     """Save a given Tensor into an image file.
101 |     """
102 |     num_images = len(images)
103 |     num_rows, num_cols = get_subplot_shape(num_images, nrow)
104 | 
105 |     fig = plt.figure(figsize=(25, 20))
106 | 
107 |     for i in range(num_images):
108 |         ax = fig.add_subplot(num_rows, num_cols, i + 1)
109 | 
110 |         image, true_label, prob = images[i], labels[i], probs[i]
111 | 
112 |         true_prob = prob[true_label]
113 |         incorrect_prob, incorrect_label = torch.max(prob, dim=0)
114 |         true_class = classes[true_label]
115 | 
116 |         incorrect_class = classes[incorrect_label]
117 | 
118 |         if normalize:
119 |             image = reverse_transforms(image)
120 | 
121 |         ax.imshow(image)
122 |         title = f'true label: {true_class} ({true_prob:.3f})\n ' \
123 |                 f'pred label: {incorrect_class} ({incorrect_prob:.3f})'
124 |         ax.set_title(title, fontsize=20)
125 |         ax.axis('off')
126 | 
127 |     fig.subplots_adjust(hspace=0.3)
128 | 
129 |     plt.savefig(fp)
130 |     plt.close()
131 | 
132 | 
133 | @torch.no_grad()
134 | def save_grid_images_and_captions(
135 |         images: Union[torch.Tensor, List[torch.Tensor]],
136 |         captions: List,
137 |         fp: Union[Text, pathlib.Path, BinaryIO],
138 |         nrow: int = 4,
139 |         normalize: bool = True
140 | ) -> None:
141 |     """
142 |     Save a grid of images and their captions into an image file.
143 | 
144 |     Args:
145 |         images (Union[torch.Tensor, List[torch.Tensor]]): A list of images to display.
146 |         captions (List): A list of captions for each image.
147 |         fp (Union[Text, pathlib.Path, BinaryIO]): The file path to save the image to.
148 |         nrow (int, optional): The number of images to display in each row. Defaults to 4.
149 |         normalize (bool, optional): Whether to normalize the image or not. Defaults to False.
150 |     """
151 |     num_images = len(images)
152 |     num_rows, num_cols = get_subplot_shape(num_images, nrow)
153 | 
154 |     fig = plt.figure(figsize=(25, 20))
155 | 
156 |     for i in range(num_images):
157 |         ax = fig.add_subplot(num_rows, num_cols, i + 1)
158 |         image, caption = images[i], captions[i]
159 | 
160 |         if normalize:
161 |             image = reverse_transforms(image)
162 | 
163 |         ax.imshow(image)
164 |         title = f'"{caption}"' if num_images > 1 else f'"{captions}"'
165 |         title = insert_newline(title)
166 |         ax.set_title(title, fontsize=20)
167 |         ax.axis('off')
168 | 
169 |     fig.subplots_adjust(hspace=0.3)
170 | 
171 |     plt.savefig(fp)
172 |     plt.close()
173 | 
174 | 
175 | def get_subplot_shape(num_images, nrow):
176 |     """
177 |     Calculate the number of rows and columns required to display images in a grid.
178 | 
179 |     Args:
180 |         num_images (int): The total number of images to display.
181 |         nrow (int): The maximum number of images to display in each row.
182 | 
183 |     Returns:
184 |         Tuple[int, int]: The number of rows and columns required to display images in a grid.
185 |     """
186 |     num_cols = min(num_images, nrow)
187 |     num_rows = (num_images + num_cols - 1) // num_cols
188 |     return num_rows, num_cols
189 | 
190 | 
191 | def insert_newline(string, point=9):
192 |     # split by blank
193 |     words = string.split()
194 |     if len(words) <= point:
195 |         return string
196 | 
197 |     word_chunks = [words[i:i + point] for i in range(0, len(words), point)]
198 |     new_string = "\n".join(" ".join(chunk) for chunk in word_chunks)
199 |     return new_string
200 | 


--------------------------------------------------------------------------------
/libs/utils/lazy.py:
--------------------------------------------------------------------------------
  1 | # -*- coding: utf-8 -*-
  2 | # Copyright (c) XiMing Xing. All rights reserved.
  3 | # Author: XiMing Xing
  4 | # Description:
  5 | 
  6 | import importlib
  7 | import importlib.util
  8 | import os
  9 | import sys
 10 | 
 11 | 
 12 | def attach(package_name, submodules=None, submod_attrs=None):
 13 |     """Attach lazily loaded submodules, functions, or other attributes.
 14 | 
 15 |     Typically, modules import submodules and attributes as follows::
 16 | 
 17 |       import mysubmodule
 18 |       import anothersubmodule
 19 | 
 20 |       from .foo import someattr
 21 | 
 22 |     The idea is to replace a package's `__getattr__`, `__dir__`, and
 23 |     `__all__`, such that all imports work exactly the way they did
 24 |     before, except that they are only imported when used.
 25 | 
 26 |     The typical way to call this function, replacing the above imports, is::
 27 | 
 28 |       __getattr__, __lazy_dir__, __all__ = lazy.attach(
 29 |         __name__,
 30 |         ['mysubmodule', 'anothersubmodule'],
 31 |         {'foo': 'someattr'}
 32 |       )
 33 | 
 34 |     This functionality requires Python 3.7 or higher.
 35 | 
 36 |     Parameters
 37 |     ----------
 38 |     package_name : str
 39 |         Typically use ``__name__``.
 40 |     submodules : set
 41 |         List of submodules to attach.
 42 |     submod_attrs : dict
 43 |         Dictionary of submodule -> list of attributes / functions.
 44 |         These attributes are imported as they are used.
 45 | 
 46 |     Returns
 47 |     -------
 48 |     __getattr__, __dir__, __all__
 49 | 
 50 |     """
 51 |     if submod_attrs is None:
 52 |         submod_attrs = {}
 53 | 
 54 |     if submodules is None:
 55 |         submodules = set()
 56 |     else:
 57 |         submodules = set(submodules)
 58 | 
 59 |     attr_to_modules = {
 60 |         attr: mod for mod, attrs in submod_attrs.items() for attr in attrs
 61 |     }
 62 | 
 63 |     __all__ = list(submodules | attr_to_modules.keys())
 64 | 
 65 |     def __getattr__(name):
 66 |         if name in submodules:
 67 |             return importlib.import_module(f'{package_name}.{name}')
 68 |         elif name in attr_to_modules:
 69 |             submod = importlib.import_module(
 70 |                 f'{package_name}.{attr_to_modules[name]}'
 71 |             )
 72 |             return getattr(submod, name)
 73 |         else:
 74 |             raise AttributeError(f'No {package_name} attribute {name}')
 75 | 
 76 |     def __dir__():
 77 |         return __all__
 78 | 
 79 |     eager_import = os.environ.get('EAGER_IMPORT', '')
 80 |     if eager_import not in ['', '0', 'false']:
 81 |         for attr in set(attr_to_modules.keys()) | submodules:
 82 |             __getattr__(attr)
 83 | 
 84 |     return __getattr__, __dir__, list(__all__)
 85 | 
 86 | 
 87 | def load(fullname):
 88 |     """Return a lazily imported proxy for a module.
 89 | 
 90 |     We often see the following pattern::
 91 | 
 92 |       def myfunc():
 93 |           import scipy as sp
 94 |           sp.argmin(...)
 95 |           ....
 96 | 
 97 |     This is to prevent a module, in this case `scipy`, from being
 98 |     imported at function definition time, since that can be slow.
 99 | 
100 |     This function provides a proxy module that, upon access, imports
101 |     the actual module.  So the idiom equivalent to the above example is::
102 | 
103 |       sp = lazy.load("scipy")
104 | 
105 |       def myfunc():
106 |           sp.argmin(...)
107 |           ....
108 | 
109 |     The initial import time is fast because the actual import is delayed
110 |     until the first attribute is requested. The overall import time may
111 |     decrease as well for users that don't make use of large portions
112 |     of the library.
113 | 
114 |     Parameters
115 |     ----------
116 |     fullname : str
117 |         The full name of the module or submodule to import.  For example::
118 | 
119 |           sp = lazy.load('scipy')  # import scipy as sp
120 |           spla = lazy.load('scipy.linalg')  # import scipy.linalg as spla
121 | 
122 |     Returns
123 |     -------
124 |     pm : importlib.util._LazyModule
125 |         Proxy module.  Can be used like any regularly imported module.
126 |         Actual loading of the module occurs upon first attribute request.
127 | 
128 |     """
129 |     try:
130 |         return sys.modules[fullname]
131 |     except KeyError:
132 |         pass
133 | 
134 |     spec = importlib.util.find_spec(fullname)
135 |     if spec is None:
136 |         raise ModuleNotFoundError(f"No module name '{fullname}'")
137 | 
138 |     module = importlib.util.module_from_spec(spec)
139 |     sys.modules[fullname] = module
140 | 
141 |     loader = importlib.util.LazyLoader(spec.loader)
142 |     loader.exec_module(module)
143 | 
144 |     return module
145 | 


--------------------------------------------------------------------------------
/libs/utils/logging.py:
--------------------------------------------------------------------------------
 1 | # -*- coding: utf-8 -*-
 2 | # Copyright (c) XiMing Xing. All rights reserved.
 3 | # Author: XiMing Xing
 4 | # Description:
 5 | 
 6 | import os
 7 | import sys
 8 | import errno
 9 | 
10 | 
11 | def get_logger(logs_dir: str, file_name: str = "log.txt"):
12 |     logger = PrintLogger(os.path.join(logs_dir, file_name))
13 |     sys.stdout = logger  # record all python print
14 |     return logger
15 | 
16 | 
17 | class PrintLogger(object):
18 | 
19 |     def __init__(self, fpath=None):
20 |         """
21 |         python standard input/output records
22 |         """
23 |         self.console = sys.stdout
24 |         self.file = None
25 |         if fpath is not None:
26 |             mkdir_if_missing(os.path.dirname(fpath))
27 |             self.file = open(fpath, 'w')
28 | 
29 |     def __del__(self):
30 |         self.close()
31 | 
32 |     def __enter__(self):
33 |         pass
34 | 
35 |     def __exit__(self, *args):
36 |         self.close()
37 | 
38 |     def write(self, msg):
39 |         self.console.write(msg)
40 |         if self.file is not None:
41 |             self.file.write(msg)
42 | 
43 |     def write_in(self, msg):
44 |         """write in log only, not console"""
45 |         if self.file is not None:
46 |             self.file.write(msg)
47 | 
48 |     def flush(self):
49 |         self.console.flush()
50 |         if self.file is not None:
51 |             self.file.flush()
52 |             os.fsync(self.file.fileno())
53 | 
54 |     def close(self):
55 |         self.console.close()
56 |         if self.file is not None:
57 |             self.file.close()
58 | 
59 | 
60 | def mkdir_if_missing(dir_path):
61 |     try:
62 |         os.makedirs(dir_path)
63 |     except OSError as e:
64 |         if e.errno != errno.EEXIST:
65 |             raise
66 | 


--------------------------------------------------------------------------------
/libs/utils/meter.py:
--------------------------------------------------------------------------------
 1 | # -*- coding: utf-8 -*-
 2 | # Copyright (c) XiMing Xing. All rights reserved.
 3 | # Author: XiMing Xing
 4 | # Description:
 5 | 
 6 | from enum import Enum
 7 | 
 8 | import torch
 9 | import torch.distributed as dist
10 | 
11 | 
12 | class Summary(Enum):
13 |     NONE = 0
14 |     AVERAGE = 1
15 |     SUM = 2
16 |     COUNT = 3
17 | 
18 | 
19 | class AverageMeter(object):
20 |     """Computes and stores the average and current value"""
21 | 
22 |     def __init__(self, name, fmt=':f', summary_type=Summary.AVERAGE):
23 |         self.name = name
24 |         self.fmt = fmt
25 |         self.summary_type = summary_type
26 |         self.reset()
27 | 
28 |     def reset(self):
29 |         self.val = 0
30 |         self.avg = 0
31 |         self.sum = 0
32 |         self.count = 0
33 | 
34 |     def update(self, val, n=1):
35 |         self.val = val
36 |         self.sum += val * n
37 |         self.count += n
38 |         self.avg = self.sum / self.count
39 | 
40 |     def all_reduce(self):
41 |         if torch.cuda.is_available():
42 |             device = torch.device("cuda")
43 |         elif torch.backends.mps.is_available():
44 |             device = torch.device("mps")
45 |         else:
46 |             device = torch.device("cpu")
47 | 
48 |         total = torch.tensor([self.sum, self.count], dtype=torch.float32, device=device)
49 |         dist.all_reduce(total, dist.ReduceOp.SUM, async_op=False)
50 |         self.sum, self.count = total.tolist()
51 |         self.avg = self.sum / self.count
52 | 
53 |     def __str__(self):
54 |         fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
55 |         return fmtstr.format(**self.__dict__)
56 | 
57 |     def summary(self):
58 |         fmtstr = ''
59 |         if self.summary_type is Summary.NONE:
60 |             fmtstr = ''
61 |         elif self.summary_type is Summary.AVERAGE:
62 |             fmtstr = '{name} {avg:.3f}'
63 |         elif self.summary_type is Summary.SUM:
64 |             fmtstr = '{name} {sum:.3f}'
65 |         elif self.summary_type is Summary.COUNT:
66 |             fmtstr = '{name} {count:.3f}'
67 |         else:
68 |             raise ValueError('invalid summary type %r' % self.summary_type)
69 | 
70 |         return fmtstr.format(**self.__dict__)
71 | 


--------------------------------------------------------------------------------
/libs/utils/misc.py:
--------------------------------------------------------------------------------
 1 | # -*- coding: utf-8 -*-
 2 | # Copyright (c) XiMing Xing. All rights reserved.
 3 | # Author: XiMing Xing
 4 | # Description:
 5 | 
 6 | import math
 7 | 
 8 | import torch
 9 | 
10 | 
11 | def identity(t, *args, **kwargs):
12 |     """return t"""
13 |     return t
14 | 
15 | 
16 | def exists(x):
17 |     """whether x is None or not"""
18 |     return x is not None
19 | 
20 | 
21 | def default(val, d):
22 |     """ternary judgment: val != None ? val : d"""
23 |     if exists(val):
24 |         return val
25 |     return d() if callable(d) else d
26 | 
27 | 
28 | def has_int_squareroot(num):
29 |     return (math.sqrt(num) ** 2) == num
30 | 
31 | 
32 | def num_to_groups(num, divisor):
33 |     groups = num // divisor
34 |     remainder = num % divisor
35 |     arr = [divisor] * groups
36 |     if remainder > 0:
37 |         arr.append(remainder)
38 |     return arr
39 | 
40 | 
41 | #################################################################################
42 | #                             Model Utils                                       #
43 | #################################################################################
44 | 
45 | def sum_params(model: torch.nn.Module, eps: float = 1e6):
46 |     return sum(p.numel() for p in model.parameters()) / eps
47 | 
48 | 
49 | #################################################################################
50 | #                            DataLoader Utils                                   #
51 | #################################################################################
52 | 
53 | def cycle(dl):
54 |     while True:
55 |         for data in dl:
56 |             yield data
57 | 
58 | 
59 | #################################################################################
60 | #                            Diffusion Model Utils                              #
61 | #################################################################################
62 | 
63 | def extract(a, t, x_shape):
64 |     b, *_ = t.shape
65 |     assert x_shape[0] == b
66 |     out = a.gather(-1, t)  # 1-D tensor, shape: (b,)
67 |     return out.reshape(b, *((1,) * (len(x_shape) - 1)))  # shape: [b, 1, 1, 1]
68 | 
69 | 
70 | def unnormalize(x):
71 |     """unnormalize_to_zero_to_one"""
72 |     x = (x + 1) * 0.5  # Map the data interval to [0, 1]
73 |     return torch.clamp(x, 0.0, 1.0)
74 | 
75 | 
76 | def normalize(x):
77 |     """normalize_to_neg_one_to_one"""
78 |     x = x * 2 - 1  # Map the data interval to [-1, 1]
79 |     return torch.clamp(x, -1.0, 1.0)
80 | 


--------------------------------------------------------------------------------
/libs/utils/model_summary.py:
--------------------------------------------------------------------------------
  1 | # -*- coding: utf-8 -*-
  2 | # Copyright (c) XiMing Xing. All rights reserved.
  3 | # Author: XiMing Xing
  4 | # Description:
  5 | 
  6 | import sys
  7 | from collections import OrderedDict
  8 | 
  9 | import numpy as np
 10 | import torch
 11 | 
 12 | layer_modules = (torch.nn.MultiheadAttention,)
 13 | 
 14 | 
 15 | def summary(model, input_data=None, input_data_args=None, input_shape=None, input_dtype=torch.FloatTensor,
 16 |             batch_size=-1,
 17 |             *args, **kwargs):
 18 |     """
 19 |     give example input data as least one way like below:
 20 |     ① input_data ---> model.forward(input_data)
 21 |     ② input_data_args ---> model.forward(*input_data_args)
 22 |     ③ input_shape & input_dtype ---> model.forward(*[torch.rand(2, *size).type(input_dtype) for size in input_shape])
 23 |     """
 24 | 
 25 |     hooks = []
 26 |     summary = OrderedDict()
 27 | 
 28 |     def register_hook(module):
 29 |         def hook(module, inputs, outputs):
 30 | 
 31 |             class_name = str(module.__class__).split(".")[-1].split("'")[0]
 32 |             module_idx = len(summary)
 33 | 
 34 |             key = "%s-%i" % (class_name, module_idx + 1)
 35 | 
 36 |             info = OrderedDict()
 37 |             info["id"] = id(module)
 38 |             if isinstance(outputs, (list, tuple)):
 39 |                 try:
 40 |                     info["out"] = [batch_size] + list(outputs[0].size())[1:]
 41 |                 except AttributeError:
 42 |                     # pack_padded_seq and pad_packed_seq store feature into data attribute
 43 |                     info["out"] = [batch_size] + list(outputs[0].data.size())[1:]
 44 |             else:
 45 |                 info["out"] = [batch_size] + list(outputs.size())[1:]
 46 | 
 47 |             info["params_nt"], info["params"] = 0, 0
 48 |             for name, param in module.named_parameters():
 49 |                 info["params"] += param.nelement() * param.requires_grad
 50 |                 info["params_nt"] += param.nelement() * (not param.requires_grad)
 51 | 
 52 |             summary[key] = info
 53 | 
 54 |         # ignore Sequential and ModuleList and other containers
 55 |         if isinstance(module, layer_modules) or not module._modules:
 56 |             hooks.append(module.register_forward_hook(hook))
 57 | 
 58 |     model.apply(register_hook)
 59 | 
 60 |     # multiple inputs to the network
 61 |     if isinstance(input_shape, tuple):
 62 |         input_shape = [input_shape]
 63 | 
 64 |     if input_data is not None:
 65 |         x = [input_data]
 66 |     elif input_shape is not None:
 67 |         # batch_size of 2 for batchnorm
 68 |         x = [torch.rand(2, *size).type(input_dtype) for size in input_shape]
 69 |     elif input_data_args is not None:
 70 |         x = input_data_args
 71 |     else:
 72 |         x = []
 73 |     try:
 74 |         with torch.no_grad():
 75 |             model(*x) if not (kwargs or args) else model(*x, *args, **kwargs)
 76 |     except Exception:
 77 |         # This can be usefull for debugging
 78 |         print("Failed to run summary...")
 79 |         raise
 80 |     finally:
 81 |         for hook in hooks:
 82 |             hook.remove()
 83 |     summary_logs = []
 84 |     summary_logs.append("--------------------------------------------------------------------------")
 85 |     line_new = "{:<30}  {:>20} {:>20}".format("Layer (type)", "Output Shape", "Param #")
 86 |     summary_logs.append(line_new)
 87 |     summary_logs.append("==========================================================================")
 88 |     total_params = 0
 89 |     total_output = 0
 90 |     trainable_params = 0
 91 |     for layer in summary:
 92 |         # layer, output_shape, params
 93 |         line_new = "{:<30}  {:>20} {:>20}".format(
 94 |             layer,
 95 |             str(summary[layer]["out"]),
 96 |             "{0:,}".format(summary[layer]["params"] + summary[layer]["params_nt"])
 97 |         )
 98 |         total_params += (summary[layer]["params"] + summary[layer]["params_nt"])
 99 |         total_output += np.prod(summary[layer]["out"])
100 |         trainable_params += summary[layer]["params"]
101 |         summary_logs.append(line_new)
102 | 
103 |     # assume 4 bytes/number
104 |     if input_data is not None:
105 |         total_input_size = abs(sys.getsizeof(input_data) / (1024 ** 2.))
106 |     elif input_shape is not None:
107 |         total_input_size = abs(np.prod(input_shape) * batch_size * 4. / (1024 ** 2.))
108 |     else:
109 |         total_input_size = 0.0
110 |     total_output_size = abs(2. * total_output * 4. / (1024 ** 2.))  # x2 for gradients
111 |     total_params_size = abs(total_params * 4. / (1024 ** 2.))
112 |     total_size = total_params_size + total_output_size + total_input_size
113 | 
114 |     summary_logs.append("==========================================================================")
115 |     summary_logs.append("Total params: {0:,}".format(total_params))
116 |     summary_logs.append("Trainable params: {0:,}".format(trainable_params))
117 |     summary_logs.append("Non-trainable params: {0:,}".format(total_params - trainable_params))
118 |     summary_logs.append("--------------------------------------------------------------------------")
119 |     summary_logs.append("Input size (MB): %0.6f" % total_input_size)
120 |     summary_logs.append("Forward/backward pass size (MB): %0.6f" % total_output_size)
121 |     summary_logs.append("Params size (MB): %0.6f" % total_params_size)
122 |     summary_logs.append("Estimated Total Size (MB): %0.6f" % total_size)
123 |     summary_logs.append("--------------------------------------------------------------------------")
124 | 
125 |     summary_info = "\n".join(summary_logs)
126 | 
127 |     print(summary_info)
128 |     return summary_info
129 | 


--------------------------------------------------------------------------------
/libs/utils/tqdm.py:
--------------------------------------------------------------------------------
 1 | # -*- coding: utf-8 -*-
 2 | # Copyright (c) XiMing Xing. All rights reserved.
 3 | # Author: XiMing Xing
 4 | # Description:
 5 | 
 6 | from typing import Callable
 7 | from tqdm.auto import tqdm
 8 | 
 9 | 
10 | def tqdm_decorator(func: Callable):
11 |     """A decorator function called tqdm_decorator that takes a function as an argument and
12 |     returns a new function that wraps the input function with a tqdm progress bar.
13 | 
14 |     Noting: **The input function is assumed to have an object self as its first argument**, which contains a step attribute,
15 |     an args attribute with a train_num_steps attribute, and an accelerator attribute with an is_main_process attribute.
16 | 
17 |     Args:
18 |         func: tqdm_decorator
19 | 
20 |     Returns:
21 |             a new function that wraps the input function with a tqdm progress bar.
22 |     """
23 | 
24 |     def wrapper(*args, **kwargs):
25 |         with tqdm(initial=args[0].step,
26 |                   total=args[0].args.train_num_steps,
27 |                   disable=not args[0].accelerator.is_main_process) as pbar:
28 |             func(*args, **kwargs, pbar=pbar)
29 | 
30 |     return wrapper
31 | 


--------------------------------------------------------------------------------
/pipelines/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright (c) XiMing Xing. All rights reserved.
3 | # Author: XiMing Xing
4 | # Description:
5 | 


--------------------------------------------------------------------------------
/pipelines/inversion/ILVR.py:
--------------------------------------------------------------------------------
  1 | # -*- coding: utf-8 -*-
  2 | # Copyright (c) XiMing Xing. All rights reserved.
  3 | # Author: XiMing Xing
  4 | # Description:
  5 | 
  6 | from argparse import Namespace
  7 | 
  8 | import torch
  9 | import torch.nn as nn
 10 | from torchvision import utils as tv_util
 11 | from tqdm import tqdm
 12 | 
 13 | from libs.engine import ModelState
 14 | from libs.utils import cycle
 15 | from sketch_nn.augment.resizer import Resizer
 16 | 
 17 | 
 18 | class ILVRPipeline(ModelState):
 19 | 
 20 |     def __init__(
 21 |             self,
 22 |             args: Namespace,
 23 |             eps_model: nn.Module,
 24 |             eps_model_path: str,
 25 |             diffusion: nn.Module,
 26 |             dataloader: torch.utils.data.DataLoader
 27 |     ):
 28 |         super().__init__(args)
 29 |         self.args = args
 30 | 
 31 |         # set log path
 32 |         self.results_path = self.results_path.joinpath(f"{args.task}-sample")
 33 |         self.results_path.mkdir(exist_ok=True)
 34 | 
 35 |         self.diffusion = diffusion
 36 | 
 37 |         # create eps_model
 38 |         self.print(f"loading SDE from `{eps_model_path}` ....")
 39 |         self.eps_model = self.load_ckpt_model_only(eps_model, eps_model_path)
 40 |         if args.model.use_fp16:
 41 |             self.eps_model.convert_to_fp16()
 42 |         self.eps_model.eval()
 43 |         self.print(f"-> eps_model Params: {(sum(p.numel() for p in self.eps_model.parameters()) / 1e6):.3f}M")
 44 | 
 45 |         self.eps_model, self.dataloader = self.accelerator.prepare(self.eps_model, dataloader)
 46 |         self.dataloader = cycle(self.dataloader)
 47 | 
 48 |     def sample(self):
 49 |         device = self.accelerator.device
 50 |         accelerator = self.accelerator
 51 | 
 52 |         sample = next(iter(self.dataloader))
 53 |         batch_size = sample["image"].shape[0]  # get real batch_size
 54 |         image_size = self.args.image_size
 55 | 
 56 |         down_N = self.args.down_N
 57 |         shape = (batch_size, 3, image_size, image_size)
 58 |         shape_d = (
 59 |             batch_size, 3, int(image_size / down_N), int(image_size / down_N)
 60 |         )
 61 |         down = Resizer(shape, 1 / down_N).to(device)
 62 |         up = Resizer(shape_d, down_N).to(device)
 63 |         resizers = (down, up)
 64 | 
 65 |         extra_kwargs = {}
 66 |         model_kwargs = {}
 67 | 
 68 |         i = 0
 69 |         with tqdm(initial=i, total=self.args.total_samples, disable=not accelerator.is_main_process) as pbar:
 70 |             while self.step < self.args.total_samples:
 71 |                 sample = next(self.dataloader)
 72 |                 ref_img, name = sample["image"], sample["fname"]
 73 |                 extra_kwargs["ref_img"] = ref_img
 74 | 
 75 |                 sample = self.diffusion.p_sample_loop(
 76 |                     self.eps_model,
 77 |                     (batch_size, 3, image_size, image_size),
 78 |                     clip_denoised=self.args.diffusion.clip_denoised,
 79 |                     model_kwargs=model_kwargs,
 80 |                     resizers=resizers,
 81 |                     range_t=self.args.range_t,
 82 |                     extra_kwargs=extra_kwargs
 83 |                 )
 84 | 
 85 |                 if self.accelerator.is_main_process:
 86 |                     sample = self.accelerator.gather(sample)
 87 |                     sample = (sample + 1) / 2
 88 | 
 89 |                     if self.args.get_final_results:
 90 |                         for b in range(sample.shape[0]):
 91 |                             name_ = name[b].split(".")[0]  # Remove file suffixes
 92 |                             save_path = self.results_path / f"{int(self.step + b)}-{name_}.png"
 93 |                             tv_util.save_image(sample[b], save_path)
 94 |                     else:
 95 |                         for b in range(sample.shape[0]):
 96 |                             save_path = self.results_path.joinpath(
 97 |                                 f"i-{i}-b-{b}-t-down_N-{down_N}-rt-{self.args.range_t}.png"
 98 |                             )
 99 |                             # (x0, sampled)
100 |                             save_grids = torch.cat(
101 |                                 [ref_img[b].unsqueeze_(0), sample[b].unsqueeze_(0)],
102 |                                 dim=0
103 |                             )
104 |                             tv_util.save_image(save_grids.float(), save_path, nrow=sample.shape[0])
105 | 
106 |                 i += batch_size
107 |                 pbar.update(1)
108 |         self.close()
109 | 


--------------------------------------------------------------------------------
/pipelines/inversion/ILVR_mixup.py:
--------------------------------------------------------------------------------
  1 | # -*- coding: utf-8 -*-
  2 | # Copyright (c) XiMing Xing. All rights reserved.
  3 | # Author: XiMing Xing
  4 | # Description:
  5 | 
  6 | from argparse import Namespace
  7 | 
  8 | import torch
  9 | import torch.nn as nn
 10 | from torchvision import utils as tv_util
 11 | from tqdm import tqdm
 12 | 
 13 | from libs.engine import ModelState
 14 | from libs.utils import cycle
 15 | from sketch_nn.augment.resizer import Resizer
 16 | 
 17 | class ILVRMixupPipeline(ModelState):
 18 | 
 19 |     def __init__(
 20 |             self,
 21 |             args: Namespace,
 22 |             eps_model: nn.Module,
 23 |             eps_model_path: str,
 24 |             diffusion: nn.Module,
 25 |             src_dataloader: torch.utils.data.DataLoader,
 26 |             ref_dataloader: torch.utils.data.DataLoader
 27 |     ):
 28 |         super().__init__(args)
 29 |         self.args = args
 30 | 
 31 |         # set log path
 32 |         self.results_path = self.results_path.joinpath(f"{args.task}-sample")
 33 |         self.results_path.mkdir(exist_ok=True)
 34 | 
 35 |         self.diffusion = diffusion
 36 | 
 37 |         # create eps_model
 38 |         self.print(f"loading SDE from `{eps_model_path}` ....")
 39 |         self.eps_model = self.load_ckpt_model_only(eps_model, eps_model_path)
 40 |         if args.model.use_fp16:
 41 |             self.eps_model.convert_to_fp16()
 42 |         self.eps_model.eval()
 43 |         self.print(f"-> eps_model Params: {(sum(p.numel() for p in self.eps_model.parameters()) / 1e6):.3f}M")
 44 | 
 45 |         self.eps_model = self.accelerator.prepare(self.eps_model)
 46 |         self.src_dataloader, self.ref_dataloader = self.accelerator.prepare(src_dataloader, ref_dataloader)
 47 |         self.src_dataloader = cycle(self.src_dataloader)
 48 |         self.ref_dataloader = cycle(self.ref_dataloader)
 49 | 
 50 |     def sample(self):
 51 |         device = self.accelerator.device
 52 | 
 53 |         sample = next(iter(self.src_dataloader))
 54 |         batch_size = sample["image"].shape[0]  # get real batch_size
 55 |         image_size = self.args.image_size
 56 | 
 57 |         down_N = self.args.down_N
 58 |         shape = (batch_size, 3, image_size, image_size)
 59 |         shape_d = (
 60 |             batch_size, 3, int(image_size / down_N), int(image_size / down_N)
 61 |         )
 62 |         down = Resizer(shape, 1 / down_N).to(device)
 63 |         up = Resizer(shape_d, down_N).to(device)
 64 |         resizers = (down, up)
 65 | 
 66 |         model_kwargs = {}
 67 |         i = 0
 68 | 
 69 |         with tqdm(initial=i, total=self.args.total_samples, disable=not self.accelerator.is_local_main_process) as pbar:
 70 |             while i < self.args.total_samples:
 71 |                 src_sample = next(self.src_dataloader)
 72 |                 src_input, src_name = src_sample["image"], src_sample["fname"]
 73 |                 ref_sample = next(self.ref_dataloader)
 74 |                 ref_input, ref_name = ref_sample["image"], ref_sample["fname"]
 75 | 
 76 |                 extra_kwargs = {
 77 |                     "src_input": src_input,
 78 |                     "ref_input": ref_input,
 79 |                     "fuse_scale": self.args.fuse_scale
 80 |                 }
 81 | 
 82 |                 sample = self.diffusion.p_sample_loop(
 83 |                     self.eps_model,
 84 |                     (batch_size, 3, image_size, image_size),
 85 |                     clip_denoised=self.args.diffusion.clip_denoised,
 86 |                     model_kwargs=model_kwargs,
 87 |                     resizers=resizers,
 88 |                     range_t=self.args.range_t,
 89 |                     extra_kwargs=extra_kwargs
 90 |                 )
 91 | 
 92 |                 if self.accelerator.is_main_process:
 93 |                     sample = self.accelerator.gather(sample)
 94 |                     sample = (sample + 1) / 2
 95 | 
 96 |                     if self.args.get_final_results:
 97 |                         for b in range(sample.shape[0]):
 98 |                             s_name_ = src_name[b].split(".")[0]  # Remove file suffixes
 99 |                             r_name_ = ref_name[b].split(".")[0]  # Remove file suffixes
100 |                             save_path = self.results_path / f"{int(self.step + b)}-{s_name_}_to_{r_name_}.png"
101 |                             tv_util.save_image(sample[b], save_path)
102 |                         else:
103 |                             for b in range(sample.shape[0]):
104 |                                 save_path = self.results_path.joinpath(
105 |                                     f"i-{i}-b-{b}-t-down_N-{down_N}-rt-{self.args.range_t}.png"
106 |                                 )
107 |                                 # (x0, sampled)
108 |                                 save_grids = torch.cat(
109 |                                     [ref_input[b].unsqueeze_(0), sample[b].unsqueeze_(0)],
110 |                                     dim=0
111 |                                 )
112 |                                 tv_util.save_image(save_grids.float(), save_path, nrow=sample.shape[0])
113 | 
114 |                 i += batch_size
115 |                 pbar.update(1)
116 | 
117 |         self.close()
118 | 


--------------------------------------------------------------------------------
/pipelines/inversion/SDEdit_iter_pipeline.py:
--------------------------------------------------------------------------------
  1 | # -*- coding: utf-8 -*-
  2 | # Copyright (c) XiMing Xing. All rights reserved.
  3 | # Author: XiMing Xing
  4 | # Description:
  5 | 
  6 | from datetime import datetime
  7 | 
  8 | import torch
  9 | from torchvision import utils as tv_util
 10 | from tqdm import tqdm
 11 | 
 12 | from libs.engine import ModelState
 13 | from libs.utils import cycle
 14 | from sketch_nn.augment.resizer import Resizer
 15 | from sketch_nn.methods.inversion.SDEdit_iter import IterativeSDEdit
 16 | 
 17 | 
 18 | class IterativeSDEditPipeline(ModelState):
 19 | 
 20 |     def __init__(self, args, sde_model, sde_path, src_dataloader, ref_dataloader):
 21 |         super().__init__(args)
 22 |         self.args = args
 23 | 
 24 |         self.print(f"loading SDE from `{sde_path}` ....")
 25 |         self.sde_model = self.load_ckpt_model_only(sde_model, sde_path)
 26 |         self.print(f"-> SDE Params: {(sum(p.numel() for p in sde_model.parameters()) / 1e6):.3f}M")
 27 | 
 28 |         self.results_path = self.results_path.joinpath(f"{args.dataset}-{args.task}-sample-seed-{args.seed}")
 29 |         self.results_path.mkdir(exist_ok=True)
 30 | 
 31 |         dpm_cfg = args.diffusion
 32 |         self.SDEdit = IterativeSDEdit(dpm_cfg.timesteps, dpm_cfg.beta_schedule, dpm_cfg.var_type)
 33 | 
 34 |         self.SDEdit, self.sde_model = self.accelerator.prepare(self.SDEdit, self.sde_model)
 35 |         self.src_dataloader, self.ref_dataloader = self.accelerator.prepare(src_dataloader, ref_dataloader)
 36 |         self.src_dataloader = cycle(self.src_dataloader)
 37 |         self.ref_dataloader = cycle(self.ref_dataloader)
 38 | 
 39 |         self.print()
 40 | 
 41 |     def sample(self):
 42 |         accelerator = self.accelerator
 43 |         device = self.accelerator.device
 44 | 
 45 |         sample = next(iter(self.src_dataloader))
 46 |         batch_size = sample["image"].shape[0]  # online batch_size
 47 |         image_size = self.args.image_size
 48 | 
 49 |         s_down_N = self.args.src_down_N
 50 |         shape = (batch_size, 3, image_size, image_size)
 51 |         s_shape_d = (
 52 |             batch_size, 3, int(image_size / s_down_N), int(image_size / s_down_N)
 53 |         )
 54 |         src_down = Resizer(shape, 1 / s_down_N).to(device)
 55 |         src_up = Resizer(s_shape_d, s_down_N).to(device)
 56 |         low_passer = (src_down, src_up)
 57 | 
 58 |         model_kwargs = {}
 59 |         iter_kwargs = {
 60 |             'low_passer': low_passer,
 61 |             'fusion_scale': self.args.fusion_scale
 62 |         }
 63 |         i = 0
 64 |         with tqdm(initial=i, total=self.args.total_samples, disable=not accelerator.is_main_process) as pbar:
 65 |             while i < self.args.total_samples:
 66 |                 src_sample = next(self.src_dataloader)
 67 |                 src_input, name = src_sample["image"], src_sample["fname"]
 68 |                 ref_sample = next(self.ref_dataloader)
 69 |                 ref_input = ref_sample["image"]
 70 | 
 71 |                 model_kwargs['step'] = i
 72 | 
 73 |                 start_time = datetime.now()
 74 |                 results = self.SDEdit.iterative_sampling_progressive(
 75 |                     src_input,
 76 |                     ref_input,
 77 |                     self.args.iter_step,
 78 |                     iter_kwargs,
 79 |                     list(self.args.repeat_step),
 80 |                     list(self.args.perturb_step),
 81 |                     model=self.sde_model,
 82 |                     model_kwargs=model_kwargs,
 83 |                     device=device,
 84 |                     recorder=pbar
 85 |                 )
 86 |                 pbar.set_description(f"one batch time: {datetime.now() - start_time}, "
 87 |                                      f"total_iter: {self.args.iter_step}")
 88 | 
 89 |                 if accelerator.is_main_process:
 90 |                     results = accelerator.gather(results)
 91 |                     # gather final result
 92 |                     for b in range(batch_size):
 93 |                         all_iter_grids = []
 94 |                         for ith in range(self.args.iter_step):
 95 |                             for kth in range(self.args.repeat_step[ith]):
 96 |                                 x0, final, perturb_x0, blurred_x0 = results[f"{ith}-{kth}-th"]
 97 |                                 # (x0, perturbed_x0, kth_translated_x0)
 98 |                                 # (x0, perturbed_x0, src, blurred_x0, kth_translated_x0)
 99 |                                 save_grids = torch.cat([x0[b].unsqueeze_(0),
100 |                                                         perturb_x0[b].unsqueeze_(0),
101 |                                                         src_input[b].unsqueeze_(0)
102 |                                                         if ith != 0 else torch.zeros_like(src_input[b]).unsqueeze_(0),
103 |                                                         blurred_x0[b].unsqueeze_(0),
104 |                                                         final[b].unsqueeze_(0)], dim=0)
105 |                                 all_iter_grids.append(save_grids)
106 |                         # visual
107 |                         img_name = name[b].split(".")[0]  # Remove file suffixes
108 |                         save_path = self.results_path.joinpath(
109 |                             f"{i}-{img_name}-iter-{ith + 1}-K-{kth + 1}-t-{self.args.perturb_step}.png"
110 |                         )
111 |                         tv_util.save_image(torch.cat(all_iter_grids, dim=0), save_path, nrow=5)
112 | 
113 |                 i += 1
114 |                 pbar.update(1)
115 | 
116 |         self.close()
117 | 


--------------------------------------------------------------------------------
/pipelines/inversion/SDEdit_pipeline.py:
--------------------------------------------------------------------------------
 1 | # -*- coding: utf-8 -*-
 2 | # Copyright (c) XiMing Xing. All rights reserved.
 3 | # Author: XiMing Xing
 4 | # Description:
 5 | 
 6 | from datetime import datetime
 7 | 
 8 | import torch
 9 | from torchvision import utils as tv_util
10 | from tqdm import tqdm
11 | 
12 | from libs.engine import ModelState
13 | from dpm_nn.inversion.SDEdit import SDEdit
14 | 
15 | 
16 | class SDEditPipeline(ModelState):
17 | 
18 |     def __init__(self, args, sde_model, sde_path, dataloader, use_dpm_solver: bool = False):
19 |         super().__init__(args)
20 |         self.args = args
21 |         self.use_dpm_solver = use_dpm_solver
22 | 
23 |         self.print(f"loading SDE from `{sde_path}` ....")
24 |         self.sde_model = self.load_ckpt_model_only(sde_model, sde_path)
25 |         self.print(f"-> SDE Params: {(sum(p.numel() for p in sde_model.parameters()) / 1e6):.3f}M")
26 | 
27 |         dpm_cfg = args.diffusion
28 |         self.SDEdit = SDEdit(dpm_cfg.timesteps, dpm_cfg.beta_schedule, dpm_cfg.var_type)
29 | 
30 |         self.SDEdit, self.sde_model = \
31 |             self.accelerator.prepare(self.SDEdit, self.sde_model)
32 |         self.dataloader = self.accelerator.prepare(dataloader)
33 | 
34 |         self.print()
35 | 
36 |     def sample(self):
37 |         accelerator = self.accelerator
38 |         device = self.accelerator.device
39 | 
40 |         sample = next(iter(self.dataloader))
41 |         batch_size = sample["image"].shape[0]  # online batch_size
42 |         image_size = self.args.image_size
43 | 
44 |         model_kwargs = {}
45 |         with tqdm(self.dataloader, disable=not accelerator.is_local_main_process) as pbar:
46 |             for i, sample in enumerate(pbar):
47 |                 start_time = datetime.now()
48 | 
49 |                 src_input, name = sample["image"], sample["fname"]
50 | 
51 |                 model_kwargs['step'] = i
52 |                 results = self.SDEdit.sampling_progressive(
53 |                     src_input,
54 |                     mask=sample.get('mask', None),  # editing mask
55 |                     repeat_step=self.args.repeat_step,
56 |                     perturb_step=self.args.perturb_step,
57 |                     model=self.sde_model,
58 |                     model_kwargs=model_kwargs,
59 |                     device=device,
60 |                     recorder=pbar,
61 |                     use_dpm_solver=self.use_dpm_solver
62 |                 )
63 | 
64 |                 pbar.set_description(f"time per batch: {datetime.now() - start_time}")
65 |                 # pbar.write(f"Running time: {datetime.now() - start_time} | batch_size: {batch_size} \n")
66 | 
67 |                 if accelerator.is_main_process:
68 |                     results = accelerator.gather(results)
69 |                     # gather final result
70 |                     for b in range(batch_size):
71 |                         all_iter_grids = []
72 |                         for kth in range(len(results)):
73 |                             x0, final, perturb_x0 = results[f"{kth}-th"]
74 |                             # (x0, perturbed_x0, kth_translated_x0)
75 |                             save_grids = torch.cat(
76 |                                 [x0[b].unsqueeze_(0), perturb_x0[b].unsqueeze_(0), final[b].unsqueeze_(0)],
77 |                                 dim=0
78 |                             )
79 |                             all_iter_grids.append(save_grids)
80 |                         # visual
81 |                         img_name = name[b].split(".")[0]  # Remove file suffixes
82 |                         save_path = self.results_path.joinpath(
83 |                             f"i-{i}-{img_name}-B-{b}-K-{kth + 1}-t-{self.args.perturb_step}.png"
84 |                         )
85 |                         tv_util.save_image(torch.cat(all_iter_grids, dim=0), save_path, nrow=3)
86 | 
87 |         self.close()
88 | 


--------------------------------------------------------------------------------
/pipelines/inversion/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright (c) XiMing Xing. All rights reserved.
3 | # Author: XiMing Xing
4 | # Description:
5 | 


--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
  1 | absl-py @ file:///opt/conda/conda-bld/absl-py_1639803114343/work
  2 | accelerate==0.18.0
  3 | aiohttp @ file:///tmp/build/80754af9/aiohttp_1646806366512/work
  4 | aiosignal @ file:///tmp/build/80754af9/aiosignal_1637843061372/work
  5 | antlr4-python3-runtime==4.9.3
  6 | async-timeout @ file:///tmp/build/80754af9/async-timeout_1637851218186/work
  7 | attrs @ file:///opt/conda/conda-bld/attrs_1642510447205/work
  8 | beautifulsoup4==4.12.2
  9 | blinker==1.4
 10 | Bottleneck @ file:///opt/conda/conda-bld/bottleneck_1657175564434/work
 11 | brotlipy==0.7.0
 12 | cachetools @ file:///tmp/build/80754af9/cachetools_1619597386817/work
 13 | certifi @ file:///croot/certifi_1665076670883/work/certifi
 14 | cffi @ file:///tmp/abs_98z5h56wf8/croots/recipe/cffi_1659598650955/work
 15 | charset-normalizer @ file:///tmp/build/80754af9/charset-normalizer_1630003229654/work
 16 | click @ file:///tmp/build/80754af9/click_1646038465422/work
 17 | clip @ git+https://github.com/openai/CLIP.git@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1
 18 | cmake==3.26.1
 19 | contourpy==1.0.6
 20 | cryptography @ file:///tmp/build/80754af9/cryptography_1652083738073/work
 21 | cycler @ file:///tmp/build/80754af9/cycler_1637851556182/work
 22 | docker-pycreds==0.4.0
 23 | einops==0.5.0
 24 | filelock==3.9.0
 25 | fonttools==4.25.0
 26 | frozenlist @ file:///tmp/build/80754af9/frozenlist_1637767111923/work
 27 | fsspec==2023.4.0
 28 | ftfy==6.1.1
 29 | future==0.18.2
 30 | gitdb==4.0.10
 31 | GitPython==3.1.30
 32 | google-auth @ file:///opt/conda/conda-bld/google-auth_1646735974934/work
 33 | google-auth-oauthlib @ file:///tmp/build/80754af9/google-auth-oauthlib_1617120569401/work
 34 | googledrivedownloader @ file:///home/conda/feedstock_root/build_artifacts/googledrivedownloader_1619807768586/work
 35 | grpcio @ file:///tmp/build/80754af9/grpcio_1637590823556/work
 36 | huggingface-hub==0.12.0
 37 | idna @ file:///tmp/build/80754af9/idna_1637925883363/work
 38 | importlib-metadata @ file:///tmp/build/80754af9/importlib-metadata_1648562408398/work
 39 | Jinja2 @ file:///opt/conda/conda-bld/jinja2_1647436528585/work
 40 | joblib @ file:///tmp/build/80754af9/joblib_1635411271373/work
 41 | kiwisolver @ file:///opt/conda/conda-bld/kiwisolver_1653292039266/work
 42 | kornia==0.6.11
 43 | lightning-utilities==0.8.0
 44 | lit==16.0.0
 45 | Markdown @ file:///tmp/build/80754af9/markdown_1614363528767/work
 46 | MarkupSafe @ file:///opt/conda/conda-bld/markupsafe_1654597864307/work
 47 | matplotlib==3.6.2
 48 | mkl-fft==1.3.1
 49 | mkl-random @ file:///tmp/build/80754af9/mkl_random_1626186064646/work
 50 | mkl-service==2.4.0
 51 | mpmath==1.3.0
 52 | multidict @ file:///opt/conda/conda-bld/multidict_1662369340274/work
 53 | munkres==1.1.4
 54 | networkx @ file:///opt/conda/conda-bld/networkx_1657784097507/work
 55 | numexpr @ file:///opt/conda/conda-bld/numexpr_1656940300424/work
 56 | numpy @ file:///tmp/abs_653_j00fmm/croots/recipe/numpy_and_numpy_base_1659432701727/work
 57 | nvidia-cublas-cu11==11.10.3.66
 58 | nvidia-cuda-cupti-cu11==11.7.101
 59 | nvidia-cuda-nvrtc-cu11==11.7.99
 60 | nvidia-cuda-runtime-cu11==11.7.99
 61 | nvidia-cudnn-cu11==8.5.0.96
 62 | nvidia-cufft-cu11==10.9.0.58
 63 | nvidia-curand-cu11==10.2.10.91
 64 | nvidia-cusolver-cu11==11.4.0.1
 65 | nvidia-cusparse-cu11==11.7.4.91
 66 | nvidia-nccl-cu11==2.14.3
 67 | nvidia-nvtx-cu11==11.7.91
 68 | oauthlib @ file:///tmp/abs_08ngfezid4/croots/recipe/oauthlib_1659642459222/work
 69 | omegaconf==2.3.0
 70 | opencv-python==4.7.0.72
 71 | packaging @ file:///tmp/build/80754af9/packaging_1637314298585/work
 72 | pandas==1.4.3
 73 | pathtools==0.1.2
 74 | Pillow==9.2.0
 75 | ply==3.11
 76 | promise==2.3
 77 | protobuf==3.20.1
 78 | psutil==5.9.2
 79 | pyasn1 @ file:///Users/ktietz/demo/mc3/conda-bld/pyasn1_1629708007385/work
 80 | pyasn1-modules==0.2.8
 81 | pycparser @ file:///tmp/build/80754af9/pycparser_1636541352034/work
 82 | PyJWT @ file:///opt/conda/conda-bld/pyjwt_1657544592787/work
 83 | pyOpenSSL @ file:///opt/conda/conda-bld/pyopenssl_1643788558760/work
 84 | pyparsing @ file:///opt/conda/conda-bld/pyparsing_1661452539315/work
 85 | PyQt5-sip==12.11.0
 86 | PySocks @ file:///tmp/build/80754af9/pysocks_1605305779399/work
 87 | python-dateutil @ file:///tmp/build/80754af9/python-dateutil_1626374649649/work
 88 | python-louvain @ file:///tmp/build/80754af9/python-louvain_1612304551119/work
 89 | pytz @ file:///opt/conda/conda-bld/pytz_1654762638606/work
 90 | PyYAML==6.0
 91 | regex==2022.10.31
 92 | requests @ file:///opt/conda/conda-bld/requests_1657734628632/work
 93 | requests-oauthlib==1.3.0
 94 | rsa @ file:///tmp/build/80754af9/rsa_1614366226499/work
 95 | scikit-learn @ file:///tmp/abs_d76175bc-917a-47d4-9994-b56265948a6328vmoe2o/croots/recipe/scikit-learn_1658419412415/work
 96 | scipy @ file:///home/conda/feedstock_root/build_artifacts/scipy_1653073867187/work
 97 | sentencepiece==0.1.99
 98 | sentry-sdk==1.12.1
 99 | setproctitle==1.3.2
100 | shortuuid==1.0.11
101 | sip @ file:///tmp/abs_44cd77b_pu/croots/recipe/sip_1659012365470/work
102 | six @ file:///tmp/build/80754af9/six_1644875935023/work
103 | smmap==5.0.0
104 | soupsieve==2.4.1
105 | sympy==1.11.1
106 | tensorboard @ file:///home/builder/stiwari/miniconda3/envs/tf_new_env/conda-bld/tensorboard_1661447826088/work/tensorboard-2.9.0-py3-none-any.whl
107 | tensorboard-data-server @ file:///tmp/build/80754af9/tensorboard-data-server_1633035064162/work/tensorboard_data_server-0.6.0-py3-none-manylinux2010_x86_64.whl
108 | tensorboard-plugin-wit @ file:///home/builder/tkoch/workspace/tensorflow/tensorboard-plugin-wit_1658918494740/work/tensorboard_plugin_wit-1.8.1-py3-none-any.whl
109 | threadpoolctl @ file:///Users/ktietz/demo/mc3/conda-bld/threadpoolctl_1629802263681/work
110 | tokenizers==0.13.2
111 | toml @ file:///tmp/build/80754af9/toml_1616166611790/work
112 | torch==1.13.1+cu116
113 | torch-cluster @ file:///usr/share/miniconda/envs/test/conda-bld/pytorch-cluster_1631029005429/work
114 | torch-geometric @ file:///usr/share/miniconda/envs/test/conda-bld/pyg_1640156451028/work
115 | torch-scatter @ file:///usr/share/miniconda/envs/test/conda-bld/pytorch-scatter_1634900577572/work
116 | torch-sparse @ file:///usr/share/miniconda/envs/test/conda-bld/pytorch-sparse_1631173533284/work
117 | torch-spline-conv @ file:///usr/share/miniconda/envs/test/conda-bld/pytorch-spline-conv_1631007898768/work
118 | torchaudio==0.13.1+cu116
119 | torchmetrics==0.11.4
120 | torchvision==0.14.1+cu116
121 | tornado @ file:///opt/conda/conda-bld/tornado_1662061693373/work
122 | tqdm @ file:///opt/conda/conda-bld/tqdm_1650891076910/work
123 | transformers==4.26.0
124 | triton==2.0.0
125 | typing_extensions @ file:///tmp/abs_ben9emwtky/croots/recipe/typing_extensions_1659638822008/work
126 | urllib3 @ file:///tmp/abs_5dhwnz6atv/croots/recipe/urllib3_1659110457909/work
127 | wandb==0.13.7
128 | wcwidth==0.2.5
129 | Werkzeug @ file:///opt/conda/conda-bld/werkzeug_1645628268370/work
130 | yacs @ file:///tmp/build/80754af9/yacs_1634047592950/work
131 | yarl @ file:///opt/conda/conda-bld/yarl_1661437085904/work
132 | zipp @ file:///opt/conda/conda-bld/zipp_1652341764480/work
133 | 


--------------------------------------------------------------------------------
/run/run_SDEdit.py:
--------------------------------------------------------------------------------
  1 | # -*- coding: utf-8 -*-
  2 | # Copyright (c) XiMing Xing. All rights reserved.
  3 | # Author: XiMing Xing
  4 | # Description:
  5 | 
  6 | import os
  7 | import sys
  8 | import argparse
  9 | 
 10 | from accelerate.utils import set_seed
 11 | 
 12 | sys.path.append(os.path.split(os.path.abspath(os.path.dirname(__file__)))[0])
 13 | 
 14 | from libs.utils.argparse import (merge_and_update_config, accelerate_parser, base_data_parser, base_sampling_parser)
 15 | from dpm_nn.guided_dpm.build_ADMs import ADMs_build_util
 16 | from sketch_nn.dataset.build import build_image2image_translation_dataset
 17 | 
 18 | 
 19 | def main(args):
 20 |     assert len(args.data_folder) > 0, "Insufficient dataset entry!"
 21 | 
 22 |     args.batch_size = args.valid_batch_size
 23 | 
 24 |     if args.task == "base":  # SDEdit - image to image translation
 25 |         from pipelines.inversion.SDEdit_pipeline import SDEditPipeline
 26 | 
 27 |         dataloader = build_image2image_translation_dataset(args.dataset, args.data_folder,
 28 |                                                            split=args.split,
 29 |                                                            image_size=args.image_size,
 30 |                                                            batch_size=args.valid_batch_size,
 31 |                                                            shuffle=args.shuffle, drop_last=True,
 32 |                                                            num_workers=args.num_workers)
 33 | 
 34 |         sde_model, _ = ADMs_build_util(args.image_size, args.num_classes, args.model, args.diffusion)
 35 | 
 36 |         SDEdit = SDEditPipeline(args, sde_model, args.sdepath, dataloader, args.use_dpm_solver)
 37 |         SDEdit.sample()
 38 | 
 39 |     elif args.task == "mask":  # TODO: SDEdit - image editing
 40 |         pass
 41 | 
 42 |     elif args.task == "ref":
 43 |         from pipelines.inversion.SDEdit_iter_pipeline import IterativeSDEditPipeline
 44 | 
 45 |         src_dataloader = build_image2image_translation_dataset(args.dataset, args.data_folder,
 46 |                                                                split=args.split,
 47 |                                                                image_size=args.image_size,
 48 |                                                                batch_size=args.valid_batch_size,
 49 |                                                                shuffle=args.shuffle, drop_last=True,
 50 |                                                                num_workers=args.num_workers)
 51 |         ref_dataloader = build_image2image_translation_dataset(args.dataset, args.ref_data_folder,
 52 |                                                                split=args.split,
 53 |                                                                image_size=args.image_size,
 54 |                                                                batch_size=args.valid_batch_size,
 55 |                                                                shuffle=args.shuffle, drop_last=True,
 56 |                                                                num_workers=args.num_workers)
 57 | 
 58 |         sde_model, _ = ADMs_build_util(args.image_size, args.num_classes, args.model, args.diffusion)
 59 | 
 60 |         SDEdit = IterativeSDEditPipeline(args, sde_model, args.sdepath, src_dataloader, ref_dataloader)
 61 |         SDEdit.sample()
 62 | 
 63 | 
 64 | if __name__ == '__main__':
 65 |     """ 
 66 |     ## cat2dog, base sampling, SDEdit:
 67 |     CUDA_VISIBLE_DEVICES=0 python run/run_SDEdit.py -c SDEdit/cat2dog-img256.yaml -sdepath ./checkpoint/InvSDE/afhq_dog_4m.pt -dpath ./dataset/afhq/val/cat -respath ./workdir/sdedit_cat -vbz 32 -final -ts 500
 68 |     CUDA_VISIBLE_DEVICES=0 python run/run_SDEdit.py -c SDEdit/cat2dog-img256.yaml -sdepath ./checkpoint/InvSDE/afhq_dog_4m.pt -dpath ./dataset/afhq/val/dog -respath ./workdir/sdedit_dog -vbz 32 -final -ts 500
 69 |     
 70 |     ## SDEdit + ref:
 71 |     CUDA_VISIBLE_DEVICES=0 python run/run_SDEdit.py -c SDEdit/iter-cat2dog-img256-p400-k33-dN32.yaml --task ref -sdepath ./checkpoint/afhq_dog_4m.pt -dpath ./dataset/afhq/train/cat -rdpath ./dataset/afhq/train_edge_map/dog -respath /data2/xingxm/skgruns/ -vbz 8
 72 |     """
 73 | 
 74 |     parser = argparse.ArgumentParser(
 75 |         description="SDEdit",
 76 |         parents=[accelerate_parser(), base_data_parser(), base_sampling_parser()]
 77 |     )
 78 | 
 79 |     # flag
 80 |     parser.add_argument("-tk", "--task",
 81 |                         default="base", type=str, choices=["base", "mask", "ref"],
 82 |                         help="guided image synthesis and editing.")
 83 |     # config
 84 |     parser.add_argument("-c", "--config",
 85 |                         required=True, type=str,
 86 |                         default="SDEdit/cat2dog-img256.yaml",
 87 |                         help="YAML/YML file for configuration.")
 88 |     # data path
 89 |     parser.add_argument("-dpath", "--data_folder",
 90 |                         nargs="+", type=str,
 91 |                         # default==['./dataset/afhq/val/cat'],
 92 |                         # default=['./dataset/afhq/train/cat', './dataset/afhq/train/dog'],
 93 |                         # default=['./dataset/afhq/train/cat', './dataset/afhq/train/wild', './dataset/afhq/train/dog'],
 94 |                         help="single input for single-domain, multi inputs for multi-domain")
 95 |     parser.add_argument("-rdpath", "--ref_data_folder",
 96 |                         nargs="+", type=str, default=None,
 97 |                         # default==['./dataset/afhq/val/cat'],
 98 |                         # default=['./dataset/afhq/train/cat', './dataset/afhq/train/dog'],
 99 |                         # default=['./dataset/afhq/train/cat', './dataset/afhq/train/wild', './dataset/afhq/train/dog'],
100 |                         help="single input for single-domain, multi inputs for multi-domain")
101 |     # model path
102 |     parser.add_argument("-sdepath",
103 |                         default="./checkpoint/afhq_dog_4m.pt", type=str,
104 |                         help="place pretrained model in `./checkpoint/afhq_dog_4m.pt`, "
105 |                              "if None, then train from scratch")
106 |     # use dpm-solver
107 |     parser.add_argument("-uds", "--use_dpm_solver",
108 |                         action='store_true',
109 |                         help="use dpm_solver accelerates sampling.")
110 |     # sampling mode
111 |     parser.add_argument("-final", "--get_final_results",
112 |                         action='store_true',
113 |                         help="visualize intermediate results or just get final output.")
114 | 
115 |     args = parser.parse_args()
116 |     args = merge_and_update_config(args)
117 | 
118 |     set_seed(args.seed)
119 |     main(args)
120 | 


--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
 1 | # -*- coding: utf-8 -*-
 2 | # Copyright (c) XiMing Xing. All rights reserved.
 3 | # Author: XiMing Xing
 4 | 
 5 | """
 6 | Description: How to install
 7 | # all:
 8 | pip install omegaconf tqdm scipy opencv-python einops BeautifulSoup4 timm matplotlib torchmetrics accelerate diffusers triton transformers -i https://pypi.tuna.tsinghua.edu.cn/simple
 9 | 
10 | # CLIP:
11 | pip install git+https://github.com/openai/CLIP.git -i https://pypi.tuna.tsinghua.edu.cn/simple
12 | 
13 | # torch 1.13.1:
14 | conda install pytorch==1.13.1 torchvision==0.14.1 torchaudio==0.13.1 pytorch-cuda=11.6 -c pytorch -c nvidia
15 | 
16 | # xformers (python=3.10):
17 | conda install xformers -c xformers
18 | xFormers - Toolbox to Accelerate Research on Transformers:
19 | https://github.com/facebookresearch/xformers
20 | """
21 | 
22 | from setuptools import setup, find_packages
23 | 
24 | setup(
25 |     name='SketchGuidedGeneration',
26 |     packages=find_packages(),
27 |     version='0.0.13',
28 |     license='MIT',
29 |     description='Sketch Guided Content Generation',
30 |     author='XiMing Xing',
31 |     author_email='ximingxing@gmail.com',
32 |     url='https://github.com/ximinng/SketchGeneration/',
33 |     long_description_content_type='text/markdown',
34 |     keywords=[
35 |         'artificial intelligence',
36 |         'generative models',
37 |         'sketch'
38 |     ],
39 |     install_requires=[
40 |         'omegaconf',  # YAML processor
41 |         'accelerate',  # Hugging Face - pytorch distributed configuration
42 |         'diffusers',  # Hugging Face - diffusion models
43 |         'transformers',  # Hugging Face - transformers
44 |         'einops',
45 |         'pillow',
46 |         'torch>=1.13.1',
47 |         'torchvision',
48 |         'tensorboard',
49 |         'torchmetrics',
50 |         'tqdm',  # progress bar
51 |         'timm',  # computer vision models
52 |         "numpy",  # numpy
53 |         'matplotlib',
54 |         'scikit-learn',
55 |         'omegaconf',  # configs
56 |         'Pillow',  # keep the PIL.Image.Resampling deprecation away,
57 |         'wandb',  # weights & Biases
58 |         'opencv-python',  # cv2
59 |         'BeautifulSoup4'
60 |     ],
61 |     classifiers=[
62 |         'Development Status :: 4 - Beta',
63 |         'Intended Audience :: Developers',
64 |         'Topic :: Scientific/Engineering :: Artificial Intelligence',
65 |         'License :: OSI Approved :: MIT License',
66 |         'Programming Language :: Python :: 3.8',
67 |     ],
68 | )
69 | 


--------------------------------------------------------------------------------
/sketch_nn/__init__.py:
--------------------------------------------------------------------------------
 1 | # -*- coding: utf-8 -*-
 2 | # Copyright (c) XiMing Xing. All rights reserved.
 3 | # Author: XiMing Xing
 4 | # Description:
 5 | 
 6 | from . import augment
 7 | from . import dataset
 8 | from . import edge_map
 9 | from . import methods
10 | from . import model
11 | from . import photo2sketch
12 | from . import rasterize
13 | 
14 | __version__ = '0.0.12'
15 | 


--------------------------------------------------------------------------------
/sketch_nn/augment/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright (c) XiMing Xing. All rights reserved.
3 | # Author: XiMing Xing
4 | # Description:
5 | 
6 | from .mixup import Mixup
7 | 


--------------------------------------------------------------------------------
/sketch_nn/augment/mixup.py:
--------------------------------------------------------------------------------
 1 | # -*- coding: utf-8 -*-
 2 | # Copyright (c) XiMing Xing. All rights reserved.
 3 | # Author: XiMing Xing
 4 | # Description:
 5 | 
 6 | import torch
 7 | import numpy as np
 8 | 
 9 | 
10 | class Mixup(object):
11 |     """
12 |     "Mixup: Beyond Empirical Risk Minimization (https://arxiv.org/abs/1710.09412)". In ICLR, 2018.
13 |     https://github.com/facebookresearch/mixup-cifar10
14 |     """
15 | 
16 |     def single_domain_mix(self, x, y, alpha=1.0, device='cpu'):
17 |         if alpha > 0:
18 |             lam = np.random.beta(alpha, alpha)
19 |         else:
20 |             lam = 1
21 | 
22 |         batch_size = x.size()[0]
23 |         index = torch.randperm(batch_size).to(device)
24 | 
25 |         mixed_x = lam * x + (1 - lam) * x[index, :]
26 |         y_a, y_b = y, y[index]
27 |         return mixed_x, y_a, y_b, lam
28 | 
29 |     def dual_domains_mix(self, x1, x2, y1, y2, alpha=1.0, device='cpu'):
30 |         if alpha > 0:
31 |             lam = np.random.beta(alpha, alpha)
32 |         else:
33 |             lam = 1
34 | 
35 |         mixed_x = lam * x1 + (1 - lam) * x2
36 |         return mixed_x, y1, y2, lam
37 | 
38 |     def criterion(self, criterion, pred, y_a, y_b, lam):
39 |         return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)
40 | 


--------------------------------------------------------------------------------
/sketch_nn/dataset/__init__.py:
--------------------------------------------------------------------------------
 1 | # -*- coding: utf-8 -*-
 2 | # Copyright (c) XiMing Xing. All rights reserved.
 3 | # Author: XiMing Xing
 4 | # Description:
 5 | 
 6 | # sketch dataset
 7 | from .mnist import MNISTDataset
 8 | from .sketchx_shoe_chairV2 import SketchXShoeAndChairCoordDataset, SketchXShoeAndChairPhotoDataset
 9 | from .sketchy import SketchyDataset
10 | # real image dataset
11 | from .cifar10 import CIFAR10Dataset
12 | from .imagenet import ImageNetDataset
13 | # common
14 | from .base_dataset import MultiDomainDataset, SingleDomainDataset, SingleDomainWithFileNameDataset
15 | 
16 | # utils
17 | from .base_dataset import is_image_file
18 | 


--------------------------------------------------------------------------------
/sketch_nn/dataset/utils.py:
--------------------------------------------------------------------------------
 1 | # -*- coding: utf-8 -*-
 2 | # Copyright (c) XiMing Xing. All rights reserved.
 3 | # Author: XiMing Xing
 4 | # Description:
 5 | 
 6 | 
 7 | ImageSuffices = [
 8 |     '.jpg', '.JPG', '.jpeg', '.JPEG',
 9 |     '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
10 |     '.tif', '.TIF', '.tiff', '.TIFF',
11 | ]
12 | 
13 | 
14 | def is_image_file(filename):
15 |     return any(filename.endswith(extension) for extension in ImageSuffices)
16 | 


--------------------------------------------------------------------------------
/sketch_nn/edge_map/DoG/XDoG.py:
--------------------------------------------------------------------------------
 1 | # -*- coding: utf-8 -*-
 2 | # Copyright (c) XiMing Xing. All rights reserved.
 3 | # Author: XiMing Xing
 4 | # Description:
 5 | 
 6 | import numpy as np
 7 | import cv2
 8 | from scipy import ndimage as ndi
 9 | from skimage import filters
10 | 
11 | 
12 | class XDoG:
13 | 
14 |     def __init__(self,
15 |                  gamma=0.98,
16 |                  phi=200,
17 |                  eps=-0.1,
18 |                  sigma=0.8,
19 |                  k=10,
20 |                  binarize: bool = True):
21 |         """
22 |         XDoG algorithm.
23 | 
24 |         Args:
25 |             gamma: Control the size of the Gaussian filter
26 |             phi: Control changes in edge strength
27 |             eps: Threshold for controlling edge strength
28 |             sigma: The standard deviation of the Gaussian filter controls the degree of smoothness
29 |             k: Control the size ratio of Gaussian filter, (k=10 or k=1.6)
30 |             binarize(bool): Whether to binarize the output
31 |         """
32 | 
33 |         super(XDoG, self).__init__()
34 | 
35 |         self.gamma = gamma
36 |         assert 0 <= self.gamma <= 1
37 | 
38 |         self.phi = phi
39 |         assert 0 <= self.phi <= 1500
40 | 
41 |         self.eps = eps
42 |         assert -1 <= self.eps <= 1
43 | 
44 |         self.sigma = sigma
45 |         assert 0.1 <= self.sigma <= 10
46 | 
47 |         self.k = k
48 |         assert 1 <= self.k <= 100
49 | 
50 |         self.binarize = binarize
51 | 
52 |     def __call__(self, img):
53 |         # to gray if image is not already grayscale
54 |         if len(img.shape) == 3 and img.shape[2] == 3:
55 |             img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
56 |         elif len(img.shape) == 3 and img.shape[2] == 4:
57 |             img = cv2.cvtColor(img, cv2.COLOR_BGRA2GRAY)
58 | 
59 |         if np.isnan(img).any():
60 |             img[np.isnan(img)] = np.mean(img[~np.isnan(img)])
61 | 
62 |         # gaussian filter
63 |         imf1 = ndi.gaussian_filter(img, self.sigma)
64 |         imf2 = ndi.gaussian_filter(img, self.sigma * self.k)
65 |         imdiff = imf1 - self.gamma * imf2
66 | 
67 |         # XDoG
68 |         imdiff = (imdiff < self.eps) * 1.0 + (imdiff >= self.eps) * (1.0 + np.tanh(self.phi * imdiff))
69 | 
70 |         # normalize
71 |         imdiff -= imdiff.min()
72 |         imdiff /= imdiff.max()
73 | 
74 |         if self.binarize:
75 |             th = filters.threshold_otsu(imdiff)
76 |             imdiff = (imdiff >= th).astype('float32')
77 | 
78 |         return imdiff
79 | 


--------------------------------------------------------------------------------
/sketch_nn/edge_map/DoG/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright (c) XiMing Xing. All rights reserved.
3 | # Author: XiMing Xing
4 | # Description:
5 | 
6 | from .XDoG import XDoG
7 | 
8 | __all__ = ['XDoG']
9 | 


--------------------------------------------------------------------------------
/sketch_nn/edge_map/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright (c) XiMing Xing. All rights reserved.
3 | # Author: XiMing Xing
4 | # Description:
5 | 


--------------------------------------------------------------------------------
/sketch_nn/edge_map/canny/__init__.py:
--------------------------------------------------------------------------------
 1 | # -*- coding: utf-8 -*-
 2 | # Copyright (c) XiMing Xing. All rights reserved.
 3 | # Author: XiMing Xing
 4 | # Description:
 5 | 
 6 | import cv2
 7 | 
 8 | 
 9 | class CannyDetector:
10 | 
11 |     def __call__(self, img, low_threshold, high_threshold, L2gradient=False):
12 |         return cv2.Canny(img, low_threshold, high_threshold, L2gradient)
13 | 
14 | 
15 | __all__ = ['CannyDetector']
16 | 


--------------------------------------------------------------------------------
/sketch_nn/edge_map/image_grads/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright (c) XiMing Xing. All rights reserved.
3 | # Author: XiMing Xing
4 | # Description:
5 | 
6 | from .laplacian import LaplacianDetector
7 | 
8 | __all__ = ['LaplacianDetector']
9 | 


--------------------------------------------------------------------------------
/sketch_nn/edge_map/image_grads/laplacian.py:
--------------------------------------------------------------------------------
 1 | # -*- coding: utf-8 -*-
 2 | # Copyright (c) XiMing Xing. All rights reserved.
 3 | # Author: XiMing Xing
 4 | # Description:
 5 | 
 6 | 
 7 | import cv2
 8 | 
 9 | 
10 | class LaplacianDetector:
11 | 
12 |     def __call__(self, img):
13 |         return cv2.Laplacian(img, cv2.CV_64F)
14 | 


--------------------------------------------------------------------------------
/sketch_nn/methods/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright (c) XiMing Xing. All rights reserved.
3 | # Author: XiMing Xing
4 | # Description:
5 | 


--------------------------------------------------------------------------------
/sketch_nn/methods/inversion/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright (c) XiMing Xing. All rights reserved.
3 | # Author: XiMing Xing
4 | # Description:
5 | 


--------------------------------------------------------------------------------
/sketch_nn/model/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright (c) XiMing Xing. All rights reserved.
3 | # Author: XiMing Xing
4 | # Description:
5 | 


--------------------------------------------------------------------------------
/sketch_nn/photo2sketch/InformativeDrawings/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright (c) XiMing Xing. All rights reserved.
3 | # Author: XiMing Xing
4 | # Description:
5 | 


--------------------------------------------------------------------------------
/sketch_nn/photo2sketch/InformativeDrawings/default_config.yaml:
--------------------------------------------------------------------------------
1 | input_nc: 3
2 | output_nc: 1
3 | n_blocks: 3


--------------------------------------------------------------------------------
/sketch_nn/photo2sketch/PhotoSketching/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright (c) XiMing Xing. All rights reserved.
3 | # Author: XiMing Xing
4 | # Description:
5 | 


--------------------------------------------------------------------------------
/sketch_nn/photo2sketch/PhotoSketching/base_model.py:
--------------------------------------------------------------------------------
 1 | import os
 2 | 
 3 | import torch
 4 | 
 5 | 
 6 | class BaseModel:
 7 | 
 8 |     def name(self):
 9 |         return 'BaseModel'
10 | 
11 |     def initialize(self, opt):
12 |         self.opt = opt
13 |         self.isTrain = opt.isTrain
14 |         self.device = torch.device("cuda" if opt.use_cuda else "cpu")
15 |         self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)
16 | 
17 |     def set_input(self, input):
18 |         self.input = input
19 | 
20 |     def forward(self):
21 |         pass
22 | 
23 |     # used in test time, no backprop
24 |     def test(self):
25 |         pass
26 | 
27 |     def get_image_paths(self):
28 |         pass
29 | 
30 |     def optimize_parameters(self):
31 |         pass
32 | 
33 |     def get_current_visuals(self):
34 |         return self.input
35 | 
36 |     def get_current_errors(self):
37 |         return {}
38 | 
39 |     def save(self, label):
40 |         pass
41 | 
42 |     # helper saving function that can be used by subclasses
43 |     def save_network(self, network, network_label, epoch_label):
44 |         save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
45 |         save_path = os.path.join(self.save_dir, save_filename)
46 |         torch.save(network.cpu().state_dict(), save_path)
47 |         network = network.to(self.device)
48 | 
49 |     # helper loading function that can be used by subclasses
50 |     def load_network(self, network, network_label, epoch_label):
51 |         save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
52 |         if self.opt.pretrain_path:
53 |             save_path = os.path.join(self.opt.pretrain_path, save_filename)
54 |         else:
55 |             save_path = os.path.join(self.save_dir, save_filename)
56 |         network.load_state_dict(torch.load(save_path))
57 | 
58 |     # update learning rate (called once every epoch)
59 |     def update_learning_rate(self):
60 |         for scheduler in self.schedulers:
61 |             scheduler.step()
62 |         lr = self.optimizers[0].param_groups[0]['lr']
63 |         print('learning rate = %.7f' % lr)
64 | 
65 |     def set_requires_grad(self, nets, requires_grad=False):
66 |         if not isinstance(nets, list):
67 |             nets = [nets]
68 |         for net in nets:
69 |             if net is not None:
70 |                 for param in net.parameters():
71 |                     param.requires_grad = requires_grad
72 | 


--------------------------------------------------------------------------------
/sketch_nn/photo2sketch/PhotoSketching/base_options.py:
--------------------------------------------------------------------------------
 1 | import argparse
 2 | import os
 3 | 
 4 | import torch
 5 | 
 6 | from .util import mkdirs
 7 | 
 8 | 
 9 | class BaseOptions():
10 |     def __init__(self):
11 |         self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
12 |         self.initialized = False
13 | 
14 |     def initialize(self):
15 |         self.parser.add_argument('--dataroot', required=True,
16 |                                  help='path to images (should have subfolders trainA, trainB, valA, valB, etc)')
17 |         self.parser.add_argument('--batchSize', type=int, default=1, help='input batch size')
18 |         self.parser.add_argument('--loadSize', type=int, default=286, help='scale images to this size')
19 |         self.parser.add_argument('--fineSize', type=int, default=256, help='then crop to this size')
20 |         self.parser.add_argument('--input_nc', type=int, default=3, help='# of input image channels')
21 |         self.parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels')
22 |         self.parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in first conv layer')
23 |         self.parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in first conv layer')
24 |         self.parser.add_argument('--which_model_netD', type=str, default='basic', help='selects model to use for netD')
25 |         self.parser.add_argument('--which_model_netG', type=str, default='resnet_9blocks',
26 |                                  help='selects model to use for netG')
27 |         self.parser.add_argument('--n_layers_D', type=int, default=3, help='only used if which_model_netD==n_layers')
28 |         self.parser.add_argument('--no-cuda', action='store_true', default=False,
29 |                                  help='disable CUDA training (please use CUDA_VISIBLE_DEVICES to select GPU)')
30 |         self.parser.add_argument('--name', type=str, default='experiment_name',
31 |                                  help='name of the experiment. It decides where to store samples and models')
32 |         self.parser.add_argument('--dataset_mode', type=str, default='unaligned',
33 |                                  help='chooses how datasets are loaded. [unaligned | aligned | single]')
34 |         self.parser.add_argument('--model', type=str, default='cycle_gan',
35 |                                  help='chooses which model to use. cycle_gan, pix2pix, test')
36 |         self.parser.add_argument('--which_direction', type=str, default='AtoB', help='AtoB or BtoA')
37 |         self.parser.add_argument('--nThreads', default=6, type=int, help='# threads for loading data')
38 |         self.parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
39 |         self.parser.add_argument('--norm', type=str, default='instance',
40 |                                  help='instance normalization or batch normalization')
41 |         self.parser.add_argument('--serial_batches', action='store_true',
42 |                                  help='if true, takes images in order to make batches, otherwise takes them randomly')
43 |         self.parser.add_argument('--display_winsize', type=int, default=256, help='display window size')
44 |         self.parser.add_argument('--display_id', type=int, default=1, help='window id of the web display')
45 |         self.parser.add_argument('--display_server', type=str, default="http://localhost",
46 |                                  help='visdom server of the web display')
47 |         self.parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display')
48 |         self.parser.add_argument('--no_dropout', action='store_true', help='no dropout for the generator')
49 |         self.parser.add_argument('--max_dataset_size', type=int, default=float("inf"),
50 |                                  help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')
51 |         self.parser.add_argument('--resize_or_crop', type=str, default='resize_and_crop',
52 |                                  help='scaling and cropping of images at load time [resize_and_crop|crop|scale_width|scale_width_and_crop]')
53 |         self.parser.add_argument('--no_flip', action='store_true',
54 |                                  help='if specified, do not flip the images for data augmentation')
55 |         self.parser.add_argument('--init_type', type=str, default='normal',
56 |                                  help='network initialization [normal|xavier|kaiming|orthogonal]')
57 |         self.parser.add_argument('--render_dir', type=str, default='sketch-rendered')
58 |         self.parser.add_argument('--aug_folder', type=str, default='width-5')
59 |         self.parser.add_argument('--stroke_dir', type=str, default='')
60 |         self.parser.add_argument('--crop', action='store_true')
61 |         self.parser.add_argument('--rotate', action='store_true')
62 |         self.parser.add_argument('--color_jitter', action='store_true')
63 |         self.parser.add_argument('--stroke_no_couple', action='store_true', help='')
64 |         self.parser.add_argument('--pretrain_path', type=str, default='')
65 |         self.parser.add_argument('--nGT', type=int, default=5)
66 |         self.parser.add_argument('--rot_int_max', type=int, default=3)
67 |         self.parser.add_argument('--jitter_amount', type=float, default=0.02)
68 |         self.parser.add_argument('--inverse_gamma', action='store_true')
69 |         self.parser.add_argument('--img_mean', type=float, nargs='+')
70 |         self.parser.add_argument('--img_std', type=float, nargs='+')
71 |         self.parser.add_argument('--lst_file', type=str)
72 |         self.initialized = True
73 | 
74 |     def parse(self):
75 |         if not self.initialized:
76 |             self.initialize()
77 |         self.opt = self.parser.parse_args()
78 |         self.opt.isTrain = self.isTrain  # train or test
79 | 
80 |         self.opt.use_cuda = not self.opt.no_cuda and torch.cuda.is_available()
81 |         args = vars(self.opt)
82 | 
83 |         print('------------ Options -------------')
84 |         for k, v in sorted(args.items()):
85 |             print('%s: %s' % (str(k), str(v)))
86 |         print('-------------- End ----------------')
87 | 
88 |         # save to the disk
89 |         expr_dir = os.path.join(self.opt.checkpoints_dir, self.opt.name)
90 |         mkdirs(expr_dir)
91 |         file_name = os.path.join(expr_dir, 'opt.txt')
92 |         with open(file_name, 'wt') as opt_file:
93 |             opt_file.write('------------ Options -------------\n')
94 |             for k, v in sorted(args.items()):
95 |                 opt_file.write('%s: %s\n' % (str(k), str(v)))
96 |             opt_file.write('-------------- End ----------------\n')
97 |         return self.opt
98 | 


--------------------------------------------------------------------------------
/sketch_nn/photo2sketch/PhotoSketching/default_config.yaml:
--------------------------------------------------------------------------------
1 | which_model_netG: "resnet_9blocks"
2 | input_nc: 3
3 | output_nc: 1
4 | norm: "batch"
5 | use_dropout: False
6 | n_blocks: 9
7 | ngf: 64


--------------------------------------------------------------------------------
/sketch_nn/photo2sketch/PhotoSketching/image_pool.py:
--------------------------------------------------------------------------------
 1 | import random
 2 | import numpy as np
 3 | import torch
 4 | 
 5 | 
 6 | class ImagePool():
 7 |     def __init__(self, pool_size):
 8 |         self.pool_size = pool_size
 9 |         if self.pool_size > 0:
10 |             self.num_imgs = 0
11 |             self.images = []
12 | 
13 |     def query(self, images):
14 |         if self.pool_size == 0:
15 |             return images
16 |         return_images = []
17 |         for image in images:
18 |             image = torch.unsqueeze(image, 0)
19 |             if self.num_imgs < self.pool_size:
20 |                 self.num_imgs = self.num_imgs + 1
21 |                 self.images.append(image)
22 |                 return_images.append(image)
23 |             else:
24 |                 p = random.uniform(0, 1)
25 |                 if p > 0.5:
26 |                     random_id = random.randint(0, self.pool_size-1)
27 |                     tmp = self.images[random_id].clone()
28 |                     self.images[random_id] = image
29 |                     return_images.append(tmp)
30 |                 else:
31 |                     return_images.append(image)
32 |         return_images = torch.cat(return_images, 0)
33 |         return return_images
34 | 


--------------------------------------------------------------------------------
/sketch_nn/photo2sketch/PhotoSketching/test_options.py:
--------------------------------------------------------------------------------
 1 | from .base_options import BaseOptions
 2 | 
 3 | 
 4 | class TestOptions(BaseOptions):
 5 |     def initialize(self):
 6 |         BaseOptions.initialize(self)
 7 |         self.parser.add_argument('--ntest', type=int, default=float("inf"), help='# of test examples.')
 8 |         self.parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.')
 9 |         self.parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images')
10 |         self.parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc')
11 |         self.parser.add_argument('--which_epoch', type=str, default='latest',
12 |                                  help='which epoch to load? set to latest to use latest cached model')
13 |         self.parser.add_argument('--how_many', type=int, default=50, help='how many test images to run')
14 |         self.parser.add_argument('--file_name', type=str, default='')
15 |         self.parser.add_argument('--suffix', type=str, default='')
16 |         self.isTrain = False
17 | 


--------------------------------------------------------------------------------
/sketch_nn/photo2sketch/PhotoSketching/util.py:
--------------------------------------------------------------------------------
 1 | from __future__ import print_function
 2 | import torch
 3 | import numpy as np
 4 | from PIL import Image
 5 | import inspect
 6 | import re
 7 | import numpy as np
 8 | import os
 9 | import collections
10 | 
11 | # Converts a Tensor into a Numpy array
12 | # |imtype|: the desired type of the converted numpy array
13 | def tensor2im(image_tensor, imtype=np.uint8):
14 |     image_numpy = image_tensor[0].cpu().float().numpy()
15 |     if image_numpy.shape[0] == 1:
16 |         image_numpy = np.tile(image_numpy, (3, 1, 1))
17 |     image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
18 |     return image_numpy.astype(imtype)
19 | 
20 | def tensor2im2(image_tensor, imtype=np.uint8):
21 |     image_numpy = image_tensor.detach().cpu().float().numpy()
22 |     image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
23 |     return image_numpy.astype(imtype)
24 | 
25 | def tensor2im3(image_tensor, imtype=np.uint8):
26 |     image_numpy = 1.0 - image_tensor.detach().cpu().float().numpy()
27 |     if image_numpy.shape[0] == 1:
28 |         image_numpy = np.tile(image_numpy, (3, 1, 1))
29 |     image_numpy = 255.0 * np.transpose(image_numpy, (1, 2, 0))
30 |     return image_numpy.astype(imtype)
31 | 
32 | def tensor2im4(image_tensor, img_mean, img_std, imtype=np.uint8):
33 |     image_numpy = image_tensor.detach().cpu().float().numpy()
34 |     n_channel = len(img_mean)
35 |     for c in range(n_channel):
36 |         image_numpy[c, :, :] = image_numpy[c, :, :]*img_std[c] + img_mean[c]
37 |     if image_numpy.shape[0] == 1:
38 |         image_numpy = np.tile(image_numpy, (3, 1, 1))
39 |     image_numpy = 255.0 * np.transpose(image_numpy, (1, 2, 0))
40 |     return image_numpy.astype(imtype)
41 | 
42 | def diagnose_network(net, name='network'):
43 |     mean = 0.0
44 |     count = 0
45 |     for param in net.parameters():
46 |         if param.grad is not None:
47 |             mean += torch.mean(torch.abs(param.grad.detach()))
48 |             count += 1
49 |     if count > 0:
50 |         mean = mean / count
51 |     print(name)
52 |     print(mean)
53 | 
54 | 
55 | def save_image(image_numpy, image_path):
56 |     image_pil = Image.fromarray(image_numpy)
57 |     image_pil.save(image_path)
58 | 
59 | 
60 | def print_numpy(x, val=True, shp=False):
61 |     x = x.astype(np.float64)
62 |     if shp:
63 |         print('shape,', x.shape)
64 |     if val:
65 |         x = x.flatten()
66 |         print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % (
67 |             np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x)))
68 | 
69 | 
70 | def mkdirs(paths):
71 |     if isinstance(paths, list) and not isinstance(paths, str):
72 |         for path in paths:
73 |             mkdir(path)
74 |     else:
75 |         mkdir(paths)
76 | 
77 | 
78 | def mkdir(path):
79 |     if not os.path.exists(path):
80 |         os.makedirs(path)
81 | 


--------------------------------------------------------------------------------
/sketch_nn/photo2sketch/__init__.py:
--------------------------------------------------------------------------------
 1 | # -*- coding: utf-8 -*-
 2 | # Copyright (c) XiMing Xing. All rights reserved.
 3 | # Author: XiMing Xing
 4 | # Description:
 5 | 
 6 | import os
 7 | from argparse import Namespace
 8 | from typing import Union, Dict
 9 | from functools import lru_cache
10 | from omegaconf import OmegaConf
11 | 
12 | __all__ = ["photo2sketch_model_build_util", "photo2sketch_available_models"]
13 | 
14 | _METHODS = ["PhotoSketching", "InformativeDrawings"]
15 | 
16 | 
17 | def photo2sketch_available_models():
18 |     return _METHODS
19 | 
20 | 
21 | @lru_cache()
22 | def default_config_path(dir_name: str) -> str:
23 |     return os.path.join(os.path.dirname(os.path.abspath(__file__)), dir_name, "default_config.yaml")
24 | 
25 | 
26 | def photo2sketch_model_build_util(
27 |         method: str = "PhotoSketching",
28 |         model_config: Union[Namespace, Dict] = None
29 | ):
30 |     assert method in _METHODS, f"Model {method} not recognized."
31 | 
32 |     if model_config is None:  # load default configuration
33 |         config_path = default_config_path(method)
34 |         model_config = OmegaConf.load(config_path)
35 | 
36 |     if method == "PhotoSketching":
37 |         from .PhotoSketching.networks import ResnetGenerator, get_norm_layer
38 |         norm_layer = get_norm_layer(norm_type=model_config.norm)
39 |         model = ResnetGenerator(model_config.input_nc, model_config.output_nc,
40 |                                 model_config.ngf, norm_layer, model_config.use_dropout,
41 |                                 model_config.n_blocks)
42 |         return model
43 |     elif method == "InformativeDrawings":
44 |         from .InformativeDrawings.model import Generator
45 |         model = Generator(model_config.input_nc, model_config.output_nc, model_config.n_blocks)
46 |         return model
47 |     else:
48 |         raise ModuleNotFoundError("Model [%s] not recognized." % method)
49 | 


--------------------------------------------------------------------------------
/sketch_nn/rasterize/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright (c) XiMing Xing. All rights reserved.
3 | # Author: XiMing Xing
4 | # Description:
5 | 
6 | from .rasterize import sketch_vector_rasterize


--------------------------------------------------------------------------------
/sketch_nn/rasterize/bresenham.py:
--------------------------------------------------------------------------------
 1 | # -*- coding: utf-8 -*-
 2 | # Copyright (c) XiMing Xing. All rights reserved.
 3 | # Author: XiMing Xing
 4 | # Description: Implementation of Bresenham's line drawing algorithm.
 5 | #              See en.wikipedia.org/wiki/Bresenham's_line_algorithm
 6 | 
 7 | 
 8 | def bresenham_algo(x0, y0, x1, y1):
 9 |     """
10 |     Yield integer coordinates on the line from (x0, y0) to (x1, y1).
11 |     Input coordinates should be integers.
12 | 
13 |     Examples:
14 |     >>> from bresenham import bresenham
15 |     >>> list(bresenham(-1, -4, 3, 2))
16 |     [(-1, -4), (0, -3), (0, -2), (1, -1), (2, 0), (2, 1), (3, 2)]
17 | 
18 |     Args:
19 |         x0: integer coordinates
20 |         y0: integer coordinates
21 |         x1: integer coordinates
22 |         y1: integer coordinates
23 | 
24 |     Returns:
25 |             The result will contain both the start and the end point.
26 |     """
27 |     dx = x1 - x0
28 |     dy = y1 - y0
29 | 
30 |     xsign = 1 if dx > 0 else -1
31 |     ysign = 1 if dy > 0 else -1
32 | 
33 |     dx = abs(dx)
34 |     dy = abs(dy)
35 | 
36 |     if dx > dy:
37 |         xx, xy, yx, yy = xsign, 0, 0, ysign
38 |     else:
39 |         dx, dy = dy, dx
40 |         xx, xy, yx, yy = 0, ysign, xsign, 0
41 | 
42 |     D = 2 * dy - dx
43 |     y = 0
44 | 
45 |     for x in range(dx + 1):
46 |         yield x0 + x * xx + y * yx, y0 + x * xy + y * yy
47 |         if D >= 0:
48 |             y += 1
49 |             D -= 2 * dx
50 |         D += 2 * dy
51 | 


--------------------------------------------------------------------------------
/sketch_nn/rasterize/rasterize.py:
--------------------------------------------------------------------------------
  1 | # -*- coding: utf-8 -*-
  2 | # Copyright (c) XiMing Xing. All rights reserved.
  3 | # Author: XiMing Xing
  4 | # Description:
  5 | 
  6 | import numpy as np
  7 | import scipy.ndimage
  8 | 
  9 | from .bresenham import bresenham_algo
 10 | 
 11 | 
 12 | def get_stroke_num(vector_image):
 13 |     return len(np.split(vector_image[:, :2], np.where(vector_image[:, 2])[0] + 1, axis=0)[:-1])
 14 | 
 15 | 
 16 | def select_strokes(vector_image, strokes):
 17 |     """
 18 |     select strokes
 19 |     Args:
 20 |         vector_image: vector_image(x,y,p) coordinate array
 21 |         strokes: after keeping only selected strokes
 22 | 
 23 |     Returns:
 24 | 
 25 |     """
 26 |     c = vector_image
 27 |     c_split = np.split(c[:, :2], np.where(c[:, 2])[0] + 1, axis=0)[:-1]
 28 | 
 29 |     c_selected = []
 30 |     for i in strokes:
 31 |         c_selected.append(c_split[i])
 32 | 
 33 |     xyp = []
 34 |     for i in c_selected:
 35 |         p = np.zeros((len(i), 1))
 36 |         p[-1] = 1
 37 |         xyp.append(np.hstack((i, p)))
 38 |     xyp = np.concatenate(xyp)
 39 |     return xyp
 40 | 
 41 | 
 42 | def batch_points2png(vector_images, Side=256):
 43 |     for vector_image in vector_images:
 44 |         pixel_length = 0
 45 |         # number_of_samples = random
 46 |         sample_freq = list(np.round(np.linspace(0, len(vector_image), 18)[1:]))
 47 |         Sample_len = []
 48 |         raster_images = []
 49 |         raster_image = np.zeros((int(Side), int(Side)), dtype=np.float32)
 50 |         initX, initY = int(vector_image[0, 0]), int(vector_image[0, 1])
 51 |         for i in range(0, len(vector_image)):
 52 |             if i > 0:
 53 |                 if vector_image[i - 1, 2] == 1:
 54 |                     initX, initY = int(vector_image[i, 0]), int(vector_image[i, 1])
 55 | 
 56 |             cordList = list(bresenham_algo(initX, initY, int(vector_image[i, 0]), int(vector_image[i, 1])))
 57 |             pixel_length += len(cordList)
 58 | 
 59 |             for cord in cordList:
 60 |                 if (cord[0] > 0 and cord[1] > 0) and (cord[0] < Side and cord[1] < Side):
 61 |                     raster_image[cord[1], cord[0]] = 255.0
 62 |             initX, initY = int(vector_image[i, 0]), int(vector_image[i, 1])
 63 | 
 64 |             if i in sample_freq:
 65 |                 raster_images.append(scipy.ndimage.binary_dilation(raster_image, iterations=2) * 255.0)
 66 |                 Sample_len.append(pixel_length)
 67 | 
 68 |         raster_images.append(scipy.ndimage.binary_dilation(raster_image, iterations=3) * 255.0)
 69 |         Sample_len.append(pixel_length)
 70 | 
 71 |     return raster_images
 72 | 
 73 | 
 74 | def points2png(vector_image, Side=256):
 75 |     raster_image = np.zeros((int(Side), int(Side)), dtype=np.float32)
 76 |     initX, initY = int(vector_image[0, 0]), int(vector_image[0, 1])
 77 |     pixel_length = 0
 78 | 
 79 |     for i in range(0, len(vector_image)):
 80 |         if i > 0:
 81 |             if vector_image[i - 1, 2] == 1:
 82 |                 initX, initY = int(vector_image[i, 0]), int(vector_image[i, 1])
 83 | 
 84 |         cordList = list(bresenham_algo(initX, initY, int(vector_image[i, 0]), int(vector_image[i, 1])))
 85 |         pixel_length += len(cordList)
 86 | 
 87 |         for cord in cordList:
 88 |             if (cord[0] > 0 and cord[1] > 0) and (cord[0] < Side and cord[1] < Side):
 89 |                 raster_image[cord[1], cord[0]] = 255.0
 90 |         initX, initY = int(vector_image[i, 0]), int(vector_image[i, 1])
 91 | 
 92 |     raster_image = scipy.ndimage.binary_dilation(raster_image) * 255.0
 93 |     return raster_image
 94 | 
 95 | 
 96 | def preprocess(sketch_points, side=256.0):
 97 |     sketch_points = sketch_points.astype(np.float)
 98 |     sketch_points[:, :2] = sketch_points[:, :2] / np.array([256, 256])
 99 |     sketch_points[:, :2] = sketch_points[:, :2] * side
100 |     sketch_points = np.round(sketch_points)
101 |     return sketch_points
102 | 
103 | 
104 | def sketch_vector_rasterize(sketch_points):
105 |     sketch_points = preprocess(sketch_points)
106 |     raster_images = points2png(sketch_points)
107 |     return raster_images
108 | 
109 | 
110 | def convert_to_red(image):
111 |     l = image.shape[1]
112 |     image[1] = np.zeros((l, l))
113 |     image[2] = np.zeros((l, l))
114 |     return image
115 | 
116 | 
117 | def convert_to_green(image):
118 |     l = image.shape[1]
119 |     image[0] = np.zeros((l, l))
120 |     image[2] = np.zeros((l, l))
121 |     return image
122 | 
123 | 
124 | def convert_to_blue(image):
125 |     l = image.shape[1]
126 |     image[0] = np.zeros((l, l))
127 |     image[1] = np.zeros((l, l))
128 |     return image
129 | 
130 | 
131 | def convert_to_black(image):
132 |     l = image.shape[1]
133 |     image[0] = np.zeros((l, l))
134 |     image[1] = np.zeros((l, l))
135 |     image[2] = np.zeros((l, l))
136 |     return image
137 | 


--------------------------------------------------------------------------------
/style_transfer/AdaIN/README.md:
--------------------------------------------------------------------------------
 1 | # pytorch-AdaIN
 2 | 
 3 | This is an unofficial pytorch implementation of a paper, Arbitrary Style Transfer in Real-time with Adaptive Instance Normalization [Huang+, ICCV2017].
 4 | I'm really grateful to the [original implementation](https://github.com/xunhuang1995/AdaIN-style) in Torch by the authors, which is very useful.
 5 | 
 6 | ![Results](results.png)
 7 | 
 8 | ## Requirements
 9 | Please install requirements by `pip install -r requirements.txt`
10 | 
11 | - Python 3.5+
12 | - PyTorch 0.4+
13 | - TorchVision
14 | - Pillow
15 | 
16 | (optional, for training)
17 | - tqdm
18 | - TensorboardX
19 | 
20 | ## Usage
21 | 
22 | ### Download models
23 | Download [decoder.pth](https://drive.google.com/file/d/1bMfhMMwPeXnYSQI6cDWElSZxOxc6aVyr/view?usp=sharing)/[vgg_normalized.pth](https://drive.google.com/file/d/1EpkBA2K2eYILDSyPTt0fztz59UjAIpZU/view?usp=sharing) and put them under `models/`.
24 | 
25 | ### Test
26 | Use `--content` and `--style` to provide the respective path to the content and style image.
27 | ```
28 | CUDA_VISIBLE_DEVICES= python test.py --content input/content/cornell.jpg --style input/style/woman_with_hat_matisse.jpg
29 | ```
30 | 
31 | You can also run the code on directories of content and style images using `--content_dir` and `--style_dir`. It will save every possible combination of content and styles to the output directory.
32 | ```
33 | CUDA_VISIBLE_DEVICES= python test.py --content_dir input/content --style_dir input/style
34 | ```
35 | 
36 | This is an example of mixing four styles by specifying `--style` and `--style_interpolation_weights` option.
37 | ```
38 | CUDA_VISIBLE_DEVICES= python test.py --content input/content/avril.jpg --style input/style/picasso_self_portrait.jpg,input/style/impronte_d_artista.jpg,input/style/trial.jpg,input/style/antimonocromatismo.jpg --style_interpolation_weights 1,1,1,1 --content_size 512 --style_size 512 --crop
39 | ```
40 | 
41 | Some other options:
42 | * `--content_size`: New (minimum) size for the content image. Keeping the original size if set to 0.
43 | * `--style_size`: New (minimum) size for the content image. Keeping the original size if set to 0.
44 | * `--alpha`: Adjust the degree of stylization. It should be a value between 0.0 and 1.0 (default).
45 | * `--preserve_color`: Preserve the color of the content image.
46 | 
47 | 
48 | ### Train
49 | Use `--content_dir` and `--style_dir` to provide the respective directory to the content and style images.
50 | ```
51 | CUDA_VISIBLE_DEVICES= python train.py --content_dir  --style_dir 
52 | ```
53 | 
54 | For more details and parameters, please refer to --help option.
55 | 
56 | I share the model trained by this code [here](https://drive.google.com/file/d/1YIBRdgGBoVllLhmz_N7PwfeP5V9Vz2Nr/view?usp=sharing)
57 | 
58 | ## References
59 | - [1]: X. Huang and S. Belongie. "Arbitrary Style Transfer in Real-time with Adaptive Instance Normalization.", in ICCV, 2017.
60 | - [2]: [Original implementation in Torch](https://github.com/xunhuang1995/AdaIN-style)
61 | 


--------------------------------------------------------------------------------
/style_transfer/AdaIN/__init__.py:
--------------------------------------------------------------------------------
 1 | # -*- coding: utf-8 -*-
 2 | # Copyright (c) XiMing Xing. All rights reserved.
 3 | # Author: XiMing Xing
 4 | # Description:
 5 | # URL: https://github.com/naoto0804/pytorch-AdaIN
 6 | 
 7 | from .function import coral, adaptive_instance_normalization
 8 | 
 9 | __all__ = ['coral', 'adaptive_instance_normalization']
10 | 


--------------------------------------------------------------------------------
/style_transfer/AdaIN/function.py:
--------------------------------------------------------------------------------
 1 | import torch
 2 | 
 3 | 
 4 | def calc_mean_std(feat, eps=1e-5):
 5 |     # eps is a small value added to the variance to avoid divide-by-zero.
 6 |     size = feat.size()
 7 |     assert (len(size) == 4)
 8 |     N, C = size[:2]
 9 |     feat_var = feat.view(N, C, -1).var(dim=2) + eps
10 |     feat_std = feat_var.sqrt().view(N, C, 1, 1)
11 |     feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
12 |     return feat_mean, feat_std
13 | 
14 | 
15 | def adaptive_instance_normalization(content_feat, style_feat):
16 |     assert (content_feat.size()[:2] == style_feat.size()[:2])
17 |     size = content_feat.size()
18 |     style_mean, style_std = calc_mean_std(style_feat)
19 |     content_mean, content_std = calc_mean_std(content_feat)
20 | 
21 |     normalized_feat = (content_feat - content_mean.expand(
22 |         size)) / content_std.expand(size)
23 |     return normalized_feat * style_std.expand(size) + style_mean.expand(size)
24 | 
25 | 
26 | def _calc_feat_flatten_mean_std(feat):
27 |     # takes 3D feat (C, H, W), return mean and std of array within channels
28 |     assert (feat.size()[0] == 3)
29 |     assert (isinstance(feat, torch.FloatTensor))
30 |     feat_flatten = feat.view(3, -1)
31 |     mean = feat_flatten.mean(dim=-1, keepdim=True)
32 |     std = feat_flatten.std(dim=-1, keepdim=True)
33 |     return feat_flatten, mean, std
34 | 
35 | 
36 | def _mat_sqrt(x):
37 |     U, D, V = torch.svd(x)
38 |     return torch.mm(torch.mm(U, D.pow(0.5).diag()), V.t())
39 | 
40 | 
41 | def coral(source, target):
42 |     # assume both source and target are 3D array (C, H, W)
43 |     # Note: flatten -> f
44 | 
45 |     source_f, source_f_mean, source_f_std = _calc_feat_flatten_mean_std(source)
46 |     source_f_norm = (source_f - source_f_mean.expand_as(
47 |         source_f)) / source_f_std.expand_as(source_f)
48 |     source_f_cov_eye = \
49 |         torch.mm(source_f_norm, source_f_norm.t()) + torch.eye(3)
50 | 
51 |     target_f, target_f_mean, target_f_std = _calc_feat_flatten_mean_std(target)
52 |     target_f_norm = (target_f - target_f_mean.expand_as(
53 |         target_f)) / target_f_std.expand_as(target_f)
54 |     target_f_cov_eye = \
55 |         torch.mm(target_f_norm, target_f_norm.t()) + torch.eye(3)
56 | 
57 |     source_f_norm_transfer = torch.mm(
58 |         _mat_sqrt(target_f_cov_eye),
59 |         torch.mm(torch.inverse(_mat_sqrt(source_f_cov_eye)),
60 |                  source_f_norm)
61 |     )
62 | 
63 |     source_f_transfer = source_f_norm_transfer * \
64 |                         target_f_std.expand_as(source_f_norm) + \
65 |                         target_f_mean.expand_as(source_f_norm)
66 | 
67 |     return source_f_transfer.view(source.size())
68 | 


--------------------------------------------------------------------------------
/style_transfer/AdaIN/net.py:
--------------------------------------------------------------------------------
  1 | import torch.nn as nn
  2 | 
  3 | from .function import calc_mean_std, adaptive_instance_normalization as adain
  4 | 
  5 | decoder = nn.Sequential(
  6 |     nn.ReflectionPad2d((1, 1, 1, 1)),
  7 |     nn.Conv2d(512, 256, (3, 3)),
  8 |     nn.ReLU(),
  9 |     nn.Upsample(scale_factor=2, mode='nearest'),
 10 |     nn.ReflectionPad2d((1, 1, 1, 1)),
 11 |     nn.Conv2d(256, 256, (3, 3)),
 12 |     nn.ReLU(),
 13 |     nn.ReflectionPad2d((1, 1, 1, 1)),
 14 |     nn.Conv2d(256, 256, (3, 3)),
 15 |     nn.ReLU(),
 16 |     nn.ReflectionPad2d((1, 1, 1, 1)),
 17 |     nn.Conv2d(256, 256, (3, 3)),
 18 |     nn.ReLU(),
 19 |     nn.ReflectionPad2d((1, 1, 1, 1)),
 20 |     nn.Conv2d(256, 128, (3, 3)),
 21 |     nn.ReLU(),
 22 |     nn.Upsample(scale_factor=2, mode='nearest'),
 23 |     nn.ReflectionPad2d((1, 1, 1, 1)),
 24 |     nn.Conv2d(128, 128, (3, 3)),
 25 |     nn.ReLU(),
 26 |     nn.ReflectionPad2d((1, 1, 1, 1)),
 27 |     nn.Conv2d(128, 64, (3, 3)),
 28 |     nn.ReLU(),
 29 |     nn.Upsample(scale_factor=2, mode='nearest'),
 30 |     nn.ReflectionPad2d((1, 1, 1, 1)),
 31 |     nn.Conv2d(64, 64, (3, 3)),
 32 |     nn.ReLU(),
 33 |     nn.ReflectionPad2d((1, 1, 1, 1)),
 34 |     nn.Conv2d(64, 3, (3, 3)),
 35 | )
 36 | 
 37 | vgg = nn.Sequential(
 38 |     nn.Conv2d(3, 3, (1, 1)),
 39 |     nn.ReflectionPad2d((1, 1, 1, 1)),
 40 |     nn.Conv2d(3, 64, (3, 3)),
 41 |     nn.ReLU(),  # relu1-1
 42 |     nn.ReflectionPad2d((1, 1, 1, 1)),
 43 |     nn.Conv2d(64, 64, (3, 3)),
 44 |     nn.ReLU(),  # relu1-2
 45 |     nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
 46 |     nn.ReflectionPad2d((1, 1, 1, 1)),
 47 |     nn.Conv2d(64, 128, (3, 3)),
 48 |     nn.ReLU(),  # relu2-1
 49 |     nn.ReflectionPad2d((1, 1, 1, 1)),
 50 |     nn.Conv2d(128, 128, (3, 3)),
 51 |     nn.ReLU(),  # relu2-2
 52 |     nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
 53 |     nn.ReflectionPad2d((1, 1, 1, 1)),
 54 |     nn.Conv2d(128, 256, (3, 3)),
 55 |     nn.ReLU(),  # relu3-1
 56 |     nn.ReflectionPad2d((1, 1, 1, 1)),
 57 |     nn.Conv2d(256, 256, (3, 3)),
 58 |     nn.ReLU(),  # relu3-2
 59 |     nn.ReflectionPad2d((1, 1, 1, 1)),
 60 |     nn.Conv2d(256, 256, (3, 3)),
 61 |     nn.ReLU(),  # relu3-3
 62 |     nn.ReflectionPad2d((1, 1, 1, 1)),
 63 |     nn.Conv2d(256, 256, (3, 3)),
 64 |     nn.ReLU(),  # relu3-4
 65 |     nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
 66 |     nn.ReflectionPad2d((1, 1, 1, 1)),
 67 |     nn.Conv2d(256, 512, (3, 3)),
 68 |     nn.ReLU(),  # relu4-1, this is the last layer used
 69 |     nn.ReflectionPad2d((1, 1, 1, 1)),
 70 |     nn.Conv2d(512, 512, (3, 3)),
 71 |     nn.ReLU(),  # relu4-2
 72 |     nn.ReflectionPad2d((1, 1, 1, 1)),
 73 |     nn.Conv2d(512, 512, (3, 3)),
 74 |     nn.ReLU(),  # relu4-3
 75 |     nn.ReflectionPad2d((1, 1, 1, 1)),
 76 |     nn.Conv2d(512, 512, (3, 3)),
 77 |     nn.ReLU(),  # relu4-4
 78 |     nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
 79 |     nn.ReflectionPad2d((1, 1, 1, 1)),
 80 |     nn.Conv2d(512, 512, (3, 3)),
 81 |     nn.ReLU(),  # relu5-1
 82 |     nn.ReflectionPad2d((1, 1, 1, 1)),
 83 |     nn.Conv2d(512, 512, (3, 3)),
 84 |     nn.ReLU(),  # relu5-2
 85 |     nn.ReflectionPad2d((1, 1, 1, 1)),
 86 |     nn.Conv2d(512, 512, (3, 3)),
 87 |     nn.ReLU(),  # relu5-3
 88 |     nn.ReflectionPad2d((1, 1, 1, 1)),
 89 |     nn.Conv2d(512, 512, (3, 3)),
 90 |     nn.ReLU()  # relu5-4
 91 | )
 92 | 
 93 | 
 94 | class Net(nn.Module):
 95 |     def __init__(self, encoder, decoder):
 96 |         super(Net, self).__init__()
 97 |         enc_layers = list(encoder.children())
 98 |         self.enc_1 = nn.Sequential(*enc_layers[:4])  # input -> relu1_1
 99 |         self.enc_2 = nn.Sequential(*enc_layers[4:11])  # relu1_1 -> relu2_1
100 |         self.enc_3 = nn.Sequential(*enc_layers[11:18])  # relu2_1 -> relu3_1
101 |         self.enc_4 = nn.Sequential(*enc_layers[18:31])  # relu3_1 -> relu4_1
102 |         self.decoder = decoder
103 |         self.mse_loss = nn.MSELoss()
104 | 
105 |         # fix the encoder
106 |         for name in ['enc_1', 'enc_2', 'enc_3', 'enc_4']:
107 |             for param in getattr(self, name).parameters():
108 |                 param.requires_grad = False
109 | 
110 |     # extract relu1_1, relu2_1, relu3_1, relu4_1 from input image
111 |     def encode_with_intermediate(self, input):
112 |         results = [input]
113 |         for i in range(4):
114 |             func = getattr(self, 'enc_{:d}'.format(i + 1))
115 |             results.append(func(results[-1]))
116 |         return results[1:]
117 | 
118 |     # extract relu4_1 from input image
119 |     def encode(self, input):
120 |         for i in range(4):
121 |             input = getattr(self, 'enc_{:d}'.format(i + 1))(input)
122 |         return input
123 | 
124 |     def calc_content_loss(self, input, target):
125 |         assert (input.size() == target.size())
126 |         assert (target.requires_grad is False)
127 |         return self.mse_loss(input, target)
128 | 
129 |     def calc_style_loss(self, input, target):
130 |         assert (input.size() == target.size())
131 |         assert (target.requires_grad is False)
132 |         input_mean, input_std = calc_mean_std(input)
133 |         target_mean, target_std = calc_mean_std(target)
134 |         return self.mse_loss(input_mean, target_mean) + \
135 |                self.mse_loss(input_std, target_std)
136 | 
137 |     def forward(self, content, style, alpha=1.0):
138 |         assert 0 <= alpha <= 1
139 |         style_feats = self.encode_with_intermediate(style)
140 |         content_feat = self.encode(content)
141 |         t = adain(content_feat, style_feats[-1])
142 |         t = alpha * t + (1 - alpha) * content_feat
143 | 
144 |         g_t = self.decoder(t)
145 |         g_t_feats = self.encode_with_intermediate(g_t)
146 | 
147 |         loss_c = self.calc_content_loss(g_t_feats[-1], t)
148 |         loss_s = self.calc_style_loss(g_t_feats[0], style_feats[0])
149 |         for i in range(1, 4):
150 |             loss_s += self.calc_style_loss(g_t_feats[i], style_feats[i])
151 |         return loss_c, loss_s
152 | 


--------------------------------------------------------------------------------
/style_transfer/STROTSS/README.md:
--------------------------------------------------------------------------------
 1 | # PyTorch implementation of Style Transfer by Relaxed Optimal Transport and Self-Similarity (STROTSS) with improvements
 2 | 
 3 | Implements [STROTSS](https://arxiv.org/abs/1904.12785) with sinkhorn EMD as introduced in the paper [Interactive Neural Style Transfer with artists](https://arxiv.org/pdf/2003.06659).
 4 | 
 5 | This code is inspired by [the original implementation](https://github.com/nkolkin13/STROTSS) released by the authors of STROTSS.
 6 | 
 7 | 
 8 | ## Dependencies:
 9 | * python3 >= 3.6
10 | * pytorch >= 1.0
11 | * torchvision >= 0.4
12 | * imageio >= 2.2
13 | * numpy >= 1.1
14 | 
15 | ## Usage:
16 | 
17 |   * standard
18 |     ```
19 |     python test.py -c images/content_im.jpg -s images/style_im.jpg
20 |     ```
21 |   * sinkhorn earth movers distance
22 |     ```
23 |     python test.py -c images/content_im.jpg -s images/style_im.jpg --use_sinkhorn
24 |     ```
25 |   * guidance masks
26 |     ```
27 |     python test.py -c images/content_im.jpg -s images/style_im.jpg --content_guidance images/content_guidance.jpg --style_guidance images/style_guidance
28 |     ```
29 | General usage
30 | ```
31 | python test.py
32 |     --content CONTENT
33 |     --style STYLE
34 |     [--output OUTPUT]
35 |     [--content_weight CONTENT_WEIGHT]
36 |     [--max_scale MAX_SCALE]
37 |     [--seed SEED]
38 |     [--content_guidance CONTENT_GUIDANCE]
39 |     [--style_guidance STYLE_GUIDANCE]
40 |     [--print_freq PRINT_FREQ]
41 |     [--use_sinkhorn]
42 |     [--sinkhorn_reg SINKHORN_REG]
43 |     [--sinkhorn_maxiter SINKHORN_MAXITER]
44 | ```
45 | 
46 | ## Citation
47 | 
48 | If you use this code, please cite [the original STROTSS paper](https://arxiv.org/abs/1904.12785) and
49 | ```
50 | @article{kerdreux2020interactive,
51 |   title={Interactive Neural Style Transfer with Artists},
52 |   author={Kerdreux, Thomas and Thiry, Louis and Kerdreux, Erwan},
53 |   journal={arXiv preprint arXiv:2003.06659},
54 |   year={2020}
55 | }
56 | ```
57 | 


--------------------------------------------------------------------------------
/style_transfer/STROTSS/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright (c) XiMing Xing. All rights reserved.
3 | # Author: XiMing Xing
4 | # Description:
5 | 


--------------------------------------------------------------------------------
/style_transfer/STROTSS/style_transfer.py:
--------------------------------------------------------------------------------
  1 | import glob
  2 | import time
  3 | 
  4 | import imageio
  5 | import torch
  6 | 
  7 | from . import utils
  8 | from . import vgg_pt
  9 | from . import loss_utils
 10 | 
 11 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
 12 | 
 13 | 
 14 | def style_transfer(stylized_im, content_im, style_path, output_path,
 15 |                    long_side, content_weight, content_regions, style_regions,
 16 |                    lr, print_freq=100, max_iter=250,
 17 |                    resample_freq=1, optimize_laplacian_pyramid=True,
 18 |                    use_sinkhorn=False, sinkhorn_reg=0.1, sinkhorn_maxiter=30):
 19 |     cnn = vgg_pt.Vgg16_pt().to(device)
 20 | 
 21 |     phi = lambda x: cnn.forward(x)
 22 |     phi2 = lambda x, y, z: cnn.forward_cat(x, z, samps=y, forward_func=cnn.forward)
 23 | 
 24 |     if optimize_laplacian_pyramid:
 25 |         laplacian_pyramid = utils.create_laplacian_pyramid(stylized_im, pyramid_depth=5)
 26 |         parameters = [torch.nn.Parameter(li.data, requires_grad=True) for li in laplacian_pyramid]
 27 |     else:
 28 |         parameters = [torch.nn.Parameter(stylized_im.data, requires_grad=True)]
 29 | 
 30 |     optimizer = torch.optim.RMSprop(parameters, lr=lr)
 31 | 
 32 |     content_im_cnn_features = cnn(content_im)
 33 | 
 34 |     style_image_paths = glob.glob(style_path + '*')[::3]
 35 | 
 36 |     strotss_loss = loss_utils.RelaxedOptimalTransportSelfSimilarityLoss(
 37 |         use_sinkhorn=use_sinkhorn, sinkhorn_reg=sinkhorn_reg, sinkhorn_maxiter=sinkhorn_maxiter)
 38 | 
 39 |     style_features = []
 40 |     for style_region in style_regions:
 41 |         style_features.append(utils.load_style_features(phi2, style_image_paths, style_region,
 42 |                                                         subsamps=1000, scale=long_side, inner=5))
 43 | 
 44 |     if optimize_laplacian_pyramid:
 45 |         stylized_im = utils.synthetize_image_from_laplacian_pyramid(parameters)
 46 |     else:
 47 |         stylized_im = parameters[0]
 48 | 
 49 |     resized_content_regions = []
 50 |     for content_region in content_regions:
 51 |         resized_content_region = utils.resize(torch.from_numpy(content_region),
 52 |                                               (stylized_im.size(3), stylized_im.size(2)), mode='nearest').numpy()
 53 |         resized_content_regions.append(resized_content_region.astype('bool'))
 54 | 
 55 |     for i in range(max_iter):
 56 |         if i == 200:
 57 |             optimizer = torch.optim.RMSprop(parameters, lr=0.1 * lr)
 58 | 
 59 |         optimizer.zero_grad()
 60 |         if optimize_laplacian_pyramid:
 61 |             stylized_im = utils.synthetize_image_from_laplacian_pyramid(parameters)
 62 |         else:
 63 |             stylized_im = parameters[0]
 64 | 
 65 |         if i == 0 or i % (resample_freq * 10) == 0:
 66 |             for i_region, resized_content_region in enumerate(resized_content_regions):
 67 |                 strotss_loss.init_inds(content_im_cnn_features, style_features[i_region], resized_content_region,
 68 |                                        i_region)
 69 | 
 70 |         if i == 0 or i % resample_freq == 0:
 71 |             strotss_loss.shuffle_feature_inds()
 72 | 
 73 |         stylized_im_cnn_features = cnn(stylized_im)
 74 | 
 75 |         loss = strotss_loss.eval(stylized_im_cnn_features,
 76 |                                  content_im_cnn_features, style_features,
 77 |                                  content_weight=content_weight, moment_weight=1.0)
 78 | 
 79 |         loss.backward()
 80 |         optimizer.step()
 81 | 
 82 |         if i % print_freq == 0:
 83 |             print(f'step {i}/{max_iter}, loss {loss.item():.6f}')
 84 | 
 85 |     return stylized_im, loss
 86 | 
 87 | 
 88 | def run_style_transfer(content_path, style_path, content_weight, max_scale, content_regions, style_regions,
 89 |                        output_path='./output.png', print_freq=100, use_sinkhorn=False, sinkhorn_reg=0.1,
 90 |                        sinkhorn_maxiter=30):
 91 |     smallest_size = 64
 92 |     start = time.time()
 93 | 
 94 |     content_image, style_image = utils.load_img(content_path), utils.load_img(style_path)
 95 |     _, content_H, content_W = content_image.size()
 96 |     _, style_H, style_W = style_image.size()
 97 |     print(f'content image size {content_H}x{content_W}, style image size {style_H}x{style_W}')
 98 | 
 99 |     for scale in range(1, max_scale + 1):
100 |         t0 = time.time()
101 | 
102 |         scaled_size = smallest_size * (2 ** (scale - 1))
103 | 
104 |         print('Processing scale {}/{}, size {}...'.format(scale, max_scale, scaled_size))
105 | 
106 |         content_scaled_size = (int(content_H * scaled_size / content_W), scaled_size) if content_H < content_W else (
107 |             scaled_size, int(content_W * scaled_size / content_H))
108 |         content_image_scaled = utils.resize(content_image.unsqueeze(0), content_scaled_size).to(device)
109 |         bottom_laplacian = content_image_scaled - utils.resize(utils.downsample(content_image_scaled),
110 |                                                                content_scaled_size)
111 | 
112 |         lr = 2e-3
113 |         if scale == 1:
114 |             style_image_mean = style_image.unsqueeze(0).mean(dim=(2, 3), keepdim=True).to(device)
115 |             stylized_im = style_image_mean + bottom_laplacian
116 |         elif scale > 1 and scale < max_scale:
117 |             stylized_im = utils.resize(stylized_im.clone(), content_scaled_size) + bottom_laplacian
118 |         elif scale == max_scale:
119 |             stylized_im = utils.resize(stylized_im.clone(), content_scaled_size)
120 |             lr = 1e-3
121 | 
122 |         stylized_im, final_loss = style_transfer(stylized_im, content_image_scaled, style_path, output_path,
123 |                                                  scaled_size, content_weight, content_regions, style_regions, lr,
124 |                                                  print_freq=print_freq, use_sinkhorn=use_sinkhorn,
125 |                                                  sinkhorn_reg=sinkhorn_reg, sinkhorn_maxiter=sinkhorn_maxiter)
126 | 
127 |         content_weight /= 2.0
128 |         print('...done in {:.1f} sec, final loss {:.4f}'.format(time.time() - t0, final_loss.item()))
129 | 
130 |     print('Finished in {:.1f} secs'.format(time.time() - start))
131 | 
132 |     canvas = torch.clamp(stylized_im[0], -0.5, 0.5).data.cpu().numpy().transpose(1, 2, 0)
133 |     print(f'Saving to output to {output_path}.')
134 |     imageio.imwrite(output_path, canvas)
135 | 
136 |     return final_loss, stylized_im
137 | 


--------------------------------------------------------------------------------
/style_transfer/STROTSS/test.py:
--------------------------------------------------------------------------------
 1 | import argparse
 2 | 
 3 | import imageio
 4 | import numpy as np
 5 | import torch
 6 | 
 7 | from .style_transfer import run_style_transfer
 8 | from .utils import extract_regions
 9 | 
10 | if __name__ == '__main__':
11 |     parser = argparse.ArgumentParser('Style transfer by relaxed optimal transport with sinkhorn distance')
12 |     parser.add_argument('--content', '-c', help="path of content img", required=True)
13 |     parser.add_argument('--style', '-s', help="path of style img", required=True)
14 |     parser.add_argument('--output', '-o', help="path of output img", default='output.png')
15 |     parser.add_argument('--content_weight', type=float, help='no padding used', default=0.5)
16 |     parser.add_argument('--max_scale', type=int, help='max scale for the style transfer', default=4)
17 |     parser.add_argument('--seed', type=int, help='random seed', default=0)
18 |     parser.add_argument('--content_guidance', default='', help="path of content guidance region image")
19 |     parser.add_argument('--style_guidance', default='', help="path of style guidance regions image")
20 |     parser.add_argument('--print_freq', type=int, default=100, help='print frequency for the loss')
21 |     parser.add_argument('--use_sinkhorn', action='store_true', help='use sinkhorn algo. for the earth mover distance')
22 |     parser.add_argument('--sinkhorn_reg', type=float, help='reg param for sinkhorn', default=0.1)
23 |     parser.add_argument('--sinkhorn_maxiter', type=int, default=30, help='number of interations for sinkohrn algo')
24 | 
25 |     args = parser.parse_args()
26 | 
27 |     torch.manual_seed(args.seed)
28 |     np.random.seed(args.seed)
29 |     content_weight = 16 * args.content_weight
30 |     max_scale = args.max_scale
31 |     use_guidance_region = args.content_guidance and args.style_guidance
32 | 
33 |     if use_guidance_region:
34 |         content_regions, style_regions = extract_regions(args.content_guidance, args.style_guidance)
35 |     else:
36 |         content_img, style_img = imageio.imread(args.content), imageio.imread(args.style)
37 |         content_regions, style_regions = [np.ones(content_img.shape[:2], dtype=np.float32)], [
38 |             np.ones(style_img.shape[:2], dtype=np.float32)]
39 | 
40 |     loss, canvas = run_style_transfer(args.content, args.style, content_weight,
41 |                                       max_scale, content_regions, style_regions, args.output,
42 |                                       print_freq=args.print_freq, use_sinkhorn=args.use_sinkhorn,
43 |                                       sinkhorn_reg=args.sinkhorn_reg,
44 |                                       sinkhorn_maxiter=args.sinkhorn_maxiter)
45 | 


--------------------------------------------------------------------------------
/style_transfer/STROTSS/utils.py:
--------------------------------------------------------------------------------
  1 | from PIL import Image
  2 | 
  3 | import imageio
  4 | import numpy as np
  5 | import torch
  6 | import torch.nn.functional as F
  7 | import torchvision
  8 | 
  9 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
 10 | 
 11 | 
 12 | def downsample(img, factor=2, mode='bilinear'):
 13 |     img_H, img_W = img.size(2), img.size(3)
 14 |     return F.interpolate(img, (max(img_H // factor, 1), max(img_W // factor, 1)), mode=mode)
 15 | 
 16 | 
 17 | def resize(img, size, mode='bilinear'):
 18 |     if len(img.shape) == 2:
 19 |         return F.interpolate(img.unsqueeze(0).unsqueeze(0), size, mode=mode)[0, 0]
 20 |     elif len(img.shape) == 3:
 21 |         return F.interpolate(img.unsqueeze(0), size, mode=mode)[0]
 22 |     return F.interpolate(img, size, mode=mode)
 23 | 
 24 | 
 25 | def load_img(img_path, size=None):
 26 |     img = torchvision.transforms.functional.to_tensor(Image.open(img_path).convert('RGB')) - 0.5
 27 |     if size is None:
 28 |         return img
 29 |     elif isinstance(size, (int, float)):
 30 |         return F.interpolate(img.unsqueeze(0), scale_factor=size / img.size(1), mode='bilinear')[0]
 31 |     else:
 32 |         return F.interpolate(img.unsqueeze(0), size, mode='bilinear')[0]
 33 | 
 34 | 
 35 | def create_laplacian_pyramid(image, pyramid_depth):
 36 |     laplacian_pyramid = []
 37 |     current_image = image
 38 |     for i in range(pyramid_depth):
 39 |         laplacian_pyramid.append(current_image - resize(downsample(current_image), current_image.shape[2:4]))
 40 |         current_image = downsample(current_image)
 41 |     laplacian_pyramid.append(current_image)
 42 | 
 43 |     return laplacian_pyramid
 44 | 
 45 | 
 46 | def synthetize_image_from_laplacian_pyramid(laplacian_pyramid):
 47 |     current_image = laplacian_pyramid[-1]
 48 |     for i in range(len(laplacian_pyramid) - 2, -1, -1):
 49 |         up_x = laplacian_pyramid[i].size(2)
 50 |         up_y = laplacian_pyramid[i].size(3)
 51 |         current_image = laplacian_pyramid[i] + resize(current_image, (up_x, up_y))
 52 | 
 53 |     return current_image
 54 | 
 55 | 
 56 | YUV_transform = torch.from_numpy(np.float32([
 57 |     [0.577350, 0.577350, 0.577350],
 58 |     [-0.577350, 0.788675, -0.211325],
 59 |     [-0.577350, -0.211325, 0.788675]
 60 | ])).to(device)
 61 | 
 62 | 
 63 | def rgb_to_yuv(rgb):
 64 |     global YUV_transform
 65 |     return torch.mm(YUV_transform, rgb)
 66 | 
 67 | 
 68 | def extract_regions(content_path, style_path, min_count=10000):
 69 |     style_guidance_img = imageio.imread(style_path).transpose(1, 0, 2)
 70 |     content_guidance_img = imageio.imread(content_path).transpose(1, 0, 2)
 71 | 
 72 |     color_codes, color_counts = np.unique(style_guidance_img.reshape(-1, style_guidance_img.shape[2]), axis=0,
 73 |                                           return_counts=True)
 74 | 
 75 |     color_codes = color_codes[color_counts > min_count]
 76 | 
 77 |     content_regions = []
 78 |     style_regions = []
 79 | 
 80 |     for color_code in color_codes:
 81 |         color_code = color_code[np.newaxis, np.newaxis, :]
 82 | 
 83 |         style_regions.append((np.abs(style_guidance_img - color_code).sum(axis=2) == 0).astype(np.float32))
 84 |         content_regions.append((np.abs(content_guidance_img - color_code).sum(axis=2) == 0).astype(np.float32))
 85 | 
 86 |     return [content_regions, style_regions]
 87 | 
 88 | 
 89 | def load_style_features(features_extractor, paths, style_region, subsamps=-1, scale=-1, inner=1):
 90 |     features = []
 91 | 
 92 |     for p in paths:
 93 |         style_im = load_img(p, size=scale).unsqueeze(0).to(device)
 94 | 
 95 |         r = resize(torch.from_numpy(style_region), (style_im.size(3), style_im.size(2))).numpy()
 96 | 
 97 |         # NOTE: understand inner
 98 |         for j in range(inner):
 99 |             with torch.no_grad():
100 |                 features_j = features_extractor(style_im, subsamps, r)
101 | 
102 |             features_j = [feat_j.view(feat_j.size(0), feat_j.size(1), -1, 1) for feat_j in features_j]
103 | 
104 |             if len(features) == 0:
105 |                 features = features_j
106 |             else:
107 |                 features = [torch.cat([features_j[i], features[i]], 2) for i in range(len(features))]
108 | 
109 |     return features
110 | 


--------------------------------------------------------------------------------
/style_transfer/STROTSS/vgg_pt.py:
--------------------------------------------------------------------------------
  1 | import numpy as np
  2 | 
  3 | import torch
  4 | import torch.nn.functional as F
  5 | import torchvision
  6 | 
  7 | 
  8 | class Vgg16_pt(torch.nn.Module):
  9 |     def __init__(self, requires_grad=False, use_random=True):
 10 |         super(Vgg16_pt, self).__init__()
 11 |         # load pretrained model
 12 |         self.vgg_layers = torchvision.models.vgg16(
 13 |             weights=torchvision.models.VGG16_Weights.DEFAULT
 14 |         ).features
 15 |         self.use_random = use_random
 16 | 
 17 |         if not requires_grad:
 18 |             for param in self.parameters():
 19 |                 param.requires_grad = False
 20 | 
 21 |         self.inds = range(11)
 22 |         self.layer_indices = [1, 3, 6, 8, 11, 13, 15, 22, 29]
 23 | 
 24 |     def forward_base(self, X):
 25 |         l2 = [X]
 26 |         x = X
 27 |         for i in range(30):
 28 |             x = self.vgg_layers[i].forward(x)
 29 |             if i in self.layer_indices:
 30 |                 l2.append(x)
 31 | 
 32 |         return l2
 33 | 
 34 |     def forward(self, X):
 35 |         return self.forward_base(X)
 36 | 
 37 |     def forward_cat(self, X, r, samps=100, forward_func=None):
 38 | 
 39 |         if not forward_func:
 40 |             forward_func = self.forward
 41 | 
 42 |         x = X
 43 |         out2 = forward_func(X)
 44 | 
 45 |         try:
 46 |             r = r[:, :, 0]
 47 |         except:
 48 |             pass
 49 | 
 50 |         if r.max() < 0.1:
 51 |             region_mask = np.greater(r.flatten() + 1., 0.5)
 52 |         else:
 53 |             region_mask = np.greater(r.flatten(), 0.5)
 54 | 
 55 |         xx, xy = np.meshgrid(np.array(range(x.size(2))), np.array(range(x.size(3))))
 56 |         xx = np.expand_dims(xx.flatten(), 1)
 57 |         xy = np.expand_dims(xy.flatten(), 1)
 58 |         xc = np.concatenate([xx, xy], 1)
 59 |         xc = xc[region_mask, :]
 60 | 
 61 |         const2 = min(samps, xc.shape[0])
 62 | 
 63 |         if self.use_random:
 64 |             np.random.shuffle(xc)
 65 |         else:
 66 |             xc = xc[::(xc.shape[0] // const2), :]
 67 | 
 68 |         xx = xc[:const2, 0]
 69 |         yy = xc[:const2, 1]
 70 | 
 71 |         temp = X
 72 |         temp_list = [temp[:, :, xx[j], yy[j]].unsqueeze(2).unsqueeze(3) for j in range(const2)]
 73 |         temp = torch.cat(temp_list, 2)
 74 | 
 75 |         l2 = []
 76 |         for i in range(len(out2)):
 77 | 
 78 |             temp = out2[i]
 79 | 
 80 |             if i > 0 and out2[i].size(2) < out2[i - 1].size(2):
 81 |                 xx = xx / 2.0
 82 |                 yy = yy / 2.0
 83 | 
 84 |             xx = np.clip(xx, 0, temp.size(2) - 1).astype(np.int32)
 85 |             yy = np.clip(yy, 0, temp.size(3) - 1).astype(np.int32)
 86 | 
 87 |             temp_list = [temp[:, :, xx[j], yy[j]].unsqueeze(2).unsqueeze(3) for j in range(const2)]
 88 |             temp = torch.cat(temp_list, 2)
 89 | 
 90 |             l2.append(temp.clone().detach())
 91 | 
 92 |         out2 = [torch.cat([li.contiguous() for li in l2], 1)]
 93 | 
 94 |         return out2
 95 | 
 96 |     def forward_diff(self, X):
 97 |         l2 = self.forward_base(X)
 98 | 
 99 |         out2 = [l2[i].contiguous() for i in self.inds]
100 | 
101 |         for i in range(len(out2)):
102 |             temp = out2[i]
103 |             temp2 = F.pad(temp, (2, 2, 0, 0), value=1.)
104 |             temp3 = F.pad(temp, (0, 0, 2, 2), value=1.)
105 |             out2[i] = torch.cat(
106 |                 [temp, temp2[:, :, :, 4:], temp2[:, :, :, :-4], temp3[:, :, 4:, :], temp3[:, :, :-4, :]], 1)
107 | 
108 |         return out2
109 | 


--------------------------------------------------------------------------------
/style_transfer/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright (c) XiMing Xing. All rights reserved.
3 | # Author: XiMing Xing
4 | # Description:
5 | 


--------------------------------------------------------------------------------