├── .gitignore ├── README.md ├── docs ├── .DS_Store ├── 404.html ├── Gemfile ├── _config.yml ├── about │ └── index.html ├── assets │ ├── css │ │ └── app.css │ ├── fig1.jpg │ ├── fig2.jpg │ ├── fig3.jpg │ └── js │ │ ├── app.js │ │ └── getImages.js ├── dataset.html ├── development │ └── 2018 │ │ └── 05 │ │ └── 28 │ │ └── welcome-to-jekyll.html ├── feed.xml ├── index.html ├── index2.html ├── robots.txt └── sitemap.xml ├── edgegan ├── __init__.py ├── models │ ├── __init__.py │ ├── classifier.py │ ├── discriminator.py │ ├── edgegan.py │ ├── encoder.py │ └── generator.py ├── nn │ ├── __init__.py │ ├── functional.py │ └── modules │ │ ├── __init__.py │ │ ├── activation.py │ │ ├── conv.py │ │ ├── linear.py │ │ ├── normalization.py │ │ ├── pooling.py │ │ └── upsampling.py ├── test.py ├── train.py └── utils │ ├── __init__.py │ ├── data │ ├── __init__.py │ └── dataset.py │ └── utils.py ├── images └── dataset_example │ ├── test │ ├── 14809.png │ ├── 14810.png │ ├── 14811.png │ └── 14812.png │ └── train │ ├── 60975.png │ ├── 60981.png │ ├── 60987.png │ ├── 60991.png │ └── 60994.png └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | sample*/ 132 | checkpoint*/ 133 | logs/ 134 | .vscode/ 135 | checksum/ 136 | *.pkl 137 | backup/ 138 | data/ 139 | !edgegan/utils/data/ 140 | data 141 | outputs 142 | outputs/ 143 | localdata 144 | !images/* 145 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # EdgeGAN 2 | ### [Project Page](https://sysu-imsl.com/EdgeGAN/) | [Paper](https://arxiv.org/abs/2003.02683) 3 | SketchyCOCO: Image Generation from Freehand Scene Sketches 4 | Chengying Gao, Qi Liu, Qi Xu, Limin Wang, Jianzhuang Liu, Changqing Zou 5 | 6 | # Installation 7 | Clone this repo. 8 | ``` 9 | git@github.com:sysu-imsl/EdgeGAN.git 10 | cd EdgeGAN 11 | ``` 12 | This repo requires TensorFlow 1.14.0 and python 3+. 13 | `conda create/activate` is suggested to manage multiple versions of tensorflow. 14 | After switching to proper conda environment, run `conda install --file requirements.txt` 15 | 16 | # Dataset 17 | Our dataset can be found in [SketchyCOCO](https://github.com/sysu-imsl/SketchyCOCO). Follow the guide and prepare the dataset. 18 | 19 | ## Directory Structure 20 | For singleclass dataset 21 | ``` 22 | EdegGAN 23 | └───data 24 | └───train 25 | | | .png 26 | | | .png 27 | | | ... 28 | | 29 | └───test 30 | | .png 31 | | .png 32 | | ... 33 | ``` 34 | For multiclass dataset 35 | 36 | ``` 37 | EdegGAN 38 | └───data 39 | └───train 40 | | └───<0> 41 | | | | .png 42 | | | | ... 43 | | └───<1> 44 | | | ... 45 | | 46 | └───test 47 | | └───<0> 48 | | | | .png 49 | | | | ... 50 | | └───<1> 51 | | | ... 52 | ``` 53 | For our pretrained model, the class label 0 to 13 correspond to "airplane, cat, giraffe, zebra, dog, elephant, fire hydrant, horse, bicycle, car, traffic light, cow, motorcycle, sheep". Please prepare the input as similar as the examples of training images and test images in [images/dataset_example](https://github.com/sysu-imsl/EdgeGAN/tree/master/images/dataset_example). 54 | 55 | ## Example 56 | ### Train 57 | ![60975.png](images/dataset_example/train/60975.png?raw=true) 58 | ![60981.png](images/dataset_example/train/60981.png?raw=true) 59 | ![60987.png](images/dataset_example/train/60987.png?raw=true) 60 | ![60991.png](images/dataset_example/train/60991.png?raw=true) 61 | ![60994.png](images/dataset_example/train/60994.png?raw=true) 62 | ### Test 63 | ![14809.png](images/dataset_example/test/14809.png?raw=true) 64 | ![14810.png](images/dataset_example/test/14810.png?raw=true) 65 | ![14811.png](images/dataset_example/test/14811.png?raw=true) 66 | ![14812.png](images/dataset_example/test/14812.png?raw=true) 67 | 68 | # Testing 69 | 1. Download the pretrained model from [Google Drive](https://drive.google.com/file/d/1ilxx_mLKaiMRhwzzcrXjIaNlsmfqR6MT/view?usp=sharing) trained with 14 classes, and run: 70 | ``` bash 71 | mkdir -p outputs/edgegan 72 | cd outputs/edgegan 73 | cp . 74 | unzip checkpoints.zip 75 | cd ../.. 76 | ``` 77 | 2. Generate images with models: 78 | ``` bash 79 | python -m edgegan.test --name=edgegan --dataroot= --dataset= --gpu= #(model trained with multi-classes) 80 | python -m edgegan.test --name=[model_name] --dataroot= --dataset= --nomulticlasses --gpu= #(model trained with single class) 81 | ``` 82 | 3. the outputs will be located at `outputs/edgegan/test_output/` by default 83 | 84 | # Training 85 | It will cost about fifteen hours to run on a single Nvidia RTX 2080 Ti card. 86 | 87 | ``` bash 88 | python -m edgegan.train --name= --dataroot= --dataset= --gpu= #(with multi-classes) 89 | python -m edgegan.train --name= --dataroot= --dataset= --nomulticlasses --gpu= #(with single class) 90 | 91 | ``` 92 | 93 | # Citation 94 | If you use this code for your research, please cite our papers. 95 | ``` 96 | @inproceedings{gao2020sketchycoco, 97 | title={SketchyCOCO: Image Generation From Freehand Scene Sketches}, 98 | author={Gao, Chengying and Liu, Qi and Xu, Qi and Wang, Limin and Liu, Jianzhuang and Zou, Changqing}, 99 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 100 | pages={5174--5183}, 101 | year={2020} 102 | } 103 | ``` 104 | -------------------------------------------------------------------------------- /docs/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu-imsl/EdgeGAN/7867a5a35ff860f692cf9a167a22742cb424407e/docs/.DS_Store -------------------------------------------------------------------------------- /docs/404.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | - SketchyCOCO 8 | 9 | 12 | 13 | 14 | 15 | 16 | 17 | SketchyCOCO | Write an awesome description for your new site here. You can edit this line in _config.yml. It will appear in your document head meta (for Google search results) and in your feed.xml site description. 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 28 | 29 | 30 | 31 | 32 | 33 | 64 | 65 | 66 |
67 |
68 |
69 |

70 |

71 | 72 |
73 |
74 |
75 | 76 | 77 | 78 | 79 |
80 |
81 |
82 | 83 |
84 | 85 | 86 | 87 | 100 | 101 |
102 |

404

103 | 104 |

Page not found :(

105 |

The requested page could not be found.

106 |
107 | 108 |
109 | 110 |
111 |

Latest Posts

112 | 113 |
114 | 115 |
116 |
117 | 118 |
119 | Why use a static site generator 120 |
121 | 122 |
123 |
124 | 125 | Why use a static site generator 126 | 127 |

You’ll find this post in your _posts directory. Go ahead and edit it and re-build the site to see your changes. You can rebuild the site in many different ways, but the most common way is to run jekyll serve, which launches a web server and auto-regenerates your site when a file is updated.Jekyll requires blog post files to be named according to the following format: be named according to the following format: be named according to the following format: be named according to the following format: be named according to the following format: be named according to the following format: be named according to the following format: be named according to the following format: be named according to the following format:

128 | 129 |

130 |
131 |
132 | Read more 133 |
134 |
135 |
136 | 137 |
138 |
139 |
140 | 141 |
142 | 143 | 144 | 145 | 146 |
147 | 148 |
149 |
150 |
151 | 152 |
153 |
154 | 155 | 156 | 157 |
158 |

Theme built by C.S. Rhymes

159 |
160 |
161 |
162 | 163 | 164 | 165 | 166 | 167 | -------------------------------------------------------------------------------- /docs/Gemfile: -------------------------------------------------------------------------------- 1 | source 'https://rubygems.org' 2 | gem "bulma-clean-theme", '0.7.2' 3 | gem 'github-pages', group: :jekyll_plugins -------------------------------------------------------------------------------- /docs/_config.yml: -------------------------------------------------------------------------------- 1 | theme: jekyll-theme-cayman -------------------------------------------------------------------------------- /docs/about/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | About - SketchyCOCO 8 | 9 | 12 | 13 | 14 | 15 | 16 | 17 | About | SketchyCOCO 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 28 | 29 | 30 | 31 | 32 | 33 | 64 | 65 | 66 |
67 |
68 |
69 |

About

70 |

71 | 72 |
73 |
74 |
75 | 76 | 77 | 78 | 79 |
80 |
81 |
82 | 83 |
84 | 85 | 86 | 87 | 88 |
89 |

This is the base Jekyll theme. You can find out more info about customizing your Jekyll theme, as well as basic Jekyll usage documentation at jekyllrb.com

90 | 91 |

You can find the source code for Minima at GitHub: 92 | jekyll / 93 | minima

94 | 95 |

You can find the source code for Jekyll at GitHub: 96 | jekyll / 97 | jekyll

98 | 99 | 100 |
101 |
102 | 103 |
104 |

Latest Posts

105 | 106 |
107 | 108 |
109 |
110 | 111 |
112 | Why use a static site generator 113 |
114 | 115 |
116 |
117 | 118 | Why use a static site generator 119 | 120 |

You’ll find this post in your _posts directory. Go ahead and edit it and re-build the site to see your changes. You can rebuild the site in many different ways, but the most common way is to run jekyll serve, which launches a web server and auto-regenerates your site when a file is updated.Jekyll requires blog post files to be named according to the following format: be named according to the following format: be named according to the following format: be named according to the following format: be named according to the following format: be named according to the following format: be named according to the following format: be named according to the following format: be named according to the following format:

121 | 122 |

123 |
124 |
125 | Read more 126 |
127 |
128 |
129 | 130 |
131 |
132 |
133 | 134 |
135 | 136 | 137 | 138 | 139 |
140 | 141 |
142 |
143 |
144 | 145 |
146 |
147 | 148 | 149 | 150 |
151 |

Theme built by C.S. Rhymes

152 |
153 |
154 |
155 | 156 | 157 | 158 | 159 | 160 | -------------------------------------------------------------------------------- /docs/assets/fig1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu-imsl/EdgeGAN/7867a5a35ff860f692cf9a167a22742cb424407e/docs/assets/fig1.jpg -------------------------------------------------------------------------------- /docs/assets/fig2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu-imsl/EdgeGAN/7867a5a35ff860f692cf9a167a22742cb424407e/docs/assets/fig2.jpg -------------------------------------------------------------------------------- /docs/assets/fig3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu-imsl/EdgeGAN/7867a5a35ff860f692cf9a167a22742cb424407e/docs/assets/fig3.jpg -------------------------------------------------------------------------------- /docs/assets/js/app.js: -------------------------------------------------------------------------------- 1 | document.addEventListener('DOMContentLoaded', () => { 2 | 3 | // Get all "navbar-burger" elements 4 | const $navbarBurgers = Array.prototype.slice.call(document.querySelectorAll('.navbar-burger'), 0); 5 | 6 | // Check if there are any navbar burgers 7 | if ($navbarBurgers.length > 0) { 8 | 9 | // Add a click event on each of them 10 | $navbarBurgers.forEach( el => { 11 | el.addEventListener('click', () => { 12 | 13 | // Get the target from the "data-target" attribute 14 | const target = el.dataset.target; 15 | const $target = document.getElementById(target); 16 | 17 | // Toggle the "is-active" class on both the "navbar-burger" and the "navbar-menu" 18 | el.classList.toggle('is-active'); 19 | $target.classList.toggle('is-active'); 20 | 21 | }); 22 | }); 23 | } 24 | 25 | }); 26 | -------------------------------------------------------------------------------- /docs/assets/js/getImages.js: -------------------------------------------------------------------------------- 1 | function onClick(arg) { 2 | var data_ = arg.getAttribute('data'); 3 | var data_type = arg.getAttribute('data-type'); 4 | var output = document.getElementById("output"); 5 | var sketch_path = ''; 6 | var gt_path = ''; 7 | var json_path = ''; 8 | var result_array = []; 9 | if (data_ == "Objects") { 10 | sketch_path = 'Object/Sketch/'; 11 | gt_path = 'Object/GT/'; 12 | json_path = 'https://cdn.jsdelivr.net/gh/sysu-imsl/CDN-for-SketchyCOCO@v1.0/Object.json' 13 | json_path = 'https://cdn.jsdelivr.net/gh/sysu-imsl/CDN-for-SketchyCOCO-added@1.5/Object_total.json' 14 | } else { 15 | sketch_path = 'Scene/Sketch/'; 16 | gt_path = 'Scene/GT/'; 17 | json_path = 'https://cdn.jsdelivr.net/gh/sysu-imsl/CDN-for-SketchyCOCO@v1.0/Scene.json' 18 | } 19 | $.ajax({ 20 | type:'get', 21 | url:json_path, 22 | dataType:'json', 23 | success:function(data){ 24 | console.log(data_type) 25 | if (data_ == "Objects") { 26 | if (data_type.toString() in data['added']) { 27 | var tmp = data['added'][data_type.toString()]; 28 | var tmp1 = data['origin'][data_type.toString()]; 29 | } else { 30 | var tmp = data['origin'][data_type.toString()]; 31 | var tmp1 = 0; 32 | } 33 | 34 | } else { 35 | var tmp = data.data; 36 | var tmp1 = 0; 37 | } 38 | // console.log(tmp) 39 | var counter = 0; 40 | var index = 0; 41 | var names = new Array() 42 | if (tmp1 != 0) { 43 | var pre_head = "https://cdn.jsdelivr.net/gh/sysu-imsl/CDN-for-SketchyCOCO-added@1.5/"; 44 | } else { 45 | var pre_head = "https://cdn.jsdelivr.net/gh/sysu-imsl/CDN-for-SketchyCOCO@v1.0/data/"; 46 | } 47 | for (var n = 0; n\n \n \n " + "\"\"" + "\n "); 53 | result_array.push(" \n " + "\"\"" + "\n "); 54 | } else { 55 | result_array.push(" \n " + "\"\"" + "\n "); 56 | result_array.push(" \n " + "\"\"" + "\n "); 57 | if (counter != 0) { 58 | // result_array.push(" \n sketch(" + names[0] + ")\n ground-truth(" + names[0] + ") sketch("+(names[1])+")\n ground-truth("+(names[1])+")"); 59 | result_array.push(" \n sketch\n ground truth sketch\n ground-truth"); 60 | // result_array.push(" \n " + names[0] + " "+(names[1])+""); 61 | index = index + 2; 62 | } 63 | } 64 | counter++; 65 | } 66 | 67 | if (tmp1 != 0) { 68 | var pre_head = "https://cdn.jsdelivr.net/gh/sysu-imsl/CDN-for-SketchyCOCO@v1.0/data/"; 69 | for (var n = 0; n\n \n \n " + "\"\"" + "\n "); 75 | result_array.push(" \n " + "\"\"" + "\n "); 76 | } else { 77 | result_array.push(" \n " + "\"\"" + "\n "); 78 | result_array.push(" \n " + "\"\"" + "\n "); 79 | if (counter != 0) { 80 | // result_array.push(" \n sketch(" + names[0] + ")\n ground-truth(" + names[0] + ") sketch("+(names[1])+")\n ground-truth("+(names[1])+")"); 81 | result_array.push(" \n sketch\n ground truth sketch\n ground-truth"); 82 | // result_array.push(" \n " + names[0] + " "+(names[1])+""); 83 | index = index + 2; 84 | } 85 | } 86 | counter++; 87 | } 88 | } 89 | if (counter % 2 != 0) { 90 | result_array.push(" \n sketch\n ground truth "); 91 | } 92 | console.log(counter) 93 | output.innerHTML = " \n \n" + result_array.join("\n") + " \n
"; 94 | } 95 | }) 96 | 97 | } 98 | 99 | 100 | $(document).ready(function () { 101 | // document.getElementById("first").click(); 102 | $('#start').trigger("onclick") 103 | }); 104 | 105 | // ele_1009 106 | // ele_67 107 | // ele_1030 108 | // ele_1179 109 | // ele_1080 110 | // ele_1052 111 | // ele_26 112 | // ele_1174 113 | // ele_1029 114 | // ele_758 115 | // ele_62 -------------------------------------------------------------------------------- /docs/dataset.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | SketchyCOCO 8 | 9 | 12 | 13 | 14 | 15 | 16 | 17 | SketchyCOCO 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 63 | 64 | 65 |
66 | 73 |
74 | 75 | 76 | 77 | 78 |
79 |
80 |
81 | 82 |
83 | 84 | 85 | 136 |
137 | 138 |
139 | 140 |
141 | 142 | 143 |
144 |
145 | 146 |
147 |
148 |
149 | 150 |
151 |
152 | 153 | 154 | 155 |
156 |

Theme built by C.S. Rhymes

157 |
158 |
159 |
160 | 161 | 162 | 163 | 164 | 165 | -------------------------------------------------------------------------------- /docs/development/2018/05/28/welcome-to-jekyll.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | Why use a static site generator - SketchyCOCO 8 | 9 | 12 | 13 | 14 | 15 | 16 | 17 | Why use a static site generator | SketchyCOCO 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 32 | 33 | 34 | 35 | 36 | 37 | 68 | 69 | 70 |
71 |
72 |
73 |

Why use a static site generator

74 |

75 | 76 |
77 |
78 |
79 | 80 | 81 | 82 | 83 |
84 |
85 |
86 | 87 |
88 | 89 | 90 | 91 |
92 | 93 |

Published: May 28, 2018 by C.S. Rhymes

94 | 95 |

You’ll find this post in your _posts directory. Go ahead and edit it and re-build the site to see your changes. You can rebuild the site in many different ways, but the most common way is to run jekyll serve, which launches a web server and auto-regenerates your site when a file is updated.Jekyll requires blog post files to be named according to the following format: be named according to the following format: be named according to the following format: be named according to the following format: be named according to the following format: be named according to the following format: be named according to the following format: be named according to the following format: be named according to the following format:

96 | 97 |

YEAR-MONTH-DAY-title.MARKUP

98 | 99 |

Where YEAR is a four-digit number, MONTH and DAY are both two-digit numbers, and MARKUP is the file extension representing the format used in the file. After that, include the necessary front matter. Take a look at the source for this post to get an idea about how it works.

100 | 101 |

Jekyll also offers powerful support for code snippets:

102 | 103 |
def print_hi(name)
104 |   puts "Hi, #{name}"
105 | end
106 | print_hi('Tom')
107 | #=> prints 'Hi, Tom' to STDOUT.
108 | 109 |

Check out the Jekyll docs for more info on how to get the most out of Jekyll. File all bugs/feature requests at Jekyll’s GitHub repo. If you have questions, you can ask them on Jekyll Talk.

110 | 111 | 112 |
113 | 114 |
115 | 116 |
117 | 118 | 119 |

Share

120 | 138 | 139 | 140 | 141 | 142 |
143 | 144 |
145 |

Latest Posts

146 | 147 |
148 | 149 |
150 |
151 | 152 |
153 | Why use a static site generator 154 |
155 | 156 |
157 |
158 | 159 | Why use a static site generator 160 | 161 |

You’ll find this post in your _posts directory. Go ahead and edit it and re-build the site to see your changes. You can rebuild the site in many different ways, but the most common way is to run jekyll serve, which launches a web server and auto-regenerates your site when a file is updated.Jekyll requires blog post files to be named according to the following format: be named according to the following format: be named according to the following format: be named according to the following format: be named according to the following format: be named according to the following format: be named according to the following format: be named according to the following format: be named according to the following format:

162 | 163 |

164 |
165 |
166 | Read more 167 |
168 |
169 |
170 | 171 |
172 |
173 |
174 | 175 |
176 | 177 | 178 | 179 | 180 |
181 | 182 |
183 |
184 |
185 | 186 |
187 |
188 | 189 | 190 | 191 |
192 |

Theme built by C.S. Rhymes

193 |
194 |
195 |
196 | 197 | 198 | 199 | 200 | 201 | -------------------------------------------------------------------------------- /docs/feed.xml: -------------------------------------------------------------------------------- 1 | Jekyll2020-06-11T20:21:45+08:00http://localhost:4000/test_for_building_pages/feed.xmlSketchyCOCOWrite an awesome description for your new site here. You can edit this line in _config.yml. It will appear in your document head meta (for Google search results) and in your feed.xml site description.Why use a static site generator2018-05-28T18:50:07+08:002018-05-28T18:50:07+08:00http://localhost:4000/test_for_building_pages/development/2018/05/28/welcome-to-jekyll<p>You’ll find this post in your <code class="language-plaintext highlighter-rouge">_posts</code> directory. Go ahead and edit it and re-build the site to see your changes. You can rebuild the site in many different ways, but the most common way is to run <code class="language-plaintext highlighter-rouge">jekyll serve</code>, which launches a web server and auto-regenerates your site when a file is updated.Jekyll requires blog post files to be named according to the following format: be named according to the following format: be named according to the following format: be named according to the following format: be named according to the following format: be named according to the following format: be named according to the following format: be named according to the following format: be named according to the following format:</p> 2 | 3 | <p><code class="language-plaintext highlighter-rouge">YEAR-MONTH-DAY-title.MARKUP</code></p> 4 | 5 | <p>Where <code class="language-plaintext highlighter-rouge">YEAR</code> is a four-digit number, <code class="language-plaintext highlighter-rouge">MONTH</code> and <code class="language-plaintext highlighter-rouge">DAY</code> are both two-digit numbers, and <code class="language-plaintext highlighter-rouge">MARKUP</code> is the file extension representing the format used in the file. After that, include the necessary front matter. Take a look at the source for this post to get an idea about how it works.</p> 6 | 7 | <p>Jekyll also offers powerful support for code snippets:</p> 8 | 9 | <figure class="highlight"><pre><code class="language-ruby" data-lang="ruby"><span class="k">def</span> <span class="nf">print_hi</span><span class="p">(</span><span class="nb">name</span><span class="p">)</span> 10 | <span class="nb">puts</span> <span class="s2">"Hi, </span><span class="si">#{</span><span class="nb">name</span><span class="si">}</span><span class="s2">"</span> 11 | <span class="k">end</span> 12 | <span class="n">print_hi</span><span class="p">(</span><span class="s1">'Tom'</span><span class="p">)</span> 13 | <span class="c1">#=&gt; prints 'Hi, Tom' to STDOUT.</span></code></pre></figure> 14 | 15 | <p>Check out the <a href="https://jekyllrb.com/docs/home">Jekyll docs</a> for more info on how to get the most out of Jekyll. File all bugs/feature requests at <a href="https://github.com/jekyll/jekyll">Jekyll’s GitHub repo</a>. If you have questions, you can ask them on <a href="https://talk.jekyllrb.com/">Jekyll Talk</a>.</p>C.S. RhymesYou’ll find this post in your _posts directory. Go ahead and edit it and re-build the site to see your changes. You can rebuild the site in many different ways, but the most common way is to run jekyll serve, which launches a web server and auto-regenerates your site when a file is updated.Jekyll requires blog post files to be named according to the following format: be named according to the following format: be named according to the following format: be named according to the following format: be named according to the following format: be named according to the following format: be named according to the following format: be named according to the following format: be named according to the following format: -------------------------------------------------------------------------------- /docs/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | SketchyCOCO: Image Generation from Freehand Scene Sketches - SketchyCOCO 8 | 9 | 12 | 13 | 14 | 15 | 16 | 17 | SketchyCOCO: Image Generation from Freehand Scene Sketches | SketchyCOCO 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 28 | 29 | 30 | 31 | 32 | 33 | 64 | 65 | 66 |
67 |
68 |
69 |

SketchyCOCO: Image Generation from Freehand Scene Sketches

70 |

71 | 72 |
73 |
74 |
75 | 76 | 77 | 78 | 79 |
80 |
81 |
82 | 83 |
84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 |
92 | 93 | 94 | 95 |
96 | 97 |
98 |
99 |
100 | 101 |
102 | 103 | 104 |
105 |
106 | 107 |

SketchyCOCO: Image Generation from Freehand Scene Sketches

108 | 109 |

We introduce the first method for automatic image generation from scene-level freehand sketches. Our model allows for controllable image generation by specifying thesynthesis goal via freehand sketches. The key contribution is an attribute vector bridged Generative Adversarial Network called EdgeGAN, which supports high visual-quality object-level image content generation without using freehand sketches as training data. We have built a largescale composite dataset called SketchyCOCO to support and evaluate the solution. We validate our approach on the tasks of both object-level and scene-level image generation on SketchyCOCO. Through quantitative, qualitative results, human evaluation and ablation studies, we demonstrate the method’s capacity to generate realistic complex scene-level images from various freehand sketches.

110 | 111 |

112 |
113 |
114 | Github 115 | Paper Download 116 |
117 | 118 |
119 |
120 | 121 |
122 |
123 |
124 | 125 |
126 |
127 |
128 | 129 |
130 | 131 | 132 |
133 |
134 | 135 |

SketchyCOCO Dataset

136 | 137 |

In summary, we collect 20198(18869+1329) triplets of {foreground sketch, foreground image, foreground edge map} examples covering 14 classes, 27683(22171+5512) pairs of {background sketch, background image} examples covering 3 classes, 14081(11265+2816) pairs of {foreground image&background sketch, scene image} examples, 14081(11265+2816) pairs of {scene sketch, scene image} examples, and the segmentation ground truth for 14081(11265+2816) scene sketches.

138 | 139 |

140 |
141 |
142 | Github 143 | Data Exploration 144 |
145 | 146 |
147 |
148 | 149 |
150 |
151 |
152 | 153 | 154 | 155 | 156 |
157 |
158 | 164 |
165 |
166 |
167 | 168 | 169 | 170 |
171 | 172 |
173 |
174 | 175 | 176 | 177 |
178 |

Theme built by C.S. Rhymes

179 |
180 |
181 |
182 | 183 | 184 | 185 | 186 | 187 | -------------------------------------------------------------------------------- /docs/index2.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | SketchyCOCO - SketchyCOCO 8 | 9 | 12 | 13 | 14 | 15 | 16 | 17 | SketchyCOCO | Write an awesome description for your new site here. You can edit this line in _config.yml. It will appear in your document head meta (for Google search results) and in your feed.xml site description. 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 28 | 29 | 30 | 31 | 32 | 33 | 64 | 65 | 66 |
67 |
68 |
69 |

SketchyCOCO

70 |

Page Subtitle

71 | 72 |
73 |
74 |
75 | 76 | 77 | 78 | 79 |
80 |
81 |
82 | 83 |
84 | 85 | 86 | 87 | 88 | 89 | 90 |
91 |
92 | 103 | 104 |
105 | 106 |
107 | 118 | 119 |
120 |
121 | 127 |
128 |
129 |
130 | 131 |
132 |

Latest Posts

133 | 134 |
135 | 136 |
137 |
138 | 139 |
140 | Why use a static site generator 141 |
142 | 143 |
144 |
145 | 146 | Why use a static site generator 147 | 148 |

You’ll find this post in your _posts directory. Go ahead and edit it and re-build the site to see your changes. You can rebuild the site in many different ways, but the most common way is to run jekyll serve, which launches a web server and auto-regenerates your site when a file is updated.Jekyll requires blog post files to be named according to the following format: be named according to the following format: be named according to the following format: be named according to the following format: be named according to the following format: be named according to the following format: be named according to the following format: be named according to the following format: be named according to the following format:

149 | 150 |

151 |
152 |
153 | Read more 154 |
155 |
156 |
157 | 158 |
159 |
160 |
161 | 162 |
163 | 164 | 165 | 166 | 167 |
168 | 169 |
170 |
171 |
172 | 173 |
174 |
175 | 176 | 177 | 178 |
179 |

Theme built by C.S. Rhymes

180 |
181 |
182 |
183 | 184 | 185 | 186 | 187 | 188 | -------------------------------------------------------------------------------- /docs/robots.txt: -------------------------------------------------------------------------------- 1 | Sitemap: http://localhost:4000/test_for_building_pages/sitemap.xml 2 | -------------------------------------------------------------------------------- /docs/sitemap.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | http://localhost:4000/test_for_building_pages/development/2018/05/28/welcome-to-jekyll.html 5 | 2018-05-28T18:50:07+08:00 6 | 7 | 8 | http://localhost:4000/test_for_building_pages/about/ 9 | 10 | 11 | http://localhost:4000/test_for_building_pages/dataset.html 12 | 13 | 14 | http://localhost:4000/test_for_building_pages/ 15 | 16 | 17 | http://localhost:4000/test_for_building_pages/index2.html 18 | 19 | 20 | -------------------------------------------------------------------------------- /edgegan/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu-imsl/EdgeGAN/7867a5a35ff860f692cf9a167a22742cb424407e/edgegan/__init__.py -------------------------------------------------------------------------------- /edgegan/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .generator import Generator 2 | from .discriminator import Discriminator 3 | from .encoder import Encoder 4 | from .classifier import Classifier 5 | from .edgegan import EdgeGAN 6 | -------------------------------------------------------------------------------- /edgegan/models/classifier.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from edgegan import nn 4 | 5 | 6 | class Classifier(object): 7 | def __init__(self, name, SPECTRAL_NORM_UPDATE_OPS): 8 | print(' [*] Init Discriminator %s', name) 9 | self.name = name 10 | self.SPECTRAL_NORM_UPDATE_OPS = SPECTRAL_NORM_UPDATE_OPS 11 | 12 | def __call__(self, x, num_classes, labels=None, reuse=False, data_format='NCHW'): 13 | assert data_format == 'NCHW' 14 | size = 64 15 | num_blocks = 1 16 | resize_func = tf.image.resize_bilinear 17 | sn = True 18 | 19 | if data_format == 'NCHW': 20 | channel_axis = 1 21 | else: 22 | channel_axis = 3 23 | if type(x) is list: 24 | x = x[-1] 25 | 26 | if data_format == 'NCHW': 27 | x_list = [] 28 | resized_ = x 29 | x_list.append(resized_) 30 | 31 | for i in range(5): 32 | resized_ = nn.mean_pool( 33 | resized_, data_format=data_format) 34 | x_list.append(resized_) 35 | x_list = x_list[::-1] 36 | else: 37 | raise NotImplementedError 38 | 39 | output_dim = 1 40 | 41 | activation_fn_d = nn.prelu 42 | normalizer_fn_d = None 43 | normalizer_params_d = None 44 | weight_initializer = tf.random_normal_initializer(0, 0.02) 45 | 46 | with tf.variable_scope(self.name) as scope: 47 | if reuse: 48 | scope.reuse_variables() 49 | 50 | h0 = nn.conv2d2(x_list[-1], 8, kernel_size=7, sn=sn, stride=1, data_format=data_format, 51 | activation_fn=activation_fn_d, 52 | normalizer_fn=normalizer_fn_d, 53 | normalizer_params=normalizer_params_d, 54 | weights_initializer=weight_initializer) 55 | 56 | # Initial memory state 57 | hidden_state_shape = h0.get_shape().as_list() 58 | batch_size = hidden_state_shape[0] 59 | hidden_state_shape[0] = 1 60 | hts_0 = [h0] 61 | for i in range(1, num_blocks): 62 | h0 = tf.tile(tf.get_variable("initial_hidden_state_%d" % i, shape=hidden_state_shape, dtype=tf.float32, 63 | initializer=tf.zeros_initializer()), [batch_size, 1, 1, 1]) 64 | hts_0.append(h0) 65 | 66 | hts_1 = nn.mru_conv(x_list[-1], hts_0, 67 | size * 2, sn=sn, stride=2, dilate_rate=1, 68 | data_format=data_format, num_blocks=num_blocks, 69 | last_unit=False, 70 | activation_fn=activation_fn_d, 71 | normalizer_fn=normalizer_fn_d, 72 | normalizer_params=normalizer_params_d, 73 | weights_initializer=weight_initializer, 74 | unit_num=1) 75 | hts_2 = nn.mru_conv(x_list[-2], hts_1, 76 | size * 4, sn=sn, stride=2, dilate_rate=1, 77 | data_format=data_format, num_blocks=num_blocks, 78 | last_unit=False, 79 | activation_fn=activation_fn_d, 80 | normalizer_fn=normalizer_fn_d, 81 | normalizer_params=normalizer_params_d, 82 | weights_initializer=weight_initializer, 83 | unit_num=2) 84 | hts_3 = nn.mru_conv(x_list[-3], hts_2, 85 | size * 8, sn=sn, stride=2, dilate_rate=1, 86 | data_format=data_format, num_blocks=num_blocks, 87 | last_unit=False, 88 | activation_fn=activation_fn_d, 89 | normalizer_fn=normalizer_fn_d, 90 | normalizer_params=normalizer_params_d, 91 | weights_initializer=weight_initializer, 92 | unit_num=3) 93 | hts_4 = nn.mru_conv(x_list[-4], hts_3, 94 | size * 12, sn=sn, stride=2, dilate_rate=1, 95 | data_format=data_format, num_blocks=num_blocks, 96 | last_unit=True, 97 | activation_fn=activation_fn_d, 98 | normalizer_fn=normalizer_fn_d, 99 | normalizer_params=normalizer_params_d, 100 | weights_initializer=weight_initializer, 101 | unit_num=4) 102 | 103 | img = hts_4[-1] 104 | img_shape = img.get_shape().as_list() 105 | 106 | # discriminator end 107 | disc = nn.conv2d2(img, output_dim, kernel_size=1, sn=sn, stride=1, data_format=data_format, 108 | activation_fn=None, normalizer_fn=None, 109 | weights_initializer=weight_initializer) 110 | 111 | # classification end 112 | img = tf.reduce_mean(img, axis=( 113 | 2, 3) if data_format == 'NCHW' else (1, 2)) 114 | logits = nn.fully_connected(img, num_classes, sn=sn, activation_fn=None, 115 | normalizer_fn=None, SPECTRAL_NORM_UPDATE_OPS=self.SPECTRAL_NORM_UPDATE_OPS) 116 | 117 | self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 118 | self.name) 119 | return disc, tf.nn.sigmoid(logits), logits 120 | -------------------------------------------------------------------------------- /edgegan/models/discriminator.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from edgegan import nn 3 | 4 | 5 | class Discriminator(object): 6 | def __init__(self, name, is_train, norm='batch', activation='lrelu', 7 | num_filters=64, use_resnet=False): 8 | print(' [*] Init Discriminator %s', name) 9 | self._num_filters = num_filters 10 | self.name = name 11 | self._is_train = is_train 12 | self._norm = norm 13 | self._activation = activation 14 | self._use_resnet = use_resnet 15 | self._reuse = False 16 | 17 | def __call__(self, input, reuse=False): 18 | if self._use_resnet: 19 | return self._resnet(input) 20 | else: 21 | return self._convnet(input) 22 | 23 | def _resnet(self, input): 24 | # return None 25 | with tf.variable_scope(self.name, reuse=self._reuse): 26 | D = nn.residual2(input, self._num_filters, 'd_resnet_0', 3, 1, 27 | self._is_train, self._reuse, norm=None, 28 | activation=self._activation) 29 | D = tf.nn.avg_pool(D, [1, 2, 2, 1], [1, 2, 2, 1], 'SAME') 30 | 31 | D = nn.residual2(D, self._num_filters*2, 'd_resnet_1', 3, 1, 32 | self._is_train, self._reuse, self._norm, 33 | self._activation) 34 | D = tf.nn.avg_pool(D, [1, 2, 2, 1], [1, 2, 2, 1], 'SAME') 35 | 36 | D = nn.residual2(D, self._num_filters*4, 'd_resnet_3', 3, 1, 37 | self._is_train, self._reuse, self._norm, 38 | self._activation) 39 | D = tf.nn.avg_pool(D, [1, 2, 2, 1], [1, 2, 2, 1], 'SAME') 40 | 41 | D = nn.residual2(D, self._num_filters*8, 'd_resnet_4', 3, 1, 42 | self._is_train, self._reuse, self._norm, 43 | self._activation) 44 | D = tf.nn.avg_pool(D, [1, 2, 2, 1], [1, 2, 2, 1], 'SAME') 45 | 46 | D = nn.activation_fn(D, self._activation) 47 | D = tf.nn.avg_pool(D, [1, 8, 8, 1], [1, 8, 8, 1], 'SAME') 48 | 49 | D = nn.linear(tf.reshape(D, [input.get_shape()[0], -1]), 1, 50 | name='d_linear_resnet_5') 51 | 52 | self._reuse = True 53 | self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 54 | self.name) 55 | 56 | return tf.nn.sigmoid(D), D 57 | 58 | def _convnet(self, input): 59 | 60 | with tf.variable_scope(self.name, reuse=self._reuse): 61 | D = nn.conv_block(input, self._num_filters, 'd_conv_0', 4, 2, 62 | self._is_train, self._reuse, norm=None, 63 | activation=self._activation) 64 | D = nn.conv_block(D, self._num_filters*2, 'd_conv_1', 4, 2, 65 | self._is_train, self._reuse, self._norm, 66 | self._activation) 67 | D = nn.conv_block(D, self._num_filters*4, 'd_conv_3', 4, 2, 68 | self._is_train, self._reuse, self._norm, 69 | self._activation) 70 | D = nn.conv_block(D, self._num_filters*8, 'd_conv_4', 4, 2, 71 | self._is_train, self._reuse, self._norm, 72 | self._activation) 73 | 74 | D = nn.linear(tf.reshape(D, [input.get_shape()[0], -1]), 1, 75 | name='d_linear_5') 76 | 77 | self._reuse = True 78 | self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 79 | self.name) 80 | 81 | return tf.nn.sigmoid(D), D 82 | -------------------------------------------------------------------------------- /edgegan/models/edgegan.py: -------------------------------------------------------------------------------- 1 | from __future__ import division, print_function 2 | 3 | import math 4 | import os 5 | import pickle 6 | import sys 7 | import time 8 | from glob import glob 9 | from functools import partial 10 | 11 | import numpy as np 12 | import tensorflow as tf 13 | 14 | import edgegan.nn.functional as F 15 | from edgegan import nn, utils 16 | 17 | from .classifier import Classifier 18 | from .discriminator import Discriminator 19 | from .encoder import Encoder 20 | from .generator import Generator 21 | 22 | 23 | def pathsplit(path): 24 | path = os.path.normpath(path) 25 | return path.split(os.sep) 26 | 27 | 28 | def channel_first(input): 29 | return tf.transpose(input, [0, 3, 1, 2]) 30 | 31 | 32 | def random_blend(a, b, batchsize): 33 | alpha_dist = tf.contrib.distributions.Uniform(low=0., high=1.) 34 | alpha = alpha_dist.sample((batchsize, 1, 1, 1)) 35 | return b + alpha * (a - b) 36 | 37 | 38 | def penalty(synthesized, real, nn_func, batchsize, weight): 39 | assert callable(nn_func) 40 | interpolated = random_blend(synthesized, real, batchsize) 41 | inte_logit = nn_func(interpolated, reuse=True) 42 | return weight * F.gradient_penalty(inte_logit, interpolated) 43 | 44 | 45 | class EdgeGAN(object): 46 | def __init__(self, sess, config, dataset, 47 | z_dim=100, gf_dim=64, df_dim=64, 48 | gfc_dim=1024, dfc_dim=1024, c_dim=3): 49 | """ 50 | 51 | Args: 52 | sess: TensorFlow session 53 | batch_size: The size of batch. Should be specified before training. 54 | z_dim: (optional) Dimension of dim for Z. [100] 55 | gf_dim: (optional) Dimension of gen filters in first conv layer. [64] 56 | df_dim: (optional) Dimension of discrim filters in first conv layer. [64] 57 | gfc_dim: (optional) Dimension of gen units for for fully connected layer. [1024] 58 | dfc_dim: (optional) Dimension of discrim units for fully connected layer. [1024] 59 | c_dim: (optional) Dimension of image color. For grayscale input, set to 1. [3] 60 | """ 61 | self.sess = sess 62 | self.config = config 63 | 64 | self.z_dim = z_dim 65 | 66 | self.gf_dim = gf_dim 67 | self.df_dim = df_dim 68 | 69 | self.gfc_dim = gfc_dim 70 | self.dfc_dim = dfc_dim 71 | 72 | self.c_dim = c_dim 73 | self.optimizers = [] 74 | self.dataset = dataset 75 | 76 | def register_optim_if(self, name, optims, cond=True, repeat=1): 77 | if not cond: 78 | return 79 | optims = optims if isinstance(optims, list) else [optims, ] 80 | optims = [item() for item in optims] 81 | self.optimizers.append({ 82 | 'name': name, 83 | 'optims': optims, 84 | 'repeat': repeat, 85 | }) 86 | 87 | def construct_optimizers(self): 88 | class Empty: 89 | pass 90 | 91 | def var_list(module_name): 92 | try: 93 | return getattr(self, module_name).var_list 94 | except: 95 | return Empty() 96 | 97 | def loss(loss_name): 98 | try: 99 | return getattr(self, loss_name) 100 | except: 101 | return Empty() 102 | 103 | def optim_creator(loss_name, module_name): 104 | return partial( 105 | tf.train.RMSPropOptimizer(self.config.learning_rate).minimize, 106 | loss=loss(loss_name), var_list=var_list(module_name) 107 | ) 108 | 109 | self.register_optim_if('d_optim', optim_creator( 110 | 'joint_dis_dloss', 'joint_discriminator')) 111 | self.register_optim_if('d_optim_patch2', optim_creator( 112 | 'image_dis_dloss', 'image_discriminator'), self.config.use_image_discriminator) 113 | self.register_optim_if('d_optim_patch3', optim_creator( 114 | 'edge_dis_dloss', 'edge_discriminator'), self.config.use_edge_discriminator) 115 | self.register_optim_if('d_optim2', optim_creator( 116 | 'loss_d_ac', 'classifier'), self.config.multiclasses) 117 | edge_goptim = tf.train.RMSPropOptimizer(self.config.learning_rate).minimize( 118 | self.edge_gloss, var_list=self.edge_generator.var_list) 119 | image_goptim = tf.train.RMSPropOptimizer(self.config.learning_rate).minimize( 120 | self.image_gloss, var_list=self.image_generator.var_list) 121 | g_optim = [edge_goptim, image_goptim] 122 | self.register_optim_if('g_optim_u', lambda: g_optim) 123 | self.register_optim_if('e_optim', optim_creator('zl_loss', 'encoder')) 124 | self.register_optim_if('g_optim_b', lambda: g_optim) 125 | 126 | def update_model(self, images, z): 127 | for optim_param in self.optimizers: 128 | optims, repeat = optim_param['optims'], optim_param['repeat'] 129 | for _ in range(repeat): 130 | _ = self.sess.run(optims, {self.inputs: images, self.z: z}) 131 | 132 | def build_networks(self): 133 | self.edge_generator = Generator('G1', is_train=True, 134 | norm=self.config.G_norm, 135 | batch_size=self.config.batch_size, 136 | output_height=self.config.output_height, 137 | output_width=int( 138 | self.config.output_width/2), 139 | input_dim=self.gf_dim, 140 | output_dim=self.c_dim, 141 | use_resnet=self.config.if_resnet_g) 142 | self.image_generator = Generator('G2', is_train=True, 143 | norm=self.config.G_norm, 144 | batch_size=self.config.batch_size, 145 | output_height=self.config.output_height, 146 | output_width=int( 147 | self.config.output_width/2), 148 | input_dim=self.gf_dim, 149 | output_dim=self.c_dim, 150 | use_resnet=self.config.if_resnet_g) 151 | 152 | if self.config.multiclasses: 153 | self.classifier = Classifier( 154 | 'D2', self.config.SPECTRAL_NORM_UPDATE_OPS) 155 | 156 | self.joint_discriminator = Discriminator('D', is_train=True, 157 | norm=self.config.D_norm, 158 | num_filters=self.df_dim, 159 | use_resnet=self.config.if_resnet_d) 160 | 161 | if self.config.use_image_discriminator is True: 162 | self.image_discriminator = Discriminator('D_patch2', is_train=True, 163 | norm=self.config.D_norm, 164 | num_filters=self.df_dim, 165 | use_resnet=self.config.if_resnet_d) 166 | 167 | if self.config.use_edge_discriminator is True: 168 | self.edge_discriminator = Discriminator('D_patch3', is_train=True, 169 | norm=self.config.D_norm, 170 | num_filters=self.df_dim, 171 | use_resnet=self.config.if_resnet_d) 172 | 173 | self.encoder = Encoder('E', is_train=True, 174 | norm=self.config.E_norm, 175 | image_size=self.config.input_height, 176 | latent_dim=self.z_dim, 177 | use_resnet=self.config.if_resnet_e) 178 | 179 | def define_inputs(self): 180 | image_size = ( 181 | [self.config.output_height, self.config.output_width, ] if self.config.crop 182 | else [self.config.input_height, self.config.input_width, ] 183 | ) 184 | self.image_dims = image_size + [self.c_dim] 185 | self.inputs = tf.placeholder( 186 | tf.float32, [self.config.batch_size] + self.image_dims, name='real_images') 187 | 188 | if self.config.multiclasses: 189 | self.z = tf.placeholder( 190 | tf.float32, [None, self.z_dim+1], name='z') 191 | class_onehot = tf.one_hot( 192 | tf.cast(self.z[:, -1], dtype=tf.int32), 193 | self.config.num_classes, 194 | on_value=1., off_value=0., dtype=tf.float32 195 | ) 196 | self.z_onehot = tf.concat( 197 | [self.z[:, 0:self.z_dim], class_onehot], 1) 198 | else: 199 | self.z = tf.placeholder( 200 | tf.float32, [None, self.z_dim], name='z') 201 | 202 | def forward(self): 203 | def split_inputs(inputs): 204 | begin = int(self.config.output_width / 2) 205 | end = self.config.output_width 206 | return ( 207 | inputs[:, :, :begin, :], 208 | inputs[:, :, begin: end:] 209 | ) 210 | 211 | def resize(inputs, size): 212 | return tf.image.resize_images( 213 | inputs, [size, size], method=2) 214 | 215 | def delete_it_later(): 216 | pictures = self.inputs[:, :, int( 217 | self.config.output_width / 2):self.config.output_width, :] 218 | _ = tf.image.resize_images(pictures, [self.config.edge_dis_size, self.config.edge_dis_size], 219 | method=2) 220 | 221 | if self.config.multiclasses: 222 | self.edge_output = self.edge_generator(self.z_onehot) 223 | self.image_output = self.image_generator(self.z_onehot) 224 | else: 225 | self.edge_output = self.edge_generator(self.z) 226 | self.image_output = self.image_generator(self.z) 227 | 228 | if self.config.multiclasses: 229 | def classify(inputs, reuse): 230 | _, _, result = self.classifier( 231 | channel_first(inputs), 232 | num_classes=self.config.num_classes, 233 | labels=self.z[:, -1], 234 | reuse=reuse, data_format='NCHW' 235 | ) 236 | return result 237 | 238 | self.trueimage_class_output = classify( 239 | self.inputs[:, :, int(self.image_dims[1]/2):, :], reuse=False) 240 | self.fakeimage_class_output = classify( 241 | self.image_output, reuse=True) 242 | 243 | self.D, self.truejoint_dis_output = self.joint_discriminator( 244 | self.inputs) 245 | self.joint_output = tf.concat([self.edge_output, self.image_output], 2) 246 | self.D_, self.fakejoint_dis_output = self.joint_discriminator( 247 | self.joint_output, reuse=True) 248 | 249 | if self.config.use_image_discriminator: 250 | pictures = self.inputs[:, :, int( 251 | self.config.output_width / 2):self.config.output_width, :] 252 | self.resized_inputs = resize(pictures, self.config.image_dis_size) 253 | self.imageD, self.trueimage_dis_output = self.image_discriminator( 254 | self.resized_inputs) 255 | 256 | self.resized_image_output = resize( 257 | self.image_output, self.config.image_dis_size) 258 | self.imageDfake, self.fakeimage_dis_output = self.image_discriminator( 259 | self.resized_image_output, reuse=True) 260 | 261 | if self.config.use_edge_discriminator: 262 | edges = self.inputs[:, :, 0:int( 263 | self.config.output_width / 2), :] 264 | self.resized_edges = resize(edges, self.config.edge_dis_size) 265 | 266 | delete_it_later() 267 | self.edgeD, self.trueedge_dis_output = self.edge_discriminator( 268 | self.resized_edges) 269 | 270 | self.resized_edge_output = resize( 271 | self.edge_output, self.config.edge_dis_size) 272 | self.edgeDfake, self.fakeedge_dis_output = self.edge_discriminator( 273 | self.resized_edge_output, reuse=True) 274 | 275 | self.z_recon, _, _ = self.encoder(self.edge_output) 276 | 277 | def define_losses(self): 278 | self.joint_dis_dloss = ( 279 | F.discriminator_ganloss(self.fakejoint_dis_output, self.truejoint_dis_output) + 280 | penalty( 281 | self.joint_output, self.inputs, self.joint_discriminator, 282 | self.config.batch_size, self.config.lambda_gp 283 | ) 284 | ) 285 | self.joint_dis_gloss = F.generator_ganloss(self.fakejoint_dis_output) 286 | 287 | if self.config.use_image_discriminator: 288 | self.image_dis_dloss = ( 289 | F.discriminator_ganloss(self.fakeimage_dis_output, self.trueimage_dis_output) + 290 | penalty( 291 | self.resized_image_output, self.resized_inputs, self.image_discriminator, 292 | self.config.batch_size, self.config.lambda_gp 293 | ) 294 | ) 295 | self.image_dis_gloss = F.generator_ganloss( 296 | self.fakeimage_dis_output) 297 | else: 298 | self.image_dis_dloss = 0 299 | self.image_dis_gloss = 0 300 | 301 | if self.config.use_edge_discriminator: 302 | self.edge_dis_dloss = ( 303 | F.discriminator_ganloss(self.fakeedge_dis_output, self.trueedge_dis_output) + 304 | penalty( 305 | self.resized_edge_output, self.resized_edges, self.edge_discriminator, 306 | self.config.batch_size, self.config.lambda_gp 307 | ) 308 | ) 309 | self.edge_dis_gloss = F.generator_ganloss(self.fakeedge_dis_output) 310 | else: 311 | self.edge_dis_dloss = 0 312 | self.edge_dis_gloss = 0 313 | 314 | self.edge_gloss = ( 315 | self.config.joint_dweight * self.joint_dis_gloss + 316 | self.config.edge_dweight * self.edge_dis_gloss + 0 317 | ) 318 | self.image_gloss = ( 319 | self.config.joint_dweight * self.joint_dis_gloss + 0 + 320 | self.config.image_dweight * self.image_dis_gloss + 0 + 0 321 | ) 322 | 323 | # focal loss 324 | if self.config.multiclasses: 325 | self.loss_g_ac, self.loss_d_ac = F.get_acgan_loss_focal( 326 | self.trueimage_class_output, tf.cast( 327 | self.z[:, -1], dtype=tf.int32), 328 | self.fakeimage_class_output, tf.cast( 329 | self.z[:, -1], dtype=tf.int32), 330 | num_classes=self.config.num_classes) 331 | 332 | self.image_gloss += self.loss_g_ac 333 | else: 334 | self.loss_g_ac = 0 335 | self.loss_d_ac = 0 336 | 337 | z_target = self.z[:, 338 | :self.z_dim] if self.config.multiclasses else self.z 339 | self.zl_loss = F.l1loss( 340 | z_target, self.z_recon, 341 | weight=self.config.stage1_zl_loss 342 | ) 343 | 344 | def define_summaries(self): 345 | self.z_sum = nn.histogram_summary("z", self.z) 346 | self.inputs_sum = nn.image_summary("inputs", self.inputs) 347 | 348 | self.G1_sum = nn.image_summary("G1", self.edge_output) 349 | self.G2_sum = nn.image_summary("G2", self.image_output) 350 | 351 | self.g1_loss_sum = nn.scalar_summary("edge_gloss", self.edge_gloss) 352 | self.g2_loss_sum = nn.scalar_summary("image_gloss", self.image_gloss) 353 | 354 | self.g_loss_sum = nn.scalar_summary( 355 | "joint_dis_gloss", self.joint_dis_gloss) 356 | 357 | self.d_loss_sum = nn.scalar_summary( 358 | "joint_dis_dloss", self.joint_dis_dloss) 359 | 360 | self.zl_loss_sum = nn.scalar_summary("zl_loss", self.zl_loss) 361 | 362 | self.loss_g_ac_sum = nn.scalar_summary( 363 | "loss_g_ac", self.loss_g_ac) 364 | self.loss_d_ac_sum = nn.scalar_summary( 365 | "loss_d_ac", self.loss_d_ac) 366 | 367 | self.g_sum = nn.merge_summary([self.z_sum, self.G1_sum, self.G2_sum, 368 | self.zl_loss_sum, self.g_loss_sum, 369 | self.loss_g_ac_sum, self.g1_loss_sum, self.g2_loss_sum]) 370 | self.d_sum = nn.merge_summary([self.z_sum, self.inputs_sum, 371 | self.d_loss_sum, self.loss_d_ac_sum]) 372 | self.d_sum_tmp = nn.histogram_summary("d", self.D) 373 | self.d__sum_tmp = nn.histogram_summary("d_", self.D_) 374 | self.g_sum = nn.merge_summary([self.g_sum, self.d__sum_tmp]) 375 | self.d_sum = nn.merge_summary([self.d_sum, self.d_sum_tmp]) 376 | 377 | if self.config.use_image_discriminator: 378 | self.d_patch2_sum = nn.histogram_summary( 379 | "imageD", self.imageD) 380 | self.d__patch2_sum = nn.histogram_summary( 381 | "imageDfake", self.imageDfake) 382 | self.resized_inputs_sum = nn.image_summary( 383 | "resized_inputs_image", self.resized_inputs) 384 | self.resized_G_sum = nn.image_summary( 385 | "resized_G_image", self.resized_image_output) 386 | self.d_loss_patch2_sum = nn.scalar_summary( 387 | "image_dis_dloss", self.image_dis_dloss) 388 | self.g_loss_patch2_sum = nn.scalar_summary( 389 | "image_dis_gloss", self.image_dis_gloss) 390 | self.g_sum = nn.merge_summary( 391 | [self.g_sum, self.d__patch2_sum, self.resized_G_sum, self.g_loss_patch2_sum]) 392 | self.d_sum = nn.merge_summary( 393 | [self.d_sum, self.d_patch2_sum, self.resized_inputs_sum, self.d_loss_patch2_sum]) 394 | 395 | if self.config.use_edge_discriminator: 396 | self.d_patch3_sum = nn.histogram_summary( 397 | "edgeD", self.edgeD) 398 | self.d__patch3_sum = nn.histogram_summary( 399 | "edgeDfake", self.edgeDfake) 400 | self.resized_inputs_p3_sum = nn.image_summary( 401 | "resized_inputs_p3_image", self.resized_edges) 402 | self.resized_G_p3_sum = nn.image_summary( 403 | "resized_G_p3_image", self.resized_edge_output) 404 | self.d_loss_patch3_sum = nn.scalar_summary( 405 | "edge_dis_dloss", self.edge_dis_dloss) 406 | self.g_loss_patch3_sum = nn.scalar_summary( 407 | "edge_dis_gloss", self.edge_dis_gloss) 408 | self.g_sum = nn.merge_summary( 409 | [self.g_sum, self.d__patch3_sum, self.resized_G_p3_sum, self.g_loss_patch3_sum]) 410 | self.d_sum = nn.merge_summary( 411 | [self.d_sum, self.d_patch3_sum, self.resized_inputs_p3_sum, self.d_loss_patch3_sum]) 412 | 413 | def build_train_model(self): 414 | self.build_networks() 415 | self.define_inputs() 416 | self.forward() 417 | self.define_losses() 418 | self.construct_optimizers() 419 | self.define_summaries() 420 | 421 | self.saver = tf.train.Saver() 422 | 423 | utils.show_all_variables() 424 | 425 | def train(self): 426 | 427 | def add_summary(images, z, counter): 428 | discriminator_summary = self.sess.run( 429 | self.d_sum, feed_dict={self.inputs: images, self.z: z}) 430 | self.writer.add_summary(discriminator_summary, counter) 431 | generator_summary = self.sess.run( 432 | self.g_sum, feed_dict={self.inputs: images, self.z: z}) 433 | self.writer.add_summary(generator_summary, counter) 434 | 435 | self.build_train_model() 436 | 437 | try: 438 | tf.global_variables_initializer().run() 439 | except: 440 | tf.initialize_all_variables().run() 441 | 442 | # init summary writer 443 | self.writer = nn.SummaryWriter(self.config.logdir, self.sess.graph) 444 | 445 | counter = 1 446 | start_time = time.time() 447 | loaded, checkpoint_counter = self.load( 448 | self.saver, self.config.checkpoint_dir) 449 | if loaded: 450 | counter = checkpoint_counter 451 | print(" [*] Load SUCCESS") 452 | else: 453 | print(" [!] Load failed...") 454 | 455 | # train 456 | for epoch in range(self.config.epoch): 457 | self.dataset.shuffle() 458 | for idx in range(len(self.dataset)): 459 | batch_images, batch_z, batch_files = self.dataset[idx] 460 | 461 | self.update_model(batch_images, batch_z) 462 | add_summary(batch_images, batch_z, counter) 463 | 464 | def evaluate(obj): 465 | return getattr(obj, 'eval')( 466 | {self.inputs: batch_images, self.z: batch_z}) 467 | 468 | joint_dloss = evaluate(self.joint_dis_dloss) 469 | if self.config.use_image_discriminator: 470 | image_dloss = evaluate(self.image_dis_dloss) 471 | else: 472 | image_dloss = 0 473 | if self.config.use_edge_discriminator: 474 | edge_dloss = evaluate(self.edge_dis_dloss) 475 | else: 476 | edge_dloss = 0 477 | discriminator_err = joint_dloss + 0 + image_dloss + edge_dloss 478 | 479 | g1loss = evaluate(self.edge_gloss) 480 | g2loss = evaluate(self.image_gloss) 481 | generator_err = g1loss + g2loss 482 | 483 | counter += 1 484 | print("Epoch: [%2d/%2d] [%4d/%4d] time: %4.4f, joint_dis_dloss: %.8f, joint_dis_gloss: %.8f" 485 | % (epoch, self.config.epoch, idx, len(self.dataset), 486 | time.time() - start_time, 2 * discriminator_err, generator_err)) 487 | if np.mod(counter, self.config.save_checkpoint_frequency) == 2: 488 | self.save(self.saver, self.config.checkpoint_dir, 489 | counter) 490 | 491 | 492 | def define_test_input(self): 493 | if self.config.crop: 494 | self.image_dims = [self.config.output_height, self.config.output_width, 495 | self.c_dim] 496 | else: 497 | self.image_dims = [self.config.input_height, self.config.input_width, 498 | self.c_dim] 499 | self.inputs = tf.placeholder( 500 | tf.float32, [None] + self.image_dims, name='real_images') 501 | 502 | self.input_left = self.inputs[ 503 | ..., :self.image_dims[0], 504 | :self.image_dims[1]//2, :self.image_dims[2] 505 | ] 506 | self.z, _, _ = self.encoder(self.input_left) 507 | if self.config.multiclasses: 508 | self.classes = tf.placeholder( 509 | tf.int32, shape=[None], name='classes') 510 | self.class_onehot = tf.one_hot( 511 | self.classes, self.config.num_classes, 512 | on_value=1., off_value=0., dtype=tf.float32 513 | ) 514 | self.z = tf.concat([self.z, self.class_onehot], 1) 515 | 516 | self.edge_output = self.edge_generator(self.z) 517 | self.image_output = self.image_generator(self.z) 518 | 519 | def build_test_model(self): 520 | self.encoder = Encoder('E', is_train=True, 521 | norm=self.config.E_norm, 522 | image_size=self.config.input_height, 523 | latent_dim=self.z_dim, 524 | use_resnet=self.config.if_resnet_e) 525 | 526 | self.edge_generator = Generator('G1', is_train=False, 527 | norm=self.config.G_norm, 528 | batch_size=self.config.batch_size, 529 | output_height=self.config.output_height, 530 | output_width=int( 531 | self.config.output_width/2), 532 | input_dim=self.gf_dim, 533 | output_dim=self.c_dim, 534 | use_resnet=self.config.if_resnet_g) 535 | self.image_generator = Generator('G2', is_train=False, 536 | norm=self.config.G_norm, 537 | batch_size=self.config.batch_size, 538 | output_height=self.config.output_height, 539 | output_width=int( 540 | self.config.output_width/2), 541 | input_dim=self.gf_dim, 542 | output_dim=self.c_dim, 543 | use_resnet=self.config.if_resnet_g) 544 | 545 | self.define_test_input() 546 | 547 | self.saver = tf.train.Saver() 548 | 549 | utils.show_all_variables() 550 | 551 | def test(self): 552 | def name_with_class(filename): 553 | splited = pathsplit(filename) 554 | return os.path.join(*splited[splited.index('test') + 1:]) 555 | 556 | def classes(filenames): 557 | result = [] 558 | mask = [] 559 | for path in filenames: 560 | try: 561 | classid = int(pathsplit(path)[-2]) 562 | if classid >= self.config.num_classes: 563 | mask.append(False) 564 | continue 565 | result.append(classid) 566 | mask.append(True) 567 | except: 568 | mask.append(False) 569 | pass 570 | return result, np.array(mask, dtype=np.bool) 571 | 572 | self.build_test_model() 573 | 574 | # init var 575 | try: 576 | tf.global_variables_initializer().run() 577 | except: 578 | tf.initialize_all_variables().run() 579 | 580 | counter = 1 581 | start_time = time.time() 582 | loaded, checkpoint_counter = self.load( 583 | self.saver, self.config.checkpoint_dir) 584 | if loaded: 585 | counter = checkpoint_counter 586 | print(" [*] Load SUCCESS") 587 | else: 588 | print(" [!] Load failed...") 589 | return 590 | 591 | for idx in range(len(self.dataset)): 592 | batch_images, filenames = self.dataset[idx] 593 | feed_dict = {self.inputs: batch_images} 594 | if self.config.multiclasses: 595 | class_of_images, mask = classes(filenames) 596 | if len(class_of_images) == 0: 597 | continue 598 | batch_classes = np.array(class_of_images, dtype=np.int) 599 | feed_dict[self.inputs] = feed_dict[self.inputs][mask] 600 | feed_dict.update({ 601 | self.classes: batch_classes, 602 | }) 603 | 604 | # generate images 605 | inputL = batch_images[:, :, 0:int(self.config.output_width / 2), :] 606 | outputL = self.sess.run(self.edge_output, 607 | feed_dict=feed_dict) 608 | outputR = self.sess.run(self.image_output, 609 | feed_dict=feed_dict) 610 | 611 | if self.config.output_combination == "inputL_outputR": 612 | results = np.append(inputL, outputR, axis=2) 613 | elif self.config.output_combination == "outputL_inputR": 614 | results = np.append(outputL, inputR, axis=2) 615 | elif self.config.output_combination == "outputR": 616 | results = outputR 617 | else: 618 | results = np.append(batch_images, outputL, axis=2) 619 | results = np.append(results, outputR, axis=2) 620 | 621 | assert results.shape[0] == len(filenames) 622 | for fname, img in zip(filenames, results): 623 | name = name_with_class(fname) 624 | img = img[np.newaxis, ...] 625 | utils.save_images( 626 | img, [1, 1], 627 | os.path.join( 628 | self.config.test_output_dir, 629 | self.config.dataset, name, 630 | ) 631 | ) 632 | 633 | print("Test: [%4d/%4d]" % (idx, len(self.dataset))) 634 | 635 | def save(self, saver, checkpoint_dir, step): 636 | print(" [*] Saving checkpoints...") 637 | # utils.makedirs(checkpoint_dir) 638 | checkpoint_dir = os.path.join(checkpoint_dir, self.model_name) 639 | saver.save(self.sess, checkpoint_dir, global_step=step) 640 | 641 | def load(self, saver, checkpoint_dir): 642 | import re 643 | 644 | checkpoint_dir = os.path.join(checkpoint_dir) 645 | print(" [*] Reading checkpoints {}...".format(checkpoint_dir)) 646 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 647 | if ckpt and ckpt.model_checkpoint_path: 648 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path) 649 | saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name)) 650 | 651 | counter = int( 652 | next(re.finditer("(\d+)(?!.*\d)", ckpt_name)).group(0)) 653 | print(" [*] Success to read {}".format(ckpt_name)) 654 | return True, counter 655 | else: 656 | print(" [*] Failed to find a checkpoint") 657 | return False, 0 658 | 659 | @property 660 | def model_name(self): 661 | return 'EdgeGAN-Model' 662 | -------------------------------------------------------------------------------- /edgegan/models/encoder.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | import edgegan.nn.functional as F 4 | from edgegan import nn 5 | 6 | 7 | class Encoder(object): 8 | def __init__(self, name, is_train, norm='batch', activation='relu', 9 | image_size=128, latent_dim=8, 10 | use_resnet=True): 11 | print(' [*] Init Encoder %s', name) 12 | self.name = name 13 | self._is_train = is_train 14 | self._norm = norm 15 | self._activation = activation 16 | self._image_size = image_size 17 | self._latent_dim = latent_dim 18 | self._use_resnet = use_resnet 19 | self._reuse = False 20 | 21 | def __call__(self, input): 22 | if self._use_resnet: 23 | return self._resnet(input) 24 | else: 25 | return self._convnet(input) 26 | 27 | def _convnet(self, input): 28 | with tf.variable_scope(self.name, reuse=self._reuse): 29 | num_filters = [64, 128, 256, 512, 512, 512, 512] 30 | if self._image_size == 256: 31 | num_filters.append(512) 32 | 33 | E = input 34 | for i, n in enumerate(num_filters): 35 | E = nn.conv_block(E, n, 'e_convnet_{}_{}'.format(n, i), 4, 36 | 2, self._is_train, self._reuse, 37 | norm=self._norm if i else None, 38 | activation=self._activation) 39 | E = F.flatten(E) 40 | mu = nn.mlp(E, self._latent_dim, 'FC8_mu', self._is_train, 41 | self._reuse, norm=None, activation=None) 42 | log_sigma = nn.mlp(E, self._latent_dim, 'FC8_sigma', 43 | self._is_train, self._reuse, 44 | norm=None, activation=None) 45 | 46 | z = mu + tf.random_normal(shape=tf.shape(self._latent_dim)) \ 47 | * tf.exp(log_sigma) 48 | 49 | self._reuse = True 50 | self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 51 | self.name) 52 | return z, mu, log_sigma 53 | 54 | def _resnet(self, input): 55 | with tf.variable_scope(self.name, reuse=self._reuse): 56 | num_filters = [128, 256, 512, 512] 57 | if self._image_size == 256: 58 | num_filters.append(512) 59 | 60 | E = input 61 | E = nn.conv_block(E, 64, 'e_resnet_{}_{}'.format(64, 0), 4, 2, 62 | self._is_train, self._reuse, norm=None, 63 | activation=self._activation, bias=True) 64 | for i, n in enumerate(num_filters): 65 | E = nn.residual(E, n, 'e_resnet_{}_{}'.format(n, i + 1), 66 | self._is_train, self._reuse, 67 | norm=self._norm, bias=True) 68 | E = tf.nn.avg_pool(E, [1, 2, 2, 1], [1, 2, 2, 1], 'SAME') 69 | E = nn.activation_fn(E, 'relu') 70 | E = tf.nn.avg_pool(E, [1, 8, 8, 1], [1, 8, 8, 1], 'SAME') 71 | E = F.flatten(E) 72 | mu = nn.mlp(E, self._latent_dim, 'FC8_mu', self._is_train, 73 | self._reuse, norm=None, activation=None) 74 | log_sigma = nn.mlp(E, self._latent_dim, 'FC8_sigma', 75 | self._is_train, self._reuse, norm=None, 76 | activation=None) 77 | 78 | z = mu + tf.random_normal(shape=tf.shape(self._latent_dim)) \ 79 | * tf.exp(log_sigma) 80 | 81 | self._reuse = True 82 | self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 83 | self.name) 84 | return z, mu, log_sigma 85 | -------------------------------------------------------------------------------- /edgegan/models/generator.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import tensorflow as tf 4 | 5 | import edgegan.nn.functional as F 6 | from edgegan import nn 7 | 8 | 9 | class Generator(object): 10 | def __init__(self, name, is_train, norm='batch', activation='relu', 11 | batch_size=64, output_height=64, output_width=128, 12 | input_dim=64, output_dim=3, use_resnet=False): 13 | print(' [*] Init Generator %s', name) 14 | self.name = name 15 | self._is_train = is_train 16 | self._norm = norm 17 | self._activation = activation 18 | self._batch_size = batch_size 19 | self._output_height = output_height 20 | self._output_width = output_width 21 | self._input_dim = input_dim 22 | self._output_dim = output_dim 23 | self._use_resnet = use_resnet 24 | self._reuse = False 25 | 26 | def _conv_out_size_same(self, size, stride): 27 | return int(math.ceil(float(size) / float(stride))) 28 | 29 | def __call__(self, z): 30 | if self._use_resnet: 31 | return self._resnet(z) 32 | else: 33 | return self._convnet(z) 34 | 35 | def _convnet(self, z): 36 | with tf.variable_scope(self.name, reuse=self._reuse): 37 | s_h, s_w = self._output_height, self._output_width 38 | s_h2, s_w2 = self._conv_out_size_same( 39 | s_h, 2), self._conv_out_size_same(s_w, 2) 40 | s_h4, s_w4 = self._conv_out_size_same( 41 | s_h2, 2), self._conv_out_size_same(s_w2, 2) 42 | s_h8, s_w8 = self._conv_out_size_same( 43 | s_h4, 2), self._conv_out_size_same(s_w4, 2) 44 | s_h16, s_w16 = self._conv_out_size_same( 45 | s_h8, 2), self._conv_out_size_same(s_w8, 2) 46 | 47 | # project `z` and reshape 48 | z_ = nn.linear(z, self._input_dim*8 * 49 | s_h16*s_w16, name='g_lin_0') 50 | h0 = tf.reshape(z_, [-1, s_h16, s_w16, self._input_dim * 8]) 51 | h0 = nn.activation_fn(nn.norm( 52 | h0, self._norm), self._activation) 53 | 54 | h1 = nn.deconv_block(h0, [self._batch_size, s_h8, s_w8, self._input_dim*4], 55 | 'g_dconv_1', 5, 2, self._is_train, 56 | self._reuse, self._norm, self._activation) 57 | 58 | h2 = nn.deconv_block(h1, [self._batch_size, s_h4, s_w4, self._input_dim*2], 59 | 'g_dconv_2', 5, 2, self._is_train, 60 | self._reuse, self._norm, self._activation) 61 | 62 | h3 = nn.deconv_block(h2, [self._batch_size, s_h2, s_w2, self._input_dim], 63 | 'g_dconv_3', 5, 2, self._is_train, 64 | self._reuse, self._norm, self._activation) 65 | 66 | h4 = nn.deconv_block(h3, [self._batch_size, s_h, s_w, self._output_dim], 67 | 'g_dconv_4', 5, 2, self._is_train, 68 | self._reuse, None, None) 69 | 70 | self._reuse = True 71 | self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 72 | self.name) 73 | 74 | return tf.nn.tanh(h4) 75 | 76 | def _resnet(self, z): 77 | # return None 78 | with tf.variable_scope(self.name, reuse=self._reuse): 79 | s_h, s_w = self._output_height, self._output_width 80 | s_h2, s_w2 = self._conv_out_size_same( 81 | s_h, 2), self._conv_out_size_same(s_w, 2) 82 | s_h4, s_w4 = self._conv_out_size_same( 83 | s_h2, 2), self._conv_out_size_same(s_w2, 2) 84 | s_h8, s_w8 = self._conv_out_size_same( 85 | s_h4, 2), self._conv_out_size_same(s_w4, 2) 86 | s_h16, s_w16 = self._conv_out_size_same( 87 | s_h8, 2), self._conv_out_size_same(s_w8, 2) 88 | 89 | # project `z` and reshape 90 | z_ = nn.linear(z, self._input_dim*8 * 91 | s_h16*s_w16, name='g_lin_resnet_0') 92 | h0 = nn.activation_fn(nn.norm( 93 | z_, self._norm), self._activation) 94 | h0 = tf.reshape(h0, [-1, s_h16, s_w16, self._input_dim * 8]) 95 | 96 | h1 = nn.deresidual2(h0, [self._batch_size, s_h8/2, s_w8/2, self._input_dim*4], 97 | 'g_resnet_1', 3, 1, self._is_train, 98 | self._reuse, self._norm, self._activation) 99 | h1 = nn.upsample2(h1, "NHWC") 100 | 101 | h2 = nn.deresidual2(h1, [self._batch_size, s_h4/2, s_w4/2, self._input_dim*2], 102 | 'g_resnet_2', 3, 1, self._is_train, 103 | self._reuse, self._norm, self._activation) 104 | h2 = nn.upsample2(h2, "NHWC") 105 | 106 | h3 = nn.deresidual2(h2, [self._batch_size, s_h2/2, s_w2/2, self._input_dim], 107 | 'g_resnet_3', 3, 1, self._is_train, 108 | self._reuse, self._norm, self._activation) 109 | h3 = nn.upsample2(h3, "NHWC") 110 | 111 | h4 = nn.deresidual2(h3, [self._batch_size, s_h/2, s_w/2, self._output_dim], 112 | 'g_resnet_4', 3, 1, self._is_train, 113 | self._reuse, None, None) 114 | h4 = nn.upsample2(h4, "NHWC") 115 | 116 | self._reuse = True 117 | self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 118 | self.name) 119 | 120 | return tf.nn.tanh(h4) 121 | -------------------------------------------------------------------------------- /edgegan/nn/__init__.py: -------------------------------------------------------------------------------- 1 | from .modules import * 2 | import tensorflow as tf 3 | 4 | try: 5 | image_summary = tf.image_summary 6 | scalar_summary = tf.scalar_summary 7 | histogram_summary = tf.histogram_summary 8 | merge_summary = tf.merge_summary 9 | SummaryWriter = tf.train.SummaryWriter 10 | except: 11 | image_summary = tf.summary.image 12 | scalar_summary = tf.summary.scalar 13 | histogram_summary = tf.summary.histogram 14 | merge_summary = tf.summary.merge 15 | SummaryWriter = tf.summary.FileWriter 16 | -------------------------------------------------------------------------------- /edgegan/nn/functional.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | 5 | def get_acgan_loss_focal(real_image_logits_out, real_image_label, 6 | disc_image_logits_out, condition, 7 | num_classes, ld1=1, ld2=0.5, ld_focal=2.): 8 | loss_ac_d = tf.reduce_mean((1 - tf.reduce_sum(tf.nn.softmax(real_image_logits_out) * tf.squeeze( 9 | tf.one_hot(real_image_label, num_classes, on_value=1., off_value=0., dtype=tf.float32)), axis=1)) ** ld_focal * 10 | tf.nn.sparse_softmax_cross_entropy_with_logits(logits=real_image_logits_out, labels=real_image_label)) 11 | loss_ac_d = ld1 * loss_ac_d 12 | 13 | loss_ac_g = tf.reduce_mean( 14 | tf.nn.sparse_softmax_cross_entropy_with_logits(logits=disc_image_logits_out, labels=condition)) 15 | loss_ac_g = ld2 * loss_ac_g 16 | return loss_ac_g, loss_ac_d 17 | 18 | 19 | def get_class_loss(logits_out, label, num_classes, ld_focal=2.0): 20 | loss = tf.reduce_mean((1 - tf.reduce_sum(tf.nn.softmax(logits_out) * tf.squeeze( 21 | tf.one_hot(label, num_classes, on_value=1., off_value=0., dtype=tf.float32)), axis=1)) ** ld_focal * 22 | tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits_out, labels=label)) 23 | return loss 24 | 25 | 26 | def gradient_penalty(output, on): 27 | gradients = tf.gradients(output, [on, ])[0] 28 | grad_l2 = tf.sqrt(tf.reduce_sum(tf.square(gradients), axis=[1, 2, 3])) 29 | return tf.reduce_mean((grad_l2-1)**2) 30 | 31 | 32 | def discriminator_ganloss(output, target): 33 | return tf.reduce_mean(output - target) 34 | 35 | 36 | def generator_ganloss(output): 37 | return tf.reduce_mean(output * -1) 38 | 39 | 40 | def l1loss(output, target, weight): 41 | return weight * tf.reduce_mean(tf.abs(output - target)) 42 | 43 | 44 | def flatten(input): 45 | return tf.reshape(input, [-1, np.prod(input.get_shape().as_list()[1:])]) 46 | -------------------------------------------------------------------------------- /edgegan/nn/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .activation import activation_fn, miu_relu, prelu, lrelu 2 | from .normalization import norm, spectral_normed_weight 3 | from .conv import * 4 | from .linear import * 5 | from .pooling import mean_pool 6 | from .upsampling import upsample, upsample2 7 | -------------------------------------------------------------------------------- /edgegan/nn/modules/activation.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def activation_fn(input, name='lrelu'): 5 | assert name in ['relu', 'lrelu', 'tanh', 'sigmoid', None] 6 | if name == 'relu': 7 | return tf.nn.relu(input) 8 | elif name == 'lrelu': 9 | return tf.maximum(input, 0.2*input) 10 | elif name == 'tanh': 11 | return tf.tanh(input) 12 | elif name == 'sigmoid': 13 | return tf.sigmoid(input) 14 | else: 15 | return input 16 | 17 | 18 | def miu_relu(x, miu=0.7, name="miu_relu"): 19 | with tf.variable_scope(name): 20 | return (x + tf.sqrt((1 - miu) ** 2 + x ** 2)) / 2. 21 | 22 | 23 | def prelu(x, name="prelu"): 24 | with tf.variable_scope(name): 25 | leak = tf.get_variable("param", shape=None, initializer=0.2, regularizer=None, 26 | trainable=True, caching_device=None) 27 | return tf.maximum(leak * x, x) 28 | 29 | 30 | def lrelu(x, leak=0.2, name="lrelu"): 31 | with tf.variable_scope(name): 32 | return tf.maximum(leak * x, x) 33 | -------------------------------------------------------------------------------- /edgegan/nn/modules/conv.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | import tensorflow as tf 4 | import tensorflow.contrib.layers as ly 5 | from tensorflow.python.ops import init_ops 6 | 7 | from .activation import activation_fn as _activation 8 | from .normalization import norm as _norm, spectral_normed_weight 9 | from .pooling import mean_pool 10 | from .activation import lrelu 11 | 12 | 13 | def conv2d(input, output_dim, filter_size=5, stride=2, reuse=False, 14 | pad='SAME', bias=True, name=None): 15 | stride_shape = [1, stride, stride, 1] 16 | filter_shape = [filter_size, filter_size, input.get_shape()[-1], 17 | output_dim] 18 | 19 | with tf.variable_scope(name or 'conv2d', reuse=reuse): 20 | w = tf.get_variable('w', filter_shape, 21 | initializer=tf.truncated_normal_initializer( 22 | stddev=0.02)) 23 | if pad == 'REFLECT': 24 | p = (filter_size - 1) // 2 25 | x = tf.pad(input, [[0, 0], [p, p], [p, p], [0, 0]], 'REFLECT') 26 | conv = tf.nn.conv2d(x, w, stride_shape, padding='VALID') 27 | else: 28 | assert pad in ['SAME', 'VALID'] 29 | conv = tf.nn.conv2d(input, w, stride_shape, padding=pad) 30 | 31 | if bias: 32 | b = tf.get_variable('b', [output_dim], 33 | initializer=tf.constant_initializer(0.0)) 34 | conv = tf.reshape(tf.nn.bias_add(conv, b), tf.shape(conv)) 35 | 36 | return conv 37 | 38 | 39 | def deconv2d(input_, output_shape, with_w=False, 40 | filter_size=5, stride=2, reuse=False, name=None): 41 | with tf.variable_scope(name or 'deconv2d', reuse=reuse): 42 | stride_shape = [1, stride, stride, 1] 43 | filter_shape = [filter_size, filter_size, output_shape[-1], 44 | input_.get_shape()[-1]] 45 | 46 | w = tf.get_variable('w', filter_shape, 47 | initializer=tf.random_normal_initializer( 48 | stddev=0.02)) 49 | deconv = tf.nn.conv2d_transpose( 50 | input_, w, output_shape=output_shape, strides=stride_shape) 51 | b = tf.get_variable('b', [output_shape[-1]], 52 | initializer=tf.constant_initializer(0.0)) 53 | deconv = tf.reshape(tf.nn.bias_add(deconv, b), deconv.get_shape()) 54 | 55 | if with_w: 56 | return deconv, w, b 57 | else: 58 | return deconv 59 | 60 | 61 | def conv_block(input, num_filters, name, k_size, stride, is_train, reuse, norm, 62 | activation, pad='SAME', bias=False): 63 | with tf.variable_scope(name, reuse=reuse): 64 | out = conv2d(input, num_filters, k_size, stride, reuse, pad, bias) 65 | out = _norm(out, is_train, norm) 66 | out = _activation(out, activation) 67 | return out 68 | 69 | 70 | def residual(input, num_filters, name, is_train, reuse, norm, pad='REFLECT', 71 | bias=False): 72 | with tf.variable_scope(name, reuse=reuse): 73 | with tf.variable_scope('res1', reuse=reuse): 74 | out = conv2d(input, num_filters, 3, 1, reuse, pad, bias) 75 | out = _norm(out, is_train, norm) 76 | out = tf.nn.relu(out) 77 | 78 | with tf.variable_scope('res2', reuse=reuse): 79 | out = conv2d(out, num_filters, 3, 1, reuse, pad, bias) 80 | out = _norm(out, is_train, norm) 81 | 82 | with tf.variable_scope('shortcut', reuse=reuse): 83 | shortcut = conv2d(input, num_filters, 1, 1, reuse, pad, bias) 84 | 85 | return tf.nn.relu(shortcut + out) 86 | 87 | 88 | def residual2(input, num_filters, name, k_size, stride, is_train, reuse, norm, 89 | activation, pad='SAME', bias=False): 90 | with tf.variable_scope(name, reuse=reuse): 91 | with tf.variable_scope('res1', reuse=reuse): 92 | out = conv2d(input, num_filters, k_size, stride, reuse, pad, bias) 93 | out = _norm(out, is_train, norm) 94 | out = _activation(out) 95 | 96 | with tf.variable_scope('res2', reuse=reuse): 97 | out = conv2d(out, num_filters, k_size, stride, reuse, pad, bias) 98 | out = _norm(out, is_train, norm) 99 | 100 | with tf.variable_scope('shortcut', reuse=reuse): 101 | shortcut = conv2d(input, num_filters, 1, 1, reuse, pad, bias) 102 | 103 | return _activation(shortcut + out) 104 | 105 | 106 | def deresidual2(input, num_filters, name, k_size, stride, is_train, reuse, 107 | norm, activation, with_w=False): 108 | with tf.variable_scope(name, reuse=reuse): 109 | with tf.variable_scope('res1', reuse=reuse): 110 | out = deconv2d(input, num_filters, with_w, k_size, stride, reuse) 111 | out = _norm(out, is_train, norm) 112 | out = _activation(out, activation) 113 | 114 | with tf.variable_scope('res2', reuse=reuse): 115 | out = deconv2d(out, num_filters, with_w, k_size, stride, reuse) 116 | out = _norm(out, is_train, norm) 117 | 118 | with tf.variable_scope('shortcut', reuse=reuse): 119 | shortcut = deconv2d(input, num_filters, with_w, 1, 1, reuse) 120 | 121 | return _activation(shortcut + out, activation) 122 | 123 | 124 | def deconv_block(input, output_shape, name, k_size, stride, is_train, reuse, 125 | norm, activation, with_w=False): 126 | with tf.variable_scope(name, reuse=reuse): 127 | out = deconv2d(input, output_shape, with_w, k_size, stride, reuse) 128 | out = _norm(out, is_train, norm) 129 | out = _activation(out, activation) 130 | return out 131 | 132 | 133 | def mru_conv_block_v3(inp, ht, filter_depth, sn, 134 | stride, dilate=1, 135 | activation_fn=tf.nn.relu, 136 | normalizer_fn=None, 137 | normalizer_params=None, 138 | weights_initializer=ly.xavier_initializer_conv2d(), 139 | biases_initializer_mask=tf.constant_initializer( 140 | value=0.5), 141 | biases_initializer_h=tf.constant_initializer(value=-1), 142 | data_format='NCHW', 143 | weight_decay_rate=1e-8, 144 | norm_mask=False, 145 | norm_input=True, 146 | deconv=False): 147 | 148 | def norm_activ(tensor_in): 149 | if normalizer_fn is not None: 150 | _normalizer_params = normalizer_params or {} 151 | tensor_normed = normalizer_fn(tensor_in, **_normalizer_params) 152 | else: 153 | tensor_normed = tf.identity(tensor_in) 154 | if activation_fn is not None: 155 | tensor_normed = activation_fn(tensor_normed) 156 | 157 | return tensor_normed 158 | 159 | channel_index = 1 if data_format == 'NCHW' else 3 160 | reduce_dim = [2, 3] if data_format == 'NCHW' else [1, 2] 161 | hidden_depth = ht.get_shape().as_list()[channel_index] 162 | regularizer = ly.l2_regularizer( 163 | weight_decay_rate) if weight_decay_rate > 0 else None 164 | weights_initializer_mask = weights_initializer 165 | biases_initializer = tf.zeros_initializer() 166 | 167 | if norm_mask: 168 | mask_normalizer_fn = normalizer_fn 169 | mask_normalizer_params = normalizer_params 170 | else: 171 | mask_normalizer_fn = None 172 | mask_normalizer_params = None 173 | 174 | if deconv: 175 | if stride == 2: 176 | ht = upsample(ht, data_format=data_format) 177 | elif stride != 1: 178 | raise NotImplementedError 179 | 180 | ht_orig = tf.identity(ht) 181 | 182 | # Normalize hidden state 183 | with tf.variable_scope('norm_activation_in') as sc: 184 | if norm_input: 185 | full_inp = tf.concat([norm_activ(ht), inp], axis=channel_index) 186 | else: 187 | full_inp = tf.concat([ht, inp], axis=channel_index) 188 | 189 | # update gate 190 | rg = conv2d2(full_inp, hidden_depth, 3, sn=sn, stride=1, rate=dilate, 191 | data_format=data_format, activation_fn=lrelu, 192 | normalizer_fn=mask_normalizer_fn, normalizer_params=mask_normalizer_params, 193 | weights_regularizer=regularizer, 194 | weights_initializer=weights_initializer_mask, 195 | biases_initializer=biases_initializer_mask, 196 | scope='update_gate') 197 | rg = (rg - tf.reduce_min(rg, axis=reduce_dim, keep_dims=True)) / ( 198 | tf.reduce_max(rg, axis=reduce_dim, keep_dims=True) - tf.reduce_min(rg, axis=reduce_dim, keep_dims=True)) 199 | 200 | # Input Image conv 201 | img_new = conv2d2(inp, hidden_depth, 3, sn=sn, stride=1, rate=dilate, 202 | data_format=data_format, activation_fn=None, 203 | normalizer_fn=None, normalizer_params=None, 204 | biases_initializer=biases_initializer, 205 | weights_regularizer=regularizer, 206 | weights_initializer=weights_initializer) 207 | 208 | ht_plus = ht + rg * img_new 209 | with tf.variable_scope('norm_activation_merge_1') as sc: 210 | ht_new_in = norm_activ(ht_plus) 211 | 212 | # new hidden state 213 | h_new = conv2d2(ht_new_in, filter_depth, 3, sn=sn, stride=1, rate=dilate, 214 | data_format=data_format, activation_fn=activation_fn, 215 | normalizer_fn=normalizer_fn, normalizer_params=normalizer_params, 216 | biases_initializer=biases_initializer, 217 | weights_regularizer=regularizer, 218 | weights_initializer=weights_initializer) 219 | h_new = conv2d2(h_new, filter_depth, 3, sn=sn, stride=1, rate=dilate, 220 | data_format=data_format, activation_fn=None, 221 | normalizer_fn=None, normalizer_params=None, 222 | biases_initializer=biases_initializer, 223 | weights_regularizer=regularizer, 224 | weights_initializer=weights_initializer) 225 | 226 | # new hidden state out 227 | # linear project for filter depth 228 | if ht.get_shape().as_list()[channel_index] != filter_depth: 229 | ht_orig = conv2d2(ht_orig, filter_depth, 1, sn=sn, stride=1, 230 | data_format=data_format, activation_fn=None, 231 | normalizer_fn=None, normalizer_params=None, 232 | biases_initializer=biases_initializer, 233 | weights_regularizer=regularizer, 234 | weights_initializer=weights_initializer) 235 | ht_new = ht_orig + h_new 236 | 237 | if not deconv: 238 | if stride == 2: 239 | ht_new = mean_pool(ht_new, data_format=data_format) 240 | elif stride != 1: 241 | raise NotImplementedError 242 | 243 | return ht_new 244 | 245 | 246 | def conv2d2(inputs, num_outputs, kernel_size, sn, stride=1, rate=1, 247 | data_format='NCHW', activation_fn=tf.nn.relu, 248 | normalizer_fn=None, normalizer_params=None, 249 | weights_regularizer=None, 250 | weights_initializer=ly.xavier_initializer(), 251 | biases_initializer=init_ops.zeros_initializer(), 252 | biases_regularizer=None, 253 | reuse=None, scope=None, 254 | SPECTRAL_NORM_UPDATE_OPS='spectral_norm_update_ops'): 255 | assert data_format == 'NCHW' 256 | assert rate == 1 257 | if data_format == 'NCHW': 258 | channel_axis = 1 259 | stride = [1, 1, stride, stride] 260 | rate = [1, 1, rate, rate] 261 | else: 262 | channel_axis = 3 263 | stride = [1, stride, stride, 1] 264 | rate = [1, rate, rate, 1] 265 | input_dim = inputs.get_shape().as_list()[channel_axis] 266 | 267 | with tf.variable_scope(scope, 'Conv', [inputs], reuse=reuse) as sc: 268 | inputs = tf.convert_to_tensor(inputs) 269 | 270 | weights = tf.get_variable(name="weights", shape=(kernel_size, kernel_size, input_dim, num_outputs), 271 | initializer=weights_initializer, regularizer=weights_regularizer, 272 | trainable=True, dtype=inputs.dtype.base_dtype) 273 | # Spectral Normalization 274 | if sn: 275 | weights = spectral_normed_weight( 276 | weights, num_iters=1, update_collection=SPECTRAL_NORM_UPDATE_OPS) 277 | 278 | conv_out = tf.nn.conv2d( 279 | inputs, weights, strides=stride, padding='SAME', data_format=data_format) 280 | 281 | if biases_initializer is not None: 282 | biases = tf.get_variable(name='biases', shape=(1, num_outputs, 1, 1), 283 | initializer=biases_initializer, regularizer=biases_regularizer, 284 | trainable=True, dtype=inputs.dtype.base_dtype) 285 | conv_out += biases 286 | 287 | if normalizer_fn is not None: 288 | normalizer_params = normalizer_params or {} 289 | conv_out = normalizer_fn( 290 | conv_out, activation_fn=None, **normalizer_params) 291 | 292 | if activation_fn is not None: 293 | conv_out = activation_fn(conv_out) 294 | 295 | return conv_out 296 | 297 | 298 | def mru_conv(x, ht, filter_depth, sn, stride=2, dilate_rate=1, 299 | num_blocks=5, last_unit=False, 300 | activation_fn=tf.nn.relu, 301 | normalizer_fn=None, 302 | normalizer_params=None, 303 | weights_initializer=ly.xavier_initializer_conv2d(), 304 | weight_decay_rate=1e-5, 305 | unit_num=0, data_format='NCHW'): 306 | assert len(ht) == num_blocks 307 | 308 | def norm_activ(tensor_in): 309 | if normalizer_fn is not None: 310 | _normalizer_params = normalizer_params or {} 311 | tensor_normed = normalizer_fn(tensor_in, **_normalizer_params) 312 | else: 313 | tensor_normed = tf.identity(tensor_in) 314 | if activation_fn is not None: 315 | tensor_normed = activation_fn(tensor_normed) 316 | 317 | return tensor_normed 318 | 319 | if dilate_rate != 1: 320 | stride = 1 321 | 322 | cell_block = functools.partial(mru_conv_block_v3, deconv=False) 323 | 324 | hts_new = [] 325 | inp = x 326 | with tf.variable_scope('mru_conv_unit_t_%d_layer_0' % unit_num): 327 | ht_new = cell_block(inp, ht[0], filter_depth, sn=sn, stride=stride, 328 | dilate=dilate_rate, 329 | activation_fn=activation_fn, 330 | normalizer_fn=normalizer_fn, 331 | normalizer_params=normalizer_params, 332 | weights_initializer=weights_initializer, 333 | data_format=data_format, 334 | weight_decay_rate=weight_decay_rate) 335 | hts_new.append(ht_new) 336 | inp = ht_new 337 | 338 | for i in range(1, num_blocks): 339 | if stride == 2: 340 | ht[i] = mean_pool(ht[i], data_format=data_format) 341 | with tf.variable_scope('mru_conv_unit_t_%d_layer_%d' % (unit_num, i)): 342 | ht_new = cell_block(inp, ht[i], filter_depth, sn=sn, stride=1, 343 | dilate=dilate_rate, 344 | activation_fn=activation_fn, 345 | normalizer_fn=normalizer_fn, 346 | normalizer_params=normalizer_params, 347 | weights_initializer=weights_initializer, 348 | data_format=data_format, 349 | weight_decay_rate=weight_decay_rate) 350 | hts_new.append(ht_new) 351 | inp = ht_new 352 | 353 | if hasattr(cell_block, 'func') and cell_block.func == mru_conv_block_v3 and last_unit: 354 | with tf.variable_scope('mru_conv_unit_last_norm'): 355 | hts_new[-1] = norm_activ(hts_new[-1]) 356 | 357 | return hts_new 358 | -------------------------------------------------------------------------------- /edgegan/nn/modules/linear.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow.contrib.layers as ly 3 | from tensorflow.python.ops import init_ops 4 | 5 | from .activation import activation_fn as _activation 6 | from .normalization import norm as _norm 7 | from .normalization import spectral_normed_weight 8 | 9 | 10 | def linear(input_, output_size, with_w=False, reuse=False, name=None): 11 | shape = input_.get_shape().as_list() 12 | 13 | with tf.variable_scope(name or "linear", reuse=reuse): 14 | try: 15 | matrix = tf.get_variable( 16 | "Matrix", [shape[1], output_size], 17 | tf.float32, 18 | tf.random_normal_initializer(stddev=0.02)) 19 | except ValueError as err: 20 | msg = "NOTE: Usually, this is due to an issue with the image \ 21 | dimensions. Did you correctly set '--crop' or '--input_height' or \ 22 | '--output_height'?" 23 | err.args = err.args + (msg, ) 24 | raise 25 | bias = tf.get_variable( 26 | "bias", [output_size], 27 | initializer=tf.constant_initializer(0.0)) 28 | if with_w: 29 | return tf.matmul(input_, matrix) + bias, matrix, bias 30 | else: 31 | return tf.matmul(input_, matrix) + bias 32 | 33 | 34 | def fully_connected(inputs, num_outputs, sn, activation_fn=None, 35 | normalizer_fn=None, normalizer_params=None, 36 | weights_initializer=ly.xavier_initializer(), 37 | weight_decay_rate=1e-6, 38 | biases_initializer=init_ops.zeros_initializer(), 39 | biases_regularizer=None, 40 | reuse=None, scope=None, SPECTRAL_NORM_UPDATE_OPS='spectral_norm_update_ops'): 41 | # TODO move regularizer definitions to model 42 | weights_regularizer = ly.l2_regularizer(weight_decay_rate) 43 | 44 | input_dim = inputs.get_shape().as_list()[1] 45 | 46 | with tf.variable_scope(scope, 'fully_connected', [inputs], reuse=reuse) as sc: 47 | inputs = tf.convert_to_tensor(inputs) 48 | 49 | weights = tf.get_variable(name="weights", shape=(input_dim, num_outputs), 50 | initializer=weights_initializer, regularizer=weights_regularizer, 51 | trainable=True, dtype=inputs.dtype.base_dtype) 52 | 53 | # Spectral Normalization 54 | if sn: 55 | weights = spectral_normed_weight( 56 | weights, num_iters=1, update_collection=(SPECTRAL_NORM_UPDATE_OPS)) 57 | 58 | linear_out = tf.matmul(inputs, weights) 59 | 60 | if biases_initializer is not None: 61 | biases = tf.get_variable(name="biases", shape=(num_outputs,), 62 | initializer=biases_initializer, regularizer=biases_regularizer, 63 | trainable=True, dtype=inputs.dtype.base_dtype) 64 | 65 | linear_out = tf.nn.bias_add(linear_out, biases) 66 | 67 | # Apply normalizer function / layer. 68 | if normalizer_fn is not None: 69 | normalizer_params = normalizer_params or {} 70 | linear_out = normalizer_fn( 71 | linear_out, activation_fn=None, **normalizer_params) 72 | 73 | if activation_fn is not None: 74 | linear_out = activation_fn(linear_out) 75 | 76 | return linear_out 77 | 78 | 79 | def mlp(input, out_dim, name, is_train, reuse, norm=None, activation=None, 80 | dtype=tf.float32, bias=True): 81 | with tf.variable_scope(name, reuse=reuse): 82 | _, n = input.get_shape() 83 | w = tf.get_variable('w', [n, out_dim], dtype, 84 | tf.random_normal_initializer(0.0, 0.02)) 85 | out = tf.matmul(input, w) 86 | if bias: 87 | b = tf.get_variable('b', [out_dim], 88 | initializer=tf.constant_initializer(0.0)) 89 | out = out + b 90 | out = _activation(out, activation) 91 | out = _norm(out, is_train, norm) 92 | return out 93 | -------------------------------------------------------------------------------- /edgegan/nn/modules/normalization.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import warnings 3 | 4 | __all__ = [ 5 | 'norm', 6 | 'spectral_normed_weight', 7 | ] 8 | 9 | 10 | def norm(input, is_train, norm='batch', 11 | epsilon=1e-5, momentum=0.9, name=None): 12 | assert norm in ['instance', 'batch', None] 13 | if norm == 'instance': 14 | with tf.variable_scope(name or 'instance_norm'): 15 | eps = 1e-5 16 | mean, sigma = tf.nn.moments(input, [1, 2], keep_dims=True) 17 | normalized = (input - mean) / (tf.sqrt(sigma) + eps) 18 | out = normalized 19 | elif norm == 'batch': 20 | with tf.variable_scope(name or 'batch_norm'): 21 | out = tf.contrib.layers.batch_norm(input, 22 | decay=momentum, center=True, 23 | updates_collections=None, 24 | epsilon=epsilon, 25 | scale=True, is_training=True) 26 | else: 27 | out = input 28 | 29 | return out 30 | 31 | 32 | NO_OPS = 'NO_OPS' 33 | 34 | def _l2normalize(v, eps=1e-12): 35 | return v / (tf.reduce_sum(v ** 2) ** 0.5 + eps) 36 | 37 | 38 | def spectral_normed_weight(W, u=None, num_iters=1, update_collection=None, with_sigma=False): 39 | W_shape = W.shape.as_list() 40 | W_reshaped = tf.reshape(W, [-1, W_shape[-1]]) 41 | if u is None: 42 | with tf.variable_scope(W.name.rsplit('/', 1)[0]) as sc: 43 | u = tf.get_variable( 44 | "u", [1, W_shape[-1]], initializer=tf.truncated_normal_initializer(), trainable=False) 45 | 46 | def power_iteration(i, u_i, v_i): 47 | v_ip1 = _l2normalize(tf.matmul(u_i, tf.transpose(W_reshaped))) 48 | u_ip1 = _l2normalize(tf.matmul(v_ip1, W_reshaped)) 49 | return i + 1, u_ip1, v_ip1 50 | 51 | _, u_final, v_final = tf.while_loop( 52 | cond=lambda i, _1, _2: i < num_iters, 53 | body=power_iteration, 54 | loop_vars=(tf.constant(0, dtype=tf.int32), 55 | u, tf.zeros(dtype=tf.float32, shape=[1, W_reshaped.shape.as_list()[0]])) 56 | ) 57 | if update_collection is None: 58 | warnings.warn( 59 | 'Setting update_collection to None will make u being updated every W execution. This maybe undesirable' 60 | '. Please consider using a update collection instead.') 61 | sigma = tf.matmul(tf.matmul(v_final, W_reshaped), 62 | tf.transpose(u_final))[0, 0] 63 | W_bar = W_reshaped / sigma 64 | with tf.control_dependencies([u.assign(u_final)]): 65 | W_bar = tf.reshape(W_bar, W_shape) 66 | else: 67 | sigma = tf.matmul(tf.matmul(v_final, W_reshaped), 68 | tf.transpose(u_final))[0, 0] 69 | W_bar = W_reshaped / sigma 70 | W_bar = tf.reshape(W_bar, W_shape) 71 | if update_collection != NO_OPS: 72 | tf.add_to_collection(update_collection, u.assign(u_final)) 73 | if with_sigma: 74 | return W_bar, sigma 75 | else: 76 | return W_bar 77 | -------------------------------------------------------------------------------- /edgegan/nn/modules/pooling.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def mean_pool(input, data_format): 5 | assert data_format == 'NCHW' 6 | output = tf.add_n( 7 | [input[:, :, ::2, ::2], input[:, :, 1::2, ::2], input[:, :, ::2, 1::2], input[:, :, 1::2, 1::2]]) / 4. 8 | return output 9 | -------------------------------------------------------------------------------- /edgegan/nn/modules/upsampling.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def upsample(input, data_format): 5 | assert data_format == 'NCHW' 6 | output = tf.concat([input, input, input, input], axis=1) 7 | output = tf.transpose(output, [0, 2, 3, 1]) 8 | output = tf.depth_to_space(output, 2) 9 | output = tf.transpose(output, [0, 3, 1, 2]) 10 | return output 11 | 12 | 13 | def upsample2(input, data_format): 14 | assert data_format == 'NHWC' 15 | output = tf.transpose(input, [0, 3, 1, 2]) 16 | output = tf.concat([output, output, output, output], axis=1) 17 | output = tf.transpose(output, [0, 2, 3, 1]) 18 | output = tf.depth_to_space(output, 2) 19 | return output 20 | -------------------------------------------------------------------------------- /edgegan/test.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import tensorflow as tf 5 | from numpy.random import seed 6 | 7 | from edgegan.models import EdgeGAN 8 | from edgegan.utils import makedirs, pp 9 | from edgegan.utils.data import Dataset 10 | 11 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 12 | phase = 'test' 13 | 14 | seed(2333) 15 | tf.set_random_seed(6666) 16 | 17 | _FLAGS = tf.app.flags 18 | _FLAGS.DEFINE_string("gpu", "0", "Gpu ID") 19 | _FLAGS.DEFINE_string("name", "edgegan", "Folder for all outputs") 20 | _FLAGS.DEFINE_string("outputsroot", "outputs", "Outputs root") 21 | _FLAGS.DEFINE_integer("epoch", 100, "Epoch to train [25]") 22 | _FLAGS.DEFINE_float("learning_rate", 0.0002, "") 23 | _FLAGS.DEFINE_float("train_size", np.inf, "The size of train images [np.inf]") 24 | _FLAGS.DEFINE_integer("batch_size", 64, "The size of batch images [64]") 25 | _FLAGS.DEFINE_integer( 26 | "input_height", 64, "The size of image to use (will be center cropped). [108]") 27 | _FLAGS.DEFINE_integer( 28 | "input_width", 128, "The size of image to use (will be center cropped). If None, same value as input_height [None]") 29 | _FLAGS.DEFINE_integer("output_height", 64, 30 | "The size of the output images to produce [64]") 31 | _FLAGS.DEFINE_integer("output_width", 128, 32 | "The size of the output images to produce. If None, same value as output_height [None]") 33 | _FLAGS.DEFINE_string("dataset", "class14", "") 34 | _FLAGS.DEFINE_string("input_fname_pattern", "*png", 35 | "Glob pattern of filename of input images [*]") 36 | _FLAGS.DEFINE_string("checkpoint_dir", None, "") 37 | _FLAGS.DEFINE_string("logdir", None, "") 38 | _FLAGS.DEFINE_string("dataroot", "./data", "Root directory of dataset [data]") 39 | _FLAGS.DEFINE_string("test_output_dir", "test_output", 40 | "Directory name to save the image samples [samples]") 41 | _FLAGS.DEFINE_boolean("crop", False, "") 42 | 43 | 44 | # setting of testing 45 | _FLAGS.DEFINE_string("output_combination", "full", 46 | "The combination of output image: full(input+output), inputL_outputR(the left of input combine the right of output),outputL_inputR, outputR") 47 | 48 | # # multi class 49 | _FLAGS.DEFINE_boolean("multiclasses", True, "if use focal loss") 50 | _FLAGS.DEFINE_integer("num_classes", 14, "num of classes") 51 | 52 | _FLAGS.DEFINE_string("type", "gpwgan", "gan type: [dcgan | wgan | gpwgan]") 53 | _FLAGS.DEFINE_string("optim", "rmsprop", "optimizer type: [adam | rmsprop]") 54 | _FLAGS.DEFINE_string("model", "old", "which base model(G and D): [old | new]") 55 | 56 | 57 | _FLAGS.DEFINE_boolean("if_resnet_e", True, "if use resnet for E") 58 | _FLAGS.DEFINE_boolean("if_resnet_g", False, "if use resnet for G") 59 | _FLAGS.DEFINE_boolean("if_resnet_d", False, "if use resnet for origin D") 60 | _FLAGS.DEFINE_string("E_norm", "instance", 61 | "normalization options:[instance, batch, norm]") 62 | _FLAGS.DEFINE_string("G_norm", "instance", 63 | "normalization options:[instance, batch, norm]") 64 | _FLAGS.DEFINE_string("D_norm", "instance", 65 | "normalization options:[instance, batch, norm]") 66 | 67 | FLAGS = _FLAGS.FLAGS 68 | 69 | os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu 70 | 71 | 72 | def subdirs(root): 73 | return [name for name in os.listdir(root) 74 | if os.path.isdir(os.path.join(root, name))] 75 | 76 | 77 | def make_outputs_dir(flags): 78 | makedirs(os.path.join(flags.test_output_dir, flags.dataset)) 79 | for path in subdirs(os.path.join(flags.dataroot, flags.dataset, phase)): 80 | makedirs(os.path.join(flags.test_output_dir, flags.dataset, path)) 81 | 82 | 83 | def update_flags(flags): 84 | if flags.input_width is None: 85 | flags.input_width = flags.input_height 86 | if flags.output_width is None: 87 | flags.output_width = flags.output_height 88 | 89 | flags.batch_size = 1 90 | 91 | path = os.path.join(flags.outputsroot, flags.name) 92 | setattr(flags, 'checkpoint_dir', os.path.join(path, 'checkpoints')) 93 | setattr(flags, 'logdir', os.path.join(path, 'logs')) 94 | setattr(flags, 'test_output_dir', os.path.join(path, 'test_output')) 95 | 96 | return flags 97 | 98 | 99 | def create_dataset(flags): 100 | dataset_config = { 101 | 'input_height': flags.input_height, 102 | 'input_width': flags.input_width, 103 | 'output_height': flags.output_height, 104 | 'output_width': flags.output_width, 105 | 'crop': flags.crop, 106 | 'grayscale': False, 107 | } 108 | return Dataset( 109 | flags.dataroot, flags.dataset, 110 | flags.train_size, 1, 111 | dataset_config, None, phase 112 | ) 113 | 114 | 115 | def main(_): 116 | flags = update_flags(FLAGS) 117 | pp.pprint(flags.__flags) 118 | make_outputs_dir(flags) 119 | 120 | run_config = tf.ConfigProto() 121 | run_config.gpu_options.allow_growth = True 122 | 123 | with tf.Session(config=run_config) as sess: 124 | edgegan_model = EdgeGAN(sess, flags, None) 125 | edgegan_model.dataset = create_dataset(flags) 126 | edgegan_model.test() 127 | 128 | 129 | if __name__ == '__main__': 130 | tf.app.run() 131 | -------------------------------------------------------------------------------- /edgegan/train.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | import numpy as np 5 | import tensorflow as tf 6 | from numpy.random import seed 7 | 8 | from edgegan.models import EdgeGAN 9 | from edgegan.utils import makedirs, pp 10 | from edgegan.utils.data import Dataset 11 | 12 | 13 | _FLAGS = tf.app.flags 14 | _FLAGS.DEFINE_string("gpu", "0", "Gpu ID") 15 | _FLAGS.DEFINE_string("name", "edgegan", "Folder for all outputs") 16 | _FLAGS.DEFINE_string("outputsroot", "outputs", "Outputs root") 17 | _FLAGS.DEFINE_integer("epoch", 100, "Epoch to train [25]") 18 | _FLAGS.DEFINE_float("learning_rate", 0.0002, "") 19 | _FLAGS.DEFINE_float("train_size", np.inf, "The size of train images [np.inf]") 20 | _FLAGS.DEFINE_integer("batch_size", 64, "The size of batch images [64]") 21 | _FLAGS.DEFINE_integer( 22 | "input_height", 64, "The size of image to use (will be center cropped). [108]") 23 | _FLAGS.DEFINE_integer( 24 | "input_width", 128, "The size of image to use (will be center cropped). If None, same value as input_height [None]") 25 | _FLAGS.DEFINE_integer("output_height", 64, 26 | "The size of the output images to produce [64]") 27 | _FLAGS.DEFINE_integer("output_width", 128, 28 | "The size of the output images to produce. If None, same value as output_height [None]") 29 | _FLAGS.DEFINE_string("dataset", "class14", "") 30 | _FLAGS.DEFINE_string("input_fname_pattern", "*png", 31 | "Glob pattern of filename of input images [*]") 32 | _FLAGS.DEFINE_string("checkpoint_dir", None, "") 33 | _FLAGS.DEFINE_string("logdir", None, "") 34 | _FLAGS.DEFINE_string("dataroot", "./data", "Root directory of dataset [data]") 35 | _FLAGS.DEFINE_integer("save_checkpoint_frequency", 500, 36 | "frequency for saving checkpoint") 37 | _FLAGS.DEFINE_boolean("crop", False, "") 38 | 39 | 40 | # weight of loss 41 | _FLAGS.DEFINE_float("stage1_zl_loss", 10.0, "weight of z l1 loss") 42 | 43 | # multi class 44 | _FLAGS.DEFINE_boolean("multiclasses", True, "if use focal loss") 45 | _FLAGS.DEFINE_integer("num_classes", 14, "num of classes") 46 | _FLAGS.DEFINE_string("SPECTRAL_NORM_UPDATE_OPS", 47 | "spectral_norm_update_ops", "") 48 | 49 | _FLAGS.DEFINE_boolean("if_resnet_e", True, "if use resnet for E") 50 | _FLAGS.DEFINE_boolean("if_resnet_g", False, "if use resnet for G") 51 | _FLAGS.DEFINE_boolean("if_resnet_d", False, "if use resnet for D") 52 | _FLAGS.DEFINE_float("lambda_gp", 10.0, "") 53 | 54 | _FLAGS.DEFINE_string("E_norm", "instance", 55 | "normalization options:[instance, batch, norm]") 56 | _FLAGS.DEFINE_string("G_norm", "instance", 57 | "normalization options:[instance, batch, norm]") 58 | _FLAGS.DEFINE_string("D_norm", "instance", 59 | "normalization options:[instance, batch, norm]") 60 | 61 | 62 | _FLAGS.DEFINE_boolean("use_image_discriminator", True, 63 | "True for using patch discriminator, modify the size of input of discriminator") 64 | _FLAGS.DEFINE_integer("image_dis_size", 128, "The size of input for image discriminator") 65 | _FLAGS.DEFINE_boolean("use_edge_discriminator", True, 66 | "True for using patch discriminator, modify the size of input of discriminator, user for edge discriminator when G_num == 2") 67 | _FLAGS.DEFINE_integer("edge_dis_size", 128, "The size of input for edge discriminator") 68 | _FLAGS.DEFINE_float("joint_dweight", 1.0, 69 | "weight of origin discriminative loss") 70 | _FLAGS.DEFINE_float("image_dweight", 1.0, 71 | "weight of image discriminative loss, is ineffective when use_image_discriminator is false") 72 | _FLAGS.DEFINE_float("edge_dweight", 1.0, 73 | "weight of edge discriminative loss, is ineffective when use_edge_discriminator is false") 74 | _FLAGS.DEFINE_integer("z_dim", 100, "dimension of random vector z") 75 | FLAGS = _FLAGS.FLAGS 76 | 77 | os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu 78 | 79 | def make_outputs_dir(flags): 80 | makedirs(flags.outputsroot) 81 | makedirs(flags.checkpoint_dir) 82 | makedirs(flags.logdir) 83 | 84 | 85 | def update_flags(flags): 86 | if flags.input_width is None: 87 | flags.input_width = flags.input_height 88 | if flags.output_width is None: 89 | flags.output_width = flags.output_height 90 | 91 | if not flags.multiclasses: 92 | flags.num_classes = None 93 | 94 | path = os.path.join(flags.outputsroot, flags.name) 95 | setattr(flags, 'checkpoint_dir', os.path.join(path, 'checkpoints')) 96 | setattr(flags, 'logdir', os.path.join(path, 'logs')) 97 | 98 | return flags 99 | 100 | def save_flags(flags): 101 | path = os.path.join(flags.outputsroot, flags.name) 102 | 103 | flag_dict = flags.flag_values_dict() 104 | with open(os.path.join(path, 'flags.json'), 'w') as f: 105 | json.dump(flag_dict, f, indent=4) 106 | 107 | return flags 108 | 109 | def main(_): 110 | phase = 'train' 111 | flags = update_flags(FLAGS) 112 | pp.pprint(flags.__flags) 113 | make_outputs_dir(flags) 114 | save_flags(flags) 115 | 116 | run_config = tf.ConfigProto() 117 | run_config.gpu_options.allow_growth = True 118 | dataset_config = { 119 | 'input_height': flags.input_height, 120 | 'input_width': flags.input_width, 121 | 'output_height': flags.output_height, 122 | 'output_width': flags.output_width, 123 | 'crop': flags.crop, 124 | 'grayscale': False, 125 | 'z_dim': flags.z_dim, 126 | } 127 | 128 | with tf.Session(config=run_config) as sess: 129 | dataset = Dataset( 130 | flags.dataroot, flags.dataset, 131 | flags.train_size, flags.batch_size, 132 | dataset_config, flags.num_classes, phase) 133 | edgegan_model = EdgeGAN(sess, flags, dataset, z_dim=flags.z_dim) 134 | edgegan_model.train() 135 | 136 | 137 | if __name__ == '__main__': 138 | tf.app.run() 139 | -------------------------------------------------------------------------------- /edgegan/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import * 2 | -------------------------------------------------------------------------------- /edgegan/utils/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataset import Dataset 2 | -------------------------------------------------------------------------------- /edgegan/utils/data/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | from pathlib import Path 4 | 5 | import numpy as np 6 | 7 | from edgegan.utils import get_image 8 | 9 | 10 | def extension_match_recursive(root, exts): 11 | result = [] 12 | for ext in exts: 13 | paths = [str(p) for p in Path(root).rglob(ext)] 14 | result.extend(paths) 15 | return result 16 | 17 | 18 | class Dataset(): 19 | def __init__(self, dataroot, name, size, batchsize, config, num_classes=None, phase='train'): 20 | assert phase in ['train', 'test', ] 21 | self.batchsize = batchsize 22 | self.num_classes = num_classes 23 | self.config = config 24 | self.phase = phase 25 | if phase == 'train': 26 | if num_classes is not None: 27 | self.data = [] 28 | for i in range(num_classes): 29 | for ext in ['*.png', '*.jpg']: 30 | data_path = os.path.join( 31 | dataroot, name, phase, str(i), ext) 32 | self.data.extend(glob(data_path)) 33 | else: 34 | data_path = os.path.join( 35 | dataroot, name, phase, '*.png') 36 | self.data = glob(data_path) 37 | else: 38 | data_path = os.path.join(dataroot, name, phase) 39 | self.data = extension_match_recursive( 40 | data_path, 41 | ['*.png', '*.jpg'] 42 | ) 43 | self.data = sorted(self.data) 44 | 45 | if len(self.data) == 0: 46 | raise Exception("[!] No data found in '" + data_path + "'") 47 | if len(self.data) < self.batchsize: 48 | raise Exception( 49 | "[!] Entire dataset size is less than the configured batch_size") 50 | self.size = min(len(self.data), size) 51 | 52 | def shuffle(self): 53 | np.random.shuffle(self.data) 54 | 55 | def __len__(self): 56 | return self.size // self.batchsize 57 | 58 | def __getitem__(self, idx): 59 | filenames = self.data[idx * self.batchsize: (idx+1)*self.batchsize] 60 | batch = [ 61 | get_image(filename, 62 | input_height=self.config['input_height'], 63 | input_width=self.config['input_width'], 64 | resize_height=self.config['output_height'], 65 | resize_width=self.config['output_width'], 66 | crop=self.config['crop'], 67 | grayscale=self.config['grayscale']) for filename in filenames] 68 | 69 | batch_images = np.array(batch).astype(np.float32) 70 | 71 | if self.phase == 'train': 72 | batch_z = np.random.normal( 73 | size=(self.batchsize, self.config['z_dim'])) 74 | 75 | if self.num_classes is not None: 76 | def get_class(filePath): 77 | end = filePath.rfind("/") 78 | start = filePath.rfind("/", 0, end) 79 | return int(filePath[start+1:end]) 80 | batch_classes = [get_class(batch_file) 81 | for batch_file in filenames] 82 | batch_classes = np.array( 83 | batch_classes).reshape((self.batchsize, 1)) 84 | batch_z = np.concatenate((batch_z, batch_classes), axis=1) 85 | 86 | if self.phase == 'test': 87 | assert batch_images.shape[0] == len(filenames) 88 | 89 | return (batch_images, batch_z, filenames) if self.phase == 'train' else (batch_images, filenames) 90 | -------------------------------------------------------------------------------- /edgegan/utils/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import os 3 | import math 4 | import pprint 5 | import numpy as np 6 | import imageio 7 | import scipy.misc 8 | from PIL import Image 9 | 10 | import tensorflow as tf 11 | import tensorflow.contrib.slim as slim 12 | 13 | 14 | def makedirs(path): 15 | if not os.path.exists(path): 16 | os.system('mkdir -p {}'.format(path)) 17 | # os.makedirs(path) 18 | 19 | 20 | checksum_path = 'checksum' 21 | 22 | pp = pprint.PrettyPrinter() 23 | 24 | 25 | def get_stddev(x, k_h, k_w): return 1 / \ 26 | math.sqrt(k_w * k_h * x.get_shape()[-1]) 27 | 28 | 29 | def image_manifold_size(num_images): 30 | manifold_h = int(np.floor(np.sqrt(num_images))) 31 | manifold_w = int(np.ceil(np.sqrt(num_images))) 32 | assert manifold_h * manifold_w == num_images 33 | return manifold_h, manifold_w 34 | 35 | 36 | def show_all_variables(): 37 | model_vars = tf.trainable_variables() 38 | slim.model_analyzer.analyze_vars(model_vars, print_info=True) 39 | 40 | 41 | def get_image(image_path, 42 | input_height, 43 | input_width, 44 | resize_height=64, 45 | resize_width=64, 46 | crop=True, 47 | grayscale=False): 48 | image = imread(image_path, grayscale) 49 | return transform(image, input_height, input_width, resize_height, 50 | resize_width, crop) 51 | 52 | 53 | def save_images(images, size, image_path): 54 | return imsave(inverse_transform(images), size, image_path) 55 | 56 | 57 | # def imread(path, grayscale=False): 58 | # assert grayscale == False 59 | # # return pyplot.imread(path).astype(np.float) * 255 60 | # return imageio.imread(path) 61 | 62 | 63 | def merge_images(images, size): 64 | return inverse_transform(images) 65 | 66 | 67 | def merge(images, size): 68 | h, w = images.shape[1], images.shape[2] 69 | if (images.shape[3] in (3, 4)): 70 | c = images.shape[3] 71 | img = np.zeros((h * size[0], w * size[1], c)) 72 | for idx, image in enumerate(images): 73 | i = idx % size[1] 74 | j = idx // size[1] 75 | img[j * h:j * h + h, i * w:i * w + w, :] = image 76 | return img 77 | elif images.shape[3] == 1: 78 | img = np.zeros((h * size[0], w * size[1])) 79 | for idx, image in enumerate(images): 80 | i = idx % size[1] 81 | j = idx // size[1] 82 | img[j * h:j * h + h, i * w:i * w + w] = image[:, :, 0] 83 | return img 84 | else: 85 | raise ValueError('in merge(images,size) images parameter ' 86 | 'must have dimensions: HxW or HxWx3 or HxWx4') 87 | 88 | 89 | # def imsave(images, size, path): 90 | # image = np.squeeze(merge(images, size)) 91 | # return imageio.imsave(path, (image * 255).astype(np.uint8)) 92 | 93 | 94 | # def imresize(img, size): 95 | # size[0], size[1] = size[1], size[0] 96 | # img = Image.fromarray(img) 97 | # resized = img.resize(size, Image.BILINEAR) 98 | # return np.array(resized) 99 | 100 | 101 | # def center_crop(x, crop_h, crop_w, resize_h=64, resize_w=64): 102 | # if crop_w is None: 103 | # crop_w = crop_h 104 | # h, w = x.shape[:2] 105 | # j = int(round((h - crop_h) / 2.)) 106 | # i = int(round((w - crop_w) / 2.)) 107 | # return imresize(x[j:j + crop_h, i:i + crop_w], 108 | # [resize_h, resize_w]) 109 | 110 | 111 | # def transform(image, 112 | # input_height, 113 | # input_width, 114 | # resize_height=64, 115 | # resize_width=64, 116 | # crop=True): 117 | # if crop: 118 | # cropped_image = center_crop(image, input_height, input_width, 119 | # resize_height, resize_width) 120 | # else: 121 | # cropped_image = imresize(image, 122 | # [resize_height, resize_width]) 123 | # return cropped_image.astype(np.float32) / 127.5 - 1. 124 | 125 | 126 | def imread(path, grayscale=False): 127 | if (grayscale): 128 | return scipy.misc.imread(path, flatten=True).astype(np.float) 129 | else: 130 | return scipy.misc.imread(path).astype(np.float) 131 | 132 | 133 | def imsave(images, size, path): 134 | image = np.squeeze(merge(images, size)) 135 | return scipy.misc.imsave(path, image) 136 | 137 | 138 | def center_crop(x, crop_h, crop_w, resize_h=64, resize_w=64): 139 | if crop_w is None: 140 | crop_w = crop_h 141 | h, w = x.shape[:2] 142 | j = int(round((h - crop_h) / 2.)) 143 | i = int(round((w - crop_w) / 2.)) 144 | return scipy.misc.imresize(x[j:j + crop_h, i:i + crop_w], 145 | [resize_h, resize_w]) 146 | 147 | 148 | def transform(image, 149 | input_height, 150 | input_width, 151 | resize_height=64, 152 | resize_width=64, 153 | crop=True): 154 | if crop: 155 | cropped_image = center_crop(image, input_height, input_width, 156 | resize_height, resize_width) 157 | else: 158 | cropped_image = scipy.misc.imresize(image, 159 | [resize_height, resize_width]) 160 | return np.array(cropped_image) / 127.5 - 1. 161 | 162 | 163 | def inverse_transform(images): 164 | return (images + 1.) / 2. 165 | 166 | 167 | def to_json(output_path, *layers): 168 | with open(output_path, "w") as layer_f: 169 | lines = "" 170 | for w, b, bn in layers: 171 | layer_idx = w.name.split('/')[0].split('h')[1] 172 | 173 | B = b.eval() 174 | 175 | if "lin/" in w.name: 176 | W = w.eval() 177 | depth = W.shape[1] 178 | else: 179 | W = np.rollaxis(w.eval(), 2, 0) 180 | depth = W.shape[0] 181 | 182 | biases = { 183 | "sy": 1, 184 | "sx": 1, 185 | "depth": depth, 186 | "w": ['%.2f' % elem for elem in list(B)] 187 | } 188 | if bn is not None: 189 | gamma = bn.gamma.eval() 190 | beta = bn.beta.eval() 191 | 192 | gamma = { 193 | "sy": 1, 194 | "sx": 1, 195 | "depth": depth, 196 | "w": ['%.2f' % elem for elem in list(gamma)] 197 | } 198 | beta = { 199 | "sy": 1, 200 | "sx": 1, 201 | "depth": depth, 202 | "w": ['%.2f' % elem for elem in list(beta)] 203 | } 204 | else: 205 | gamma = {"sy": 1, "sx": 1, "depth": 0, "w": []} 206 | beta = {"sy": 1, "sx": 1, "depth": 0, "w": []} 207 | 208 | if "lin/" in w.name: 209 | fs = [] 210 | for w in W.T: 211 | fs.append({ 212 | "sy": 1, 213 | "sx": 1, 214 | "depth": W.shape[0], 215 | "w": ['%.2f' % elem for elem in list(w)] 216 | }) 217 | 218 | lines += """ 219 | var layer_%s = { 220 | "layer_type": "fc", 221 | "sy": 1, "sx": 1, 222 | "out_sx": 1, "out_sy": 1, 223 | "stride": 1, "pad": 0, 224 | "out_depth": %s, "in_depth": %s, 225 | "biases": %s, 226 | "gamma": %s, 227 | "beta": %s, 228 | "filters": %s 229 | };""" % (layer_idx.split('_')[0], W.shape[1], W.shape[0], biases, 230 | gamma, beta, fs) 231 | else: 232 | fs = [] 233 | for w_ in W: 234 | fs.append({ 235 | "sy": 236 | 5, 237 | "sx": 238 | 5, 239 | "depth": 240 | W.shape[3], 241 | "w": ['%.2f' % elem for elem in list(w_.flatten())] 242 | }) 243 | 244 | lines += """ 245 | var layer_%s = { 246 | "layer_type": "deconv", 247 | "sy": 5, "sx": 5, 248 | "out_sx": %s, "out_sy": %s, 249 | "stride": 2, "pad": 1, 250 | "out_depth": %s, "in_depth": %s, 251 | "biases": %s, 252 | "gamma": %s, 253 | "beta": %s, 254 | "filters": %s 255 | };""" % (layer_idx, 2**(int(layer_idx) + 2), 2**(int(layer_idx) + 2), 256 | W.shape[0], W.shape[3], biases, gamma, beta, fs) 257 | layer_f.write(" ".join(lines.replace("'", "").split())) 258 | 259 | 260 | def make_gif(images, fname, duration=2, true_image=False): 261 | import moviepy.editor as mpy 262 | 263 | def make_frame(t): 264 | try: 265 | x = images[int(len(images) / duration * t)] 266 | except: 267 | x = images[-1] 268 | 269 | if true_image: 270 | return x.astype(np.uint8) 271 | else: 272 | return ((x + 1) / 2 * 255).astype(np.uint8) 273 | 274 | clip = mpy.VideoClip(make_frame, duration=duration) 275 | clip.write_gif(fname, fps=len(images) / duration) 276 | -------------------------------------------------------------------------------- /images/dataset_example/test/14809.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu-imsl/EdgeGAN/7867a5a35ff860f692cf9a167a22742cb424407e/images/dataset_example/test/14809.png -------------------------------------------------------------------------------- /images/dataset_example/test/14810.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu-imsl/EdgeGAN/7867a5a35ff860f692cf9a167a22742cb424407e/images/dataset_example/test/14810.png -------------------------------------------------------------------------------- /images/dataset_example/test/14811.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu-imsl/EdgeGAN/7867a5a35ff860f692cf9a167a22742cb424407e/images/dataset_example/test/14811.png -------------------------------------------------------------------------------- /images/dataset_example/test/14812.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu-imsl/EdgeGAN/7867a5a35ff860f692cf9a167a22742cb424407e/images/dataset_example/test/14812.png -------------------------------------------------------------------------------- /images/dataset_example/train/60975.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu-imsl/EdgeGAN/7867a5a35ff860f692cf9a167a22742cb424407e/images/dataset_example/train/60975.png -------------------------------------------------------------------------------- /images/dataset_example/train/60981.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu-imsl/EdgeGAN/7867a5a35ff860f692cf9a167a22742cb424407e/images/dataset_example/train/60981.png -------------------------------------------------------------------------------- /images/dataset_example/train/60987.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu-imsl/EdgeGAN/7867a5a35ff860f692cf9a167a22742cb424407e/images/dataset_example/train/60987.png -------------------------------------------------------------------------------- /images/dataset_example/train/60991.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu-imsl/EdgeGAN/7867a5a35ff860f692cf9a167a22742cb424407e/images/dataset_example/train/60991.png -------------------------------------------------------------------------------- /images/dataset_example/train/60994.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu-imsl/EdgeGAN/7867a5a35ff860f692cf9a167a22742cb424407e/images/dataset_example/train/60994.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.18.1 2 | imageio==2.6.1 3 | tensorflow_gpu==1.14.0 4 | scipy==1.2.2 5 | moviepy==1.0.3 6 | Pillow==7.2.0 7 | --------------------------------------------------------------------------------