├── .gitignore ├── LICENSE ├── README.md ├── Test ├── Set14 │ ├── baboon.bmp │ ├── barbara.bmp │ ├── bridge.bmp │ ├── coastguard.bmp │ ├── comic.bmp │ ├── face.bmp │ ├── flowers.bmp │ ├── foreman.bmp │ ├── lenna.bmp │ ├── man.bmp │ ├── monarch.bmp │ ├── pepper.bmp │ ├── ppt3.bmp │ └── zebra.bmp └── Set5 │ ├── baby_GT.bmp │ ├── bird_GT.bmp │ ├── butterfly_GT.bmp │ ├── head_GT.bmp │ └── woman_GT.bmp ├── Train ├── t1.bmp ├── t10.bmp ├── t11.bmp ├── t12.bmp ├── t13.bmp ├── t14.bmp ├── t15.bmp ├── t16.bmp ├── t17.bmp ├── t18.bmp ├── t19.bmp ├── t2.bmp ├── t20.bmp ├── t21.bmp ├── t22.bmp ├── t23.bmp ├── t24.bmp ├── t25.bmp ├── t26.bmp ├── t27.bmp ├── t28.bmp ├── t29.bmp ├── t3.bmp ├── t30.bmp ├── t31.bmp ├── t32.bmp ├── t33.bmp ├── t34.bmp ├── t35.bmp ├── t36.bmp ├── t37.bmp ├── t38.bmp ├── t39.bmp ├── t4.bmp ├── t40.bmp ├── t42.bmp ├── t43.bmp ├── t44.bmp ├── t45.bmp ├── t46.bmp ├── t47.bmp ├── t48.bmp ├── t49.bmp ├── t5.bmp ├── t50.bmp ├── t51.bmp ├── t52.bmp ├── t53.bmp ├── t54.bmp ├── t55.bmp ├── t56.bmp ├── t57.bmp ├── t58.bmp ├── t59.bmp ├── t6.bmp ├── t60.bmp ├── t61.bmp ├── t62.bmp ├── t63.bmp ├── t64.bmp ├── t65.bmp ├── t66.bmp ├── t7.bmp ├── t8.bmp ├── t9.bmp ├── tt1.bmp ├── tt10.bmp ├── tt12.bmp ├── tt13.bmp ├── tt14.bmp ├── tt15.bmp ├── tt16.bmp ├── tt17.bmp ├── tt18.bmp ├── tt19.bmp ├── tt2.bmp ├── tt20.bmp ├── tt21.bmp ├── tt22.bmp ├── tt23.bmp ├── tt24.bmp ├── tt25.bmp ├── tt26.bmp ├── tt27.bmp ├── tt3.bmp ├── tt4.bmp ├── tt5.bmp ├── tt6.bmp ├── tt7.bmp ├── tt8.bmp └── tt9.bmp ├── expand_data.py ├── main.py ├── model.py ├── result ├── bicubic.png ├── fsrcnn.png └── original.png └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *,cover 46 | .hypothesis/ 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | *.log 54 | local_settings.py 55 | 56 | # Flask stuff: 57 | instance/ 58 | .webassets-cache 59 | 60 | # Scrapy stuff: 61 | .scrapy 62 | 63 | # Sphinx documentation 64 | docs/_build/ 65 | 66 | # PyBuilder 67 | target/ 68 | 69 | # IPython Notebook 70 | .ipynb_checkpoints 71 | 72 | # pyenv 73 | .python-version 74 | 75 | # celery beat schedule file 76 | celerybeat-schedule 77 | 78 | # dotenv 79 | .env 80 | 81 | # virtualenv 82 | venv/ 83 | ENV/ 84 | 85 | # Spyder project settings 86 | .spyderproject 87 | 88 | # Rope project settings 89 | .ropeproject 90 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2016 Drake Levy 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FSRCNN-TensorFlow 2 | TensorFlow implementation of the Fast Super-Resolution Convolutional Neural Network (FSRCNN). This implements two models: FSRCNN which is more accurate but slower and FSRCNN-s which is faster but less accurate. Based on this [project](http://mmlab.ie.cuhk.edu.hk/projects/FSRCNN.html). 3 | 4 | ## Prerequisites 5 | * Python 2.7 6 | * TensorFlow 7 | * Scipy version > 0.18 8 | * h5py 9 | * PIL 10 | 11 | ## Usage 12 | For training: `python main.py` 13 |
14 | For testing: `python main.py --train False` 15 | 16 | To use FSCRNN-s instead of FSCRNN: `python main.py --fast True` 17 | 18 | Can specify epochs, learning rate, data directory, etc: 19 |
20 | `python main.py --epochs 10 --learning_rate 0.0001 --data_dir Train` 21 |
22 | Check `main.py` for all the possible flags 23 | 24 | Also includes script `expand_data.py` which scales and rotates all the images in the specified training set to expand it 25 | 26 | ## Result 27 | 28 | Original butterfly image: 29 | 30 | ![orig](https://github.com/drakelevy/FSRCNN-Tensorflow/blob/master/result/original.png?raw=true) 31 | 32 | 33 | Bicubic interpolated image: 34 | 35 | ![bicubic](https://github.com/drakelevy/FSRCNN-Tensorflow/blob/master/result/bicubic.png?raw=true) 36 | 37 | 38 | Super-resolved image: 39 | 40 | ![srcnn](https://github.com/drakelevy/FSRCNN-Tensorflow/blob/master/result/fsrcnn.png?raw=true) 41 | 42 | ## TODO 43 | 44 | * Add RGB support (Increase each layer depth to 3) 45 | * Speed up pre-processing for large datasets 46 | * Set learning rate for deconvolutional layer to 1e-4 (vs 1e-3 for the rest) 47 | 48 | ## References 49 | 50 | * [tegg89/SRCNN-Tensorflow](https://github.com/tegg89/SRCNN-Tensorflow) 51 | 52 | * [liliumao/Tensorflow-srcnn](https://github.com/liliumao/Tensorflow-srcnn) 53 | 54 | * [carpedm20/DCGAN-tensorflow](https://github.com/carpedm20/DCGAN-tensorflow) 55 | -------------------------------------------------------------------------------- /Test/Set14/baboon.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Test/Set14/baboon.bmp -------------------------------------------------------------------------------- /Test/Set14/barbara.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Test/Set14/barbara.bmp -------------------------------------------------------------------------------- /Test/Set14/bridge.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Test/Set14/bridge.bmp -------------------------------------------------------------------------------- /Test/Set14/coastguard.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Test/Set14/coastguard.bmp -------------------------------------------------------------------------------- /Test/Set14/comic.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Test/Set14/comic.bmp -------------------------------------------------------------------------------- /Test/Set14/face.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Test/Set14/face.bmp -------------------------------------------------------------------------------- /Test/Set14/flowers.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Test/Set14/flowers.bmp -------------------------------------------------------------------------------- /Test/Set14/foreman.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Test/Set14/foreman.bmp -------------------------------------------------------------------------------- /Test/Set14/lenna.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Test/Set14/lenna.bmp -------------------------------------------------------------------------------- /Test/Set14/man.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Test/Set14/man.bmp -------------------------------------------------------------------------------- /Test/Set14/monarch.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Test/Set14/monarch.bmp -------------------------------------------------------------------------------- /Test/Set14/pepper.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Test/Set14/pepper.bmp -------------------------------------------------------------------------------- /Test/Set14/ppt3.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Test/Set14/ppt3.bmp -------------------------------------------------------------------------------- /Test/Set14/zebra.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Test/Set14/zebra.bmp -------------------------------------------------------------------------------- /Test/Set5/baby_GT.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Test/Set5/baby_GT.bmp -------------------------------------------------------------------------------- /Test/Set5/bird_GT.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Test/Set5/bird_GT.bmp -------------------------------------------------------------------------------- /Test/Set5/butterfly_GT.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Test/Set5/butterfly_GT.bmp -------------------------------------------------------------------------------- /Test/Set5/head_GT.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Test/Set5/head_GT.bmp -------------------------------------------------------------------------------- /Test/Set5/woman_GT.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Test/Set5/woman_GT.bmp -------------------------------------------------------------------------------- /Train/t1.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t1.bmp -------------------------------------------------------------------------------- /Train/t10.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t10.bmp -------------------------------------------------------------------------------- /Train/t11.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t11.bmp -------------------------------------------------------------------------------- /Train/t12.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t12.bmp -------------------------------------------------------------------------------- /Train/t13.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t13.bmp -------------------------------------------------------------------------------- /Train/t14.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t14.bmp -------------------------------------------------------------------------------- /Train/t15.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t15.bmp -------------------------------------------------------------------------------- /Train/t16.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t16.bmp -------------------------------------------------------------------------------- /Train/t17.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t17.bmp -------------------------------------------------------------------------------- /Train/t18.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t18.bmp -------------------------------------------------------------------------------- /Train/t19.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t19.bmp -------------------------------------------------------------------------------- /Train/t2.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t2.bmp -------------------------------------------------------------------------------- /Train/t20.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t20.bmp -------------------------------------------------------------------------------- /Train/t21.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t21.bmp -------------------------------------------------------------------------------- /Train/t22.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t22.bmp -------------------------------------------------------------------------------- /Train/t23.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t23.bmp -------------------------------------------------------------------------------- /Train/t24.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t24.bmp -------------------------------------------------------------------------------- /Train/t25.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t25.bmp -------------------------------------------------------------------------------- /Train/t26.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t26.bmp -------------------------------------------------------------------------------- /Train/t27.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t27.bmp -------------------------------------------------------------------------------- /Train/t28.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t28.bmp -------------------------------------------------------------------------------- /Train/t29.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t29.bmp -------------------------------------------------------------------------------- /Train/t3.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t3.bmp -------------------------------------------------------------------------------- /Train/t30.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t30.bmp -------------------------------------------------------------------------------- /Train/t31.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t31.bmp -------------------------------------------------------------------------------- /Train/t32.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t32.bmp -------------------------------------------------------------------------------- /Train/t33.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t33.bmp -------------------------------------------------------------------------------- /Train/t34.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t34.bmp -------------------------------------------------------------------------------- /Train/t35.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t35.bmp -------------------------------------------------------------------------------- /Train/t36.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t36.bmp -------------------------------------------------------------------------------- /Train/t37.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t37.bmp -------------------------------------------------------------------------------- /Train/t38.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t38.bmp -------------------------------------------------------------------------------- /Train/t39.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t39.bmp -------------------------------------------------------------------------------- /Train/t4.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t4.bmp -------------------------------------------------------------------------------- /Train/t40.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t40.bmp -------------------------------------------------------------------------------- /Train/t42.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t42.bmp -------------------------------------------------------------------------------- /Train/t43.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t43.bmp -------------------------------------------------------------------------------- /Train/t44.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t44.bmp -------------------------------------------------------------------------------- /Train/t45.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t45.bmp -------------------------------------------------------------------------------- /Train/t46.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t46.bmp -------------------------------------------------------------------------------- /Train/t47.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t47.bmp -------------------------------------------------------------------------------- /Train/t48.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t48.bmp -------------------------------------------------------------------------------- /Train/t49.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t49.bmp -------------------------------------------------------------------------------- /Train/t5.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t5.bmp -------------------------------------------------------------------------------- /Train/t50.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t50.bmp -------------------------------------------------------------------------------- /Train/t51.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t51.bmp -------------------------------------------------------------------------------- /Train/t52.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t52.bmp -------------------------------------------------------------------------------- /Train/t53.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t53.bmp -------------------------------------------------------------------------------- /Train/t54.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t54.bmp -------------------------------------------------------------------------------- /Train/t55.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t55.bmp -------------------------------------------------------------------------------- /Train/t56.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t56.bmp -------------------------------------------------------------------------------- /Train/t57.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t57.bmp -------------------------------------------------------------------------------- /Train/t58.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t58.bmp -------------------------------------------------------------------------------- /Train/t59.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t59.bmp -------------------------------------------------------------------------------- /Train/t6.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t6.bmp -------------------------------------------------------------------------------- /Train/t60.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t60.bmp -------------------------------------------------------------------------------- /Train/t61.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t61.bmp -------------------------------------------------------------------------------- /Train/t62.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t62.bmp -------------------------------------------------------------------------------- /Train/t63.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t63.bmp -------------------------------------------------------------------------------- /Train/t64.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t64.bmp -------------------------------------------------------------------------------- /Train/t65.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t65.bmp -------------------------------------------------------------------------------- /Train/t66.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t66.bmp -------------------------------------------------------------------------------- /Train/t7.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t7.bmp -------------------------------------------------------------------------------- /Train/t8.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t8.bmp -------------------------------------------------------------------------------- /Train/t9.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/t9.bmp -------------------------------------------------------------------------------- /Train/tt1.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/tt1.bmp -------------------------------------------------------------------------------- /Train/tt10.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/tt10.bmp -------------------------------------------------------------------------------- /Train/tt12.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/tt12.bmp -------------------------------------------------------------------------------- /Train/tt13.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/tt13.bmp -------------------------------------------------------------------------------- /Train/tt14.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/tt14.bmp -------------------------------------------------------------------------------- /Train/tt15.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/tt15.bmp -------------------------------------------------------------------------------- /Train/tt16.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/tt16.bmp -------------------------------------------------------------------------------- /Train/tt17.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/tt17.bmp -------------------------------------------------------------------------------- /Train/tt18.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/tt18.bmp -------------------------------------------------------------------------------- /Train/tt19.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/tt19.bmp -------------------------------------------------------------------------------- /Train/tt2.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/tt2.bmp -------------------------------------------------------------------------------- /Train/tt20.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/tt20.bmp -------------------------------------------------------------------------------- /Train/tt21.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/tt21.bmp -------------------------------------------------------------------------------- /Train/tt22.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/tt22.bmp -------------------------------------------------------------------------------- /Train/tt23.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/tt23.bmp -------------------------------------------------------------------------------- /Train/tt24.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/tt24.bmp -------------------------------------------------------------------------------- /Train/tt25.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/tt25.bmp -------------------------------------------------------------------------------- /Train/tt26.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/tt26.bmp -------------------------------------------------------------------------------- /Train/tt27.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/tt27.bmp -------------------------------------------------------------------------------- /Train/tt3.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/tt3.bmp -------------------------------------------------------------------------------- /Train/tt4.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/tt4.bmp -------------------------------------------------------------------------------- /Train/tt5.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/tt5.bmp -------------------------------------------------------------------------------- /Train/tt6.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/tt6.bmp -------------------------------------------------------------------------------- /Train/tt7.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/tt7.bmp -------------------------------------------------------------------------------- /Train/tt8.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/tt8.bmp -------------------------------------------------------------------------------- /Train/tt9.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/Train/tt9.bmp -------------------------------------------------------------------------------- /expand_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import glob 4 | import numpy as np 5 | from PIL import Image 6 | import pdb 7 | 8 | # Artifically expands the dataset by a factor of 19 by scaling and then rotating every image 9 | def main(): 10 | if len(sys.argv) == 2: 11 | data = prepare_data(sys.argv[1]) 12 | else: 13 | print("Missing argument: You must specify a folder with images to expand") 14 | return 15 | 16 | for i in xrange(len(data)): 17 | scale(data[i]) 18 | rotate(data[i]) 19 | 20 | def prepare_data(dataset): 21 | filenames = os.listdir(dataset) 22 | data_dir = os.path.join(os.getcwd(), dataset) 23 | data = glob.glob(os.path.join(data_dir, "*.bmp")) 24 | 25 | return data 26 | 27 | def scale(file): 28 | image = Image.open(file) 29 | width, height = image.size 30 | 31 | scales = [0.9, 0.8, 0.7, 0.6] 32 | for scale in scales: 33 | new_width, new_height = int(width * scale), int(height * scale) 34 | new_image = image.resize((new_width, new_height), Image.ANTIALIAS) 35 | new_path = '{}-{}.bmp'.format(file[:-4], scale) 36 | new_image.save(new_path) 37 | 38 | def rotate(file): 39 | image = Image.open(file) 40 | 41 | rotations = [90, 180, 270] 42 | for rotation in rotations: 43 | new_image = image.rotate(rotation, expand=True) 44 | new_path = '{}-{}.bmp'.format(file[:-4], rotation) 45 | new_image.save(new_path) 46 | 47 | if __name__ == '__main__': 48 | main() -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from model import FSRCNN 2 | 3 | import numpy as np 4 | import tensorflow as tf 5 | 6 | import pprint 7 | import os 8 | 9 | flags = tf.app.flags 10 | flags.DEFINE_boolean("fast", False, "Use the fast model (FSRCNN-s) [False]") 11 | flags.DEFINE_integer("epoch", 10, "Number of epochs [10]") 12 | flags.DEFINE_integer("batch_size", 128, "The size of batch images [128]") 13 | flags.DEFINE_float("learning_rate", 1e-3, "The learning rate of gradient descent algorithm [1e-3]") 14 | flags.DEFINE_float("momentum", 0.9, "The momentum value for the momentum SGD [0.9]") 15 | flags.DEFINE_integer("c_dim", 1, "Dimension of image color [1]") 16 | flags.DEFINE_integer("scale", 3, "The size of scale factor for preprocessing input image [3]") 17 | flags.DEFINE_integer("stride", 4, "The size of stride to apply to input image [4]") 18 | flags.DEFINE_string("checkpoint_dir", "checkpoint", "Name of checkpoint directory [checkpoint]") 19 | flags.DEFINE_string("output_dir", "result", "Name of test output directory [result]") 20 | flags.DEFINE_string("data_dir", "FastTrain", "Name of data directory to train on [FastTrain]") 21 | flags.DEFINE_boolean("train", True, "True for training, false for testing [True]") 22 | flags.DEFINE_integer("threads", 1, "Number of processes to pre-process data with [1]") 23 | flags.DEFINE_boolean("params", False, "Save weight and bias parameters [False]") 24 | 25 | FLAGS = flags.FLAGS 26 | 27 | pp = pprint.PrettyPrinter() 28 | 29 | def main(_): 30 | pp.pprint(flags.FLAGS.__flags) 31 | 32 | if FLAGS.fast: 33 | FLAGS.checkpoint_dir = 'fast_{}'.format(FLAGS.checkpoint_dir) 34 | if not os.path.exists(FLAGS.checkpoint_dir): 35 | os.makedirs(FLAGS.checkpoint_dir) 36 | if not os.path.exists(FLAGS.output_dir): 37 | os.makedirs(FLAGS.output_dir) 38 | 39 | 40 | with tf.Session() as sess: 41 | fsrcnn = FSRCNN(sess, config=FLAGS) 42 | fsrcnn.run() 43 | 44 | if __name__ == '__main__': 45 | tf.app.run() 46 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | from utils import ( 2 | read_data, 3 | thread_train_setup, 4 | train_input_setup, 5 | test_input_setup, 6 | save_params, 7 | merge, 8 | array_image_save 9 | ) 10 | 11 | import time 12 | import os 13 | 14 | import numpy as np 15 | import tensorflow as tf 16 | 17 | from PIL import Image 18 | import pdb 19 | 20 | # Based on http://mmlab.ie.cuhk.edu.hk/projects/FSRCNN.html 21 | class FSRCNN(object): 22 | 23 | def __init__(self, sess, config): 24 | self.sess = sess 25 | self.fast = config.fast 26 | self.train = config.train 27 | self.c_dim = config.c_dim 28 | self.is_grayscale = (self.c_dim == 1) 29 | self.epoch = config.epoch 30 | self.scale = config.scale 31 | self.stride = config.stride 32 | self.batch_size = config.batch_size 33 | self.learning_rate = config.learning_rate 34 | self.momentum = config.momentum 35 | self.threads = config.threads 36 | self.params = config.params 37 | 38 | # Different image/label sub-sizes for different scaling factors x2, x3, x4 39 | scale_factors = [[14, 20], [11, 21], [10, 24]] 40 | self.image_size, self.label_size = scale_factors[self.scale - 2] 41 | # Testing uses different strides to ensure sub-images line up correctly 42 | if not self.train: 43 | self.stride = [10, 7, 6][self.scale - 2] 44 | 45 | # Different model layer counts and filter sizes for FSRCNN vs FSRCNN-s (fast), (s, d, m) in paper 46 | model_params = [[56, 12, 4], [32, 5, 1]] 47 | self.model_params = model_params[self.fast] 48 | 49 | self.checkpoint_dir = config.checkpoint_dir 50 | self.output_dir = config.output_dir 51 | self.data_dir = config.data_dir 52 | self.build_model() 53 | 54 | 55 | def build_model(self): 56 | self.images = tf.placeholder(tf.float32, [None, self.image_size, self.image_size, self.c_dim], name='images') 57 | self.labels = tf.placeholder(tf.float32, [None, self.label_size, self.label_size, self.c_dim], name='labels') 58 | # Batch size differs in training vs testing 59 | self.batch = tf.placeholder(tf.int32, shape=[], name='batch') 60 | 61 | # FSCRNN-s (fast) has smaller filters and less layers but can achieve faster performance 62 | s, d, m = self.model_params 63 | 64 | expand_weight, deconv_weight = 'w{}'.format(m + 3), 'w{}'.format(m + 4) 65 | self.weights = { 66 | 'w1': tf.Variable(tf.random_normal([5, 5, 1, s], stddev=0.0378, dtype=tf.float32), name='w1'), 67 | 'w2': tf.Variable(tf.random_normal([1, 1, s, d], stddev=0.3536, dtype=tf.float32), name='w2'), 68 | expand_weight: tf.Variable(tf.random_normal([1, 1, d, s], stddev=0.189, dtype=tf.float32), name=expand_weight), 69 | deconv_weight: tf.Variable(tf.random_normal([9, 9, 1, s], stddev=0.0001, dtype=tf.float32), name=deconv_weight) 70 | } 71 | 72 | expand_bias, deconv_bias = 'b{}'.format(m + 3), 'b{}'.format(m + 4) 73 | self.biases = { 74 | 'b1': tf.Variable(tf.zeros([s]), name='b1'), 75 | 'b2': tf.Variable(tf.zeros([d]), name='b2'), 76 | expand_bias: tf.Variable(tf.zeros([s]), name=expand_bias), 77 | deconv_bias: tf.Variable(tf.zeros([1]), name=deconv_bias) 78 | } 79 | 80 | # Create the m mapping layers weights/biases 81 | for i in range(3, m + 3): 82 | weight_name, bias_name = 'w{}'.format(i), 'b{}'.format(i) 83 | self.weights[weight_name] = tf.Variable(tf.random_normal([3, 3, d, d], stddev=0.1179, dtype=tf.float32), name=weight_name) 84 | self.biases[bias_name] = tf.Variable(tf.zeros([d]), name=bias_name) 85 | 86 | self.pred = self.model() 87 | 88 | # Loss function (MSE) 89 | self.loss = tf.reduce_mean(tf.reduce_sum(tf.square(self.labels - self.pred), reduction_indices=0)) 90 | 91 | self.saver = tf.train.Saver() 92 | 93 | def run(self): 94 | # SGD with momentum 95 | self.train_op = tf.train.MomentumOptimizer(self.learning_rate, self.momentum).minimize(self.loss) 96 | 97 | tf.initialize_all_variables().run() 98 | 99 | if self.load(self.checkpoint_dir): 100 | print(" [*] Load SUCCESS") 101 | else: 102 | print(" [!] Load failed...") 103 | 104 | if self.params: 105 | save_params(self.sess, self.weights, self.biases) 106 | elif self.train: 107 | self.run_train() 108 | else: 109 | self.run_test() 110 | 111 | def run_train(self): 112 | start_time = time.time() 113 | print("Beginning training setup...") 114 | if self.threads == 1: 115 | train_input_setup(self) 116 | else: 117 | thread_train_setup(self) 118 | print("Training setup took {} seconds with {} threads".format(time.time() - start_time, self.threads)) 119 | 120 | data_dir = os.path.join('./{}'.format(self.checkpoint_dir), "train.h5") 121 | train_data, train_label = read_data(data_dir) 122 | print("Total setup time took {} seconds with {} threads".format(time.time() - start_time, self.threads)) 123 | 124 | print("Training...") 125 | start_time = time.time() 126 | start_average, end_average, counter = 0, 0, 0 127 | 128 | for ep in xrange(self.epoch): 129 | # Run by batch images 130 | batch_idxs = len(train_data) // self.batch_size 131 | batch_average = 0 132 | for idx in xrange(0, batch_idxs): 133 | batch_images = train_data[idx * self.batch_size : (idx + 1) * self.batch_size] 134 | batch_labels = train_label[idx * self.batch_size : (idx + 1) * self.batch_size] 135 | 136 | counter += 1 137 | _, err = self.sess.run([self.train_op, self.loss], feed_dict={self.images: batch_images, self.labels: batch_labels, self.batch: self.batch_size}) 138 | batch_average += err 139 | 140 | if counter % 10 == 0: 141 | print("Epoch: [%2d], step: [%2d], time: [%4.4f], loss: [%.8f]" \ 142 | % ((ep+1), counter, time.time() - start_time, err)) 143 | 144 | # Save every 500 steps 145 | if counter % 500 == 0: 146 | self.save(self.checkpoint_dir, counter) 147 | 148 | batch_average = float(batch_average) / batch_idxs 149 | if ep < (self.epoch * 0.2): 150 | start_average += batch_average 151 | elif ep >= (self.epoch * 0.8): 152 | end_average += batch_average 153 | 154 | # Compare loss of the first 20% and the last 20% epochs 155 | start_average = float(start_average) / (self.epoch * 0.2) 156 | end_average = float(end_average) / (self.epoch * 0.2) 157 | print("Start Average: [%.6f], End Average: [%.6f], Improved: [%.2f%%]" \ 158 | % (start_average, end_average, 100 - (100*end_average/start_average))) 159 | 160 | # Linux desktop notification when training has been completed 161 | # title = "Training complete - FSRCNN" 162 | # notification = "{}-{}-{} done training after {} epochs".format(self.image_size, self.label_size, self.stride, self.epoch); 163 | # notify_command = 'notify-send "{}" "{}"'.format(title, notification) 164 | # os.system(notify_command) 165 | 166 | 167 | def run_test(self): 168 | nx, ny = test_input_setup(self) 169 | data_dir = os.path.join('./{}'.format(self.checkpoint_dir), "test.h5") 170 | test_data, test_label = read_data(data_dir) 171 | 172 | print("Testing...") 173 | 174 | start_time = time.time() 175 | result = self.pred.eval({self.images: test_data, self.labels: test_label, self.batch: nx * ny}) 176 | print("Took %.3f seconds" % (time.time() - start_time)) 177 | 178 | result = merge(result, [nx, ny]) 179 | result = result.squeeze() 180 | image_path = os.path.join(os.getcwd(), self.output_dir) 181 | image_path = os.path.join(image_path, "test_image.png") 182 | 183 | array_image_save(result * 255, image_path) 184 | 185 | def model(self): 186 | # Feature Extraction 187 | conv_feature = self.prelu(tf.nn.conv2d(self.images, self.weights['w1'], strides=[1,1,1,1], padding='VALID') + self.biases['b1'], 1) 188 | 189 | # Shrinking 190 | conv_shrink = self.prelu(tf.nn.conv2d(conv_feature, self.weights['w2'], strides=[1,1,1,1], padding='SAME') + self.biases['b2'], 2) 191 | 192 | # Mapping (# mapping layers = m) 193 | prev_layer, m = conv_shrink, self.model_params[2] 194 | for i in range(3, m + 3): 195 | weights, biases = self.weights['w{}'.format(i)], self.biases['b{}'.format(i)] 196 | prev_layer = self.prelu(tf.nn.conv2d(prev_layer, weights, strides=[1,1,1,1], padding='SAME') + biases, i) 197 | 198 | # Expanding 199 | expand_weights, expand_biases = self.weights['w{}'.format(m + 3)], self.biases['b{}'.format(m + 3)] 200 | conv_expand = self.prelu(tf.nn.conv2d(prev_layer, expand_weights, strides=[1,1,1,1], padding='SAME') + expand_biases, 7) 201 | 202 | # Deconvolution 203 | deconv_output = [self.batch, self.label_size, self.label_size, self.c_dim] 204 | deconv_stride = [1, self.scale, self.scale, 1] 205 | deconv_weights, deconv_biases = self.weights['w{}'.format(m + 4)], self.biases['b{}'.format(m + 4)] 206 | conv_deconv = tf.nn.conv2d_transpose(conv_expand, deconv_weights, output_shape=deconv_output, strides=deconv_stride, padding='SAME') + deconv_biases 207 | 208 | return conv_deconv 209 | 210 | def prelu(self, _x, i): 211 | """ 212 | PreLU tensorflow implementation 213 | """ 214 | alphas = tf.get_variable('alpha{}'.format(i), _x.get_shape()[-1], initializer=tf.constant_initializer(0.0), dtype=tf.float32) 215 | pos = tf.nn.relu(_x) 216 | neg = alphas * (_x - abs(_x)) * 0.5 217 | 218 | return pos + neg 219 | 220 | def save(self, checkpoint_dir, step): 221 | model_name = "FSRCNN.model" 222 | model_dir = "%s_%s" % ("fsrcnn", self.label_size) 223 | checkpoint_dir = os.path.join(checkpoint_dir, model_dir) 224 | 225 | if not os.path.exists(checkpoint_dir): 226 | os.makedirs(checkpoint_dir) 227 | 228 | self.saver.save(self.sess, 229 | os.path.join(checkpoint_dir, model_name), 230 | global_step=step) 231 | 232 | def load(self, checkpoint_dir): 233 | print(" [*] Reading checkpoints...") 234 | model_dir = "%s_%s" % ("fsrcnn", self.label_size) 235 | checkpoint_dir = os.path.join(checkpoint_dir, model_dir) 236 | 237 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 238 | if ckpt and ckpt.model_checkpoint_path: 239 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path) 240 | self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name)) 241 | return True 242 | else: 243 | return False -------------------------------------------------------------------------------- /result/bicubic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/result/bicubic.png -------------------------------------------------------------------------------- /result/fsrcnn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/result/fsrcnn.png -------------------------------------------------------------------------------- /result/original.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/FSRCNN-TensorFlow/e66845dbe076b4075d027307087be29e6ac29f74/result/original.png -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Scipy version > 0.18 is needed, due to 'mode' option from scipy.misc.imread function 3 | """ 4 | 5 | import os 6 | import glob 7 | import h5py 8 | import random 9 | from math import floor 10 | import struct 11 | 12 | import tensorflow as tf 13 | from PIL import Image 14 | from scipy.misc import imread 15 | import numpy as np 16 | from multiprocessing import Pool, Lock, active_children 17 | 18 | import pdb 19 | 20 | FLAGS = tf.app.flags.FLAGS 21 | 22 | def read_data(path): 23 | """ 24 | Read h5 format data file 25 | 26 | Args: 27 | path: file path of desired file 28 | 29 | Returns: 30 | data: '.h5' file format that contains train data values 31 | label: '.h5' file format that contains train label values 32 | """ 33 | with h5py.File(path, 'r') as hf: 34 | data = np.array(hf.get('data')) 35 | label = np.array(hf.get('label')) 36 | return data, label 37 | 38 | def preprocess(path, scale=3): 39 | """ 40 | Preprocess single image file 41 | (1) Read original image as YCbCr format (and grayscale as default) 42 | (2) Normalize 43 | (3) Downsampled by scale factor (using anti-aliasing) 44 | """ 45 | 46 | image = Image.open(path).convert('L') 47 | (width, height) = image.size 48 | label_ = np.array(list(image.getdata())).astype(np.float).reshape((height, width)) / 255 49 | image.close() 50 | 51 | cropped_image = Image.fromarray(modcrop(label_, scale)) 52 | 53 | (width, height) = cropped_image.size 54 | new_width, new_height = int(width / scale), int(height / scale) 55 | scaled_image = cropped_image.resize((new_width, new_height), Image.ANTIALIAS) 56 | cropped_image.close() 57 | 58 | (width, height) = scaled_image.size 59 | input_ = np.array(list(scaled_image.getdata())).astype(np.float).reshape((height, width)) 60 | 61 | return input_, label_ 62 | 63 | def prepare_data(sess, dataset): 64 | """ 65 | Args: 66 | dataset: choose train dataset or test dataset 67 | 68 | For train dataset, output data would be ['.../t1.bmp', '.../t2.bmp', ..., '.../t99.bmp'] 69 | """ 70 | if FLAGS.train: 71 | filenames = os.listdir(dataset) 72 | data_dir = os.path.join(os.getcwd(), dataset) 73 | else: 74 | data_dir = os.path.join(os.sep, (os.path.join(os.getcwd(), dataset)), "Set5") 75 | data = sorted(glob.glob(os.path.join(data_dir, "*.bmp"))) 76 | 77 | return data 78 | 79 | def make_data(sess, checkpoint_dir, data, label): 80 | """ 81 | Make input data as h5 file format 82 | Depending on 'train' (flag value), savepath would be changed. 83 | """ 84 | if FLAGS.train: 85 | savepath = os.path.join(os.getcwd(), '{}/train.h5'.format(checkpoint_dir)) 86 | else: 87 | savepath = os.path.join(os.getcwd(), '{}/test.h5'.format(checkpoint_dir)) 88 | 89 | with h5py.File(savepath, 'w') as hf: 90 | hf.create_dataset('data', data=data) 91 | hf.create_dataset('label', data=label) 92 | 93 | def image_read(path, is_grayscale=True): 94 | """ 95 | Read image using its path. 96 | Default value is gray-scale, and image is read by YCbCr format as the paper said. 97 | """ 98 | if is_grayscale: 99 | return imread(path, flatten=True, mode='YCbCr').astype(np.float) 100 | else: 101 | return imread(path, mode='YCbCr').astype(np.float) 102 | 103 | def modcrop(image, scale=3): 104 | """ 105 | To scale down and up the original image, first thing to do is to have no remainder while scaling operation. 106 | 107 | We need to find modulo of height (and width) and scale factor. 108 | Then, subtract the modulo from height (and width) of original image size. 109 | There would be no remainder even after scaling operation. 110 | """ 111 | if len(image.shape) == 3: 112 | h, w, _ = image.shape 113 | h = h - np.mod(h, scale) 114 | w = w - np.mod(w, scale) 115 | image = image[0:h, 0:w, :] 116 | else: 117 | h, w = image.shape 118 | h = h - np.mod(h, scale) 119 | w = w - np.mod(w, scale) 120 | image = image[0:h, 0:w] 121 | return image 122 | 123 | def train_input_worker(args): 124 | image_data, config = args 125 | image_size, label_size, stride, scale, save_image = config 126 | 127 | single_input_sequence, single_label_sequence = [], [] 128 | padding = abs(image_size - label_size) / 2 # eg. for 3x: (21 - 11) / 2 = 5 129 | label_padding = label_size / scale # eg. for 3x: 21 / 3 = 7 130 | 131 | input_, label_ = preprocess(image_data, scale) 132 | 133 | if len(input_.shape) == 3: 134 | h, w, _ = input_.shape 135 | else: 136 | h, w = input_.shape 137 | 138 | for x in range(0, h - image_size - padding + 1, stride): 139 | for y in range(0, w - image_size - padding + 1, stride): 140 | sub_input = input_[x + padding : x + padding + image_size, y + padding : y + padding + image_size] 141 | x_loc, y_loc = x + label_padding, y + label_padding 142 | sub_label = label_[x_loc * scale : x_loc * scale + label_size, y_loc * scale : y_loc * scale + label_size] 143 | 144 | sub_input = sub_input.reshape([image_size, image_size, 1]) 145 | sub_label = sub_label.reshape([label_size, label_size, 1]) 146 | 147 | single_input_sequence.append(sub_input) 148 | single_label_sequence.append(sub_label) 149 | 150 | return [single_input_sequence, single_label_sequence] 151 | 152 | 153 | def thread_train_setup(config): 154 | """ 155 | Spawns |config.threads| worker processes to pre-process the data 156 | 157 | This has not been extensively tested so use at your own risk. 158 | Also this is technically multiprocessing not threading, I just say thread 159 | because it's shorter to type. 160 | """ 161 | sess = config.sess 162 | 163 | # Load data path 164 | data = prepare_data(sess, dataset=config.data_dir) 165 | 166 | # Initialize multiprocessing pool with # of processes = config.threads 167 | pool = Pool(config.threads) 168 | 169 | # Distribute |images_per_thread| images across each worker process 170 | config_values = [config.image_size, config.label_size, config.stride, config.scale, config.save_image] 171 | images_per_thread = len(data) / config.threads 172 | workers = [] 173 | for thread in range(config.threads): 174 | args_list = [(data[i], config_values) for i in range(thread * images_per_thread, (thread + 1) * images_per_thread)] 175 | worker = pool.map_async(train_input_worker, args_list) 176 | workers.append(worker) 177 | print("{} worker processes created".format(config.threads)) 178 | 179 | pool.close() 180 | 181 | results = [] 182 | for i in range(len(workers)): 183 | print("Waiting for worker process {}".format(i)) 184 | results.extend(workers[i].get(timeout=240)) 185 | print("Worker process {} done".format(i)) 186 | 187 | print("All worker processes done!") 188 | 189 | sub_input_sequence, sub_label_sequence = [], [] 190 | 191 | for image in range(len(results)): 192 | single_input_sequence, single_label_sequence = results[image] 193 | sub_input_sequence.extend(single_input_sequence) 194 | sub_label_sequence.extend(single_label_sequence) 195 | 196 | arrdata = np.asarray(sub_input_sequence) 197 | arrlabel = np.asarray(sub_label_sequence) 198 | 199 | make_data(sess, config.checkpoint_dir, arrdata, arrlabel) 200 | 201 | 202 | def train_input_setup(config): 203 | """ 204 | Read image files, make their sub-images, and save them as a h5 file format. 205 | """ 206 | sess = config.sess 207 | image_size, label_size, stride, scale = config.image_size, config.label_size, config.stride, config.scale 208 | 209 | # Load data path 210 | data = prepare_data(sess, dataset=config.data_dir) 211 | 212 | sub_input_sequence, sub_label_sequence = [], [] 213 | padding = abs(image_size - label_size) / 2 # eg. for 3x: (21 - 11) / 2 = 5 214 | label_padding = label_size / scale # eg. for 3x: 21 / 3 = 7 215 | 216 | for i in xrange(len(data)): 217 | input_, label_ = preprocess(data[i], scale) 218 | 219 | if len(input_.shape) == 3: 220 | h, w, _ = input_.shape 221 | else: 222 | h, w = input_.shape 223 | 224 | for x in range(0, h - image_size - padding + 1, stride): 225 | for y in range(0, w - image_size - padding + 1, stride): 226 | sub_input = input_[x + padding : x + padding + image_size, y + padding : y + padding + image_size] 227 | x_loc, y_loc = x + label_padding, y + label_padding 228 | sub_label = label_[x_loc * scale : x_loc * scale + label_size, y_loc * scale : y_loc * scale + label_size] 229 | 230 | sub_input = sub_input.reshape([image_size, image_size, 1]) 231 | sub_label = sub_label.reshape([label_size, label_size, 1]) 232 | 233 | sub_input_sequence.append(sub_input) 234 | sub_label_sequence.append(sub_label) 235 | 236 | arrdata = np.asarray(sub_input_sequence) 237 | arrlabel = np.asarray(sub_label_sequence) 238 | 239 | make_data(sess, config.checkpoint_dir, arrdata, arrlabel) 240 | 241 | 242 | def test_input_setup(config): 243 | """ 244 | Read image files, make their sub-images, and save them as a h5 file format. 245 | """ 246 | sess = config.sess 247 | image_size, label_size, stride, scale = config.image_size, config.label_size, config.stride, config.scale 248 | 249 | # Load data path 250 | data = prepare_data(sess, dataset="Test") 251 | 252 | sub_input_sequence, sub_label_sequence = [], [] 253 | padding = abs(image_size - label_size) / 2 # eg. (21 - 11) / 2 = 5 254 | label_padding = label_size / scale # eg. 21 / 3 = 7 255 | 256 | pic_index = 2 # Index of image based on lexicographic order in data folder 257 | input_, label_ = preprocess(data[pic_index], config.scale) 258 | 259 | if len(input_.shape) == 3: 260 | h, w, _ = input_.shape 261 | else: 262 | h, w = input_.shape 263 | 264 | nx, ny = 0, 0 265 | for x in range(0, h - image_size - padding + 1, stride): 266 | nx += 1 267 | ny = 0 268 | for y in range(0, w - image_size - padding + 1, stride): 269 | ny += 1 270 | sub_input = input_[x + padding : x + padding + image_size, y + padding : y + padding + image_size] 271 | x_loc, y_loc = x + label_padding, y + label_padding 272 | sub_label = label_[x_loc * scale : x_loc * scale + label_size, y_loc * scale : y_loc * scale + label_size] 273 | 274 | sub_input = sub_input.reshape([image_size, image_size, 1]) 275 | sub_label = sub_label.reshape([label_size, label_size, 1]) 276 | 277 | sub_input_sequence.append(sub_input) 278 | sub_label_sequence.append(sub_label) 279 | 280 | arrdata = np.asarray(sub_input_sequence) 281 | arrlabel = np.asarray(sub_label_sequence) 282 | 283 | make_data(sess, config.checkpoint_dir, arrdata, arrlabel) 284 | 285 | return nx, ny 286 | 287 | # You can ignore, I just wanted to see how much space all the parameters would take up 288 | def save_params(sess, weights, biases): 289 | param_dir = "params/" 290 | 291 | if not os.path.exists(param_dir): 292 | os.makedirs(param_dir) 293 | 294 | weight_file = open(param_dir + "weights", 'wb') 295 | for layer in weights: 296 | layer_weights = sess.run(weights[layer]) 297 | 298 | for filter_x in range(len(layer_weights)): 299 | for filter_y in range(len(layer_weights[filter_x])): 300 | filter_weights = layer_weights[filter_x][filter_y] 301 | for input_channel in range(len(filter_weights)): 302 | for output_channel in range(len(filter_weights[input_channel])): 303 | weight_value = filter_weights[input_channel][output_channel] 304 | # Write bytes directly to save space 305 | weight_file.write(struct.pack("f", weight_value)) 306 | weight_file.write(struct.pack("x")) 307 | 308 | weight_file.write("\n\n") 309 | weight_file.close() 310 | 311 | bias_file = open(param_dir + "biases.txt", 'w') 312 | for layer in biases: 313 | bias_file.write("Layer {}\n".format(layer)) 314 | layer_biases = sess.run(biases[layer]) 315 | for bias in layer_biases: 316 | # Can write as characters due to low bias parameter count 317 | bias_file.write("{}, ".format(bias)) 318 | bias_file.write("\n\n") 319 | 320 | bias_file.close() 321 | 322 | def merge(images, size): 323 | """ 324 | Merges sub-images back into original image size 325 | """ 326 | h, w = images.shape[1], images.shape[2] 327 | img = np.zeros((h * size[0], w * size[1], 1)) 328 | for idx, image in enumerate(images): 329 | i = idx % size[1] 330 | j = idx // size[1] 331 | img[j*h:j*h+h, i*w:i*w+w, :] = image 332 | 333 | return img 334 | 335 | def array_image_save(array, image_path): 336 | """ 337 | Converts np array to image and saves it 338 | """ 339 | image = Image.fromarray(array) 340 | if image.mode != 'RGB': 341 | image = image.convert('RGB') 342 | image.save(image_path) 343 | print("Saved image: {}".format(image_path)) 344 | --------------------------------------------------------------------------------