├── .gitignore ├── README.md ├── app ├── README.md ├── app.py ├── components │ ├── app.js │ ├── brush_size_slider.js │ ├── canvas.js │ ├── display_size_nag_bar.js │ ├── download_button.js │ ├── fancy_button.js │ ├── output_picture.js │ ├── purecanvas.js │ ├── speech_bubble.js │ ├── toolbox.js │ ├── toolbox_label_button.js │ ├── toolbox_tool_button.js │ └── tweet_button.js ├── index.html ├── models │ ├── 102.yaml │ ├── 62.yaml │ └── 79.yaml ├── package-lock.json ├── package.json ├── requirements.txt ├── robots.txt ├── scripts │ ├── upload_to_gh.sh │ └── upload_to_s3.sh └── static │ ├── fonts.css │ ├── img │ ├── default.png │ ├── example.png │ ├── favicon.png │ ├── favicon.svg │ └── social.jpg │ ├── index.css │ └── index.js ├── imgs ├── model-lineage.png ├── model-lineage.svg └── screenshot.png ├── lib ├── .gitignore ├── README.md ├── SPADE-master │ ├── .gitignore │ ├── LICENSE.md │ ├── README.md │ ├── data │ │ ├── __init__.py │ │ ├── ade20k_dataset.py │ │ ├── base_dataset.py │ │ ├── cityscapes_dataset.py │ │ ├── coco_dataset.py │ │ ├── custom_dataset.py │ │ ├── facades_dataset.py │ │ ├── image_folder.py │ │ └── pix2pix_dataset.py │ ├── datasets │ │ └── coco_generate_instance_map.py │ ├── models │ │ ├── __init__.py │ │ ├── networks │ │ │ ├── __init__.py │ │ │ ├── architecture.py │ │ │ ├── base_network.py │ │ │ ├── discriminator.py │ │ │ ├── encoder.py │ │ │ ├── generator.py │ │ │ ├── loss.py │ │ │ ├── normalization.py │ │ │ └── sync_batchnorm │ │ │ │ ├── __init__.py │ │ │ │ ├── batchnorm.py │ │ │ │ ├── batchnorm_reimpl.py │ │ │ │ ├── comm.py │ │ │ │ ├── replicate.py │ │ │ │ └── unittest.py │ │ └── pix2pix_model.py │ ├── options │ │ ├── __init__.py │ │ ├── base_options.py │ │ ├── test_options.py │ │ └── train_options.py │ ├── requirements.txt │ ├── test.py │ ├── train.py │ ├── trainers │ │ ├── __init__.py │ │ └── pix2pix_trainer.py │ └── util │ │ ├── __init__.py │ │ ├── coco.py │ │ ├── html.py │ │ ├── iter_counter.py │ │ ├── util.py │ │ └── visualizer.py └── model_server │ ├── latest_net_G.yaml │ ├── requirements.txt │ ├── scripts │ └── load_test.py │ ├── server.py │ └── test_payload.json ├── models ├── README.md ├── model_0.py ├── model_1.py ├── model_10.py ├── model_11.py ├── model_12.py ├── model_13.py ├── model_14.py ├── model_2.py ├── model_3.py ├── model_4.py ├── model_5.py ├── model_6.py ├── model_7.py ├── model_8.py └── model_9.py └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | lib/SPADE-master/datasets/coco_stuff/ 2 | lib/SPADE-master/docs/ 3 | .ipynb_checkpoints 4 | .DS_Store 5 | .vscode 6 | *.pyc 7 | __pycache__/ 8 | node_modules/ 9 | dist/ 10 | checkpoints/ 11 | temp 12 | *.pt 13 | *.egg-info/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # paint-with-ml 2 | 3 | ![](imgs/screenshot.png) 4 | 5 | This app uses a version of the GauGAN deep learning model to transform an input segmentation masks into a landscape painting. 6 | 7 | GauGAN, released in 2019, is the most powerful image-to-image translation algorithm currently known. This version of the model has been trained on the popular ADE20K dataset, then fine-tuned on a dataset of 250 paintings from Bob Ross's beloved PBS series, "The Joy of Painting". 8 | 9 | Choose from nine different semantic brushes to craft your painting. Then click on the "Run" button to generate a result! Here are some example outputs: 10 | 11 | ![](https://i.imgur.com/QzNFAV6.png) 12 | 13 | Try it out yourself by visiting [paintwith.spell.ml](http://paintwith.spell.ml/). To learn more about how this model was built and served, check out [the accompanying article](https://spell.ml/blog/paint-with-machine-learning-X10i3BAAACQAUI_o) on the Spell blog. 14 | 15 | ## Project hierarchy 16 | 17 | ``` 18 | ├── LICENSE 19 | ├── README.md <- You're reading it! 20 | ├── lib/ <- Vendored model code (Python). 21 | │ ├── SPADE-master/ <- Vendored copy of NVlabs/SPADE.* ** 22 | │ └── model_server/ <- Model server code (uses Spell model serving). 23 | ├── models <- Model assets 24 | │ ├── model_0.py <- The build script for the first model trained. 25 | │ ├── ... 26 | │ ├── model_N.py <- The build script for the last model trained. 27 | │ └── README.md <- Reference to the models builds. 28 | ├── notebooks <- Jupyter notebooks discussing the model build process. 29 | ├── requirements.txt <- Project environment requirements, installable with pip. 30 | ├── app/ <- User-facing demo React web app. 31 | │ ├── README.md <- Reference on how to build and serve the web app. 32 | │ ├── index.html 33 | │ ├── components/ <- React components. 34 | │ ├── models/ <- Model configs (see model_server/). 35 | │ ├── static/ 36 | │ ├── requirements.txt <- Web app Python requirements, installable with pip. 37 | | └── package.json <- Web app JS requirements, installable with npm. 38 | ├── Dockerfile <- Dockerfile bundling the web application. 39 | └── .gitignore 40 | 41 | * Also contains a copy of vacancy/Synchronized-BatchNorm-PyTorch, a NVlabs/SPADE requirement 42 | ** Code has minor modifications made for compatibility with the Jupyter environment 43 | ``` 44 | -------------------------------------------------------------------------------- /app/README.md: -------------------------------------------------------------------------------- 1 | # app 2 | 3 | This folder defines our consumer-facing web application. 4 | 5 | The application frontend is in React. The core of the application is a drawable HTML Canvas with click-based interaction logic. The application backend is a [model server](https://spell.ml/docs/model_servers) running on Spell. 6 | 7 | ## Deployment 8 | 9 | ### Local 10 | 11 | 1. Run `npm install` to install the JS packages. 12 | 2. Run `npm run-script build` to build the app JS assets. 13 | 3. Run `pip install -r requirements.txt` (preferably in a `virtualenv` or `conda` environment) to install the Python packages. 14 | 4. Export the Flask environment variables (you may set `FLASK_ENV=production` instead, if you are so inclined): 15 | 16 | ```bash 17 | export FLASK_APP=app.py 18 | export FLASK_ENV=development 19 | ``` 20 | 5. Start the web service using `flask run --no-reload`. 21 | 22 | ### Remote 23 | 24 | We're hosting this site statically on AWS S3. See [Enabling website hosting](https://docs.aws.amazon.com/AmazonS3/latest/dev/EnableWebsiteHosting.html) in the AWS documentation for details, and `app/scripts/upload_to_s3.sh` for the deploy script. 25 | -------------------------------------------------------------------------------- /app/app.py: -------------------------------------------------------------------------------- 1 | import os 2 | import requests 3 | 4 | from flask import Flask, send_file 5 | from flask import request 6 | app = Flask('neural-painter', static_folder='static') 7 | 8 | @app.route('/', methods=['GET']) 9 | def index(): 10 | return send_file('index.html') 11 | 12 | print('App ready!') 13 | -------------------------------------------------------------------------------- /app/components/brush_size_slider.js: -------------------------------------------------------------------------------- 1 | import React from 'react'; 2 | import { Component } from 'react'; 3 | 4 | 5 | class BrushSizeSlider extends Component { 6 | render() { 7 | return
8 | this.props.onChange(e.target.value)} 12 | /> 13 |
14 | } 15 | } 16 | 17 | export default BrushSizeSlider; -------------------------------------------------------------------------------- /app/components/display_size_nag_bar.js: -------------------------------------------------------------------------------- 1 | import React from 'react'; 2 | import { Component } from 'react'; 3 | 4 | 5 | class DisplaySizeNagBar extends Component { 6 | render() { 7 | return
8 |
9 | This app is optimized for large screens and may not work in your current browser! 10 |
11 |
12 | } 13 | } 14 | 15 | export default DisplaySizeNagBar; 16 | -------------------------------------------------------------------------------- /app/components/download_button.js: -------------------------------------------------------------------------------- 1 | import React from 'react'; 2 | import { Component } from 'react'; 3 | 4 | 5 | class DownloadButton extends Component { 6 | render() { 7 | // NOTE(aleksey): Chrome does not support target="_blank" for data URIs (Firefox does). 8 | return 9 | 10 | 11 | 12 | 13 | } 14 | } 15 | 16 | export default DownloadButton; -------------------------------------------------------------------------------- /app/components/fancy_button.js: -------------------------------------------------------------------------------- 1 | import React from 'react'; 2 | import { Component } from 'react'; 3 | 4 | 5 | class FancyButton extends Component { 6 | constructor() { 7 | super(); 8 | this.state = {'hover': false} 9 | this.onMouseEnter = this.onMouseEnter.bind(this); 10 | this.onMouseLeave = this.onMouseLeave.bind(this); 11 | } 12 | 13 | // NOTE: JS hover events are necessary because the CSS :hover event handler cannot propagate 14 | // from the button element to the SVG element directly, as needed here. We have to handle this 15 | // transition in JS in getButtonImage(). 16 | onMouseEnter() { this.setState({'hover': true}); } 17 | onMouseLeave() { this.setState({'hover': false}); } 18 | 19 | getButtonImage() { 20 | const imageFill = (this.state.hover && this.props.visualType == "unfilled") ? 'black' : 'white'; 21 | 22 | switch (this.props.buttonFunction) { 23 | case 'run': 24 | return 25 | 26 | 27 | ; 28 | case 'reset': 29 | return 30 | 31 | ; 32 | } 33 | } 34 | 35 | render() { 36 | const imageHoverClass = this.state.hover ? 'hover' : ''; 37 | const imageFilledClass = this.props.visualType; 38 | const containerClassNames = `fancy-button ${this.props.visualType}` 39 | const imageClassNames = `fancy-button-image-container ${imageHoverClass} ${imageFilledClass}` 40 | return
46 |
47 | {this.getButtonImage()} 48 |
49 |
50 | {this.props.buttonFunction} 51 |
52 |
; 53 | } 54 | } 55 | 56 | export default FancyButton; -------------------------------------------------------------------------------- /app/components/output_picture.js: -------------------------------------------------------------------------------- 1 | import React from 'react'; 2 | import { Component } from 'react'; 3 | 4 | 5 | class OutputPicture extends Component { 6 | render() { 7 | return 12 | } 13 | } 14 | 15 | export default OutputPicture; -------------------------------------------------------------------------------- /app/components/purecanvas.js: -------------------------------------------------------------------------------- 1 | import React from 'react'; 2 | import { Component } from 'react'; 3 | 4 | 5 | class PureCanvas extends Component { 6 | shouldComponentUpdate() { 7 | return false; 8 | } 9 | 10 | render() { 11 | return ( 12 | node ? this.props.contextRef(node) : null } 15 | width="512" 16 | height="512" 17 | onClick={this.props.onClick} 18 | onMouseDown={this.props.onMouseDown} 19 | onMouseUp={this.props.onMouseUp} 20 | onMouseMove={this.props.onMouseMove} 21 | onMouseOut={this.props.onMouseOut} 22 | /> 23 | ); 24 | } 25 | } 26 | 27 | export default PureCanvas; -------------------------------------------------------------------------------- /app/components/speech_bubble.js: -------------------------------------------------------------------------------- 1 | import React from 'react'; 2 | import { Component } from 'react'; 3 | 4 | 5 | class SpeechBubble extends Component { 6 | render() { 7 | return
8 |
9 | {this.props.message} 10 |
11 |
12 |
; 13 | } 14 | } 15 | 16 | export default SpeechBubble; -------------------------------------------------------------------------------- /app/components/toolbox_tool_button.js: -------------------------------------------------------------------------------- 1 | import React from 'react'; 2 | import { Component } from 'react'; 3 | 4 | 5 | class ToolboxToolButton extends Component { 6 | getToolImage() { 7 | const fill = this.props.activeTool == this.props.tool ? 'black' : 'white'; 8 | switch (this.props.tool) { 9 | case 'brush': 10 | return 11 | 12 | 13 | case 'eraser': 14 | return 15 | 16 | 17 | case 'fill': 18 | return 19 | 20 | 21 | } 22 | } 23 | 24 | render() { 25 | const className = this.props.activeTool == this.props.tool ? 'selector selected' : 'selector unselected'; 26 | return
30 |
31 | {this.getToolImage()} 32 |
33 |
34 | {this.props.tool} 35 |
36 | 37 |
38 | } 39 | } 40 | 41 | export default ToolboxToolButton; -------------------------------------------------------------------------------- /app/components/tweet_button.js: -------------------------------------------------------------------------------- 1 | import React from 'react'; 2 | import { Component } from 'react'; 3 | 4 | 5 | class TweetButton extends Component { 6 | render() { 7 | return 8 | 9 | 10 | 11 | ; 12 | } 13 | } 14 | 15 | export default TweetButton; -------------------------------------------------------------------------------- /app/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | Paint with Machine Learning 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 24 | 38 | 39 | 40 | 41 |
42 | 43 | 44 | 45 | -------------------------------------------------------------------------------- /app/models/102.yaml: -------------------------------------------------------------------------------- 1 | # GuaGAN_Bob_Ross_From_ADE20K_Landscapes_No_VAE 2 | name: '102' 3 | aspect_ratio: 1.0 4 | batchSize: 1 5 | checkpoints_dir: checkpoints/ 6 | contain_dontcare_label: True 7 | crop_size: 256 8 | dataset_mode: custom 9 | gan_mode: hinge 10 | gpu_ids: [] 11 | init_type: xavier 12 | init_variance: 0.02 13 | isTrain: False 14 | label_nc: 150 15 | load_size: 256 16 | model: pix2pix 17 | netG: spade 18 | ngf: 64 19 | no_instance: True 20 | norm_D: spectralinstance 21 | norm_E: spectralinstance 22 | norm_G: spectralspadesyncbatch3x3 23 | num_upsampling_layers: normal 24 | phase: test 25 | preprocess_mode: resize_and_crop 26 | semantic_nc: 151 27 | use_vae: False 28 | which_epoch: latest 29 | z_dim: 256 -------------------------------------------------------------------------------- /app/models/62.yaml: -------------------------------------------------------------------------------- 1 | # GuaGAN_Bob_Ross_From_Scratch_Best 2 | name: '62' 3 | aspect_ratio: 1.0 4 | batchSize: 1 5 | checkpoints_dir: checkpoints/ 6 | contain_dontcare_label: False 7 | crop_size: 256 8 | dataset_mode: custom 9 | gan_mode: hinge 10 | gpu_ids: [] 11 | init_type: xavier 12 | init_variance: 0.02 13 | isTrain: False 14 | label_nc: 9 15 | load_size: 256 16 | model: pix2pix 17 | netG: spade 18 | ngf: 64 19 | no_instance: True 20 | norm_D: spectralinstance 21 | norm_E: spectralinstance 22 | norm_G: spectralspadesyncbatch3x3 23 | num_upsampling_layers: normal 24 | phase: test 25 | preprocess_mode: resize_and_crop 26 | semantic_nc: 151 27 | use_vae: False 28 | which_epoch: latest 29 | z_dim: 256 -------------------------------------------------------------------------------- /app/models/79.yaml: -------------------------------------------------------------------------------- 1 | # GuaGAN_ADE20K_Landscapes 2 | name: '79' 3 | aspect_ratio: 1.0 4 | batchSize: 1 5 | checkpoints_dir: checkpoints/ 6 | contain_dontcare_label: True 7 | crop_size: 256 8 | dataset_mode: custom 9 | gan_mode: hinge 10 | gpu_ids: [] 11 | init_type: xavier 12 | init_variance: 0.02 13 | isTrain: False 14 | label_nc: 150 15 | load_size: 256 16 | model: pix2pix 17 | netG: spade 18 | ngf: 64 19 | no_instance: True 20 | norm_D: spectralinstance 21 | norm_E: spectralinstance 22 | norm_G: spectralspadesyncbatch3x3 23 | num_upsampling_layers: normal 24 | phase: test 25 | preprocess_mode: resize_and_crop 26 | semantic_nc: 151 27 | use_vae: False 28 | which_epoch: latest 29 | z_dim: 256 -------------------------------------------------------------------------------- /app/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "bob-ross-neural-painter", 3 | "version": "1.0.0", 4 | "description": "Paint your own Bob Ross.", 5 | "main": "index.js", 6 | "scripts": { 7 | "build": "npx browserify -t [babelify --presets [ @babel/preset-env @babel/preset-react ] ] static/index.js -o static/dist/index.js", 8 | "test": "echo \"Error: no test specified\" && exit 1" 9 | }, 10 | "repository": { 11 | "type": "git", 12 | "url": "git+https://github.com/ResidentMario/bob-ross-neural-painter.git" 13 | }, 14 | "author": "Aleksey Bilogur", 15 | "license": "ISC", 16 | "bugs": { 17 | "url": "https://github.com/ResidentMario/bob-ross-neural-painter/issues" 18 | }, 19 | "homepage": "https://github.com/ResidentMario/bob-ross-neural-painter#readme", 20 | "dependencies": { 21 | "@babel/core": "^7.10.4", 22 | "@babel/preset-env": "^7.10.4", 23 | "@babel/preset-react": "^7.10.4", 24 | "babelify": "^10.0.0", 25 | "browserify": "^16.5.1", 26 | "core-js": "^3.6.5", 27 | "react": "^16.13.1", 28 | "react-dom": "^16.13.1", 29 | "request": "^2.88.2", 30 | "request-promise-native": "^1.0.8" 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /app/requirements.txt: -------------------------------------------------------------------------------- 1 | # model serving requirements 2 | spell 3 | pillow==7.1.0 4 | flask==1.1.1 -------------------------------------------------------------------------------- /app/robots.txt: -------------------------------------------------------------------------------- 1 | User-agent: * 2 | Allow: / -------------------------------------------------------------------------------- /app/scripts/upload_to_gh.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | set -e 3 | npm run build 4 | cd ../ && \ 5 | git checkout -B gh-pages && \ 6 | rm -rf * && \ 7 | git restore --source master -- app/index.html app/static/ && \ 8 | cp -rf app/* . && \ 9 | rm -rf app && \ 10 | git commit -a -m "Publish website." && \ 11 | git push origin gh-pages -------------------------------------------------------------------------------- /app/scripts/upload_to_s3.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | # Replace s3://paintwith.spell.ml/ with your target bucket name 3 | set -e 4 | pushd ../ && PAINT_WITH_ML_HOME=$PWD && popd 5 | aws s3 cp --acl public-read \ 6 | $PAINT_WITH_ML_HOME/index.html \ 7 | s3://paintwith.spell.ml/index.html 8 | aws s3 sync --acl public-read \ 9 | $PAINT_WITH_ML_HOME/static/ \ 10 | s3://paintwith.spell.ml/static/ 11 | aws s3 cp --acl public-read \ 12 | $PAINT_WITH_ML_HOME/robots.txt \ 13 | s3://paintwith.spell.ml/robots.txt -------------------------------------------------------------------------------- /app/static/img/default.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spellml/paint-with-ml/263145bd78e4ae0c37c8a1bc0f072ac1df1aeead/app/static/img/default.png -------------------------------------------------------------------------------- /app/static/img/example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spellml/paint-with-ml/263145bd78e4ae0c37c8a1bc0f072ac1df1aeead/app/static/img/example.png -------------------------------------------------------------------------------- /app/static/img/favicon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spellml/paint-with-ml/263145bd78e4ae0c37c8a1bc0f072ac1df1aeead/app/static/img/favicon.png -------------------------------------------------------------------------------- /app/static/img/favicon.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | -------------------------------------------------------------------------------- /app/static/img/social.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spellml/paint-with-ml/263145bd78e4ae0c37c8a1bc0f072ac1df1aeead/app/static/img/social.jpg -------------------------------------------------------------------------------- /app/static/index.js: -------------------------------------------------------------------------------- 1 | import React from 'react'; 2 | import { render } from 'react-dom'; 3 | import App from '../components/app.js'; 4 | 5 | render(, document.getElementById('root')); -------------------------------------------------------------------------------- /imgs/model-lineage.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spellml/paint-with-ml/263145bd78e4ae0c37c8a1bc0f072ac1df1aeead/imgs/model-lineage.png -------------------------------------------------------------------------------- /imgs/screenshot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spellml/paint-with-ml/263145bd78e4ae0c37c8a1bc0f072ac1df1aeead/imgs/screenshot.png -------------------------------------------------------------------------------- /lib/.gitignore: -------------------------------------------------------------------------------- 1 | checkpoints/ 2 | results/ 3 | .idea/ 4 | *.tar.gz 5 | *.zip 6 | *.pkl 7 | *.pyc 8 | -------------------------------------------------------------------------------- /lib/README.md: -------------------------------------------------------------------------------- 1 | # lib 2 | 3 | This directory contains the project pip installables: a vendored copy of [NVlabs/SPADE](https://github.com/NVlabs/SPADE) GH directory, and `model_loader`, a basic custom model server designed to work with Spell and GuaGAN. 4 | -------------------------------------------------------------------------------- /lib/SPADE-master/.gitignore: -------------------------------------------------------------------------------- 1 | checkpoints/ 2 | results/ 3 | .idea/ 4 | *.tar.gz 5 | *.zip 6 | *.pkl 7 | *.pyc 8 | -------------------------------------------------------------------------------- /lib/SPADE-master/data/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | import importlib 7 | import torch.utils.data 8 | from data.base_dataset import BaseDataset 9 | 10 | 11 | def find_dataset_using_name(dataset_name): 12 | # Given the option --dataset [datasetname], 13 | # the file "datasets/datasetname_dataset.py" 14 | # will be imported. 15 | dataset_filename = "data." + dataset_name + "_dataset" 16 | datasetlib = importlib.import_module(dataset_filename) 17 | 18 | # In the file, the class called DatasetNameDataset() will 19 | # be instantiated. It has to be a subclass of BaseDataset, 20 | # and it is case-insensitive. 21 | dataset = None 22 | target_dataset_name = dataset_name.replace('_', '') + 'dataset' 23 | for name, cls in datasetlib.__dict__.items(): 24 | if name.lower() == target_dataset_name.lower() \ 25 | and issubclass(cls, BaseDataset): 26 | dataset = cls 27 | 28 | if dataset is None: 29 | raise ValueError("In %s.py, there should be a subclass of BaseDataset " 30 | "with class name that matches %s in lowercase." % 31 | (dataset_filename, target_dataset_name)) 32 | 33 | return dataset 34 | 35 | 36 | def get_option_setter(dataset_name): 37 | dataset_class = find_dataset_using_name(dataset_name) 38 | return dataset_class.modify_commandline_options 39 | 40 | 41 | def create_dataloader(opt): 42 | dataset = find_dataset_using_name(opt.dataset_mode) 43 | instance = dataset() 44 | instance.initialize(opt) 45 | print("dataset [%s] of size %d was created" % 46 | (type(instance).__name__, len(instance))) 47 | dataloader = torch.utils.data.DataLoader( 48 | instance, 49 | batch_size=opt.batchSize, 50 | shuffle=not opt.serial_batches, 51 | num_workers=int(opt.nThreads), 52 | drop_last=opt.isTrain 53 | ) 54 | return dataloader 55 | -------------------------------------------------------------------------------- /lib/SPADE-master/data/ade20k_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | from data.pix2pix_dataset import Pix2pixDataset 7 | from data.image_folder import make_dataset 8 | 9 | 10 | class ADE20KDataset(Pix2pixDataset): 11 | 12 | @staticmethod 13 | def modify_commandline_options(parser, is_train): 14 | parser = Pix2pixDataset.modify_commandline_options(parser, is_train) 15 | parser.set_defaults(preprocess_mode='resize_and_crop') 16 | if is_train: 17 | parser.set_defaults(load_size=286) 18 | else: 19 | parser.set_defaults(load_size=256) 20 | parser.set_defaults(crop_size=256) 21 | parser.set_defaults(display_winsize=256) 22 | parser.set_defaults(label_nc=150) 23 | parser.set_defaults(contain_dontcare_label=True) 24 | parser.set_defaults(cache_filelist_read=False) 25 | parser.set_defaults(cache_filelist_write=False) 26 | parser.set_defaults(no_instance=True) 27 | return parser 28 | 29 | def get_paths(self, opt): 30 | root = opt.dataroot 31 | phase = 'val' if opt.phase == 'test' else 'train' 32 | 33 | all_images = make_dataset(root, recursive=True, read_cache=False, write_cache=False) 34 | image_paths = [] 35 | label_paths = [] 36 | for p in all_images: 37 | if '_%s_' % phase not in p: 38 | continue 39 | if p.endswith('.jpg'): 40 | image_paths.append(p) 41 | elif p.endswith('.png'): 42 | label_paths.append(p) 43 | 44 | instance_paths = [] # don't use instance map for ade20k 45 | 46 | return label_paths, image_paths, instance_paths 47 | 48 | # In ADE20k, 'unknown' label is of value 0. 49 | # Change the 'unknown' label to the last label to match other datasets. 50 | def postprocess(self, input_dict): 51 | label = input_dict['label'] 52 | label = label - 1 53 | label[label == -1] = self.opt.label_nc 54 | -------------------------------------------------------------------------------- /lib/SPADE-master/data/base_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | import torch.utils.data as data 7 | from PIL import Image 8 | import torchvision.transforms as transforms 9 | import numpy as np 10 | import random 11 | 12 | 13 | class BaseDataset(data.Dataset): 14 | def __init__(self): 15 | super(BaseDataset, self).__init__() 16 | 17 | @staticmethod 18 | def modify_commandline_options(parser, is_train): 19 | return parser 20 | 21 | def initialize(self, opt): 22 | pass 23 | 24 | 25 | def get_params(opt, size): 26 | w, h = size 27 | new_h = h 28 | new_w = w 29 | if opt.preprocess_mode == 'resize_and_crop': 30 | new_h = new_w = opt.load_size 31 | elif opt.preprocess_mode == 'scale_width_and_crop': 32 | new_w = opt.load_size 33 | new_h = opt.load_size * h // w 34 | elif opt.preprocess_mode == 'scale_shortside_and_crop': 35 | ss, ls = min(w, h), max(w, h) # shortside and longside 36 | width_is_shorter = w == ss 37 | ls = int(opt.load_size * ls / ss) 38 | new_w, new_h = (ss, ls) if width_is_shorter else (ls, ss) 39 | 40 | x = random.randint(0, np.maximum(0, new_w - opt.crop_size)) 41 | y = random.randint(0, np.maximum(0, new_h - opt.crop_size)) 42 | 43 | flip = random.random() > 0.5 44 | return {'crop_pos': (x, y), 'flip': flip} 45 | 46 | 47 | def get_transform(opt, params, method=Image.BICUBIC, normalize=True, toTensor=True): 48 | transform_list = [] 49 | if 'resize' in opt.preprocess_mode: 50 | osize = [opt.load_size, opt.load_size] 51 | transform_list.append(transforms.Resize(osize, interpolation=method)) 52 | elif 'scale_width' in opt.preprocess_mode: 53 | transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, method))) 54 | elif 'scale_shortside' in opt.preprocess_mode: 55 | transform_list.append(transforms.Lambda(lambda img: __scale_shortside(img, opt.load_size, method))) 56 | 57 | if 'crop' in opt.preprocess_mode: 58 | transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size))) 59 | 60 | if opt.preprocess_mode == 'none': 61 | base = 32 62 | transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base, method))) 63 | 64 | if opt.preprocess_mode == 'fixed': 65 | w = opt.crop_size 66 | h = round(opt.crop_size / opt.aspect_ratio) 67 | transform_list.append(transforms.Lambda(lambda img: __resize(img, w, h, method))) 68 | 69 | if opt.isTrain and not opt.no_flip: 70 | transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip']))) 71 | 72 | if toTensor: 73 | transform_list += [transforms.ToTensor()] 74 | 75 | if normalize: 76 | transform_list += [transforms.Normalize((0.5, 0.5, 0.5), 77 | (0.5, 0.5, 0.5))] 78 | return transforms.Compose(transform_list) 79 | 80 | 81 | def normalize(): 82 | return transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 83 | 84 | 85 | def __resize(img, w, h, method=Image.BICUBIC): 86 | return img.resize((w, h), method) 87 | 88 | 89 | def __make_power_2(img, base, method=Image.BICUBIC): 90 | ow, oh = img.size 91 | h = int(round(oh / base) * base) 92 | w = int(round(ow / base) * base) 93 | if (h == oh) and (w == ow): 94 | return img 95 | return img.resize((w, h), method) 96 | 97 | 98 | def __scale_width(img, target_width, method=Image.BICUBIC): 99 | ow, oh = img.size 100 | if (ow == target_width): 101 | return img 102 | w = target_width 103 | h = int(target_width * oh / ow) 104 | return img.resize((w, h), method) 105 | 106 | 107 | def __scale_shortside(img, target_width, method=Image.BICUBIC): 108 | ow, oh = img.size 109 | ss, ls = min(ow, oh), max(ow, oh) # shortside and longside 110 | width_is_shorter = ow == ss 111 | if (ss == target_width): 112 | return img 113 | ls = int(target_width * ls / ss) 114 | nw, nh = (ss, ls) if width_is_shorter else (ls, ss) 115 | return img.resize((nw, nh), method) 116 | 117 | 118 | def __crop(img, pos, size): 119 | ow, oh = img.size 120 | x1, y1 = pos 121 | tw = th = size 122 | return img.crop((x1, y1, x1 + tw, y1 + th)) 123 | 124 | 125 | def __flip(img, flip): 126 | if flip: 127 | return img.transpose(Image.FLIP_LEFT_RIGHT) 128 | return img 129 | -------------------------------------------------------------------------------- /lib/SPADE-master/data/cityscapes_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | import os.path 7 | from data.pix2pix_dataset import Pix2pixDataset 8 | from data.image_folder import make_dataset 9 | 10 | 11 | class CityscapesDataset(Pix2pixDataset): 12 | 13 | @staticmethod 14 | def modify_commandline_options(parser, is_train): 15 | parser = Pix2pixDataset.modify_commandline_options(parser, is_train) 16 | parser.set_defaults(preprocess_mode='fixed') 17 | parser.set_defaults(load_size=512) 18 | parser.set_defaults(crop_size=512) 19 | parser.set_defaults(display_winsize=512) 20 | parser.set_defaults(label_nc=35) 21 | parser.set_defaults(aspect_ratio=2.0) 22 | parser.set_defaults(batchSize=16) 23 | opt, _ = parser.parse_known_args() 24 | if hasattr(opt, 'num_upsampling_layers'): 25 | parser.set_defaults(num_upsampling_layers='more') 26 | return parser 27 | 28 | def get_paths(self, opt): 29 | root = opt.dataroot 30 | phase = 'val' if opt.phase == 'test' else 'train' 31 | 32 | label_dir = os.path.join(root, 'gtFine', phase) 33 | label_paths_all = make_dataset(label_dir, recursive=True) 34 | label_paths = [p for p in label_paths_all if p.endswith('_labelIds.png')] 35 | 36 | image_dir = os.path.join(root, 'leftImg8bit', phase) 37 | image_paths = make_dataset(image_dir, recursive=True) 38 | 39 | if not opt.no_instance: 40 | instance_paths = [p for p in label_paths_all if p.endswith('_instanceIds.png')] 41 | else: 42 | instance_paths = [] 43 | 44 | return label_paths, image_paths, instance_paths 45 | 46 | def paths_match(self, path1, path2): 47 | name1 = os.path.basename(path1) 48 | name2 = os.path.basename(path2) 49 | # compare the first 3 components, [city]_[id1]_[id2] 50 | return '_'.join(name1.split('_')[:3]) == \ 51 | '_'.join(name2.split('_')[:3]) 52 | -------------------------------------------------------------------------------- /lib/SPADE-master/data/coco_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | import os.path 7 | from data.pix2pix_dataset import Pix2pixDataset 8 | from data.image_folder import make_dataset 9 | 10 | 11 | class CocoDataset(Pix2pixDataset): 12 | 13 | @staticmethod 14 | def modify_commandline_options(parser, is_train): 15 | parser = Pix2pixDataset.modify_commandline_options(parser, is_train) 16 | parser.add_argument('--coco_no_portraits', action='store_true') 17 | parser.set_defaults(preprocess_mode='resize_and_crop') 18 | if is_train: 19 | parser.set_defaults(load_size=286) 20 | else: 21 | parser.set_defaults(load_size=256) 22 | parser.set_defaults(crop_size=256) 23 | parser.set_defaults(display_winsize=256) 24 | parser.set_defaults(label_nc=182) 25 | parser.set_defaults(contain_dontcare_label=True) 26 | parser.set_defaults(cache_filelist_read=True) 27 | parser.set_defaults(cache_filelist_write=True) 28 | return parser 29 | 30 | def get_paths(self, opt): 31 | root = opt.dataroot 32 | phase = 'val' if opt.phase == 'test' else opt.phase 33 | 34 | label_dir = os.path.join(root, '%s_label' % phase) 35 | label_paths = make_dataset(label_dir, recursive=False, read_cache=True) 36 | 37 | if not opt.coco_no_portraits and opt.isTrain: 38 | label_portrait_dir = os.path.join(root, '%s_label_portrait' % phase) 39 | if os.path.isdir(label_portrait_dir): 40 | label_portrait_paths = make_dataset(label_portrait_dir, recursive=False, read_cache=True) 41 | label_paths += label_portrait_paths 42 | 43 | image_dir = os.path.join(root, '%s_img' % phase) 44 | image_paths = make_dataset(image_dir, recursive=False, read_cache=True) 45 | 46 | if not opt.coco_no_portraits and opt.isTrain: 47 | image_portrait_dir = os.path.join(root, '%s_img_portrait' % phase) 48 | if os.path.isdir(image_portrait_dir): 49 | image_portrait_paths = make_dataset(image_portrait_dir, recursive=False, read_cache=True) 50 | image_paths += image_portrait_paths 51 | 52 | if not opt.no_instance: 53 | instance_dir = os.path.join(root, '%s_inst' % phase) 54 | instance_paths = make_dataset(instance_dir, recursive=False, read_cache=True) 55 | 56 | if not opt.coco_no_portraits and opt.isTrain: 57 | instance_portrait_dir = os.path.join(root, '%s_inst_portrait' % phase) 58 | if os.path.isdir(instance_portrait_dir): 59 | instance_portrait_paths = make_dataset(instance_portrait_dir, recursive=False, read_cache=True) 60 | instance_paths += instance_portrait_paths 61 | 62 | else: 63 | instance_paths = [] 64 | 65 | return label_paths, image_paths, instance_paths 66 | -------------------------------------------------------------------------------- /lib/SPADE-master/data/custom_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | from data.pix2pix_dataset import Pix2pixDataset 7 | from data.image_folder import make_dataset 8 | 9 | 10 | class CustomDataset(Pix2pixDataset): 11 | """ Dataset that loads images from directories 12 | Use option --label_dir, --image_dir, --instance_dir to specify the directories. 13 | The images in the directories are sorted in alphabetical order and paired in order. 14 | """ 15 | 16 | @staticmethod 17 | def modify_commandline_options(parser, is_train): 18 | parser = Pix2pixDataset.modify_commandline_options(parser, is_train) 19 | parser.set_defaults(preprocess_mode='resize_and_crop') 20 | load_size = 286 if is_train else 256 21 | parser.set_defaults(load_size=load_size) 22 | parser.set_defaults(crop_size=256) 23 | parser.set_defaults(display_winsize=256) 24 | parser.set_defaults(label_nc=13) 25 | parser.set_defaults(contain_dontcare_label=False) 26 | 27 | parser.add_argument('--label_dir', type=str, required=True, 28 | help='path to the directory that contains label images') 29 | parser.add_argument('--image_dir', type=str, required=True, 30 | help='path to the directory that contains photo images') 31 | parser.add_argument('--instance_dir', type=str, default='', 32 | help='path to the directory that contains instance maps. Leave black if not exists') 33 | return parser 34 | 35 | def get_paths(self, opt): 36 | label_dir = opt.label_dir 37 | label_paths = make_dataset(label_dir, recursive=False, read_cache=True) 38 | 39 | image_dir = opt.image_dir 40 | image_paths = make_dataset(image_dir, recursive=False, read_cache=True) 41 | 42 | if len(opt.instance_dir) > 0: 43 | instance_dir = opt.instance_dir 44 | instance_paths = make_dataset(instance_dir, recursive=False, read_cache=True) 45 | else: 46 | instance_paths = [] 47 | 48 | assert len(label_paths) == len(image_paths), "The #images in %s and %s do not match. Is there something wrong?" 49 | 50 | return label_paths, image_paths, instance_paths 51 | -------------------------------------------------------------------------------- /lib/SPADE-master/data/facades_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | import os.path 7 | from data.pix2pix_dataset import Pix2pixDataset 8 | from data.image_folder import make_dataset 9 | 10 | 11 | class FacadesDataset(Pix2pixDataset): 12 | 13 | @staticmethod 14 | def modify_commandline_options(parser, is_train): 15 | parser = Pix2pixDataset.modify_commandline_options(parser, is_train) 16 | parser.set_defaults(dataroot='./dataset/facades/') 17 | parser.set_defaults(preprocess_mode='resize_and_crop') 18 | load_size = 286 if is_train else 256 19 | parser.set_defaults(load_size=load_size) 20 | parser.set_defaults(crop_size=256) 21 | parser.set_defaults(display_winsize=256) 22 | parser.set_defaults(label_nc=13) 23 | parser.set_defaults(contain_dontcare_label=False) 24 | parser.set_defaults(no_instance=True) 25 | return parser 26 | 27 | def get_paths(self, opt): 28 | root = opt.dataroot 29 | phase = 'val' if opt.phase == 'test' else opt.phase 30 | 31 | label_dir = os.path.join(root, '%s_label' % phase) 32 | label_paths = make_dataset(label_dir, recursive=False, read_cache=True) 33 | 34 | image_dir = os.path.join(root, '%s_img' % phase) 35 | image_paths = make_dataset(image_dir, recursive=False, read_cache=True) 36 | 37 | instance_paths = [] 38 | 39 | return label_paths, image_paths, instance_paths 40 | -------------------------------------------------------------------------------- /lib/SPADE-master/data/image_folder.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | ############################################################################### 7 | # Code from 8 | # https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py 9 | # Modified the original code so that it also loads images from the current 10 | # directory as well as the subdirectories 11 | ############################################################################### 12 | import torch.utils.data as data 13 | from PIL import Image 14 | import os 15 | 16 | IMG_EXTENSIONS = [ 17 | '.jpg', '.JPG', '.jpeg', '.JPEG', 18 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tiff', '.webp' 19 | ] 20 | 21 | 22 | def is_image_file(filename): 23 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 24 | 25 | 26 | def make_dataset_rec(dir, images): 27 | assert os.path.isdir(dir), '%s is not a valid directory' % dir 28 | 29 | for root, dnames, fnames in sorted(os.walk(dir, followlinks=True)): 30 | for fname in fnames: 31 | if is_image_file(fname): 32 | path = os.path.join(root, fname) 33 | images.append(path) 34 | 35 | 36 | def make_dataset(dir, recursive=False, read_cache=False, write_cache=False): 37 | images = [] 38 | 39 | if read_cache: 40 | possible_filelist = os.path.join(dir, 'files.list') 41 | if os.path.isfile(possible_filelist): 42 | with open(possible_filelist, 'r') as f: 43 | images = f.read().splitlines() 44 | return images 45 | 46 | if recursive: 47 | make_dataset_rec(dir, images) 48 | else: 49 | assert os.path.isdir(dir) or os.path.islink(dir), '%s is not a valid directory' % dir 50 | 51 | for root, dnames, fnames in sorted(os.walk(dir)): 52 | for fname in fnames: 53 | if is_image_file(fname): 54 | path = os.path.join(root, fname) 55 | images.append(path) 56 | 57 | if write_cache: 58 | filelist_cache = os.path.join(dir, 'files.list') 59 | with open(filelist_cache, 'w') as f: 60 | for path in images: 61 | f.write("%s\n" % path) 62 | print('wrote filelist cache at %s' % filelist_cache) 63 | 64 | return images 65 | 66 | 67 | def default_loader(path): 68 | return Image.open(path).convert('RGB') 69 | 70 | 71 | class ImageFolder(data.Dataset): 72 | 73 | def __init__(self, root, transform=None, return_paths=False, 74 | loader=default_loader): 75 | imgs = make_dataset(root) 76 | if len(imgs) == 0: 77 | raise(RuntimeError("Found 0 images in: " + root + "\n" 78 | "Supported image extensions are: " + 79 | ",".join(IMG_EXTENSIONS))) 80 | 81 | self.root = root 82 | self.imgs = imgs 83 | self.transform = transform 84 | self.return_paths = return_paths 85 | self.loader = loader 86 | 87 | def __getitem__(self, index): 88 | path = self.imgs[index] 89 | img = self.loader(path) 90 | if self.transform is not None: 91 | img = self.transform(img) 92 | if self.return_paths: 93 | return img, path 94 | else: 95 | return img 96 | 97 | def __len__(self): 98 | return len(self.imgs) 99 | -------------------------------------------------------------------------------- /lib/SPADE-master/data/pix2pix_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | from data.base_dataset import BaseDataset, get_params, get_transform 7 | from PIL import Image 8 | import util.util as util 9 | import os 10 | 11 | 12 | class Pix2pixDataset(BaseDataset): 13 | @staticmethod 14 | def modify_commandline_options(parser, is_train): 15 | parser.add_argument('--no_pairing_check', action='store_true', 16 | help='If specified, skip sanity check of correct label-image file pairing') 17 | return parser 18 | 19 | def initialize(self, opt): 20 | self.opt = opt 21 | 22 | label_paths, image_paths, instance_paths = self.get_paths(opt) 23 | 24 | util.natural_sort(label_paths) 25 | util.natural_sort(image_paths) 26 | if not opt.no_instance: 27 | util.natural_sort(instance_paths) 28 | 29 | label_paths = label_paths[:opt.max_dataset_size] 30 | image_paths = image_paths[:opt.max_dataset_size] 31 | instance_paths = instance_paths[:opt.max_dataset_size] 32 | 33 | if not opt.no_pairing_check: 34 | for path1, path2 in zip(label_paths, image_paths): 35 | assert self.paths_match(path1, path2), \ 36 | "The label-image pair (%s, %s) do not look like the right pair because the filenames are quite different. Are you sure about the pairing? Please see data/pix2pix_dataset.py to see what is going on, and use --no_pairing_check to bypass this." % (path1, path2) 37 | 38 | self.label_paths = label_paths 39 | self.image_paths = image_paths 40 | self.instance_paths = instance_paths 41 | 42 | size = len(self.label_paths) 43 | self.dataset_size = size 44 | 45 | def get_paths(self, opt): 46 | label_paths = [] 47 | image_paths = [] 48 | instance_paths = [] 49 | assert False, "A subclass of Pix2pixDataset must override self.get_paths(self, opt)" 50 | return label_paths, image_paths, instance_paths 51 | 52 | def paths_match(self, path1, path2): 53 | filename1_without_ext = os.path.splitext(os.path.basename(path1))[0] 54 | filename2_without_ext = os.path.splitext(os.path.basename(path2))[0] 55 | return filename1_without_ext == filename2_without_ext 56 | 57 | def __getitem__(self, index): 58 | # Label Image 59 | label_path = self.label_paths[index] 60 | label = Image.open(label_path) 61 | params = get_params(self.opt, label.size) 62 | transform_label = get_transform(self.opt, params, method=Image.NEAREST, normalize=False) 63 | label_tensor = transform_label(label) * 255.0 64 | label_tensor[label_tensor == 255] = self.opt.label_nc # 'unknown' is opt.label_nc 65 | 66 | # input image (real images) 67 | image_path = self.image_paths[index] 68 | assert self.paths_match(label_path, image_path), \ 69 | "The label_path %s and image_path %s don't match." % \ 70 | (label_path, image_path) 71 | image = Image.open(image_path) 72 | image = image.convert('RGB') 73 | 74 | transform_image = get_transform(self.opt, params) 75 | image_tensor = transform_image(image) 76 | 77 | # if using instance maps 78 | if self.opt.no_instance: 79 | instance_tensor = 0 80 | else: 81 | instance_path = self.instance_paths[index] 82 | instance = Image.open(instance_path) 83 | if instance.mode == 'L': 84 | instance_tensor = transform_label(instance) * 255 85 | instance_tensor = instance_tensor.long() 86 | else: 87 | instance_tensor = transform_label(instance) 88 | 89 | input_dict = {'label': label_tensor, 90 | 'instance': instance_tensor, 91 | 'image': image_tensor, 92 | 'path': image_path, 93 | } 94 | 95 | # Give subclasses a chance to modify the final output 96 | self.postprocess(input_dict) 97 | 98 | return input_dict 99 | 100 | def postprocess(self, input_dict): 101 | return input_dict 102 | 103 | def __len__(self): 104 | return self.dataset_size 105 | -------------------------------------------------------------------------------- /lib/SPADE-master/datasets/coco_generate_instance_map.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | import os 7 | import argparse 8 | from pycocotools.coco import COCO 9 | import numpy as np 10 | import skimage.io as io 11 | from skimage.draw import polygon 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('--annotation_file', type=str, default="./annotations/instances_train2017.json", 15 | help="Path to the annocation file. It can be downloaded at http://images.cocodataset.org/annotations/annotations_trainval2017.zip. Should be either instances_train2017.json or instances_val2017.json") 16 | parser.add_argument('--input_label_dir', type=str, default="./train_label/", 17 | help="Path to the directory containing label maps. It can be downloaded at http://calvin.inf.ed.ac.uk/wp-content/uploads/data/cocostuffdataset/stuffthingmaps_trainval2017.zip") 18 | parser.add_argument('--output_instance_dir', type=str, default="./train_inst/", 19 | help="Path to the output directory of instance maps") 20 | 21 | opt = parser.parse_args() 22 | 23 | print("annotation file at {}".format(opt.annotation_file)) 24 | print("input label maps at {}".format(opt.input_label_dir)) 25 | print("output dir at {}".format(opt.output_instance_dir)) 26 | 27 | # initialize COCO api for instance annotations 28 | coco = COCO(opt.annotation_file) 29 | 30 | 31 | # display COCO categories and supercategories 32 | cats = coco.loadCats(coco.getCatIds()) 33 | imgIds = coco.getImgIds(catIds=coco.getCatIds(cats)) 34 | for ix, id in enumerate(imgIds): 35 | if ix % 50 == 0: 36 | print("{} / {}".format(ix, len(imgIds))) 37 | img_dict = coco.loadImgs(id)[0] 38 | filename = img_dict["file_name"].replace("jpg", "png") 39 | label_name = os.path.join(opt.input_label_dir, filename) 40 | inst_name = os.path.join(opt.output_instance_dir, filename) 41 | img = io.imread(label_name, as_grey=True) 42 | 43 | annIds = coco.getAnnIds(imgIds=id, catIds=[], iscrowd=None) 44 | anns = coco.loadAnns(annIds) 45 | count = 0 46 | for ann in anns: 47 | if type(ann["segmentation"]) == list: 48 | if "segmentation" in ann: 49 | for seg in ann["segmentation"]: 50 | poly = np.array(seg).reshape((int(len(seg) / 2), 2)) 51 | rr, cc = polygon(poly[:, 1] - 1, poly[:, 0] - 1) 52 | img[rr, cc] = count 53 | count += 1 54 | 55 | io.imsave(inst_name, img) 56 | -------------------------------------------------------------------------------- /lib/SPADE-master/models/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | import importlib 7 | import torch 8 | 9 | 10 | def find_model_using_name(model_name): 11 | # Given the option --model [modelname], 12 | # the file "models/modelname_model.py" 13 | # will be imported. 14 | model_filename = "models." + model_name + "_model" 15 | modellib = importlib.import_module(model_filename) 16 | 17 | # In the file, the class called ModelNameModel() will 18 | # be instantiated. It has to be a subclass of torch.nn.Module, 19 | # and it is case-insensitive. 20 | model = None 21 | target_model_name = model_name.replace('_', '') + 'model' 22 | for name, cls in modellib.__dict__.items(): 23 | if name.lower() == target_model_name.lower() \ 24 | and issubclass(cls, torch.nn.Module): 25 | model = cls 26 | 27 | if model is None: 28 | print("In %s.py, there should be a subclass of torch.nn.Module with class name that matches %s in lowercase." % (model_filename, target_model_name)) 29 | exit(0) 30 | 31 | return model 32 | 33 | 34 | def get_option_setter(model_name): 35 | model_class = find_model_using_name(model_name) 36 | return model_class.modify_commandline_options 37 | 38 | 39 | def create_model(opt): 40 | model = find_model_using_name(opt.model) 41 | instance = model(opt) 42 | print("model [%s] was created" % (type(instance).__name__)) 43 | 44 | return instance 45 | -------------------------------------------------------------------------------- /lib/SPADE-master/models/networks/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | import torch 7 | from models.networks.base_network import BaseNetwork 8 | from models.networks.loss import * 9 | from models.networks.discriminator import * 10 | from models.networks.generator import * 11 | from models.networks.encoder import * 12 | import util.util as util 13 | 14 | 15 | def find_network_using_name(target_network_name, filename): 16 | target_class_name = target_network_name + filename 17 | module_name = 'models.networks.' + filename 18 | network = util.find_class_in_module(target_class_name, module_name) 19 | 20 | assert issubclass(network, BaseNetwork), \ 21 | "Class %s should be a subclass of BaseNetwork" % network 22 | 23 | return network 24 | 25 | 26 | def modify_commandline_options(parser, is_train): 27 | opt, _ = parser.parse_known_args() 28 | 29 | netG_cls = find_network_using_name(opt.netG, 'generator') 30 | parser = netG_cls.modify_commandline_options(parser, is_train) 31 | if is_train: 32 | netD_cls = find_network_using_name(opt.netD, 'discriminator') 33 | parser = netD_cls.modify_commandline_options(parser, is_train) 34 | netE_cls = find_network_using_name('conv', 'encoder') 35 | parser = netE_cls.modify_commandline_options(parser, is_train) 36 | 37 | return parser 38 | 39 | 40 | def create_network(cls, opt): 41 | net = cls(opt) 42 | net.print_network() 43 | if len(opt.gpu_ids) > 0: 44 | assert(torch.cuda.is_available()) 45 | net.cuda() 46 | net.init_weights(opt.init_type, opt.init_variance) 47 | return net 48 | 49 | 50 | def define_G(opt): 51 | netG_cls = find_network_using_name(opt.netG, 'generator') 52 | return create_network(netG_cls, opt) 53 | 54 | 55 | def define_D(opt): 56 | netD_cls = find_network_using_name(opt.netD, 'discriminator') 57 | return create_network(netD_cls, opt) 58 | 59 | 60 | def define_E(opt): 61 | # there exists only one encoder type 62 | netE_cls = find_network_using_name('conv', 'encoder') 63 | return create_network(netE_cls, opt) 64 | -------------------------------------------------------------------------------- /lib/SPADE-master/models/networks/architecture.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torchvision 10 | import torch.nn.utils.spectral_norm as spectral_norm 11 | from models.networks.normalization import SPADE 12 | 13 | 14 | # ResNet block that uses SPADE. 15 | # It differs from the ResNet block of pix2pixHD in that 16 | # it takes in the segmentation map as input, learns the skip connection if necessary, 17 | # and applies normalization first and then convolution. 18 | # This architecture seemed like a standard architecture for unconditional or 19 | # class-conditional GAN architecture using residual block. 20 | # The code was inspired from https://github.com/LMescheder/GAN_stability. 21 | class SPADEResnetBlock(nn.Module): 22 | def __init__(self, fin, fout, opt): 23 | super().__init__() 24 | # Attributes 25 | self.learned_shortcut = (fin != fout) 26 | fmiddle = min(fin, fout) 27 | 28 | # create conv layers 29 | self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=1) 30 | self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=1) 31 | if self.learned_shortcut: 32 | self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False) 33 | 34 | # apply spectral norm if specified 35 | if 'spectral' in opt.norm_G: 36 | self.conv_0 = spectral_norm(self.conv_0) 37 | self.conv_1 = spectral_norm(self.conv_1) 38 | if self.learned_shortcut: 39 | self.conv_s = spectral_norm(self.conv_s) 40 | 41 | # define normalization layers 42 | spade_config_str = opt.norm_G.replace('spectral', '') 43 | self.norm_0 = SPADE(spade_config_str, fin, opt.semantic_nc) 44 | self.norm_1 = SPADE(spade_config_str, fmiddle, opt.semantic_nc) 45 | if self.learned_shortcut: 46 | self.norm_s = SPADE(spade_config_str, fin, opt.semantic_nc) 47 | 48 | # note the resnet block with SPADE also takes in |seg|, 49 | # the semantic segmentation map as input 50 | def forward(self, x, seg): 51 | x_s = self.shortcut(x, seg) 52 | 53 | dx = self.conv_0(self.actvn(self.norm_0(x, seg))) 54 | dx = self.conv_1(self.actvn(self.norm_1(dx, seg))) 55 | 56 | out = x_s + dx 57 | 58 | return out 59 | 60 | def shortcut(self, x, seg): 61 | if self.learned_shortcut: 62 | x_s = self.conv_s(self.norm_s(x, seg)) 63 | else: 64 | x_s = x 65 | return x_s 66 | 67 | def actvn(self, x): 68 | return F.leaky_relu(x, 2e-1) 69 | 70 | 71 | # ResNet block used in pix2pixHD 72 | # We keep the same architecture as pix2pixHD. 73 | class ResnetBlock(nn.Module): 74 | def __init__(self, dim, norm_layer, activation=nn.ReLU(False), kernel_size=3): 75 | super().__init__() 76 | 77 | pw = (kernel_size - 1) // 2 78 | self.conv_block = nn.Sequential( 79 | nn.ReflectionPad2d(pw), 80 | norm_layer(nn.Conv2d(dim, dim, kernel_size=kernel_size)), 81 | activation, 82 | nn.ReflectionPad2d(pw), 83 | norm_layer(nn.Conv2d(dim, dim, kernel_size=kernel_size)) 84 | ) 85 | 86 | def forward(self, x): 87 | y = self.conv_block(x) 88 | out = x + y 89 | return out 90 | 91 | 92 | # VGG architecter, used for the perceptual loss using a pretrained VGG network 93 | class VGG19(torch.nn.Module): 94 | def __init__(self, requires_grad=False): 95 | super().__init__() 96 | vgg_pretrained_features = torchvision.models.vgg19(pretrained=True).features 97 | self.slice1 = torch.nn.Sequential() 98 | self.slice2 = torch.nn.Sequential() 99 | self.slice3 = torch.nn.Sequential() 100 | self.slice4 = torch.nn.Sequential() 101 | self.slice5 = torch.nn.Sequential() 102 | for x in range(2): 103 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 104 | for x in range(2, 7): 105 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 106 | for x in range(7, 12): 107 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 108 | for x in range(12, 21): 109 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 110 | for x in range(21, 30): 111 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 112 | if not requires_grad: 113 | for param in self.parameters(): 114 | param.requires_grad = False 115 | 116 | def forward(self, X): 117 | h_relu1 = self.slice1(X) 118 | h_relu2 = self.slice2(h_relu1) 119 | h_relu3 = self.slice3(h_relu2) 120 | h_relu4 = self.slice4(h_relu3) 121 | h_relu5 = self.slice5(h_relu4) 122 | out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5] 123 | return out 124 | -------------------------------------------------------------------------------- /lib/SPADE-master/models/networks/base_network.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | import torch.nn as nn 7 | from torch.nn import init 8 | 9 | 10 | class BaseNetwork(nn.Module): 11 | def __init__(self): 12 | super(BaseNetwork, self).__init__() 13 | 14 | @staticmethod 15 | def modify_commandline_options(parser, is_train): 16 | return parser 17 | 18 | def print_network(self): 19 | if isinstance(self, list): 20 | self = self[0] 21 | num_params = 0 22 | for param in self.parameters(): 23 | num_params += param.numel() 24 | print('Network [%s] was created. Total number of parameters: %.1f million. ' 25 | 'To see the architecture, do print(network).' 26 | % (type(self).__name__, num_params / 1000000)) 27 | 28 | def init_weights(self, init_type='normal', gain=0.02): 29 | def init_func(m): 30 | classname = m.__class__.__name__ 31 | if classname.find('BatchNorm2d') != -1: 32 | if hasattr(m, 'weight') and m.weight is not None: 33 | init.normal_(m.weight.data, 1.0, gain) 34 | if hasattr(m, 'bias') and m.bias is not None: 35 | init.constant_(m.bias.data, 0.0) 36 | elif hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 37 | if init_type == 'normal': 38 | init.normal_(m.weight.data, 0.0, gain) 39 | elif init_type == 'xavier': 40 | init.xavier_normal_(m.weight.data, gain=gain) 41 | elif init_type == 'xavier_uniform': 42 | init.xavier_uniform_(m.weight.data, gain=1.0) 43 | elif init_type == 'kaiming': 44 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 45 | elif init_type == 'orthogonal': 46 | init.orthogonal_(m.weight.data, gain=gain) 47 | elif init_type == 'none': # uses pytorch's default init method 48 | m.reset_parameters() 49 | else: 50 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 51 | if hasattr(m, 'bias') and m.bias is not None: 52 | init.constant_(m.bias.data, 0.0) 53 | 54 | self.apply(init_func) 55 | 56 | # propagate to children 57 | for m in self.children(): 58 | if hasattr(m, 'init_weights'): 59 | m.init_weights(init_type, gain) 60 | -------------------------------------------------------------------------------- /lib/SPADE-master/models/networks/discriminator.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | import torch.nn as nn 7 | import numpy as np 8 | import torch.nn.functional as F 9 | from models.networks.base_network import BaseNetwork 10 | from models.networks.normalization import get_nonspade_norm_layer 11 | import util.util as util 12 | 13 | 14 | class MultiscaleDiscriminator(BaseNetwork): 15 | @staticmethod 16 | def modify_commandline_options(parser, is_train): 17 | parser.add_argument('--netD_subarch', type=str, default='n_layer', 18 | help='architecture of each discriminator') 19 | parser.add_argument('--num_D', type=int, default=2, 20 | help='number of discriminators to be used in multiscale') 21 | opt, _ = parser.parse_known_args() 22 | 23 | # define properties of each discriminator of the multiscale discriminator 24 | subnetD = util.find_class_in_module(opt.netD_subarch + 'discriminator', 25 | 'models.networks.discriminator') 26 | subnetD.modify_commandline_options(parser, is_train) 27 | 28 | return parser 29 | 30 | def __init__(self, opt): 31 | super().__init__() 32 | self.opt = opt 33 | 34 | for i in range(opt.num_D): 35 | subnetD = self.create_single_discriminator(opt) 36 | self.add_module('discriminator_%d' % i, subnetD) 37 | 38 | def create_single_discriminator(self, opt): 39 | subarch = opt.netD_subarch 40 | if subarch == 'n_layer': 41 | netD = NLayerDiscriminator(opt) 42 | else: 43 | raise ValueError('unrecognized discriminator subarchitecture %s' % subarch) 44 | return netD 45 | 46 | def downsample(self, input): 47 | return F.avg_pool2d(input, kernel_size=3, 48 | stride=2, padding=[1, 1], 49 | count_include_pad=False) 50 | 51 | # Returns list of lists of discriminator outputs. 52 | # The final result is of size opt.num_D x opt.n_layers_D 53 | def forward(self, input): 54 | result = [] 55 | get_intermediate_features = not self.opt.no_ganFeat_loss 56 | for name, D in self.named_children(): 57 | out = D(input) 58 | if not get_intermediate_features: 59 | out = [out] 60 | result.append(out) 61 | input = self.downsample(input) 62 | 63 | return result 64 | 65 | 66 | # Defines the PatchGAN discriminator with the specified arguments. 67 | class NLayerDiscriminator(BaseNetwork): 68 | @staticmethod 69 | def modify_commandline_options(parser, is_train): 70 | parser.add_argument('--n_layers_D', type=int, default=4, 71 | help='# layers in each discriminator') 72 | return parser 73 | 74 | def __init__(self, opt): 75 | super().__init__() 76 | self.opt = opt 77 | 78 | kw = 4 79 | padw = int(np.ceil((kw - 1.0) / 2)) 80 | nf = opt.ndf 81 | input_nc = self.compute_D_input_nc(opt) 82 | 83 | norm_layer = get_nonspade_norm_layer(opt, opt.norm_D) 84 | sequence = [[nn.Conv2d(input_nc, nf, kernel_size=kw, stride=2, padding=padw), 85 | nn.LeakyReLU(0.2, False)]] 86 | 87 | for n in range(1, opt.n_layers_D): 88 | nf_prev = nf 89 | nf = min(nf * 2, 512) 90 | stride = 1 if n == opt.n_layers_D - 1 else 2 91 | sequence += [[norm_layer(nn.Conv2d(nf_prev, nf, kernel_size=kw, 92 | stride=stride, padding=padw)), 93 | nn.LeakyReLU(0.2, False) 94 | ]] 95 | 96 | sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]] 97 | 98 | # We divide the layers into groups to extract intermediate layer outputs 99 | for n in range(len(sequence)): 100 | self.add_module('model' + str(n), nn.Sequential(*sequence[n])) 101 | 102 | def compute_D_input_nc(self, opt): 103 | input_nc = opt.label_nc + opt.output_nc 104 | if opt.contain_dontcare_label: 105 | input_nc += 1 106 | if not opt.no_instance: 107 | input_nc += 1 108 | return input_nc 109 | 110 | def forward(self, input): 111 | results = [input] 112 | for submodel in self.children(): 113 | intermediate_output = submodel(results[-1]) 114 | results.append(intermediate_output) 115 | 116 | get_intermediate_features = not self.opt.no_ganFeat_loss 117 | if get_intermediate_features: 118 | return results[1:] 119 | else: 120 | return results[-1] 121 | -------------------------------------------------------------------------------- /lib/SPADE-master/models/networks/encoder.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | import torch.nn as nn 7 | import numpy as np 8 | import torch.nn.functional as F 9 | from models.networks.base_network import BaseNetwork 10 | from models.networks.normalization import get_nonspade_norm_layer 11 | 12 | 13 | class ConvEncoder(BaseNetwork): 14 | """ Same architecture as the image discriminator """ 15 | 16 | def __init__(self, opt): 17 | super().__init__() 18 | 19 | kw = 3 20 | pw = int(np.ceil((kw - 1.0) / 2)) 21 | ndf = opt.ngf 22 | norm_layer = get_nonspade_norm_layer(opt, opt.norm_E) 23 | self.layer1 = norm_layer(nn.Conv2d(3, ndf, kw, stride=2, padding=pw)) 24 | self.layer2 = norm_layer(nn.Conv2d(ndf * 1, ndf * 2, kw, stride=2, padding=pw)) 25 | self.layer3 = norm_layer(nn.Conv2d(ndf * 2, ndf * 4, kw, stride=2, padding=pw)) 26 | self.layer4 = norm_layer(nn.Conv2d(ndf * 4, ndf * 8, kw, stride=2, padding=pw)) 27 | self.layer5 = norm_layer(nn.Conv2d(ndf * 8, ndf * 8, kw, stride=2, padding=pw)) 28 | if opt.crop_size >= 256: 29 | self.layer6 = norm_layer(nn.Conv2d(ndf * 8, ndf * 8, kw, stride=2, padding=pw)) 30 | 31 | self.so = s0 = 4 32 | self.fc_mu = nn.Linear(ndf * 8 * s0 * s0, 256) 33 | self.fc_var = nn.Linear(ndf * 8 * s0 * s0, 256) 34 | 35 | self.actvn = nn.LeakyReLU(0.2, False) 36 | self.opt = opt 37 | 38 | def forward(self, x): 39 | if x.size(2) != 256 or x.size(3) != 256: 40 | x = F.interpolate(x, size=(256, 256), mode='bilinear') 41 | 42 | x = self.layer1(x) 43 | x = self.layer2(self.actvn(x)) 44 | x = self.layer3(self.actvn(x)) 45 | x = self.layer4(self.actvn(x)) 46 | x = self.layer5(self.actvn(x)) 47 | if self.opt.crop_size >= 256: 48 | x = self.layer6(self.actvn(x)) 49 | x = self.actvn(x) 50 | 51 | x = x.view(x.size(0), -1) 52 | mu = self.fc_mu(x) 53 | logvar = self.fc_var(x) 54 | 55 | return mu, logvar 56 | -------------------------------------------------------------------------------- /lib/SPADE-master/models/networks/generator.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from models.networks.base_network import BaseNetwork 10 | from models.networks.normalization import get_nonspade_norm_layer 11 | from models.networks.architecture import ResnetBlock as ResnetBlock 12 | from models.networks.architecture import SPADEResnetBlock as SPADEResnetBlock 13 | 14 | 15 | class SPADEGenerator(BaseNetwork): 16 | @staticmethod 17 | def modify_commandline_options(parser, is_train): 18 | parser.set_defaults(norm_G='spectralspadesyncbatch3x3') 19 | parser.add_argument('--num_upsampling_layers', 20 | choices=('normal', 'more', 'most'), default='normal', 21 | help="If 'more', adds upsampling layer between the two middle resnet blocks. If 'most', also add one more upsampling + resnet layer at the end of the generator") 22 | 23 | return parser 24 | 25 | def __init__(self, opt): 26 | super().__init__() 27 | self.opt = opt 28 | nf = opt.ngf 29 | 30 | self.sw, self.sh = self.compute_latent_vector_size(opt) 31 | 32 | if opt.use_vae: 33 | # In case of VAE, we will sample from random z vector 34 | self.fc = nn.Linear(opt.z_dim, 16 * nf * self.sw * self.sh) 35 | else: 36 | # Otherwise, we make the network deterministic by starting with 37 | # downsampled segmentation map instead of random z 38 | self.fc = nn.Conv2d(self.opt.semantic_nc, 16 * nf, 3, padding=1) 39 | 40 | self.head_0 = SPADEResnetBlock(16 * nf, 16 * nf, opt) 41 | 42 | self.G_middle_0 = SPADEResnetBlock(16 * nf, 16 * nf, opt) 43 | self.G_middle_1 = SPADEResnetBlock(16 * nf, 16 * nf, opt) 44 | 45 | self.up_0 = SPADEResnetBlock(16 * nf, 8 * nf, opt) 46 | self.up_1 = SPADEResnetBlock(8 * nf, 4 * nf, opt) 47 | self.up_2 = SPADEResnetBlock(4 * nf, 2 * nf, opt) 48 | self.up_3 = SPADEResnetBlock(2 * nf, 1 * nf, opt) 49 | 50 | final_nc = nf 51 | 52 | if opt.num_upsampling_layers == 'most': 53 | self.up_4 = SPADEResnetBlock(1 * nf, nf // 2, opt) 54 | final_nc = nf // 2 55 | 56 | self.conv_img = nn.Conv2d(final_nc, 3, 3, padding=1) 57 | 58 | self.up = nn.Upsample(scale_factor=2) 59 | 60 | def compute_latent_vector_size(self, opt): 61 | if opt.num_upsampling_layers == 'normal': 62 | num_up_layers = 5 63 | elif opt.num_upsampling_layers == 'more': 64 | num_up_layers = 6 65 | elif opt.num_upsampling_layers == 'most': 66 | num_up_layers = 7 67 | else: 68 | raise ValueError('opt.num_upsampling_layers [%s] not recognized' % 69 | opt.num_upsampling_layers) 70 | 71 | sw = opt.crop_size // (2**num_up_layers) 72 | sh = round(sw / opt.aspect_ratio) 73 | 74 | return sw, sh 75 | 76 | def forward(self, input, z=None): 77 | seg = input 78 | 79 | if self.opt.use_vae: 80 | # we sample z from unit normal and reshape the tensor 81 | if z is None: 82 | z = torch.randn(input.size(0), self.opt.z_dim, 83 | dtype=torch.float32, device=input.get_device()) 84 | x = self.fc(z) 85 | x = x.view(-1, 16 * self.opt.ngf, self.sh, self.sw) 86 | else: 87 | # we downsample segmap and run convolution 88 | x = F.interpolate(seg, size=(self.sh, self.sw)) 89 | x = self.fc(x) 90 | 91 | x = self.head_0(x, seg) 92 | 93 | x = self.up(x) 94 | x = self.G_middle_0(x, seg) 95 | 96 | if self.opt.num_upsampling_layers == 'more' or \ 97 | self.opt.num_upsampling_layers == 'most': 98 | x = self.up(x) 99 | 100 | x = self.G_middle_1(x, seg) 101 | 102 | x = self.up(x) 103 | x = self.up_0(x, seg) 104 | x = self.up(x) 105 | x = self.up_1(x, seg) 106 | x = self.up(x) 107 | x = self.up_2(x, seg) 108 | x = self.up(x) 109 | x = self.up_3(x, seg) 110 | 111 | if self.opt.num_upsampling_layers == 'most': 112 | x = self.up(x) 113 | x = self.up_4(x, seg) 114 | 115 | x = self.conv_img(F.leaky_relu(x, 2e-1)) 116 | x = F.tanh(x) 117 | 118 | return x 119 | 120 | 121 | class Pix2PixHDGenerator(BaseNetwork): 122 | @staticmethod 123 | def modify_commandline_options(parser, is_train): 124 | parser.add_argument('--resnet_n_downsample', type=int, default=4, help='number of downsampling layers in netG') 125 | parser.add_argument('--resnet_n_blocks', type=int, default=9, help='number of residual blocks in the global generator network') 126 | parser.add_argument('--resnet_kernel_size', type=int, default=3, 127 | help='kernel size of the resnet block') 128 | parser.add_argument('--resnet_initial_kernel_size', type=int, default=7, 129 | help='kernel size of the first convolution') 130 | parser.set_defaults(norm_G='instance') 131 | return parser 132 | 133 | def __init__(self, opt): 134 | super().__init__() 135 | input_nc = opt.label_nc + (1 if opt.contain_dontcare_label else 0) + (0 if opt.no_instance else 1) 136 | 137 | norm_layer = get_nonspade_norm_layer(opt, opt.norm_G) 138 | activation = nn.ReLU(False) 139 | 140 | model = [] 141 | 142 | # initial conv 143 | model += [nn.ReflectionPad2d(opt.resnet_initial_kernel_size // 2), 144 | norm_layer(nn.Conv2d(input_nc, opt.ngf, 145 | kernel_size=opt.resnet_initial_kernel_size, 146 | padding=0)), 147 | activation] 148 | 149 | # downsample 150 | mult = 1 151 | for i in range(opt.resnet_n_downsample): 152 | model += [norm_layer(nn.Conv2d(opt.ngf * mult, opt.ngf * mult * 2, 153 | kernel_size=3, stride=2, padding=1)), 154 | activation] 155 | mult *= 2 156 | 157 | # resnet blocks 158 | for i in range(opt.resnet_n_blocks): 159 | model += [ResnetBlock(opt.ngf * mult, 160 | norm_layer=norm_layer, 161 | activation=activation, 162 | kernel_size=opt.resnet_kernel_size)] 163 | 164 | # upsample 165 | for i in range(opt.resnet_n_downsample): 166 | nc_in = int(opt.ngf * mult) 167 | nc_out = int((opt.ngf * mult) / 2) 168 | model += [norm_layer(nn.ConvTranspose2d(nc_in, nc_out, 169 | kernel_size=3, stride=2, 170 | padding=1, output_padding=1)), 171 | activation] 172 | mult = mult // 2 173 | 174 | # final output conv 175 | model += [nn.ReflectionPad2d(3), 176 | nn.Conv2d(nc_out, opt.output_nc, kernel_size=7, padding=0), 177 | nn.Tanh()] 178 | 179 | self.model = nn.Sequential(*model) 180 | 181 | def forward(self, input, z=None): 182 | return self.model(input) 183 | -------------------------------------------------------------------------------- /lib/SPADE-master/models/networks/loss.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from models.networks.architecture import VGG19 10 | 11 | 12 | # Defines the GAN loss which uses either LSGAN or the regular GAN. 13 | # When LSGAN is used, it is basically same as MSELoss, 14 | # but it abstracts away the need to create the target label tensor 15 | # that has the same size as the input 16 | class GANLoss(nn.Module): 17 | def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0, 18 | tensor=torch.FloatTensor, opt=None): 19 | super(GANLoss, self).__init__() 20 | self.real_label = target_real_label 21 | self.fake_label = target_fake_label 22 | self.real_label_tensor = None 23 | self.fake_label_tensor = None 24 | self.zero_tensor = None 25 | self.Tensor = tensor 26 | self.gan_mode = gan_mode 27 | self.opt = opt 28 | if gan_mode == 'ls': 29 | pass 30 | elif gan_mode == 'original': 31 | pass 32 | elif gan_mode == 'w': 33 | pass 34 | elif gan_mode == 'hinge': 35 | pass 36 | else: 37 | raise ValueError('Unexpected gan_mode {}'.format(gan_mode)) 38 | 39 | def get_target_tensor(self, input, target_is_real): 40 | if target_is_real: 41 | if self.real_label_tensor is None: 42 | self.real_label_tensor = self.Tensor(1).fill_(self.real_label) 43 | self.real_label_tensor.requires_grad_(False) 44 | return self.real_label_tensor.expand_as(input) 45 | else: 46 | if self.fake_label_tensor is None: 47 | self.fake_label_tensor = self.Tensor(1).fill_(self.fake_label) 48 | self.fake_label_tensor.requires_grad_(False) 49 | return self.fake_label_tensor.expand_as(input) 50 | 51 | def get_zero_tensor(self, input): 52 | if self.zero_tensor is None: 53 | self.zero_tensor = self.Tensor(1).fill_(0) 54 | self.zero_tensor.requires_grad_(False) 55 | return self.zero_tensor.expand_as(input) 56 | 57 | def loss(self, input, target_is_real, for_discriminator=True): 58 | if self.gan_mode == 'original': # cross entropy loss 59 | target_tensor = self.get_target_tensor(input, target_is_real) 60 | loss = F.binary_cross_entropy_with_logits(input, target_tensor) 61 | return loss 62 | elif self.gan_mode == 'ls': 63 | target_tensor = self.get_target_tensor(input, target_is_real) 64 | return F.mse_loss(input, target_tensor) 65 | elif self.gan_mode == 'hinge': 66 | if for_discriminator: 67 | if target_is_real: 68 | minval = torch.min(input - 1, self.get_zero_tensor(input)) 69 | loss = -torch.mean(minval) 70 | else: 71 | minval = torch.min(-input - 1, self.get_zero_tensor(input)) 72 | loss = -torch.mean(minval) 73 | else: 74 | assert target_is_real, "The generator's hinge loss must be aiming for real" 75 | loss = -torch.mean(input) 76 | return loss 77 | else: 78 | # wgan 79 | if target_is_real: 80 | return -input.mean() 81 | else: 82 | return input.mean() 83 | 84 | def __call__(self, input, target_is_real, for_discriminator=True): 85 | # computing loss is a bit complicated because |input| may not be 86 | # a tensor, but list of tensors in case of multiscale discriminator 87 | if isinstance(input, list): 88 | loss = 0 89 | for pred_i in input: 90 | if isinstance(pred_i, list): 91 | pred_i = pred_i[-1] 92 | loss_tensor = self.loss(pred_i, target_is_real, for_discriminator) 93 | bs = 1 if len(loss_tensor.size()) == 0 else loss_tensor.size(0) 94 | new_loss = torch.mean(loss_tensor.view(bs, -1), dim=1) 95 | loss += new_loss 96 | return loss / len(input) 97 | else: 98 | return self.loss(input, target_is_real, for_discriminator) 99 | 100 | 101 | # Perceptual loss that uses a pretrained VGG network 102 | class VGGLoss(nn.Module): 103 | def __init__(self, gpu_ids): 104 | super(VGGLoss, self).__init__() 105 | self.vgg = VGG19().cuda() 106 | self.criterion = nn.L1Loss() 107 | self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0] 108 | 109 | def forward(self, x, y): 110 | x_vgg, y_vgg = self.vgg(x), self.vgg(y) 111 | loss = 0 112 | for i in range(len(x_vgg)): 113 | loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach()) 114 | return loss 115 | 116 | 117 | # KL Divergence loss used in VAE with an image encoder 118 | class KLDLoss(nn.Module): 119 | def forward(self, mu, logvar): 120 | return -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) 121 | -------------------------------------------------------------------------------- /lib/SPADE-master/models/networks/normalization.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | import re 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from models.networks.sync_batchnorm import SynchronizedBatchNorm2d 11 | import torch.nn.utils.spectral_norm as spectral_norm 12 | 13 | 14 | # Returns a function that creates a normalization function 15 | # that does not condition on semantic map 16 | def get_nonspade_norm_layer(opt, norm_type='instance'): 17 | # helper function to get # output channels of the previous layer 18 | def get_out_channel(layer): 19 | if hasattr(layer, 'out_channels'): 20 | return getattr(layer, 'out_channels') 21 | return layer.weight.size(0) 22 | 23 | # this function will be returned 24 | def add_norm_layer(layer): 25 | nonlocal norm_type 26 | if norm_type.startswith('spectral'): 27 | layer = spectral_norm(layer) 28 | subnorm_type = norm_type[len('spectral'):] 29 | 30 | if subnorm_type == 'none' or len(subnorm_type) == 0: 31 | return layer 32 | 33 | # remove bias in the previous layer, which is meaningless 34 | # since it has no effect after normalization 35 | if getattr(layer, 'bias', None) is not None: 36 | delattr(layer, 'bias') 37 | layer.register_parameter('bias', None) 38 | 39 | if subnorm_type == 'batch': 40 | norm_layer = nn.BatchNorm2d(get_out_channel(layer), affine=True) 41 | elif subnorm_type == 'sync_batch': 42 | norm_layer = SynchronizedBatchNorm2d(get_out_channel(layer), affine=True) 43 | elif subnorm_type == 'instance': 44 | norm_layer = nn.InstanceNorm2d(get_out_channel(layer), affine=False) 45 | else: 46 | raise ValueError('normalization layer %s is not recognized' % subnorm_type) 47 | 48 | return nn.Sequential(layer, norm_layer) 49 | 50 | return add_norm_layer 51 | 52 | 53 | # Creates SPADE normalization layer based on the given configuration 54 | # SPADE consists of two steps. First, it normalizes the activations using 55 | # your favorite normalization method, such as Batch Norm or Instance Norm. 56 | # Second, it applies scale and bias to the normalized output, conditioned on 57 | # the segmentation map. 58 | # The format of |config_text| is spade(norm)(ks), where 59 | # (norm) specifies the type of parameter-free normalization. 60 | # (e.g. syncbatch, batch, instance) 61 | # (ks) specifies the size of kernel in the SPADE module (e.g. 3x3) 62 | # Example |config_text| will be spadesyncbatch3x3, or spadeinstance5x5. 63 | # Also, the other arguments are 64 | # |norm_nc|: the #channels of the normalized activations, hence the output dim of SPADE 65 | # |label_nc|: the #channels of the input semantic map, hence the input dim of SPADE 66 | class SPADE(nn.Module): 67 | def __init__(self, config_text, norm_nc, label_nc): 68 | super().__init__() 69 | 70 | assert config_text.startswith('spade') 71 | parsed = re.search('spade(\D+)(\d)x\d', config_text) 72 | param_free_norm_type = str(parsed.group(1)) 73 | ks = int(parsed.group(2)) 74 | 75 | if param_free_norm_type == 'instance': 76 | self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False) 77 | elif param_free_norm_type == 'syncbatch': 78 | self.param_free_norm = SynchronizedBatchNorm2d(norm_nc, affine=False) 79 | elif param_free_norm_type == 'batch': 80 | self.param_free_norm = nn.BatchNorm2d(norm_nc, affine=False) 81 | else: 82 | raise ValueError('%s is not a recognized param-free norm type in SPADE' 83 | % param_free_norm_type) 84 | 85 | # The dimension of the intermediate embedding space. Yes, hardcoded. 86 | nhidden = 128 87 | 88 | pw = ks // 2 89 | self.mlp_shared = nn.Sequential( 90 | nn.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw), 91 | nn.ReLU() 92 | ) 93 | self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw) 94 | self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw) 95 | 96 | def forward(self, x, segmap): 97 | 98 | # Part 1. generate parameter-free normalized activations 99 | normalized = self.param_free_norm(x) 100 | 101 | # Part 2. produce scaling and bias conditioned on semantic map 102 | segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest') 103 | actv = self.mlp_shared(segmap) 104 | gamma = self.mlp_gamma(actv) 105 | beta = self.mlp_beta(actv) 106 | 107 | # apply scale and bias 108 | out = normalized * (1 + gamma) + beta 109 | 110 | return out 111 | -------------------------------------------------------------------------------- /lib/SPADE-master/models/networks/sync_batchnorm/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : __init__.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d 12 | from .batchnorm import patch_sync_batchnorm, convert_model 13 | from .replicate import DataParallelWithCallback, patch_replication_callback 14 | -------------------------------------------------------------------------------- /lib/SPADE-master/models/networks/sync_batchnorm/batchnorm_reimpl.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # File : batchnorm_reimpl.py 4 | # Author : acgtyrant 5 | # Date : 11/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.init as init 14 | 15 | __all__ = ['BatchNorm2dReimpl'] 16 | 17 | 18 | class BatchNorm2dReimpl(nn.Module): 19 | """ 20 | A re-implementation of batch normalization, used for testing the numerical 21 | stability. 22 | 23 | Author: acgtyrant 24 | See also: 25 | https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues/14 26 | """ 27 | def __init__(self, num_features, eps=1e-5, momentum=0.1): 28 | super().__init__() 29 | 30 | self.num_features = num_features 31 | self.eps = eps 32 | self.momentum = momentum 33 | self.weight = nn.Parameter(torch.empty(num_features)) 34 | self.bias = nn.Parameter(torch.empty(num_features)) 35 | self.register_buffer('running_mean', torch.zeros(num_features)) 36 | self.register_buffer('running_var', torch.ones(num_features)) 37 | self.reset_parameters() 38 | 39 | def reset_running_stats(self): 40 | self.running_mean.zero_() 41 | self.running_var.fill_(1) 42 | 43 | def reset_parameters(self): 44 | self.reset_running_stats() 45 | init.uniform_(self.weight) 46 | init.zeros_(self.bias) 47 | 48 | def forward(self, input_): 49 | batchsize, channels, height, width = input_.size() 50 | numel = batchsize * height * width 51 | input_ = input_.permute(1, 0, 2, 3).contiguous().view(channels, numel) 52 | sum_ = input_.sum(1) 53 | sum_of_square = input_.pow(2).sum(1) 54 | mean = sum_ / numel 55 | sumvar = sum_of_square - sum_ * mean 56 | 57 | self.running_mean = ( 58 | (1 - self.momentum) * self.running_mean 59 | + self.momentum * mean.detach() 60 | ) 61 | unbias_var = sumvar / (numel - 1) 62 | self.running_var = ( 63 | (1 - self.momentum) * self.running_var 64 | + self.momentum * unbias_var.detach() 65 | ) 66 | 67 | bias_var = sumvar / numel 68 | inv_std = 1 / (bias_var + self.eps).pow(0.5) 69 | output = ( 70 | (input_ - mean.unsqueeze(1)) * inv_std.unsqueeze(1) * 71 | self.weight.unsqueeze(1) + self.bias.unsqueeze(1)) 72 | 73 | return output.view(channels, batchsize, height, width).permute(1, 0, 2, 3).contiguous() 74 | 75 | -------------------------------------------------------------------------------- /lib/SPADE-master/models/networks/sync_batchnorm/comm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : comm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import queue 12 | import collections 13 | import threading 14 | 15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] 16 | 17 | 18 | class FutureResult(object): 19 | """A thread-safe future implementation. Used only as one-to-one pipe.""" 20 | 21 | def __init__(self): 22 | self._result = None 23 | self._lock = threading.Lock() 24 | self._cond = threading.Condition(self._lock) 25 | 26 | def put(self, result): 27 | with self._lock: 28 | assert self._result is None, 'Previous result has\'t been fetched.' 29 | self._result = result 30 | self._cond.notify() 31 | 32 | def get(self): 33 | with self._lock: 34 | if self._result is None: 35 | self._cond.wait() 36 | 37 | res = self._result 38 | self._result = None 39 | return res 40 | 41 | 42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) 43 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) 44 | 45 | 46 | class SlavePipe(_SlavePipeBase): 47 | """Pipe for master-slave communication.""" 48 | 49 | def run_slave(self, msg): 50 | self.queue.put((self.identifier, msg)) 51 | ret = self.result.get() 52 | self.queue.put(True) 53 | return ret 54 | 55 | 56 | class SyncMaster(object): 57 | """An abstract `SyncMaster` object. 58 | 59 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should 60 | call `register(id)` and obtain an `SlavePipe` to communicate with the master. 61 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, 62 | and passed to a registered callback. 63 | - After receiving the messages, the master device should gather the information and determine to message passed 64 | back to each slave devices. 65 | """ 66 | 67 | def __init__(self, master_callback): 68 | """ 69 | 70 | Args: 71 | master_callback: a callback to be invoked after having collected messages from slave devices. 72 | """ 73 | self._master_callback = master_callback 74 | self._queue = queue.Queue() 75 | self._registry = collections.OrderedDict() 76 | self._activated = False 77 | 78 | def __getstate__(self): 79 | return {'master_callback': self._master_callback} 80 | 81 | def __setstate__(self, state): 82 | self.__init__(state['master_callback']) 83 | 84 | def register_slave(self, identifier): 85 | """ 86 | Register an slave device. 87 | 88 | Args: 89 | identifier: an identifier, usually is the device id. 90 | 91 | Returns: a `SlavePipe` object which can be used to communicate with the master device. 92 | 93 | """ 94 | if self._activated: 95 | assert self._queue.empty(), 'Queue is not clean before next initialization.' 96 | self._activated = False 97 | self._registry.clear() 98 | future = FutureResult() 99 | self._registry[identifier] = _MasterRegistry(future) 100 | return SlavePipe(identifier, self._queue, future) 101 | 102 | def run_master(self, master_msg): 103 | """ 104 | Main entry for the master device in each forward pass. 105 | The messages were first collected from each devices (including the master device), and then 106 | an callback will be invoked to compute the message to be sent back to each devices 107 | (including the master device). 108 | 109 | Args: 110 | master_msg: the message that the master want to send to itself. This will be placed as the first 111 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. 112 | 113 | Returns: the message to be sent back to the master device. 114 | 115 | """ 116 | self._activated = True 117 | 118 | intermediates = [(0, master_msg)] 119 | for i in range(self.nr_slaves): 120 | intermediates.append(self._queue.get()) 121 | 122 | results = self._master_callback(intermediates) 123 | assert results[0][0] == 0, 'The first result should belongs to the master.' 124 | 125 | for i, res in results: 126 | if i == 0: 127 | continue 128 | self._registry[i].result.put(res) 129 | 130 | for i in range(self.nr_slaves): 131 | assert self._queue.get() is True 132 | 133 | return results[0][1] 134 | 135 | @property 136 | def nr_slaves(self): 137 | return len(self._registry) 138 | -------------------------------------------------------------------------------- /lib/SPADE-master/models/networks/sync_batchnorm/replicate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : replicate.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import functools 12 | 13 | from torch.nn.parallel.data_parallel import DataParallel 14 | 15 | __all__ = [ 16 | 'CallbackContext', 17 | 'execute_replication_callbacks', 18 | 'DataParallelWithCallback', 19 | 'patch_replication_callback' 20 | ] 21 | 22 | 23 | class CallbackContext(object): 24 | pass 25 | 26 | 27 | def execute_replication_callbacks(modules): 28 | """ 29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. 30 | 31 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 32 | 33 | Note that, as all modules are isomorphism, we assign each sub-module with a context 34 | (shared among multiple copies of this module on different devices). 35 | Through this context, different copies can share some information. 36 | 37 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback 38 | of any slave copies. 39 | """ 40 | master_copy = modules[0] 41 | nr_modules = len(list(master_copy.modules())) 42 | ctxs = [CallbackContext() for _ in range(nr_modules)] 43 | 44 | for i, module in enumerate(modules): 45 | for j, m in enumerate(module.modules()): 46 | if hasattr(m, '__data_parallel_replicate__'): 47 | m.__data_parallel_replicate__(ctxs[j], i) 48 | 49 | 50 | class DataParallelWithCallback(DataParallel): 51 | """ 52 | Data Parallel with a replication callback. 53 | 54 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by 55 | original `replicate` function. 56 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 57 | 58 | Examples: 59 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 60 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 61 | # sync_bn.__data_parallel_replicate__ will be invoked. 62 | """ 63 | 64 | def replicate(self, module, device_ids): 65 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids) 66 | execute_replication_callbacks(modules) 67 | return modules 68 | 69 | 70 | def patch_replication_callback(data_parallel): 71 | """ 72 | Monkey-patch an existing `DataParallel` object. Add the replication callback. 73 | Useful when you have customized `DataParallel` implementation. 74 | 75 | Examples: 76 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 77 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) 78 | > patch_replication_callback(sync_bn) 79 | # this is equivalent to 80 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 81 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 82 | """ 83 | 84 | assert isinstance(data_parallel, DataParallel) 85 | 86 | old_replicate = data_parallel.replicate 87 | 88 | @functools.wraps(old_replicate) 89 | def new_replicate(module, device_ids): 90 | modules = old_replicate(module, device_ids) 91 | execute_replication_callbacks(modules) 92 | return modules 93 | 94 | data_parallel.replicate = new_replicate 95 | -------------------------------------------------------------------------------- /lib/SPADE-master/models/networks/sync_batchnorm/unittest.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : unittest.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import unittest 12 | import torch 13 | 14 | 15 | class TorchTestCase(unittest.TestCase): 16 | def assertTensorClose(self, x, y): 17 | adiff = float((x - y).abs().max()) 18 | if (y == 0).all(): 19 | rdiff = 'NaN' 20 | else: 21 | rdiff = float((adiff / y).abs().max()) 22 | 23 | message = ( 24 | 'Tensor close check failed\n' 25 | 'adiff={}\n' 26 | 'rdiff={}\n' 27 | ).format(adiff, rdiff) 28 | self.assertTrue(torch.allclose(x, y), message) 29 | 30 | -------------------------------------------------------------------------------- /lib/SPADE-master/options/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ -------------------------------------------------------------------------------- /lib/SPADE-master/options/test_options.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | from .base_options import BaseOptions 7 | 8 | 9 | class TestOptions(BaseOptions): 10 | def initialize(self, parser): 11 | BaseOptions.initialize(self, parser) 12 | parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.') 13 | parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') 14 | parser.add_argument('--how_many', type=int, default=float("inf"), help='how many test images to run') 15 | 16 | parser.set_defaults(preprocess_mode='scale_width_and_crop', crop_size=256, load_size=256, display_winsize=256) 17 | parser.set_defaults(serial_batches=True) 18 | parser.set_defaults(no_flip=True) 19 | parser.set_defaults(phase='test') 20 | self.isTrain = False 21 | return parser 22 | -------------------------------------------------------------------------------- /lib/SPADE-master/options/train_options.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | from .base_options import BaseOptions 7 | 8 | 9 | class TrainOptions(BaseOptions): 10 | def initialize(self, parser): 11 | BaseOptions.initialize(self, parser) 12 | # for displays 13 | parser.add_argument('--display_freq', type=int, default=100, help='frequency of showing training results on screen') 14 | parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console') 15 | parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results') 16 | parser.add_argument('--save_epoch_freq', type=int, default=10, help='frequency of saving checkpoints at the end of epochs') 17 | parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/') 18 | parser.add_argument('--debug', action='store_true', help='only do one epoch and displays at each iteration') 19 | parser.add_argument('--tf_log', action='store_true', help='if specified, use tensorboard logging. Requires tensorflow installed') 20 | 21 | # for training 22 | parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model') 23 | parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') 24 | parser.add_argument('--niter', type=int, default=50, help='# of iter at starting learning rate. This is NOT the total #epochs. Totla #epochs is niter + niter_decay') 25 | parser.add_argument('--niter_decay', type=int, default=0, help='# of iter to linearly decay learning rate to zero') 26 | parser.add_argument('--optimizer', type=str, default='adam') 27 | parser.add_argument('--beta1', type=float, default=0.0, help='momentum term of adam') 28 | parser.add_argument('--beta2', type=float, default=0.9, help='momentum term of adam') 29 | parser.add_argument('--no_TTUR', action='store_true', help='Use TTUR training scheme') 30 | 31 | # the default values for beta1 and beta2 differ by TTUR option 32 | opt, _ = parser.parse_known_args() 33 | if opt.no_TTUR: 34 | parser.set_defaults(beta1=0.5, beta2=0.999) 35 | 36 | parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam') 37 | parser.add_argument('--D_steps_per_G', type=int, default=1, help='number of discriminator iterations per generator iterations.') 38 | 39 | # for discriminators 40 | parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in first conv layer') 41 | parser.add_argument('--lambda_feat', type=float, default=10.0, help='weight for feature matching loss') 42 | parser.add_argument('--lambda_vgg', type=float, default=10.0, help='weight for vgg loss') 43 | parser.add_argument('--no_ganFeat_loss', action='store_true', help='if specified, do *not* use discriminator feature matching loss') 44 | parser.add_argument('--no_vgg_loss', action='store_true', help='if specified, do *not* use VGG feature matching loss') 45 | parser.add_argument('--gan_mode', type=str, default='hinge', help='(ls|original|hinge)') 46 | parser.add_argument('--netD', type=str, default='multiscale', help='(n_layers|multiscale|image)') 47 | parser.add_argument('--lambda_kld', type=float, default=0.05) 48 | self.isTrain = True 49 | return parser 50 | -------------------------------------------------------------------------------- /lib/SPADE-master/requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.0.0 2 | torchvision 3 | dominate>=2.3.1 4 | dill 5 | scikit-image 6 | -------------------------------------------------------------------------------- /lib/SPADE-master/test.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | import os 7 | from collections import OrderedDict 8 | 9 | import data 10 | from options.test_options import TestOptions 11 | from models.pix2pix_model import Pix2PixModel 12 | from util.visualizer import Visualizer 13 | from util import html 14 | 15 | opt = TestOptions().parse() 16 | 17 | dataloader = data.create_dataloader(opt) 18 | 19 | model = Pix2PixModel(opt) 20 | model.eval() 21 | 22 | visualizer = Visualizer(opt) 23 | 24 | # create a webpage that summarizes the all results 25 | web_dir = os.path.join(opt.results_dir, opt.name, 26 | '%s_%s' % (opt.phase, opt.which_epoch)) 27 | webpage = html.HTML(web_dir, 28 | 'Experiment = %s, Phase = %s, Epoch = %s' % 29 | (opt.name, opt.phase, opt.which_epoch)) 30 | 31 | # test 32 | for i, data_i in enumerate(dataloader): 33 | if i * opt.batchSize >= opt.how_many: 34 | break 35 | 36 | generated = model(data_i, mode='inference') 37 | 38 | img_path = data_i['path'] 39 | for b in range(generated.shape[0]): 40 | print('process image... %s' % img_path[b]) 41 | visuals = OrderedDict([('input_label', data_i['label'][b]), 42 | ('synthesized_image', generated[b])]) 43 | visualizer.save_images(webpage, visuals, img_path[b:b + 1]) 44 | 45 | webpage.save() 46 | -------------------------------------------------------------------------------- /lib/SPADE-master/train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | import sys 7 | from collections import OrderedDict 8 | from options.train_options import TrainOptions 9 | import data 10 | from util.iter_counter import IterationCounter 11 | from util.visualizer import Visualizer 12 | from trainers.pix2pix_trainer import Pix2PixTrainer 13 | 14 | # parse options 15 | opt = TrainOptions().parse() 16 | 17 | # print options to help debugging 18 | print(' '.join(sys.argv)) 19 | 20 | # load the dataset 21 | dataloader = data.create_dataloader(opt) 22 | 23 | # create trainer for our model 24 | trainer = Pix2PixTrainer(opt) 25 | 26 | # create tool for counting iterations 27 | iter_counter = IterationCounter(opt, len(dataloader)) 28 | 29 | # create tool for visualization 30 | visualizer = Visualizer(opt) 31 | 32 | for epoch in iter_counter.training_epochs(): 33 | iter_counter.record_epoch_start(epoch) 34 | for i, data_i in enumerate(dataloader, start=iter_counter.epoch_iter): 35 | iter_counter.record_one_iteration() 36 | 37 | # Training 38 | # train generator 39 | if i % opt.D_steps_per_G == 0: 40 | trainer.run_generator_one_step(data_i) 41 | 42 | # train discriminator 43 | trainer.run_discriminator_one_step(data_i) 44 | 45 | # Visualizations 46 | if iter_counter.needs_printing(): 47 | losses = trainer.get_latest_losses() 48 | visualizer.print_current_errors(epoch, iter_counter.epoch_iter, 49 | losses, iter_counter.time_per_iter) 50 | visualizer.plot_current_errors(losses, iter_counter.total_steps_so_far) 51 | 52 | if iter_counter.needs_displaying(): 53 | visuals = OrderedDict([('input_label', data_i['label']), 54 | ('synthesized_image', trainer.get_latest_generated()), 55 | ('real_image', data_i['image'])]) 56 | visualizer.display_current_results(visuals, epoch, iter_counter.total_steps_so_far) 57 | 58 | if iter_counter.needs_saving(): 59 | print('saving the latest model (epoch %d, total_steps %d)' % 60 | (epoch, iter_counter.total_steps_so_far)) 61 | trainer.save('latest') 62 | iter_counter.record_current_iter() 63 | 64 | trainer.update_learning_rate(epoch) 65 | iter_counter.record_epoch_end() 66 | 67 | if epoch % opt.save_epoch_freq == 0 or \ 68 | epoch == iter_counter.total_epochs: 69 | print('saving the model at the end of epoch %d, iters %d' % 70 | (epoch, iter_counter.total_steps_so_far)) 71 | trainer.save('latest') 72 | trainer.save(epoch) 73 | 74 | print('Training was successfully finished.') 75 | -------------------------------------------------------------------------------- /lib/SPADE-master/trainers/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | -------------------------------------------------------------------------------- /lib/SPADE-master/trainers/pix2pix_trainer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | from models.networks.sync_batchnorm import DataParallelWithCallback 7 | from models.pix2pix_model import Pix2PixModel 8 | 9 | 10 | class Pix2PixTrainer(): 11 | """ 12 | Trainer creates the model and optimizers, and uses them to 13 | updates the weights of the network while reporting losses 14 | and the latest visuals to visualize the progress in training. 15 | """ 16 | 17 | def __init__(self, opt): 18 | self.opt = opt 19 | self.pix2pix_model = Pix2PixModel(opt) 20 | if len(opt.gpu_ids) > 0: 21 | self.pix2pix_model = DataParallelWithCallback(self.pix2pix_model, 22 | device_ids=opt.gpu_ids) 23 | self.pix2pix_model_on_one_gpu = self.pix2pix_model.module 24 | else: 25 | self.pix2pix_model_on_one_gpu = self.pix2pix_model 26 | 27 | self.generated = None 28 | if opt.isTrain: 29 | self.optimizer_G, self.optimizer_D = \ 30 | self.pix2pix_model_on_one_gpu.create_optimizers(opt) 31 | self.old_lr = opt.lr 32 | 33 | def run_generator_one_step(self, data): 34 | self.optimizer_G.zero_grad() 35 | g_losses, generated = self.pix2pix_model(data, mode='generator') 36 | g_loss = sum(g_losses.values()).mean() 37 | g_loss.backward() 38 | self.optimizer_G.step() 39 | self.g_losses = g_losses 40 | self.generated = generated 41 | 42 | def run_discriminator_one_step(self, data): 43 | self.optimizer_D.zero_grad() 44 | d_losses = self.pix2pix_model(data, mode='discriminator') 45 | d_loss = sum(d_losses.values()).mean() 46 | d_loss.backward() 47 | self.optimizer_D.step() 48 | self.d_losses = d_losses 49 | 50 | def get_latest_losses(self): 51 | return {**self.g_losses, **self.d_losses} 52 | 53 | def get_latest_generated(self): 54 | return self.generated 55 | 56 | def update_learning_rate(self, epoch): 57 | self.update_learning_rate(epoch) 58 | 59 | def save(self, epoch): 60 | self.pix2pix_model_on_one_gpu.save(epoch) 61 | 62 | ################################################################## 63 | # Helper functions 64 | ################################################################## 65 | 66 | def update_learning_rate(self, epoch): 67 | if epoch > self.opt.niter: 68 | lrd = self.opt.lr / self.opt.niter_decay 69 | new_lr = self.old_lr - lrd 70 | else: 71 | new_lr = self.old_lr 72 | 73 | if new_lr != self.old_lr: 74 | if self.opt.no_TTUR: 75 | new_lr_G = new_lr 76 | new_lr_D = new_lr 77 | else: 78 | new_lr_G = new_lr / 2 79 | new_lr_D = new_lr * 2 80 | 81 | for param_group in self.optimizer_D.param_groups: 82 | param_group['lr'] = new_lr_D 83 | for param_group in self.optimizer_G.param_groups: 84 | param_group['lr'] = new_lr_G 85 | print('update learning rate: %f -> %f' % (self.old_lr, new_lr)) 86 | self.old_lr = new_lr 87 | -------------------------------------------------------------------------------- /lib/SPADE-master/util/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | -------------------------------------------------------------------------------- /lib/SPADE-master/util/coco.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | 7 | def id2label(id): 8 | if id == 182: 9 | id = 0 10 | else: 11 | id = id + 1 12 | labelmap = \ 13 | {0: 'unlabeled', 14 | 1: 'person', 15 | 2: 'bicycle', 16 | 3: 'car', 17 | 4: 'motorcycle', 18 | 5: 'airplane', 19 | 6: 'bus', 20 | 7: 'train', 21 | 8: 'truck', 22 | 9: 'boat', 23 | 10: 'traffic light', 24 | 11: 'fire hydrant', 25 | 12: 'street sign', 26 | 13: 'stop sign', 27 | 14: 'parking meter', 28 | 15: 'bench', 29 | 16: 'bird', 30 | 17: 'cat', 31 | 18: 'dog', 32 | 19: 'horse', 33 | 20: 'sheep', 34 | 21: 'cow', 35 | 22: 'elephant', 36 | 23: 'bear', 37 | 24: 'zebra', 38 | 25: 'giraffe', 39 | 26: 'hat', 40 | 27: 'backpack', 41 | 28: 'umbrella', 42 | 29: 'shoe', 43 | 30: 'eye glasses', 44 | 31: 'handbag', 45 | 32: 'tie', 46 | 33: 'suitcase', 47 | 34: 'frisbee', 48 | 35: 'skis', 49 | 36: 'snowboard', 50 | 37: 'sports ball', 51 | 38: 'kite', 52 | 39: 'baseball bat', 53 | 40: 'baseball glove', 54 | 41: 'skateboard', 55 | 42: 'surfboard', 56 | 43: 'tennis racket', 57 | 44: 'bottle', 58 | 45: 'plate', 59 | 46: 'wine glass', 60 | 47: 'cup', 61 | 48: 'fork', 62 | 49: 'knife', 63 | 50: 'spoon', 64 | 51: 'bowl', 65 | 52: 'banana', 66 | 53: 'apple', 67 | 54: 'sandwich', 68 | 55: 'orange', 69 | 56: 'broccoli', 70 | 57: 'carrot', 71 | 58: 'hot dog', 72 | 59: 'pizza', 73 | 60: 'donut', 74 | 61: 'cake', 75 | 62: 'chair', 76 | 63: 'couch', 77 | 64: 'potted plant', 78 | 65: 'bed', 79 | 66: 'mirror', 80 | 67: 'dining table', 81 | 68: 'window', 82 | 69: 'desk', 83 | 70: 'toilet', 84 | 71: 'door', 85 | 72: 'tv', 86 | 73: 'laptop', 87 | 74: 'mouse', 88 | 75: 'remote', 89 | 76: 'keyboard', 90 | 77: 'cell phone', 91 | 78: 'microwave', 92 | 79: 'oven', 93 | 80: 'toaster', 94 | 81: 'sink', 95 | 82: 'refrigerator', 96 | 83: 'blender', 97 | 84: 'book', 98 | 85: 'clock', 99 | 86: 'vase', 100 | 87: 'scissors', 101 | 88: 'teddy bear', 102 | 89: 'hair drier', 103 | 90: 'toothbrush', 104 | 91: 'hair brush', # Last class of Thing 105 | 92: 'banner', # Beginning of Stuff 106 | 93: 'blanket', 107 | 94: 'branch', 108 | 95: 'bridge', 109 | 96: 'building-other', 110 | 97: 'bush', 111 | 98: 'cabinet', 112 | 99: 'cage', 113 | 100: 'cardboard', 114 | 101: 'carpet', 115 | 102: 'ceiling-other', 116 | 103: 'ceiling-tile', 117 | 104: 'cloth', 118 | 105: 'clothes', 119 | 106: 'clouds', 120 | 107: 'counter', 121 | 108: 'cupboard', 122 | 109: 'curtain', 123 | 110: 'desk-stuff', 124 | 111: 'dirt', 125 | 112: 'door-stuff', 126 | 113: 'fence', 127 | 114: 'floor-marble', 128 | 115: 'floor-other', 129 | 116: 'floor-stone', 130 | 117: 'floor-tile', 131 | 118: 'floor-wood', 132 | 119: 'flower', 133 | 120: 'fog', 134 | 121: 'food-other', 135 | 122: 'fruit', 136 | 123: 'furniture-other', 137 | 124: 'grass', 138 | 125: 'gravel', 139 | 126: 'ground-other', 140 | 127: 'hill', 141 | 128: 'house', 142 | 129: 'leaves', 143 | 130: 'light', 144 | 131: 'mat', 145 | 132: 'metal', 146 | 133: 'mirror-stuff', 147 | 134: 'moss', 148 | 135: 'mountain', 149 | 136: 'mud', 150 | 137: 'napkin', 151 | 138: 'net', 152 | 139: 'paper', 153 | 140: 'pavement', 154 | 141: 'pillow', 155 | 142: 'plant-other', 156 | 143: 'plastic', 157 | 144: 'platform', 158 | 145: 'playingfield', 159 | 146: 'railing', 160 | 147: 'railroad', 161 | 148: 'river', 162 | 149: 'road', 163 | 150: 'rock', 164 | 151: 'roof', 165 | 152: 'rug', 166 | 153: 'salad', 167 | 154: 'sand', 168 | 155: 'sea', 169 | 156: 'shelf', 170 | 157: 'sky-other', 171 | 158: 'skyscraper', 172 | 159: 'snow', 173 | 160: 'solid-other', 174 | 161: 'stairs', 175 | 162: 'stone', 176 | 163: 'straw', 177 | 164: 'structural-other', 178 | 165: 'table', 179 | 166: 'tent', 180 | 167: 'textile-other', 181 | 168: 'towel', 182 | 169: 'tree', 183 | 170: 'vegetable', 184 | 171: 'wall-brick', 185 | 172: 'wall-concrete', 186 | 173: 'wall-other', 187 | 174: 'wall-panel', 188 | 175: 'wall-stone', 189 | 176: 'wall-tile', 190 | 177: 'wall-wood', 191 | 178: 'water-other', 192 | 179: 'waterdrops', 193 | 180: 'window-blind', 194 | 181: 'window-other', 195 | 182: 'wood'} 196 | if id in labelmap: 197 | return labelmap[id] 198 | else: 199 | return 'unknown' 200 | -------------------------------------------------------------------------------- /lib/SPADE-master/util/html.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | import datetime 7 | import dominate 8 | from dominate.tags import * 9 | import os 10 | 11 | 12 | class HTML: 13 | def __init__(self, web_dir, title, refresh=0): 14 | if web_dir.endswith('.html'): 15 | web_dir, html_name = os.path.split(web_dir) 16 | else: 17 | web_dir, html_name = web_dir, 'index.html' 18 | self.title = title 19 | self.web_dir = web_dir 20 | self.html_name = html_name 21 | self.img_dir = os.path.join(self.web_dir, 'images') 22 | if len(self.web_dir) > 0 and not os.path.exists(self.web_dir): 23 | os.makedirs(self.web_dir) 24 | if len(self.web_dir) > 0 and not os.path.exists(self.img_dir): 25 | os.makedirs(self.img_dir) 26 | 27 | self.doc = dominate.document(title=title) 28 | with self.doc: 29 | h1(datetime.datetime.now().strftime("%I:%M%p on %B %d, %Y")) 30 | if refresh > 0: 31 | with self.doc.head: 32 | meta(http_equiv="refresh", content=str(refresh)) 33 | 34 | def get_image_dir(self): 35 | return self.img_dir 36 | 37 | def add_header(self, str): 38 | with self.doc: 39 | h3(str) 40 | 41 | def add_table(self, border=1): 42 | self.t = table(border=border, style="table-layout: fixed;") 43 | self.doc.add(self.t) 44 | 45 | def add_images(self, ims, txts, links, width=512): 46 | self.add_table() 47 | with self.t: 48 | with tr(): 49 | for im, txt, link in zip(ims, txts, links): 50 | with td(style="word-wrap: break-word;", halign="center", valign="top"): 51 | with p(): 52 | with a(href=os.path.join('images', link)): 53 | img(style="width:%dpx" % (width), src=os.path.join('images', im)) 54 | br() 55 | p(txt.encode('utf-8')) 56 | 57 | def save(self): 58 | html_file = os.path.join(self.web_dir, self.html_name) 59 | f = open(html_file, 'wt') 60 | f.write(self.doc.render()) 61 | f.close() 62 | 63 | 64 | if __name__ == '__main__': 65 | html = HTML('web/', 'test_html') 66 | html.add_header('hello world') 67 | 68 | ims = [] 69 | txts = [] 70 | links = [] 71 | for n in range(4): 72 | ims.append('image_%d.jpg' % n) 73 | txts.append('text_%d' % n) 74 | links.append('image_%d.jpg' % n) 75 | html.add_images(ims, txts, links) 76 | html.save() 77 | -------------------------------------------------------------------------------- /lib/SPADE-master/util/iter_counter.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | import os 7 | import time 8 | import numpy as np 9 | 10 | 11 | # Helper class that keeps track of training iterations 12 | class IterationCounter(): 13 | def __init__(self, opt, dataset_size): 14 | self.opt = opt 15 | self.dataset_size = dataset_size 16 | 17 | self.first_epoch = 1 18 | self.total_epochs = opt.niter + opt.niter_decay 19 | self.epoch_iter = 0 # iter number within each epoch 20 | self.iter_record_path = os.path.join(self.opt.checkpoints_dir, self.opt.name, 'iter.txt') 21 | if opt.isTrain and opt.continue_train: 22 | try: 23 | self.first_epoch, self.epoch_iter = np.loadtxt( 24 | self.iter_record_path, delimiter=',', dtype=int) 25 | print('Resuming from epoch %d at iteration %d' % (self.first_epoch, self.epoch_iter)) 26 | except: 27 | print('Could not load iteration record at %s. Starting from beginning.' % 28 | self.iter_record_path) 29 | 30 | self.total_steps_so_far = (self.first_epoch - 1) * dataset_size + self.epoch_iter 31 | 32 | # return the iterator of epochs for the training 33 | def training_epochs(self): 34 | return range(self.first_epoch, self.total_epochs + 1) 35 | 36 | def record_epoch_start(self, epoch): 37 | self.epoch_start_time = time.time() 38 | self.epoch_iter = 0 39 | self.last_iter_time = time.time() 40 | self.current_epoch = epoch 41 | 42 | def record_one_iteration(self): 43 | current_time = time.time() 44 | 45 | # the last remaining batch is dropped (see data/__init__.py), 46 | # so we can assume batch size is always opt.batchSize 47 | self.time_per_iter = (current_time - self.last_iter_time) / self.opt.batchSize 48 | self.last_iter_time = current_time 49 | self.total_steps_so_far += self.opt.batchSize 50 | self.epoch_iter += self.opt.batchSize 51 | 52 | def record_epoch_end(self): 53 | current_time = time.time() 54 | self.time_per_epoch = current_time - self.epoch_start_time 55 | print('End of epoch %d / %d \t Time Taken: %d sec' % 56 | (self.current_epoch, self.total_epochs, self.time_per_epoch)) 57 | if self.current_epoch % self.opt.save_epoch_freq == 0: 58 | np.savetxt(self.iter_record_path, (self.current_epoch + 1, 0), 59 | delimiter=',', fmt='%d') 60 | print('Saved current iteration count at %s.' % self.iter_record_path) 61 | 62 | def record_current_iter(self): 63 | np.savetxt(self.iter_record_path, (self.current_epoch, self.epoch_iter), 64 | delimiter=',', fmt='%d') 65 | print('Saved current iteration count at %s.' % self.iter_record_path) 66 | 67 | def needs_saving(self): 68 | return (self.total_steps_so_far % self.opt.save_latest_freq) < self.opt.batchSize 69 | 70 | def needs_printing(self): 71 | return (self.total_steps_so_far % self.opt.print_freq) < self.opt.batchSize 72 | 73 | def needs_displaying(self): 74 | return (self.total_steps_so_far % self.opt.display_freq) < self.opt.batchSize 75 | -------------------------------------------------------------------------------- /lib/SPADE-master/util/visualizer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | import os 7 | import ntpath 8 | import time 9 | from . import util 10 | from . import html 11 | import scipy.misc 12 | try: 13 | from StringIO import StringIO # Python 2.7 14 | except ImportError: 15 | from io import BytesIO # Python 3.x 16 | 17 | class Visualizer(): 18 | def __init__(self, opt): 19 | self.opt = opt 20 | self.tf_log = opt.isTrain and opt.tf_log 21 | self.use_html = opt.isTrain and not opt.no_html 22 | self.win_size = opt.display_winsize 23 | self.name = opt.name 24 | if self.tf_log: 25 | import tensorflow as tf 26 | self.tf = tf 27 | self.log_dir = os.path.join(opt.checkpoints_dir, opt.name, 'logs') 28 | self.writer = tf.summary.FileWriter(self.log_dir) 29 | 30 | if self.use_html: 31 | self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web') 32 | self.img_dir = os.path.join(self.web_dir, 'images') 33 | print('create web directory %s...' % self.web_dir) 34 | util.mkdirs([self.web_dir, self.img_dir]) 35 | if opt.isTrain: 36 | self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt') 37 | with open(self.log_name, "a") as log_file: 38 | now = time.strftime("%c") 39 | log_file.write('================ Training Loss (%s) ================\n' % now) 40 | 41 | # |visuals|: dictionary of images to display or save 42 | def display_current_results(self, visuals, epoch, step): 43 | 44 | ## convert tensors to numpy arrays 45 | visuals = self.convert_visuals_to_numpy(visuals) 46 | 47 | if self.tf_log: # show images in tensorboard output 48 | img_summaries = [] 49 | for label, image_numpy in visuals.items(): 50 | # Write the image to a string 51 | try: 52 | s = StringIO() 53 | except: 54 | s = BytesIO() 55 | if len(image_numpy.shape) >= 4: 56 | image_numpy = image_numpy[0] 57 | scipy.misc.toimage(image_numpy).save(s, format="jpeg") 58 | # Create an Image object 59 | img_sum = self.tf.Summary.Image(encoded_image_string=s.getvalue(), height=image_numpy.shape[0], width=image_numpy.shape[1]) 60 | # Create a Summary value 61 | img_summaries.append(self.tf.Summary.Value(tag=label, image=img_sum)) 62 | 63 | # Create and write Summary 64 | summary = self.tf.Summary(value=img_summaries) 65 | self.writer.add_summary(summary, step) 66 | 67 | if self.use_html: # save images to a html file 68 | for label, image_numpy in visuals.items(): 69 | if isinstance(image_numpy, list): 70 | for i in range(len(image_numpy)): 71 | img_path = os.path.join(self.img_dir, 'epoch%.3d_iter%.3d_%s_%d.png' % (epoch, step, label, i)) 72 | util.save_image(image_numpy[i], img_path) 73 | else: 74 | img_path = os.path.join(self.img_dir, 'epoch%.3d_iter%.3d_%s.png' % (epoch, step, label)) 75 | if len(image_numpy.shape) >= 4: 76 | image_numpy = image_numpy[0] 77 | util.save_image(image_numpy, img_path) 78 | 79 | # update website 80 | webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=5) 81 | for n in range(epoch, 0, -1): 82 | webpage.add_header('epoch [%d]' % n) 83 | ims = [] 84 | txts = [] 85 | links = [] 86 | 87 | for label, image_numpy in visuals.items(): 88 | if isinstance(image_numpy, list): 89 | for i in range(len(image_numpy)): 90 | img_path = 'epoch%.3d_iter%.3d_%s_%d.png' % (n, step, label, i) 91 | ims.append(img_path) 92 | txts.append(label+str(i)) 93 | links.append(img_path) 94 | else: 95 | img_path = 'epoch%.3d_iter%.3d_%s.png' % (n, step, label) 96 | ims.append(img_path) 97 | txts.append(label) 98 | links.append(img_path) 99 | if len(ims) < 10: 100 | webpage.add_images(ims, txts, links, width=self.win_size) 101 | else: 102 | num = int(round(len(ims)/2.0)) 103 | webpage.add_images(ims[:num], txts[:num], links[:num], width=self.win_size) 104 | webpage.add_images(ims[num:], txts[num:], links[num:], width=self.win_size) 105 | webpage.save() 106 | 107 | # errors: dictionary of error labels and values 108 | def plot_current_errors(self, errors, step): 109 | if self.tf_log: 110 | for tag, value in errors.items(): 111 | value = value.mean().float() 112 | summary = self.tf.Summary(value=[self.tf.Summary.Value(tag=tag, simple_value=value)]) 113 | self.writer.add_summary(summary, step) 114 | 115 | # errors: same format as |errors| of plotCurrentErrors 116 | def print_current_errors(self, epoch, i, errors, t): 117 | message = '(epoch: %d, iters: %d, time: %.3f) ' % (epoch, i, t) 118 | for k, v in errors.items(): 119 | #print(v) 120 | #if v != 0: 121 | v = v.mean().float() 122 | message += '%s: %.3f ' % (k, v) 123 | 124 | print(message) 125 | with open(self.log_name, "a") as log_file: 126 | log_file.write('%s\n' % message) 127 | 128 | def convert_visuals_to_numpy(self, visuals): 129 | for key, t in visuals.items(): 130 | tile = self.opt.batchSize > 8 131 | if 'input_label' == key: 132 | t = util.tensor2label(t, self.opt.label_nc + 2, tile=tile) 133 | else: 134 | t = util.tensor2im(t, tile=tile) 135 | visuals[key] = t 136 | return visuals 137 | 138 | # save image to the disk 139 | def save_images(self, webpage, visuals, image_path): 140 | visuals = self.convert_visuals_to_numpy(visuals) 141 | 142 | image_dir = webpage.get_image_dir() 143 | short_path = ntpath.basename(image_path[0]) 144 | name = os.path.splitext(short_path)[0] 145 | 146 | webpage.add_header(name) 147 | ims = [] 148 | txts = [] 149 | links = [] 150 | 151 | for label, image_numpy in visuals.items(): 152 | image_name = os.path.join(label, '%s.png' % (name)) 153 | save_path = os.path.join(image_dir, image_name) 154 | util.save_image(image_numpy, save_path, create_dir=True) 155 | 156 | ims.append(image_name) 157 | txts.append(label) 158 | links.append(image_name) 159 | webpage.add_images(ims, txts, links, width=self.win_size) 160 | -------------------------------------------------------------------------------- /lib/model_server/latest_net_G.yaml: -------------------------------------------------------------------------------- 1 | # run 102: https://web.spell.ml/ResidentMario/runs/102 2 | name: GuaGAN_Bob_Ross_From_ADE20K_Landscapes_No_VAE 3 | aspect_ratio: 1.0 4 | batchSize: 1 5 | checkpoints_dir: /model/ 6 | contain_dontcare_label: True 7 | crop_size: 256 8 | dataset_mode: custom 9 | gan_mode: hinge 10 | gpu_ids: [] 11 | init_type: xavier 12 | init_variance: 0.02 13 | isTrain: False 14 | label_nc: 150 15 | load_size: 256 16 | model: pix2pix 17 | netG: spade 18 | ngf: 64 19 | no_instance: True 20 | norm_D: spectralinstance 21 | norm_E: spectralinstance 22 | norm_G: spectralspadesyncbatch3x3 23 | num_upsampling_layers: normal 24 | phase: test 25 | preprocess_mode: resize_and_crop 26 | semantic_nc: 151 27 | use_vae: False 28 | which_epoch: latest 29 | z_dim: 256 -------------------------------------------------------------------------------- /lib/model_server/requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.0.0 2 | torchvision 3 | dominate>=2.3.1 4 | dill 5 | scikit-image 6 | numpy 7 | Pillow 8 | pyyaml 9 | # starlette -------------------------------------------------------------------------------- /lib/model_server/scripts/load_test.py: -------------------------------------------------------------------------------- 1 | import time 2 | import argparse 3 | import pathlib 4 | import os 5 | import json 6 | 7 | from concurrent.futures import ThreadPoolExecutor 8 | from requests_futures.sessions import FuturesSession 9 | 10 | # To load test I ran this script whilst simultaneously interacting with the application 11 | # interactively to see how it was responding. 12 | # Tests tried: 13 | # * from 15 to 150 concurrent users (no autoscaling) 14 | # * from 100 to 300 concurrent users (autoscales to two pods) 15 | parser = argparse.ArgumentParser(description='Process some integers.') 16 | parser.add_argument('--max-concurrent-users', dest='max_concurrent_users', type=int, default=150) 17 | parser.add_argument('--min-concurrent-users', dest='min_concurrent_users', type=int, default=15) 18 | 19 | ENDPOINT = "http://spell-org.spell-org.spell.services/spell-org/paint_with_ml/predict" 20 | with open(pathlib.Path(os.getcwd()).parent / 'test_payload.json', 'r') as fp: 21 | PAYLOAD = json.load(fp) 22 | 23 | args = parser.parse_args() 24 | max_requests_per_second = args.max_concurrent_users // 15 25 | min_requests_per_second = args.min_concurrent_users // 15 26 | curr_requests_per_second = min_requests_per_second 27 | already_peeked = False 28 | t = 0 29 | 30 | # assuming a maximum of 8 seconds of latency (serving on GPU averages 4 seconds) 31 | session = FuturesSession(executor=ThreadPoolExecutor(max_workers=8 * max_requests_per_second)) 32 | 33 | while True: 34 | t += 1 35 | # we can't inspect the response for errors because .result() is a blocking function so 36 | # we're relying on model server metrics to tell us how we're doing 37 | for _ in range(curr_requests_per_second): 38 | _ = session.get(ENDPOINT) 39 | 40 | if t % 15 == 0: 41 | if not already_peeked: 42 | curr_requests_per_second += 1 43 | if curr_requests_per_second == max_requests_per_second: 44 | already_peeked = True 45 | else: 46 | curr_requests_per_second -= 1 47 | if curr_requests_per_second == 0: 48 | break 49 | 50 | print(f"Sent {curr_requests_per_second} requests at time {t}. Sleeping for 1 second...") 51 | 52 | # this assumes that making the request is instantaneous, which, for small enough volumes, it 53 | # basically is 54 | time.sleep(1) 55 | -------------------------------------------------------------------------------- /lib/model_server/server.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from PIL import Image 4 | 5 | import os 6 | import yaml 7 | import io 8 | import base64 9 | import json 10 | 11 | from spell.serving import BasePredictor 12 | 13 | if "MODEL_CONFIG_PATH" in os.environ: 14 | model_config_path = os.environ["MODEL_CONFIG_PATH"] 15 | else: 16 | model_config_path = "/model/latest_net_G.yaml" 17 | 18 | if "SPADE_CODE_DIR" in os.environ: 19 | spade_dir = os.environ["SPADE_CODE_DIR"] 20 | else: 21 | spade_dir = "../SPADE-master/" 22 | import sys; sys.path.append(spade_dir) 23 | 24 | from models.pix2pix_model import Pix2PixModel 25 | from options.test_options import TestOptions 26 | 27 | 28 | class Predictor(BasePredictor): 29 | def __init__(self): 30 | opt = TestOptions() 31 | with open(model_config_path, 'r') as fp: 32 | saved_opts = yaml.safe_load(fp) 33 | for _opt in saved_opts: 34 | setattr(opt, _opt, saved_opts[_opt]) 35 | 36 | model = Pix2PixModel(opt) 37 | model.eval() 38 | self.model = model 39 | 40 | # No longer need SPADE-master on PYTHONPATH. 41 | sys.path.pop() 42 | 43 | def png_data_uri_to_batch_tensor(self, data_uri): 44 | # NOTE(aleksey): keys must match those set in app.js. Values are labels from ADE20K. 45 | color_key = { 46 | (245, 216, 122, 255): 3, # sky 47 | (36, 207, 156, 255): 5, # tree 48 | (236, 118, 142, 255): 18, # plant 49 | (210, 250, 255, 255): 10, # grass 50 | (174, 162, 177, 255): 14, # rock 51 | (245, 147, 34, 255): 17, # mountain 52 | (68, 202, 218, 255): 61, # river 53 | (176, 50, 235, 255): 22, # lake 54 | (138, 115, 227, 255): 27, # ocean 55 | } 56 | # base64 encoded PNG -> PIL image -> array 57 | data_uri = data_uri[data_uri.find(',') + 1:] 58 | segmap_c = np.array( 59 | Image.open(io.BytesIO(base64.b64decode(data_uri))) 60 | ) 61 | segmap = np.zeros((512, 512)) 62 | for x in range(512): 63 | for y in range(512): 64 | c = tuple(segmap_c[x][y]) 65 | segmap[x][y] = color_key[c] 66 | del segmap_c 67 | tensor = torch.tensor(segmap[None, None]).float() 68 | return tensor 69 | 70 | def batch_tensor_to_png_data_uri(self, tensor): 71 | # [-1, 1] batch tensor -> [0, 255] image array -> PIL image 72 | result = Image.fromarray( 73 | ((tensor.squeeze().numpy() + 1) / 2 * 255).transpose(1, 2, 0).astype('uint8') 74 | ) 75 | 76 | # Package into a base64 PNG and return 77 | out = io.BytesIO() 78 | result.save(out, format='PNG') 79 | out = 'data:image/png;base64,'.encode('utf-8') + base64.b64encode(out.getvalue()) 80 | return out 81 | 82 | def get_prediction(self, data_uri): 83 | label = self.png_data_uri_to_batch_tensor(data_uri) 84 | 85 | result = self.model({ 86 | 'label': label, 87 | 'instance': torch.tensor([0]), 88 | 'image': torch.tensor(np.zeros((1, 3, 512, 512))).float(), 89 | 'path': ['~/'] 90 | }, mode='inference') 91 | 92 | out = self.batch_tensor_to_png_data_uri(result) 93 | return out 94 | 95 | def predict(self, payload): 96 | # NOTE: this endpoint should return a body of {'image': }. 97 | image = self.get_prediction(payload['segmap']) 98 | image = image.decode('utf-8') 99 | return json.dumps({"image": image}) 100 | -------------------------------------------------------------------------------- /lib/model_server/test_payload.json: -------------------------------------------------------------------------------- 1 | {"segmap":""} -------------------------------------------------------------------------------- /models/README.md: -------------------------------------------------------------------------------- 1 | # models 2 | 3 | This folder contains all of the models built over the course of this project. 4 | 5 | The GuaGAN model is designed to be used from the command line, and uses CLI-based `argparse` configuration. Internally all this does within the Python code is create a `TrainOptions` or `TestOptions` object with the right fields attached. To execute the model from a Python script, we create and configure our own `TrainOptions` with the correct options set. 6 | 7 | A machine learning project will involve many iterations of model training. You will start with a simple model, one which either replicates an expected result or produces a simple output. That models grows in complexity as you try new things, keeping what works and discarding what doesn't. The collective list of models you train this way is your **model history**. 8 | -------------------------------------------------------------------------------- /models/model_0.py: -------------------------------------------------------------------------------- 1 | # This is a GuaGAN training run with all-default settings trained on the Bob Ross image corpus. 2 | 3 | import sys; sys.path.append('../lib/SPADE-master/') 4 | from options.train_options import TrainOptions 5 | from models.pix2pix_model import Pix2PixModel 6 | from collections import OrderedDict 7 | import data 8 | from util.iter_counter import IterationCounter 9 | from util.visualizer import Visualizer 10 | from trainers.pix2pix_trainer import Pix2PixTrainer 11 | import os 12 | 13 | opt = TrainOptions() 14 | opt.D_steps_per_G = 1 15 | opt.aspect_ratio = 1.0 16 | opt.batchSize = 1 17 | opt.beta1 = 0.0 18 | opt.beta2 = 0.9 19 | opt.cache_filelist_read = False 20 | opt.cache_filelist_write = False 21 | opt.checkpoints_dir = '/spell/checkpoints/' 22 | opt.contain_dontcare_label = False 23 | opt.continue_train = False 24 | opt.crop_size = 256 25 | opt.dataroot = '/spell/bob_ross_segmented/' # data mount point 26 | opt.dataset_mode = 'custom' 27 | opt.debug = False 28 | opt.display_freq = 100 29 | opt.display_winsize = 256 30 | opt.fff = 1 # junk value for the argparse 31 | opt.gan_mode = 'hinge' 32 | opt.init_type = 'xavier' 33 | opt.init_variance = 0.02 34 | opt.isTrain = True 35 | opt.lambda_feat = 10.0 36 | opt.lambda_kld = 0.05 37 | opt.lambda_vgg = 10.0 38 | opt.load_from_opt_file = False 39 | opt.load_size = 286 # should this be 256? 40 | opt.lr = 0.0002 41 | opt.max_dataset_size = 9223372036854775807 42 | opt.model = 'pix2pix' 43 | opt.nThreads = 0 44 | opt.n_layers_D = 4 45 | opt.name = 'bob_ross' 46 | opt.ndf = 64 47 | opt.nef = 16 48 | opt.netD = 'multiscale' 49 | opt.netD_subarch = 'n_layer' 50 | opt.netG = 'spade' 51 | opt.ngf = 64 52 | opt.niter = 50 53 | opt.niter_decay = 0 54 | opt.no_TTUR = False 55 | opt.no_flip = False 56 | opt.no_ganFeat_loss = False 57 | opt.no_html = True 58 | opt.no_instance = True 59 | opt.no_pairing_check = False 60 | opt.no_vgg_loss = False 61 | opt.norm_D = 'spectralinstance' 62 | opt.norm_E = 'spectralinstance' 63 | opt.norm_G = 'spectralspadesyncbatch3x3' 64 | opt.num_D = 2 65 | opt.num_upsampling_layers = 'normal' 66 | opt.optimizer = 'adam' 67 | opt.output_nc = 3 68 | opt.phase = 'train' 69 | opt.preprocess_mode = 'resize_and_crop' 70 | opt.print_freq = 100 71 | opt.save_epoch_freq = 10 72 | opt.save_latest_freq = 5000 73 | opt.serial_batches = False 74 | opt.tf_log = False 75 | opt.use_vae = False 76 | opt.which_epoch = 'latest' 77 | opt.z_dim = 256 78 | opt.gpu_ids=[0] 79 | opt.results_dir='../data/SPADE_from_scratch_results/' 80 | opt.semantic_nc = 9 81 | opt.label_nc = 9 82 | opt.label_dir = '/spell/bob_ross_segmented/training/labels/' 83 | opt.image_dir = '/spell/bob_ross_segmented/training/images/' 84 | opt.instance_dir = '' 85 | 86 | # Create the folder structure expected by the model checkpointing feature. 87 | if not os.path.exists('/spell/checkpoints/'): 88 | os.mkdir('/spell/checkpoints/') 89 | if not os.path.exists('/spell/checkpoints/bob_ross/'): 90 | os.mkdir('/spell/checkpoints/bob_ross/') 91 | 92 | model = Pix2PixModel(opt) 93 | model.train() 94 | 95 | def test_train(): 96 | # print options to help debugging 97 | # print(' '.join(sys.argv)) 98 | 99 | # load the dataset 100 | dataloader = data.create_dataloader(opt) 101 | 102 | # create trainer for our model 103 | trainer = Pix2PixTrainer(opt) 104 | 105 | # create tool for counting iterations 106 | iter_counter = IterationCounter(opt, len(dataloader)) 107 | 108 | # create tool for visualization 109 | visualizer = Visualizer(opt) 110 | 111 | for epoch in iter_counter.training_epochs(): 112 | iter_counter.record_epoch_start(epoch) 113 | for i, data_i in enumerate(dataloader, start=iter_counter.epoch_iter): 114 | iter_counter.record_one_iteration() 115 | 116 | # Training 117 | # train generator 118 | if i % opt.D_steps_per_G == 0: 119 | trainer.run_generator_one_step(data_i) 120 | 121 | # train discriminator 122 | trainer.run_discriminator_one_step(data_i) 123 | 124 | # Visualizations 125 | if iter_counter.needs_printing(): 126 | losses = trainer.get_latest_losses() 127 | visualizer.print_current_errors(epoch, iter_counter.epoch_iter, 128 | losses, iter_counter.time_per_iter) 129 | visualizer.plot_current_errors(losses, iter_counter.total_steps_so_far) 130 | 131 | if iter_counter.needs_displaying(): 132 | visuals = OrderedDict([('input_label', data_i['label']), 133 | ('synthesized_image', trainer.get_latest_generated()), 134 | ('real_image', data_i['image'])]) 135 | visualizer.display_current_results(visuals, epoch, iter_counter.total_steps_so_far) 136 | 137 | if iter_counter.needs_saving(): 138 | print('saving the latest model (epoch %d, total_steps %d)' % 139 | (epoch, iter_counter.total_steps_so_far)) 140 | trainer.save('latest') 141 | iter_counter.record_current_iter() 142 | 143 | trainer.update_learning_rate(epoch) 144 | iter_counter.record_epoch_end() 145 | 146 | if epoch % opt.save_epoch_freq == 0 or \ 147 | epoch == iter_counter.total_epochs: 148 | print('saving the model at the end of epoch %d, iters %d' % 149 | (epoch, iter_counter.total_steps_so_far)) 150 | trainer.save('latest') 151 | trainer.save(epoch) 152 | 153 | print('Training was successfully finished.') 154 | 155 | test_train() 156 | -------------------------------------------------------------------------------- /models/model_1.py: -------------------------------------------------------------------------------- 1 | # Changes relative to model_0.py: 2 | # * Decreased load_size from 286 to 256, matching crop_size. I do not think edge cropping is 3 | # useful with this dataset. 4 | # * Changed the results target to '/spell/bob_ross_segmented/results/' (though this value 5 | # doesn't do anything because opt.no_html = True). 6 | # * Set opt.use_vae = True to enable the VAE layer; the intent of this run is to evaluate this 7 | # training option. 8 | 9 | import sys; sys.path.append('../lib/SPADE-master/') 10 | from options.train_options import TrainOptions 11 | from models.pix2pix_model import Pix2PixModel 12 | from collections import OrderedDict 13 | import data 14 | from util.iter_counter import IterationCounter 15 | from util.visualizer import Visualizer 16 | from trainers.pix2pix_trainer import Pix2PixTrainer 17 | import os 18 | 19 | opt = TrainOptions() 20 | opt.D_steps_per_G = 1 21 | opt.aspect_ratio = 1.0 22 | opt.batchSize = 1 23 | opt.beta1 = 0.0 24 | opt.beta2 = 0.9 25 | opt.cache_filelist_read = False 26 | opt.cache_filelist_write = False 27 | opt.checkpoints_dir = '/spell/checkpoints/' 28 | opt.contain_dontcare_label = False 29 | opt.continue_train = False 30 | opt.crop_size = 256 31 | opt.dataroot = '/spell/bob_ross_segmented/' # data mount point 32 | opt.dataset_mode = 'custom' 33 | opt.debug = False 34 | opt.display_freq = 100 35 | opt.display_winsize = 256 36 | opt.fff = 1 # junk value for the argparse 37 | opt.gan_mode = 'hinge' 38 | opt.init_type = 'xavier' 39 | opt.init_variance = 0.02 40 | opt.isTrain = True 41 | opt.lambda_feat = 10.0 42 | opt.lambda_kld = 0.05 43 | opt.lambda_vgg = 10.0 44 | opt.load_from_opt_file = False 45 | opt.load_size = 256 46 | opt.lr = 0.0002 47 | opt.max_dataset_size = 9223372036854775807 48 | opt.model = 'pix2pix' 49 | opt.nThreads = 0 50 | opt.n_layers_D = 4 51 | opt.name = 'bob_ross' 52 | opt.ndf = 64 53 | opt.nef = 16 54 | opt.netD = 'multiscale' 55 | opt.netD_subarch = 'n_layer' 56 | opt.netG = 'spade' 57 | opt.ngf = 64 58 | opt.niter = 50 59 | opt.niter_decay = 0 60 | opt.no_TTUR = False 61 | opt.no_flip = False 62 | opt.no_ganFeat_loss = False 63 | opt.no_html = True 64 | opt.no_instance = True 65 | opt.no_pairing_check = False 66 | opt.no_vgg_loss = False 67 | opt.norm_D = 'spectralinstance' 68 | opt.norm_E = 'spectralinstance' 69 | opt.norm_G = 'spectralspadesyncbatch3x3' 70 | opt.num_D = 2 71 | opt.num_upsampling_layers = 'normal' 72 | opt.optimizer = 'adam' 73 | opt.output_nc = 3 74 | opt.phase = 'train' 75 | opt.preprocess_mode = 'resize_and_crop' 76 | opt.print_freq = 100 77 | opt.save_epoch_freq = 10 78 | opt.save_latest_freq = 5000 79 | opt.serial_batches = False 80 | opt.tf_log = False 81 | opt.use_vae = True 82 | opt.which_epoch = 'latest' 83 | opt.z_dim = 256 84 | opt.gpu_ids=[0] 85 | opt.results_dir='/spell/bob_ross_segmented/results/' 86 | opt.semantic_nc = 9 87 | opt.label_nc = 9 88 | opt.label_dir = '/spell/bob_ross_segmented/training/labels/' 89 | opt.image_dir = '/spell/bob_ross_segmented/training/images/' 90 | opt.instance_dir = '' 91 | 92 | # Create the folder structure expected by the model checkpointing feature. 93 | if not os.path.exists('/spell/checkpoints/'): 94 | os.mkdir('/spell/checkpoints/') 95 | if not os.path.exists('/spell/checkpoints/bob_ross/'): 96 | os.mkdir('/spell/checkpoints/bob_ross/') 97 | if not os.path.exists('/spell/bob_ross_segmented/results/'): 98 | os.mkdir('/spell/bob_ross_segmented/results/') 99 | 100 | model = Pix2PixModel(opt) 101 | model.train() 102 | 103 | def test_train(): 104 | # print options to help debugging 105 | # print(' '.join(sys.argv)) 106 | 107 | # load the dataset 108 | dataloader = data.create_dataloader(opt) 109 | 110 | # create trainer for our model 111 | trainer = Pix2PixTrainer(opt) 112 | 113 | # create tool for counting iterations 114 | iter_counter = IterationCounter(opt, len(dataloader)) 115 | 116 | # create tool for visualization 117 | visualizer = Visualizer(opt) 118 | 119 | for epoch in iter_counter.training_epochs(): 120 | iter_counter.record_epoch_start(epoch) 121 | for i, data_i in enumerate(dataloader, start=iter_counter.epoch_iter): 122 | iter_counter.record_one_iteration() 123 | 124 | # Training 125 | # train generator 126 | if i % opt.D_steps_per_G == 0: 127 | trainer.run_generator_one_step(data_i) 128 | 129 | # train discriminator 130 | trainer.run_discriminator_one_step(data_i) 131 | 132 | # Visualizations 133 | if iter_counter.needs_printing(): 134 | losses = trainer.get_latest_losses() 135 | visualizer.print_current_errors(epoch, iter_counter.epoch_iter, 136 | losses, iter_counter.time_per_iter) 137 | visualizer.plot_current_errors(losses, iter_counter.total_steps_so_far) 138 | 139 | if iter_counter.needs_displaying(): 140 | visuals = OrderedDict([('input_label', data_i['label']), 141 | ('synthesized_image', trainer.get_latest_generated()), 142 | ('real_image', data_i['image'])]) 143 | visualizer.display_current_results(visuals, epoch, iter_counter.total_steps_so_far) 144 | 145 | if iter_counter.needs_saving(): 146 | print('saving the latest model (epoch %d, total_steps %d)' % 147 | (epoch, iter_counter.total_steps_so_far)) 148 | trainer.save('latest') 149 | iter_counter.record_current_iter() 150 | 151 | trainer.update_learning_rate(epoch) 152 | iter_counter.record_epoch_end() 153 | 154 | if epoch % opt.save_epoch_freq == 0 or \ 155 | epoch == iter_counter.total_epochs: 156 | print('saving the model at the end of epoch %d, iters %d' % 157 | (epoch, iter_counter.total_steps_so_far)) 158 | trainer.save('latest') 159 | trainer.save(epoch) 160 | 161 | print('Training was successfully finished.') 162 | 163 | test_train() 164 | -------------------------------------------------------------------------------- /models/model_10.py: -------------------------------------------------------------------------------- 1 | # This is the ADE20K pretrain model. This is meant to be trained on a machine with 2 | # 8 GPUs onboard. This model is a test version of the model which only trains for 3 | # a single epoch to verify that this model definition is, indeed, good. 4 | # 5 | # VAE is disabled for this run. Including it in the model would cause style transfer 6 | # to/from just the ADE20K corpus AFAIK, as previous experiments with training 7 | # this model from scratch demonstrated that the size of the Bob Ross corpus is not 8 | # sufficient for the style transfer elements to really kick in. 9 | 10 | import sys; sys.path.append('../lib/SPADE-master/') 11 | from options.train_options import TrainOptions 12 | from models.pix2pix_model import Pix2PixModel 13 | from collections import OrderedDict 14 | import data 15 | from util.iter_counter import IterationCounter 16 | from util.visualizer import Visualizer 17 | from trainers.pix2pix_trainer import Pix2PixTrainer 18 | import os 19 | 20 | opt = TrainOptions() 21 | opt.D_steps_per_G = 1 22 | opt.aspect_ratio = 1.0 23 | opt.batchSize = 8 24 | opt.beta1 = 0.0 25 | opt.beta2 = 0.9 26 | opt.cache_filelist_read = False 27 | opt.cache_filelist_write = False 28 | opt.checkpoints_dir = '/spell/checkpoints/' 29 | opt.contain_dontcare_label = True 30 | opt.continue_train = False 31 | opt.crop_size = 256 32 | opt.dataroot = '/spell/adek20k' # data mount point 33 | opt.dataset_mode = "ade20k" 34 | opt.debug = False 35 | opt.display_freq = 100 36 | opt.display_winsize = 256 37 | opt.fff = 1 # junk value for the argparse 38 | opt.gan_mode = 'hinge' 39 | opt.init_type = 'xavier' 40 | opt.init_variance = 0.02 41 | opt.isTrain = True 42 | opt.label_nc = 150 43 | opt.lambda_feat = 10.0 44 | opt.lambda_kld = 0.05 45 | opt.lambda_vgg = 10.0 46 | opt.load_from_opt_file = False 47 | opt.load_size = 286 # should this be 256? 48 | opt.lr = 0.0002 49 | opt.max_dataset_size = 9223372036854775807 50 | opt.model = 'pix2pix' 51 | opt.nThreads = 0 52 | opt.n_layers_D = 4 53 | opt.name = 'ade20k_pretrained' 54 | opt.ndf = 64 55 | opt.nef = 16 56 | opt.netD = 'multiscale' 57 | opt.netD_subarch = 'n_layer' 58 | opt.netG = 'spade' 59 | opt.ngf = 64 60 | opt.niter = 1 61 | opt.niter_decay = 0 62 | opt.no_TTUR = False 63 | opt.no_flip = False 64 | opt.no_ganFeat_loss = False 65 | opt.no_html = True 66 | opt.no_instance = True 67 | opt.no_pairing_check = False 68 | opt.no_vgg_loss = False 69 | opt.norm_D = 'spectralinstance' 70 | opt.norm_E = 'spectralinstance' 71 | opt.norm_G = 'spectralspadesyncbatch3x3' 72 | opt.num_D = 2 73 | opt.num_upsampling_layers = 'normal' 74 | opt.optimizer = 'adam' 75 | opt.output_nc = 3 76 | opt.phase = 'train' 77 | opt.preprocess_mode = 'resize_and_crop' 78 | opt.print_freq = 100 79 | opt.save_epoch_freq = 10 80 | opt.save_latest_freq = 5000 81 | opt.serial_batches = False 82 | opt.tf_log = True 83 | opt.use_vae = False 84 | opt.which_epoch = 'latest' 85 | opt.z_dim = 256 86 | opt.gpu_ids=[0,1,2,3,4,5,6,7] 87 | opt.results_dir='../data/SPADE_from_scratch_results/' 88 | opt.semantic_nc = 151 89 | 90 | # Create the folder structure expected by the model checkpointing feature. 91 | if not os.path.exists('/spell/checkpoints/'): 92 | os.mkdir('/spell/checkpoints/') 93 | if not os.path.exists('/spell/checkpoints/ade20k_pretrained/'): 94 | os.mkdir('/spell/checkpoints/ade20k_pretrained/') 95 | 96 | model = Pix2PixModel(opt) 97 | model.train() 98 | 99 | def test_train(): 100 | # print options to help debugging 101 | # print(' '.join(sys.argv)) 102 | 103 | # load the dataset 104 | dataloader = data.create_dataloader(opt) 105 | 106 | # create trainer for our model 107 | trainer = Pix2PixTrainer(opt) 108 | 109 | # create tool for counting iterations 110 | iter_counter = IterationCounter(opt, len(dataloader)) 111 | 112 | # create tool for visualization 113 | visualizer = Visualizer(opt) 114 | 115 | for epoch in iter_counter.training_epochs(): 116 | iter_counter.record_epoch_start(epoch) 117 | for i, data_i in enumerate(dataloader, start=iter_counter.epoch_iter): 118 | iter_counter.record_one_iteration() 119 | 120 | # Training 121 | # train generator 122 | if i % opt.D_steps_per_G == 0: 123 | trainer.run_generator_one_step(data_i) 124 | 125 | # train discriminator 126 | trainer.run_discriminator_one_step(data_i) 127 | 128 | # Visualizations 129 | if iter_counter.needs_printing(): 130 | losses = trainer.get_latest_losses() 131 | visualizer.print_current_errors(epoch, iter_counter.epoch_iter, 132 | losses, iter_counter.time_per_iter) 133 | visualizer.plot_current_errors(losses, iter_counter.total_steps_so_far) 134 | 135 | if iter_counter.needs_displaying(): 136 | visuals = OrderedDict([('input_label', data_i['label']), 137 | ('synthesized_image', trainer.get_latest_generated()), 138 | ('real_image', data_i['image'])]) 139 | visualizer.display_current_results(visuals, epoch, iter_counter.total_steps_so_far) 140 | 141 | if iter_counter.needs_saving(): 142 | print('saving the latest model (epoch %d, total_steps %d)' % 143 | (epoch, iter_counter.total_steps_so_far)) 144 | trainer.save('latest') 145 | iter_counter.record_current_iter() 146 | 147 | trainer.update_learning_rate(epoch) 148 | iter_counter.record_epoch_end() 149 | 150 | if epoch % opt.save_epoch_freq == 0 or \ 151 | epoch == iter_counter.total_epochs: 152 | print('saving the model at the end of epoch %d, iters %d' % 153 | (epoch, iter_counter.total_steps_so_far)) 154 | trainer.save('latest') 155 | trainer.save(epoch) 156 | 157 | print('Training was successfully finished.') 158 | 159 | test_train() 160 | -------------------------------------------------------------------------------- /models/model_11.py: -------------------------------------------------------------------------------- 1 | # This is the ADE20K pretrain model. This is meant to be trained on a machine with 2 | # 8 GPUs onboard. This file is the full, 50-epoch model trainer. Warning: this model costs 3 | # ~350$ to train once! 4 | 5 | import sys; sys.path.append('../lib/SPADE-master/') 6 | from options.train_options import TrainOptions 7 | from models.pix2pix_model import Pix2PixModel 8 | from collections import OrderedDict 9 | import data 10 | from util.iter_counter import IterationCounter 11 | from util.visualizer import Visualizer 12 | from trainers.pix2pix_trainer import Pix2PixTrainer 13 | import os 14 | 15 | opt = TrainOptions() 16 | opt.D_steps_per_G = 1 17 | opt.aspect_ratio = 1.0 18 | opt.batchSize = 8 19 | opt.beta1 = 0.0 20 | opt.beta2 = 0.9 21 | opt.cache_filelist_read = False 22 | opt.cache_filelist_write = False 23 | opt.checkpoints_dir = '/spell/checkpoints/' 24 | opt.contain_dontcare_label = True 25 | opt.continue_train = False 26 | opt.crop_size = 256 27 | opt.dataroot = '/spell/adek20k' # data mount point 28 | opt.dataset_mode = "ade20k" 29 | opt.debug = False 30 | opt.display_freq = 100 31 | opt.display_winsize = 256 32 | opt.fff = 1 # junk value for the argparse 33 | opt.gan_mode = 'hinge' 34 | opt.init_type = 'xavier' 35 | opt.init_variance = 0.02 36 | opt.isTrain = True 37 | opt.label_nc = 150 38 | opt.lambda_feat = 10.0 39 | opt.lambda_kld = 0.05 40 | opt.lambda_vgg = 10.0 41 | opt.load_from_opt_file = False 42 | opt.load_size = 286 # should this be 256? 43 | opt.lr = 0.0002 44 | opt.max_dataset_size = 9223372036854775807 45 | opt.model = 'pix2pix' 46 | opt.nThreads = 0 47 | opt.n_layers_D = 4 48 | opt.name = 'ade20k_pretrained' 49 | opt.ndf = 64 50 | opt.nef = 16 51 | opt.netD = 'multiscale' 52 | opt.netD_subarch = 'n_layer' 53 | opt.netG = 'spade' 54 | opt.ngf = 64 55 | opt.niter = 50 56 | opt.niter_decay = 0 57 | opt.no_TTUR = False 58 | opt.no_flip = False 59 | opt.no_ganFeat_loss = False 60 | opt.no_html = True 61 | opt.no_instance = True 62 | opt.no_pairing_check = False 63 | opt.no_vgg_loss = False 64 | opt.norm_D = 'spectralinstance' 65 | opt.norm_E = 'spectralinstance' 66 | opt.norm_G = 'spectralspadesyncbatch3x3' 67 | opt.num_D = 2 68 | opt.num_upsampling_layers = 'normal' 69 | opt.optimizer = 'adam' 70 | opt.output_nc = 3 71 | opt.phase = 'train' 72 | opt.preprocess_mode = 'resize_and_crop' 73 | opt.print_freq = 100 74 | opt.save_epoch_freq = 10 75 | opt.save_latest_freq = 5000 76 | opt.serial_batches = False 77 | opt.tf_log = True 78 | opt.use_vae = False 79 | opt.which_epoch = 'latest' 80 | opt.z_dim = 256 81 | opt.gpu_ids=[0,1,2,3,4,5,6,7] 82 | opt.results_dir='../data/SPADE_from_scratch_results/' 83 | opt.semantic_nc = 151 84 | 85 | # Create the folder structure expected by the model checkpointing feature. 86 | if not os.path.exists('/spell/checkpoints/'): 87 | os.mkdir('/spell/checkpoints/') 88 | if not os.path.exists('/spell/checkpoints/ade20k_pretrained/'): 89 | os.mkdir('/spell/checkpoints/ade20k_pretrained/') 90 | 91 | model = Pix2PixModel(opt) 92 | model.train() 93 | 94 | def test_train(): 95 | # print options to help debugging 96 | # print(' '.join(sys.argv)) 97 | 98 | # load the dataset 99 | dataloader = data.create_dataloader(opt) 100 | 101 | # create trainer for our model 102 | trainer = Pix2PixTrainer(opt) 103 | 104 | # create tool for counting iterations 105 | iter_counter = IterationCounter(opt, len(dataloader)) 106 | 107 | # create tool for visualization 108 | visualizer = Visualizer(opt) 109 | 110 | for epoch in iter_counter.training_epochs(): 111 | iter_counter.record_epoch_start(epoch) 112 | for i, data_i in enumerate(dataloader, start=iter_counter.epoch_iter): 113 | iter_counter.record_one_iteration() 114 | 115 | # Training 116 | # train generator 117 | if i % opt.D_steps_per_G == 0: 118 | trainer.run_generator_one_step(data_i) 119 | 120 | # train discriminator 121 | trainer.run_discriminator_one_step(data_i) 122 | 123 | # Visualizations 124 | if iter_counter.needs_printing(): 125 | losses = trainer.get_latest_losses() 126 | visualizer.print_current_errors(epoch, iter_counter.epoch_iter, 127 | losses, iter_counter.time_per_iter) 128 | visualizer.plot_current_errors(losses, iter_counter.total_steps_so_far) 129 | 130 | if iter_counter.needs_displaying(): 131 | visuals = OrderedDict([('input_label', data_i['label']), 132 | ('synthesized_image', trainer.get_latest_generated()), 133 | ('real_image', data_i['image'])]) 134 | visualizer.display_current_results(visuals, epoch, iter_counter.total_steps_so_far) 135 | 136 | if iter_counter.needs_saving(): 137 | print('saving the latest model (epoch %d, total_steps %d)' % 138 | (epoch, iter_counter.total_steps_so_far)) 139 | trainer.save('latest') 140 | iter_counter.record_current_iter() 141 | 142 | trainer.update_learning_rate(epoch) 143 | iter_counter.record_epoch_end() 144 | 145 | if epoch % opt.save_epoch_freq == 0 or \ 146 | epoch == iter_counter.total_epochs: 147 | print('saving the model at the end of epoch %d, iters %d' % 148 | (epoch, iter_counter.total_steps_so_far)) 149 | trainer.save('latest') 150 | trainer.save(epoch) 151 | 152 | print('Training was successfully finished.') 153 | 154 | test_train() 155 | -------------------------------------------------------------------------------- /models/model_13.py: -------------------------------------------------------------------------------- 1 | # This is the Bob Ross model trained on the ADE20K pretrain output. No iterative freezing 2 | # and unfreezing is performed; instead, training is performed for 50 epochs with a fixed 3 | # training rate which is 1/10th of the default. Compare with model_12.py. 4 | 5 | import sys; sys.path.append('../lib/SPADE-master/') 6 | from options.train_options import TrainOptions 7 | from models.pix2pix_model import Pix2PixModel 8 | from collections import OrderedDict 9 | import data 10 | from util.iter_counter import IterationCounter 11 | from util.visualizer import Visualizer 12 | from trainers.pix2pix_trainer import Pix2PixTrainer 13 | import os 14 | 15 | opt = TrainOptions() 16 | opt.D_steps_per_G = 1 17 | opt.aspect_ratio = 1.0 18 | opt.batchSize = 8 19 | opt.beta1 = 0.0 20 | opt.beta2 = 0.9 21 | opt.cache_filelist_read = False 22 | opt.cache_filelist_write = False 23 | opt.checkpoints_dir = '/spell/checkpoints/' 24 | opt.contain_dontcare_label = True 25 | opt.continue_train = True 26 | opt.crop_size = 256 27 | opt.dataroot = '/spell/bob_ross_segmented/' # data mount point 28 | opt.dataset_mode = 'custom' 29 | opt.debug = False 30 | opt.display_freq = 100 31 | opt.display_winsize = 256 32 | opt.fff = 1 # junk value for the argparse 33 | opt.gan_mode = 'hinge' 34 | opt.init_type = 'xavier' 35 | opt.init_variance = 0.02 36 | opt.isTrain = True 37 | opt.label_nc = 150 38 | opt.lambda_feat = 10.0 39 | opt.lambda_kld = 0.05 40 | opt.lambda_vgg = 10.0 41 | opt.load_from_opt_file = False 42 | opt.load_size = 256 43 | opt.lr = 0.0002 44 | opt.max_dataset_size = 9223372036854775807 45 | opt.model = 'pix2pix' 46 | opt.nThreads = 0 47 | opt.n_layers_D = 4 48 | opt.name = 'bob_ross_x_ade20k_outdoors' 49 | opt.ndf = 64 50 | opt.nef = 16 51 | opt.netD = 'multiscale' 52 | opt.netD_subarch = 'n_layer' 53 | opt.netG = 'spade' 54 | opt.ngf = 64 55 | opt.niter = 50 56 | opt.niter_decay = 0 57 | opt.no_TTUR = False 58 | opt.no_flip = False 59 | opt.no_ganFeat_loss = False 60 | opt.no_html = True 61 | opt.no_instance = True 62 | opt.no_pairing_check = False 63 | opt.no_vgg_loss = False 64 | opt.norm_D = 'spectralinstance' 65 | opt.norm_E = 'spectralinstance' 66 | opt.norm_G = 'spectralspadesyncbatch3x3' 67 | opt.num_D = 2 68 | opt.num_upsampling_layers = 'normal' 69 | opt.optimizer = 'adam' 70 | opt.output_nc = 3 71 | opt.phase = 'train' 72 | opt.preprocess_mode = 'resize_and_crop' 73 | opt.print_freq = 100 74 | opt.save_epoch_freq = 10 75 | opt.save_latest_freq = 5000 76 | opt.serial_batches = False 77 | opt.tf_log = True 78 | opt.use_vae = False 79 | opt.which_epoch = 'latest' 80 | opt.z_dim = 256 81 | opt.gpu_ids=[0,1] 82 | opt.results_dir='../data/SPADE_from_scratch_results/' 83 | opt.semantic_nc = 151 84 | opt.label_nc = 150 85 | opt.label_dir = '/spell/bob_ross_segmented/labels/' 86 | opt.image_dir = '/spell/bob_ross_segmented/images/' 87 | opt.instance_dir = '' 88 | 89 | model = Pix2PixModel(opt) 90 | model.train() 91 | 92 | def train(): 93 | # create trainer for our model and freeze necessary model layers 94 | opt.niter = opt.niter + 20 # 20 more iterations of training 95 | opt.lr = 0.00002 # 1/10th of the original lr 96 | trainer = Pix2PixTrainer(opt) 97 | 98 | # Proceed with training. 99 | 100 | # load the dataset 101 | dataloader = data.create_dataloader(opt) 102 | 103 | trainer = Pix2PixTrainer(opt) 104 | 105 | # create tool for counting iterations 106 | iter_counter = IterationCounter(opt, len(dataloader)) 107 | 108 | # create tool for visualization 109 | visualizer = Visualizer(opt) 110 | 111 | for epoch in iter_counter.training_epochs(): 112 | iter_counter.record_epoch_start(epoch) 113 | for i, data_i in enumerate(dataloader, start=iter_counter.epoch_iter): 114 | iter_counter.record_one_iteration() 115 | 116 | # Training 117 | # train generator 118 | if i % opt.D_steps_per_G == 0: 119 | trainer.run_generator_one_step(data_i) 120 | 121 | # train discriminator 122 | trainer.run_discriminator_one_step(data_i) 123 | 124 | # Visualizations 125 | if iter_counter.needs_printing(): 126 | losses = trainer.get_latest_losses() 127 | visualizer.print_current_errors(epoch, iter_counter.epoch_iter, 128 | losses, iter_counter.time_per_iter) 129 | visualizer.plot_current_errors(losses, iter_counter.total_steps_so_far) 130 | 131 | if iter_counter.needs_displaying(): 132 | visuals = OrderedDict([('input_label', data_i['label']), 133 | ('synthesized_image', trainer.get_latest_generated()), 134 | ('real_image', data_i['image'])]) 135 | visualizer.display_current_results(visuals, epoch, iter_counter.total_steps_so_far) 136 | 137 | if iter_counter.needs_saving(): 138 | print('saving the latest model (epoch %d, total_steps %d)' % 139 | (epoch, iter_counter.total_steps_so_far)) 140 | trainer.save('latest') 141 | iter_counter.record_current_iter() 142 | 143 | trainer.update_learning_rate(epoch) 144 | iter_counter.record_epoch_end() 145 | 146 | if epoch % opt.save_epoch_freq == 0 or \ 147 | epoch == iter_counter.total_epochs: 148 | print('saving the model at the end of epoch %d, iters %d' % 149 | (epoch, iter_counter.total_steps_so_far)) 150 | trainer.save('latest') 151 | trainer.save(epoch) 152 | 153 | train() 154 | -------------------------------------------------------------------------------- /models/model_2.py: -------------------------------------------------------------------------------- 1 | # Changes relative to model_1.py: 2 | # * Increased output size to 512x512. 3 | 4 | import sys; sys.path.append('../lib/SPADE-master/') 5 | from options.train_options import TrainOptions 6 | from models.pix2pix_model import Pix2PixModel 7 | from collections import OrderedDict 8 | import data 9 | from util.iter_counter import IterationCounter 10 | from util.visualizer import Visualizer 11 | from trainers.pix2pix_trainer import Pix2PixTrainer 12 | import os 13 | 14 | opt = TrainOptions() 15 | opt.D_steps_per_G = 1 16 | opt.aspect_ratio = 1.0 17 | opt.batchSize = 1 18 | opt.beta1 = 0.0 19 | opt.beta2 = 0.9 20 | opt.cache_filelist_read = False 21 | opt.cache_filelist_write = False 22 | opt.checkpoints_dir = '/spell/checkpoints/' 23 | opt.contain_dontcare_label = False 24 | opt.continue_train = False 25 | opt.crop_size = 512 26 | opt.dataroot = '/spell/bob_ross_segmented/' # data mount point 27 | opt.dataset_mode = 'custom' 28 | opt.debug = False 29 | opt.display_freq = 100 30 | opt.display_winsize = 512 31 | opt.fff = 1 # junk value for the argparse 32 | opt.gan_mode = 'hinge' 33 | opt.init_type = 'xavier' 34 | opt.init_variance = 0.02 35 | opt.isTrain = True 36 | opt.lambda_feat = 10.0 37 | opt.lambda_kld = 0.05 38 | opt.lambda_vgg = 10.0 39 | opt.load_from_opt_file = False 40 | opt.load_size = 512 41 | opt.lr = 0.0002 42 | opt.max_dataset_size = 9223372036854775807 43 | opt.model = 'pix2pix' 44 | opt.nThreads = 0 45 | opt.n_layers_D = 4 46 | opt.name = 'bob_ross' 47 | opt.ndf = 64 48 | opt.nef = 16 49 | opt.netD = 'multiscale' 50 | opt.netD_subarch = 'n_layer' 51 | opt.netG = 'spade' 52 | opt.ngf = 64 53 | opt.niter = 50 54 | opt.niter_decay = 0 55 | opt.no_TTUR = False 56 | opt.no_flip = False 57 | opt.no_ganFeat_loss = False 58 | opt.no_html = True 59 | opt.no_instance = True 60 | opt.no_pairing_check = False 61 | opt.no_vgg_loss = False 62 | opt.norm_D = 'spectralinstance' 63 | opt.norm_E = 'spectralinstance' 64 | opt.norm_G = 'spectralspadesyncbatch3x3' 65 | opt.num_D = 2 66 | opt.num_upsampling_layers = 'normal' 67 | opt.optimizer = 'adam' 68 | opt.output_nc = 3 69 | opt.phase = 'train' 70 | opt.preprocess_mode = 'resize_and_crop' 71 | opt.print_freq = 100 72 | opt.save_epoch_freq = 10 73 | opt.save_latest_freq = 5000 74 | opt.serial_batches = False 75 | opt.tf_log = False 76 | opt.use_vae = True 77 | opt.which_epoch = 'latest' 78 | opt.z_dim = 256 79 | opt.gpu_ids=[0] 80 | opt.results_dir='/spell/bob_ross_segmented/results/' 81 | opt.semantic_nc = 9 82 | opt.label_nc = 9 83 | opt.label_dir = '/spell/bob_ross_segmented/training/labels/' 84 | opt.image_dir = '/spell/bob_ross_segmented/training/images/' 85 | opt.instance_dir = '' 86 | 87 | # Create the folder structure expected by the model checkpointing feature. 88 | if not os.path.exists('/spell/checkpoints/'): 89 | os.mkdir('/spell/checkpoints/') 90 | if not os.path.exists('/spell/checkpoints/bob_ross/'): 91 | os.mkdir('/spell/checkpoints/bob_ross/') 92 | 93 | model = Pix2PixModel(opt) 94 | model.train() 95 | 96 | def test_train(): 97 | # print options to help debugging 98 | # print(' '.join(sys.argv)) 99 | 100 | # load the dataset 101 | dataloader = data.create_dataloader(opt) 102 | 103 | # create trainer for our model 104 | trainer = Pix2PixTrainer(opt) 105 | 106 | # create tool for counting iterations 107 | iter_counter = IterationCounter(opt, len(dataloader)) 108 | 109 | # create tool for visualization 110 | visualizer = Visualizer(opt) 111 | 112 | for epoch in iter_counter.training_epochs(): 113 | iter_counter.record_epoch_start(epoch) 114 | for i, data_i in enumerate(dataloader, start=iter_counter.epoch_iter): 115 | iter_counter.record_one_iteration() 116 | 117 | # Training 118 | # train generator 119 | if i % opt.D_steps_per_G == 0: 120 | trainer.run_generator_one_step(data_i) 121 | 122 | # train discriminator 123 | trainer.run_discriminator_one_step(data_i) 124 | 125 | # Visualizations 126 | if iter_counter.needs_printing(): 127 | losses = trainer.get_latest_losses() 128 | visualizer.print_current_errors(epoch, iter_counter.epoch_iter, 129 | losses, iter_counter.time_per_iter) 130 | visualizer.plot_current_errors(losses, iter_counter.total_steps_so_far) 131 | 132 | if iter_counter.needs_displaying(): 133 | visuals = OrderedDict([('input_label', data_i['label']), 134 | ('synthesized_image', trainer.get_latest_generated()), 135 | ('real_image', data_i['image'])]) 136 | visualizer.display_current_results(visuals, epoch, iter_counter.total_steps_so_far) 137 | 138 | if iter_counter.needs_saving(): 139 | print('saving the latest model (epoch %d, total_steps %d)' % 140 | (epoch, iter_counter.total_steps_so_far)) 141 | trainer.save('latest') 142 | iter_counter.record_current_iter() 143 | 144 | trainer.update_learning_rate(epoch) 145 | iter_counter.record_epoch_end() 146 | 147 | if epoch % opt.save_epoch_freq == 0 or \ 148 | epoch == iter_counter.total_epochs: 149 | print('saving the model at the end of epoch %d, iters %d' % 150 | (epoch, iter_counter.total_steps_so_far)) 151 | trainer.save('latest') 152 | trainer.save(epoch) 153 | 154 | print('Training was successfully finished.') 155 | 156 | test_train() 157 | -------------------------------------------------------------------------------- /models/model_3.py: -------------------------------------------------------------------------------- 1 | # Changes relative to model_2.py: 2 | # * lambda_kld increased 10x (increases VAE fit training contribution) 3 | 4 | import sys; sys.path.append('../lib/SPADE-master/') 5 | from options.train_options import TrainOptions 6 | from models.pix2pix_model import Pix2PixModel 7 | from collections import OrderedDict 8 | import data 9 | from util.iter_counter import IterationCounter 10 | from util.visualizer import Visualizer 11 | from trainers.pix2pix_trainer import Pix2PixTrainer 12 | import os 13 | 14 | opt = TrainOptions() 15 | opt.D_steps_per_G = 1 16 | opt.aspect_ratio = 1.0 17 | opt.batchSize = 1 18 | opt.beta1 = 0.0 19 | opt.beta2 = 0.9 20 | opt.cache_filelist_read = False 21 | opt.cache_filelist_write = False 22 | opt.checkpoints_dir = '/spell/checkpoints/' 23 | opt.contain_dontcare_label = False 24 | opt.continue_train = False 25 | opt.crop_size = 512 26 | opt.dataroot = '/spell/bob_ross_segmented/' # data mount point 27 | opt.dataset_mode = 'custom' 28 | opt.debug = False 29 | opt.display_freq = 100 30 | opt.display_winsize = 512 31 | opt.fff = 1 # junk value for the argparse 32 | opt.gan_mode = 'hinge' 33 | opt.init_type = 'xavier' 34 | opt.init_variance = 0.02 35 | opt.isTrain = True 36 | opt.lambda_feat = 10.0 37 | opt.lambda_kld = 0.5 38 | opt.lambda_vgg = 10.0 39 | opt.load_from_opt_file = False 40 | opt.load_size = 512 41 | opt.lr = 0.0002 42 | opt.max_dataset_size = 9223372036854775807 43 | opt.model = 'pix2pix' 44 | opt.nThreads = 0 45 | opt.n_layers_D = 4 46 | opt.name = 'bob_ross' 47 | opt.ndf = 64 48 | opt.nef = 16 49 | opt.netD = 'multiscale' 50 | opt.netD_subarch = 'n_layer' 51 | opt.netG = 'spade' 52 | opt.ngf = 64 53 | opt.niter = 50 54 | opt.niter_decay = 0 55 | opt.no_TTUR = False 56 | opt.no_flip = False 57 | opt.no_ganFeat_loss = False 58 | opt.no_html = True 59 | opt.no_instance = True 60 | opt.no_pairing_check = False 61 | opt.no_vgg_loss = False 62 | opt.norm_D = 'spectralinstance' 63 | opt.norm_E = 'spectralinstance' 64 | opt.norm_G = 'spectralspadesyncbatch3x3' 65 | opt.num_D = 2 66 | opt.num_upsampling_layers = 'normal' 67 | opt.optimizer = 'adam' 68 | opt.output_nc = 3 69 | opt.phase = 'train' 70 | opt.preprocess_mode = 'resize_and_crop' 71 | opt.print_freq = 100 72 | opt.save_epoch_freq = 10 73 | opt.save_latest_freq = 5000 74 | opt.serial_batches = False 75 | opt.tf_log = False 76 | opt.use_vae = True 77 | opt.which_epoch = 'latest' 78 | opt.z_dim = 256 79 | opt.gpu_ids=[0] 80 | opt.results_dir='/spell/bob_ross_segmented/results/' 81 | opt.semantic_nc = 9 82 | opt.label_nc = 9 83 | opt.label_dir = '/spell/bob_ross_segmented/training/labels/' 84 | opt.image_dir = '/spell/bob_ross_segmented/training/images/' 85 | opt.instance_dir = '' 86 | 87 | # Create the folder structure expected by the model checkpointing feature. 88 | if not os.path.exists('/spell/checkpoints/'): 89 | os.mkdir('/spell/checkpoints/') 90 | if not os.path.exists('/spell/checkpoints/bob_ross/'): 91 | os.mkdir('/spell/checkpoints/bob_ross/') 92 | 93 | model = Pix2PixModel(opt) 94 | model.train() 95 | 96 | def test_train(): 97 | # print options to help debugging 98 | # print(' '.join(sys.argv)) 99 | 100 | # load the dataset 101 | dataloader = data.create_dataloader(opt) 102 | 103 | # create trainer for our model 104 | trainer = Pix2PixTrainer(opt) 105 | 106 | # create tool for counting iterations 107 | iter_counter = IterationCounter(opt, len(dataloader)) 108 | 109 | # create tool for visualization 110 | visualizer = Visualizer(opt) 111 | 112 | for epoch in iter_counter.training_epochs(): 113 | iter_counter.record_epoch_start(epoch) 114 | for i, data_i in enumerate(dataloader, start=iter_counter.epoch_iter): 115 | iter_counter.record_one_iteration() 116 | 117 | # Training 118 | # train generator 119 | if i % opt.D_steps_per_G == 0: 120 | trainer.run_generator_one_step(data_i) 121 | 122 | # train discriminator 123 | trainer.run_discriminator_one_step(data_i) 124 | 125 | # Visualizations 126 | if iter_counter.needs_printing(): 127 | losses = trainer.get_latest_losses() 128 | visualizer.print_current_errors(epoch, iter_counter.epoch_iter, 129 | losses, iter_counter.time_per_iter) 130 | visualizer.plot_current_errors(losses, iter_counter.total_steps_so_far) 131 | 132 | if iter_counter.needs_displaying(): 133 | visuals = OrderedDict([('input_label', data_i['label']), 134 | ('synthesized_image', trainer.get_latest_generated()), 135 | ('real_image', data_i['image'])]) 136 | visualizer.display_current_results(visuals, epoch, iter_counter.total_steps_so_far) 137 | 138 | if iter_counter.needs_saving(): 139 | print('saving the latest model (epoch %d, total_steps %d)' % 140 | (epoch, iter_counter.total_steps_so_far)) 141 | trainer.save('latest') 142 | iter_counter.record_current_iter() 143 | 144 | trainer.update_learning_rate(epoch) 145 | iter_counter.record_epoch_end() 146 | 147 | if epoch % opt.save_epoch_freq == 0 or \ 148 | epoch == iter_counter.total_epochs: 149 | print('saving the model at the end of epoch %d, iters %d' % 150 | (epoch, iter_counter.total_steps_so_far)) 151 | trainer.save('latest') 152 | trainer.save(epoch) 153 | 154 | print('Training was successfully finished.') 155 | 156 | test_train() 157 | -------------------------------------------------------------------------------- /models/model_4.py: -------------------------------------------------------------------------------- 1 | # This is a dummy model that trains with TensorBoard enabled just to make sure I see 2 | # where the files land. 3 | 4 | import sys; sys.path.append('../lib/SPADE-master/') 5 | from options.train_options import TrainOptions 6 | from models.pix2pix_model import Pix2PixModel 7 | from collections import OrderedDict 8 | import data 9 | from util.iter_counter import IterationCounter 10 | from util.visualizer import Visualizer 11 | from trainers.pix2pix_trainer import Pix2PixTrainer 12 | import os 13 | 14 | opt = TrainOptions() 15 | opt.D_steps_per_G = 1 16 | opt.aspect_ratio = 1.0 17 | opt.batchSize = 1 18 | opt.beta1 = 0.0 19 | opt.beta2 = 0.9 20 | opt.cache_filelist_read = False 21 | opt.cache_filelist_write = False 22 | opt.checkpoints_dir = '/spell/checkpoints/' 23 | opt.contain_dontcare_label = False 24 | opt.continue_train = False 25 | opt.crop_size = 512 26 | opt.dataroot = '/spell/bob_ross_segmented/' # data mount point 27 | opt.dataset_mode = 'custom' 28 | opt.debug = False 29 | opt.display_freq = 100 30 | opt.display_winsize = 512 31 | opt.fff = 1 # junk value for the argparse 32 | opt.gan_mode = 'hinge' 33 | opt.init_type = 'xavier' 34 | opt.init_variance = 0.02 35 | opt.isTrain = True 36 | opt.lambda_feat = 10.0 37 | opt.lambda_kld = 0.5 38 | opt.lambda_vgg = 10.0 39 | opt.load_from_opt_file = False 40 | opt.load_size = 512 41 | opt.lr = 0.0002 42 | opt.max_dataset_size = 9223372036854775807 43 | opt.model = 'pix2pix' 44 | opt.nThreads = 0 45 | opt.n_layers_D = 4 46 | opt.name = 'bob_ross' 47 | opt.ndf = 64 48 | opt.nef = 16 49 | opt.netD = 'multiscale' 50 | opt.netD_subarch = 'n_layer' 51 | opt.netG = 'spade' 52 | opt.ngf = 64 53 | opt.niter = 5 54 | opt.niter_decay = 0 55 | opt.no_TTUR = False 56 | opt.no_flip = False 57 | opt.no_ganFeat_loss = False 58 | opt.no_html = True 59 | opt.no_instance = True 60 | opt.no_pairing_check = False 61 | opt.no_vgg_loss = False 62 | opt.norm_D = 'spectralinstance' 63 | opt.norm_E = 'spectralinstance' 64 | opt.norm_G = 'spectralspadesyncbatch3x3' 65 | opt.num_D = 2 66 | opt.num_upsampling_layers = 'normal' 67 | opt.optimizer = 'adam' 68 | opt.output_nc = 3 69 | opt.phase = 'train' 70 | opt.preprocess_mode = 'resize_and_crop' 71 | opt.print_freq = 100 72 | opt.save_epoch_freq = 10 73 | opt.save_latest_freq = 5000 74 | opt.serial_batches = False 75 | opt.tf_log = True 76 | opt.use_vae = True 77 | opt.which_epoch = 'latest' 78 | opt.z_dim = 256 79 | opt.gpu_ids=[0] 80 | opt.results_dir='/spell/bob_ross_segmented/results/' 81 | opt.semantic_nc = 9 82 | opt.label_nc = 9 83 | opt.label_dir = '/spell/bob_ross_segmented/training/labels/' 84 | opt.image_dir = '/spell/bob_ross_segmented/training/images/' 85 | opt.instance_dir = '' 86 | 87 | # Create the folder structure expected by the model checkpointing feature. 88 | if not os.path.exists('/spell/checkpoints/'): 89 | os.mkdir('/spell/checkpoints/') 90 | if not os.path.exists('/spell/checkpoints/bob_ross/'): 91 | os.mkdir('/spell/checkpoints/bob_ross/') 92 | 93 | model = Pix2PixModel(opt) 94 | model.train() 95 | 96 | def test_train(): 97 | # print options to help debugging 98 | # print(' '.join(sys.argv)) 99 | 100 | # load the dataset 101 | dataloader = data.create_dataloader(opt) 102 | 103 | # create trainer for our model 104 | trainer = Pix2PixTrainer(opt) 105 | 106 | # create tool for counting iterations 107 | iter_counter = IterationCounter(opt, len(dataloader)) 108 | 109 | # create tool for visualization 110 | visualizer = Visualizer(opt) 111 | 112 | for epoch in iter_counter.training_epochs(): 113 | iter_counter.record_epoch_start(epoch) 114 | for i, data_i in enumerate(dataloader, start=iter_counter.epoch_iter): 115 | iter_counter.record_one_iteration() 116 | 117 | # Training 118 | # train generator 119 | if i % opt.D_steps_per_G == 0: 120 | trainer.run_generator_one_step(data_i) 121 | 122 | # train discriminator 123 | trainer.run_discriminator_one_step(data_i) 124 | 125 | # Visualizations 126 | if iter_counter.needs_printing(): 127 | losses = trainer.get_latest_losses() 128 | visualizer.print_current_errors(epoch, iter_counter.epoch_iter, 129 | losses, iter_counter.time_per_iter) 130 | visualizer.plot_current_errors(losses, iter_counter.total_steps_so_far) 131 | 132 | if iter_counter.needs_displaying(): 133 | visuals = OrderedDict([('input_label', data_i['label']), 134 | ('synthesized_image', trainer.get_latest_generated()), 135 | ('real_image', data_i['image'])]) 136 | visualizer.display_current_results(visuals, epoch, iter_counter.total_steps_so_far) 137 | 138 | if iter_counter.needs_saving(): 139 | print('saving the latest model (epoch %d, total_steps %d)' % 140 | (epoch, iter_counter.total_steps_so_far)) 141 | trainer.save('latest') 142 | iter_counter.record_current_iter() 143 | 144 | trainer.update_learning_rate(epoch) 145 | iter_counter.record_epoch_end() 146 | 147 | if epoch % opt.save_epoch_freq == 0 or \ 148 | epoch == iter_counter.total_epochs: 149 | print('saving the model at the end of epoch %d, iters %d' % 150 | (epoch, iter_counter.total_steps_so_far)) 151 | trainer.save('latest') 152 | trainer.save(epoch) 153 | 154 | print('Training was successfully finished.') 155 | 156 | test_train() 157 | -------------------------------------------------------------------------------- /models/model_5.py: -------------------------------------------------------------------------------- 1 | # This model trains for 100 epochs instead of the standard 50; and it uses the TensorBoard 2 | # integration. KLD is still set to 10x the default. 3 | 4 | import sys; sys.path.append('../lib/SPADE-master/') 5 | from options.train_options import TrainOptions 6 | from models.pix2pix_model import Pix2PixModel 7 | from collections import OrderedDict 8 | import data 9 | from util.iter_counter import IterationCounter 10 | from util.visualizer import Visualizer 11 | from trainers.pix2pix_trainer import Pix2PixTrainer 12 | import os 13 | 14 | opt = TrainOptions() 15 | opt.D_steps_per_G = 1 16 | opt.aspect_ratio = 1.0 17 | opt.batchSize = 1 18 | opt.beta1 = 0.0 19 | opt.beta2 = 0.9 20 | opt.cache_filelist_read = False 21 | opt.cache_filelist_write = False 22 | opt.checkpoints_dir = '/spell/checkpoints/' 23 | opt.contain_dontcare_label = False 24 | opt.continue_train = False 25 | opt.crop_size = 512 26 | opt.dataroot = '/spell/bob_ross_segmented/' # data mount point 27 | opt.dataset_mode = 'custom' 28 | opt.debug = False 29 | opt.display_freq = 100 30 | opt.display_winsize = 512 31 | opt.fff = 1 # junk value for the argparse 32 | opt.gan_mode = 'hinge' 33 | opt.init_type = 'xavier' 34 | opt.init_variance = 0.02 35 | opt.isTrain = True 36 | opt.lambda_feat = 10.0 37 | opt.lambda_kld = 0.5 38 | opt.lambda_vgg = 10.0 39 | opt.load_from_opt_file = False 40 | opt.load_size = 512 41 | opt.lr = 0.0002 42 | opt.max_dataset_size = 9223372036854775807 43 | opt.model = 'pix2pix' 44 | opt.nThreads = 0 45 | opt.n_layers_D = 4 46 | opt.name = 'bob_ross' 47 | opt.ndf = 64 48 | opt.nef = 16 49 | opt.netD = 'multiscale' 50 | opt.netD_subarch = 'n_layer' 51 | opt.netG = 'spade' 52 | opt.ngf = 64 53 | opt.niter = 100 54 | opt.niter_decay = 0 55 | opt.no_TTUR = False 56 | opt.no_flip = False 57 | opt.no_ganFeat_loss = False 58 | opt.no_html = True 59 | opt.no_instance = True 60 | opt.no_pairing_check = False 61 | opt.no_vgg_loss = False 62 | opt.norm_D = 'spectralinstance' 63 | opt.norm_E = 'spectralinstance' 64 | opt.norm_G = 'spectralspadesyncbatch3x3' 65 | opt.num_D = 2 66 | opt.num_upsampling_layers = 'normal' 67 | opt.optimizer = 'adam' 68 | opt.output_nc = 3 69 | opt.phase = 'train' 70 | opt.preprocess_mode = 'resize_and_crop' 71 | opt.print_freq = 100 72 | opt.save_epoch_freq = 10 73 | opt.save_latest_freq = 5000 74 | opt.serial_batches = False 75 | opt.tf_log = True 76 | opt.use_vae = True 77 | opt.which_epoch = 'latest' 78 | opt.z_dim = 256 79 | opt.gpu_ids=[0] 80 | opt.results_dir='/spell/bob_ross_segmented/results/' 81 | opt.semantic_nc = 9 82 | opt.label_nc = 9 83 | opt.label_dir = '/spell/bob_ross_segmented/training/labels/' 84 | opt.image_dir = '/spell/bob_ross_segmented/training/images/' 85 | opt.instance_dir = '' 86 | 87 | # Create the folder structure expected by the model checkpointing feature. 88 | if not os.path.exists('/spell/checkpoints/'): 89 | os.mkdir('/spell/checkpoints/') 90 | if not os.path.exists('/spell/checkpoints/bob_ross/'): 91 | os.mkdir('/spell/checkpoints/bob_ross/') 92 | 93 | model = Pix2PixModel(opt) 94 | model.train() 95 | 96 | def test_train(): 97 | # print options to help debugging 98 | # print(' '.join(sys.argv)) 99 | 100 | # load the dataset 101 | dataloader = data.create_dataloader(opt) 102 | 103 | # create trainer for our model 104 | trainer = Pix2PixTrainer(opt) 105 | 106 | # create tool for counting iterations 107 | iter_counter = IterationCounter(opt, len(dataloader)) 108 | 109 | # create tool for visualization 110 | visualizer = Visualizer(opt) 111 | 112 | for epoch in iter_counter.training_epochs(): 113 | iter_counter.record_epoch_start(epoch) 114 | for i, data_i in enumerate(dataloader, start=iter_counter.epoch_iter): 115 | iter_counter.record_one_iteration() 116 | 117 | # Training 118 | # train generator 119 | if i % opt.D_steps_per_G == 0: 120 | trainer.run_generator_one_step(data_i) 121 | 122 | # train discriminator 123 | trainer.run_discriminator_one_step(data_i) 124 | 125 | # Visualizations 126 | if iter_counter.needs_printing(): 127 | losses = trainer.get_latest_losses() 128 | visualizer.print_current_errors(epoch, iter_counter.epoch_iter, 129 | losses, iter_counter.time_per_iter) 130 | visualizer.plot_current_errors(losses, iter_counter.total_steps_so_far) 131 | 132 | if iter_counter.needs_displaying(): 133 | visuals = OrderedDict([('input_label', data_i['label']), 134 | ('synthesized_image', trainer.get_latest_generated()), 135 | ('real_image', data_i['image'])]) 136 | visualizer.display_current_results(visuals, epoch, iter_counter.total_steps_so_far) 137 | 138 | if iter_counter.needs_saving(): 139 | print('saving the latest model (epoch %d, total_steps %d)' % 140 | (epoch, iter_counter.total_steps_so_far)) 141 | trainer.save('latest') 142 | iter_counter.record_current_iter() 143 | 144 | trainer.update_learning_rate(epoch) 145 | iter_counter.record_epoch_end() 146 | 147 | if epoch % opt.save_epoch_freq == 0 or \ 148 | epoch == iter_counter.total_epochs: 149 | print('saving the model at the end of epoch %d, iters %d' % 150 | (epoch, iter_counter.total_steps_so_far)) 151 | trainer.save('latest') 152 | trainer.save(epoch) 153 | 154 | print('Training was successfully finished.') 155 | 156 | test_train() 157 | -------------------------------------------------------------------------------- /models/model_6.py: -------------------------------------------------------------------------------- 1 | # This model trains for 100 epochs instead of the standard 50, and it uses the TensorBoard 2 | # integration. KLD is reset back to its original value of 0.05; compare to model_5.py, where 3 | # KLD is set 10x higher, to 0.5. 4 | 5 | import sys; sys.path.append('../lib/SPADE-master/') 6 | from options.train_options import TrainOptions 7 | from models.pix2pix_model import Pix2PixModel 8 | from collections import OrderedDict 9 | import data 10 | from util.iter_counter import IterationCounter 11 | from util.visualizer import Visualizer 12 | from trainers.pix2pix_trainer import Pix2PixTrainer 13 | import os 14 | 15 | opt = TrainOptions() 16 | opt.D_steps_per_G = 1 17 | opt.aspect_ratio = 1.0 18 | opt.batchSize = 1 19 | opt.beta1 = 0.0 20 | opt.beta2 = 0.9 21 | opt.cache_filelist_read = False 22 | opt.cache_filelist_write = False 23 | opt.checkpoints_dir = '/spell/checkpoints/' 24 | opt.contain_dontcare_label = False 25 | opt.continue_train = False 26 | opt.crop_size = 512 27 | opt.dataroot = '/spell/bob_ross_segmented/' # data mount point 28 | opt.dataset_mode = 'custom' 29 | opt.debug = False 30 | opt.display_freq = 100 31 | opt.display_winsize = 512 32 | opt.fff = 1 # junk value for the argparse 33 | opt.gan_mode = 'hinge' 34 | opt.init_type = 'xavier' 35 | opt.init_variance = 0.02 36 | opt.isTrain = True 37 | opt.lambda_feat = 10.0 38 | opt.lambda_kld = 0.05 39 | opt.lambda_vgg = 10.0 40 | opt.load_from_opt_file = False 41 | opt.load_size = 512 42 | opt.lr = 0.0002 43 | opt.max_dataset_size = 9223372036854775807 44 | opt.model = 'pix2pix' 45 | opt.nThreads = 0 46 | opt.n_layers_D = 4 47 | opt.name = 'bob_ross' 48 | opt.ndf = 64 49 | opt.nef = 16 50 | opt.netD = 'multiscale' 51 | opt.netD_subarch = 'n_layer' 52 | opt.netG = 'spade' 53 | opt.ngf = 64 54 | opt.niter = 100 55 | opt.niter_decay = 0 56 | opt.no_TTUR = False 57 | opt.no_flip = False 58 | opt.no_ganFeat_loss = False 59 | opt.no_html = True 60 | opt.no_instance = True 61 | opt.no_pairing_check = False 62 | opt.no_vgg_loss = False 63 | opt.norm_D = 'spectralinstance' 64 | opt.norm_E = 'spectralinstance' 65 | opt.norm_G = 'spectralspadesyncbatch3x3' 66 | opt.num_D = 2 67 | opt.num_upsampling_layers = 'normal' 68 | opt.optimizer = 'adam' 69 | opt.output_nc = 3 70 | opt.phase = 'train' 71 | opt.preprocess_mode = 'resize_and_crop' 72 | opt.print_freq = 100 73 | opt.save_epoch_freq = 10 74 | opt.save_latest_freq = 5000 75 | opt.serial_batches = False 76 | opt.tf_log = True 77 | opt.use_vae = True 78 | opt.which_epoch = 'latest' 79 | opt.z_dim = 256 80 | opt.gpu_ids=[0] 81 | opt.results_dir='/spell/bob_ross_segmented/results/' 82 | opt.semantic_nc = 9 83 | opt.label_nc = 9 84 | opt.label_dir = '/spell/bob_ross_segmented/training/labels/' 85 | opt.image_dir = '/spell/bob_ross_segmented/training/images/' 86 | opt.instance_dir = '' 87 | 88 | # Create the folder structure expected by the model checkpointing feature. 89 | if not os.path.exists('/spell/checkpoints/'): 90 | os.mkdir('/spell/checkpoints/') 91 | if not os.path.exists('/spell/checkpoints/bob_ross/'): 92 | os.mkdir('/spell/checkpoints/bob_ross/') 93 | 94 | model = Pix2PixModel(opt) 95 | model.train() 96 | 97 | def test_train(): 98 | # print options to help debugging 99 | # print(' '.join(sys.argv)) 100 | 101 | # load the dataset 102 | dataloader = data.create_dataloader(opt) 103 | 104 | # create trainer for our model 105 | trainer = Pix2PixTrainer(opt) 106 | 107 | # create tool for counting iterations 108 | iter_counter = IterationCounter(opt, len(dataloader)) 109 | 110 | # create tool for visualization 111 | visualizer = Visualizer(opt) 112 | 113 | for epoch in iter_counter.training_epochs(): 114 | iter_counter.record_epoch_start(epoch) 115 | for i, data_i in enumerate(dataloader, start=iter_counter.epoch_iter): 116 | iter_counter.record_one_iteration() 117 | 118 | # Training 119 | # train generator 120 | if i % opt.D_steps_per_G == 0: 121 | trainer.run_generator_one_step(data_i) 122 | 123 | # train discriminator 124 | trainer.run_discriminator_one_step(data_i) 125 | 126 | # Visualizations 127 | if iter_counter.needs_printing(): 128 | losses = trainer.get_latest_losses() 129 | visualizer.print_current_errors(epoch, iter_counter.epoch_iter, 130 | losses, iter_counter.time_per_iter) 131 | visualizer.plot_current_errors(losses, iter_counter.total_steps_so_far) 132 | 133 | if iter_counter.needs_displaying(): 134 | visuals = OrderedDict([('input_label', data_i['label']), 135 | ('synthesized_image', trainer.get_latest_generated()), 136 | ('real_image', data_i['image'])]) 137 | visualizer.display_current_results(visuals, epoch, iter_counter.total_steps_so_far) 138 | 139 | if iter_counter.needs_saving(): 140 | print('saving the latest model (epoch %d, total_steps %d)' % 141 | (epoch, iter_counter.total_steps_so_far)) 142 | trainer.save('latest') 143 | iter_counter.record_current_iter() 144 | 145 | trainer.update_learning_rate(epoch) 146 | iter_counter.record_epoch_end() 147 | 148 | if epoch % opt.save_epoch_freq == 0 or \ 149 | epoch == iter_counter.total_epochs: 150 | print('saving the model at the end of epoch %d, iters %d' % 151 | (epoch, iter_counter.total_steps_so_far)) 152 | trainer.save('latest') 153 | trainer.save(epoch) 154 | 155 | print('Training was successfully finished.') 156 | 157 | test_train() 158 | -------------------------------------------------------------------------------- /models/model_7.py: -------------------------------------------------------------------------------- 1 | # This is a ADE20K training script in its original configuration. 2 | # Used to test how long and how expensive this model would be to train from scratch. 3 | 4 | import sys; sys.path.append('../lib/SPADE-master/') 5 | from options.train_options import TrainOptions 6 | from models.pix2pix_model import Pix2PixModel 7 | from collections import OrderedDict 8 | import data 9 | from util.iter_counter import IterationCounter 10 | from util.visualizer import Visualizer 11 | from trainers.pix2pix_trainer import Pix2PixTrainer 12 | 13 | opt = TrainOptions() 14 | 15 | opt.D_steps_per_G = 1 16 | opt.aspect_ratio = 1.0 17 | opt.batchSize = 1 18 | opt.beta1 = 0.0 19 | opt.beta2 = 0.9 20 | opt.cache_filelist_read = False 21 | opt.cache_filelist_write = False 22 | opt.checkpoints_dir = '/spell/checkpoints/' 23 | opt.contain_dontcare_label = True 24 | opt.continue_train = False 25 | opt.crop_size = 256 26 | opt.dataroot = '/spell/ade20k_c/ADEChallengeData2016/' # data mount point 27 | opt.dataset_mode = "ade20k" 28 | opt.debug = False 29 | opt.display_freq = 100 30 | opt.display_winsize = 256 31 | opt.fff = 1 # junk value for the argparse 32 | opt.gan_mode = 'hinge' 33 | # opt.gpu_ids = [] 34 | opt.init_type = 'xavier' 35 | opt.init_variance = 0.02 36 | opt.isTrain = True 37 | opt.label_nc = 150 38 | opt.lambda_feat = 10.0 39 | opt.lambda_kld = 0.05 40 | opt.lambda_vgg = 10.0 41 | opt.load_from_opt_file = False 42 | opt.load_size = 286 43 | opt.lr = 0.0002 44 | opt.max_dataset_size = 9223372036854775807 45 | opt.model = 'pix2pix' 46 | opt.nThreads = 0 47 | opt.n_layers_D = 4 48 | opt.name = 'ade20k_pretrained' 49 | opt.ndf = 64 50 | opt.nef = 16 51 | opt.netD = 'multiscale' 52 | opt.netD_subarch = 'n_layer' 53 | opt.netG = 'spade' 54 | opt.ngf = 64 55 | opt.niter = 50 56 | opt.niter_decay = 0 57 | opt.no_TTUR = False 58 | opt.no_flip = False 59 | opt.no_ganFeat_loss = False 60 | opt.no_html = True 61 | opt.no_instance = True 62 | opt.no_pairing_check = False 63 | opt.no_vgg_loss = False 64 | opt.norm_D = 'spectralinstance' 65 | opt.norm_E = 'spectralinstance' 66 | opt.norm_G = 'spectralspadesyncbatch3x3' 67 | opt.num_D = 2 68 | opt.num_upsampling_layers = 'normal' 69 | opt.optimizer = 'adam' 70 | opt.output_nc = 3 71 | opt.phase = 'train' 72 | opt.preprocess_mode = 'resize_and_crop' 73 | opt.print_freq = 100 74 | opt.save_epoch_freq = 10 75 | opt.save_latest_freq = 5000 76 | opt.serial_batches = False 77 | opt.tf_log = False 78 | opt.use_vae = False 79 | opt.which_epoch = 'latest' 80 | opt.z_dim = 256 81 | 82 | # addition arguments copied over from the previous TestOptions declarer 83 | opt.gpu_ids=[0] 84 | opt.results_dir='../data/SPADE_from_scratch_results/' 85 | opt.semantic_nc = 151 -------------------------------------------------------------------------------- /models/model_8.py: -------------------------------------------------------------------------------- 1 | # This is a ADE20K training script in a 512x512 configuration. 2 | # Used to test how long and how expensive this model would be to train from scratch. 3 | 4 | import sys; sys.path.append('../lib/SPADE-master/') 5 | from options.train_options import TrainOptions 6 | from models.pix2pix_model import Pix2PixModel 7 | from collections import OrderedDict 8 | import data 9 | from util.iter_counter import IterationCounter 10 | from util.visualizer import Visualizer 11 | from trainers.pix2pix_trainer import Pix2PixTrainer 12 | 13 | opt = TrainOptions() 14 | 15 | opt.D_steps_per_G = 1 16 | opt.aspect_ratio = 1.0 17 | opt.batchSize = 1 18 | opt.beta1 = 0.0 19 | opt.beta2 = 0.9 20 | opt.cache_filelist_read = False 21 | opt.cache_filelist_write = False 22 | opt.checkpoints_dir = '/spell/checkpoints/' 23 | opt.contain_dontcare_label = True 24 | opt.continue_train = False 25 | opt.crop_size = 512 26 | opt.dataroot = '/spell/ade20k_c/ADEChallengeData2016/' # data mount point 27 | opt.dataset_mode = "ade20k" 28 | opt.debug = False 29 | opt.display_freq = 100 30 | opt.display_winsize = 512 31 | opt.fff = 1 # junk value for the argparse 32 | opt.gan_mode = 'hinge' 33 | # opt.gpu_ids = [] 34 | opt.init_type = 'xavier' 35 | opt.init_variance = 0.02 36 | opt.isTrain = True 37 | opt.label_nc = 150 38 | opt.lambda_feat = 10.0 39 | opt.lambda_kld = 0.05 40 | opt.lambda_vgg = 10.0 41 | opt.load_from_opt_file = False 42 | opt.load_size = 512 43 | opt.lr = 0.0002 44 | opt.max_dataset_size = 9223372036854775807 45 | opt.model = 'pix2pix' 46 | opt.nThreads = 0 47 | opt.n_layers_D = 4 48 | opt.name = 'ade20k_pretrained' 49 | opt.ndf = 64 50 | opt.nef = 16 51 | opt.netD = 'multiscale' 52 | opt.netD_subarch = 'n_layer' 53 | opt.netG = 'spade' 54 | opt.ngf = 64 55 | opt.niter = 50 56 | opt.niter_decay = 0 57 | opt.no_TTUR = False 58 | opt.no_flip = False 59 | opt.no_ganFeat_loss = False 60 | opt.no_html = True 61 | opt.no_instance = True 62 | opt.no_pairing_check = False 63 | opt.no_vgg_loss = False 64 | opt.norm_D = 'spectralinstance' 65 | opt.norm_E = 'spectralinstance' 66 | opt.norm_G = 'spectralspadesyncbatch3x3' 67 | opt.num_D = 2 68 | opt.num_upsampling_layers = 'normal' 69 | opt.optimizer = 'adam' 70 | opt.output_nc = 3 71 | opt.phase = 'train' 72 | opt.preprocess_mode = 'resize_and_crop' 73 | opt.print_freq = 100 74 | opt.save_epoch_freq = 10 75 | opt.save_latest_freq = 5000 76 | opt.serial_batches = False 77 | opt.tf_log = False 78 | opt.use_vae = False 79 | opt.which_epoch = 'latest' 80 | opt.z_dim = 256 81 | 82 | # addition arguments copied over from the previous TestOptions declarer 83 | opt.gpu_ids=[0] 84 | opt.results_dir='../data/SPADE_from_scratch_results/' 85 | opt.semantic_nc = 151 -------------------------------------------------------------------------------- /models/model_9.py: -------------------------------------------------------------------------------- 1 | # This is model 6 with VAE turned off. 2 | 3 | import sys; sys.path.append('../lib/SPADE-master/') 4 | from options.train_options import TrainOptions 5 | from models.pix2pix_model import Pix2PixModel 6 | from collections import OrderedDict 7 | import data 8 | from util.iter_counter import IterationCounter 9 | from util.visualizer import Visualizer 10 | from trainers.pix2pix_trainer import Pix2PixTrainer 11 | import os 12 | 13 | opt = TrainOptions() 14 | opt.D_steps_per_G = 1 15 | opt.aspect_ratio = 1.0 16 | opt.batchSize = 1 17 | opt.beta1 = 0.0 18 | opt.beta2 = 0.9 19 | opt.cache_filelist_read = False 20 | opt.cache_filelist_write = False 21 | opt.checkpoints_dir = '/spell/checkpoints/' 22 | opt.contain_dontcare_label = False 23 | opt.continue_train = False 24 | opt.crop_size = 512 25 | opt.dataroot = '/spell/bob_ross_segmented/' # data mount point 26 | opt.dataset_mode = 'custom' 27 | opt.debug = False 28 | opt.display_freq = 100 29 | opt.display_winsize = 512 30 | opt.fff = 1 # junk value for the argparse 31 | opt.gan_mode = 'hinge' 32 | opt.init_type = 'xavier' 33 | opt.init_variance = 0.02 34 | opt.isTrain = True 35 | opt.lambda_feat = 10.0 36 | opt.lambda_kld = 0.05 37 | opt.lambda_vgg = 10.0 38 | opt.load_from_opt_file = False 39 | opt.load_size = 512 40 | opt.lr = 0.0002 41 | opt.max_dataset_size = 9223372036854775807 42 | opt.model = 'pix2pix' 43 | opt.nThreads = 0 44 | opt.n_layers_D = 4 45 | opt.name = 'bob_ross' 46 | opt.ndf = 64 47 | opt.nef = 16 48 | opt.netD = 'multiscale' 49 | opt.netD_subarch = 'n_layer' 50 | opt.netG = 'spade' 51 | opt.ngf = 64 52 | opt.niter = 100 53 | opt.niter_decay = 0 54 | opt.no_TTUR = False 55 | opt.no_flip = False 56 | opt.no_ganFeat_loss = False 57 | opt.no_html = True 58 | opt.no_instance = True 59 | opt.no_pairing_check = False 60 | opt.no_vgg_loss = False 61 | opt.norm_D = 'spectralinstance' 62 | opt.norm_E = 'spectralinstance' 63 | opt.norm_G = 'spectralspadesyncbatch3x3' 64 | opt.num_D = 2 65 | opt.num_upsampling_layers = 'normal' 66 | opt.optimizer = 'adam' 67 | opt.output_nc = 3 68 | opt.phase = 'train' 69 | opt.preprocess_mode = 'resize_and_crop' 70 | opt.print_freq = 100 71 | opt.save_epoch_freq = 10 72 | opt.save_latest_freq = 5000 73 | opt.serial_batches = False 74 | opt.tf_log = True 75 | opt.use_vae = False 76 | opt.which_epoch = 'latest' 77 | opt.z_dim = 256 78 | opt.gpu_ids=[0] 79 | opt.results_dir='/spell/bob_ross_segmented/results/' 80 | opt.semantic_nc = 9 81 | opt.label_nc = 9 82 | opt.label_dir = '/spell/bob_ross_segmented/training/labels/' 83 | opt.image_dir = '/spell/bob_ross_segmented/training/images/' 84 | opt.instance_dir = '' 85 | 86 | # Create the folder structure expected by the model checkpointing feature. 87 | if not os.path.exists('/spell/checkpoints/'): 88 | os.mkdir('/spell/checkpoints/') 89 | if not os.path.exists('/spell/checkpoints/bob_ross/'): 90 | os.mkdir('/spell/checkpoints/bob_ross/') 91 | 92 | model = Pix2PixModel(opt) 93 | model.train() 94 | 95 | def test_train(): 96 | # print options to help debugging 97 | # print(' '.join(sys.argv)) 98 | 99 | # load the dataset 100 | dataloader = data.create_dataloader(opt) 101 | 102 | # create trainer for our model 103 | trainer = Pix2PixTrainer(opt) 104 | 105 | # create tool for counting iterations 106 | iter_counter = IterationCounter(opt, len(dataloader)) 107 | 108 | # create tool for visualization 109 | visualizer = Visualizer(opt) 110 | 111 | for epoch in iter_counter.training_epochs(): 112 | iter_counter.record_epoch_start(epoch) 113 | for i, data_i in enumerate(dataloader, start=iter_counter.epoch_iter): 114 | iter_counter.record_one_iteration() 115 | 116 | # Training 117 | # train generator 118 | if i % opt.D_steps_per_G == 0: 119 | trainer.run_generator_one_step(data_i) 120 | 121 | # train discriminator 122 | trainer.run_discriminator_one_step(data_i) 123 | 124 | # Visualizations 125 | if iter_counter.needs_printing(): 126 | losses = trainer.get_latest_losses() 127 | visualizer.print_current_errors(epoch, iter_counter.epoch_iter, 128 | losses, iter_counter.time_per_iter) 129 | visualizer.plot_current_errors(losses, iter_counter.total_steps_so_far) 130 | 131 | if iter_counter.needs_displaying(): 132 | visuals = OrderedDict([('input_label', data_i['label']), 133 | ('synthesized_image', trainer.get_latest_generated()), 134 | ('real_image', data_i['image'])]) 135 | visualizer.display_current_results(visuals, epoch, iter_counter.total_steps_so_far) 136 | 137 | if iter_counter.needs_saving(): 138 | print('saving the latest model (epoch %d, total_steps %d)' % 139 | (epoch, iter_counter.total_steps_so_far)) 140 | trainer.save('latest') 141 | iter_counter.record_current_iter() 142 | 143 | trainer.update_learning_rate(epoch) 144 | iter_counter.record_epoch_end() 145 | 146 | if epoch % opt.save_epoch_freq == 0 or \ 147 | epoch == iter_counter.total_epochs: 148 | print('saving the model at the end of epoch %d, iters %d' % 149 | (epoch, iter_counter.total_steps_so_far)) 150 | trainer.save('latest') 151 | trainer.save(epoch) 152 | 153 | print('Training was successfully finished.') 154 | 155 | test_train() 156 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.0.0 2 | torchvision 3 | dominate>=2.3.1 4 | dill 5 | scikit-image 6 | jupyterlab 7 | requests_futures --------------------------------------------------------------------------------