├── .gitignore ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── data_process.py ├── evaluate.py ├── figs └── generation.png ├── generate_creative_birds.py ├── generate_creative_creatures.py ├── inception.py ├── part_generator.py ├── part_selector.py ├── requirements.txt ├── run_part_generator.py ├── run_part_selector.py └── training_scripts ├── train_creative_birds ├── bird_short_creative_clf_aug.sh ├── bird_short_creative_sequential_unet_partonly_beak.sh ├── bird_short_creative_sequential_unet_partonly_body.sh ├── bird_short_creative_sequential_unet_partonly_eye.sh ├── bird_short_creative_sequential_unet_partonly_head.sh ├── bird_short_creative_sequential_unet_partonly_legs.sh ├── bird_short_creative_sequential_unet_partonly_mouth.sh ├── bird_short_creative_sequential_unet_partonly_tail.sh └── bird_short_creative_sequential_unet_partonly_wings.sh └── train_creative_creatures ├── bird_short_creative_sequential_unet_partonly_arms.sh ├── bird_short_creative_sequential_unet_partonly_beak.sh ├── bird_short_creative_sequential_unet_partonly_body.sh ├── bird_short_creative_sequential_unet_partonly_ears.sh ├── bird_short_creative_sequential_unet_partonly_eye.sh ├── bird_short_creative_sequential_unet_partonly_feet.sh ├── bird_short_creative_sequential_unet_partonly_fin.sh ├── bird_short_creative_sequential_unet_partonly_hair.sh ├── bird_short_creative_sequential_unet_partonly_hands.sh ├── bird_short_creative_sequential_unet_partonly_head.sh ├── bird_short_creative_sequential_unet_partonly_horns.sh ├── bird_short_creative_sequential_unet_partonly_legs.sh ├── bird_short_creative_sequential_unet_partonly_mouth.sh ├── bird_short_creative_sequential_unet_partonly_nose.sh ├── bird_short_creative_sequential_unet_partonly_paws.sh ├── bird_short_creative_sequential_unet_partonly_tail.sh ├── bird_short_creative_sequential_unet_partonly_wings.sh └── generic_long_creative_clf_aug.sh /.gitignore: -------------------------------------------------------------------------------- 1 | # perconslized folder and files 2 | sketchrnn/tmp/ 3 | data/tmp/ 4 | stylegan2/debug/ 5 | stylegan2/__pycache__/ 6 | # Byte-compiled / optimized / DLL files 7 | *__pycache__/ 8 | *.py[cod] 9 | *$py.class 10 | __pycache__/ 11 | 12 | # C extensions 13 | *.so 14 | 15 | # Distribution / packaging 16 | .Python 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | lib/ 24 | lib64/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | share/python-wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .nox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *.cover 54 | *.py,cover 55 | .hypothesis/ 56 | .pytest_cache/ 57 | cover/ 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Django stuff: 64 | *.log 65 | local_settings.py 66 | db.sqlite3 67 | db.sqlite3-journal 68 | 69 | # Flask stuff: 70 | instance/ 71 | .webassets-cache 72 | 73 | # Scrapy stuff: 74 | .scrapy 75 | 76 | # Sphinx documentation 77 | docs/_build/ 78 | 79 | # PyBuilder 80 | .pybuilder/ 81 | target/ 82 | 83 | # Jupyter Notebook 84 | .ipynb_checkpoints 85 | 86 | # IPython 87 | profile_default/ 88 | ipython_config.py 89 | 90 | # pyenv 91 | # For a library or package, you might want to ignore these files since the code is 92 | # intended to run in multiple environments; otherwise, check them in: 93 | # .python-version 94 | 95 | # pipenv 96 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 97 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 98 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 99 | # install all needed dependencies. 100 | #Pipfile.lock 101 | 102 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 103 | __pypackages__/ 104 | 105 | # Celery stuff 106 | celerybeat-schedule 107 | celerybeat.pid 108 | 109 | # SageMath parsed files 110 | *.sage.py 111 | 112 | # Environments 113 | .env 114 | .venv 115 | env/ 116 | venv/ 117 | ENV/ 118 | env.bak/ 119 | venv.bak/ 120 | 121 | # Spyder project settings 122 | .spyderproject 123 | .spyproject 124 | 125 | # Rope project settings 126 | .ropeproject 127 | 128 | # mkdocs documentation 129 | /site 130 | 131 | # mypy 132 | .mypy_cache/ 133 | .dmypy.json 134 | dmypy.json 135 | 136 | # Pyre type checker 137 | .pyre/ 138 | 139 | # pytype static type analyzer 140 | .pytype/ 141 | 142 | # Cython debug symbols 143 | cython_debug/ 144 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Open Source Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as contributors and maintainers pledge to make participation in our project and our community a harassment-free experience for everyone, regardless of age, body size, disability, ethnicity, sex characteristics, gender identity and expression, level of experience, education, socio-economic status, nationality, personal appearance, race, religion, or sexual identity and orientation. 6 | 7 | ## Our Standards 8 | 9 | Examples of behavior that contributes to creating a positive environment include: 10 | 11 | Using welcoming and inclusive language 12 | Being respectful of differing viewpoints and experiences 13 | Gracefully accepting constructive criticism 14 | Focusing on what is best for the community 15 | Showing empathy towards other community members 16 | Examples of unacceptable behavior by participants include: 17 | 18 | The use of sexualized language or imagery and unwelcome sexual attention or advances 19 | Trolling, insulting/derogatory comments, and personal or political attacks 20 | Public or private harassment 21 | Publishing others’ private information, such as a physical or electronic address, without explicit permission 22 | Other conduct which could reasonably be considered inappropriate in a professional setting 23 | 24 | ## Our Responsibilities 25 | 26 | Project maintainers are responsible for clarifying the standards of acceptable behavior and are expected to take appropriate and fair corrective action in response to any instances of unacceptable behavior. 27 | 28 | Project maintainers have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, or to ban temporarily or permanently any contributor for other behaviors that they deem inappropriate, threatening, offensive, or harmful. 29 | 30 | ## Scope 31 | 32 | This Code of Conduct applies within all project spaces, and it also applies when an individual is representing the project or its community in public spaces. Examples of representing a project or community include using an official project e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event. Representation of a project may be further defined and clarified by project maintainers. 33 | 34 | ## Enforcement 35 | 36 | Instances of abusive, harassing, or otherwise unacceptable behavior may be reported by contacting the project team at opensource-conduct@fb.com. All complaints will be reviewed and investigated and will result in a response that is deemed necessary and appropriate to the circumstances. The project team is obligated to maintain confidentiality with regard to the reporter of an incident. Further details of specific enforcement policies may be posted separately. 37 | 38 | Project maintainers who do not follow or enforce the Code of Conduct in good faith may face temporary or permanent repercussions as determined by other members of the project’s leadership. 39 | 40 | ## Attribution 41 | 42 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 43 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 44 | 45 | [homepage]: https://www.contributor-covenant.org 46 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to DoodlerGAN 2 | 3 | We want to make contributing to this project as easy and transparent as 4 | possible. 5 | 6 | ## Pull Requests 7 | 8 | We actively welcome your pull requests. 9 | 10 | 1. Fork the repo and create your branch from `master`. 11 | 2. If you've added code that should be tested, add tests. 12 | 3. If you've changed APIs, update the documentation. 13 | 4. Ensure the test suite passes. 14 | 5. Make sure your code lints. 15 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 16 | 17 | ## Contributor License Agreement ("CLA") 18 | 19 | In order to accept your pull request, we need you to submit a CLA. You only need 20 | to do this once to work on any of Facebook's open source projects. 21 | 22 | Complete your CLA here: 23 | 24 | ## Issues 25 | 26 | We use GitHub issues to track public bugs. Please ensure your description is 27 | clear and has sufficient instructions to be able to reproduce the issue. 28 | 29 | Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe 30 | disclosure of security bugs. In those cases, please go through the process 31 | outlined on that page and do not file a public issue. 32 | 33 | ## Coding Style 34 | 35 | - 2 spaces for indentation rather than tabs 36 | - 80 character line length 37 | - ... 38 | 39 | ## License 40 | 41 | By contributing to DoodlerGAN, you agree that your contributions will be licensed 42 | under the LICENSE file in the root directory of this source tree. 43 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Facebook, Inc. and its affiliates. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Creative Sketch Generation (DoodlerGAN) 2 | 3 | Paper: https://arxiv.org/abs/2011.10039 4 | 5 | Demos: http://doodlergan.cloudcv.org/ 6 | 7 | Datasets: https://songweige.github.io/projects/creative_sketech_generation/gallery_creatures.html 8 | 9 | Project Page: https://songweige.github.io/projects/creative_sketech_generation/home.html 10 | 11 | DoodlerGAN is a part-based Generative Adversarial Network (GAN) designed to generate creative sketches with unseen compositions of novel part appearances. Concretely, DoodlerGAN contains two modules: the part generator and the part selector. Given a part-based representation of a partial sketch, the part selector predicts which part category to draw next. Given a part-based representation of a partial sketch and a part category, the part generator generates a raster image of the part (which represents both the appearance and location of the part). Some randomly selected generation with DoodlerGAN trained on Creative Birds and Creative Creatures dataset are shown below. 12 | 13 | ![Generated Sketches](figs/generation.png) 14 | 15 | ## Preparation 16 | 17 | First, create the enviroment with Anaconda. Install Pytorch and the other packages listed in requirements.txt. The code is tested with PyTorch 1.3.1 and CUDA 10.0: 18 | 19 | ``` 20 | mkdir creative_sketch_generation creative_sketch_generation/data creative_sketch_generation/results creative_sketch_generation/models 21 | cd creative_sketch_generation 22 | git clone git@github.com:fairinternal/AI-doodler.git 23 | conda create -n doodler python=3.7 24 | conda activate doodler 25 | conda install pytorch==1.3.1 -c pytorch 26 | pip install -r requirements.txt 27 | ``` 28 | 29 | Next, download our processed Creative Birds and Creative Creatures datasets from the GoogleDrive: https://drive.google.com/drive/folders/14ZywlSE-khagmSz23KKFbLCQLoMOxPzl?usp=sharing and unzip the folders under the directory `creative_sketch_generation/data/`. 30 | 31 | To process the raw data from the scratch, check the scripts in `data_process.py`. 32 | 33 | ## Usage 34 | 35 | ### Training 36 | 37 | Refer to the `training_scripts` folder for the scripts that reproduce our results. Example usages of training part generators and part selectors are shown below: 38 | 39 | ``` 40 | python run_part_generator.py --new --results_dir ../results --models_dir ../models --n_part 10 --data ../data/ird_short_wings_json_64 --name short_bird_creative_wings --num_train_steps 300000 --batch_size 40 --network_capacity 16 --gradient_accumulate_every 1 --save_every 2000 --image_size 64 --sparsity_penalty 0.01 --learning_rate_D 1e-4 --learning_rate_G 1e-4 41 | python run_part_generator.py --new --results_dir ../results --models_dir ../models --large_aug --n_part 19 --data ../data/generic_long_legs_json_64 --name long_generic_creative_legs --batch_size 40 --gradient_accumulate_every 1 --network_capacity 16 --save_every 2000 --image_size 64 --sparsity_penalty 0.01 --learning_rate_D 1e-4 --learning_rate_G 1e-4 --num_train_steps 600000 42 | python run_part_selector.py --new --results_dir ../results --models_dir ../models --n_part 10 --data ../data/bird_short_ --name short_bird_creative_selector --batch_size 128 --save_every 1000 --image_size 64 43 | ``` 44 | 45 | ### Inference 46 | 47 | The part generators and part selector are used iteratively to complete the entire sketche given random initial strokes during the inference. To generate a `[num_image_tiles x num_image_tiles]` grid to visualize the generations based on the trained model, one can use the following scripts. We also release our trained models on the GoogleDrive. 48 | 49 | ``` 50 | python generate_creative_birds.py --models_dir ../models --results_dir ../results/creative_bird_generation --data_dir ../data --num_image_tiles 8 51 | python generate_creative_creatures.py --models_dir ../models --results_dir ../results/creative_creature_generation --data_dir ../data --num_image_tiles 10 52 | ``` 53 | 54 | To generate 10,000 sketches for quantitative evaluation, use `--generate_all` flag as below. The script will automatically create three folders under `results_dir`: `DoodlerGAN_all/bw`, `DoodlerGAN_all/color`, and `DoodlerGAN_all/color_initial`, which include the generations in grayscale, or with different parts or colored initial stroke colored. 55 | 56 | ``` 57 | python generate_creative_birds.py --generate_all --models_dir ../models --results_dir ../results/creative_bird_generation --data_dir ../data --num_image_tiles 8 58 | python generate_creative_creatures.py --generate_all --models_dir ../models --results_dir ../results/creative_creature_generation --data_dir ../data --num_image_tiles 10 59 | ``` 60 | 61 | ### Quantitative Evaluation 62 | 63 | We analyze the quality and novelty of the generations with four metrics: Frechet inception distances (FID), generation diversity (GD), characteristic score (CS) and semantic diversity score (SDS). Please refer to our papers for more details of the metrics. To run the evaluation, use the following script with indicated generator directory and real image directory: 64 | 65 | ``` 66 | python evaluate.py training_dir generation_dir --gpu 1 --name birds 67 | ``` 68 | 69 | ### PNG to SVG Conversion 70 | 71 | For png output to svg conversion first install the following packages: 72 | 73 | ```bash 74 | apt-get install imagemagick 75 | apt-get install potrace 76 | ``` 77 | 78 | Once the packages are installed, a png image can be converted to svg using the following command: 79 | 80 | ```bash 81 | 82 | convert input.png bmp:- | mkbitmap - -t 0.20 -o - | potrace --svg --group -o - > output.svg 83 | ``` 84 | 85 | ## Citation 86 | ``` 87 | @misc{ge2020creative, 88 | title={Creative Sketch Generation}, 89 | author={Songwei Ge and Vedanuj Goswami and C. Lawrence Zitnick and Devi Parikh}, 90 | year={2020}, 91 | eprint={2011.10039}, 92 | archivePrefix={arXiv}, 93 | primaryClass={cs.CV} 94 | } 95 | ``` 96 | 97 | ## License 98 | 99 | DoodlerGAN is MIT licensed, as found in the LICENSE file. 100 | -------------------------------------------------------------------------------- /data_process.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import os 7 | import cv2 8 | import json 9 | import copy 10 | import numpy as np 11 | import cairocffi as cairo 12 | 13 | 14 | def vector_to_raster(vector_images, part_label=False, nodetail=False, side=64, line_diameter=16, padding=16, bg_color=(1,1,1), fg_color=(0,0,0)): 15 | """ 16 | padding and line_diameter are relative to the original 512x512 image. 17 | """ 18 | original_side = 512. 19 | surface = cairo.ImageSurface(cairo.FORMAT_ARGB32, side, side) 20 | ctx = cairo.Context(surface) 21 | ctx.set_antialias(cairo.ANTIALIAS_BEST) 22 | ctx.set_line_cap(cairo.LINE_CAP_ROUND) 23 | ctx.set_line_join(cairo.LINE_JOIN_ROUND) 24 | ctx.set_line_width(line_diameter) 25 | # scale to match the new size 26 | # add padding at the edges for the line_diameter 27 | # and add additional padding to account for antialiasing 28 | total_padding = padding * 2. + line_diameter 29 | new_scale = float(side) / float(original_side + total_padding) 30 | ctx.scale(new_scale, new_scale) 31 | ctx.translate(total_padding / 2., total_padding / 2.) 32 | raster_images = [] 33 | for i, vector_data in enumerate(vector_images): 34 | # clear background 35 | ctx.set_source_rgb(*bg_color) 36 | ctx.paint() 37 | vector_image = [] 38 | x_max = y_max = 0 39 | for step in vector_data['all_strokes']: 40 | vector_image.append([]) # for each step 41 | for stroke in step: 42 | if len(stroke) == 0: # skip the empty stroke 43 | vector_image[-1].append([]) 44 | continue 45 | vector_image[-1].append(np.array([stroke[0][:2]]+[point[2:4] for point in stroke])) # add each stroke N x 2 46 | x_max_stroke, y_max_stroke = np.max(vector_image[-1][-1], 0) 47 | x_max = x_max_stroke if x_max_stroke>x_max else x_max 48 | y_max = y_max_stroke if y_max_stroke>y_max else y_max 49 | offset = ((original_side, original_side) - np.array([x_max, y_max])) / 2. 50 | offset = offset.reshape(1,2) 51 | for j in range(len(vector_image)): 52 | for k in range(len(vector_image[j])): 53 | vector_image[j][k] = vector_image[j][k]+offset if len(vector_image[j][k]) > 0 else vector_image[j][k] 54 | # draw strokes, this is the most cpu-intensive part 55 | ctx.set_source_rgb(*fg_color) 56 | for j, step in enumerate(vector_image): 57 | if part_label: 58 | ctx.set_source_rgb(*COLORS[vector_data['partsUsed'][j]]) 59 | if nodetail and j == len(vector_image)-1 and vector_data['partsUsed'][j] == 'details': 60 | continue 61 | for stroke in step: 62 | if len(stroke) == 0: 63 | continue 64 | ctx.move_to(stroke[0][0], stroke[0][1]) 65 | for x, y in stroke: 66 | ctx.line_to(x, y) 67 | ctx.stroke() 68 | surface_data = surface.get_data() 69 | if part_label: 70 | raster_image = np.copy(np.asarray(surface_data)).reshape(side, side, 4)[:, :, :3] 71 | else: 72 | raster_image = np.copy(np.asarray(surface_data))[::4].reshape(side, side) 73 | raster_images.append(raster_image) 74 | return raster_images 75 | 76 | 77 | def vector_image_to_vector_part(vector_images, target_part, side=64, line_diameter=16, padding=16, data_name='bird'): 78 | """ 79 | save processed vector image for target_parts: input partial images, input parts and target images with target parts 80 | """ 81 | original_side = 512. 82 | # scale to match the new size 83 | # add padding at the edges for the line_diameter 84 | # and add additional padding to account for antialiasing 85 | total_padding = padding * 2. + line_diameter 86 | new_scale = float(side) / float(original_side + total_padding) 87 | processed_vector_input_parts = [] 88 | processed_vector_parts = [] 89 | # each item in processed_vector_images is a list that corresponds to all target parts that appear in that sketch 90 | for i, vector_data in enumerate(vector_images): 91 | # check if target part is drawn 92 | processed_vector_input_parts.append([]) 93 | processed_vector_parts.append([]) 94 | # store the strokes for each part 95 | if data_name == 'bird': 96 | strokes_input_parts = {'initial':[], 'eye':[], 'beak':[], 'body':[], 'head':[], 'legs':[], 'mouth':[], 'tail':[], 'wings':[]} 97 | elif data_name == 'creature': 98 | strokes_input_parts = {'initial':[], 'eye':[], 'arms':[], 'beak':[], 'mouth':[], 'body':[], 'ears':[], 'feet':[], 'fin':[], 'hair':[], 'hands':[], 99 | 'head':[], 'horns':[], 'legs':[], 'nose':[], 'paws':[], 'tail':[], 'wings':[]} 100 | if target_part not in vector_data['partsUsed']: 101 | continue 102 | vector_image = [] 103 | x_max = y_max = 0 104 | for step in vector_data['all_strokes']: 105 | vector_image.append([]) # for each step 106 | for stroke in step: 107 | if len(stroke) == 0: # skip the empty stroke 108 | vector_image[-1].append([]) 109 | continue 110 | vector_image[-1].append(np.array([stroke[0][:2]]+[point[2:4] for point in stroke])) # add each stroke N x 2 111 | x_max_stroke, y_max_stroke = np.max(vector_image[-1][-1], 0) 112 | x_max = x_max_stroke if x_max_stroke>x_max else x_max 113 | y_max = y_max_stroke if y_max_stroke>y_max else y_max 114 | offset = ((original_side, original_side) - np.array([x_max, y_max])) / 2. 115 | offset = offset.reshape(1,2) 116 | for j in range(len(vector_image)): 117 | for k in range(len(vector_image[j])): 118 | vector_image[j][k] = vector_image[j][k]+offset if len(vector_image[j][k]) > 0 else vector_image[j][k] 119 | # save strokes 120 | for j, step in enumerate(vector_image): 121 | if vector_data['partsUsed'][j] == target_part: # find one part 122 | processed_vector_input_parts[-1].append(copy.deepcopy(strokes_input_parts)) 123 | if j != len(vector_image)-1 and vector_data['partsUsed'][j] != 'details': # last one and details 124 | strokes_input_parts[vector_data['partsUsed'][j]] += step 125 | else: 126 | continue 127 | if vector_data['partsUsed'][j] == target_part: 128 | # record the input + part 129 | processed_vector_parts[-1].append(step) 130 | # record all the parts 131 | processed_vector_input_parts[-1].append(copy.deepcopy(strokes_input_parts)) 132 | return processed_vector_input_parts, processed_vector_parts 133 | 134 | 135 | 136 | ######################################################################################################################## 137 | ######################################################################################################################## 138 | # basic setups, load data 139 | data_name = 'bird' # or 'creature' 140 | side=64 # size of the rendered image 141 | 142 | ## data format: ['assignment_id', 'hit_id', 'worker_id', 'output', 'submit_time'] 143 | ## 'output' --> ['all_strokes', 'prompts', 'comment', 'description', 'partsUsed'] 144 | if data_name == 'bird': 145 | COLORS = {'initial':np.array([45, 169, 145])/255., 'eye':np.array([243, 156, 18])/255., 'none':np.array([149, 165, 166])/255., 146 | 'beak':np.array([211, 84, 0])/255., 'body':np.array([41, 128, 185])/255., 'details':np.array([171, 190, 191])/255., 147 | 'head':np.array([192, 57, 43])/255., 'legs':np.array([142, 68, 173])/255., 'mouth':np.array([39, 174, 96])/255., 148 | 'tail':np.array([69, 85, 101])/255., 'wings':np.array([127, 140, 141])/255.} 149 | part_to_id = {'initial': 0, 'eye': 1, 'beak': 2, 'body': 3, 'head': 4, 'legs': 5, 'mouth': 6, 'tail': 7, 'wings': 8} 150 | target_parts = ['eye', 'beak', 'body', 'head', 'legs', 'mouth', 'tail', 'wings', 'details'] 151 | data = json.loads(open('raw_data_clean/creative_birds_json.txt').read()) 152 | elif data_name == 'creature': 153 | COLORS = {'initial':np.array([45, 169, 145])/255., 'eye':np.array([243, 156, 18])/255., 'none':np.array([149, 165, 166])/255., 154 | 'arms':np.array([211, 84, 0])/255., 'beak':np.array([41, 128, 185])/255., 'mouth':np.array([54, 153, 219])/255., 155 | 'body':np.array([192, 57, 43])/255., 'ears':np.array([142, 68, 173])/255., 'feet':np.array([39, 174, 96])/255., 156 | 'fin':np.array([69, 85, 101])/255., 'hair':np.array([127, 140, 141])/255., 'hands':np.array([45, 63, 81])/255., 157 | 'head':np.array([241, 197, 17])/255., 'horns':np.array([51, 205, 117])/255., 'legs':np.array([232, 135, 50])/255., 158 | 'nose':np.array([233, 90, 75])/255., 'paws':np.array([160, 98, 186])/255., 'tail':np.array([58, 78, 99])/255., 159 | 'wings':np.array([198, 203, 207])/255., 'details':np.array([171, 190, 191])/255.} 160 | part_to_id = {'initial': 0, 'eye': 1, 'arms': 2, 'beak': 3, 'mouth': 4, 'body': 5, 'ears': 6, 'feet': 7, 'fin': 8, 161 | 'hair': 9, 'hands': 10, 'head': 11, 'horns': 12, 'legs': 13, 'nose': 14, 'paws': 15, 'tail': 16, 'wings':17} 162 | target_parts = ['arms', 'beak', 'mouth', 'body', 'eye', 'ears', 'feet', 'fin', 'hair', 'hands', 163 | 'head', 'horns', 'legs', 'nose', 'paws', 'tail', 'wings', 'details'] 164 | data = json.loads(open('raw_data_clean/creative_creatures_json.txt').read()) 165 | data = [json.loads(line) for j in range(1, 12) for line in open('raw_data/doodle_generic_%d.txt'%j)] 166 | wid_rej = [line.rstrip() for line in open('raw_data/reject_generic_workids_all.txt')] 167 | 168 | 169 | ######################################################################################################################## 170 | # visualize all the sketches by rendering raster images 171 | raster_images_gs = vector_to_raster(data, part_label=False, nodetail=True, side=side, line_diameter=3, padding=16, bg_color=(0,0,0), fg_color=(1,1,1)) 172 | raster_images_rgb = vector_to_raster(data, part_label=True, nodetail=True, side=side, line_diameter=3, padding=16, bg_color=(1,1,1), fg_color=(0,0,0)) 173 | 174 | outpath = os.path.join('data/%s_short_full_%d'%(data_name, side)) 175 | outpath_rgb = os.path.join('data/%s_short_full_rgb_%d'%(data_name, side)) 176 | if not os.path.exists(outpath): 177 | os.mkdir(outpath) 178 | os.mkdir(outpath_rgb) 179 | 180 | 181 | for i, (raster_image, raster_image_rgb) in enumerate(zip(raster_images_gs[:100], raster_images_rgb[:100])): 182 | if not data[i]['good_sample']: 183 | continue 184 | cv2.imwrite(os.path.join(outpath, "sketch_%s.png"%i), raster_image) 185 | cv2.imwrite(os.path.join(outpath_rgb, "sketch_%s.png"%i), raster_image_rgb) 186 | 187 | 188 | descriptions = [item['description'].strip() for item in data if item['good_sample']] 189 | with open('%s_description.json'%data_name, 'w') as fp: 190 | json.dump(descriptions, fp) 191 | 192 | ######################################################################################################################## 193 | ## process vectors images for doodlerGAN 194 | for target_part in target_parts: 195 | print('rendering %s...'%target_part) 196 | vector_input_parts, vector_parts = vector_image_to_vector_part(data, target_part=target_part, side=side, line_diameter=16, padding=16, data_name=data_name) 197 | outpath_train = 'data/%s_short_%s_json_%d_train'%(data_name, target_part, side) 198 | outpath_test = 'data/%s_short_%s_json_%d_test'%(data_name, target_part, side) 199 | if not os.path.exists(outpath_test): 200 | os.mkdir(outpath_test) 201 | os.mkdir(outpath_train) 202 | for i in range(len(data)-500): 203 | if not data[i]['good_sample']: 204 | continue 205 | if len(vector_input_parts[i]) == 0: 206 | continue 207 | for j in range(len(vector_input_parts[i])-1): 208 | if data_name == 'bird': 209 | json_data = {'input_parts':{'initial': [], 'eye': [], 'head': [], 'body': [], 'beak': [], 'legs': [], 'wings': [], 'mouth': [], 'tail': []}, 'target_part':[]} 210 | elif data_name == 'creature': 211 | json_data = {'input_parts':{'initial':[], 'eye':[], 'arms':[], 'beak':[], 'mouth':[], 'body':[], 'ears':[], 'feet':[], 'fin':[], 'hair':[], 'hands':[], 212 | 'head':[], 'horns':[], 'legs':[], 'nose':[], 'paws':[], 'tail':[], 'wings':[]}, 'target_part':[]} 213 | if target_part != 'none': 214 | json_data['target_part'] = [item.tolist() for item in vector_parts[i][j] if len(item) > 0] 215 | for key in vector_input_parts[i][j].keys(): 216 | json_data['input_parts'][key] = [item.tolist() for item in vector_input_parts[i][j][key] if len(item) > 0] 217 | with open(outpath_train+"/sketch%d_%d.json"%(i, j), 'w') as fw: 218 | json.dump(json_data, fw) 219 | for i in range(len(data)-500, len(data)): 220 | if not data[i]['good_sample']: 221 | continue 222 | if len(vector_input_parts[i]) == 0: 223 | continue 224 | for j in range(len(vector_input_parts[i])-1): 225 | if data_name == 'bird': 226 | json_data = {'input_parts':{'initial': [], 'eye': [], 'head': [], 'body': [], 'beak': [], 'legs': [], 'wings': [], 'mouth': [], 'tail': []}, 'target_part':[]} 227 | elif data_name == 'creature': 228 | json_data = {'input_parts':{'initial':[], 'eye':[], 'arms':[], 'beak':[], 'mouth':[], 'body':[], 'ears':[], 'feet':[], 'fin':[], 'hair':[], 'hands':[], 229 | 'head':[], 'horns':[], 'legs':[], 'nose':[], 'paws':[], 'tail':[], 'wings':[]}, 'target_part':[]} 230 | if target_part != 'none': 231 | json_data['target_part'] = [item.tolist() for item in vector_parts[i][j] if len(item) > 0] 232 | for key in vector_input_parts[i][j].keys(): 233 | json_data['input_parts'][key] = [item.tolist() for item in vector_input_parts[i][j][key] if len(item) > 0] 234 | with open(outpath_test+"/sketch%d_%d.json"%(i, j), 'w') as fw: 235 | json.dump(json_data, fw) -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | 8 | """Calculates the Frechet Inception Distance (FID) to evalulate GANs 9 | 10 | The FID metric calculates the distance between two distributions of images. 11 | Typically, we have summary statistics (mean & covariance matrix) of one 12 | of these distributions, while the 2nd distribution is given by a GAN. 13 | 14 | When run as a stand-alone program, it compares the distribution of 15 | images that are stored as PNG/JPEG at a specified location with a 16 | distribution given by summary statistics (in pickle format). 17 | 18 | The FID is calculated by assuming that X_1 and X_2 are the activations of 19 | the pool_3 layer of the inception net for generated samples and real world 20 | samples respectively. 21 | 22 | See --help to see further details. 23 | 24 | Code apapted from https://github.com/bioinf-jku/TTUR to use PyTorch instead 25 | of Tensorflow 26 | 27 | Copyright 2018 Institute of Bioinformatics, JKU Linz 28 | 29 | Licensed under the Apache License, Version 2.0 (the "License"); 30 | you may not use this file except in compliance with the License. 31 | You may obtain a copy of the License at 32 | 33 | http://www.apache.org/licenses/LICENSE-2.0 34 | 35 | Unless required by applicable law or agreed to in writing, software 36 | distributed under the License is distributed on an "AS IS" BASIS, 37 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 38 | See the License for the specific language governing permissions and 39 | limitations under the License. 40 | """ 41 | import os 42 | import cv2 43 | import json 44 | import pathlib 45 | from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter 46 | import torchvision 47 | 48 | import numpy as np 49 | import torch 50 | from scipy import linalg 51 | from torch.nn.functional import adaptive_avg_pool2d 52 | import torch.nn.functional as F 53 | 54 | from PIL import Image 55 | 56 | try: 57 | from tqdm import tqdm 58 | except ImportError: 59 | # If not tqdm is not available, provide a mock version of it 60 | def tqdm(x): return x 61 | 62 | from inception import InceptionV3 63 | 64 | parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) 65 | parser.add_argument('path', type=str, nargs=2, 66 | help=('Path to the generated images or ' 67 | 'to .npz statistic files')) 68 | parser.add_argument('--batch-size', type=int, default=50, 69 | help='Batch size to use') 70 | parser.add_argument('--dims', type=int, default=2048, 71 | choices=list(InceptionV3.BLOCK_INDEX_BY_DIM), 72 | help=('Dimensionality of Inception features to use. ' 73 | 'By default, uses pool3 features')) 74 | parser.add_argument('-c', '--gpu', default='', type=str, 75 | help='GPU to use (leave blank for CPU only)') 76 | parser.add_argument('--name', default='birds', type=str, 77 | help='which dataset to be evluated', choices=['birds', 'creatures']) 78 | 79 | 80 | with open('../data/id_to_class.json', 'r') as fp: 81 | ID2CLASS = json.load(fp) 82 | ID2CLASS ={int(k): v for k, v in ID2CLASS.items()} 83 | 84 | B_SET = ['bird', 'duck', 'flamingo', 'parrot'] 85 | C_SET = ['ant', 'bear', 'bee', 'bird', 'butterfly', 'camel', 'cat', 'cow', 'crab', 'crocodile', 'dog', 'dolphin', 'duck', 86 | 'elephant', 'fish', 'flamingo', 'frog', 'giraffe', 'hedgehog', 'horse', 'kangaroo', 'lion', 'lobster', 'monkey', 'mosquito', 87 | 'mouse', 'octopus', 'owl', 'panda', 'parrot', 'penguin', 'pig', 'rabbit', 'raccoon', 'rhinoceros', 'scorpion', 'sea_turtle', 88 | 'shark', 'sheep', 'snail', 'snake', 'spider', 'squirrel', 'swan', 'tiger', 'whale', 'zebra'] 89 | 90 | def imread(filename): 91 | """ 92 | Loads an image file into a (height, width, 3) uint8 ndarray. 93 | """ 94 | return np.asarray(Image.open(filename), dtype=np.uint8) 95 | 96 | 97 | def resize(sketch): 98 | x_nonzero, y_nonzero = np.where(sketch>0) 99 | try: 100 | coord_min = min(x_nonzero.min(), y_nonzero.min()) 101 | coord_max = max(x_nonzero.max(), y_nonzero.max()) 102 | sketch_new = np.zeros([64, 64]) 103 | sketch_cropped = cv2.resize(sketch[coord_min:coord_max, coord_min:coord_max], (60, 60)) 104 | sketch_new[2:-2, 2:-2] = sketch_cropped 105 | except: 106 | sketch_new = sketch 107 | return sketch_new 108 | 109 | def resize_batch(sketches): 110 | return np.array([resize(sketch) for sketch in sketches]) 111 | 112 | def get_acts_and_preds(files, model, batch_size=50, dims=2048, 113 | cuda=False, verbose=False, name='birds'): 114 | """Calculates the activations of the pool_3 layer for all images. 115 | 116 | Params: 117 | -- files : List of image files paths 118 | -- model : Instance of inception model 119 | -- batch_size : Batch size of images for the model to process at once. 120 | Make sure that the number of samples is a multiple of 121 | the batch size, otherwise some samples are ignored. This 122 | behavior is retained to match the original FID score 123 | implementation. 124 | -- dims : Dimensionality of features returned by Inception 125 | -- cuda : If set to True, use GPU 126 | -- verbose : If set to True and parameter out_step is given, the number 127 | of calculated batches is reported. 128 | -- name : The name of the dataset: for birds we calculate CS only and for creatures we also calculate SDS. 129 | """ 130 | if name == 'birds': 131 | target_set = B_SET 132 | elif name == 'creatures': 133 | target_set = C_SET 134 | 135 | model.eval() 136 | 137 | if batch_size > len(files): 138 | print(('Warning: batch size is bigger than the data size. ' 139 | 'Setting batch size to data size')) 140 | batch_size = len(files) 141 | 142 | pred_arr = np.empty((len(files), dims)) 143 | preds_final_arr = {} 144 | logits_arr = torch.zeros(345).cuda() 145 | 146 | for i in tqdm(range(0, len(files), batch_size)): 147 | if verbose: 148 | print('\rPropagating batch %d/%d' % (i + 1, n_batches), 149 | end='', flush=True) 150 | start = i 151 | end = i + batch_size 152 | 153 | images = np.array([imread(str(f)).astype(np.float32) for f in files[start:end]]) 154 | images = images/255. 155 | images = 1-images 156 | images[images<0.1] = 0 157 | 158 | # Reshape to (n_images, 3, height, width) 159 | if len(images.shape) == 4: 160 | images = images.transpose((0, 3, 1, 2)) 161 | elif len(images.shape) == 3: 162 | images = np.expand_dims(images, 1) 163 | 164 | batch = torch.from_numpy(images).type(torch.FloatTensor) 165 | if cuda: 166 | batch = batch.cuda() 167 | 168 | batch = F.interpolate(batch, size=(299, 299), mode='bilinear', align_corners=False) 169 | 170 | # store the model predictions 171 | logits = model.inception(batch) 172 | _, final_preds = torch.max(logits, 1) 173 | logits = F.softmax(logits, 1) 174 | for logit, final_pred in zip(logits, final_preds): 175 | logits_arr += logit 176 | pred_class = ID2CLASS[final_pred.item()] 177 | if pred_class in preds_final_arr: 178 | preds_final_arr[pred_class] += 1 179 | else: 180 | preds_final_arr[pred_class] = 1 181 | 182 | 183 | pred = model(batch)[0] 184 | 185 | # If model output is not scalar, apply global spatial average pooling. 186 | # This happens if you choose a dimensionality not equal 2048. 187 | if pred.size(2) != 1 or pred.size(3) != 1: 188 | pred = adaptive_avg_pool2d(pred, output_size=(1, 1)) 189 | 190 | pred_arr[start:end] = pred.cpu().data.numpy().reshape(pred.size(0), -1) 191 | 192 | # calculate CS and SDS 193 | characteristic_count = 0. 194 | total_count = 0. 195 | for class_name in preds_final_arr: 196 | total_count += preds_final_arr[class_name] 197 | if class_name not in target_set: 198 | continue 199 | characteristic_count += preds_final_arr[class_name] 200 | CS = characteristic_count/total_count 201 | probs_all = logits_arr / total_count 202 | # import ipdb;ipdb.set_trace() 203 | if name == 'creatures': 204 | C_prob = sum([probs_all[cl_id].item() for cl_id in range(345) if ID2CLASS[cl_id] in C_SET]) 205 | CCS = sum([-probs_all[cl_id].item()*np.log(probs_all[cl_id].item()/C_prob) for cl_id in range(345) if ID2CLASS[cl_id] in C_SET]) 206 | else: 207 | CCS = 0. 208 | 209 | if verbose: 210 | print(' done') 211 | return pred_arr, CS, CCS 212 | 213 | 214 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): 215 | """Numpy implementation of the Frechet Distance. 216 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) 217 | and X_2 ~ N(mu_2, C_2) is 218 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). 219 | 220 | Stable version by Dougal J. Sutherland. 221 | 222 | Params: 223 | -- mu1 : Numpy array containing the activations of a layer of the 224 | inception net (like returned by the function 'get_predictions') 225 | for generated samples. 226 | -- mu2 : The sample mean over activations, precalculated on an 227 | representative data set. 228 | -- sigma1: The covariance matrix over activations for generated samples. 229 | -- sigma2: The covariance matrix over activations, precalculated on an 230 | representative data set. 231 | 232 | Returns: 233 | -- : The Frechet Distance. 234 | """ 235 | 236 | mu1 = np.atleast_1d(mu1) 237 | mu2 = np.atleast_1d(mu2) 238 | 239 | sigma1 = np.atleast_2d(sigma1) 240 | sigma2 = np.atleast_2d(sigma2) 241 | 242 | assert mu1.shape == mu2.shape, \ 243 | 'Training and test mean vectors have different lengths' 244 | assert sigma1.shape == sigma2.shape, \ 245 | 'Training and test covariances have different dimensions' 246 | 247 | diff = mu1 - mu2 248 | 249 | # Product might be almost singular 250 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) 251 | if not np.isfinite(covmean).all(): 252 | msg = ('fid calculation produces singular product; ' 253 | 'adding %s to diagonal of cov estimates') % eps 254 | print(msg) 255 | offset = np.eye(sigma1.shape[0]) * eps 256 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) 257 | 258 | # Numerical error might give slight imaginary component 259 | if np.iscomplexobj(covmean): 260 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 261 | m = np.max(np.abs(covmean.imag)) 262 | raise ValueError('Imaginary component {}'.format(m)) 263 | covmean = covmean.real 264 | 265 | tr_covmean = np.trace(covmean) 266 | 267 | return (diff.dot(diff) + np.trace(sigma1) + 268 | np.trace(sigma2) - 2 * tr_covmean) 269 | 270 | 271 | def calculate_acts_and_preds(files, model, batch_size=50, 272 | dims=2048, cuda=False, verbose=False, name='birds'): 273 | """Calculation of the statistics used by the FID and diversity, CS, SDS. 274 | Params: 275 | -- files : List of image files paths 276 | -- model : Instance of inception model 277 | -- batch_size : The images numpy array is split into batches with 278 | batch size batch_size. A reasonable batch size 279 | depends on the hardware. 280 | -- dims : Dimensionality of features returned by Inception 281 | -- cuda : If set to True, use GPU 282 | -- verbose : If set to True and parameter out_step is given, the 283 | number of calculated batches is reported. 284 | Returns: 285 | -- mu : The mean over samples of the activations of the pool_3 layer of 286 | the inception model. 287 | -- sigma : The covariance matrix of the activations of the pool_3 layer of 288 | the inception model. 289 | -- diversity : average pairwise distances between samples. 290 | -- CS : characteristic score. 291 | -- SDS : semantic diversity score. 292 | """ 293 | assert name in ['birds', 'creatures'] 294 | act, CS, SDS = get_acts_and_preds(files, model, batch_size, dims, cuda, verbose, name) 295 | mu = np.mean(act, axis=0) 296 | sigma = np.cov(act, rowvar=False) 297 | diversity = cal_diversity(act) 298 | # import ipdb;ipdb.set_trace() 299 | return mu, sigma, diversity, CS, SDS 300 | 301 | 302 | def cal_diversity(act): 303 | n_sample = min(act.shape[0], 1000) 304 | act = act[:n_sample] 305 | n_part = n_sample*(n_sample-1)/2 306 | score = 0. 307 | for i in range(n_sample): 308 | for j in range(i+1, n_sample): 309 | score += np.sqrt(np.sum((act[i]-act[j])**2)) 310 | return score/n_part 311 | 312 | 313 | def _compute_statistics_of_path(path, model, batch_size, dims, cuda, name): 314 | if path.endswith('.npz'): 315 | f = np.load(path) 316 | m, s = f['mu'][:], f['sigma'][:] 317 | f.close() 318 | else: 319 | path = pathlib.Path(path) 320 | files = list(path.glob('*.jpg')) + list(path.glob('*.png')) 321 | m, s, diversity, CS, SDS = calculate_acts_and_preds(files, model, batch_size, 322 | dims, cuda, False, name) 323 | return m, s, diversity, CS, SDS 324 | 325 | 326 | def calculate_scores_given_paths(paths, batch_size, cuda, dims, name): 327 | """Calculates the FID of two paths""" 328 | for p in paths: 329 | if not os.path.exists(p): 330 | raise RuntimeError('Invalid path: %s' % p) 331 | 332 | block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims] 333 | 334 | model = InceptionV3([block_idx], normalize_input=False, use_fid_inception=False) 335 | if cuda: 336 | model.cuda() 337 | 338 | m1, s1, d1, CS1, SDS1 = _compute_statistics_of_path(paths[0], model, batch_size, 339 | dims, cuda, name) 340 | m2, s2, d2, CS2, SDS2 = _compute_statistics_of_path(paths[1], model, batch_size, 341 | dims, cuda, name) 342 | # import ipdb;ipdb.set_trace() 343 | fid_value = calculate_frechet_distance(m1, s1, m2, s2) 344 | 345 | return fid_value, d1, d2, CS1, CS2, SDS1, SDS2 346 | 347 | 348 | if __name__ == '__main__': 349 | args = parser.parse_args() 350 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 351 | 352 | fid_value, d1, d2, CS1, CS2, SDS1, SDS2 = calculate_scores_given_paths(args.path, 353 | args.batch_size, 354 | args.gpu != '', 355 | args.dims, 356 | args.name) 357 | print('FID: ', fid_value) 358 | if args.name == 'birds': 359 | print('Diversity 1: %.2f, characteristic score 1: %.2f'%(d1, CS1)) 360 | print('Diversity 2: %.2f, characteristic score 2: %.2f'%(d2, CS2)) 361 | elif args.name == 'creatures': 362 | print('Diversity 1: %.2f, characteristic score 1: %.2f, semantic diversity score 1: %.2f'%(d1, CS1, SDS1)) 363 | print('Diversity 2: %.2f, characteristic score 2: %.2f, semantic diversity score 2: %.2f'%(d2, CS2, SDS2)) 364 | 365 | -------------------------------------------------------------------------------- /figs/generation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/DoodlerGAN/f0b1f9ff936d0f7438146ff7a79174b76b5eaa10/figs/generation.png -------------------------------------------------------------------------------- /generate_creative_birds.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import os 7 | import cv2 8 | import torch 9 | import numpy as np 10 | import argparse 11 | import torchvision 12 | from PIL import Image 13 | from tqdm import tqdm 14 | from pathlib import Path 15 | from datetime import datetime 16 | from retry.api import retry_call 17 | from torch.utils import data 18 | from torchvision import transforms 19 | from part_selector import Trainer as Trainer_selector 20 | from part_generator import Trainer as Trainer_cond_unet 21 | from scipy.ndimage.morphology import distance_transform_edt 22 | 23 | COLORS = {'initial':1-torch.cuda.FloatTensor([45, 169, 145]).view(1, -1, 1, 1)/255., 'eye':1-torch.cuda.FloatTensor([243, 156, 18]).view(1, -1, 1, 1)/255., 'none':1-torch.cuda.FloatTensor([149, 165, 166]).view(1, -1, 1, 1)/255., 24 | 'beak':1-torch.cuda.FloatTensor([211, 84, 0]).view(1, -1, 1, 1)/255., 'body':1-torch.cuda.FloatTensor([41, 128, 185]).view(1, -1, 1, 1)/255., 'details':1-torch.cuda.FloatTensor([171, 190, 191]).view(1, -1, 1, 1)/255., 25 | 'head':1-torch.cuda.FloatTensor([192, 57, 43]).view(1, -1, 1, 1)/255., 'legs':1-torch.cuda.FloatTensor([142, 68, 173]).view(1, -1, 1, 1)/255., 'mouth':1-torch.cuda.FloatTensor([39, 174, 96]).view(1, -1, 1, 1)/255., 26 | 'tail':1-torch.cuda.FloatTensor([69, 85, 101]).view(1, -1, 1, 1)/255., 'wings':1-torch.cuda.FloatTensor([127, 140, 141]).view(1, -1, 1, 1)/255.} 27 | 28 | class Initialstroke_Dataset(data.Dataset): 29 | def __init__(self, folder, image_size): 30 | super().__init__() 31 | self.folder = folder 32 | self.image_size = image_size 33 | self.paths = [p for p in Path(f'{folder}').glob(f'**/*.png')] 34 | self.transform = transforms.Compose([ 35 | transforms.ToTensor(), 36 | ]) 37 | 38 | def __len__(self): 39 | return len(self.paths) 40 | 41 | def __getitem__(self, index): 42 | path = self.paths[index] 43 | img = self.transform(Image.open(path)) 44 | return img 45 | 46 | def sample(self, n): 47 | sample_ids = [np.random.randint(self.__len__()) for _ in range(n)] 48 | samples = [self.transform(Image.open(self.paths[sample_id])) for sample_id in sample_ids] 49 | return torch.stack(samples).cuda() 50 | 51 | def load_latest(model_dir, name): 52 | model_dir = Path(model_dir) 53 | file_paths = [p for p in Path(model_dir / name).glob('model_*.pt')] 54 | saved_nums = sorted(map(lambda x: int(x.stem.split('_')[1]), file_paths)) 55 | if len(saved_nums) == 0: 56 | return 57 | name = saved_nums[-1] 58 | print(f'continuing from previous epoch - {name}') 59 | return name 60 | 61 | 62 | def noise(n, latent_dim): 63 | return torch.randn(n, latent_dim).cuda() 64 | 65 | def noise_list(n, layers, latent_dim): 66 | return [(noise(n, latent_dim), layers)] 67 | 68 | def mixed_list(n, layers, latent_dim): 69 | tt = int(torch.rand(()).numpy() * layers) 70 | return noise_list(n, tt, latent_dim) + noise_list(n, layers - tt, latent_dim) 71 | 72 | def image_noise(n, im_size): 73 | return torch.FloatTensor(n, im_size, im_size, 1).uniform_(0., 1.).cuda() 74 | 75 | def evaluate_in_chunks(max_batch_size, model, *args): 76 | split_args = list(zip(*list(map(lambda x: x.split(max_batch_size, dim=0), args)))) 77 | chunked_outputs = [model(*i) for i in split_args] 78 | if len(chunked_outputs) == 1: 79 | return chunked_outputs[0] 80 | return torch.cat(chunked_outputs, dim=0) 81 | 82 | def evaluate_in_chunks_unet(max_batch_size, model, map_feats, *args): 83 | split_args = list(zip(*list(map(lambda x: x.split(max_batch_size, dim=0), args)))) 84 | split_map_feats = list(zip(*list(map(lambda x: x.split(max_batch_size, dim=0), map_feats)))) 85 | chunked_outputs = [model(*i, j) for i, j in zip(split_args, split_map_feats)] 86 | if len(chunked_outputs) == 1: 87 | return chunked_outputs[0] 88 | return torch.cat(chunked_outputs, dim=0) 89 | 90 | def styles_def_to_tensor(styles_def): 91 | return torch.cat([t[:, None, :].expand(-1, n, -1) for t, n in styles_def], dim=1) 92 | 93 | def gs_to_rgb(image, color): 94 | image_rgb = image.repeat(1, 3, 1, 1) 95 | return 1-image_rgb*color 96 | 97 | @torch.no_grad() 98 | def generate_truncated(S, G, style, noi, trunc_psi = 0.75, num_image_tiles = 8, bitmap_feats=None, batch_size=8): 99 | latent_dim = G.latent_dim 100 | z = noise(2000, latent_dim) 101 | samples = evaluate_in_chunks(batch_size, S, z).cpu().numpy() 102 | av = np.mean(samples, axis = 0) 103 | av = np.expand_dims(av, axis = 0) 104 | 105 | w_space = [] 106 | for tensor, num_layers in style: 107 | tmp = S(tensor) 108 | av_torch = torch.from_numpy(av).cuda() 109 | # import ipdb;ipdb.set_trace() 110 | tmp = trunc_psi * (tmp - av_torch) + av_torch 111 | w_space.append((tmp, num_layers)) 112 | 113 | w_styles = styles_def_to_tensor(w_space) 114 | generated_images = evaluate_in_chunks_unet(batch_size, G, bitmap_feats, w_styles, noi) 115 | return generated_images.clamp_(0., 1.) 116 | 117 | 118 | @torch.no_grad() 119 | def generate_part(model, partial_image, partial_rgb, color=None, percentage=20, num=0, num_image_tiles=8, trunc_psi=1., save_img=False, results_dir='../results', evolvement=False): 120 | model.eval() 121 | ext = 'png' 122 | num_rows = np.sqrt(num_image_tiles) 123 | latent_dim = model.G.latent_dim 124 | image_size = model.G.image_size 125 | num_layers = model.G.num_layers 126 | if percentage == 'eye': 127 | n_eye = 10 128 | generated_partial_images_candidates = [] 129 | scores = torch.zeros(n_eye) 130 | for _ in range(n_eye): 131 | latents_z = noise_list(num_image_tiles, num_layers, latent_dim) 132 | n = image_noise(num_image_tiles, image_size) 133 | image_partial_batch = partial_image[:, -1:, :, :] 134 | bitmap_feats = model.Enc(partial_image) 135 | generated_partial_images = generate_truncated(model.S, model.G, latents_z, n, trunc_psi = trunc_psi, bitmap_feats=bitmap_feats) 136 | generated_partial_images_candidates.append(generated_partial_images) 137 | generated_partial_images_candidates = torch.cat(generated_partial_images_candidates, 0) 138 | # eye size rank 139 | n_pixels = generated_partial_images_candidates.sum(-1).sum(-1).sum(-1) # B 140 | for rank, i_eye in enumerate(torch.argsort(n_pixels, descending=True)): 141 | scores[i_eye] += (rank+1)/n_eye 142 | # eye distance rank 143 | initial_stroke = partial_image[:, :1].cpu().data.numpy() 144 | initial_stroke_dt = torch.cuda.FloatTensor(distance_transform_edt(1-initial_stroke)) 145 | dt_pixels = (generated_partial_images_candidates*initial_stroke_dt).sum(-1).sum(-1).sum(-1) # B 146 | for rank, i_eye in enumerate(torch.argsort(dt_pixels, descending=False)): # the smaller the better 147 | if n_pixels[i_eye] > 3: 148 | scores[i_eye] += (rank+1)/n_eye 149 | generated_partial_images = generated_partial_images_candidates[torch.argsort(scores, descending=True)[0]].unsqueeze(0) 150 | else: 151 | # latents and noise 152 | latents_z = noise_list(num_image_tiles, num_layers, latent_dim) 153 | n = image_noise(num_image_tiles, image_size) 154 | image_partial_batch = partial_image[:, -1:, :, :] 155 | bitmap_feats = model.Enc(partial_image) 156 | generated_partial_images = generate_truncated(model.S, model.G, latents_z, n, trunc_psi = trunc_psi, bitmap_feats=bitmap_feats) 157 | # regular 158 | generated_partial_images = generate_truncated(model.S, model.G, latents_z, n, trunc_psi = trunc_psi, bitmap_feats=bitmap_feats) 159 | generated_partial_rgb = gs_to_rgb(generated_partial_images, color) 160 | generated_images = generated_partial_images + image_partial_batch 161 | generated_rgb = 1 - ((1-generated_partial_rgb)+(1-partial_rgb)) 162 | if save_img: 163 | torchvision.utils.save_image(generated_partial_rgb, os.path.join(results_dir, f'{str(num)}-{percentage}-comp.{ext}'), nrow=num_rows) 164 | torchvision.utils.save_image(generated_rgb, os.path.join(results_dir, f'{str(num)}-{percentage}.{ext}'), nrow=num_rows) 165 | return generated_partial_images.clamp_(0., 1.), generated_images.clamp_(0., 1.), generated_partial_rgb.clamp_(0., 1.), generated_rgb.clamp_(0., 1.) 166 | 167 | 168 | def train_from_folder( 169 | data_path = '../../data', 170 | results_dir = '../../results', 171 | models_dir = '../../models', 172 | n_part = 1, 173 | image_size = 128, 174 | network_capacity = 16, 175 | batch_size = 3, 176 | num_image_tiles = 8, 177 | trunc_psi = 0.75, 178 | generate_all=False, 179 | ): 180 | min_step = 299 181 | name_eye='short_bird_creative_sequential_r6_partstack_aug_eye_unet2_largeaug' 182 | load_from = load_latest(models_dir, name_eye) 183 | load_from = min(min_step, load_from) 184 | model_eye = Trainer_cond_unet(name_eye, results_dir, models_dir, n_part=n_part, batch_size=batch_size, image_size=image_size, network_capacity=network_capacity) 185 | model_eye.load_config() 186 | model_eye.GAN.load_state_dict(torch.load('%s/%s/model_%d.pt'%(models_dir, name_eye, load_from))) 187 | 188 | name_head='short_bird_creative_sequential_r6_partstack_aug_head_unet2' 189 | load_from = load_latest(models_dir, name_head) 190 | load_from = min(min_step, load_from) 191 | model_head = Trainer_cond_unet(name_head, results_dir, models_dir, n_part=n_part, batch_size=batch_size, image_size=image_size, network_capacity=network_capacity) 192 | model_head.load_config() 193 | model_head.GAN.load_state_dict(torch.load('%s/%s/model_%d.pt'%(models_dir, name_head, load_from))) 194 | 195 | name_body='short_bird_creative_sequential_r6_partstack_aug_body_unet2' 196 | load_from = load_latest(models_dir, name_body) 197 | load_from = min(min_step, load_from) 198 | model_body = Trainer_cond_unet(name_body, results_dir, models_dir, n_part=n_part, batch_size=batch_size, image_size=image_size, network_capacity=network_capacity) 199 | model_body.load_config() 200 | model_body.GAN.load_state_dict(torch.load('%s/%s/model_%d.pt'%(models_dir, name_body, load_from))) 201 | 202 | name_beak='short_bird_creative_sequential_r6_partstack_aug_beak_unet2' 203 | load_from = load_latest(models_dir, name_beak) 204 | load_from = min(min_step, load_from) 205 | model_beak = Trainer_cond_unet(name_beak, results_dir, models_dir, n_part=n_part, batch_size=batch_size, image_size=image_size, network_capacity=network_capacity) 206 | model_beak.load_config() 207 | model_beak.GAN.load_state_dict(torch.load('%s/%s/model_%d.pt'%(models_dir, name_beak, load_from))) 208 | 209 | name_legs='short_bird_creative_sequential_r6_partstack_aug_legs_unet2' 210 | load_from = load_latest(models_dir, name_legs) 211 | load_from = min(min_step, load_from) 212 | model_legs = Trainer_cond_unet(name_legs, results_dir, models_dir, n_part=n_part, batch_size=batch_size, image_size=image_size, network_capacity=network_capacity) 213 | model_legs.load_config() 214 | model_legs.GAN.load_state_dict(torch.load('%s/%s/model_%d.pt'%(models_dir, name_legs, load_from))) 215 | 216 | name_wings='short_bird_creative_sequential_r6_partstack_aug_wings_unet2' 217 | load_from = load_latest(models_dir, name_wings) 218 | load_from = min(min_step, load_from) 219 | model_wings = Trainer_cond_unet(name_wings, results_dir, models_dir, n_part=n_part, batch_size=batch_size, image_size=image_size, network_capacity=network_capacity) 220 | model_wings.load_config() 221 | model_wings.GAN.load_state_dict(torch.load('%s/%s/model_%d.pt'%(models_dir, name_wings, load_from))) 222 | 223 | name_mouth='short_bird_creative_sequential_r6_partstack_aug_mouth_unet2' 224 | load_from = load_latest(models_dir, name_mouth) 225 | load_from = min(min_step, load_from) 226 | model_mouth = Trainer_cond_unet(name_mouth, results_dir, models_dir, n_part=n_part, batch_size=batch_size, image_size=image_size, network_capacity=network_capacity) 227 | model_mouth.load_config() 228 | model_mouth.GAN.load_state_dict(torch.load('%s/%s/model_%d.pt'%(models_dir, name_mouth, load_from))) 229 | 230 | name_tail='short_bird_creative_sequential_r6_partstack_aug_tail_unet2' 231 | load_from = load_latest(models_dir, name_tail) 232 | load_from = min(min_step, load_from) 233 | model_tail = Trainer_cond_unet(name_tail, results_dir, models_dir, n_part=n_part, batch_size=batch_size, image_size=image_size, network_capacity=network_capacity) 234 | model_tail.load_config() 235 | model_tail.GAN.load_state_dict(torch.load('%s/%s/model_%d.pt'%(models_dir, name_tail, load_from))) 236 | 237 | 238 | name_selector='short_bird_creative_selector_aug' 239 | load_from = load_latest(models_dir, name_selector) 240 | part_selector = Trainer_selector(name_selector, results_dir, models_dir, n_part=n_part, batch_size = batch_size, image_size = image_size, network_capacity = network_capacity) 241 | part_selector.load_config() 242 | part_selector.clf.load_state_dict(torch.load('%s/%s/model_%d.pt'%(models_dir, name_selector, load_from))) 243 | 244 | if not os.path.exists(results_dir): 245 | os.mkdir(results_dir) 246 | inital_dir = '%s/bird_short_test_init_strokes_%d'%(data_path, image_size) 247 | dataset = Initialstroke_Dataset(inital_dir, image_size=image_size) 248 | dataloader = data.DataLoader(dataset, num_workers=5, batch_size=batch_size, drop_last=False, shuffle=False, pin_memory=True) 249 | # import ipdb;ipdb.set_trace() 250 | 251 | models = [model_eye, model_head, model_body, model_beak, model_legs, model_wings, model_mouth, model_tail] 252 | target_parts = ['eye', 'head', 'body', 'beak', 'legs', 'wings', 'mouth', 'tail', 'none'] 253 | part_to_id = {'initial': 0, 'eye': 1, 'head': 4, 'body': 3, 'beak': 2, 'legs': 5, 'wings': 8, 'mouth': 6, 'tail': 7} 254 | max_iter = 10 255 | if generate_all: 256 | generation_dir = os.path.join(results_dir, 'DoodlerGAN_all') 257 | if not os.path.exists(generation_dir): 258 | os.mkdir(generation_dir) 259 | os.mkdir(os.path.join(generation_dir, 'bw')) 260 | os.mkdir(os.path.join(generation_dir, 'color_initial')) 261 | os.mkdir(os.path.join(generation_dir, 'color')) 262 | for count, initial_strokes in enumerate(dataloader): 263 | initial_strokes = initial_strokes.cuda() 264 | start_point = len(os.listdir(os.path.join(generation_dir, 'bw'))) 265 | print('%d sketches generated'%start_point) 266 | for i in range(batch_size): 267 | samples_name = f'generated-{start_point+i}' 268 | stack_parts = torch.zeros(1, 10, image_size, image_size).cuda() 269 | initial_strokes_rgb = gs_to_rgb(initial_strokes[i], COLORS['initial']) 270 | stack_parts[:, 0] = initial_strokes[i, 0] 271 | stack_parts[:, -1] = initial_strokes[i, 0] 272 | partial_rgbs = initial_strokes_rgb.clone() 273 | prev_part = [] 274 | for iter_i in range(max_iter): 275 | outputs = part_selector.clf.D(stack_parts) 276 | part_rgbs = torch.ones(1, 3, image_size, image_size).cuda() 277 | select_part_order = 0 278 | select_part_ids = torch.topk(outputs, k=8, dim=0)[1] 279 | select_part_id = select_part_ids[select_part_order].item() 280 | select_part = target_parts[select_part_id] 281 | while (select_part == 'none' and iter_i < 6 or select_part in prev_part): 282 | select_part_order += 1 283 | if select_part_order > 7: 284 | import ipdb;ipdb.set_trace() 285 | select_part_id = select_part_ids[select_part_order].item() 286 | select_part = target_parts[select_part_id] 287 | if select_part == 'none': 288 | break 289 | prev_part.append(select_part) 290 | sketch_rgb = partial_rgbs 291 | stack_part = stack_parts.clone() 292 | select_model = models[select_part_id] 293 | part, partial, part_rgb, partial_rgb = generate_part(select_model.GAN, stack_part, sketch_rgb, COLORS[select_part], select_part, samples_name, 1, results_dir=results_dir, trunc_psi=0.1) 294 | stack_parts[0, part_to_id[select_part]] = part[0, 0] 295 | partial_rgbs[0] = partial_rgb[0] 296 | stack_parts[0, -1] = partial[0, 0] 297 | part_rgbs[0] = part_rgb[0] 298 | initial_colored_full = np.tile(np.max(stack_parts.cpu().data.numpy()[:, 1:-1], 1), [3, 1, 1]) 299 | initial_colored_full = 1-np.max(np.stack([1-initial_strokes_rgb.cpu().data.numpy()[0], initial_colored_full]), 0) 300 | cv2.imwrite(os.path.join(generation_dir, 'bw', f'{str(samples_name)}.png'), (1-stack_parts[0, -1].cpu().data.numpy())*255.) 301 | cv2.imwrite(os.path.join(generation_dir, 'color_initial', f'{str(samples_name)}-color.png'), cv2.cvtColor(initial_colored_full.transpose(1, 2, 0)*255., cv2.COLOR_RGB2BGR)) 302 | cv2.imwrite(os.path.join(generation_dir, 'color', f'{str(samples_name)}-color.png'), cv2.cvtColor(partial_rgbs[0].cpu().data.numpy().transpose(1, 2, 0)*255., cv2.COLOR_RGB2BGR)) 303 | else: 304 | now = datetime.now() 305 | timestamp = now.strftime("%m-%d-%Y_%H-%M-%S") 306 | stack_parts = torch.zeros(num_image_tiles*num_image_tiles, 10, image_size, image_size).cuda() 307 | initial_strokes = dataset.sample(num_image_tiles*num_image_tiles).cuda() 308 | initial_strokes_rgb = gs_to_rgb(initial_strokes, COLORS['initial']) 309 | stack_parts[:, 0] = initial_strokes[:, 0] 310 | stack_parts[:, -1] = initial_strokes[:, 0] 311 | partial_rgbs = initial_strokes_rgb.clone() 312 | partial_rgbs_variation = initial_strokes_rgb.clone() 313 | prev_parts = [[] for _ in range(num_image_tiles**2)] 314 | samples_name = f'generated-{timestamp}-{min_step}' 315 | for iter_i in range(max_iter): 316 | outputs = part_selector.clf.D(stack_parts) 317 | part_rgbs = torch.ones(num_image_tiles*num_image_tiles, 3, image_size, image_size).cuda() 318 | for i in range(num_image_tiles**2): 319 | prev_part = prev_parts[i] 320 | select_part_order = 0 321 | select_part_ids = torch.topk(outputs[i], k=9, dim=0)[1] 322 | select_part_id = select_part_ids[select_part_order].item() 323 | select_part = target_parts[select_part_id] 324 | while (select_part == 'none' and iter_i < 6 or select_part in prev_part): 325 | select_part_order += 1 326 | select_part_id = select_part_ids[select_part_order].item() 327 | select_part = target_parts[select_part_id] 328 | if select_part == 'none': 329 | break 330 | prev_parts[i].append(select_part) 331 | sketch_rgb = partial_rgbs[i].clone().unsqueeze(0) 332 | stack_part = stack_parts[i].unsqueeze(0) 333 | select_model = models[select_part_id] 334 | part, partial, part_rgb, partial_rgb = generate_part(select_model.GAN, stack_part, sketch_rgb, COLORS[select_part], select_part, samples_name, 1, results_dir=results_dir, trunc_psi=0.1) 335 | stack_parts[i, part_to_id[select_part]] = part[0, 0] 336 | stack_parts[i, -1] = partial[0, 0] 337 | partial_rgbs[i] = partial_rgb[0] 338 | part_rgbs[i] = part_rgb[0] 339 | torchvision.utils.save_image(partial_rgbs, os.path.join(results_dir, f'{str(samples_name)}-{str(min_step)}-round{iter_i}.png'), nrow=num_image_tiles) 340 | torchvision.utils.save_image(part_rgbs, os.path.join(results_dir, f'{str(samples_name)}-{str(min_step)}-part-round{iter_i}.png'), nrow=num_image_tiles) 341 | torchvision.utils.save_image(1-stack_parts[:, -1:], os.path.join(results_dir, f'{str(samples_name)}-{str(min_step)}-final_pred.png'), nrow=num_image_tiles) 342 | 343 | if __name__ == "__main__": 344 | parser = argparse.ArgumentParser() 345 | parser.add_argument("--data_dir", type=str, default='../data') 346 | parser.add_argument("--results_dir", type=str, default='../results/creative_bird_generation') 347 | parser.add_argument("--models_dir", type=str, default='../models') 348 | parser.add_argument('--n_part', type=int, default=10) 349 | parser.add_argument('--image_size', type=int, default=64) 350 | parser.add_argument('--network_capacity', type=int, default=16) 351 | parser.add_argument('--batch_size', type=int, default=100) 352 | parser.add_argument('--num_image_tiles', type=int, default=8) 353 | parser.add_argument('--trunc_psi', type=float, default=1.) 354 | parser.add_argument('--generate_all', action='store_true') 355 | 356 | args = parser.parse_args() 357 | print(args) 358 | 359 | train_from_folder(args.data_dir, args.results_dir, args.models_dir, args.n_part, args.image_size, args.network_capacity, 360 | args.batch_size, args.num_image_tiles, args.trunc_psi, args.generate_all) -------------------------------------------------------------------------------- /generate_creative_creatures.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import os 7 | import cv2 8 | import torch 9 | import numpy as np 10 | import argparse 11 | import torchvision 12 | from PIL import Image 13 | from tqdm import tqdm 14 | from pathlib import Path 15 | from datetime import datetime 16 | from retry.api import retry_call 17 | from torch.utils import data 18 | from torchvision import transforms 19 | from part_selector import Trainer as Trainer_selector 20 | from part_generator import Trainer as Trainer_cond_unet 21 | from scipy.ndimage.morphology import distance_transform_edt 22 | 23 | COLORS = {'initial':1-torch.cuda.FloatTensor([45, 169, 145]).view(1, -1, 1, 1)/255., 'eye':1-torch.cuda.FloatTensor([243, 156, 18]).view(1, -1, 1, 1)/255., 'none':1-torch.cuda.FloatTensor([149, 165, 166]).view(1, -1, 1, 1)/255., 24 | 'arms':1-torch.cuda.FloatTensor([211, 84, 0]).view(1, -1, 1, 1)/255., 'beak':1-torch.cuda.FloatTensor([41, 128, 185]).view(1, -1, 1, 1)/255., 'mouth':1-torch.cuda.FloatTensor([54, 153, 219]).view(1, -1, 1, 1)/255., 25 | 'body':1-torch.cuda.FloatTensor([192, 57, 43]).view(1, -1, 1, 1)/255., 'ears':1-torch.cuda.FloatTensor([142, 68, 173]).view(1, -1, 1, 1)/255., 'feet':1-torch.cuda.FloatTensor([39, 174, 96]).view(1, -1, 1, 1)/255., 26 | 'fin':1-torch.cuda.FloatTensor([69, 85, 101]).view(1, -1, 1, 1)/255., 'hair':1-torch.cuda.FloatTensor([127, 140, 141]).view(1, -1, 1, 1)/255., 'hands':1-torch.cuda.FloatTensor([45, 63, 81]).view(1, -1, 1, 1)/255., 27 | 'head':1-torch.cuda.FloatTensor([241, 197, 17]).view(1, -1, 1, 1)/255., 'horns':1-torch.cuda.FloatTensor([51, 205, 117]).view(1, -1, 1, 1)/255., 'legs':1-torch.cuda.FloatTensor([232, 135, 50]).view(1, -1, 1, 1)/255., 28 | 'nose':1-torch.cuda.FloatTensor([233, 90, 75]).view(1, -1, 1, 1)/255., 'paws':1-torch.cuda.FloatTensor([160, 98, 186]).view(1, -1, 1, 1)/255., 'tail':1-torch.cuda.FloatTensor([58, 78, 99]).view(1, -1, 1, 1)/255., 29 | 'wings':1-torch.cuda.FloatTensor([198, 203, 207]).view(1, -1, 1, 1)/255., 'details':1-torch.cuda.FloatTensor([171, 190, 191]).view(1, -1, 1, 1)/255.} 30 | 31 | 32 | class Initialstroke_Dataset(data.Dataset): 33 | def __init__(self, folder, image_size): 34 | super().__init__() 35 | self.folder = folder 36 | self.image_size = image_size 37 | self.paths = [p for p in Path(f'{folder}').glob(f'**/*.png')] 38 | self.transform = transforms.Compose([ 39 | transforms.ToTensor(), 40 | ]) 41 | 42 | def __len__(self): 43 | return len(self.paths) 44 | 45 | def __getitem__(self, index): 46 | path = self.paths[index] 47 | img = self.transform(Image.open(path)) 48 | return img 49 | 50 | def sample(self, n): 51 | sample_ids = [np.random.randint(self.__len__()) for _ in range(n)] 52 | samples = [self.transform(Image.open(self.paths[sample_id])) for sample_id in sample_ids] 53 | return torch.stack(samples).cuda() 54 | 55 | 56 | def load_latest(model_dir, name): 57 | model_dir = Path(model_dir) 58 | file_paths = [p for p in Path(model_dir / name).glob('model_*.pt')] 59 | saved_nums = sorted(map(lambda x: int(x.stem.split('_')[1]), file_paths)) 60 | if len(saved_nums) == 0: 61 | return 62 | num = saved_nums[-1] 63 | print(f'continuing -{name} from previous epoch - {num}') 64 | return num 65 | 66 | 67 | def noise(n, latent_dim): 68 | return torch.randn(n, latent_dim).cuda() 69 | 70 | def noise_list(n, layers, latent_dim): 71 | return [(noise(n, latent_dim), layers)] 72 | 73 | def mixed_list(n, layers, latent_dim): 74 | tt = int(torch.rand(()).numpy() * layers) 75 | return noise_list(n, tt, latent_dim) + noise_list(n, layers - tt, latent_dim) 76 | 77 | def image_noise(n, im_size): 78 | return torch.FloatTensor(n, im_size, im_size, 1).uniform_(0., 1.).cuda() 79 | 80 | def evaluate_in_chunks(max_batch_size, model, *args): 81 | split_args = list(zip(*list(map(lambda x: x.split(max_batch_size, dim=0), args)))) 82 | chunked_outputs = [model(*i) for i in split_args] 83 | if len(chunked_outputs) == 1: 84 | return chunked_outputs[0] 85 | return torch.cat(chunked_outputs, dim=0) 86 | 87 | def evaluate_in_chunks_unet(max_batch_size, model, map_feats, *args): 88 | split_args = list(zip(*list(map(lambda x: x.split(max_batch_size, dim=0), args)))) 89 | split_map_feats = list(zip(*list(map(lambda x: x.split(max_batch_size, dim=0), map_feats)))) 90 | chunked_outputs = [model(*i, j) for i, j in zip(split_args, split_map_feats)] 91 | if len(chunked_outputs) == 1: 92 | return chunked_outputs[0] 93 | return torch.cat(chunked_outputs, dim=0) 94 | 95 | def styles_def_to_tensor(styles_def): 96 | return torch.cat([t[:, None, :].expand(-1, n, -1) for t, n in styles_def], dim=1) 97 | 98 | def gs_to_rgb(image, color): 99 | image_rgb = image.repeat(1, 3, 1, 1) 100 | return 1-image_rgb*color 101 | 102 | @torch.no_grad() 103 | def generate_truncated(S, G, style, noi, trunc_psi = 0.75, num_image_tiles = 8, bitmap_feats=None, batch_size=8): 104 | latent_dim = G.latent_dim 105 | z = noise(2000, latent_dim) 106 | samples = evaluate_in_chunks(batch_size, S, z).cpu().numpy() 107 | av = np.mean(samples, axis = 0) 108 | av = np.expand_dims(av, axis = 0) 109 | 110 | w_space = [] 111 | for tensor, num_layers in style: 112 | tmp = S(tensor) 113 | av_torch = torch.from_numpy(av).cuda() 114 | # import ipdb;ipdb.set_trace() 115 | tmp = trunc_psi * (tmp - av_torch) + av_torch 116 | w_space.append((tmp, num_layers)) 117 | 118 | w_styles = styles_def_to_tensor(w_space) 119 | generated_images = evaluate_in_chunks_unet(batch_size, G, bitmap_feats, w_styles, noi) 120 | return generated_images.clamp_(0., 1.) 121 | 122 | @torch.no_grad() 123 | def generate_part(model, partial_image, partial_rgb, color=None, part_name=20, num=0, num_image_tiles=8, trunc_psi=1., save_img=False, trans_std=2, results_dir='../results/bird_seq_unet_5fold'): 124 | model.eval() 125 | ext = 'png' 126 | num_rows = num_image_tiles 127 | latent_dim = model.G.latent_dim 128 | image_size = model.G.image_size 129 | num_layers = model.G.num_layers 130 | def translate_image(image, trans_std=2, rot_std=3, scale_std=2): 131 | affine_image = torch.zeros_like(image) 132 | side = image.shape[-1] 133 | x_shift = np.random.normal(0, trans_std) 134 | y_shift = np.random.normal(0, trans_std) 135 | theta = np.random.normal(0, rot_std) 136 | scale = int(np.random.normal(0, scale_std)) 137 | T = np.float32([[1, 0, x_shift], [0, 1, y_shift]]) 138 | M = cv2.getRotationMatrix2D((side/2,side/2),theta,1) 139 | for i in range(image.shape[1]): 140 | sketch_channel = image[0, i].cpu().data.numpy() 141 | sketch_translation = cv2.warpAffine(sketch_channel, T, (side, side)) 142 | affine_image[0, i] = torch.cuda.FloatTensor(sketch_translation) 143 | return affine_image, x_shift, y_shift, theta, scale 144 | def recover_image(image, x_shift, y_shift, theta, scale): 145 | x_shift *= -1 146 | y_shift *= -1 147 | theta *= -1 148 | # scale *= -1 149 | affine_image = torch.zeros_like(image) 150 | side = image.shape[-1] 151 | T = np.float32([[1, 0, x_shift], [0, 1, y_shift]]) 152 | M = cv2.getRotationMatrix2D((side/2,side/2),theta,1) 153 | for i in range(image.shape[1]): 154 | sketch_channel = image[0, i].cpu().data.numpy() 155 | sketch_translation = cv2.warpAffine(sketch_channel, T, (side, side)) 156 | affine_image[0, i] = torch.cuda.FloatTensor(sketch_translation) 157 | return affine_image 158 | 159 | # latents and noise 160 | latents_z = noise_list(num_rows ** 2, num_layers, latent_dim) 161 | n = image_noise(num_rows ** 2, image_size) 162 | image_partial_batch = partial_image[:, -1:, :, :] 163 | translated_image, dx, dy, theta, scale = translate_image(partial_image, trans_std=trans_std) 164 | bitmap_feats = model.Enc(translated_image) 165 | # bitmap_feats = model.Enc(partial_image) 166 | # generated_partial_images = generate_truncated(model.S, model.G, latents_z, n, trunc_psi = trunc_psi, bitmap_feats=bitmap_feats) 167 | generated_partial_images = recover_image(generate_truncated(model.S, model.G, latents_z, n, trunc_psi = trunc_psi, bitmap_feats=bitmap_feats), dx, dy, theta, scale) 168 | # post process 169 | generated_partial_rgb = gs_to_rgb(generated_partial_images, color) 170 | generated_images = generated_partial_images + image_partial_batch 171 | generated_rgb = 1 - ((1-generated_partial_rgb)+(1-partial_rgb)) 172 | if save_img: 173 | torchvision.utils.save_image(generated_partial_rgb, os.path.join(results_dir, f'{str(num)}-{part_name}-comp.{ext}'), nrow=num_rows) 174 | torchvision.utils.save_image(generated_rgb, os.path.join(results_dir, f'{str(num)}-{part_name}.{ext}'), nrow=num_rows) 175 | return generated_partial_images.clamp_(0., 1.), generated_images.clamp_(0., 1.), generated_partial_rgb.clamp_(0., 1.), generated_rgb.clamp_(0., 1.) 176 | 177 | 178 | def train_from_folder( 179 | data_path = '../../data', 180 | results_dir = '../../results', 181 | models_dir = '../../models', 182 | n_part = 1, 183 | image_size = 128, 184 | network_capacity = 16, 185 | batch_size = 3, 186 | num_image_tiles = 8, 187 | trunc_psi = 0.75, 188 | generate_all=False, 189 | ): 190 | min_step = 599 191 | name_eye='long_generic_creative_sequential_r6_partstack_aug_eye_unet_largeaug' 192 | load_from = load_latest(models_dir, name_eye) 193 | load_from = min(min_step, load_from) 194 | model_eye = Trainer_cond_unet(name_eye, results_dir, models_dir, n_part=n_part, batch_size=batch_size, image_size=image_size, network_capacity=network_capacity) 195 | model_eye.load_config() 196 | model_eye.GAN.load_state_dict(torch.load('%s/%s/model_%d.pt'%(models_dir, name_eye, load_from))) 197 | 198 | 199 | name_head='long_generic_creative_sequential_r6_partstack_aug_head_unet_largeaug' 200 | load_from = load_latest(models_dir, name_head) 201 | load_from = min(min_step, load_from) 202 | model_head = Trainer_cond_unet(name_head, results_dir, models_dir, n_part=n_part, batch_size=batch_size, image_size=image_size, network_capacity=network_capacity) 203 | model_head.load_config() 204 | model_head.GAN.load_state_dict(torch.load('%s/%s/model_%d.pt'%(models_dir, name_head, load_from))) 205 | 206 | 207 | name_body='long_generic_creative_sequential_r6_partstack_aug_body_unet_largeaug' 208 | load_from = load_latest(models_dir, name_body) 209 | load_from = min(min_step, load_from) 210 | model_body = Trainer_cond_unet(name_body, results_dir, models_dir, n_part=n_part, batch_size=batch_size, image_size=image_size, network_capacity=network_capacity) 211 | model_body.load_config() 212 | model_body.GAN.load_state_dict(torch.load('%s/%s/model_%d.pt'%(models_dir, name_body, load_from))) 213 | 214 | 215 | name_beak='long_generic_creative_sequential_r6_partstack_aug_beak_unet_largeaug' 216 | load_from = load_latest(models_dir, name_beak) 217 | load_from = min(min_step, load_from) 218 | model_beak = Trainer_cond_unet(name_beak, results_dir, models_dir, n_part=n_part, batch_size=batch_size, image_size=image_size, network_capacity=network_capacity) 219 | model_beak.load_config() 220 | model_beak.GAN.load_state_dict(torch.load('%s/%s/model_%d.pt'%(models_dir, name_beak, load_from))) 221 | 222 | 223 | name_ears='long_generic_creative_sequential_r6_partstack_aug_ears_unet_largeaug' 224 | load_from = load_latest(models_dir, name_ears) 225 | load_from = min(min_step, load_from) 226 | model_ears = Trainer_cond_unet(name_ears, results_dir, models_dir, n_part=n_part, batch_size=batch_size, image_size=image_size, network_capacity=network_capacity) 227 | model_ears.load_config() 228 | model_ears.GAN.load_state_dict(torch.load('%s/%s/model_%d.pt'%(models_dir, name_ears, load_from))) 229 | 230 | 231 | name_hands='long_generic_creative_sequential_r6_partstack_aug_hands_unet_largeaug' 232 | load_from = load_latest(models_dir, name_hands) 233 | load_from = min(min_step, load_from) 234 | model_hands = Trainer_cond_unet(name_hands, results_dir, models_dir, n_part=n_part, batch_size=batch_size, image_size=image_size, network_capacity=network_capacity) 235 | model_hands.load_config() 236 | model_hands.GAN.load_state_dict(torch.load('%s/%s/model_%d.pt'%(models_dir, name_hands, load_from))) 237 | 238 | 239 | name_legs='long_generic_creative_sequential_r6_partstack_aug_legs_unet_largeaug' 240 | load_from = load_latest(models_dir, name_legs) 241 | load_from = min(min_step, load_from) 242 | model_legs = Trainer_cond_unet(name_legs, results_dir, models_dir, n_part=n_part, batch_size=batch_size, image_size=image_size, network_capacity=network_capacity) 243 | model_legs.load_config() 244 | model_legs.GAN.load_state_dict(torch.load('%s/%s/model_%d.pt'%(models_dir, name_legs, load_from))) 245 | 246 | 247 | name_feet='long_generic_creative_sequential_r6_partstack_aug_feet_unet_largeaug' 248 | load_from = load_latest(models_dir, name_feet) 249 | load_from = min(min_step, load_from) 250 | model_feet = Trainer_cond_unet(name_feet, results_dir, models_dir, n_part=n_part, batch_size=batch_size, image_size=image_size, network_capacity=network_capacity) 251 | model_feet.load_config() 252 | model_feet.GAN.load_state_dict(torch.load('%s/%s/model_%d.pt'%(models_dir, name_feet, load_from))) 253 | 254 | 255 | name_wings='long_generic_creative_sequential_r6_partstack_aug_wings_unet_largeaug' 256 | load_from = load_latest(models_dir, name_wings) 257 | load_from = min(min_step, load_from) 258 | model_wings = Trainer_cond_unet(name_wings, results_dir, models_dir, n_part=n_part, batch_size=batch_size, image_size=image_size, network_capacity=network_capacity) 259 | model_wings.load_config() 260 | 261 | model_wings.GAN.load_state_dict(torch.load('%s/%s/model_%d.pt'%(models_dir, name_wings, load_from))) 262 | 263 | 264 | name_mouth='long_generic_creative_sequential_r6_partstack_aug_mouth_unet_largeaug' 265 | load_from = load_latest(models_dir, name_mouth) 266 | load_from = min(min_step, load_from) 267 | model_mouth = Trainer_cond_unet(name_mouth, results_dir, models_dir, n_part=n_part, batch_size=batch_size, image_size=image_size, network_capacity=network_capacity) 268 | model_mouth.load_config() 269 | model_mouth.GAN.load_state_dict(torch.load('%s/%s/model_%d.pt'%(models_dir, name_mouth, load_from))) 270 | 271 | 272 | name_nose='long_generic_creative_sequential_r6_partstack_aug_nose_unet_largeaug' 273 | load_from = load_latest(models_dir, name_nose) 274 | load_from = min(min_step, load_from) 275 | model_nose = Trainer_cond_unet(name_nose, results_dir, models_dir, n_part=n_part, batch_size=batch_size, image_size=image_size, network_capacity=network_capacity) 276 | model_nose.load_config() 277 | model_nose.GAN.load_state_dict(torch.load('%s/%s/model_%d.pt'%(models_dir, name_nose, load_from))) 278 | 279 | 280 | name_hair='long_generic_creative_sequential_r6_partstack_aug_hair_unet_largeaug' 281 | load_from = load_latest(models_dir, name_hair) 282 | load_from = min(min_step, load_from) 283 | model_hair = Trainer_cond_unet(name_hair, results_dir, models_dir, n_part=n_part, batch_size=batch_size, image_size=image_size, network_capacity=network_capacity) 284 | model_hair.load_config() 285 | model_hair.GAN.load_state_dict(torch.load('%s/%s/model_%d.pt'%(models_dir, name_hair, load_from))) 286 | 287 | 288 | name_tail='long_generic_creative_sequential_r6_partstack_aug_tail_unet_largeaug' 289 | load_from = load_latest(models_dir, name_tail) 290 | load_from = min(min_step, load_from) 291 | model_tail = Trainer_cond_unet(name_tail, results_dir, models_dir, n_part=n_part, batch_size=batch_size, image_size=image_size, network_capacity=network_capacity) 292 | model_tail.load_config() 293 | model_tail.GAN.load_state_dict(torch.load('%s/%s/model_%d.pt'%(models_dir, name_tail, load_from))) 294 | 295 | 296 | name_fin='long_generic_creative_sequential_r6_partstack_aug_fin_unet_largeaug' 297 | load_from = load_latest(models_dir, name_fin) 298 | load_from = min(min_step, load_from) 299 | model_fin = Trainer_cond_unet(name_fin, results_dir, models_dir, n_part=n_part, batch_size=batch_size, image_size=image_size, network_capacity=network_capacity) 300 | model_fin.load_config() 301 | model_fin.GAN.load_state_dict(torch.load('%s/%s/model_%d.pt'%(models_dir, name_fin, load_from))) 302 | 303 | 304 | name_horns='long_generic_creative_sequential_r6_partstack_aug_horns_unet_largeaug' 305 | load_from = load_latest(models_dir, name_horns) 306 | load_from = min(min_step, load_from) 307 | model_horns = Trainer_cond_unet(name_horns, results_dir, models_dir, n_part=n_part, batch_size=batch_size, image_size=image_size, network_capacity=network_capacity) 308 | model_horns.load_config() 309 | model_horns.GAN.load_state_dict(torch.load('%s/%s/model_%d.pt'%(models_dir, name_horns, load_from))) 310 | 311 | 312 | name_paws='long_generic_creative_sequential_r6_partstack_aug_paws_unet_largeaug' 313 | load_from = load_latest(models_dir, name_paws) 314 | load_from = min(min_step, load_from) 315 | model_paws = Trainer_cond_unet(name_paws, results_dir, models_dir, n_part=n_part, batch_size=batch_size, image_size=image_size, network_capacity=network_capacity) 316 | model_paws.load_config() 317 | model_paws.GAN.load_state_dict(torch.load('%s/%s/model_%d.pt'%(models_dir, name_paws, load_from))) 318 | 319 | 320 | name_arms='long_generic_creative_sequential_r6_partstack_aug_arms_unet_largeaug' 321 | load_from = load_latest(models_dir, name_arms) 322 | load_from = min(min_step, load_from) 323 | model_arms = Trainer_cond_unet(name_arms, results_dir, models_dir, n_part=n_part, batch_size=batch_size, image_size=image_size, network_capacity=network_capacity) 324 | model_arms.load_config() 325 | model_arms.GAN.load_state_dict(torch.load('%s/%s/model_%d.pt'%(models_dir, name_arms, load_from))) 326 | 327 | 328 | name_selector='long_generic_creative_selector_aug' 329 | 330 | load_from = load_latest(models_dir, name_selector) 331 | part_selector = Trainer_selector(name_selector, results_dir, models_dir, n_part = n_part, batch_size = batch_size, image_size = image_size, network_capacity=network_capacity) 332 | part_selector.load_config() 333 | part_selector.clf.load_state_dict(torch.load('%s/%s/model_%d.pt'%(models_dir, name_selector, load_from))) 334 | 335 | 336 | inital_dir = '%s/generic_long_test_init_strokes_%d'%(data_path, image_size) 337 | if not os.path.exists(results_dir): 338 | os.mkdir(results_dir) 339 | dataset = Initialstroke_Dataset(inital_dir, image_size=image_size) 340 | dataloader = data.DataLoader(dataset, num_workers=5, batch_size=batch_size, drop_last=False, shuffle=False, pin_memory=True) 341 | 342 | models = [model_eye, model_arms, model_beak, model_mouth, model_body, model_ears, model_feet, model_fin, model_hair, 343 | model_hands, model_head, model_horns, model_legs, model_nose, model_paws, model_tail, model_wings] 344 | target_parts = ['eye', 'arms', 'beak', 'mouth', 'body', 'ears', 'feet', 'fin', 345 | 'hair', 'hands', 'head', 'horns', 'legs', 'nose', 'paws', 'tail', 'wings', 'none'] 346 | part_to_id = {'initial': 0, 'eye': 1, 'arms': 2, 'beak': 3, 'mouth': 4, 'body': 5, 'ears': 6, 'feet': 7, 'fin': 8, 347 | 'hair': 9, 'hands': 10, 'head': 11, 'horns': 12, 'legs': 13, 'nose': 14, 'paws': 15, 'tail': 16, 'wings':17} 348 | max_iter = 10 349 | 350 | if generate_all: 351 | generation_dir = os.path.join(results_dir, 'DoodlerGAN_all') 352 | if not os.path.exists(generation_dir): 353 | os.mkdir(generation_dir) 354 | os.mkdir(os.path.join(generation_dir, 'bw')) 355 | os.mkdir(os.path.join(generation_dir, 'color_initial')) 356 | os.mkdir(os.path.join(generation_dir, 'color')) 357 | for count, initial_strokes in enumerate(dataloader): 358 | initial_strokes = initial_strokes.cuda() 359 | start_point = len(os.listdir(os.path.join(generation_dir, 'bw'))) 360 | print('%d sketches generated'%start_point) 361 | for i in range(batch_size): 362 | samples_name = f'generated-{start_point+i}' 363 | stack_parts = torch.zeros(1, 19, image_size, image_size).cuda() 364 | initial_strokes_rgb = gs_to_rgb(initial_strokes[i], COLORS['initial']) 365 | stack_parts[:, 0] = initial_strokes[i, 0] 366 | stack_parts[:, -1] = initial_strokes[i, 0] 367 | partial_rgbs = initial_strokes_rgb.clone() 368 | prev_part = [] 369 | for iter_i in range(max_iter): 370 | outputs = part_selector.clf.D(stack_parts) 371 | part_rgbs = torch.ones(1, 3, image_size, image_size).cuda() 372 | select_part_order = 0 373 | select_part_ids = torch.topk(outputs, k=10, dim=0)[1] 374 | select_part_id = select_part_ids[select_part_order].item() 375 | select_part = target_parts[select_part_id] 376 | while (select_part == 'none' and iter_i < 6 or select_part in prev_part): 377 | select_part_order += 1 378 | select_part_id = select_part_ids[select_part_order].item() 379 | select_part = target_parts[select_part_id] 380 | if select_part == 'none': 381 | break 382 | prev_part.append(select_part) 383 | sketch_rgb = partial_rgbs 384 | stack_part = stack_parts[0].unsqueeze(0) 385 | select_model = models[select_part_id] 386 | part, partial, part_rgb, partial_rgb = generate_part(select_model.GAN, stack_part, sketch_rgb, COLORS[select_part], select_part, samples_name, 1, trans_std=0, results_dir=results_dir) 387 | stack_parts[0, part_to_id[select_part]] = part[0, 0] 388 | stack_parts[0, -1] = partial[0, 0] 389 | partial_rgbs[0] = partial_rgb[0] 390 | part_rgbs[0] = part_rgb[0] 391 | initial_colored_full = np.tile(np.max(stack_parts.cpu().data.numpy()[:, 1:-1], 1), [3, 1, 1]) 392 | initial_colored_full = 1-np.max(np.stack([1-initial_strokes_rgb.cpu().data.numpy()[0], initial_colored_full]), 0) 393 | cv2.imwrite(os.path.join(generation_dir, 'bw', f'{str(samples_name)}.png'), (1-stack_parts[0, -1].cpu().data.numpy())*255.) 394 | cv2.imwrite(os.path.join(generation_dir, 'color_initial', f'{str(samples_name)}-color.png'), cv2.cvtColor(initial_colored_full.transpose(1, 2, 0)*255., cv2.COLOR_RGB2BGR)) 395 | cv2.imwrite(os.path.join(generation_dir, 'color', f'{str(samples_name)}-color.png'), cv2.cvtColor(partial_rgbs[0].cpu().data.numpy().transpose(1, 2, 0)*255., cv2.COLOR_RGB2BGR)) 396 | else: 397 | now = datetime.now() 398 | timestamp = now.strftime("%m-%d-%Y_%H-%M-%S") 399 | stack_parts = torch.zeros(num_image_tiles*num_image_tiles, 19, image_size, image_size).cuda() 400 | initial_strokes = dataset.sample(num_image_tiles*num_image_tiles).cuda() 401 | initial_strokes_rgb = gs_to_rgb(initial_strokes, COLORS['initial']) 402 | stack_parts[:, 0] = initial_strokes[:, 0] 403 | stack_parts[:, -1] = initial_strokes[:, 0] 404 | partial_rgbs = initial_strokes_rgb.clone() 405 | prev_parts = [[] for _ in range(num_image_tiles**2)] 406 | samples_name = f'generated-{timestamp}-{min_step}' 407 | for iter_i in range(max_iter): 408 | outputs = part_selector.clf.D(stack_parts) 409 | part_rgbs = torch.ones(num_image_tiles*num_image_tiles, 3, image_size, image_size).cuda() 410 | for i in range(num_image_tiles**2): 411 | prev_part = prev_parts[i] 412 | select_part_order = 0 413 | select_part_ids = torch.topk(outputs[i], k=16, dim=0)[1] 414 | select_part_id = select_part_ids[select_part_order].item() 415 | select_part = target_parts[select_part_id] 416 | while (select_part == 'none' and iter_i < 6 or select_part in prev_part): 417 | select_part_order += 1 418 | select_part_id = select_part_ids[select_part_order].item() 419 | select_part = target_parts[select_part_id] 420 | if select_part == 'none': 421 | continue 422 | prev_parts[i].append(select_part) 423 | sketch_rgb = partial_rgbs[i].clone().unsqueeze(0) 424 | stack_part = stack_parts[i].unsqueeze(0) 425 | select_model = models[select_part_id] 426 | part, partial, part_rgb, partial_rgb = generate_part(select_model.GAN, stack_part, sketch_rgb, COLORS[select_part], select_part, samples_name, 1, trans_std=2, results_dir=results_dir) 427 | stack_parts[i, part_to_id[select_part]] = part[0, 0] 428 | stack_parts[i, -1] = partial[0, 0] 429 | partial_rgbs[i] = partial_rgb[0] 430 | part_rgbs[i] = part_rgb[0] 431 | torchvision.utils.save_image(partial_rgbs, os.path.join(results_dir, f'{str(samples_name)}-round{iter_i}.png'), nrow=num_image_tiles) 432 | torchvision.utils.save_image(part_rgbs, os.path.join(results_dir, f'{str(samples_name)}-part-round{iter_i}.png'), nrow=num_image_tiles) 433 | torchvision.utils.save_image(1-stack_parts[:, -1:], os.path.join(results_dir, f'{str(samples_name)}-final_pred.png'), nrow=num_image_tiles) 434 | 435 | if __name__ == "__main__": 436 | parser = argparse.ArgumentParser() 437 | parser.add_argument("--data_dir", type=str, default='../data') 438 | parser.add_argument("--results_dir", type=str, default='../results/creative_creature_generation') 439 | parser.add_argument("--models_dir", type=str, default='../models') 440 | parser.add_argument('--n_part', type=int, default=19) 441 | parser.add_argument('--image_size', type=int, default=64) 442 | parser.add_argument('--network_capacity', type=int, default=16) 443 | parser.add_argument('--batch_size', type=int, default=100) 444 | parser.add_argument('--num_image_tiles', type=int, default=8) 445 | parser.add_argument('--trunc_psi', type=float, default=1.) 446 | parser.add_argument('--generate_all', action='store_true') 447 | 448 | args = parser.parse_args() 449 | print(args) 450 | 451 | train_from_folder(args.data_dir, args.results_dir, args.models_dir, args.n_part, args.image_size, args.network_capacity, 452 | args.batch_size, args.num_image_tiles, args.trunc_psi, args.generate_all) 453 | -------------------------------------------------------------------------------- /inception.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torchvision 10 | 11 | try: 12 | from torchvision.models.utils import load_state_dict_from_url 13 | except ImportError: 14 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 15 | 16 | # Inception weights ported to Pytorch from 17 | # http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz 18 | FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth' 19 | 20 | 21 | class InceptionV3(nn.Module): 22 | """Pretrained InceptionV3 network returning feature maps""" 23 | 24 | # Index of default block of inception to return, 25 | # corresponds to output of final average pooling 26 | DEFAULT_BLOCK_INDEX = 3 27 | 28 | # Maps feature dimensionality to their output blocks indices 29 | BLOCK_INDEX_BY_DIM = { 30 | 64: 0, # First max pooling features 31 | 192: 1, # Second max pooling featurs 32 | 768: 2, # Pre-aux classifier features 33 | 2048: 3 # Final average pooling features 34 | } 35 | 36 | def __init__(self, 37 | output_blocks=[DEFAULT_BLOCK_INDEX], 38 | resize_input=True, 39 | normalize_input=True, 40 | requires_grad=False, 41 | use_fid_inception=True): 42 | """Build pretrained InceptionV3 43 | 44 | Parameters 45 | ---------- 46 | output_blocks : list of int 47 | Indices of blocks to return features of. Possible values are: 48 | - 0: corresponds to output of first max pooling 49 | - 1: corresponds to output of second max pooling 50 | - 2: corresponds to output which is fed to aux classifier 51 | - 3: corresponds to output of final average pooling 52 | resize_input : bool 53 | If true, bilinearly resizes input to width and height 299 before 54 | feeding input to model. As the network without fully connected 55 | layers is fully convolutional, it should be able to handle inputs 56 | of arbitrary size, so resizing might not be strictly needed 57 | normalize_input : bool 58 | If true, scales the input from range (0, 1) to the range the 59 | pretrained Inception network expects, namely (-1, 1) 60 | requires_grad : bool 61 | If true, parameters of the model require gradients. Possibly useful 62 | for finetuning the network 63 | use_fid_inception : bool 64 | If true, uses the pretrained Inception model used in Tensorflow's 65 | FID implementation. If false, uses the pretrained Inception model 66 | available in torchvision. The FID Inception model has different 67 | weights and a slightly different structure from torchvision's 68 | Inception model. If you want to compute FID scores, you are 69 | strongly advised to set this parameter to true to get comparable 70 | results. 71 | """ 72 | super(InceptionV3, self).__init__() 73 | 74 | self.resize_input = resize_input 75 | self.normalize_input = normalize_input 76 | self.output_blocks = sorted(output_blocks) 77 | self.last_needed_block = max(output_blocks) 78 | 79 | assert self.last_needed_block <= 3, \ 80 | 'Last possible output block index is 3' 81 | 82 | self.blocks = nn.ModuleList() 83 | 84 | if use_fid_inception: 85 | inception = fid_inception_v3() 86 | else: 87 | inception = sketch_inception_v3() 88 | 89 | self.inception = inception 90 | # Block 0: input to maxpool1 91 | block0 = [ 92 | inception.Conv2d_1a_3x3, 93 | inception.Conv2d_2a_3x3, 94 | inception.Conv2d_2b_3x3, 95 | nn.MaxPool2d(kernel_size=3, stride=2) 96 | ] 97 | self.blocks.append(nn.Sequential(*block0)) 98 | 99 | # Block 1: maxpool1 to maxpool2 100 | if self.last_needed_block >= 1: 101 | block1 = [ 102 | inception.Conv2d_3b_1x1, 103 | inception.Conv2d_4a_3x3, 104 | nn.MaxPool2d(kernel_size=3, stride=2) 105 | ] 106 | self.blocks.append(nn.Sequential(*block1)) 107 | 108 | # Block 2: maxpool2 to aux classifier 109 | if self.last_needed_block >= 2: 110 | block2 = [ 111 | inception.Mixed_5b, 112 | inception.Mixed_5c, 113 | inception.Mixed_5d, 114 | inception.Mixed_6a, 115 | inception.Mixed_6b, 116 | inception.Mixed_6c, 117 | inception.Mixed_6d, 118 | inception.Mixed_6e, 119 | ] 120 | self.blocks.append(nn.Sequential(*block2)) 121 | 122 | # Block 3: aux classifier to final avgpool 123 | if self.last_needed_block >= 3: 124 | block3 = [ 125 | inception.Mixed_7a, 126 | inception.Mixed_7b, 127 | inception.Mixed_7c, 128 | nn.AdaptiveAvgPool2d(output_size=(1, 1)) 129 | ] 130 | self.blocks.append(nn.Sequential(*block3)) 131 | 132 | for param in self.parameters(): 133 | param.requires_grad = requires_grad 134 | 135 | def forward(self, inp): 136 | """Get Inception feature maps 137 | 138 | Parameters 139 | ---------- 140 | inp : torch.autograd.Variable 141 | Input tensor of shape Bx3xHxW. Values are expected to be in 142 | range (0, 1) 143 | 144 | Returns 145 | ------- 146 | List of torch.autograd.Variable, corresponding to the selected output 147 | block, sorted ascending by index 148 | """ 149 | outp = [] 150 | x = inp 151 | 152 | # if self.resize_input: 153 | # x = F.interpolate(x, 154 | # size=(299, 299), 155 | # mode='bilinear', 156 | # align_corners=False) 157 | 158 | if self.normalize_input: 159 | x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1) 160 | 161 | for idx, block in enumerate(self.blocks): 162 | x = block(x) 163 | if idx in self.output_blocks: 164 | outp.append(x) 165 | 166 | if idx == self.last_needed_block: 167 | break 168 | 169 | return outp 170 | 171 | 172 | def _inception_v3(*args, **kwargs): 173 | """Wraps `torchvision.models.inception_v3` 174 | 175 | Skips default weight inititialization if supported by torchvision version. 176 | See https://github.com/mseitzer/pytorch-fid/issues/28. 177 | """ 178 | try: 179 | version = tuple(map(int, torchvision.__version__.split('.')[:2])) 180 | except ValueError: 181 | # Just a caution against weird version strings 182 | version = (0,) 183 | 184 | if version >= (0, 6): 185 | kwargs['init_weights'] = False 186 | 187 | return torchvision.models.inception_v3(*args, **kwargs) 188 | 189 | 190 | def fid_inception_v3(): 191 | """Build pretrained Inception model for FID computation 192 | 193 | The Inception model for FID computation uses a different set of weights 194 | and has a slightly different structure than torchvision's Inception. 195 | 196 | This method first constructs torchvision's Inception and then patches the 197 | necessary parts that are different in the FID Inception model. 198 | """ 199 | inception = _inception_v3(num_classes=1008, 200 | aux_logits=False, 201 | pretrained=False) 202 | inception.Mixed_5b = FIDInceptionA(192, pool_features=32) 203 | inception.Mixed_5c = FIDInceptionA(256, pool_features=64) 204 | inception.Mixed_5d = FIDInceptionA(288, pool_features=64) 205 | inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128) 206 | inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160) 207 | inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160) 208 | inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192) 209 | inception.Mixed_7b = FIDInceptionE_1(1280) 210 | inception.Mixed_7c = FIDInceptionE_2(2048) 211 | 212 | state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True) 213 | inception.load_state_dict(state_dict) 214 | return inception 215 | 216 | 217 | def sketch_inception_v3(): 218 | """Build pretrained Inception model for FID computation 219 | 220 | The Inception model for FID computation uses a different set of weights 221 | and has a slightly different structure than torchvision's Inception. 222 | 223 | This method first constructs torchvision's Inception and then patches the 224 | necessary parts that are different in the FID Inception model. 225 | """ 226 | inception = _inception_v3(num_classes=345, 227 | aux_logits=True, 228 | pretrained=False) 229 | inception.Conv2d_1a_3x3.conv = nn.Conv2d(1, 32, kernel_size=(3,3), stride=(2,2), bias=False) 230 | state_dict = torch.load('../models/inception/85.pt') 231 | 232 | state_dict_new = {key[7:]:state_dict[key] for key in state_dict} 233 | inception.load_state_dict(state_dict_new) 234 | return inception 235 | 236 | 237 | class FIDInceptionA(torchvision.models.inception.InceptionA): 238 | """InceptionA block patched for FID computation""" 239 | def __init__(self, in_channels, pool_features): 240 | super(FIDInceptionA, self).__init__(in_channels, pool_features) 241 | 242 | def forward(self, x): 243 | branch1x1 = self.branch1x1(x) 244 | 245 | branch5x5 = self.branch5x5_1(x) 246 | branch5x5 = self.branch5x5_2(branch5x5) 247 | 248 | branch3x3dbl = self.branch3x3dbl_1(x) 249 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 250 | branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) 251 | 252 | # Patch: Tensorflow's average pool does not use the padded zero's in 253 | # its average calculation 254 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 255 | count_include_pad=False) 256 | branch_pool = self.branch_pool(branch_pool) 257 | 258 | outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool] 259 | return torch.cat(outputs, 1) 260 | 261 | 262 | class FIDInceptionC(torchvision.models.inception.InceptionC): 263 | """InceptionC block patched for FID computation""" 264 | def __init__(self, in_channels, channels_7x7): 265 | super(FIDInceptionC, self).__init__(in_channels, channels_7x7) 266 | 267 | def forward(self, x): 268 | branch1x1 = self.branch1x1(x) 269 | 270 | branch7x7 = self.branch7x7_1(x) 271 | branch7x7 = self.branch7x7_2(branch7x7) 272 | branch7x7 = self.branch7x7_3(branch7x7) 273 | 274 | branch7x7dbl = self.branch7x7dbl_1(x) 275 | branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl) 276 | branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl) 277 | branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl) 278 | branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl) 279 | 280 | # Patch: Tensorflow's average pool does not use the padded zero's in 281 | # its average calculation 282 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 283 | count_include_pad=False) 284 | branch_pool = self.branch_pool(branch_pool) 285 | 286 | outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool] 287 | return torch.cat(outputs, 1) 288 | 289 | 290 | class FIDInceptionE_1(torchvision.models.inception.InceptionE): 291 | """First InceptionE block patched for FID computation""" 292 | def __init__(self, in_channels): 293 | super(FIDInceptionE_1, self).__init__(in_channels) 294 | 295 | def forward(self, x): 296 | branch1x1 = self.branch1x1(x) 297 | 298 | branch3x3 = self.branch3x3_1(x) 299 | branch3x3 = [ 300 | self.branch3x3_2a(branch3x3), 301 | self.branch3x3_2b(branch3x3), 302 | ] 303 | branch3x3 = torch.cat(branch3x3, 1) 304 | 305 | branch3x3dbl = self.branch3x3dbl_1(x) 306 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 307 | branch3x3dbl = [ 308 | self.branch3x3dbl_3a(branch3x3dbl), 309 | self.branch3x3dbl_3b(branch3x3dbl), 310 | ] 311 | branch3x3dbl = torch.cat(branch3x3dbl, 1) 312 | 313 | # Patch: Tensorflow's average pool does not use the padded zero's in 314 | # its average calculation 315 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 316 | count_include_pad=False) 317 | branch_pool = self.branch_pool(branch_pool) 318 | 319 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] 320 | return torch.cat(outputs, 1) 321 | 322 | 323 | class FIDInceptionE_2(torchvision.models.inception.InceptionE): 324 | """Second InceptionE block patched for FID computation""" 325 | def __init__(self, in_channels): 326 | super(FIDInceptionE_2, self).__init__(in_channels) 327 | 328 | def forward(self, x): 329 | branch1x1 = self.branch1x1(x) 330 | 331 | branch3x3 = self.branch3x3_1(x) 332 | branch3x3 = [ 333 | self.branch3x3_2a(branch3x3), 334 | self.branch3x3_2b(branch3x3), 335 | ] 336 | branch3x3 = torch.cat(branch3x3, 1) 337 | 338 | branch3x3dbl = self.branch3x3dbl_1(x) 339 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 340 | branch3x3dbl = [ 341 | self.branch3x3dbl_3a(branch3x3dbl), 342 | self.branch3x3dbl_3b(branch3x3dbl), 343 | ] 344 | branch3x3dbl = torch.cat(branch3x3dbl, 1) 345 | 346 | # Patch: The FID Inception model uses max pooling instead of average 347 | # pooling. This is likely an error in this specific Inception 348 | # implementation, as other Inception models use average pooling here 349 | # (which matches the description in the paper). 350 | branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1) 351 | branch_pool = self.branch_pool(branch_pool) 352 | 353 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] 354 | return torch.cat(outputs, 1) 355 | -------------------------------------------------------------------------------- /part_generator.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import os 7 | import math 8 | import json 9 | import torch 10 | import random 11 | import torchvision 12 | import multiprocessing 13 | import numpy as np 14 | import torch.nn.functional as F 15 | from math import floor, log2 16 | from shutil import rmtree 17 | from functools import partial 18 | 19 | from torch import nn 20 | from torch.utils import data 21 | from torch.optim import Adam 22 | from torch.autograd import grad as torch_grad 23 | from torchvision import transforms 24 | 25 | from PIL import Image 26 | from pathlib import Path 27 | import cairocffi as cairo 28 | 29 | assert torch.cuda.is_available(), 'You need to have an Nvidia GPU with CUDA installed.' 30 | 31 | COLORS_BIRD = {'initial':1-torch.cuda.FloatTensor([45, 169, 145]).view(1, -1, 1, 1)/255., 'eye':1-torch.cuda.FloatTensor([243, 156, 18]).view(1, -1, 1, 1)/255., 'none':1-torch.cuda.FloatTensor([149, 165, 166]).view(1, -1, 1, 1)/255., 32 | 'beak':1-torch.cuda.FloatTensor([211, 84, 0]).view(1, -1, 1, 1)/255., 'body':1-torch.cuda.FloatTensor([41, 128, 185]).view(1, -1, 1, 1)/255., 'details':1-torch.cuda.FloatTensor([171, 190, 191]).view(1, -1, 1, 1)/255., 33 | 'head':1-torch.cuda.FloatTensor([192, 57, 43]).view(1, -1, 1, 1)/255., 'legs':1-torch.cuda.FloatTensor([142, 68, 173]).view(1, -1, 1, 1)/255., 'mouth':1-torch.cuda.FloatTensor([39, 174, 96]).view(1, -1, 1, 1)/255., 34 | 'tail':1-torch.cuda.FloatTensor([69, 85, 101]).view(1, -1, 1, 1)/255., 'wings':1-torch.cuda.FloatTensor([127, 140, 141]).view(1, -1, 1, 1)/255.} 35 | 36 | COLORS_GENERIC = {'initial':1-torch.cuda.FloatTensor([45, 169, 145]).view(1, -1, 1, 1)/255., 'eye':1-torch.cuda.FloatTensor([243, 156, 18]).view(1, -1, 1, 1)/255., 'none':1-torch.cuda.FloatTensor([149, 165, 166]).view(1, -1, 1, 1)/255., 37 | 'arms':1-torch.cuda.FloatTensor([211, 84, 0]).view(1, -1, 1, 1)/255., 'beak':1-torch.cuda.FloatTensor([41, 128, 185]).view(1, -1, 1, 1)/255., 'mouth':1-torch.cuda.FloatTensor([54, 153, 219]).view(1, -1, 1, 1)/255., 38 | 'body':1-torch.cuda.FloatTensor([192, 57, 43]).view(1, -1, 1, 1)/255., 'ears':1-torch.cuda.FloatTensor([142, 68, 173]).view(1, -1, 1, 1)/255., 'feet':1-torch.cuda.FloatTensor([39, 174, 96]).view(1, -1, 1, 1)/255., 39 | 'fin':1-torch.cuda.FloatTensor([69, 85, 101]).view(1, -1, 1, 1)/255., 'hair':1-torch.cuda.FloatTensor([127, 140, 141]).view(1, -1, 1, 1)/255., 'hands':1-torch.cuda.FloatTensor([45, 63, 81]).view(1, -1, 1, 1)/255., 40 | 'head':1-torch.cuda.FloatTensor([241, 197, 17]).view(1, -1, 1, 1)/255., 'horns':1-torch.cuda.FloatTensor([51, 205, 117]).view(1, -1, 1, 1)/255., 'legs':1-torch.cuda.FloatTensor([232, 135, 50]).view(1, -1, 1, 1)/255., 41 | 'nose':1-torch.cuda.FloatTensor([233, 90, 75]).view(1, -1, 1, 1)/255., 'paws':1-torch.cuda.FloatTensor([160, 98, 186]).view(1, -1, 1, 1)/255., 'tail':1-torch.cuda.FloatTensor([58, 78, 99]).view(1, -1, 1, 1)/255., 42 | 'wings':1-torch.cuda.FloatTensor([198, 203, 207]).view(1, -1, 1, 1)/255., 'details':1-torch.cuda.FloatTensor([171, 190, 191]).view(1, -1, 1, 1)/255.} 43 | 44 | num_cores = multiprocessing.cpu_count() 45 | 46 | # constants 47 | 48 | EXTS = ['jpg', 'png', 'npy'] 49 | EPS = 1e-8 50 | 51 | # helper classes 52 | 53 | class NanException(Exception): 54 | pass 55 | 56 | class Flatten(nn.Module): 57 | def forward(self, x): 58 | return x.reshape(x.shape[0], -1) 59 | 60 | # helpers 61 | def gs_to_rgb(image, color): 62 | image_rgb = image.repeat(1, 3, 1, 1) 63 | return 1-image_rgb*color 64 | 65 | def default(value, d): 66 | return d if value is None else value 67 | 68 | def cycle(iterable): 69 | while True: 70 | for i in iterable: 71 | yield i 72 | 73 | def is_empty(t): 74 | return t.nelement() == 1 75 | 76 | def raise_if_nan(t): 77 | if torch.isnan(t): 78 | raise NanException 79 | 80 | def loss_backwards(loss, optimizer, **kwargs): 81 | loss.backward(**kwargs) 82 | 83 | def gradient_penalty(images, output, weight = 10): 84 | batch_size = images.shape[0] 85 | gradients = torch_grad(outputs=output, inputs=images, 86 | grad_outputs=torch.ones(output.size()).cuda(), 87 | create_graph=True, retain_graph=True, only_inputs=True)[0] 88 | 89 | gradients = gradients.view(batch_size, -1) 90 | return weight * ((gradients.norm(2, dim=1) - 1) ** 2).mean() 91 | 92 | def calc_pl_lengths(styles, images): 93 | num_pixels = images.shape[2] * images.shape[3] 94 | pl_noise = torch.randn(images.shape).cuda() / math.sqrt(num_pixels) 95 | outputs = (images * pl_noise).sum() 96 | 97 | pl_grads = torch_grad(outputs=outputs, inputs=styles, 98 | grad_outputs=torch.ones(outputs.shape).cuda(), 99 | create_graph=True, retain_graph=True, only_inputs=True)[0] 100 | return (pl_grads ** 2).sum(dim=2).mean(dim=1).sqrt() 101 | 102 | def noise(n, latent_dim): 103 | return torch.randn(n, latent_dim).cuda() 104 | 105 | def noise_list(n, layers, latent_dim): 106 | return [(noise(n, latent_dim), layers)] 107 | 108 | def mixed_list(n, layers, latent_dim): 109 | tt = int(torch.rand(()).numpy() * layers) 110 | return noise_list(n, tt, latent_dim) + noise_list(n, layers - tt, latent_dim) 111 | 112 | def latent_to_w(style_vectorizer, latent_descr): 113 | return [(style_vectorizer(z), num_layers) for z, num_layers in latent_descr] 114 | 115 | def image_noise(n, im_size): 116 | return torch.FloatTensor(n, im_size, im_size, 1).uniform_(0., 1.).cuda() 117 | 118 | def leaky_relu(p=0.2): 119 | return nn.LeakyReLU(p, inplace=True) 120 | 121 | def evaluate_in_chunks(max_batch_size, model, *args): 122 | split_args = list(zip(*list(map(lambda x: x.split(max_batch_size, dim=0), args)))) 123 | chunked_outputs = [model(*i) for i in split_args] 124 | if len(chunked_outputs) == 1: 125 | return chunked_outputs[0] 126 | return torch.cat(chunked_outputs, dim=0) 127 | 128 | def evaluate_in_chunks_unet(max_batch_size, model, map_feats, *args): 129 | split_args = list(zip(*list(map(lambda x: x.split(max_batch_size, dim=0), args)))) 130 | split_map_feats = list(zip(*list(map(lambda x: x.split(max_batch_size, dim=0), map_feats)))) 131 | chunked_outputs = [model(*i, j) for i, j in zip(split_args, split_map_feats)] 132 | if len(chunked_outputs) == 1: 133 | return chunked_outputs[0] 134 | return torch.cat(chunked_outputs, dim=0) 135 | 136 | def styles_def_to_tensor(styles_def): 137 | return torch.cat([t[:, None, :].expand(-1, n, -1) for t, n in styles_def], dim=1) 138 | 139 | 140 | # dataset 141 | 142 | class Dataset_JSON(data.Dataset): 143 | def __init__(self, folder, image_size, large_aug=False): 144 | super().__init__() 145 | min_sample_num = 10000 146 | self.folder = folder 147 | self.image_size = image_size 148 | self.large_aug = large_aug 149 | self.paths = [p for p in Path(f'{folder}').glob(f'**/*.json')] 150 | while len(self.paths) < min_sample_num: 151 | self.paths.extend(self.paths) 152 | # notice the real influence of the trans / scale is side / 512 (original side) because of scalling in rendering 153 | if not large_aug: 154 | self.rotate = [-1/12*np.pi, 1/12*np.pi] 155 | self.trans = 0.01 156 | self.scale = [0.9, 1.1] 157 | else: 158 | self.rotate = [-1/4*np.pi, 1/4*np.pi] 159 | self.trans = 0.05 160 | self.scale = [0.75, 1.25] 161 | self.line_diameter_scale = [0.25, 1.25] 162 | if 'bird' in folder: 163 | self.id_to_part = {0:'initial', 1:'eye', 4:'head', 3:'body', 2:'beak', 5:'legs', 8:'wings', 6:'mouth', 7:'tail'} 164 | elif 'generic' in folder or 'fin' in folder or 'horn' in folder: 165 | self.id_to_part = { 0:'initial', 1:'eye', 2:'arms', 3:'beak', 4:'mouth', 5:'body', 6:'ears', 7:'feet', 8:'fin', 166 | 9:'hair', 10:'hands', 11:'head', 12:'horns', 13:'legs', 14:'nose', 15:'paws', 16:'tail', 17:'wings'} 167 | self.n_part = len(self.id_to_part) 168 | 169 | def __len__(self): 170 | return len(self.paths) 171 | 172 | def __getitem__(self, index): 173 | path = self.paths[index] 174 | json_data = json.load(open(path)) 175 | input_parts_json = json_data['input_parts'] 176 | target_part_json = json_data['target_part'] 177 | # sample random affine parameters 178 | theta = np.random.uniform(*self.rotate) 179 | trans_pixel = 512*self.trans 180 | translate_x = np.random.uniform(-trans_pixel, trans_pixel) 181 | translate_y = np.random.uniform(-trans_pixel, trans_pixel) 182 | scale = np.random.uniform(*self.scale) 183 | if self.large_aug: 184 | line_diameter = np.random.uniform(*self.line_diameter_scale)*16 185 | else: 186 | line_diameter = 16 187 | # apply random affine transformation 188 | affine_target_part_json= self.affine_trans(target_part_json, theta, translate_x, translate_y, scale) 189 | processed_img_partial = [] 190 | affine_vector_input_part = [] 191 | for i in range(self.n_part): 192 | key = self.id_to_part[i] 193 | affine_input_part_json = self.affine_trans(input_parts_json[key], theta, translate_x, translate_y, scale) 194 | affine_vector_input_part += affine_input_part_json 195 | processed_img_partial.append(self.processed_part_to_raster(affine_input_part_json, side=self.image_size, line_diameter=line_diameter)) 196 | processed_img_partial.append(self.processed_part_to_raster(affine_vector_input_part, side=self.image_size, line_diameter=line_diameter)) 197 | processed_img_partonly = self.processed_part_to_raster(affine_target_part_json, side=self.image_size, line_diameter=line_diameter) 198 | processed_img = self.processed_part_to_raster(affine_vector_input_part+affine_target_part_json, side=self.image_size, line_diameter=line_diameter) 199 | # RandomHorizontalFlip 200 | if np.random.random() > 0.5: 201 | processed_img = processed_img.flip(-1) 202 | processed_img_partial = torch.cat(processed_img_partial, 0).flip(-1) 203 | processed_img_partonly = processed_img_partonly.flip(-1) 204 | else: 205 | processed_img_partial = torch.cat(processed_img_partial, 0) 206 | return processed_img, processed_img_partial, processed_img_partonly 207 | 208 | def sample_partial_test(self, n): 209 | sample_ids = [np.random.randint(self.__len__()) for _ in range(n)] 210 | sample_jsons = [json.load(open(self.paths[sample_id]))for sample_id in sample_ids] 211 | samples = [] 212 | samples_partial = [] 213 | samples_partonly = [] 214 | for sample_json in sample_jsons: 215 | input_parts_json = sample_json['input_parts'] 216 | target_part_json = sample_json['target_part'] 217 | img_partial_test = [] 218 | vector_input_part = [] 219 | for i in range(self.n_part): 220 | key = self.id_to_part[i] 221 | vector_input_part += input_parts_json[key] 222 | img_partial_test.append(self.processed_part_to_raster(input_parts_json[key], side=self.image_size)) 223 | img_partial_test.append(self.processed_part_to_raster(vector_input_part, side=self.image_size)) 224 | samples_partial.append(torch.cat(img_partial_test, 0)) 225 | img_partonly_test = self.processed_part_to_raster(target_part_json, side=self.image_size) 226 | img_test = self.processed_part_to_raster(vector_input_part+target_part_json, side=self.image_size) 227 | samples.append(img_test) 228 | samples_partonly.append(img_partonly_test) 229 | return torch.stack(samples), torch.stack(samples_partial), torch.stack(samples_partonly) 230 | 231 | def affine_trans(self, data, theta, translate_x, translate_y, scale): 232 | rotate_mat = np.array([[np.cos(theta), np.sin(theta)], [-np.sin(theta), np.cos(theta)]]) 233 | affine_data = [] 234 | for item in data: 235 | if len(item) == 0: 236 | continue 237 | affine_item = np.array(item) - 256. 238 | affine_item = np.transpose(np.matmul(rotate_mat, np.transpose(affine_item))) 239 | affine_item[:, 0] += translate_x 240 | affine_item[:, 1] += translate_y 241 | affine_item *= scale 242 | affine_data.append(affine_item + 256.) 243 | return affine_data 244 | 245 | def processed_part_to_raster(self, vector_part, side=64, line_diameter=16, padding=16, bg_color=(0,0,0), fg_color=(1,1,1)): 246 | """ 247 | render raster image based on the processed part 248 | """ 249 | original_side = 512. 250 | surface = cairo.ImageSurface(cairo.FORMAT_ARGB32, side, side) 251 | ctx = cairo.Context(surface) 252 | ctx.set_antialias(cairo.ANTIALIAS_BEST) 253 | ctx.set_line_cap(cairo.LINE_CAP_ROUND) 254 | ctx.set_line_join(cairo.LINE_JOIN_ROUND) 255 | ctx.set_line_width(line_diameter) 256 | # scale to match the new size 257 | # add padding at the edges for the line_diameter 258 | # and add additional padding to account for antialiasing 259 | total_padding = padding * 2. + line_diameter 260 | new_scale = float(side) / float(original_side + total_padding) 261 | ctx.scale(new_scale, new_scale) 262 | ctx.translate(total_padding / 2., total_padding / 2.) 263 | raster_images = [] 264 | # clear background 265 | ctx.set_source_rgb(*bg_color) 266 | ctx.paint() 267 | # draw strokes, this is the most cpu-intensive part 268 | ctx.set_source_rgb(*fg_color) 269 | for stroke in vector_part: 270 | if len(stroke) == 0: 271 | continue 272 | ctx.move_to(stroke[0][0], stroke[0][1]) 273 | for x, y in stroke: 274 | ctx.line_to(x, y) 275 | ctx.stroke() 276 | surface_data = surface.get_data() 277 | raster_image = np.copy(np.asarray(surface_data))[::4].reshape(side, side) 278 | return torch.FloatTensor(raster_image/255.)[None, :, :] 279 | 280 | # exponential moving average helpers 281 | 282 | def ema_inplace(moving_avg, new, decay): 283 | if is_empty(moving_avg): 284 | moving_avg.data.copy_(new) 285 | return 286 | moving_avg.data.mul_(decay).add_(1 - decay, new) 287 | 288 | 289 | # Encoder 290 | 291 | class EncoderBlock_unet(nn.Module): 292 | def __init__(self, input_channels, filters, downsample=True): 293 | super().__init__() 294 | self.net = nn.Sequential( 295 | nn.Conv2d(input_channels, filters, 3, padding=1), 296 | leaky_relu(), 297 | nn.Conv2d(filters, filters, 3, padding=1), 298 | leaky_relu() 299 | ) 300 | 301 | self.downsample = nn.Conv2d(filters, filters, 3, padding = 1, stride = 2) if downsample else None 302 | 303 | def forward(self, x): 304 | x = self.net(x) 305 | if self.downsample is not None: 306 | x = self.downsample(x) 307 | return x 308 | 309 | class Encoder_unet(nn.Module): 310 | def __init__(self, num_init_filters, image_size, network_capacity=16): 311 | super().__init__() 312 | num_layers = int(log2(image_size) - 1) 313 | 314 | blocks = [] 315 | filters = [num_init_filters] + [network_capacity*(2 ** (i)) for i in range(num_layers)] # 16, 32, 64, 128, 256, 512, 1024 316 | chan_in_out = list(zip(filters[0:-1], filters[1:])) 317 | 318 | for ind, (in_chan, out_chan) in enumerate(chan_in_out): # 128, 512, 2048, 4096, 16384, 65536, 262144 319 | is_not_last = ind < (len(chan_in_out) - 1) 320 | block = EncoderBlock_unet(in_chan, out_chan, downsample=is_not_last) 321 | blocks.append(block) 322 | self.blocks = nn.ModuleList(blocks) 323 | 324 | def forward(self, x): 325 | feats = [] 326 | for block in self.blocks: 327 | x = block(x) 328 | feats.append(x) 329 | return feats 330 | 331 | # stylegan2_cond_unet classes, stylegan2 code is adapted from https://github.com/lucidrains/stylegan2-pytorch 332 | class StyleVectorizer(nn.Module): 333 | def __init__(self, emb, depth): 334 | super().__init__() 335 | 336 | layers = [] 337 | for i in range(depth): 338 | layers.extend([nn.Linear(emb, emb), leaky_relu()]) 339 | 340 | self.net = nn.Sequential(*layers) 341 | 342 | def forward(self, x): 343 | return self.net(x) 344 | 345 | 346 | class RGBBlock(nn.Module): 347 | def __init__(self, latent_dim, input_channel, upsample, rgba = False): 348 | super().__init__() 349 | self.input_channel = input_channel 350 | self.to_style = nn.Linear(latent_dim, input_channel) 351 | 352 | # out_filters = 3 if not rgba else 4 353 | out_filters = 1 354 | self.conv = Conv2DMod(input_channel, out_filters, 1, demod=False) 355 | 356 | self.upsample = nn.Upsample(scale_factor = 2, mode='bilinear', align_corners=False) if upsample else None 357 | 358 | def forward(self, x, prev_rgb, istyle): 359 | b, c, h, w = x.shape 360 | style = self.to_style(istyle) 361 | x = self.conv(x, style) 362 | 363 | if prev_rgb is not None: 364 | x = x + prev_rgb 365 | 366 | if self.upsample is not None: 367 | x = self.upsample(x) 368 | 369 | return x 370 | 371 | 372 | class Conv2DMod(nn.Module): 373 | def __init__(self, in_chan, out_chan, kernel, demod=True, stride=1, dilation=1, **kwargs): 374 | super().__init__() 375 | self.filters = out_chan 376 | self.demod = demod 377 | self.kernel = kernel 378 | self.stride = stride 379 | self.dilation = dilation 380 | self.weight = nn.Parameter(torch.randn((out_chan, in_chan, kernel, kernel))) 381 | nn.init.kaiming_normal_(self.weight, a=0, mode='fan_in', nonlinearity='leaky_relu') 382 | 383 | def _get_same_padding(self, size, kernel, dilation, stride): 384 | return ((size - 1) * (stride - 1) + dilation * (kernel - 1)) // 2 385 | 386 | def forward(self, x, y): 387 | b, c, h, w = x.shape 388 | 389 | w1 = y[:, None, :, None, None] 390 | w2 = self.weight[None, :, :, :, :] 391 | weights = w2 * (w1 + 1) 392 | 393 | if self.demod: 394 | d = torch.rsqrt((weights ** 2).sum(dim=(2, 3, 4), keepdim=True) + EPS) 395 | weights = weights * d 396 | 397 | x = x.reshape(1, -1, h, w) 398 | 399 | _, _, *ws = weights.shape 400 | weights = weights.reshape(b * self.filters, *ws) 401 | 402 | padding = self._get_same_padding(h, self.kernel, self.dilation, self.stride) 403 | x = F.conv2d(x, weights, padding=padding, groups=b) 404 | 405 | x = x.reshape(-1, self.filters, h, w) 406 | return x 407 | 408 | class GeneratorBlock(nn.Module): 409 | def __init__(self, latent_dim, input_channels, filters, upsample = True, upsample_rgb = True, rgba = False): 410 | super().__init__() 411 | self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) if upsample else None 412 | 413 | self.to_style1 = nn.Linear(latent_dim, input_channels) 414 | self.to_noise1 = nn.Linear(1, filters) 415 | self.conv1 = Conv2DMod(input_channels, filters, 3) 416 | 417 | self.to_style2 = nn.Linear(latent_dim, filters) 418 | self.to_noise2 = nn.Linear(1, filters) 419 | self.conv2 = Conv2DMod(filters, filters, 3) 420 | 421 | self.activation = leaky_relu() 422 | self.to_rgb = RGBBlock(latent_dim, filters, upsample_rgb, rgba) 423 | 424 | def forward(self, x, prev_rgb, istyle, inoise): 425 | if self.upsample is not None: 426 | x = self.upsample(x) 427 | 428 | inoise = inoise[:, :x.shape[2], :x.shape[3], :] 429 | noise1 = self.to_noise1(inoise).permute((0, 3, 2, 1)) 430 | noise2 = self.to_noise2(inoise).permute((0, 3, 2, 1)) 431 | 432 | style1 = self.to_style1(istyle) 433 | x = self.conv1(x, style1) 434 | x = self.activation(x + noise1) 435 | 436 | style2 = self.to_style2(istyle) 437 | x = self.conv2(x, style2) 438 | x = self.activation(x + noise2) 439 | 440 | rgb = self.to_rgb(x, prev_rgb, istyle) 441 | return x, rgb 442 | 443 | class DiscriminatorBlock(nn.Module): 444 | def __init__(self, input_channels, filters, downsample=True): 445 | super().__init__() 446 | self.conv_res = nn.Conv2d(input_channels, filters, 1) 447 | 448 | self.net = nn.Sequential( 449 | nn.Conv2d(input_channels, filters, 3, padding=1), 450 | leaky_relu(), 451 | nn.Conv2d(filters, filters, 3, padding=1), 452 | leaky_relu() 453 | ) 454 | 455 | self.downsample = nn.Conv2d(filters, filters, 3, padding = 1, stride = 2) if downsample else None 456 | 457 | def forward(self, x): 458 | res = self.conv_res(x) 459 | x = self.net(x) 460 | x = x + res 461 | if self.downsample is not None: 462 | x = self.downsample(x) 463 | return x 464 | 465 | 466 | class Generator_unet(nn.Module): 467 | def __init__(self, image_size, latent_dim, network_capacity=16): 468 | super().__init__() 469 | self.image_size = image_size 470 | self.latent_dim = latent_dim 471 | self.num_layers = int(log2(image_size) - 1) 472 | 473 | init_channels = 4 * network_capacity 474 | self.initial_block = nn.Parameter(torch.randn((init_channels, 4, 4))) 475 | filters = [init_channels] + [network_capacity * (2 ** (i + 1)) for i in range(self.num_layers)][::-1] 476 | in_out_pairs = zip([ch+network_capacity*(2 ** (self.num_layers-1-i)) for i, ch in enumerate(filters[0:-1])], filters[1:]) 477 | 478 | self.blocks = nn.ModuleList([]) 479 | for ind, (in_chan, out_chan) in enumerate(in_out_pairs): 480 | not_first = ind != 0 481 | not_last = ind != (self.num_layers - 1) 482 | 483 | block = GeneratorBlock( 484 | latent_dim, 485 | in_chan, 486 | out_chan, 487 | upsample = not_first, 488 | upsample_rgb = not_last 489 | ) 490 | self.blocks.append(block) 491 | 492 | def forward(self, styles, input_noise, cond_feat_maps): 493 | batch_size = styles.shape[0] 494 | image_size = self.image_size 495 | x = self.initial_block.expand(batch_size, -1, -1, -1) 496 | styles = styles.transpose(0, 1) 497 | 498 | rgb = None 499 | for style, block, feat_map in zip(styles, self.blocks, cond_feat_maps[::-1]): 500 | x = torch.cat([x, feat_map], 1) 501 | x, rgb = block(x, rgb, style, input_noise) 502 | return rgb 503 | 504 | class Discriminator(nn.Module): 505 | def __init__(self, image_size, network_capacity=16, n_part=1): 506 | super().__init__() 507 | num_layers = int(log2(image_size) - 1) 508 | num_init_filters = n_part 509 | 510 | filters = [num_init_filters] + [(network_capacity) * (2 ** i) for i in range(num_layers+1)] 511 | chan_in_out = list(zip(filters[0:-1], filters[1:])) 512 | blocks = [] 513 | 514 | for ind, (in_chan, out_chan) in enumerate(chan_in_out): 515 | num_layer = ind + 1 516 | is_not_last = ind < (len(chan_in_out) - 1) 517 | block = DiscriminatorBlock(in_chan, out_chan, downsample = is_not_last) 518 | blocks.append(block) 519 | 520 | self.blocks = nn.ModuleList(blocks) 521 | 522 | latent_dim = 2 * 2 * filters[-1] 523 | 524 | self.flatten = Flatten() 525 | self.to_logit = nn.Linear(latent_dim, 1) 526 | 527 | def forward(self, x): 528 | b, *_ = x.shape 529 | 530 | for block in self.blocks: 531 | x = block(x) 532 | 533 | x = self.flatten(x) 534 | x = self.to_logit(x) 535 | return x.squeeze() 536 | 537 | class StyleGAN2_cond_unet(nn.Module): 538 | def __init__(self, image_size, n_part=10, latent_dim=512, style_depth=8, network_capacity=16, steps=1, lr_D=1e-4, lr_G=1e-4): 539 | super().__init__() 540 | self.lr_D = lr_D 541 | self.lr_G = lr_G 542 | self.steps = steps 543 | self.ema_decay = 0.995 544 | 545 | self.S = StyleVectorizer(latent_dim, style_depth) 546 | self.G = Generator_unet(image_size, latent_dim, network_capacity) 547 | self.D = Discriminator(image_size, network_capacity, n_part=n_part) 548 | self.Enc = Encoder_unet(n_part, image_size, network_capacity) 549 | 550 | self.generator_params = list(self.G.parameters()) + list(self.S.parameters()) + list(self.Enc.parameters()) 551 | self.G_opt = Adam(self.generator_params, lr = self.lr_G, betas=(0., 0.99)) 552 | self.D_opt = Adam(self.D.parameters(), lr = self.lr_D, betas=(0., 0.99)) 553 | 554 | self._init_weights() 555 | self.cuda() 556 | 557 | def _init_weights(self): 558 | for m in self.modules(): 559 | if type(m) in {nn.Conv2d, nn.Linear}: 560 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu') 561 | 562 | for block in self.G.blocks: 563 | nn.init.zeros_(block.to_noise1.weight) 564 | nn.init.zeros_(block.to_noise2.weight) 565 | nn.init.zeros_(block.to_noise1.bias) 566 | nn.init.zeros_(block.to_noise2.bias) 567 | 568 | def forward(self, x): 569 | return x 570 | 571 | class Trainer(): 572 | def __init__(self, name, results_dir, models_dir, n_part, image_size, network_capacity, batch_size = 4, mixed_prob = 0.9, 573 | gradient_accumulate_every=1, lr_D = 2e-4, lr_G = 2e-4, num_workers = None, save_every = 1000, trunc_psi = 0.6, sparsity_penalty=0.): 574 | self.GAN = None 575 | 576 | self.name = name 577 | self.results_dir = Path(results_dir) 578 | self.models_dir = Path(models_dir) 579 | self.config_path = self.models_dir / name / '.config.json' 580 | 581 | assert log2(image_size).is_integer(), 'image size must be a power of 2 (64, 128, 256, 512, 1024)' 582 | self.n_part = n_part 583 | self.image_size = image_size 584 | self.network_capacity = network_capacity 585 | 586 | self.lr_D = lr_D 587 | self.lr_G = lr_G 588 | self.batch_size = batch_size 589 | self.num_workers = num_workers 590 | self.mixed_prob = mixed_prob 591 | self.sparsity_penalty = sparsity_penalty 592 | 593 | self.save_every = save_every 594 | self.steps = 0 595 | 596 | self.trunc_psi = trunc_psi 597 | 598 | self.gradient_accumulate_every = gradient_accumulate_every 599 | 600 | self.d_loss = 0 601 | self.g_loss = 0 602 | self.last_gp_loss = 0 603 | self.pl_loss = 0 604 | self.sparsity_loss = 0 605 | 606 | self.pl_mean = torch.empty(1).cuda() 607 | self.pl_ema_decay = 0.99 608 | 609 | self.loader_D = None 610 | self.loader_G = None 611 | self.av = None 612 | 613 | if 'bird' in self.name: 614 | self.part_to_id = {'initial': 0, 'eye': 1, 'head': 4, 'body': 3, 'beak': 2, 'legs': 5, 'wings': 8, 'mouth': 6, 'tail': 7} 615 | COLORS = COLORS_BIRD 616 | elif 'generic' in self.name or 'fin' in self.name or 'horn' in self.name: 617 | self.part_to_id = {'initial': 0, 'eye': 1, 'arms': 2, 'beak': 3, 'mouth': 4, 'body': 5, 'ears': 6, 'feet': 7, 'fin': 8, 618 | 'hair': 9, 'hands': 10, 'head': 11, 'horns': 12, 'legs': 13, 'nose': 14, 'paws': 15, 'tail': 16, 'wings':17} 619 | COLORS = COLORS_GENERIC 620 | 621 | self.color = 1-torch.cuda.FloatTensor([0, 0, 0]).view(1, -1, 1, 1) 622 | self.default_color = 1-torch.cuda.FloatTensor([0, 0, 0]).view(1, -1, 1, 1) 623 | for key in COLORS: 624 | if key in self.name: 625 | self.color = COLORS[key] 626 | break 627 | 628 | for partname in self.part_to_id.keys(): 629 | if partname in self.name: 630 | self.partid = self.part_to_id[partname] 631 | self.partname = partname 632 | 633 | def init_GAN(self): 634 | self.GAN = StyleGAN2_cond_unet(n_part=self.n_part, lr_G=self.lr_G, lr_D=self.lr_D, image_size = self.image_size, network_capacity = self.network_capacity) 635 | 636 | def write_config(self): 637 | self.config_path.write_text(json.dumps(self.config())) 638 | 639 | def load_config(self): 640 | config = self.config() if not self.config_path.exists() else json.loads(self.config_path.read_text()) 641 | self.image_size = config['image_size'] 642 | self.network_capacity = config['network_capacity'] 643 | del self.GAN 644 | self.init_GAN() 645 | 646 | def config(self): 647 | return {'image_size': self.image_size, 'network_capacity': self.network_capacity} 648 | 649 | def set_data_src(self, folder, large_aug=False): 650 | self.dataset_D = Dataset_JSON(folder, self.image_size, large_aug=large_aug) 651 | self.dataset_G = Dataset_JSON(folder, self.image_size, large_aug=large_aug) 652 | self.loader_D = cycle(data.DataLoader(self.dataset_D, num_workers = default(self.num_workers, num_cores), batch_size = self.batch_size, drop_last = True, shuffle=True, pin_memory=True)) 653 | self.loader_G = cycle(data.DataLoader(self.dataset_G, num_workers = default(self.num_workers, num_cores), batch_size = self.batch_size, drop_last = True, shuffle=True, pin_memory=True)) 654 | 655 | def train(self): 656 | assert self.loader_G is not None, 'You must first initialize the data source with `.set_data_src()`' 657 | 658 | self.init_folders() 659 | 660 | if self.GAN is None: 661 | self.init_GAN() 662 | 663 | self.GAN.train() 664 | total_disc_loss = torch.tensor(0.).cuda() 665 | total_gen_loss = torch.tensor(0.).cuda() 666 | 667 | batch_size = self.batch_size 668 | 669 | image_size = self.GAN.G.image_size 670 | latent_dim = self.GAN.G.latent_dim 671 | num_layers = self.GAN.G.num_layers 672 | 673 | apply_gradient_penalty = self.steps % 4 == 0 674 | apply_path_penalty = self.steps % 32 == 0 675 | 676 | backwards = partial(loss_backwards) 677 | 678 | avg_pl_length = self.pl_mean 679 | self.GAN.D_opt.zero_grad() 680 | 681 | for i in range(self.gradient_accumulate_every): 682 | image_batch, image_cond_batch, part_only_batch = [item.cuda() for item in next(self.loader_D)] 683 | image_partial_batch = image_cond_batch[:, -1:, :, :] # take the first one as the entire input partial sketch 684 | get_latents_fn = mixed_list if np.random.random() < self.mixed_prob else noise_list 685 | style = get_latents_fn(batch_size, num_layers, latent_dim) 686 | noise = image_noise(batch_size, image_size) 687 | 688 | bitmap_feats = self.GAN.Enc(image_cond_batch) 689 | 690 | w_space = latent_to_w(self.GAN.S, style) 691 | w_styles = styles_def_to_tensor(w_space) 692 | 693 | generated_partial_images = self.GAN.G(w_styles, noise, bitmap_feats) 694 | generated_images = torch.max(generated_partial_images, image_partial_batch) 695 | 696 | generated_image_stack_batch = torch.cat([image_cond_batch[:, :self.partid], torch.max(generated_partial_images, image_cond_batch[:, self.partid:self.partid+1]), 697 | image_cond_batch[:, self.partid+1:-1], generated_images], 1) 698 | fake_output = self.GAN.D(generated_image_stack_batch.clone().detach()) 699 | 700 | image_batch.requires_grad_() 701 | real_image_stack_batch = torch.cat([image_cond_batch[:, :self.partid], torch.max(part_only_batch, image_cond_batch[:, self.partid:self.partid+1]), 702 | image_cond_batch[:, self.partid+1:-1], image_batch], 1) 703 | real_image_stack_batch.requires_grad_() 704 | real_output = self.GAN.D(real_image_stack_batch) 705 | 706 | disc_loss = (F.relu(1 + real_output) + F.relu(1 - fake_output)).mean() 707 | 708 | if apply_gradient_penalty: 709 | gp = gradient_penalty(real_image_stack_batch, real_output) 710 | self.last_gp_loss = gp.clone().detach().item() 711 | disc_loss = disc_loss + gp 712 | 713 | disc_loss = disc_loss / self.gradient_accumulate_every 714 | disc_loss.register_hook(raise_if_nan) 715 | backwards(disc_loss, self.GAN.D_opt) 716 | 717 | total_disc_loss += disc_loss.detach().item() / self.gradient_accumulate_every 718 | 719 | self.d_loss = float(total_disc_loss) 720 | self.GAN.D_opt.step() 721 | 722 | # train generator 723 | 724 | self.GAN.G_opt.zero_grad() 725 | for i in range(self.gradient_accumulate_every): 726 | image_batch, image_cond_batch, part_only_batch = [item.cuda() for item in next(self.loader_G)] 727 | image_partial_batch = image_cond_batch[:, -1:, :, :] # take the first one as the entire input partial sketch 728 | 729 | style = get_latents_fn(batch_size, num_layers, latent_dim) 730 | noise = image_noise(batch_size, image_size) 731 | 732 | bitmap_feats = self.GAN.Enc(image_cond_batch) 733 | 734 | w_space = latent_to_w(self.GAN.S, style) 735 | w_styles = styles_def_to_tensor(w_space) 736 | 737 | generated_partial_images = self.GAN.G(w_styles, noise, bitmap_feats) 738 | generated_images = torch.max(generated_partial_images, image_partial_batch) 739 | 740 | generated_image_stack_batch = torch.cat([image_cond_batch[:, :self.partid], torch.max(generated_partial_images, image_cond_batch[:, self.partid:self.partid+1]), 741 | image_cond_batch[:, self.partid+1:-1], generated_images], 1) 742 | fake_output = self.GAN.D(generated_image_stack_batch) 743 | 744 | loss = fake_output.mean() 745 | gen_loss = loss 746 | 747 | if apply_path_penalty: 748 | pl_lengths = calc_pl_lengths(w_styles, generated_images) 749 | avg_pl_length = pl_lengths.detach().mean() 750 | 751 | if not is_empty(self.pl_mean): 752 | pl_loss = ((pl_lengths - self.pl_mean) ** 2).mean() 753 | if not torch.isnan(pl_loss): 754 | gen_loss = gen_loss + pl_loss 755 | if self.similarity_penalty: 756 | gen_loss = gen_loss - self.similarity_penalty*(pl_lengths ** 2).mean() 757 | 758 | if self.sparsity_penalty: 759 | generated_density = generated_partial_images.reshape(self.batch_size, -1).sum(1) 760 | target_density = part_only_batch.reshape(self.batch_size, -1).sum(1) # if we devide the sketch by parts 761 | self.sparsity_loss = ((generated_density-target_density)**2).mean() 762 | gen_loss = gen_loss + self.sparsity_loss*self.sparsity_penalty 763 | 764 | gen_loss = gen_loss / self.gradient_accumulate_every 765 | gen_loss.register_hook(raise_if_nan) 766 | backwards(gen_loss, self.GAN.G_opt) 767 | 768 | total_gen_loss += loss.detach().item() / self.gradient_accumulate_every 769 | 770 | self.g_loss = float(total_gen_loss) 771 | self.GAN.G_opt.step() 772 | 773 | # calculate moving averages 774 | 775 | if apply_path_penalty and not torch.isnan(avg_pl_length): 776 | ema_inplace(self.pl_mean, avg_pl_length, self.pl_ema_decay) 777 | self.pl_loss = self.pl_mean.item() 778 | 779 | # save from NaN errors 780 | 781 | checkpoint_num = floor(self.steps / self.save_every) 782 | 783 | if any(torch.isnan(l) for l in (total_gen_loss, total_disc_loss)): 784 | print(f'NaN detected for generator or discriminator. Loading from checkpoint #{checkpoint_num}') 785 | self.load(checkpoint_num) 786 | raise NanException 787 | 788 | # periodically save results 789 | 790 | if self.steps % self.save_every == 0: 791 | self.save(checkpoint_num) 792 | 793 | if self.steps % 1000 == 0 or (self.steps % 100 == 0 and self.steps < 2500): 794 | self.evaluate(floor(self.steps / 1000)) 795 | 796 | self.steps += 1 797 | self.av = None 798 | 799 | @torch.no_grad() 800 | def evaluate(self, num = 0, num_image_tiles = 8, trunc = 1.0, rgb = False): 801 | self.GAN.eval() 802 | ext = 'png' 803 | num_rows = num_image_tiles 804 | 805 | # latent_dim = self.GAN.G.latent_dim - self.GAN.Enc.feat_dim 806 | latent_dim = self.GAN.G.latent_dim 807 | image_size = self.GAN.G.image_size 808 | num_layers = self.GAN.G.num_layers 809 | 810 | # latents and noise 811 | 812 | latents_z = noise_list(num_rows ** 2, num_layers, latent_dim) 813 | n = image_noise(num_rows ** 2, image_size) 814 | 815 | image_batch, image_cond_batch, part_only_batch = [item.cuda() for item in self.dataset_G.sample_partial_test(num_rows ** 2)] 816 | image_partial_batch = image_cond_batch[:, -1:, :, :] # take the first one as the entire input partial sketch 817 | 818 | # concat the two latent vectors 819 | bitmap_feats = self.GAN.Enc(image_cond_batch) 820 | 821 | generated_partial_images = self.generate_truncated(self.GAN.S, self.GAN.G, latents_z, n, trunc_psi = self.trunc_psi, bitmap_feats=bitmap_feats) 822 | generated_images = torch.max(generated_partial_images, image_partial_batch) 823 | 824 | if not rgb: 825 | torchvision.utils.save_image(image_partial_batch, str(self.results_dir / self.name / f'{str(num)}-part.{ext}'), nrow=num_rows) 826 | # torchvision.utils.save_image((image_batch-image_partial_batch).clamp_(0., 1.), str(self.results_dir / self.name / f'{str(num)}-real.{ext}'), nrow=num_rows) 827 | torchvision.utils.save_image(part_only_batch, str(self.results_dir / self.name / f'{str(num)}-real.{ext}'), nrow=num_rows) 828 | torchvision.utils.save_image(image_batch, str(self.results_dir / self.name / f'{str(num)}-full.{ext}'), nrow=num_rows) 829 | # regular 830 | torchvision.utils.save_image(generated_partial_images, str(self.results_dir / self.name / f'{str(num)}-comp.{ext}'), nrow=num_rows) 831 | torchvision.utils.save_image(generated_images.clamp_(0., 1.), str(self.results_dir / self.name / f'{str(num)}.{ext}'), nrow=num_rows) 832 | else: 833 | # part_batch = (image_batch-image_partial_batch).clamp_(0., 1.) 834 | partial_rgb = gs_to_rgb(image_partial_batch, self.default_color) 835 | # part_rgb = gs_to_rgb(part_batch, self.color) 836 | part_rgb = gs_to_rgb(part_only_batch, self.color) 837 | torchvision.utils.save_image(partial_rgb, str(self.results_dir / self.name / f'{str(num)}-part.{ext}'), nrow=num_rows) 838 | torchvision.utils.save_image(part_rgb, str(self.results_dir / self.name / f'{str(num)}-real.{ext}'), nrow=num_rows) 839 | torchvision.utils.save_image(1-((1-part_rgb)+(1-partial_rgb).clamp_(0., 1.)), str(self.results_dir / self.name / f'{str(num)}-full.{ext}'), nrow=num_rows) 840 | # regular 841 | generated_part_rgb = gs_to_rgb(generated_partial_images, self.color) 842 | torchvision.utils.save_image(generated_part_rgb, str(self.results_dir / self.name / f'{str(num)}-comp.{ext}'), nrow=num_rows) 843 | torchvision.utils.save_image(1-((1-generated_part_rgb)+(1-partial_rgb).clamp_(0., 1.)), str(self.results_dir / self.name / f'{str(num)}.{ext}'), nrow=num_rows) 844 | 845 | @torch.no_grad() 846 | def generate_truncated(self, S, G, style, noi, trunc_psi = 0.75, num_image_tiles = 8, bitmap_feats=None): 847 | latent_dim = G.latent_dim 848 | 849 | if self.av is None: 850 | z = noise(2000, latent_dim) 851 | samples = evaluate_in_chunks(self.batch_size, S, z).cpu().numpy() 852 | self.av = np.mean(samples, axis = 0) 853 | self.av = np.expand_dims(self.av, axis = 0) 854 | 855 | w_space = [] 856 | for tensor, num_layers in style: 857 | tmp = S(tensor) 858 | av_torch = torch.from_numpy(self.av).cuda() 859 | tmp = trunc_psi * (tmp - av_torch) + av_torch 860 | w_space.append((tmp, num_layers)) 861 | 862 | w_styles = styles_def_to_tensor(w_space) 863 | generated_images = evaluate_in_chunks_unet(self.batch_size, G, bitmap_feats, w_styles, noi) 864 | return generated_images.clamp_(0., 1.) 865 | 866 | def print_log(self): 867 | print(f'G: {self.g_loss:.2f} | D: {self.d_loss:.2f} | GP: {self.last_gp_loss:.2f} | PL: {self.pl_loss:.2f} | SP {self.sparsity_loss:.2f}') 868 | 869 | def model_name(self, num): 870 | return str(self.models_dir / self.name / f'model_{num}.pt') 871 | 872 | def init_folders(self): 873 | (self.results_dir / self.name).mkdir(parents=True, exist_ok=True) 874 | (self.models_dir / self.name).mkdir(parents=True, exist_ok=True) 875 | 876 | def clear(self): 877 | rmtree(str(self.models_dir / self.name), True) 878 | rmtree(str(self.results_dir / self.name), True) 879 | rmtree(str(self.config_path), True) 880 | self.init_folders() 881 | 882 | def save(self, num): 883 | torch.save(self.GAN.state_dict(), self.model_name(num)) 884 | self.write_config() 885 | 886 | def load(self, num = -1): 887 | self.load_config() 888 | 889 | name = num 890 | if num == -1: 891 | file_paths = [p for p in Path(self.models_dir / self.name).glob('model_*.pt')] 892 | saved_nums = sorted(map(lambda x: int(x.stem.split('_')[1]), file_paths)) 893 | if len(saved_nums) == 0: 894 | return 895 | name = saved_nums[-1] 896 | print(f'continuing from previous epoch - {name}') 897 | self.steps = name * self.save_every 898 | self.GAN.load_state_dict(torch.load(self.model_name(name))) 899 | -------------------------------------------------------------------------------- /part_selector.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import os 7 | import math 8 | import json 9 | from math import floor, log2 10 | import random 11 | from shutil import rmtree 12 | from functools import partial 13 | import multiprocessing 14 | 15 | import numpy as np 16 | import torch 17 | from torch import nn 18 | from torch.utils import data 19 | import torch.nn.functional as F 20 | 21 | from torch.optim import Adam 22 | 23 | import torchvision 24 | 25 | from PIL import Image 26 | from pathlib import Path 27 | import cairocffi as cairo 28 | 29 | assert torch.cuda.is_available(), 'You need to have an Nvidia GPU with CUDA installed.' 30 | 31 | num_cores = multiprocessing.cpu_count() 32 | 33 | # helper classes 34 | 35 | class NanException(Exception): 36 | pass 37 | 38 | class Flatten(nn.Module): 39 | def forward(self, x): 40 | return x.reshape(x.shape[0], -1) 41 | 42 | # helpers 43 | 44 | def default(value, d): 45 | return d if value is None else value 46 | 47 | def cycle(iterable): 48 | while True: 49 | for i in iterable: 50 | yield i 51 | 52 | def is_empty(t): 53 | return t.nelement() == 0 54 | 55 | def raise_if_nan(t): 56 | if torch.isnan(t): 57 | raise NanException 58 | 59 | def loss_backwards(loss, optimizer, **kwargs): 60 | loss.backward(**kwargs) 61 | 62 | def leaky_relu(p=0.2): 63 | return nn.LeakyReLU(p, inplace=True) 64 | 65 | 66 | class Dataset_JSON(data.Dataset): 67 | def __init__(self, base_path, name, image_size): 68 | super().__init__() 69 | self.image_size = image_size 70 | if 'bird' in name: 71 | self.target_parts = ['eye', 'head', 'body', 'beak', 'legs', 'wings', 'mouth', 'tail', 'none'] 72 | self.id_to_part = {0:'initial', 1:'eye', 4:'head', 3:'body', 2:'beak', 5:'legs', 8:'wings', 6:'mouth', 7:'tail'} 73 | elif 'generic' in name or 'fin' in name or 'horn' in name: 74 | self.target_parts = ['eye', 'arms', 'beak', 'mouth', 'body', 'ears', 'feet', 'fin', 75 | 'hair', 'hands', 'head', 'horns', 'legs', 'nose', 'paws', 'tail', 'wings', 'none'] 76 | self.id_to_part = { 0:'initial', 1:'eye', 2:'arms', 3:'beak', 4:'mouth', 5:'body', 6:'ears', 7:'feet', 8:'fin', 77 | 9:'hair', 10:'hands', 11:'head', 12:'horns', 13:'legs', 14:'nose', 15:'paws', 16:'tail', 17:'wings'} 78 | folder = base_path+'%s_json_'+'%d_train'%image_size 79 | self.paths = [] 80 | self.paths_test = [] 81 | # split the training data based on thte aids of the eye sketches 82 | for i, p in enumerate(Path(f'{folder%self.target_parts[0]}').glob(f'**/*.json')): 83 | if i%5 == 0: 84 | self.paths_test.append(p) 85 | else: 86 | self.paths.append(p) 87 | for part in self.target_parts[1:]: 88 | for i, p in enumerate(Path(f'{folder%part}').glob(f'**/*.json')): 89 | if Path(str(p).replace('_'+part, '_'+self.target_parts[0])) in self.paths_test: 90 | self.paths_test.append(p) 91 | else: 92 | self.paths.append(p) 93 | self.parts_id = [self.target_parts.index(str(path).split('_')[-5]) for path in self.paths] 94 | self.parts_id_test = [self.target_parts.index(str(path).split('_')[-5]) for path in self.paths_test] 95 | self.rotate = [-1/12*np.pi, 1/12*np.pi] 96 | self.trans = 0.01 97 | self.scale = [0.9, 1.1] 98 | self.n_part = len(self.id_to_part) 99 | 100 | self.samples_partid_test = [torch.LongTensor([self.parts_id_test[sample_id]]) for sample_id in range(self.__len_test__())] 101 | self.samples_partial_test = [] 102 | for sample_id in range(self.__len_test__()): 103 | input_parts_json = json.load(open(self.paths_test[sample_id]))['input_parts'] 104 | img_partial_test = [] 105 | vector_input_part = [] 106 | for i in range(self.n_part): 107 | key = self.id_to_part[i] 108 | vector_input_part += input_parts_json[key] 109 | img_partial_test.append(self.processed_part_to_raster(input_parts_json[key], side=self.image_size)) 110 | img_partial_test.append(self.processed_part_to_raster(vector_input_part, side=self.image_size)) 111 | self.samples_partial_test.append(torch.cat(img_partial_test, 0)) 112 | # import ipdb;ipdb.set_trace() 113 | 114 | self.samples_partid_test = torch.stack(self.samples_partid_test) 115 | self.samples_partial_test = torch.stack(self.samples_partial_test) 116 | print(' | '.join(['%s : %d'%(target_part, (self.samples_partid_test==i).sum()) for i, target_part in enumerate(self.target_parts)])+ 117 | ' | overall : %d'%(len(self.samples_partid_test))) 118 | 119 | def __len__(self): 120 | return len(self.paths) 121 | 122 | def __len_test__(self): 123 | return len(self.paths_test) 124 | 125 | def __getitem__(self, index): 126 | path = self.paths[index] 127 | part_id = self.parts_id[index] 128 | json_data = json.load(open(path)) 129 | input_parts_json = json_data['input_parts'] 130 | img_partial_test = [] 131 | vector_input_part = [] 132 | for i in range(self.n_part): 133 | key = self.id_to_part[i] 134 | vector_input_part += input_parts_json[key] 135 | img_partial_test.append(self.processed_part_to_raster(input_parts_json[key], side=self.image_size)) 136 | img_partial_test.append(self.processed_part_to_raster(vector_input_part, side=self.image_size)) 137 | # random affine 138 | theta = np.random.uniform(*self.rotate) 139 | trans_pixel = 512*self.trans 140 | translate_x = np.random.uniform(-trans_pixel, trans_pixel) 141 | translate_y = np.random.uniform(-trans_pixel, trans_pixel) 142 | scale = np.random.uniform(self.scale) 143 | # apply 144 | processed_img_partial = [] 145 | affine_vector_input_part = [] 146 | for i in range(self.n_part): 147 | key = self.id_to_part[i] 148 | affine_input_part_json = self.affine_trans(input_parts_json[key], theta, translate_x, translate_y, scale) 149 | affine_vector_input_part += affine_input_part_json 150 | processed_img_partial.append(self.processed_part_to_raster(affine_input_part_json, side=self.image_size)) 151 | processed_img_partial.append(self.processed_part_to_raster(affine_vector_input_part, side=self.image_size)) 152 | return part_id, torch.cat(processed_img_partial, 0), torch.cat(img_partial_test, 0) 153 | 154 | def sample_partial_test(self, n): 155 | sample_ids = [np.random.randint(self.__len__()) for _ in range(n)] 156 | samples_partid = [torch.LongTensor([self.parts_id[sample_id]]) for sample_id in sample_ids] 157 | sample_jsons = [json.load(open(self.paths[sample_id]))for sample_id in sample_ids] 158 | samples_partial = [] 159 | for sample_json in sample_jsons: 160 | input_parts_json = sample_json['input_parts'] 161 | img_partial_test = [] 162 | vector_input_part = [] 163 | for i in range(self.n_part): 164 | key = self.id_to_part[i] 165 | vector_input_part += input_parts_json[key] 166 | img_partial_test.append(self.processed_part_to_raster(input_parts_json[key], side=self.image_size)) 167 | img_partial_test.append(self.processed_part_to_raster(vector_input_part, side=self.image_size)) 168 | samples_partial.append(torch.cat(img_partial_test, 0)) 169 | return torch.stack(samples_partid), torch.stack(samples_partial) 170 | 171 | def affine_trans(self, data, theta, translate_x, translate_y, scale): 172 | rotate_mat = np.array([[np.cos(theta), np.sin(theta)], [-np.sin(theta), np.cos(theta)]]) 173 | affine_data = [] 174 | for item in data: 175 | affine_item = np.array(item) - 256. 176 | affine_item = np.transpose(np.matmul(rotate_mat, np.transpose(affine_item))) 177 | affine_item[:, 0] += translate_x 178 | affine_item[:, 1] += translate_y 179 | affine_item *= scale 180 | affine_data.append(affine_item + 256.) 181 | return affine_data 182 | 183 | def processed_part_to_raster(self, vector_part, side=64, line_diameter=16, padding=16, bg_color=(0,0,0), fg_color=(1,1,1)): 184 | """ 185 | render raster image based on the processed part 186 | """ 187 | original_side = 512. 188 | surface = cairo.ImageSurface(cairo.FORMAT_ARGB32, side, side) 189 | ctx = cairo.Context(surface) 190 | ctx.set_antialias(cairo.ANTIALIAS_BEST) 191 | ctx.set_line_cap(cairo.LINE_CAP_ROUND) 192 | ctx.set_line_join(cairo.LINE_JOIN_ROUND) 193 | ctx.set_line_width(line_diameter) 194 | # scale to match the new size 195 | # add padding at the edges for the line_diameter 196 | # and add additional padding to account for antialiasing 197 | total_padding = padding * 2. + line_diameter 198 | new_scale = float(side) / float(original_side + total_padding) 199 | ctx.scale(new_scale, new_scale) 200 | ctx.translate(total_padding / 2., total_padding / 2.) 201 | raster_images = [] 202 | # clear background 203 | ctx.set_source_rgb(*bg_color) 204 | ctx.paint() 205 | # draw strokes, this is the most cpu-intensive part 206 | ctx.set_source_rgb(*fg_color) 207 | for stroke in vector_part: 208 | if len(stroke) == 0: 209 | continue 210 | ctx.move_to(stroke[0][0], stroke[0][1]) 211 | for x, y in stroke: 212 | ctx.line_to(x, y) 213 | ctx.stroke() 214 | surface_data = surface.get_data() 215 | raster_image = np.copy(np.asarray(surface_data))[::4].reshape(side, side) 216 | return torch.FloatTensor(raster_image/255.)[None, :, :] 217 | 218 | # exponential moving average helpers 219 | 220 | def ema_inplace(moving_avg, new, decay): 221 | if is_empty(moving_avg): 222 | moving_avg.data.copy_(new) 223 | return 224 | moving_avg.data.mul_(decay).add_(1 - decay, new) 225 | 226 | 227 | class ClassifierBlock(nn.Module): 228 | def __init__(self, input_channels, filters, downsample=True): 229 | super().__init__() 230 | self.conv_res = nn.Conv2d(input_channels, filters, 1) 231 | 232 | self.net = nn.Sequential( 233 | nn.Conv2d(input_channels, filters, 3, padding=1), 234 | leaky_relu(), 235 | nn.Conv2d(filters, filters, 3, padding=1), 236 | leaky_relu() 237 | ) 238 | 239 | self.downsample = nn.Conv2d(filters, filters, 3, padding = 1, stride = 2) if downsample else None 240 | 241 | def forward(self, x): 242 | res = self.conv_res(x) 243 | x = self.net(x) 244 | x = x + res 245 | if self.downsample is not None: 246 | x = self.downsample(x) 247 | return x 248 | 249 | 250 | class Classifier(nn.Module): 251 | def __init__(self, image_size, network_capacity=16, n_part=1): 252 | super().__init__() 253 | num_layers = int(log2(image_size) - 1) 254 | num_init_filters = n_part 255 | 256 | blocks = [] 257 | filters = [num_init_filters] + [(network_capacity) * (2 ** i) for i in range(num_layers+1)] 258 | chan_in_out = list(zip(filters[0:-1], filters[1:])) 259 | 260 | for ind, (in_chan, out_chan) in enumerate(chan_in_out): 261 | num_layer = ind + 1 262 | is_not_last = ind < (len(chan_in_out) - 1) 263 | block = ClassifierBlock(in_chan, out_chan, downsample = is_not_last) 264 | blocks.append(block) 265 | 266 | self.blocks = nn.ModuleList(blocks) 267 | 268 | latent_dim = 2 * 2 * filters[-1] 269 | 270 | self.flatten = Flatten() 271 | self.to_logit = nn.Linear(latent_dim, n_part-1) 272 | 273 | def forward(self, x): 274 | b, *_ = x.shape 275 | 276 | for block in self.blocks: 277 | x = block(x) 278 | 279 | x = self.flatten(x) 280 | x = self.to_logit(x) 281 | return x.squeeze() 282 | 283 | 284 | class part_selector(nn.Module): 285 | def __init__(self, image_size, n_part=10, network_capacity=16, steps=1, lr=1e-4): 286 | super().__init__() 287 | self.lr = lr 288 | self.steps = steps 289 | self.ema_decay = 0.995 290 | self.D = Classifier(image_size, network_capacity,n_part=n_part) 291 | self.D_opt = Adam(self.D.parameters(), lr = self.lr, betas=(0.5, 0.9)) 292 | 293 | self._init_weights() 294 | 295 | self.cuda() 296 | 297 | def _init_weights(self): 298 | for m in self.modules(): 299 | if type(m) in {nn.Conv2d, nn.Linear}: 300 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu') 301 | 302 | def forward(self, x): 303 | return x 304 | 305 | class Trainer(): 306 | def __init__(self, name, results_dir, models_dir, n_part, image_size, network_capacity, batch_size = 4, 307 | gradient_accumulate_every=1, lr = 2e-4, num_workers = None, save_every = 1000): 308 | self.clf = None 309 | 310 | self.name = name 311 | self.results_dir = Path(results_dir) 312 | self.models_dir = Path(models_dir) 313 | self.config_path = self.models_dir / name / '.config.json' 314 | 315 | assert log2(image_size).is_integer(), 'image size must be a power of 2 (64, 128, 256, 512, 1024)' 316 | self.n_part = n_part 317 | self.image_size = image_size 318 | self.network_capacity = network_capacity 319 | 320 | self.lr = lr 321 | self.batch_size = batch_size 322 | self.num_workers = num_workers 323 | 324 | self.save_every = save_every 325 | self.steps = 0 326 | 327 | self.gradient_accumulate_every = gradient_accumulate_every 328 | 329 | self.d_loss = 0 330 | self.d_acc = 0 331 | 332 | self.loader = None 333 | 334 | self.criterion = nn.CrossEntropyLoss() 335 | 336 | if 'bird' in name: 337 | self.target_parts = ['eye', 'head', 'body', 'beak', 'legs', 'wing', 'mouth', 'tail', 'none'] 338 | elif 'generic' in name or 'fin' in name or 'horn' in name: 339 | self.target_parts = ['eye', 'arms', 'beak', 'mouth', 'body', 'ears', 'feet', 'fin', 340 | 'hair', 'hands', 'head', 'horns', 'legs', 'nose', 'paws', 'tail', 'wings', 'none'] 341 | self.n_part_class = len(self.target_parts) 342 | 343 | def init_clf(self): 344 | self.clf = part_selector(n_part=self.n_part, lr=self.lr, image_size=self.image_size, network_capacity=self.network_capacity) 345 | 346 | def write_config(self): 347 | self.config_path.write_text(json.dumps(self.config())) 348 | 349 | def load_config(self): 350 | config = self.config() if not self.config_path.exists() else json.loads(self.config_path.read_text()) 351 | self.image_size = config['image_size'] 352 | self.network_capacity = config['network_capacity'] 353 | del self.clf 354 | self.init_clf() 355 | 356 | def config(self): 357 | return {'image_size': self.image_size, 'network_capacity': self.network_capacity} 358 | 359 | def set_data_src(self, folder, name): 360 | self.dataset = Dataset_JSON(folder, name, self.image_size) 361 | print('Number of data: %d'%(len(self.dataset))) 362 | self.loader = cycle(data.DataLoader(self.dataset, num_workers=default(self.num_workers, num_cores), batch_size=self.batch_size, drop_last=True, shuffle=True, pin_memory=True)) 363 | 364 | def train(self): 365 | self.init_folders() 366 | if self.clf is None: 367 | self.init_clf() 368 | 369 | self.clf.train() 370 | total_disc_loss = torch.tensor(0.).cuda() 371 | total_acc = torch.tensor(0.).cuda() 372 | batch_size = self.batch_size 373 | 374 | backwards = partial(loss_backwards) 375 | 376 | self.clf.D_opt.zero_grad() 377 | 378 | for i in range(self.gradient_accumulate_every): 379 | part_id_batch, image_cond_batch, _ = [item.cuda() for item in next(self.loader)] 380 | outputs = self.clf.D(image_cond_batch) 381 | _, predicts = torch.max(outputs, 1) 382 | acc = (predicts == part_id_batch).sum().float() / part_id_batch.size(0) / self.gradient_accumulate_every 383 | disc_loss = self.criterion(outputs, part_id_batch) 384 | disc_loss = disc_loss / self.gradient_accumulate_every 385 | disc_loss.register_hook(raise_if_nan) 386 | backwards(disc_loss, self.clf.D_opt) 387 | total_disc_loss += disc_loss.detach().item() 388 | total_acc += acc.detach().item() 389 | 390 | self.d_loss = float(total_disc_loss) 391 | self.d_acc = float(total_acc) 392 | self.clf.D_opt.step() 393 | 394 | # save from NaN errors 395 | 396 | checkpoint_num = floor(self.steps / self.save_every) 397 | 398 | if torch.isnan(total_disc_loss): 399 | print(f'NaN detected. Loading from checkpoint #{checkpoint_num}') 400 | self.load(checkpoint_num) 401 | raise NanException 402 | 403 | # periodically save results 404 | if self.steps % self.save_every == 0: 405 | self.save(checkpoint_num) 406 | 407 | if self.steps % 1000 == 0 or (self.steps % 100 == 0 and self.steps < 2500): 408 | self.evaluate(floor(self.steps / 1000)) 409 | 410 | self.steps += 1 411 | 412 | @torch.no_grad() 413 | def evaluate(self, num = 0, num_image_tiles = 8): 414 | self.clf.eval() 415 | ext = 'png' 416 | num_rows = num_image_tiles 417 | part_id_batch, image_cond_batch = [item.cuda() for item in self.dataset.sample_partial_test(num_rows ** 2)] 418 | outputs = self.clf.D(image_cond_batch.clone().detach()) 419 | _, predicted = torch.max(outputs, 1) 420 | with open(str(self.results_dir / self.name / f'{str(num)}-pred.txt'), 'w') as fw: 421 | for i in range(num_rows): 422 | for j in range(num_rows): 423 | fw.write('%s\t'%self.target_parts[predicted[i*num_rows+j]]) 424 | fw.write('\n') 425 | with open(str(self.results_dir / self.name / f'{str(num)}-real.txt'), 'w') as fw: 426 | for i in range(num_rows): 427 | for j in range(num_rows): 428 | fw.write('%s\t'%self.target_parts[part_id_batch[i*num_rows+j]]) 429 | fw.write('\n') 430 | torchvision.utils.save_image(image_cond_batch[:, -1:], str(self.results_dir / self.name / f'{str(num)}.{ext}'), nrow=num_rows) 431 | part_id_test, image_cond_test = self.dataset.samples_partid_test.cuda(), self.dataset.samples_partial_test.cuda() 432 | class_correct = list(0. for i in range(self.n_part_class)) 433 | class_total = list(0. for i in range(self.n_part_class)) 434 | n_batch = self.dataset.__len_test__()//256 435 | for i in range(n_batch+1): 436 | if i == n_batch: 437 | part_id_batch, image_cond_batch = part_id_test[i*256:], image_cond_test[i*256:] 438 | else: 439 | part_id_batch, image_cond_batch = part_id_test[i*256:(i+1)*256], image_cond_test[i*256:(i+1)*256] 440 | outputs = self.clf.D(image_cond_batch.clone().detach()) 441 | _, predicts = torch.max(outputs, 1) 442 | with torch.no_grad(): 443 | for part_id, pred_id in zip(part_id_batch, predicts): 444 | c = (part_id == pred_id).squeeze() 445 | class_correct[part_id] += c 446 | class_total[part_id] += 1 447 | print(' | '.join(['%s: %.2f'%(target_part, 100*class_correct[i]/(class_total[i]+1e-6)) for i, target_part in enumerate(self.target_parts)])+ 448 | ' | overall : %.2f'%(100*sum(class_correct)/(sum(class_total)+1e-6))) 449 | 450 | def print_log(self): 451 | print(f'training loss: {self.d_loss:.2f} | training acc: {self.d_acc:.2f}') 452 | 453 | def model_name(self, num): 454 | return str(self.models_dir / self.name / f'model_{num}.pt') 455 | 456 | def init_folders(self): 457 | (self.results_dir / self.name).mkdir(parents=True, exist_ok=True) 458 | (self.models_dir / self.name).mkdir(parents=True, exist_ok=True) 459 | 460 | def clear(self): 461 | rmtree(str(self.models_dir / self.name), True) 462 | rmtree(str(self.results_dir / self.name), True) 463 | rmtree(str(self.config_path), True) 464 | self.init_folders() 465 | 466 | def save(self, num): 467 | torch.save(self.clf.state_dict(), self.model_name(num)) 468 | self.write_config() 469 | 470 | def load(self, num = -1): 471 | self.load_config() 472 | name = num 473 | if num == -1: 474 | file_paths = [p for p in Path(self.models_dir / self.name).glob('model_*.pt')] 475 | saved_nums = sorted(map(lambda x: int(x.stem.split('_')[1]), file_paths)) 476 | if len(saved_nums) == 0: 477 | return 478 | name = saved_nums[-1] 479 | print(f'continuing from previous epoch - {name}') 480 | self.steps = name * self.save_every 481 | self.clf.load_state_dict(torch.load(self.model_name(name))) 482 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | backcall==0.2.0 2 | cairocffi==1.1.0 3 | certifi==2020.6.20 4 | decorator==4.4.2 5 | fire==0.3.1 6 | ipdb==0.13.3 7 | ipython==7.18.1 8 | ipython-genutils==0.2.0 9 | jedi==0.17.2 10 | mkl-fft==1.1.0 11 | mkl-random==1.1.1 12 | mkl-service==2.3.0 13 | olefile==0.46 14 | opencv-python==4.4.0.42 15 | parso==0.7.1 16 | pexpect==4.8.0 17 | pickleshare==0.7.5 18 | prompt-toolkit==3.0.7 19 | ptyprocess==0.6.0 20 | py==1.9.0 21 | Pygments==2.6.1 22 | pytorch-ranger==0.1.1 23 | rdp==0.8 24 | retry==0.9.2 25 | six==1.15.0 26 | svgwrite==1.4 27 | termcolor==1.1.0 28 | torch==1.3.1 29 | torch-optimizer==0.0.1a15 30 | torchvision==0.4.2 31 | tqdm==4.48.2 32 | traitlets==5.0.4 33 | wcwidth==0.2.5 34 | -------------------------------------------------------------------------------- /run_part_generator.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import argparse 7 | from retry.api import retry_call 8 | from tqdm import tqdm 9 | from part_generator import Trainer, NanException 10 | from datetime import datetime 11 | 12 | def train_from_folder( 13 | data = '../../data', 14 | results_dir = '../../results', 15 | models_dir = '../../models', 16 | name = 'default', 17 | new = False, 18 | large_aug = False, 19 | load_from = -1, 20 | n_part = 1, 21 | image_size = 128, 22 | network_capacity = 16, 23 | batch_size = 3, 24 | gradient_accumulate_every = 5, 25 | num_train_steps = 150000, 26 | learning_rate_D = 2e-4, 27 | learning_rate_G = 2e-4, 28 | num_workers = None, 29 | save_every = 1000, 30 | generate = False, 31 | num_image_tiles = 8, 32 | trunc_psi = 0.75, 33 | sparsity_penalty = 0., 34 | ): 35 | model = Trainer( 36 | name, 37 | results_dir, 38 | models_dir, 39 | batch_size = batch_size, 40 | gradient_accumulate_every = gradient_accumulate_every, 41 | n_part = n_part, 42 | image_size = image_size, 43 | network_capacity = network_capacity, 44 | lr_D = learning_rate_D, 45 | lr_G = learning_rate_G, 46 | num_workers = num_workers, 47 | save_every = save_every, 48 | trunc_psi = trunc_psi, 49 | sparsity_penalty = sparsity_penalty, 50 | ) 51 | 52 | if not new: 53 | model.load(load_from) 54 | else: 55 | model.clear() 56 | 57 | model.set_data_src(data, large_aug) 58 | 59 | if generate: 60 | now = datetime.now() 61 | timestamp = now.strftime("%m-%d-%Y_%H-%M-%S") 62 | samples_name = f'generated-{timestamp}' 63 | model.evaluate(samples_name, num_image_tiles, rgb=True) 64 | print(f'sample images generated at {results_dir}/{name}/{samples_name}') 65 | return 66 | 67 | for _ in tqdm(range(num_train_steps - model.steps), mininterval=10., desc=f'{name}<{data}>'): 68 | retry_call(model.train, tries=3, exceptions=NanException) 69 | if _ % 50 == 0: 70 | model.print_log() 71 | 72 | if __name__ == "__main__": 73 | parser = argparse.ArgumentParser() 74 | parser.add_argument("--data", type=str, default='../../data') 75 | parser.add_argument("--results_dir", type=str, default='../../results') 76 | parser.add_argument("--models_dir", type=str, default='../../models') 77 | parser.add_argument("--name", type=str, default='default') 78 | parser.add_argument("--load_from", type=int, default=-1) 79 | 80 | parser.add_argument('--new', action='store_true') 81 | parser.add_argument('--large_aug', action='store_true') 82 | parser.add_argument('--generate', action='store_true') 83 | 84 | parser.add_argument('--n_part', type=int, default=1) 85 | parser.add_argument('--image_size', type=int, default=128) 86 | parser.add_argument('--network_capacity', type=int, default=16) 87 | parser.add_argument('--batch_size', type=int, default=3) 88 | parser.add_argument('--gradient_accumulate_every', type=int, default=5) 89 | parser.add_argument('--num_train_steps', type=int, default=150000) 90 | parser.add_argument('--num_workers', type=int, default=None) 91 | parser.add_argument('--save_every', type=int, default=1000) 92 | parser.add_argument('--num_image_tiles', type=int, default=8) 93 | 94 | parser.add_argument('--learning_rate_D', type=float, default=1e-4) 95 | parser.add_argument('--learning_rate_G', type=float, default=1e-4) 96 | parser.add_argument('--sparsity_penalty', type=float, default=0.) 97 | parser.add_argument('--trunc_psi', type=float, default=1.) 98 | 99 | args = parser.parse_args() 100 | print(args) 101 | 102 | train_from_folder(args.data, args.results_dir, args.models_dir, args.name, args.new, args.large_aug, args.load_from, args.n_part, 103 | args.image_size, args.network_capacity, args.batch_size, args.gradient_accumulate_every, args.num_train_steps, args.learning_rate_D, 104 | args.learning_rate_G, args.num_workers, args.save_every, args.generate, args.num_image_tiles, args.trunc_psi, args.sparsity_penalty) 105 | -------------------------------------------------------------------------------- /run_part_selector.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import argparse 7 | from retry.api import retry_call 8 | from tqdm import tqdm 9 | from part_selector import Trainer, NanException 10 | from datetime import datetime 11 | 12 | def train_from_folder( 13 | data = '../../data', 14 | results_dir = '../../results', 15 | models_dir = '../../models', 16 | name = 'default', 17 | new = False, 18 | load_from = -1, 19 | n_part = 1, 20 | image_size = 128, 21 | network_capacity = 16, 22 | batch_size = 3, 23 | gradient_accumulate_every = 5, 24 | num_train_steps = 150000, 25 | learning_rate = 2e-4, 26 | num_workers = None, 27 | save_every = 1000, 28 | num_image_tiles = 8, 29 | ): 30 | model = Trainer( 31 | name, 32 | results_dir, 33 | models_dir, 34 | batch_size = batch_size, 35 | gradient_accumulate_every = gradient_accumulate_every, 36 | n_part = n_part, 37 | image_size = image_size, 38 | network_capacity = network_capacity, 39 | lr = learning_rate, 40 | num_workers = num_workers, 41 | save_every = save_every, 42 | ) 43 | 44 | if not new: 45 | model.load(load_from) 46 | else: 47 | model.clear() 48 | 49 | model.set_data_src(data, name) 50 | 51 | for _ in tqdm(range(num_train_steps - model.steps), mininterval=10., desc=f'{name}<{data}>'): 52 | retry_call(model.train, tries=3, exceptions=NanException) 53 | if _ % 50 == 0: 54 | model.print_log() 55 | 56 | if __name__ == "__main__": 57 | parser = argparse.ArgumentParser() 58 | parser.add_argument("--data", type=str, default='../../data') 59 | parser.add_argument("--results_dir", type=str, default='../results') 60 | parser.add_argument("--models_dir", type=str, default='../models') 61 | parser.add_argument("--name", type=str, default='default') 62 | parser.add_argument("--load_from", type=int, default=-1) 63 | parser.add_argument('--new', action='store_true') 64 | parser.add_argument('--n_part', type=int, default=1) 65 | parser.add_argument('--image_size', type=int, default=128) 66 | parser.add_argument('--network_capacity', type=int, default=16) 67 | parser.add_argument('--batch_size', type=int, default=64) 68 | parser.add_argument('--gradient_accumulate_every', type=int, default=1) 69 | parser.add_argument('--num_train_steps', type=int, default=200000) 70 | parser.add_argument('--num_workers', type=int, default=None) 71 | parser.add_argument('--save_every', type=int, default=1000) 72 | parser.add_argument('--num_image_tiles', type=int, default=8) 73 | parser.add_argument('--learning_rate', type=float, default=2e-4) 74 | 75 | args = parser.parse_args() 76 | print(args) 77 | 78 | train_from_folder(args.data, args.results_dir, args.models_dir, args.name, args.new, args.load_from, args.n_part, 79 | args.image_size, args.network_capacity, args.batch_size, args.gradient_accumulate_every, args.num_train_steps, 80 | args.learning_rate, args.num_workers, args.save_every, args.num_image_tiles) 81 | -------------------------------------------------------------------------------- /training_scripts/train_creative_birds/bird_short_creative_clf_aug.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | #!/bin/bash 7 | #SBATCH --job-name=short_bird_creative_selector 8 | #SBATCH --output=../../../jobs/sample-short_bird_creative_selector-%j.out 9 | #SBATCH --error=../../../jobs/sample-short_bird_creative_selector-%j.err 10 | #SBATCH --partition=short 11 | #SBATCH --time=40:00:00 12 | #SBATCH --ntasks-per-node=1 13 | #SBATCH --gres=gpu:1 14 | #SBATCH --cpus-per-task=10 15 | 16 | python ../../run_part_selector.py --new --results_dir ../../../results --models_dir ../../../models --n_part 10 --data ../../../data/bird_short_ --name short_bird_creative_selector --batch_size 128 --save_every 1000 --image_size 64 17 | -------------------------------------------------------------------------------- /training_scripts/train_creative_birds/bird_short_creative_sequential_unet_partonly_beak.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | #!/bin/bash 7 | #SBATCH --job-name=short_bird_creative_beak_unet 8 | #SBATCH --output=../../../jobs/sample-short_bird_creative_beak-%j.out 9 | #SBATCH --error=../../../jobs/sample-short_bird_creative_beak-%j.err 10 | #SBATCH --partition=short 11 | #SBATCH --time=72:00:00 12 | #SBATCH --ntasks-per-node=1 13 | #SBATCH --gres=gpu:1 14 | #SBATCH --cpus-per-task=10 15 | 16 | python ../../run_part_generator.py --new --results_dir ../../../results --models_dir ../../../models --n_part 10 --data ../../../data/bird_short_beak_json_64 --name short_bird_creative_beak --batch_size 40 --network_capacity 16 --gradient_accumulate_every 1 --save_every 2000 --image_size 64 --sparsity_penalty 0.01 --learning_rate_D 1e-4 --learning_rate_G 1e-4 --num_train_steps 300000 17 | -------------------------------------------------------------------------------- /training_scripts/train_creative_birds/bird_short_creative_sequential_unet_partonly_body.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | #!/bin/bash 7 | #SBATCH --job-name=short_bird_creative_body_unet 8 | #SBATCH --output=../../../jobs/sample-short_bird_creative_body-%j.out 9 | #SBATCH --error=../../../jobs/sample-short_bird_creative_body-%j.err 10 | #SBATCH --partition=short 11 | #SBATCH --time=72:00:00 12 | #SBATCH --ntasks-per-node=1 13 | #SBATCH --gres=gpu:1 14 | #SBATCH --cpus-per-task=10 15 | 16 | python ../../run_part_generator.py --new --results_dir ../../../results --models_dir ../../../models --n_part 10 --data ../../../data/bird_short_body_json_64 --name short_bird_creative_body --batch_size 40 --network_capacity 16 --gradient_accumulate_every 1 --save_every 2000 --image_size 64 --sparsity_penalty 0.01 --learning_rate_D 1e-4 --learning_rate_G 1e-4 --num_train_steps 300000 17 | -------------------------------------------------------------------------------- /training_scripts/train_creative_birds/bird_short_creative_sequential_unet_partonly_eye.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | #!/bin/bash 7 | #SBATCH --job-name=short_bird_creative_eye_unet 8 | #SBATCH --output=../../../jobs/sample-short_bird_creative_eye-%j.out 9 | #SBATCH --error=../../../jobs/sample-short_bird_creative_eye-%j.err 10 | #SBATCH --partition=short 11 | #SBATCH --time=72:00:00 12 | #SBATCH --ntasks-per-node=1 13 | #SBATCH --gres=gpu:1 14 | #SBATCH --cpus-per-task=10 15 | 16 | python ../../run_part_generator.py --new --results_dir ../../../results --models_dir ../../../models --n_part 10 --data ../../../data/bird_short_eye_json_64 --name short_bird_creative_eye --batch_size 40 --network_capacity 16 --gradient_accumulate_every 1 --save_every 2000 --image_size 64 --sparsity_penalty 0.01 --learning_rate_D 1e-4 --learning_rate_G 1e-4 --num_train_steps 300000 17 | -------------------------------------------------------------------------------- /training_scripts/train_creative_birds/bird_short_creative_sequential_unet_partonly_head.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | #!/bin/bash 7 | #SBATCH --job-name=short_bird_creative_head_unet 8 | #SBATCH --output=../../../jobs/sample-short_bird_creative_head-%j.out 9 | #SBATCH --error=../../../jobs/sample-short_bird_creative_head-%j.err 10 | #SBATCH --partition=short 11 | #SBATCH --time=72:00:00 12 | #SBATCH --ntasks-per-node=1 13 | #SBATCH --gres=gpu:1 14 | #SBATCH --cpus-per-task=10 15 | 16 | python ../../run_part_generator.py --new --results_dir ../../../results --models_dir ../../../models --n_part 10 --data ../../../data/bird_short_head_json_64 --name short_bird_creative_head --batch_size 40 --network_capacity 16 --gradient_accumulate_every 1 --save_every 2000 --image_size 64 --sparsity_penalty 0.01 --learning_rate_D 1e-4 --learning_rate_G 1e-4 --num_train_steps 300000 17 | -------------------------------------------------------------------------------- /training_scripts/train_creative_birds/bird_short_creative_sequential_unet_partonly_legs.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | #!/bin/bash 7 | #SBATCH --job-name=short_bird_creative_legs_unet 8 | #SBATCH --output=../../../jobs/sample-short_bird_creative_legs-%j.out 9 | #SBATCH --error=../../../jobs/sample-short_bird_creative_legs-%j.err 10 | #SBATCH --partition=short 11 | #SBATCH --time=72:00:00 12 | #SBATCH --gres=gpu:1 13 | #SBATCH --cpus-per-task=10 14 | 15 | python ../../run_part_generator.py --new --results_dir ../../../results --models_dir ../../../models --n_part 10 --data ../../../data/bird_short_legs_json_64 --name short_bird_creative_legs --batch_size 40 --network_capacity 16 --gradient_accumulate_every 1 --save_every 2000 --image_size 64 --sparsity_penalty 0.01 --learning_rate_D 1e-4 --learning_rate_G 1e-4 --num_train_steps 300000 16 | -------------------------------------------------------------------------------- /training_scripts/train_creative_birds/bird_short_creative_sequential_unet_partonly_mouth.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | #!/bin/bash 7 | #SBATCH --job-name=short_bird_creative_mouth_unet 8 | #SBATCH --output=../../../jobs/sample-short_bird_creative_mouth-%j.out 9 | #SBATCH --error=../../../jobs/sample-short_bird_creative_mouth-%j.err 10 | #SBATCH --partition=short 11 | #SBATCH --time=72:00:00 12 | #SBATCH --ntasks-per-node=1 13 | #SBATCH --gres=gpu:1 14 | #SBATCH --cpus-per-task=10 15 | 16 | python ../../run_part_generator.py --new --results_dir ../../../results --models_dir ../../../models --n_part 10 --data ../../../data/bird_short_mouth_json_64 --name short_bird_creative_mouth --batch_size 40 --network_capacity 16 --gradient_accumulate_every 1 --save_every 2000 --image_size 64 --sparsity_penalty 0.01 --learning_rate_D 1e-4 --learning_rate_G 1e-4 --num_train_steps 300000 17 | -------------------------------------------------------------------------------- /training_scripts/train_creative_birds/bird_short_creative_sequential_unet_partonly_tail.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | #!/bin/bash 7 | #SBATCH --job-name=short_bird_creative_tail_unet 8 | #SBATCH --output=../../../jobs/sample-short_bird_creative_tail-%j.out 9 | #SBATCH --error=../../../jobs/sample-short_bird_creative_tail-%j.err 10 | #SBATCH --partition=short 11 | #SBATCH --time=72:00:00 12 | #SBATCH --ntasks-per-node=1 13 | #SBATCH --gres=gpu:1 14 | #SBATCH --cpus-per-task=10 15 | 16 | python ../../run_part_generator.py --new --results_dir ../../../results --models_dir ../../../models --n_part 10 --data ../../../data/bird_short_tail_json_64 --name short_bird_creative_tail --batch_size 40 --network_capacity 16 --gradient_accumulate_every 1 --save_every 2000 --image_size 64 --sparsity_penalty 0.01 --learning_rate_D 1e-4 --learning_rate_G 1e-4 --num_train_steps 300000 17 | -------------------------------------------------------------------------------- /training_scripts/train_creative_birds/bird_short_creative_sequential_unet_partonly_wings.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | #!/bin/bash 7 | #SBATCH --job-name=short_bird_creative_wings_unet 8 | #SBATCH --output=../../../jobs/sample-short_bird_creative_wings-%j.out 9 | #SBATCH --error=../../../jobs/sample-short_bird_creative_wings-%j.err 10 | #SBATCH --partition=short 11 | #SBATCH --time=72:00:00 12 | #SBATCH --ntasks-per-node=1 13 | #SBATCH --gres=gpu:1 14 | #SBATCH --cpus-per-task=10 15 | 16 | python ../../run_part_generator.py --new --results_dir ../../../results --models_dir ../../../models --n_part 10 --data ../../../data/bird_short_wings_json_64 --name short_bird_creative_wings --num_train_steps 300000 --batch_size 40 --network_capacity 16 --gradient_accumulate_every 1 --save_every 2000 --image_size 64 --sparsity_penalty 0.01 --learning_rate_D 1e-4 --learning_rate_G 1e-4 17 | -------------------------------------------------------------------------------- /training_scripts/train_creative_creatures/bird_short_creative_sequential_unet_partonly_arms.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | #!/bin/bash 7 | #SBATCH --job-name=long_generic_creative_arms_unet 8 | #SBATCH --output=../../../jobs/sample-long_generic_creative_arms-%j.out 9 | #SBATCH --error=../../../jobs/sample-long_generic_creative_arms-%j.err 10 | #SBATCH --partition=short 11 | #SBATCH --gpus-per-node=1 12 | #SBATCH --time=72:00:00 13 | #SBATCH --gres=gpu:1 14 | #SBATCH --cpus-per-task=20 15 | 16 | python ../../run_part_generator.py --new --results_dir ../../../results --models_dir ../../../models --large_aug --n_part 19 --data ../../../data/generic_long_arms_json_64_--partial ../../../data/generic_long_arms_input_parts_64_split_--name long_generic_creative_arms --batch_size 40 --gradient_accumulate_every 1 --network_capacity 16 --save_every 2000 --image_size 64 --sparsity_penalty 0.01 --learning_rate_D 1e-4 --learning_rate_G 1e-4 --num_train_steps 600000 17 | -------------------------------------------------------------------------------- /training_scripts/train_creative_creatures/bird_short_creative_sequential_unet_partonly_beak.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | #!/bin/bash 7 | #SBATCH --job-name=long_generic_creative_beak_unet 8 | #SBATCH --output=../../../jobs/sample-long_generic_creative_beak-%j.out 9 | #SBATCH --error=../../../jobs/sample-long_generic_creative_beak-%j.err 10 | #SBATCH --partition=short 11 | #SBATCH --gpus-per-node=1 12 | #SBATCH --time=72:00:00 13 | #SBATCH --gres=gpu:1 14 | #SBATCH --cpus-per-task=20 15 | 16 | 17 | python ../../run_part_generator.py --new --results_dir ../../../results --models_dir ../../../models --large_aug --n_part 19 --data ../../../data/generic_long_beak_json_64_--name long_generic_creative_beak --batch_size 40 --gradient_accumulate_every 1 --network_capacity 16 --save_every 2000 --image_size 64 --sparsity_penalty 0.01 --learning_rate_D 1e-4 --learning_rate_G 1e-4 --num_train_steps 600000 18 | -------------------------------------------------------------------------------- /training_scripts/train_creative_creatures/bird_short_creative_sequential_unet_partonly_body.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | #!/bin/bash 7 | #SBATCH --job-name=long_generic_creative_body_unet 8 | #SBATCH --output=../../../jobs/sample-long_generic_creative_body-%j.out 9 | #SBATCH --error=../../../jobs/sample-long_generic_creative_body-%j.err 10 | #SBATCH --partition=short 11 | #SBATCH --gpus-per-node=1 12 | #SBATCH --time=72:00:00 13 | #SBATCH --gres=gpu:1 14 | #SBATCH --cpus-per-task=20 15 | 16 | 17 | python ../../run_part_generator.py --new --results_dir ../../../results --models_dir ../../../models --large_aug --n_part 19 --data ../../../data/generic_long_body_json_64 --name long_generic_creative_body --batch_size 40 --gradient_accumulate_every 1 --network_capacity 16 --save_every 2000 --image_size 64 --sparsity_penalty 0.01 --learning_rate_D 1e-4 --learning_rate_G 1e-4 --num_train_steps 600000 18 | -------------------------------------------------------------------------------- /training_scripts/train_creative_creatures/bird_short_creative_sequential_unet_partonly_ears.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | #!/bin/bash 7 | #SBATCH --job-name=long_generic_creative_ears_unet 8 | #SBATCH --output=../../../jobs/sample-long_generic_creative_ears-%j.out 9 | #SBATCH --error=../../../jobs/sample-long_generic_creative_ears-%j.err 10 | #SBATCH --partition=short 11 | #SBATCH --gpus-per-node=1 12 | #SBATCH --time=72:00:00 13 | #SBATCH --gres=gpu:1 14 | #SBATCH --cpus-per-task=20 15 | 16 | 17 | python ../../run_part_generator.py --new --results_dir ../../../results --models_dir ../../../models --large_aug --n_part 19 --data ../../../data/generic_long_ears_json_64 --name long_generic_creative_ears --batch_size 40 --gradient_accumulate_every 1 --network_capacity 16 --save_every 2000 --image_size 64 --sparsity_penalty 0.01 --learning_rate_D 1e-4 --learning_rate_G 1e-4 --num_train_steps 600000 18 | -------------------------------------------------------------------------------- /training_scripts/train_creative_creatures/bird_short_creative_sequential_unet_partonly_eye.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | #!/bin/bash 7 | #SBATCH --job-name=long_generic_creative_eye_unet 8 | #SBATCH --output=../../../jobs/sample-long_generic_creative_eye-%j.out 9 | #SBATCH --error=../../../jobs/sample-long_generic_creative_eye-%j.err 10 | #SBATCH --partition=short 11 | 12 | #SBATCH --gpus-per-node=1 13 | #SBATCH --time=72:00:00 14 | 15 | #SBATCH --gres=gpu:1 16 | 17 | #SBATCH --cpus-per-task=20 18 | 19 | 20 | python ../../run_part_generator.py --new --results_dir ../../../results --models_dir ../../../models --large_aug --n_part 19 --data ../../../data/generic_long_eye_json_64_--partial ../../../data/generic_long_eye_input_parts_64_split_--name long_generic_creative_eye --batch_size 40 --gradient_accumulate_every 1 --network_capacity 16 --save_every 2000 --image_size 64 --sparsity_penalty 0.01 --learning_rate_D 1e-4 --learning_rate_G 1e-4 --num_train_steps 600000 21 | -------------------------------------------------------------------------------- /training_scripts/train_creative_creatures/bird_short_creative_sequential_unet_partonly_feet.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | #!/bin/bash 7 | #SBATCH --job-name=long_generic_creative_feet_unet 8 | #SBATCH --output=../../../jobs/sample-long_generic_creative_feet-%j.out 9 | #SBATCH --error=../../../jobs/sample-long_generic_creative_feet-%j.err 10 | #SBATCH --partition=short 11 | #SBATCH --gpus-per-node=1 12 | #SBATCH --time=72:00:00 13 | #SBATCH --gres=gpu:1 14 | #SBATCH --cpus-per-task=20 15 | 16 | 17 | python ../../run_part_generator.py --new --results_dir ../../../results --models_dir ../../../models --large_aug --n_part 19 --data ../../../data/generic_long_feet_json_64 --name long_generic_creative_feet --batch_size 40 --gradient_accumulate_every 1 --network_capacity 16 --save_every 2000 --image_size 64 --sparsity_penalty 0.01 --learning_rate_D 1e-4 --learning_rate_G 1e-4 --num_train_steps 600000 18 | -------------------------------------------------------------------------------- /training_scripts/train_creative_creatures/bird_short_creative_sequential_unet_partonly_fin.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | #!/bin/bash 7 | #SBATCH --job-name=long_generic_creative_fin_unet 8 | #SBATCH --output=../../../jobs/sample-long_generic_creative_fin-%j.out 9 | #SBATCH --error=../../../jobs/sample-long_generic_creative_fin-%j.err 10 | #SBATCH --partition=short 11 | #SBATCH --gpus-per-node=1 12 | #SBATCH --time=72:00:00 13 | #SBATCH --gres=gpu:1 14 | #SBATCH --cpus-per-task=20 15 | 16 | 17 | python ../../run_part_generator.py --new --results_dir ../../../results --models_dir ../../../models --large_aug --n_part 19 --data ../../../data/generic_long_fin_json_64 --name long_generic_creative_fin --batch_size 40 --gradient_accumulate_every 1 --network_capacity 16 --save_every 2000 --image_size 64 --sparsity_penalty 0.01 --learning_rate_D 1e-4 --learning_rate_G 1e-4 --num_train_steps 600000 18 | -------------------------------------------------------------------------------- /training_scripts/train_creative_creatures/bird_short_creative_sequential_unet_partonly_hair.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | #!/bin/bash 7 | #SBATCH --job-name=long_generic_creative_hair_unet 8 | #SBATCH --output=../../../jobs/sample-long_generic_creative_hair-%j.out 9 | #SBATCH --error=../../../jobs/sample-long_generic_creative_hair-%j.err 10 | #SBATCH --partition=short 11 | #SBATCH --gpus-per-node=1 12 | #SBATCH --time=72:00:00 13 | #SBATCH --gres=gpu:1 14 | #SBATCH --cpus-per-task=20 15 | 16 | 17 | python ../../run_part_generator.py --new --results_dir ../../../results --models_dir ../../../models --large_aug --n_part 19 --data ../../../data/generic_long_hair_json_64 --name long_generic_creative_hair --batch_size 40 --gradient_accumulate_every 1 --network_capacity 16 --save_every 2000 --image_size 64 --sparsity_penalty 0.01 --learning_rate_D 1e-4 --learning_rate_G 1e-4 --num_train_steps 600000 18 | -------------------------------------------------------------------------------- /training_scripts/train_creative_creatures/bird_short_creative_sequential_unet_partonly_hands.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | #!/bin/bash 7 | #SBATCH --job-name=long_generic_creative_hands_unet 8 | #SBATCH --output=../../../jobs/sample-long_generic_creative_hands-%j.out 9 | #SBATCH --error=../../../jobs/sample-long_generic_creative_hands-%j.err 10 | #SBATCH --partition=short 11 | #SBATCH --gpus-per-node=1 12 | #SBATCH --time=72:00:00 13 | #SBATCH --gres=gpu:1 14 | #SBATCH --cpus-per-task=20 15 | 16 | 17 | python ../../run_part_generator.py --new --results_dir ../../../results --models_dir ../../../models --large_aug --n_part 19 --data ../../../data/generic_long_hands_json_64 --name long_generic_creative_hands --batch_size 40 --gradient_accumulate_every 1 --network_capacity 16 --save_every 2000 --image_size 64 --sparsity_penalty 0.01 --learning_rate_D 1e-4 --learning_rate_G 1e-4 --num_train_steps 600000 18 | -------------------------------------------------------------------------------- /training_scripts/train_creative_creatures/bird_short_creative_sequential_unet_partonly_head.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | #!/bin/bash 7 | #SBATCH --job-name=long_generic_creative_head_unet 8 | #SBATCH --output=../../../jobs/sample-long_generic_creative_head-%j.out 9 | #SBATCH --error=../../../jobs/sample-long_generic_creative_head-%j.err 10 | #SBATCH --partition=short 11 | #SBATCH --gpus-per-node=1 12 | #SBATCH --time=72:00:00 13 | #SBATCH --gres=gpu:1 14 | #SBATCH --cpus-per-task=20 15 | 16 | 17 | python ../../run_part_generator.py --new --results_dir ../../../results --models_dir ../../../models --large_aug --n_part 19 --data ../../../data/generic_long_head_json_64 --name long_generic_creative_head --batch_size 40 --gradient_accumulate_every 1 --network_capacity 16 --save_every 2000 --image_size 64 --sparsity_penalty 0.01 --learning_rate_D 1e-4 --learning_rate_G 1e-4 --num_train_steps 600000 18 | -------------------------------------------------------------------------------- /training_scripts/train_creative_creatures/bird_short_creative_sequential_unet_partonly_horns.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | #!/bin/bash 7 | #SBATCH --job-name=long_generic_creative_horns_unet 8 | #SBATCH --output=../../../jobs/sample-long_generic_creative_horns-%j.out 9 | #SBATCH --error=../../../jobs/sample-long_generic_creative_horns-%j.err 10 | #SBATCH --partition=short 11 | #SBATCH --gpus-per-node=1 12 | #SBATCH --time=72:00:00 13 | #SBATCH --gres=gpu:1 14 | #SBATCH --cpus-per-task=20 15 | 16 | 17 | python ../../run_part_generator.py --new --results_dir ../../../results --models_dir ../../../models --large_aug --n_part 19 --data ../../../data/generic_long_horns_json_64 --name long_generic_creative_horns --batch_size 40 --gradient_accumulate_every 1 --network_capacity 16 --save_every 2000 --image_size 64 --sparsity_penalty 0.01 --learning_rate_D 1e-4 --learning_rate_G 1e-4 --num_train_steps 600000 18 | -------------------------------------------------------------------------------- /training_scripts/train_creative_creatures/bird_short_creative_sequential_unet_partonly_legs.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | #!/bin/bash 7 | #SBATCH --job-name=long_generic_creative_legs_unet 8 | #SBATCH --output=../../../jobs/sample-long_generic_creative_legs-%j.out 9 | #SBATCH --error=../../../jobs/sample-long_generic_creative_legs-%j.err 10 | #SBATCH --partition=short 11 | #SBATCH --gpus-per-node=1 12 | #SBATCH --time=72:00:00 13 | #SBATCH --gres=gpu:1 14 | #SBATCH --cpus-per-task=20 15 | 16 | 17 | python ../../run_part_generator.py --new --results_dir ../../../results --models_dir ../../../models --large_aug --n_part 19 --data ../../../data/generic_long_legs_json_64 --name long_generic_creative_legs --batch_size 40 --gradient_accumulate_every 1 --network_capacity 16 --save_every 2000 --image_size 64 --sparsity_penalty 0.01 --learning_rate_D 1e-4 --learning_rate_G 1e-4 --num_train_steps 600000 18 | -------------------------------------------------------------------------------- /training_scripts/train_creative_creatures/bird_short_creative_sequential_unet_partonly_mouth.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | #!/bin/bash 7 | #SBATCH --job-name=long_generic_creative_mouth_unet 8 | #SBATCH --output=../../../jobs/sample-long_generic_creative_mouth-%j.out 9 | #SBATCH --error=../../../jobs/sample-long_generic_creative_mouth-%j.err 10 | #SBATCH --partition=short 11 | #SBATCH --gpus-per-node=1 12 | #SBATCH --time=72:00:00 13 | #SBATCH --gres=gpu:1 14 | #SBATCH --cpus-per-task=20 15 | 16 | 17 | python ../../run_part_generator.py --new --results_dir ../../../results --models_dir ../../../models --large_aug --n_part 19 --data ../../../data/generic_long_mouth_json_64 --name long_generic_creative_mouth --batch_size 40 --gradient_accumulate_every 1 --network_capacity 16 --save_every 2000 --image_size 64 --sparsity_penalty 0.01 --learning_rate_D 1e-4 --learning_rate_G 1e-4 --num_train_steps 600000 18 | -------------------------------------------------------------------------------- /training_scripts/train_creative_creatures/bird_short_creative_sequential_unet_partonly_nose.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | #!/bin/bash 7 | #SBATCH --job-name=long_generic_creative_nose_unet 8 | #SBATCH --output=../../../jobs/sample-long_generic_creative_nose-%j.out 9 | #SBATCH --error=../../../jobs/sample-long_generic_creative_nose-%j.err 10 | #SBATCH --partition=short 11 | #SBATCH --gpus-per-node=1 12 | #SBATCH --time=72:00:00 13 | #SBATCH --gres=gpu:1 14 | #SBATCH --cpus-per-task=20 15 | 16 | 17 | python ../../run_part_generator.py --new --results_dir ../../../results --models_dir ../../../models --large_aug --n_part 19 --data ../../../data/generic_long_nose_json_64 --name long_generic_creative_nose --batch_size 40 --gradient_accumulate_every 1 --network_capacity 16 --save_every 2000 --image_size 64 --sparsity_penalty 0.01 --learning_rate_D 1e-4 --learning_rate_G 1e-4 --num_train_steps 600000 18 | -------------------------------------------------------------------------------- /training_scripts/train_creative_creatures/bird_short_creative_sequential_unet_partonly_paws.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | #!/bin/bash 7 | #SBATCH --job-name=long_generic_creative_paws_unet 8 | #SBATCH --output=../../../jobs/sample-long_generic_creative_paws-%j.out 9 | #SBATCH --error=../../../jobs/sample-long_generic_creative_paws-%j.err 10 | #SBATCH --partition=short 11 | 12 | #SBATCH --gpus-per-node=1 13 | #SBATCH --time=72:00:00 14 | 15 | #SBATCH --gres=gpu:1 16 | 17 | #SBATCH --cpus-per-task=20 18 | 19 | 20 | python ../../run_part_generator.py --new --results_dir ../../../results --models_dir ../../../models --large_aug --n_part 19 --data ../../../data/generic_long_paws_json_64 --name long_generic_creative_paws --batch_size 40 --gradient_accumulate_every 1 --network_capacity 16 --save_every 2000 --image_size 64 --sparsity_penalty 0.01 --learning_rate_D 1e-4 --learning_rate_G 1e-4 --num_train_steps 600000 21 | -------------------------------------------------------------------------------- /training_scripts/train_creative_creatures/bird_short_creative_sequential_unet_partonly_tail.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | #!/bin/bash 7 | #SBATCH --job-name=long_generic_creative_tail_unet 8 | #SBATCH --output=../../../jobs/sample-long_generic_creative_tail-%j.out 9 | #SBATCH --error=../../../jobs/sample-long_generic_creative_tail-%j.err 10 | #SBATCH --partition=short 11 | #SBATCH --gpus-per-node=1 12 | #SBATCH --time=72:00:00 13 | #SBATCH --gres=gpu:1 14 | #SBATCH --cpus-per-task=20 15 | 16 | 17 | python ../../run_part_generator.py --new --results_dir ../../../results --models_dir ../../../models --large_aug --n_part 19 --data ../../../data/generic_long_tail_json_64 --name long_generic_creative_tail --batch_size 40 --gradient_accumulate_every 1 --network_capacity 16 --save_every 2000 --image_size 64 --sparsity_penalty 0.01 --learning_rate_D 1e-4 --learning_rate_G 1e-4 --num_train_steps 600000 18 | -------------------------------------------------------------------------------- /training_scripts/train_creative_creatures/bird_short_creative_sequential_unet_partonly_wings.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | #!/bin/bash 7 | #SBATCH --job-name=long_generic_creative_wings_unet 8 | #SBATCH --output=../../../jobs/sample-long_generic_creative_wings-%j.out 9 | #SBATCH --error=../../../jobs/sample-long_generic_creative_wings-%j.err 10 | #SBATCH --partition=short 11 | #SBATCH --gpus-per-node=1 12 | #SBATCH --time=72:00:00 13 | #SBATCH --gres=gpu:1 14 | #SBATCH --cpus-per-task=20 15 | 16 | 17 | python ../../run_part_generator.py --new --results_dir ../../../results --models_dir ../../../models --large_aug --n_part 19 --data ../../../data/generic_long_wings_json_64 --name long_generic_creative_wings --batch_size 40 --gradient_accumulate_every 1 --network_capacity 16 --save_every 2000 --image_size 64 --sparsity_penalty 0.01 --learning_rate_D 1e-4 --learning_rate_G 1e-4 --num_train_steps 600000 18 | -------------------------------------------------------------------------------- /training_scripts/train_creative_creatures/generic_long_creative_clf_aug.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | #!/bin/bash 7 | #SBATCH --job-name=long_generic_creative_selector 8 | #SBATCH --output=../../../jobs/sample-long_generic_creative_selector_64_split_aug-%j.out 9 | #SBATCH --error=../../../jobs/sample-long_generic_creative_selector_64_split_aug-%j.err 10 | #SBATCH --partition=short 11 | #SBATCH --gpus-per-node=1 12 | #SBATCH --time=60:00:00 13 | #SBATCH --gres=gpu:1 14 | #SBATCH --cpus-per-task=10 15 | 16 | python ../../run_part_selector.py --new --results_dir ../../../results --models_dir ../../../models --n_part 19 --data ../../../data/generic_long_ --name long_generic_creative_selector_64 --batch_size 128 --save_every 1000 --image_size 64 17 | --------------------------------------------------------------------------------