├── .gitignore
├── LICENSE
├── README.md
├── Test
├── Set14
│ ├── baboon.bmp
│ ├── barbara.bmp
│ ├── bridge.bmp
│ ├── coastguard.bmp
│ ├── comic.bmp
│ ├── face.bmp
│ ├── flowers.bmp
│ ├── foreman.bmp
│ ├── lenna.bmp
│ ├── man.bmp
│ ├── monarch.bmp
│ ├── pepper.bmp
│ ├── ppt3.bmp
│ └── zebra.bmp
└── Set5
│ ├── baby_GT.bmp
│ ├── bird_GT.bmp
│ ├── butterfly_GT.bmp
│ ├── head_GT.bmp
│ └── woman_GT.bmp
├── Train
├── t1.bmp
├── t10.bmp
├── t11.bmp
├── t12.bmp
├── t13.bmp
├── t14.bmp
├── t15.bmp
├── t16.bmp
├── t17.bmp
├── t18.bmp
├── t19.bmp
├── t2.bmp
├── t20.bmp
├── t21.bmp
├── t22.bmp
├── t23.bmp
├── t24.bmp
├── t25.bmp
├── t26.bmp
├── t27.bmp
├── t28.bmp
├── t29.bmp
├── t3.bmp
├── t30.bmp
├── t31.bmp
├── t32.bmp
├── t33.bmp
├── t34.bmp
├── t35.bmp
├── t36.bmp
├── t37.bmp
├── t38.bmp
├── t39.bmp
├── t4.bmp
├── t40.bmp
├── t42.bmp
├── t43.bmp
├── t44.bmp
├── t45.bmp
├── t46.bmp
├── t47.bmp
├── t48.bmp
├── t49.bmp
├── t5.bmp
├── t50.bmp
├── t51.bmp
├── t52.bmp
├── t53.bmp
├── t54.bmp
├── t55.bmp
├── t56.bmp
├── t57.bmp
├── t58.bmp
├── t59.bmp
├── t6.bmp
├── t60.bmp
├── t61.bmp
├── t62.bmp
├── t63.bmp
├── t64.bmp
├── t65.bmp
├── t66.bmp
├── t7.bmp
├── t8.bmp
├── t9.bmp
├── tt1.bmp
├── tt10.bmp
├── tt12.bmp
├── tt13.bmp
├── tt14.bmp
├── tt15.bmp
├── tt16.bmp
├── tt17.bmp
├── tt18.bmp
├── tt19.bmp
├── tt2.bmp
├── tt20.bmp
├── tt21.bmp
├── tt22.bmp
├── tt23.bmp
├── tt24.bmp
├── tt25.bmp
├── tt26.bmp
├── tt27.bmp
├── tt3.bmp
├── tt4.bmp
├── tt5.bmp
├── tt6.bmp
├── tt7.bmp
├── tt8.bmp
└── tt9.bmp
├── expand_data.py
├── main.py
├── model.py
├── result
├── bicubic.png
├── fsrcnn.png
└── original.png
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | env/
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 |
27 | # PyInstaller
28 | # Usually these files are written by a python script from a template
29 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
30 | *.manifest
31 | *.spec
32 |
33 | # Installer logs
34 | pip-log.txt
35 | pip-delete-this-directory.txt
36 |
37 | # Unit test / coverage reports
38 | htmlcov/
39 | .tox/
40 | .coverage
41 | .coverage.*
42 | .cache
43 | nosetests.xml
44 | coverage.xml
45 | *,cover
46 | .hypothesis/
47 |
48 | # Translations
49 | *.mo
50 | *.pot
51 |
52 | # Django stuff:
53 | *.log
54 | local_settings.py
55 |
56 | # Flask stuff:
57 | instance/
58 | .webassets-cache
59 |
60 | # Scrapy stuff:
61 | .scrapy
62 |
63 | # Sphinx documentation
64 | docs/_build/
65 |
66 | # PyBuilder
67 | target/
68 |
69 | # IPython Notebook
70 | .ipynb_checkpoints
71 |
72 | # pyenv
73 | .python-version
74 |
75 | # celery beat schedule file
76 | celerybeat-schedule
77 |
78 | # dotenv
79 | .env
80 |
81 | # virtualenv
82 | venv/
83 | ENV/
84 |
85 | # Spyder project settings
86 | .spyderproject
87 |
88 | # Rope project settings
89 | .ropeproject
90 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2016 Drake Levy
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # FSRCNN-TensorFlow
2 | TensorFlow implementation of the Fast Super-Resolution Convolutional Neural Network (FSRCNN). This implements two models: FSRCNN which is more accurate but slower and FSRCNN-s which is faster but less accurate. Based on this [project](http://mmlab.ie.cuhk.edu.hk/projects/FSRCNN.html).
3 |
4 | ## Prerequisites
5 | * Python 2.7
6 | * TensorFlow
7 | * Scipy version > 0.18
8 | * h5py
9 | * PIL
10 |
11 | ## Usage
12 | For training: `python main.py`
13 |
14 | For testing: `python main.py --train False`
15 |
16 | To use FSCRNN-s instead of FSCRNN: `python main.py --fast True`
17 |
18 | Can specify epochs, learning rate, data directory, etc:
19 |
20 | `python main.py --epochs 10 --learning_rate 0.0001 --data_dir Train`
21 |
22 | Check `main.py` for all the possible flags
23 |
24 | Also includes script `expand_data.py` which scales and rotates all the images in the specified training set to expand it
25 |
26 | ## Result
27 |
28 | Original butterfly image:
29 |
30 | 
31 |
32 |
33 | Bicubic interpolated image:
34 |
35 | 
36 |
37 |
38 | Super-resolved image:
39 |
40 | 
41 |
42 | ## TODO
43 |
44 | * Add RGB support (Increase each layer depth to 3)
45 | * Speed up pre-processing for large datasets
46 | * Set learning rate for deconvolutional layer to 1e-4 (vs 1e-3 for the rest)
47 |
48 | ## References
49 |
50 | * [tegg89/SRCNN-Tensorflow](https://github.com/tegg89/SRCNN-Tensorflow)
51 |
52 | * [liliumao/Tensorflow-srcnn](https://github.com/liliumao/Tensorflow-srcnn)
53 |
54 | * [carpedm20/DCGAN-tensorflow](https://github.com/carpedm20/DCGAN-tensorflow)
55 |
--------------------------------------------------------------------------------
/Test/Set14/baboon.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Test/Set14/baboon.bmp
--------------------------------------------------------------------------------
/Test/Set14/barbara.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Test/Set14/barbara.bmp
--------------------------------------------------------------------------------
/Test/Set14/bridge.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Test/Set14/bridge.bmp
--------------------------------------------------------------------------------
/Test/Set14/coastguard.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Test/Set14/coastguard.bmp
--------------------------------------------------------------------------------
/Test/Set14/comic.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Test/Set14/comic.bmp
--------------------------------------------------------------------------------
/Test/Set14/face.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Test/Set14/face.bmp
--------------------------------------------------------------------------------
/Test/Set14/flowers.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Test/Set14/flowers.bmp
--------------------------------------------------------------------------------
/Test/Set14/foreman.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Test/Set14/foreman.bmp
--------------------------------------------------------------------------------
/Test/Set14/lenna.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Test/Set14/lenna.bmp
--------------------------------------------------------------------------------
/Test/Set14/man.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Test/Set14/man.bmp
--------------------------------------------------------------------------------
/Test/Set14/monarch.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Test/Set14/monarch.bmp
--------------------------------------------------------------------------------
/Test/Set14/pepper.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Test/Set14/pepper.bmp
--------------------------------------------------------------------------------
/Test/Set14/ppt3.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Test/Set14/ppt3.bmp
--------------------------------------------------------------------------------
/Test/Set14/zebra.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Test/Set14/zebra.bmp
--------------------------------------------------------------------------------
/Test/Set5/baby_GT.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Test/Set5/baby_GT.bmp
--------------------------------------------------------------------------------
/Test/Set5/bird_GT.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Test/Set5/bird_GT.bmp
--------------------------------------------------------------------------------
/Test/Set5/butterfly_GT.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Test/Set5/butterfly_GT.bmp
--------------------------------------------------------------------------------
/Test/Set5/head_GT.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Test/Set5/head_GT.bmp
--------------------------------------------------------------------------------
/Test/Set5/woman_GT.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Test/Set5/woman_GT.bmp
--------------------------------------------------------------------------------
/Train/t1.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t1.bmp
--------------------------------------------------------------------------------
/Train/t10.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t10.bmp
--------------------------------------------------------------------------------
/Train/t11.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t11.bmp
--------------------------------------------------------------------------------
/Train/t12.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t12.bmp
--------------------------------------------------------------------------------
/Train/t13.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t13.bmp
--------------------------------------------------------------------------------
/Train/t14.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t14.bmp
--------------------------------------------------------------------------------
/Train/t15.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t15.bmp
--------------------------------------------------------------------------------
/Train/t16.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t16.bmp
--------------------------------------------------------------------------------
/Train/t17.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t17.bmp
--------------------------------------------------------------------------------
/Train/t18.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t18.bmp
--------------------------------------------------------------------------------
/Train/t19.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t19.bmp
--------------------------------------------------------------------------------
/Train/t2.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t2.bmp
--------------------------------------------------------------------------------
/Train/t20.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t20.bmp
--------------------------------------------------------------------------------
/Train/t21.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t21.bmp
--------------------------------------------------------------------------------
/Train/t22.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t22.bmp
--------------------------------------------------------------------------------
/Train/t23.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t23.bmp
--------------------------------------------------------------------------------
/Train/t24.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t24.bmp
--------------------------------------------------------------------------------
/Train/t25.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t25.bmp
--------------------------------------------------------------------------------
/Train/t26.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t26.bmp
--------------------------------------------------------------------------------
/Train/t27.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t27.bmp
--------------------------------------------------------------------------------
/Train/t28.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t28.bmp
--------------------------------------------------------------------------------
/Train/t29.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t29.bmp
--------------------------------------------------------------------------------
/Train/t3.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t3.bmp
--------------------------------------------------------------------------------
/Train/t30.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t30.bmp
--------------------------------------------------------------------------------
/Train/t31.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t31.bmp
--------------------------------------------------------------------------------
/Train/t32.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t32.bmp
--------------------------------------------------------------------------------
/Train/t33.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t33.bmp
--------------------------------------------------------------------------------
/Train/t34.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t34.bmp
--------------------------------------------------------------------------------
/Train/t35.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t35.bmp
--------------------------------------------------------------------------------
/Train/t36.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t36.bmp
--------------------------------------------------------------------------------
/Train/t37.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t37.bmp
--------------------------------------------------------------------------------
/Train/t38.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t38.bmp
--------------------------------------------------------------------------------
/Train/t39.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t39.bmp
--------------------------------------------------------------------------------
/Train/t4.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t4.bmp
--------------------------------------------------------------------------------
/Train/t40.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t40.bmp
--------------------------------------------------------------------------------
/Train/t42.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t42.bmp
--------------------------------------------------------------------------------
/Train/t43.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t43.bmp
--------------------------------------------------------------------------------
/Train/t44.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t44.bmp
--------------------------------------------------------------------------------
/Train/t45.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t45.bmp
--------------------------------------------------------------------------------
/Train/t46.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t46.bmp
--------------------------------------------------------------------------------
/Train/t47.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t47.bmp
--------------------------------------------------------------------------------
/Train/t48.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t48.bmp
--------------------------------------------------------------------------------
/Train/t49.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t49.bmp
--------------------------------------------------------------------------------
/Train/t5.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t5.bmp
--------------------------------------------------------------------------------
/Train/t50.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t50.bmp
--------------------------------------------------------------------------------
/Train/t51.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t51.bmp
--------------------------------------------------------------------------------
/Train/t52.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t52.bmp
--------------------------------------------------------------------------------
/Train/t53.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t53.bmp
--------------------------------------------------------------------------------
/Train/t54.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t54.bmp
--------------------------------------------------------------------------------
/Train/t55.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t55.bmp
--------------------------------------------------------------------------------
/Train/t56.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t56.bmp
--------------------------------------------------------------------------------
/Train/t57.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t57.bmp
--------------------------------------------------------------------------------
/Train/t58.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t58.bmp
--------------------------------------------------------------------------------
/Train/t59.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t59.bmp
--------------------------------------------------------------------------------
/Train/t6.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t6.bmp
--------------------------------------------------------------------------------
/Train/t60.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t60.bmp
--------------------------------------------------------------------------------
/Train/t61.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t61.bmp
--------------------------------------------------------------------------------
/Train/t62.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t62.bmp
--------------------------------------------------------------------------------
/Train/t63.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t63.bmp
--------------------------------------------------------------------------------
/Train/t64.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t64.bmp
--------------------------------------------------------------------------------
/Train/t65.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t65.bmp
--------------------------------------------------------------------------------
/Train/t66.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t66.bmp
--------------------------------------------------------------------------------
/Train/t7.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t7.bmp
--------------------------------------------------------------------------------
/Train/t8.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t8.bmp
--------------------------------------------------------------------------------
/Train/t9.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t9.bmp
--------------------------------------------------------------------------------
/Train/tt1.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/tt1.bmp
--------------------------------------------------------------------------------
/Train/tt10.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/tt10.bmp
--------------------------------------------------------------------------------
/Train/tt12.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/tt12.bmp
--------------------------------------------------------------------------------
/Train/tt13.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/tt13.bmp
--------------------------------------------------------------------------------
/Train/tt14.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/tt14.bmp
--------------------------------------------------------------------------------
/Train/tt15.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/tt15.bmp
--------------------------------------------------------------------------------
/Train/tt16.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/tt16.bmp
--------------------------------------------------------------------------------
/Train/tt17.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/tt17.bmp
--------------------------------------------------------------------------------
/Train/tt18.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/tt18.bmp
--------------------------------------------------------------------------------
/Train/tt19.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/tt19.bmp
--------------------------------------------------------------------------------
/Train/tt2.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/tt2.bmp
--------------------------------------------------------------------------------
/Train/tt20.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/tt20.bmp
--------------------------------------------------------------------------------
/Train/tt21.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/tt21.bmp
--------------------------------------------------------------------------------
/Train/tt22.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/tt22.bmp
--------------------------------------------------------------------------------
/Train/tt23.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/tt23.bmp
--------------------------------------------------------------------------------
/Train/tt24.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/tt24.bmp
--------------------------------------------------------------------------------
/Train/tt25.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/tt25.bmp
--------------------------------------------------------------------------------
/Train/tt26.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/tt26.bmp
--------------------------------------------------------------------------------
/Train/tt27.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/tt27.bmp
--------------------------------------------------------------------------------
/Train/tt3.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/tt3.bmp
--------------------------------------------------------------------------------
/Train/tt4.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/tt4.bmp
--------------------------------------------------------------------------------
/Train/tt5.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/tt5.bmp
--------------------------------------------------------------------------------
/Train/tt6.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/tt6.bmp
--------------------------------------------------------------------------------
/Train/tt7.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/tt7.bmp
--------------------------------------------------------------------------------
/Train/tt8.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/tt8.bmp
--------------------------------------------------------------------------------
/Train/tt9.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/tt9.bmp
--------------------------------------------------------------------------------
/expand_data.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import glob
4 | import numpy as np
5 | from PIL import Image
6 | import pdb
7 |
8 | # Artifically expands the dataset by a factor of 19 by scaling and then rotating every image
9 | def main():
10 | if len(sys.argv) == 2:
11 | data = prepare_data(sys.argv[1])
12 | else:
13 | print("Missing argument: You must specify a folder with images to expand")
14 | return
15 |
16 | for i in xrange(len(data)):
17 | scale(data[i])
18 | rotate(data[i])
19 |
20 | def prepare_data(dataset):
21 | filenames = os.listdir(dataset)
22 | data_dir = os.path.join(os.getcwd(), dataset)
23 | data = glob.glob(os.path.join(data_dir, "*.bmp"))
24 |
25 | return data
26 |
27 | def scale(file):
28 | image = Image.open(file)
29 | width, height = image.size
30 |
31 | scales = [0.9, 0.8, 0.7, 0.6]
32 | for scale in scales:
33 | new_width, new_height = int(width * scale), int(height * scale)
34 | new_image = image.resize((new_width, new_height), Image.ANTIALIAS)
35 | new_path = '{}-{}.bmp'.format(file[:-4], scale)
36 | new_image.save(new_path)
37 |
38 | def rotate(file):
39 | image = Image.open(file)
40 |
41 | rotations = [90, 180, 270]
42 | for rotation in rotations:
43 | new_image = image.rotate(rotation, expand=True)
44 | new_path = '{}-{}.bmp'.format(file[:-4], rotation)
45 | new_image.save(new_path)
46 |
47 | if __name__ == '__main__':
48 | main()
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | from model import FSRCNN
2 |
3 | import numpy as np
4 | import tensorflow as tf
5 |
6 | import pprint
7 | import os
8 |
9 | flags = tf.app.flags
10 | flags.DEFINE_boolean("fast", False, "Use the fast model (FSRCNN-s) [False]")
11 | flags.DEFINE_integer("epoch", 10, "Number of epochs [10]")
12 | flags.DEFINE_integer("batch_size", 128, "The size of batch images [128]")
13 | flags.DEFINE_float("learning_rate", 1e-3, "The learning rate of gradient descent algorithm [1e-3]")
14 | flags.DEFINE_float("momentum", 0.9, "The momentum value for the momentum SGD [0.9]")
15 | flags.DEFINE_integer("c_dim", 1, "Dimension of image color [1]")
16 | flags.DEFINE_integer("scale", 3, "The size of scale factor for preprocessing input image [3]")
17 | flags.DEFINE_integer("stride", 4, "The size of stride to apply to input image [4]")
18 | flags.DEFINE_string("checkpoint_dir", "checkpoint", "Name of checkpoint directory [checkpoint]")
19 | flags.DEFINE_string("output_dir", "result", "Name of test output directory [result]")
20 | flags.DEFINE_string("data_dir", "FastTrain", "Name of data directory to train on [FastTrain]")
21 | flags.DEFINE_boolean("train", True, "True for training, false for testing [True]")
22 | flags.DEFINE_integer("threads", 1, "Number of processes to pre-process data with [1]")
23 | flags.DEFINE_boolean("params", False, "Save weight and bias parameters [False]")
24 |
25 | FLAGS = flags.FLAGS
26 |
27 | pp = pprint.PrettyPrinter()
28 |
29 | def main(_):
30 | pp.pprint(flags.FLAGS.__flags)
31 |
32 | if FLAGS.fast:
33 | FLAGS.checkpoint_dir = 'fast_{}'.format(FLAGS.checkpoint_dir)
34 | if not os.path.exists(FLAGS.checkpoint_dir):
35 | os.makedirs(FLAGS.checkpoint_dir)
36 | if not os.path.exists(FLAGS.output_dir):
37 | os.makedirs(FLAGS.output_dir)
38 |
39 |
40 | with tf.Session() as sess:
41 | fsrcnn = FSRCNN(sess, config=FLAGS)
42 | fsrcnn.run()
43 |
44 | if __name__ == '__main__':
45 | tf.app.run()
46 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | from utils import (
2 | read_data,
3 | thread_train_setup,
4 | train_input_setup,
5 | test_input_setup,
6 | save_params,
7 | merge,
8 | array_image_save
9 | )
10 |
11 | import time
12 | import os
13 |
14 | import numpy as np
15 | import tensorflow as tf
16 |
17 | from PIL import Image
18 | import pdb
19 |
20 | # Based on http://mmlab.ie.cuhk.edu.hk/projects/FSRCNN.html
21 | class FSRCNN(object):
22 |
23 | def __init__(self, sess, config):
24 | self.sess = sess
25 | self.fast = config.fast
26 | self.train = config.train
27 | self.c_dim = config.c_dim
28 | self.is_grayscale = (self.c_dim == 1)
29 | self.epoch = config.epoch
30 | self.scale = config.scale
31 | self.stride = config.stride
32 | self.batch_size = config.batch_size
33 | self.learning_rate = config.learning_rate
34 | self.momentum = config.momentum
35 | self.threads = config.threads
36 | self.params = config.params
37 |
38 | # Different image/label sub-sizes for different scaling factors x2, x3, x4
39 | scale_factors = [[14, 20], [11, 21], [10, 24]]
40 | self.image_size, self.label_size = scale_factors[self.scale - 2]
41 | # Testing uses different strides to ensure sub-images line up correctly
42 | if not self.train:
43 | self.stride = [10, 7, 6][self.scale - 2]
44 |
45 | # Different model layer counts and filter sizes for FSRCNN vs FSRCNN-s (fast), (s, d, m) in paper
46 | model_params = [[56, 12, 4], [32, 5, 1]]
47 | self.model_params = model_params[self.fast]
48 |
49 | self.checkpoint_dir = config.checkpoint_dir
50 | self.output_dir = config.output_dir
51 | self.data_dir = config.data_dir
52 | self.build_model()
53 |
54 |
55 | def build_model(self):
56 | self.images = tf.placeholder(tf.float32, [None, self.image_size, self.image_size, self.c_dim], name='images')
57 | self.labels = tf.placeholder(tf.float32, [None, self.label_size, self.label_size, self.c_dim], name='labels')
58 | # Batch size differs in training vs testing
59 | self.batch = tf.placeholder(tf.int32, shape=[], name='batch')
60 |
61 | # FSCRNN-s (fast) has smaller filters and less layers but can achieve faster performance
62 | s, d, m = self.model_params
63 |
64 | expand_weight, deconv_weight = 'w{}'.format(m + 3), 'w{}'.format(m + 4)
65 | self.weights = {
66 | 'w1': tf.Variable(tf.random_normal([5, 5, 1, s], stddev=0.0378, dtype=tf.float32), name='w1'),
67 | 'w2': tf.Variable(tf.random_normal([1, 1, s, d], stddev=0.3536, dtype=tf.float32), name='w2'),
68 | expand_weight: tf.Variable(tf.random_normal([1, 1, d, s], stddev=0.189, dtype=tf.float32), name=expand_weight),
69 | deconv_weight: tf.Variable(tf.random_normal([9, 9, 1, s], stddev=0.0001, dtype=tf.float32), name=deconv_weight)
70 | }
71 |
72 | expand_bias, deconv_bias = 'b{}'.format(m + 3), 'b{}'.format(m + 4)
73 | self.biases = {
74 | 'b1': tf.Variable(tf.zeros([s]), name='b1'),
75 | 'b2': tf.Variable(tf.zeros([d]), name='b2'),
76 | expand_bias: tf.Variable(tf.zeros([s]), name=expand_bias),
77 | deconv_bias: tf.Variable(tf.zeros([1]), name=deconv_bias)
78 | }
79 |
80 | # Create the m mapping layers weights/biases
81 | for i in range(3, m + 3):
82 | weight_name, bias_name = 'w{}'.format(i), 'b{}'.format(i)
83 | self.weights[weight_name] = tf.Variable(tf.random_normal([3, 3, d, d], stddev=0.1179, dtype=tf.float32), name=weight_name)
84 | self.biases[bias_name] = tf.Variable(tf.zeros([d]), name=bias_name)
85 |
86 | self.pred = self.model()
87 |
88 | # Loss function (MSE)
89 | self.loss = tf.reduce_mean(tf.reduce_sum(tf.square(self.labels - self.pred), reduction_indices=0))
90 |
91 | self.saver = tf.train.Saver()
92 |
93 | def run(self):
94 | # SGD with momentum
95 | self.train_op = tf.train.MomentumOptimizer(self.learning_rate, self.momentum).minimize(self.loss)
96 |
97 | tf.initialize_all_variables().run()
98 |
99 | if self.load(self.checkpoint_dir):
100 | print(" [*] Load SUCCESS")
101 | else:
102 | print(" [!] Load failed...")
103 |
104 | if self.params:
105 | save_params(self.sess, self.weights, self.biases)
106 | elif self.train:
107 | self.run_train()
108 | else:
109 | self.run_test()
110 |
111 | def run_train(self):
112 | start_time = time.time()
113 | print("Beginning training setup...")
114 | if self.threads == 1:
115 | train_input_setup(self)
116 | else:
117 | thread_train_setup(self)
118 | print("Training setup took {} seconds with {} threads".format(time.time() - start_time, self.threads))
119 |
120 | data_dir = os.path.join('./{}'.format(self.checkpoint_dir), "train.h5")
121 | train_data, train_label = read_data(data_dir)
122 | print("Total setup time took {} seconds with {} threads".format(time.time() - start_time, self.threads))
123 |
124 | print("Training...")
125 | start_time = time.time()
126 | start_average, end_average, counter = 0, 0, 0
127 |
128 | for ep in xrange(self.epoch):
129 | # Run by batch images
130 | batch_idxs = len(train_data) // self.batch_size
131 | batch_average = 0
132 | for idx in xrange(0, batch_idxs):
133 | batch_images = train_data[idx * self.batch_size : (idx + 1) * self.batch_size]
134 | batch_labels = train_label[idx * self.batch_size : (idx + 1) * self.batch_size]
135 |
136 | counter += 1
137 | _, err = self.sess.run([self.train_op, self.loss], feed_dict={self.images: batch_images, self.labels: batch_labels, self.batch: self.batch_size})
138 | batch_average += err
139 |
140 | if counter % 10 == 0:
141 | print("Epoch: [%2d], step: [%2d], time: [%4.4f], loss: [%.8f]" \
142 | % ((ep+1), counter, time.time() - start_time, err))
143 |
144 | # Save every 500 steps
145 | if counter % 500 == 0:
146 | self.save(self.checkpoint_dir, counter)
147 |
148 | batch_average = float(batch_average) / batch_idxs
149 | if ep < (self.epoch * 0.2):
150 | start_average += batch_average
151 | elif ep >= (self.epoch * 0.8):
152 | end_average += batch_average
153 |
154 | # Compare loss of the first 20% and the last 20% epochs
155 | start_average = float(start_average) / (self.epoch * 0.2)
156 | end_average = float(end_average) / (self.epoch * 0.2)
157 | print("Start Average: [%.6f], End Average: [%.6f], Improved: [%.2f%%]" \
158 | % (start_average, end_average, 100 - (100*end_average/start_average)))
159 |
160 | # Linux desktop notification when training has been completed
161 | # title = "Training complete - FSRCNN"
162 | # notification = "{}-{}-{} done training after {} epochs".format(self.image_size, self.label_size, self.stride, self.epoch);
163 | # notify_command = 'notify-send "{}" "{}"'.format(title, notification)
164 | # os.system(notify_command)
165 |
166 |
167 | def run_test(self):
168 | nx, ny = test_input_setup(self)
169 | data_dir = os.path.join('./{}'.format(self.checkpoint_dir), "test.h5")
170 | test_data, test_label = read_data(data_dir)
171 |
172 | print("Testing...")
173 |
174 | start_time = time.time()
175 | result = self.pred.eval({self.images: test_data, self.labels: test_label, self.batch: nx * ny})
176 | print("Took %.3f seconds" % (time.time() - start_time))
177 |
178 | result = merge(result, [nx, ny])
179 | result = result.squeeze()
180 | image_path = os.path.join(os.getcwd(), self.output_dir)
181 | image_path = os.path.join(image_path, "test_image.png")
182 |
183 | array_image_save(result * 255, image_path)
184 |
185 | def model(self):
186 | # Feature Extraction
187 | conv_feature = self.prelu(tf.nn.conv2d(self.images, self.weights['w1'], strides=[1,1,1,1], padding='VALID') + self.biases['b1'], 1)
188 |
189 | # Shrinking
190 | conv_shrink = self.prelu(tf.nn.conv2d(conv_feature, self.weights['w2'], strides=[1,1,1,1], padding='SAME') + self.biases['b2'], 2)
191 |
192 | # Mapping (# mapping layers = m)
193 | prev_layer, m = conv_shrink, self.model_params[2]
194 | for i in range(3, m + 3):
195 | weights, biases = self.weights['w{}'.format(i)], self.biases['b{}'.format(i)]
196 | prev_layer = self.prelu(tf.nn.conv2d(prev_layer, weights, strides=[1,1,1,1], padding='SAME') + biases, i)
197 |
198 | # Expanding
199 | expand_weights, expand_biases = self.weights['w{}'.format(m + 3)], self.biases['b{}'.format(m + 3)]
200 | conv_expand = self.prelu(tf.nn.conv2d(prev_layer, expand_weights, strides=[1,1,1,1], padding='SAME') + expand_biases, 7)
201 |
202 | # Deconvolution
203 | deconv_output = [self.batch, self.label_size, self.label_size, self.c_dim]
204 | deconv_stride = [1, self.scale, self.scale, 1]
205 | deconv_weights, deconv_biases = self.weights['w{}'.format(m + 4)], self.biases['b{}'.format(m + 4)]
206 | conv_deconv = tf.nn.conv2d_transpose(conv_expand, deconv_weights, output_shape=deconv_output, strides=deconv_stride, padding='SAME') + deconv_biases
207 |
208 | return conv_deconv
209 |
210 | def prelu(self, _x, i):
211 | """
212 | PreLU tensorflow implementation
213 | """
214 | alphas = tf.get_variable('alpha{}'.format(i), _x.get_shape()[-1], initializer=tf.constant_initializer(0.0), dtype=tf.float32)
215 | pos = tf.nn.relu(_x)
216 | neg = alphas * (_x - abs(_x)) * 0.5
217 |
218 | return pos + neg
219 |
220 | def save(self, checkpoint_dir, step):
221 | model_name = "FSRCNN.model"
222 | model_dir = "%s_%s" % ("fsrcnn", self.label_size)
223 | checkpoint_dir = os.path.join(checkpoint_dir, model_dir)
224 |
225 | if not os.path.exists(checkpoint_dir):
226 | os.makedirs(checkpoint_dir)
227 |
228 | self.saver.save(self.sess,
229 | os.path.join(checkpoint_dir, model_name),
230 | global_step=step)
231 |
232 | def load(self, checkpoint_dir):
233 | print(" [*] Reading checkpoints...")
234 | model_dir = "%s_%s" % ("fsrcnn", self.label_size)
235 | checkpoint_dir = os.path.join(checkpoint_dir, model_dir)
236 |
237 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
238 | if ckpt and ckpt.model_checkpoint_path:
239 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
240 | self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name))
241 | return True
242 | else:
243 | return False
--------------------------------------------------------------------------------
/result/bicubic.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/result/bicubic.png
--------------------------------------------------------------------------------
/result/fsrcnn.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/result/fsrcnn.png
--------------------------------------------------------------------------------
/result/original.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/result/original.png
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Scipy version > 0.18 is needed, due to 'mode' option from scipy.misc.imread function
3 | """
4 |
5 | import os
6 | import glob
7 | import h5py
8 | import random
9 | from math import floor
10 | import struct
11 |
12 | import tensorflow as tf
13 | from PIL import Image
14 | from scipy.misc import imread
15 | import numpy as np
16 | from multiprocessing import Pool, Lock, active_children
17 |
18 | import pdb
19 |
20 | FLAGS = tf.app.flags.FLAGS
21 |
22 | def read_data(path):
23 | """
24 | Read h5 format data file
25 |
26 | Args:
27 | path: file path of desired file
28 |
29 | Returns:
30 | data: '.h5' file format that contains train data values
31 | label: '.h5' file format that contains train label values
32 | """
33 | with h5py.File(path, 'r') as hf:
34 | data = np.array(hf.get('data'))
35 | label = np.array(hf.get('label'))
36 | return data, label
37 |
38 | def preprocess(path, scale=3):
39 | """
40 | Preprocess single image file
41 | (1) Read original image as YCbCr format (and grayscale as default)
42 | (2) Normalize
43 | (3) Downsampled by scale factor (using anti-aliasing)
44 | """
45 |
46 | image = Image.open(path).convert('L')
47 | (width, height) = image.size
48 | label_ = np.array(list(image.getdata())).astype(np.float).reshape((height, width)) / 255
49 | image.close()
50 |
51 | cropped_image = Image.fromarray(modcrop(label_, scale))
52 |
53 | (width, height) = cropped_image.size
54 | new_width, new_height = int(width / scale), int(height / scale)
55 | scaled_image = cropped_image.resize((new_width, new_height), Image.ANTIALIAS)
56 | cropped_image.close()
57 |
58 | (width, height) = scaled_image.size
59 | input_ = np.array(list(scaled_image.getdata())).astype(np.float).reshape((height, width))
60 |
61 | return input_, label_
62 |
63 | def prepare_data(sess, dataset):
64 | """
65 | Args:
66 | dataset: choose train dataset or test dataset
67 |
68 | For train dataset, output data would be ['.../t1.bmp', '.../t2.bmp', ..., '.../t99.bmp']
69 | """
70 | if FLAGS.train:
71 | filenames = os.listdir(dataset)
72 | data_dir = os.path.join(os.getcwd(), dataset)
73 | else:
74 | data_dir = os.path.join(os.sep, (os.path.join(os.getcwd(), dataset)), "Set5")
75 | data = sorted(glob.glob(os.path.join(data_dir, "*.bmp")))
76 |
77 | return data
78 |
79 | def make_data(sess, checkpoint_dir, data, label):
80 | """
81 | Make input data as h5 file format
82 | Depending on 'train' (flag value), savepath would be changed.
83 | """
84 | if FLAGS.train:
85 | savepath = os.path.join(os.getcwd(), '{}/train.h5'.format(checkpoint_dir))
86 | else:
87 | savepath = os.path.join(os.getcwd(), '{}/test.h5'.format(checkpoint_dir))
88 |
89 | with h5py.File(savepath, 'w') as hf:
90 | hf.create_dataset('data', data=data)
91 | hf.create_dataset('label', data=label)
92 |
93 | def image_read(path, is_grayscale=True):
94 | """
95 | Read image using its path.
96 | Default value is gray-scale, and image is read by YCbCr format as the paper said.
97 | """
98 | if is_grayscale:
99 | return imread(path, flatten=True, mode='YCbCr').astype(np.float)
100 | else:
101 | return imread(path, mode='YCbCr').astype(np.float)
102 |
103 | def modcrop(image, scale=3):
104 | """
105 | To scale down and up the original image, first thing to do is to have no remainder while scaling operation.
106 |
107 | We need to find modulo of height (and width) and scale factor.
108 | Then, subtract the modulo from height (and width) of original image size.
109 | There would be no remainder even after scaling operation.
110 | """
111 | if len(image.shape) == 3:
112 | h, w, _ = image.shape
113 | h = h - np.mod(h, scale)
114 | w = w - np.mod(w, scale)
115 | image = image[0:h, 0:w, :]
116 | else:
117 | h, w = image.shape
118 | h = h - np.mod(h, scale)
119 | w = w - np.mod(w, scale)
120 | image = image[0:h, 0:w]
121 | return image
122 |
123 | def train_input_worker(args):
124 | image_data, config = args
125 | image_size, label_size, stride, scale, save_image = config
126 |
127 | single_input_sequence, single_label_sequence = [], []
128 | padding = abs(image_size - label_size) / 2 # eg. for 3x: (21 - 11) / 2 = 5
129 | label_padding = label_size / scale # eg. for 3x: 21 / 3 = 7
130 |
131 | input_, label_ = preprocess(image_data, scale)
132 |
133 | if len(input_.shape) == 3:
134 | h, w, _ = input_.shape
135 | else:
136 | h, w = input_.shape
137 |
138 | for x in range(0, h - image_size - padding + 1, stride):
139 | for y in range(0, w - image_size - padding + 1, stride):
140 | sub_input = input_[x + padding : x + padding + image_size, y + padding : y + padding + image_size]
141 | x_loc, y_loc = x + label_padding, y + label_padding
142 | sub_label = label_[x_loc * scale : x_loc * scale + label_size, y_loc * scale : y_loc * scale + label_size]
143 |
144 | sub_input = sub_input.reshape([image_size, image_size, 1])
145 | sub_label = sub_label.reshape([label_size, label_size, 1])
146 |
147 | single_input_sequence.append(sub_input)
148 | single_label_sequence.append(sub_label)
149 |
150 | return [single_input_sequence, single_label_sequence]
151 |
152 |
153 | def thread_train_setup(config):
154 | """
155 | Spawns |config.threads| worker processes to pre-process the data
156 |
157 | This has not been extensively tested so use at your own risk.
158 | Also this is technically multiprocessing not threading, I just say thread
159 | because it's shorter to type.
160 | """
161 | sess = config.sess
162 |
163 | # Load data path
164 | data = prepare_data(sess, dataset=config.data_dir)
165 |
166 | # Initialize multiprocessing pool with # of processes = config.threads
167 | pool = Pool(config.threads)
168 |
169 | # Distribute |images_per_thread| images across each worker process
170 | config_values = [config.image_size, config.label_size, config.stride, config.scale, config.save_image]
171 | images_per_thread = len(data) / config.threads
172 | workers = []
173 | for thread in range(config.threads):
174 | args_list = [(data[i], config_values) for i in range(thread * images_per_thread, (thread + 1) * images_per_thread)]
175 | worker = pool.map_async(train_input_worker, args_list)
176 | workers.append(worker)
177 | print("{} worker processes created".format(config.threads))
178 |
179 | pool.close()
180 |
181 | results = []
182 | for i in range(len(workers)):
183 | print("Waiting for worker process {}".format(i))
184 | results.extend(workers[i].get(timeout=240))
185 | print("Worker process {} done".format(i))
186 |
187 | print("All worker processes done!")
188 |
189 | sub_input_sequence, sub_label_sequence = [], []
190 |
191 | for image in range(len(results)):
192 | single_input_sequence, single_label_sequence = results[image]
193 | sub_input_sequence.extend(single_input_sequence)
194 | sub_label_sequence.extend(single_label_sequence)
195 |
196 | arrdata = np.asarray(sub_input_sequence)
197 | arrlabel = np.asarray(sub_label_sequence)
198 |
199 | make_data(sess, config.checkpoint_dir, arrdata, arrlabel)
200 |
201 |
202 | def train_input_setup(config):
203 | """
204 | Read image files, make their sub-images, and save them as a h5 file format.
205 | """
206 | sess = config.sess
207 | image_size, label_size, stride, scale = config.image_size, config.label_size, config.stride, config.scale
208 |
209 | # Load data path
210 | data = prepare_data(sess, dataset=config.data_dir)
211 |
212 | sub_input_sequence, sub_label_sequence = [], []
213 | padding = abs(image_size - label_size) / 2 # eg. for 3x: (21 - 11) / 2 = 5
214 | label_padding = label_size / scale # eg. for 3x: 21 / 3 = 7
215 |
216 | for i in xrange(len(data)):
217 | input_, label_ = preprocess(data[i], scale)
218 |
219 | if len(input_.shape) == 3:
220 | h, w, _ = input_.shape
221 | else:
222 | h, w = input_.shape
223 |
224 | for x in range(0, h - image_size - padding + 1, stride):
225 | for y in range(0, w - image_size - padding + 1, stride):
226 | sub_input = input_[x + padding : x + padding + image_size, y + padding : y + padding + image_size]
227 | x_loc, y_loc = x + label_padding, y + label_padding
228 | sub_label = label_[x_loc * scale : x_loc * scale + label_size, y_loc * scale : y_loc * scale + label_size]
229 |
230 | sub_input = sub_input.reshape([image_size, image_size, 1])
231 | sub_label = sub_label.reshape([label_size, label_size, 1])
232 |
233 | sub_input_sequence.append(sub_input)
234 | sub_label_sequence.append(sub_label)
235 |
236 | arrdata = np.asarray(sub_input_sequence)
237 | arrlabel = np.asarray(sub_label_sequence)
238 |
239 | make_data(sess, config.checkpoint_dir, arrdata, arrlabel)
240 |
241 |
242 | def test_input_setup(config):
243 | """
244 | Read image files, make their sub-images, and save them as a h5 file format.
245 | """
246 | sess = config.sess
247 | image_size, label_size, stride, scale = config.image_size, config.label_size, config.stride, config.scale
248 |
249 | # Load data path
250 | data = prepare_data(sess, dataset="Test")
251 |
252 | sub_input_sequence, sub_label_sequence = [], []
253 | padding = abs(image_size - label_size) / 2 # eg. (21 - 11) / 2 = 5
254 | label_padding = label_size / scale # eg. 21 / 3 = 7
255 |
256 | pic_index = 2 # Index of image based on lexicographic order in data folder
257 | input_, label_ = preprocess(data[pic_index], config.scale)
258 |
259 | if len(input_.shape) == 3:
260 | h, w, _ = input_.shape
261 | else:
262 | h, w = input_.shape
263 |
264 | nx, ny = 0, 0
265 | for x in range(0, h - image_size - padding + 1, stride):
266 | nx += 1
267 | ny = 0
268 | for y in range(0, w - image_size - padding + 1, stride):
269 | ny += 1
270 | sub_input = input_[x + padding : x + padding + image_size, y + padding : y + padding + image_size]
271 | x_loc, y_loc = x + label_padding, y + label_padding
272 | sub_label = label_[x_loc * scale : x_loc * scale + label_size, y_loc * scale : y_loc * scale + label_size]
273 |
274 | sub_input = sub_input.reshape([image_size, image_size, 1])
275 | sub_label = sub_label.reshape([label_size, label_size, 1])
276 |
277 | sub_input_sequence.append(sub_input)
278 | sub_label_sequence.append(sub_label)
279 |
280 | arrdata = np.asarray(sub_input_sequence)
281 | arrlabel = np.asarray(sub_label_sequence)
282 |
283 | make_data(sess, config.checkpoint_dir, arrdata, arrlabel)
284 |
285 | return nx, ny
286 |
287 | # You can ignore, I just wanted to see how much space all the parameters would take up
288 | def save_params(sess, weights, biases):
289 | param_dir = "params/"
290 |
291 | if not os.path.exists(param_dir):
292 | os.makedirs(param_dir)
293 |
294 | weight_file = open(param_dir + "weights", 'wb')
295 | for layer in weights:
296 | layer_weights = sess.run(weights[layer])
297 |
298 | for filter_x in range(len(layer_weights)):
299 | for filter_y in range(len(layer_weights[filter_x])):
300 | filter_weights = layer_weights[filter_x][filter_y]
301 | for input_channel in range(len(filter_weights)):
302 | for output_channel in range(len(filter_weights[input_channel])):
303 | weight_value = filter_weights[input_channel][output_channel]
304 | # Write bytes directly to save space
305 | weight_file.write(struct.pack("f", weight_value))
306 | weight_file.write(struct.pack("x"))
307 |
308 | weight_file.write("\n\n")
309 | weight_file.close()
310 |
311 | bias_file = open(param_dir + "biases.txt", 'w')
312 | for layer in biases:
313 | bias_file.write("Layer {}\n".format(layer))
314 | layer_biases = sess.run(biases[layer])
315 | for bias in layer_biases:
316 | # Can write as characters due to low bias parameter count
317 | bias_file.write("{}, ".format(bias))
318 | bias_file.write("\n\n")
319 |
320 | bias_file.close()
321 |
322 | def merge(images, size):
323 | """
324 | Merges sub-images back into original image size
325 | """
326 | h, w = images.shape[1], images.shape[2]
327 | img = np.zeros((h * size[0], w * size[1], 1))
328 | for idx, image in enumerate(images):
329 | i = idx % size[1]
330 | j = idx // size[1]
331 | img[j*h:j*h+h, i*w:i*w+w, :] = image
332 |
333 | return img
334 |
335 | def array_image_save(array, image_path):
336 | """
337 | Converts np array to image and saves it
338 | """
339 | image = Image.fromarray(array)
340 | if image.mode != 'RGB':
341 | image = image.convert('RGB')
342 | image.save(image_path)
343 | print("Saved image: {}".format(image_path))
344 |
--------------------------------------------------------------------------------