├── .gitignore ├── .travis.yml ├── LICENSE ├── README.md ├── docs ├── Makefile ├── _templates │ └── about.html ├── conf.py ├── index.rst ├── make.bat ├── modules.rst ├── requirements.txt ├── tensorcv.callbacks.rst ├── tensorcv.dataflow.dataset.rst ├── tensorcv.dataflow.rst ├── tensorcv.models.rst ├── tensorcv.predicts.rst ├── tensorcv.rst ├── tensorcv.train.rst └── tensorcv.utils.rst ├── requirements.txt ├── setup.py ├── tensorcv ├── __init__.py ├── algorithms │ ├── GAN │ │ ├── DCGAN.py │ │ ├── README.md │ │ ├── config.py │ │ └── fig │ │ │ └── mnist_result.png │ └── pretrained │ │ ├── VGG.py │ │ └── VGG_.py ├── callbacks │ ├── __init__.py │ ├── base.py │ ├── debug.py │ ├── group.py │ ├── hooks.py │ ├── inference.py │ ├── inferencer.py │ ├── inputs.py │ ├── monitors.py │ ├── saver.py │ ├── summary.py │ └── trigger.py ├── data │ └── imageNetLabel.txt ├── dataflow │ ├── __init__.py │ ├── argument.py │ ├── base.py │ ├── bk │ │ └── image.py │ ├── common.py │ ├── dataset │ │ ├── BSDS500.py │ │ ├── CIFAR.py │ │ ├── MNIST.py │ │ └── __init__.py │ ├── image.py │ ├── matlab.py │ ├── normalization.py │ ├── operation.py │ ├── preprocess.py │ ├── randoms.py │ ├── sequence.py │ └── viz.py ├── models │ ├── __init__.py │ ├── base.py │ ├── bk │ │ └── layers.py │ ├── layers.py │ ├── losses.py │ ├── model_builder │ │ └── base.py │ └── utils.py ├── predicts │ ├── __init__.py │ ├── base.py │ ├── config.py │ ├── predictions.py │ └── simple.py ├── tfdataflow │ ├── __init__.py │ ├── base.py │ ├── convert.py │ └── write.py ├── train │ ├── __init__.py │ ├── base.py │ ├── config.py │ └── simple.py └── utils │ ├── __init__.py │ ├── common.py │ ├── debug.py │ ├── default.py │ ├── sesscreate.py │ ├── utils.py │ └── viz.py ├── test.py └── test ├── GAN-1D.py ├── VGG.py ├── VGG_pre_trained.py ├── VGG_train.py ├── config.py ├── model.py ├── test.py └── todo.md /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | python: 3 | - "3.4" 4 | install: 5 | - pip install -r requirements.txt 6 | - pip install tensorflow 7 | # - pip install flake8 8 | - pip install coveralls 9 | script: 10 | # - flake8 . --ignore=F405,F403 11 | - python test.py 12 | - coverage run --source=. test.py 13 | after_success: 14 | - coveralls 15 | 16 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Qian Ge 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # A deep learning package for computer vision algorithms built on top of TensorFlow 2 | 3 | [![Build Status](https://travis-ci.org/conan7882/DeepVision-tensorflow.svg?branch=master)](https://travis-ci.org/conan7882/DeepVision-tensorflow) 4 | 5 | This package is for practicing and reproducing learning-based computer vision algorithms in recent papers. 6 | 7 | It is a simplified version of [tensorpack](https://github.com/ppwwyyxx/tensorpack), which is the reference code of this package. 8 | 9 | # Install 10 | 11 | ``` 12 | pip install -r requirements.txt 13 | pip install -U git+https://github.com/conan7882/DeepVision-tensorflow.git 14 | ``` 15 | 16 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = python -msphinx 7 | SPHINXPROJ = DeepVision-TensorFlow 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -------------------------------------------------------------------------------- /docs/_templates/about.html: -------------------------------------------------------------------------------- 1 | {% extends "!layout.html" %} 2 | 3 | {%- block extrahead %} 4 | 7 | 10 | 11 | 19 | {% endblock %} -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | Test Documentation 2 | -------------------- 3 | 4 | .. toctree:: 5 | :maxdepth: 1 6 | 7 | modules 8 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=python -msphinx 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | set SPHINXPROJ=DeepVision-TensorFlow 13 | 14 | if "%1" == "" goto help 15 | 16 | %SPHINXBUILD% >NUL 2>NUL 17 | if errorlevel 9009 ( 18 | echo. 19 | echo.The Sphinx module was not found. Make sure you have Sphinx installed, 20 | echo.then set the SPHINXBUILD environment variable to point to the full 21 | echo.path of the 'sphinx-build' executable. Alternatively you may add the 22 | echo.Sphinx directory to PATH. 23 | echo. 24 | echo.If you don't have Sphinx installed, grab it from 25 | echo.http://sphinx-doc.org/ 26 | exit /b 1 27 | ) 28 | 29 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 30 | goto end 31 | 32 | :help 33 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 34 | 35 | :end 36 | popd 37 | -------------------------------------------------------------------------------- /docs/modules.rst: -------------------------------------------------------------------------------- 1 | tensorcv 2 | ======== 3 | 4 | .. toctree:: 5 | :maxdepth: 4 6 | 7 | tensorcv 8 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | termcolor 2 | numpy 3 | tqdm 4 | Sphinx>=1.6 5 | recommonmark==0.4.0 6 | guzzle_sphinx_theme 7 | mock 8 | tensorflow 9 | matplotlib 10 | Pillow -------------------------------------------------------------------------------- /docs/tensorcv.callbacks.rst: -------------------------------------------------------------------------------- 1 | tensorcv\.callbacks package 2 | =========================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | tensorcv\.callbacks\.base module 8 | -------------------------------- 9 | 10 | .. automodule:: tensorcv.callbacks.base 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | tensorcv\.callbacks\.debug module 16 | --------------------------------- 17 | 18 | .. automodule:: tensorcv.callbacks.debug 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | tensorcv\.callbacks\.group module 24 | --------------------------------- 25 | 26 | .. automodule:: tensorcv.callbacks.group 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | tensorcv\.callbacks\.hooks module 32 | --------------------------------- 33 | 34 | .. automodule:: tensorcv.callbacks.hooks 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | tensorcv\.callbacks\.inference module 40 | ------------------------------------- 41 | 42 | .. automodule:: tensorcv.callbacks.inference 43 | :members: 44 | :undoc-members: 45 | :show-inheritance: 46 | 47 | tensorcv\.callbacks\.inferencer module 48 | -------------------------------------- 49 | 50 | .. automodule:: tensorcv.callbacks.inferencer 51 | :members: 52 | :undoc-members: 53 | :show-inheritance: 54 | 55 | tensorcv\.callbacks\.inputs module 56 | ---------------------------------- 57 | 58 | .. automodule:: tensorcv.callbacks.inputs 59 | :members: 60 | :undoc-members: 61 | :show-inheritance: 62 | 63 | tensorcv\.callbacks\.monitors module 64 | ------------------------------------ 65 | 66 | .. automodule:: tensorcv.callbacks.monitors 67 | :members: 68 | :undoc-members: 69 | :show-inheritance: 70 | 71 | tensorcv\.callbacks\.saver module 72 | --------------------------------- 73 | 74 | .. automodule:: tensorcv.callbacks.saver 75 | :members: 76 | :undoc-members: 77 | :show-inheritance: 78 | 79 | tensorcv\.callbacks\.summary module 80 | ----------------------------------- 81 | 82 | .. automodule:: tensorcv.callbacks.summary 83 | :members: 84 | :undoc-members: 85 | :show-inheritance: 86 | 87 | tensorcv\.callbacks\.trigger module 88 | ----------------------------------- 89 | 90 | .. automodule:: tensorcv.callbacks.trigger 91 | :members: 92 | :undoc-members: 93 | :show-inheritance: 94 | 95 | 96 | Module contents 97 | --------------- 98 | 99 | .. automodule:: tensorcv.callbacks 100 | :members: 101 | :undoc-members: 102 | :show-inheritance: 103 | -------------------------------------------------------------------------------- /docs/tensorcv.dataflow.dataset.rst: -------------------------------------------------------------------------------- 1 | tensorcv\.dataflow\.dataset package 2 | =================================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | tensorcv\.dataflow\.dataset\.BSDS500 module 8 | ------------------------------------------- 9 | 10 | .. automodule:: tensorcv.dataflow.dataset.BSDS500 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | tensorcv\.dataflow\.dataset\.CIFAR module 16 | ----------------------------------------- 17 | 18 | .. automodule:: tensorcv.dataflow.dataset.CIFAR 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | tensorcv\.dataflow\.dataset\.MNIST module 24 | ----------------------------------------- 25 | 26 | .. automodule:: tensorcv.dataflow.dataset.MNIST 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | 32 | Module contents 33 | --------------- 34 | 35 | .. automodule:: tensorcv.dataflow.dataset 36 | :members: 37 | :undoc-members: 38 | :show-inheritance: 39 | -------------------------------------------------------------------------------- /docs/tensorcv.dataflow.rst: -------------------------------------------------------------------------------- 1 | tensorcv\.dataflow package 2 | ========================== 3 | 4 | Subpackages 5 | ----------- 6 | 7 | .. toctree:: 8 | 9 | tensorcv.dataflow.dataset 10 | 11 | Submodules 12 | ---------- 13 | 14 | tensorcv\.dataflow\.base module 15 | ------------------------------- 16 | 17 | .. automodule:: tensorcv.dataflow.base 18 | :members: 19 | :undoc-members: 20 | :show-inheritance: 21 | 22 | tensorcv\.dataflow\.common module 23 | --------------------------------- 24 | 25 | .. automodule:: tensorcv.dataflow.common 26 | :members: 27 | :undoc-members: 28 | :show-inheritance: 29 | 30 | tensorcv\.dataflow\.image module 31 | -------------------------------- 32 | 33 | .. automodule:: tensorcv.dataflow.image 34 | :members: 35 | :undoc-members: 36 | :show-inheritance: 37 | 38 | tensorcv\.dataflow\.matlab module 39 | --------------------------------- 40 | 41 | .. automodule:: tensorcv.dataflow.matlab 42 | :members: 43 | :undoc-members: 44 | :show-inheritance: 45 | 46 | tensorcv\.dataflow\.randoms module 47 | ---------------------------------- 48 | 49 | .. automodule:: tensorcv.dataflow.randoms 50 | :members: 51 | :undoc-members: 52 | :show-inheritance: 53 | 54 | 55 | Module contents 56 | --------------- 57 | 58 | .. automodule:: tensorcv.dataflow 59 | :members: 60 | :undoc-members: 61 | :show-inheritance: 62 | -------------------------------------------------------------------------------- /docs/tensorcv.models.rst: -------------------------------------------------------------------------------- 1 | tensorcv\.models package 2 | ======================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | tensorcv\.models\.base module 8 | ----------------------------- 9 | 10 | .. automodule:: tensorcv.models.base 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | tensorcv\.models\.layers module 16 | ------------------------------- 17 | 18 | .. automodule:: tensorcv.models.layers 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | tensorcv\.models\.losses module 24 | ------------------------------- 25 | 26 | .. automodule:: tensorcv.models.losses 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | 32 | Module contents 33 | --------------- 34 | 35 | .. automodule:: tensorcv.models 36 | :members: 37 | :undoc-members: 38 | :show-inheritance: 39 | -------------------------------------------------------------------------------- /docs/tensorcv.predicts.rst: -------------------------------------------------------------------------------- 1 | tensorcv\.predicts package 2 | ========================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | tensorcv\.predicts\.base module 8 | ------------------------------- 9 | 10 | .. automodule:: tensorcv.predicts.base 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | tensorcv\.predicts\.config module 16 | --------------------------------- 17 | 18 | .. automodule:: tensorcv.predicts.config 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | tensorcv\.predicts\.predictions module 24 | -------------------------------------- 25 | 26 | .. automodule:: tensorcv.predicts.predictions 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | tensorcv\.predicts\.simple module 32 | --------------------------------- 33 | 34 | .. automodule:: tensorcv.predicts.simple 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | 40 | Module contents 41 | --------------- 42 | 43 | .. automodule:: tensorcv.predicts 44 | :members: 45 | :undoc-members: 46 | :show-inheritance: 47 | -------------------------------------------------------------------------------- /docs/tensorcv.rst: -------------------------------------------------------------------------------- 1 | tensorcv package 2 | ================ 3 | 4 | Subpackages 5 | ----------- 6 | 7 | .. toctree:: 8 | 9 | tensorcv.callbacks 10 | tensorcv.dataflow 11 | tensorcv.models 12 | tensorcv.predicts 13 | tensorcv.train 14 | tensorcv.utils 15 | 16 | Module contents 17 | --------------- 18 | 19 | .. automodule:: tensorcv 20 | :members: 21 | :undoc-members: 22 | :show-inheritance: 23 | -------------------------------------------------------------------------------- /docs/tensorcv.train.rst: -------------------------------------------------------------------------------- 1 | tensorcv\.train package 2 | ======================= 3 | 4 | Submodules 5 | ---------- 6 | 7 | tensorcv\.train\.base module 8 | ---------------------------- 9 | 10 | .. automodule:: tensorcv.train.base 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | tensorcv\.train\.config module 16 | ------------------------------ 17 | 18 | .. automodule:: tensorcv.train.config 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | tensorcv\.train\.simple module 24 | ------------------------------ 25 | 26 | .. automodule:: tensorcv.train.simple 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | 32 | Module contents 33 | --------------- 34 | 35 | .. automodule:: tensorcv.train 36 | :members: 37 | :undoc-members: 38 | :show-inheritance: 39 | -------------------------------------------------------------------------------- /docs/tensorcv.utils.rst: -------------------------------------------------------------------------------- 1 | tensorcv\.utils package 2 | ======================= 3 | 4 | Submodules 5 | ---------- 6 | 7 | tensorcv\.utils\.common module 8 | ------------------------------ 9 | 10 | .. automodule:: tensorcv.utils.common 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | tensorcv\.utils\.default module 16 | ------------------------------- 17 | 18 | .. automodule:: tensorcv.utils.default 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | tensorcv\.utils\.sesscreate module 24 | ---------------------------------- 25 | 26 | .. automodule:: tensorcv.utils.sesscreate 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | tensorcv\.utils\.utils module 32 | ----------------------------- 33 | 34 | .. automodule:: tensorcv.utils.utils 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | tensorcv\.utils\.viz module 40 | ----------------------------- 41 | 42 | .. automodule:: tensorcv.utils.viz 43 | :members: 44 | :undoc-members: 45 | :show-inheritance: 46 | 47 | 48 | Module contents 49 | --------------- 50 | 51 | .. automodule:: tensorcv.utils 52 | :members: 53 | :undoc-members: 54 | :show-inheritance: 55 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | Pillow 3 | six 4 | termcolor>=1.1 5 | tabulate>=0.7.7 6 | tqdm>4.11.1 7 | msgpack-python>0.4.0 8 | msgpack-numpy>=0.3.9 9 | pyzmq>=16 -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup(name='tensorcv', 4 | version='0.1', 5 | description=' ', 6 | url='https://github.com/conan7882/DeepVision-tensorflow', 7 | author='Qian Ge', 8 | author_email='geqian1001@gmail.com', 9 | packages=['tensorcv', 'tensorcv.utils', 'tensorcv.algorithms', 'tensorcv.callbacks', 10 | 'tensorcv.data', 'tensorcv.dataflow', 'tensorcv.dataflow.dataset', 11 | 'tensorcv.models', 'tensorcv.predicts', 'tensorcv.train', 'tensorcv.tfdataflow'], 12 | zip_safe=False) 13 | -------------------------------------------------------------------------------- /tensorcv/__init__.py: -------------------------------------------------------------------------------- 1 | # File: __init__.py 2 | # Author: Qian Ge 3 | -------------------------------------------------------------------------------- /tensorcv/algorithms/GAN/DCGAN.py: -------------------------------------------------------------------------------- 1 | # File: DCGAN.py 2 | # Author: Qian Ge 3 | 4 | import argparse 5 | 6 | import tensorflow as tf 7 | 8 | from tensorcv.dataflow.randoms import RandomVec 9 | from tensorcv.dataflow.dataset.MNIST import MNIST 10 | # import tensorcv.callbacks as cb 11 | from tensorcv.callbacks import * 12 | from tensorcv.predicts import * 13 | from tensorcv.models.layers import * 14 | from tensorcv.models.losses import * 15 | from tensorcv.predicts.simple import SimpleFeedPredictor 16 | from tensorcv.predicts.config import PridectConfig 17 | from tensorcv.models.base import GANBaseModel 18 | from tensorcv.train.config import GANTrainConfig 19 | from tensorcv.train.simple import GANFeedTrainer 20 | from tensorcv.utils.common import deconv_size 21 | 22 | import config as config_path 23 | 24 | 25 | class Model(GANBaseModel): 26 | def __init__(self, 27 | input_vec_length=100, 28 | learning_rate=[0.0002, 0.0002], 29 | num_channels=None, 30 | im_size=None): 31 | 32 | super(Model, self).__init__(input_vec_length, learning_rate) 33 | 34 | if num_channels is not None: 35 | self.num_channels = num_channels 36 | if im_size is not None: 37 | self.im_height, self.im_width = im_size 38 | 39 | self.set_is_training(True) 40 | 41 | # def _get_placeholder(self): 42 | # # image 43 | # return [self.real_data] 44 | 45 | def _create_input(self): 46 | self.real_data = tf.placeholder( 47 | tf.float32, 48 | [None, self.im_height, self.im_width, self.num_channels]) 49 | self.set_train_placeholder(self.real_data) 50 | 51 | def _generator(self, train=True): 52 | 53 | final_dim = 64 54 | filter_size = 5 55 | 56 | d_height_2, d_width_2 = deconv_size(self.im_height, self.im_width) 57 | d_height_4, d_width_4 = deconv_size(d_height_2, d_width_2) 58 | d_height_8, d_width_8 = deconv_size(d_height_4, d_width_4) 59 | d_height_16, d_width_16 = deconv_size(d_height_8, d_width_8) 60 | 61 | rand_vec = self.get_random_vec_placeholder() 62 | batch_size = tf.shape(rand_vec)[0] 63 | 64 | with tf.variable_scope('fc1'): 65 | fc1 = fc(rand_vec, d_height_16 * d_width_16 * final_dim * 8, 'fc') 66 | fc1 = tf.nn.relu(batch_norm(fc1, train=train)) 67 | fc1_reshape = tf.reshape( 68 | fc1, [-1, d_height_16, d_width_16, final_dim * 8]) 69 | 70 | with tf.variable_scope('dconv2'): 71 | output_shape = [batch_size, d_height_8, d_width_8, final_dim * 4] 72 | dconv2 = dconv(fc1_reshape, filter_size, 73 | out_shape=output_shape, name='dconv') 74 | bn_dconv2 = tf.nn.relu(batch_norm(dconv2, train=train)) 75 | 76 | with tf.variable_scope('dconv3'): 77 | output_shape = [batch_size, d_height_4, d_width_4, final_dim * 2] 78 | dconv3 = dconv(bn_dconv2, filter_size, 79 | out_shape=output_shape, name='dconv') 80 | bn_dconv3 = tf.nn.relu(batch_norm(dconv3, train=train)) 81 | 82 | with tf.variable_scope('dconv4'): 83 | output_shape = [batch_size, d_height_2, d_width_2, final_dim] 84 | dconv4 = dconv(bn_dconv3, filter_size, 85 | out_shape=output_shape, name='dconv') 86 | bn_dconv4 = tf.nn.relu(batch_norm(dconv4, train=train)) 87 | 88 | with tf.variable_scope('dconv5'): 89 | # Do not use batch norm for the last layer 90 | output_shape = [batch_size, self.im_height, 91 | self.im_width, self.num_channels] 92 | dconv5 = dconv(bn_dconv4, filter_size, 93 | out_shape=output_shape, name='dconv') 94 | 95 | generation = tf.nn.tanh(dconv5, 'gen_out') 96 | return generation 97 | 98 | def _discriminator(self, input_im): 99 | 100 | filter_size = 5 101 | start_depth = 64 102 | 103 | with tf.variable_scope('conv1'): 104 | conv1 = conv(input_im, filter_size, start_depth, stride=2) 105 | bn_conv1 = leaky_relu((batch_norm(conv1))) 106 | 107 | with tf.variable_scope('conv2'): 108 | conv2 = conv(bn_conv1, filter_size, start_depth * 2, stride=2) 109 | bn_conv2 = leaky_relu((batch_norm(conv2))) 110 | 111 | with tf.variable_scope('conv3'): 112 | conv3 = conv(bn_conv2, filter_size, start_depth * 4, stride=2) 113 | bn_conv3 = leaky_relu((batch_norm(conv3))) 114 | 115 | with tf.variable_scope('conv4'): 116 | conv4 = conv(bn_conv3, filter_size, start_depth * 8, stride=2) 117 | bn_conv4 = leaky_relu((batch_norm(conv4))) 118 | 119 | with tf.variable_scope('fc5'): 120 | fc5 = fc(bn_conv4, 1, name='fc') 121 | 122 | return fc5 123 | 124 | def _ex_setup_graph(self): 125 | tf.identity(self.get_sample_gen_data(), 'generate_image') 126 | tf.identity(self.get_generator_loss(), 'g_loss_check') 127 | tf.identity(self.get_discriminator_loss(), 'd_loss_check') 128 | 129 | def _setup_summary(self): 130 | with tf.name_scope('generator_summary'): 131 | tf.summary.image('generate_sample', 132 | tf.cast(self.get_sample_gen_data(), tf.float32), 133 | collections=[self.g_collection]) 134 | tf.summary.image('generate_train', 135 | tf.cast(self.get_gen_data(), tf.float32), 136 | collections=[self.d_collection]) 137 | with tf.name_scope('real_data'): 138 | tf.summary.image('real_data', 139 | tf.cast(self.real_data, tf.float32), 140 | collections=[self.d_collection]) 141 | 142 | 143 | def get_config(FLAGS): 144 | dataset_train = MNIST('train', data_dir=config_path.data_dir, 145 | normalize='tanh') 146 | 147 | inference_list = InferImages('generate_image', prefix='gen') 148 | random_feed = RandomVec(len_vec=FLAGS.len_vec) 149 | 150 | return GANTrainConfig( 151 | dataflow=dataset_train, 152 | model=Model(input_vec_length=FLAGS.len_vec, 153 | learning_rate=[0.0002, 0.0002]), 154 | monitors=TFSummaryWriter(), 155 | discriminator_callbacks=[ 156 | # ModelSaver(periodic = 100), 157 | CheckScalar(['d_loss_check', 'g_loss_check'], 158 | periodic=10)], 159 | generator_callbacks=[GANInference(inputs=random_feed, 160 | periodic=100, 161 | inferencers=inference_list)], 162 | batch_size=FLAGS.batch_size, 163 | max_epoch=100, 164 | summary_d_periodic=10, 165 | summary_g_periodic=10, 166 | default_dirs=config_path) 167 | 168 | 169 | def get_predictConfig(FLAGS): 170 | random_feed = RandomVec(len_vec=FLAGS.len_vec) 171 | prediction_list = PredictionImage('generate_image', 172 | 'test', merge_im=True, tanh=True) 173 | im_size = [FLAGS.h, FLAGS.w] 174 | return PridectConfig(dataflow=random_feed, 175 | model=Model(input_vec_length=FLAGS.len_vec, 176 | num_channels=FLAGS.input_channel, 177 | im_size=im_size), 178 | model_name='model-100', 179 | predictions=prediction_list, 180 | batch_size=FLAGS.batch_size, 181 | default_dirs=config_path) 182 | 183 | 184 | def get_args(): 185 | parser = argparse.ArgumentParser() 186 | 187 | parser.add_argument('--len_vec', default=100, type=int, 188 | help='Length of input random vector') 189 | parser.add_argument('--input_channel', default=1, type=int, 190 | help='Number of image channels') 191 | parser.add_argument('--h', default=28, type=int, 192 | help='Heigh of input images') 193 | parser.add_argument('--w', default=28, type=int, 194 | help='Width of input images') 195 | parser.add_argument('--batch_size', default=64, type=int) 196 | 197 | parser.add_argument('--predict', action='store_true', 198 | help='Run prediction') 199 | parser.add_argument('--train', action='store_true', 200 | help='Train the model') 201 | 202 | return parser.parse_args() 203 | 204 | 205 | if __name__ == '__main__': 206 | 207 | FLAGS = get_args() 208 | 209 | if FLAGS.train: 210 | config = get_config(FLAGS) 211 | GANFeedTrainer(config).train() 212 | elif FLAGS.predict: 213 | config = get_predictConfig(FLAGS) 214 | SimpleFeedPredictor(config).run_predict() 215 | -------------------------------------------------------------------------------- /tensorcv/algorithms/GAN/README.md: -------------------------------------------------------------------------------- 1 | ## Deep Convolutional Generative Adversarial Networks (DCGAN) 2 | 3 | - Model for [DCGAN](https://arxiv.org/abs/1511.06434) 4 | 5 | - An example implementation of DCGAN using this model can be found [here](https://github.com/conan7882/tensorflow-DCGAN). 6 | 7 | *Details of how to write your own GAN model and callbacks configuration can be found in docs (coming soon).* 8 | 9 | ## Implementation Details 10 | #### Generator 11 | 12 | #### Discriminator 13 | 14 | #### Loss function 15 | 16 | #### Optimizer 17 | 18 | #### Variable initialization 19 | 20 | #### Batch normal and LeakyReLu 21 | 22 | #### Training settings 23 | - training rate 24 | - training step 25 | 26 | ## Default Summary 27 | ### Scalar: 28 | - loss of generator and discriminator 29 | 30 | ### Histogram: 31 | - gradients of generator and discriminator 32 | - discriminator output for real image and generated image 33 | 34 | ### Image 35 | - real image and generated image 36 | 37 | ## Callbacks 38 | 39 | ### Available callbacks: 40 | 41 | - TrainSummary() 42 | - CheckScalar() 43 | - GANInference() 44 | 45 | ### Available inferencer: 46 | - InferImages() 47 | 48 | ### Available predictor 49 | - PredictionImage() 50 | 51 | ## Test 52 | To test this model on MNIST dataset, first put all the directories in *config.py*. 53 | 54 | For training: 55 | 56 | $ python DCGAN.py --train --batch_size 64 57 | 58 | For testing, the batch size has to be the same as training: 59 | 60 | $ python DCGAN.py --predict --batch_size 64 61 | 62 | Using this model run on other dataset can be found [here](https://github.com/conan7882/tensorflow-DCGAN). 63 | 64 | 65 | ## Results 66 | *More results can be found [here](https://github.com/conan7882/tensorflow-DCGAN#results).* 67 | ### MNIST 68 | 69 | ![MNIST_result1](fig/mnist_result.png) 70 | 71 | 72 | ## Reference 73 | - [Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks.](https://arxiv.org/abs/1511.06434) 74 | 75 | 76 | 77 | -------------------------------------------------------------------------------- /tensorcv/algorithms/GAN/config.py: -------------------------------------------------------------------------------- 1 | # File: config.py 2 | # Author: Qian Ge 3 | 4 | # directory of input data 5 | data_dir = 'D:\\Qian\\GitHub\\workspace\\tensorflow-DCGAN\\MNIST_data\\' 6 | 7 | # directory for saving inference data 8 | infer_dir = 'D:\\Qian\\GitHub\\workspace\\test\\result\\' 9 | 10 | # directory for saving summary 11 | summary_dir = 'D:\\Qian\\GitHub\\workspace\\test\\' 12 | 13 | # directory for saving checkpoint 14 | checkpoint_dir = 'D:\\Qian\\GitHub\\workspace\\test\\' 15 | 16 | # directory for restoring checkpoint 17 | model_dir = 'D:\\Qian\\GitHub\\workspace\\test\\' 18 | 19 | # directory for saving prediction results 20 | result_dir = 'D:\\Qian\\GitHub\\workspace\\test\\2\\' 21 | -------------------------------------------------------------------------------- /tensorcv/algorithms/GAN/fig/mnist_result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/conan7882/DeepVision-tensorflow/0ffc81a62eccf021077019fb59b0e9e7615e8222/tensorcv/algorithms/GAN/fig/mnist_result.png -------------------------------------------------------------------------------- /tensorcv/algorithms/pretrained/VGG_.py: -------------------------------------------------------------------------------- 1 | # File: VGG.py 2 | # Author: Qian Ge 3 | 4 | import numpy as np 5 | import tensorflow as tf 6 | 7 | from tensorcv.models.layers import * 8 | from tensorcv.models.base import BaseModel 9 | 10 | import config 11 | 12 | VGG_MEAN = [103.939, 116.779, 123.68] 13 | 14 | class BaseVGG(BaseModel): 15 | """ base of VGG class """ 16 | def __init__(self, num_class=1000, 17 | num_channels=3, 18 | im_height=224, im_width=224, 19 | learning_rate=0.0001, 20 | is_load=False, 21 | pre_train_path=None): 22 | """ 23 | Args: 24 | num_class (int): number of image classes 25 | num_channels (int): number of input channels 26 | im_height, im_width (int): size of input image 27 | Can be unknown when testing. 28 | learning_rate (float): learning rate of training 29 | """ 30 | 31 | self.learning_rate = learning_rate 32 | self.num_channels = num_channels 33 | self.im_height = im_height 34 | self.im_width = im_width 35 | self.num_class = num_class 36 | 37 | self._is_load = is_load 38 | if self._is_load and pre_train_path is None: 39 | raise ValueError('pre_train_path can not be None!') 40 | self._pre_train_path = pre_train_path 41 | 42 | self.set_is_training(True) 43 | 44 | def _create_input(self): 45 | self.keep_prob = tf.placeholder(tf.float32, name='keep_prob') 46 | self.image = tf.placeholder(tf.float32, name='image', 47 | shape=[None, self.im_height, self.im_width, self.num_channels]) 48 | self.label = tf.placeholder(tf.int64, [None], 'label') 49 | # self.label = tf.placeholder(tf.int64, [None, self.num_class], 'label') 50 | 51 | self.set_model_input([self.image, self.keep_prob]) 52 | self.set_dropout(self.keep_prob, keep_prob=0.5) 53 | self.set_train_placeholder([self.image, self.label]) 54 | self.set_prediction_placeholder(self.image) 55 | 56 | @staticmethod 57 | def load_pre_trained(session, model_path, skip_layer=[]): 58 | weights_dict = np.load(model_path, encoding='latin1').item() 59 | for layer_name in weights_dict: 60 | print('Loading ' + layer_name) 61 | if layer_name not in skip_layer: 62 | with tf.variable_scope(layer_name, reuse=True): 63 | for data in weights_dict[layer_name]: 64 | if len(data.shape) == 1: 65 | var = tf.get_variable('biases', trainable=False) 66 | session.run(var.assign(data)) 67 | else: 68 | var = tf.get_variable('weights', trainable=False) 69 | session.run(var.assign(data)) 70 | 71 | class VGG19(BaseVGG): 72 | # def __init__(self, num_class = 1000, 73 | # num_channels = 3, 74 | # im_height = 224, im_width = 224, 75 | # learning_rate = 0.0001): 76 | 77 | # super(VGG19, self).__init__(num_class = num_class, 78 | # num_channels = num_channels, 79 | # im_height = im_height, 80 | # im_width = im_width, 81 | # learning_rate = learning_rate) 82 | 83 | def _create_conv(self, input_im, data_dict): 84 | 85 | arg_scope = tf.contrib.framework.arg_scope 86 | with arg_scope([conv], nl=tf.nn.relu, trainable=True, data_dict=data_dict): 87 | conv1_1 = conv(input_im, 3, 64, 'conv1_1') 88 | conv1_2 = conv(conv1_1, 3, 64, 'conv1_2') 89 | pool1 = max_pool(conv1_2, 'pool1', padding='SAME') 90 | 91 | conv2_1 = conv(pool1, 3, 128, 'conv2_1') 92 | conv2_2 = conv(conv2_1, 3, 128, 'conv2_2') 93 | pool2 = max_pool(conv2_2, 'pool2', padding='SAME') 94 | 95 | conv3_1 = conv(pool2, 3, 256, 'conv3_1') 96 | conv3_2 = conv(conv3_1, 3, 256, 'conv3_2') 97 | conv3_3 = conv(conv3_2, 3, 256, 'conv3_3') 98 | conv3_4 = conv(conv3_3, 3, 256, 'conv3_4') 99 | pool3 = max_pool(conv3_4, 'pool3', padding='SAME') 100 | 101 | conv4_1 = conv(pool3, 3, 512, 'conv4_1') 102 | conv4_2 = conv(conv4_1, 3, 512, 'conv4_2') 103 | conv4_3 = conv(conv4_2, 3, 512, 'conv4_3') 104 | conv4_4 = conv(conv4_3, 3, 512, 'conv4_4') 105 | pool4 = max_pool(conv4_4, 'pool4', padding='SAME') 106 | 107 | conv5_1 = conv(pool4, 3, 512, 'conv5_1') 108 | conv5_2 = conv(conv5_1, 3, 512, 'conv5_2') 109 | conv5_3 = conv(conv5_2, 3, 512, 'conv5_3') 110 | conv5_4 = conv(conv5_3, 3, 512, 'conv5_4') 111 | pool5 = max_pool(conv5_4, 'pool5', padding='SAME') 112 | 113 | self.conv_out = tf.identity(conv5_4) 114 | 115 | return pool5 116 | 117 | def _create_model(self): 118 | 119 | input_im = self.model_input[0] 120 | keep_prob = self.model_input[1] 121 | 122 | # Convert RGB image to BGR image 123 | red, green, blue = tf.split(axis=3, num_or_size_splits=3, 124 | value=input_im) 125 | 126 | input_bgr = tf.concat(axis=3, values=[ 127 | blue - VGG_MEAN[0], 128 | green - VGG_MEAN[1], 129 | red - VGG_MEAN[2], 130 | ]) 131 | 132 | data_dict = {} 133 | if self._is_load: 134 | data_dict = np.load(self._pre_train_path, encoding='latin1').item() 135 | 136 | conv_output = self._create_conv(input_bgr, data_dict) 137 | 138 | arg_scope = tf.contrib.framework.arg_scope 139 | with arg_scope([fc], trainable=True, data_dict=data_dict): 140 | fc6 = fc(conv_output, 4096, 'fc6', nl=tf.nn.relu) 141 | dropout_fc6 = dropout(fc6, keep_prob, self.is_training) 142 | 143 | fc7 = fc(dropout_fc6, 4096, 'fc7', nl=tf.nn.relu) 144 | dropout_fc7 = dropout(fc7, keep_prob, self.is_training) 145 | 146 | fc8 = fc(dropout_fc7, self.num_class, 'fc8') 147 | 148 | self.output = tf.identity(fc8, 'model_output') 149 | 150 | class VGG19_FCN(VGG19): 151 | # def __init__(self, num_class = 1000, 152 | # num_channels = 3, 153 | # im_height = 224, im_width = 224, 154 | # learning_rate = 0.0001): 155 | 156 | # super(VGG19, self).__init__(num_class = num_class, 157 | # num_channels = num_channels, 158 | # im_height = im_height, 159 | # im_width = im_width, 160 | # learning_rate = learning_rate) 161 | 162 | def _create_model(self): 163 | 164 | input_im = self.model_input[0] 165 | keep_prob = self.model_input[1] 166 | 167 | # Convert rgb image to bgr image 168 | red, green, blue = tf.split(axis=3, num_or_size_splits=3, 169 | value=input_im) 170 | 171 | input_bgr = tf.concat(axis=3, values=[ 172 | blue - VGG_MEAN[0], 173 | green - VGG_MEAN[1], 174 | red - VGG_MEAN[2], 175 | ]) 176 | 177 | data_dict = {} 178 | if self._is_load: 179 | data_dict = np.load(self._pre_train_path, encoding='latin1').item() 180 | 181 | conv_outptu = self._create_conv(input_bgr, data_dict) 182 | 183 | arg_scope = tf.contrib.framework.arg_scope 184 | with arg_scope([conv], trainable=True, data_dict=data_dict): 185 | 186 | fc6 = conv(conv_outptu, 7, 4096, 'fc6', nl=tf.nn.relu, padding='VALID') 187 | dropout_fc6 = dropout(fc6, keep_prob, self.is_training) 188 | 189 | fc7 = conv(dropout_fc6, 1, 4096, 'fc7', nl=tf.nn.relu, padding='VALID') 190 | dropout_fc7 = dropout(fc7, keep_prob, self.is_training) 191 | 192 | fc8 = conv(dropout_fc7, 1, self.num_class, 'fc8', padding='VALID') 193 | 194 | # self.conv_output = tf.identity(conv5_4, 'conv_output') 195 | self.output = tf.identity(fc8, 'model_output') 196 | filter_size = [tf.shape(fc8)[1], tf.shape(fc8)[2]] 197 | 198 | self.avg_output = global_avg_pool(fc8) 199 | 200 | @staticmethod 201 | def load_pre_trained(session, model_path, skip_layer=[]): 202 | fc_layers = ['fc6', 'fc7', 'fc8'] 203 | weights_dict = np.load(model_path, encoding='latin1').item() 204 | for layer_name in weights_dict: 205 | print('Loading ' + layer_name) 206 | if layer_name not in skip_layer: 207 | with tf.variable_scope(layer_name, reuse=True): 208 | for data in weights_dict[layer_name]: 209 | if len(data.shape) == 1: 210 | var = tf.get_variable('biases', trainable=False) 211 | session.run(var.assign(data)) 212 | else: 213 | var = tf.get_variable('weights', trainable=False) 214 | if layer_name == 'fc6': 215 | data = tf.reshape(data, [7,7,512,4096]) 216 | elif layer_name == 'fc7': 217 | data = tf.reshape(data, [1,1,4096,4096]) 218 | elif layer_name == 'fc8': 219 | data = tf.reshape(data, [1,1,4096,1000]) 220 | session.run(var.assign(data)) 221 | 222 | if __name__ == '__main__': 223 | VGG = VGG19(num_class=1000, 224 | num_channels=3, 225 | im_height=224, 226 | im_width=224) 227 | 228 | VGG.create_graph() 229 | 230 | writer = tf.summary.FileWriter(config.summary_dir) 231 | with tf.Session() as sess: 232 | sess.run(tf.global_variables_initializer()) 233 | writer.add_graph(sess.graph) 234 | 235 | writer.close() 236 | 237 | 238 | 239 | -------------------------------------------------------------------------------- /tensorcv/callbacks/__init__.py: -------------------------------------------------------------------------------- 1 | # File: __init__.py 2 | # Author: Qian Ge 3 | 4 | from .saver import * 5 | from .summary import * 6 | from .inference import * 7 | from .inferencer import * 8 | from .monitors import * 9 | from .debug import * -------------------------------------------------------------------------------- /tensorcv/callbacks/base.py: -------------------------------------------------------------------------------- 1 | import scipy.misc 2 | import os 3 | from abc import ABCMeta 4 | 5 | import numpy as np 6 | import tensorflow as tf 7 | 8 | __all__ = ['Callback', 'ProxyCallback'] 9 | 10 | def assert_type(v, tp): 11 | assert isinstance(v, tp),\ 12 | "Expect " + str(tp) + ", but " + str(v.__class__) + " is given!" 13 | 14 | class Callback(object): 15 | """ base class for callbacks """ 16 | 17 | def setup_graph(self, trainer): 18 | self.trainer = trainer 19 | self._setup_graph() 20 | 21 | @property 22 | def global_step(self): 23 | return self.trainer.get_global_step 24 | 25 | @property 26 | def epochs_completed(self): 27 | return self.trainer.epochs_completed 28 | 29 | def _setup_graph(self): 30 | pass 31 | 32 | def before_run(self, rct): 33 | fetch = self._before_run(rct) 34 | if fetch is None: 35 | return None 36 | assert_type(fetch, tf.train.SessionRunArgs) 37 | return fetch 38 | 39 | def _before_run(self, rct): 40 | return None 41 | 42 | def after_run(self, rct, val): 43 | self._after_run(rct, val) 44 | 45 | def _after_run(self, rct, val): 46 | pass 47 | 48 | def before_train(self): 49 | self._before_train() 50 | 51 | def _before_train(self): 52 | pass 53 | 54 | def before_inference(self): 55 | self._before_inference() 56 | 57 | def _before_inference(self): 58 | pass 59 | 60 | def after_train(self): 61 | self._after_train() 62 | 63 | def _after_train(self): 64 | pass 65 | 66 | def before_epoch(self): 67 | self._before_epoch() 68 | 69 | def _before_epoch(self): 70 | pass 71 | 72 | def after_epoch(self): 73 | self._after_epoch() 74 | 75 | def _after_epoch(self): 76 | pass 77 | 78 | def trigger_epoch(self): 79 | self._trigger_epoch() 80 | 81 | def _trigger_epoch(self): 82 | self.trigger() 83 | 84 | def trigger_step(self): 85 | self._trigger_step() 86 | 87 | def _trigger_step(self): 88 | pass 89 | 90 | def trigger(self): 91 | self._trigger() 92 | 93 | def _trigger(self): 94 | pass 95 | 96 | # def before_run(self): 97 | 98 | class ProxyCallback(Callback): 99 | def __init__(self, cb): 100 | assert_type(cb, Callback) 101 | self.cb = cb 102 | 103 | def __str__(self): 104 | return "Proxy-" + str(self.cb) 105 | 106 | def _before_train(self): 107 | self.cb.before_train() 108 | 109 | def _before_inference(self): 110 | self.cb.before_inference() 111 | 112 | def _setup_graph(self): 113 | with tf.name_scope(None): 114 | self.cb.setup_graph(self.trainer) 115 | 116 | def _trigger_epoch(self): 117 | self.cb.trigger_epoch() 118 | 119 | def _trigger(self): 120 | self.cb.trigger() 121 | 122 | def _trigger_step(self): 123 | self.cb.trigger_step() 124 | 125 | def _after_train(self): 126 | self.cb.after_train() 127 | 128 | def _before_epoch(self): 129 | self.cb.before_epoch() 130 | 131 | def _after_epoch(self): 132 | self.cb.after_epoch() 133 | 134 | def _before_run(self, crt): 135 | self.cb.before_run(crt) 136 | 137 | def _after_run(self, crt, val): 138 | self.cb.after_run(crt, val) 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | -------------------------------------------------------------------------------- /tensorcv/callbacks/debug.py: -------------------------------------------------------------------------------- 1 | # File: inference.py 2 | # Author: Qian Ge 3 | 4 | import tensorflow as tf 5 | 6 | from .base import Callback 7 | from ..utils.common import get_tensors_by_names 8 | 9 | __all__ = ['CheckScalar'] 10 | 11 | def assert_type(v, tp): 12 | assert isinstance(v, tp), \ 13 | "Expect " + str(tp) + ", but " + str(v.__class__) + " is given!" 14 | 15 | class CheckScalar(Callback): 16 | """ print scalar tensor values during training 17 | Attributes: 18 | _tensors 19 | _names 20 | """ 21 | def __init__(self, tensors, periodic=1): 22 | """ init CheckScalar object 23 | Args: 24 | tensors : list[string] A tensor name or list of tensor names 25 | """ 26 | if not isinstance(tensors, list): 27 | tensors = [tensors] 28 | self._tensors = tensors 29 | self._names = tensors 30 | 31 | self._periodic = periodic 32 | 33 | def _setup_graph(self): 34 | self._tensors = get_tensors_by_names(self._tensors) 35 | 36 | def _before_run(self, _): 37 | if self.global_step % self._periodic == 0: 38 | return tf.train.SessionRunArgs(fetches = self._tensors) 39 | else: 40 | return None 41 | 42 | def _after_run(self, _, val): 43 | if val.results is not None: 44 | print([name + ': ' + str(v) 45 | for name, v in zip(self._names, val.results)]) 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | -------------------------------------------------------------------------------- /tensorcv/callbacks/group.py: -------------------------------------------------------------------------------- 1 | import scipy.misc 2 | import os 3 | 4 | import numpy as np 5 | import tensorflow as tf 6 | 7 | from .base import Callback 8 | from .hooks import Callback2Hook 9 | 10 | __all__ = ['Callbacks'] 11 | 12 | def assert_type(v, tp): 13 | assert isinstance(v, tp),\ 14 | "Expect " + str(tp) + ", but " + str(v.__class__) + " is given!" 15 | 16 | class Callbacks(Callback): 17 | """ group all the callback """ 18 | def __init__(self, cbs): 19 | for cb in cbs: 20 | assert_type(cb, Callback) 21 | self.cbs = cbs 22 | 23 | 24 | def _setup_graph(self): 25 | with tf.name_scope(None): 26 | for cb in self.cbs: 27 | cb.setup_graph(self.trainer) 28 | 29 | def get_hooks(self): 30 | return [Callback2Hook(cb) for cb in self.cbs] 31 | 32 | def _before_train(self): 33 | for cb in self.cbs: 34 | cb.before_train() 35 | 36 | def _before_inference(self): 37 | for cb in self.cbs: 38 | cb.before_inference() 39 | 40 | def _after_train(self): 41 | for cb in self.cbs: 42 | cb.after_train() 43 | 44 | 45 | def _before_epoch(self): 46 | for cb in self.cbs: 47 | cb.before_epoch() 48 | 49 | 50 | def _after_epoch(self): 51 | for cb in self.cbs: 52 | cb.after_epoch() 53 | 54 | def _trigger_epoch(self): 55 | for cb in self.cbs: 56 | cb.trigger_epoch() 57 | 58 | def _trigger_step(self): 59 | for cb in self.cbs: 60 | cb.trigger_step() 61 | 62 | # def trigger(self): 63 | # self._trigger() 64 | 65 | # def _trigger(self): 66 | # pass 67 | 68 | # def before_run(self): -------------------------------------------------------------------------------- /tensorcv/callbacks/hooks.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from .base import Callback 4 | from .inferencer import InferencerBase 5 | from ..predicts.predictions import PredictionBase 6 | 7 | __all__ = ['Callback2Hook', 'Infer2Hook', 'Prediction2Hook'] 8 | 9 | def assert_type(v, tp): 10 | assert isinstance(v, tp), \ 11 | "Expect " + str(tp) + ", but " + str(v.__class__) + " is given!" 12 | 13 | class Callback2Hook(tf.train.SessionRunHook): 14 | """ """ 15 | def __init__(self, cb): 16 | self.cb = cb 17 | 18 | def before_run(self, rct): 19 | return self.cb.before_run(rct) 20 | 21 | def after_run(self, rct, val): 22 | self.cb.after_run(rct, val) 23 | 24 | class Infer2Hook(tf.train.SessionRunHook): 25 | 26 | def __init__(self, inferencer): 27 | # to be modified 28 | assert_type(inferencer, InferencerBase) 29 | self.inferencer = inferencer 30 | 31 | def before_run(self, rct): 32 | return tf.train.SessionRunArgs(fetches=self.inferencer.put_fetch()) 33 | 34 | def after_run(self, rct, val): 35 | self.inferencer.get_fetch(val) 36 | 37 | class Prediction2Hook(tf.train.SessionRunHook): 38 | def __init__(self, prediction): 39 | assert_type(prediction, PredictionBase) 40 | self.prediction = prediction 41 | 42 | def before_run(self, rct): 43 | 44 | return tf.train.SessionRunArgs(fetches=self.prediction.get_predictions()) 45 | 46 | def after_run(self, rct, val): 47 | self.prediction.after_prediction(val.results) 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | -------------------------------------------------------------------------------- /tensorcv/callbacks/inference.py: -------------------------------------------------------------------------------- 1 | # File: inference.py 2 | # Author: Qian Ge 3 | 4 | import scipy.misc 5 | import os 6 | from abc import ABCMeta 7 | 8 | import numpy as np 9 | import tensorflow as tf 10 | 11 | from .base import Callback 12 | from .group import Callbacks 13 | from .inputs import FeedInput 14 | from ..dataflow.base import DataFlow 15 | from ..dataflow.randoms import RandomVec 16 | from .hooks import Callback2Hook, Infer2Hook 17 | from ..utils.sesscreate import ReuseSessionCreator 18 | from .inferencer import InferencerBase 19 | 20 | __all__ = ['FeedInference', 'GANInference', 'FeedInferenceBatch'] 21 | 22 | def assert_type(v, tp): 23 | assert isinstance(v, tp), \ 24 | "Expect " + str(tp) + ", but " + str(v.__class__) + " is given!" 25 | 26 | class InferenceBase(Callback): 27 | """ base class for Inference """ 28 | def __init__(self, inputs=None, periodic=1, 29 | inferencers=None, extra_cbs=None, 30 | infer_batch_size=None): 31 | """ 32 | Args: 33 | extra_cbs (list[Callback]) 34 | """ 35 | self._inputs = inputs 36 | self._periodic = periodic 37 | self._infer_batch_size = infer_batch_size 38 | 39 | assert inferencers is not None or extra_cbs is not None,\ 40 | "Inferencers and extra_cbs cannot be both None!" 41 | 42 | if not isinstance(inferencers, list): 43 | inferencers = [inferencers] 44 | for infer in inferencers: 45 | assert_type(infer, InferencerBase) 46 | self._inference_list = inferencers 47 | 48 | if extra_cbs is None: 49 | self._extra_cbs = [] 50 | elif not isinstance(extra_cbs, list): 51 | self._extra_cbs = [extra_cbs] 52 | else: 53 | self._extra_cbs = extra_cbs 54 | 55 | self._cbs = [] 56 | 57 | def _setup_graph(self): 58 | self.model = self.trainer.model 59 | self.setup_inference() 60 | self.register_cbs() 61 | self._cbs = Callbacks(self._cbs) 62 | self._cbs.setup_graph(self.trainer) 63 | 64 | for infer in self._inference_list: 65 | infer.setup_inferencer() 66 | 67 | def setup_inference(self): 68 | self._setup_inference() 69 | 70 | for infer in self._inference_list: 71 | assert_type(infer, InferencerBase) 72 | infer.setup_graph(self.trainer) 73 | 74 | if self._infer_batch_size is None: 75 | self._inputs.set_batch_size(self.trainer.config.batch_size) 76 | else: 77 | self._inputs.set_batch_size(self._infer_batch_size) 78 | 79 | def _setup_inference(self): 80 | """ setup extra default callbacks for inference """ 81 | pass 82 | 83 | def register_cbs(self): 84 | for cb in self._extra_cbs: 85 | assert_type(cb, Callback) 86 | self._cbs.append(cb) 87 | 88 | def get_infer_hooks(self): 89 | return (self._cbs.get_hooks() 90 | + [Infer2Hook(infer) for infer in self._inference_list]) 91 | 92 | def _create_infer_sess(self): 93 | self.sess = self.trainer.sess 94 | infer_hooks = self.get_infer_hooks() 95 | self.hooked_sess = tf.train.MonitoredSession( 96 | session_creator = ReuseSessionCreator(self.sess), 97 | hooks = infer_hooks) 98 | 99 | def _trigger_step(self): 100 | if self.global_step % self._periodic == 0: 101 | for infer in self._inference_list: 102 | infer.before_inference() 103 | 104 | self._create_infer_sess() 105 | self.inference_step() 106 | 107 | for infer in self._inference_list: 108 | infer.after_inference() 109 | 110 | def inference_step(self): 111 | # TODO to be modified 112 | self.model.set_is_training(False) 113 | self._cbs.before_inference() 114 | self._inference_step() 115 | 116 | def _inference_step(self, extra_feed): 117 | self.hooked_sess.run(fetches = [], feed_dict = extra_feed) 118 | 119 | def _after_train(self): 120 | self._cbs.after_train() 121 | 122 | 123 | class FeedInference(InferenceBase): 124 | """ 125 | default inferencer: 126 | inference_list = InferImages('generator/gen_image', prefix = 'gen') 127 | """ 128 | def __init__(self, inputs, periodic=1, 129 | inferencers=[], extra_cbs=None, 130 | infer_batch_size=None): 131 | assert_type(inputs, DataFlow) 132 | 133 | # inferencers.append(InferImages('default', prefix = 'gen')) 134 | super(FeedInference, self).__init__(inputs=inputs, 135 | periodic=periodic, 136 | inferencers=inferencers, 137 | extra_cbs=extra_cbs, 138 | infer_batch_size=infer_batch_size) 139 | 140 | def _setup_inference(self): 141 | placeholders = self.model.get_train_placeholder() 142 | self._extra_cbs.append(FeedInput(self._inputs, placeholders)) 143 | 144 | def _inference_step(self): 145 | model_feed = self.model.get_graph_feed() 146 | while self._inputs.epochs_completed <= 0: 147 | self.hooked_sess.run(fetches = [], feed_dict = model_feed) 148 | self._inputs.reset_epochs_completed(0) 149 | 150 | class FeedInferenceBatch(FeedInference): 151 | """ do not use all validation data """ 152 | def __init__(self, inputs, periodic=1, 153 | batch_count=10, 154 | inferencers=[], extra_cbs=None, 155 | infer_batch_size=None): 156 | self._batch_count = batch_count 157 | super(FeedInferenceBatch, self).__init__(inputs=inputs, 158 | periodic=periodic, 159 | inferencers=inferencers, 160 | extra_cbs=extra_cbs, 161 | infer_batch_size=infer_batch_size) 162 | def _inference_step(self): 163 | model_feed = self.model.get_graph_feed() 164 | for i in range(self._batch_count): 165 | self.hooked_sess.run(fetches=[], feed_dict=model_feed) 166 | 167 | 168 | class GANInference(InferenceBase): 169 | def __init__(self, inputs=None, periodic=1, 170 | inferencers=None, extra_cbs=None): 171 | if inputs is not None: 172 | assert_type(inputs, RandomVec) 173 | super(GANInference, self).__init__(inputs=inputs, 174 | periodic=periodic, 175 | inferencers=inferencers, 176 | extra_cbs=extra_cbs) 177 | 178 | def _setup_inference(self): 179 | if self._inputs is not None: 180 | self._inputs.set_batch_size(self.trainer.config.batch_size) 181 | rand_vec_phs = self.model.get_random_vec_placeholder() 182 | self._extra_cbs.append(FeedInput(self._inputs, rand_vec_phs)) 183 | 184 | def _inference_step(self): 185 | if self._inputs is None: 186 | model_feed = self.model.get_graph_feed() 187 | else: 188 | model_feed = {} 189 | # while self._inputs.epochs_completed <= 0: 190 | self.hooked_sess.run(fetches=[], feed_dict=model_feed) 191 | # self._inputs.reset_epochs_completed(0) 192 | 193 | -------------------------------------------------------------------------------- /tensorcv/callbacks/inferencer.py: -------------------------------------------------------------------------------- 1 | # File: inference.py 2 | # Author: Qian Ge 3 | 4 | import os 5 | 6 | import numpy as np 7 | import tensorflow as tf 8 | 9 | from .base import Callback 10 | from ..utils.common import get_tensors_by_names, check_dir, match_tensor_save_name 11 | from ..utils.viz import * 12 | 13 | __all__ = ['InferencerBase', 'InferImages', 'InferScalars', 'InferOverlay', 'InferMat'] 14 | 15 | class InferencerBase(Callback): 16 | 17 | def setup_inferencer(self): 18 | if not isinstance(self._names, list): 19 | self._names = [self._names] 20 | self._names = get_tensors_by_names(self._names) 21 | self._setup_inference(self.trainer.default_dirs) 22 | 23 | def _setup_inference(self, default_dirs=None): 24 | pass 25 | 26 | def put_fetch(self): 27 | return self._put_fetch() 28 | 29 | def _put_fetch(self): 30 | return self._names 31 | # pass 32 | 33 | def get_fetch(self, val): 34 | self._get_fetch(val) 35 | 36 | def _get_fetch(self, val): 37 | pass 38 | 39 | def before_inference(self): 40 | """ process before every inference """ 41 | self._before_inference() 42 | 43 | def _before_inference(self): 44 | pass 45 | 46 | def after_inference(self): 47 | self._after_inference() 48 | 49 | # if re is not None: 50 | # for key, val in re.items(): 51 | # s = tf.Summary() 52 | # s.value.add(tag = key, simple_value = val) 53 | # self.trainer.monitors.process_summary(s) 54 | 55 | def _after_inference(self): 56 | return None 57 | 58 | 59 | 60 | class InferImages(InferencerBase): 61 | def __init__(self, im_name, prefix=None, color=False, tanh=False): 62 | self._names, self._prefix = match_tensor_save_name(im_name, prefix) 63 | self._color = color 64 | self._tanh = tanh 65 | 66 | def _setup_inference(self, default_dirs=None): 67 | try: 68 | self._save_dir = os.path.join(self.trainer.default_dirs.infer_dir) 69 | check_dir(self._save_dir) 70 | except AttributeError: 71 | raise AttributeError('summary_dir is not set in infer_dir.py!') 72 | 73 | def _before_inference(self): 74 | self._result_list = [] 75 | 76 | def _get_fetch(self, val): 77 | self._result_list.append(val.results) 78 | # self._result_im = val.results 79 | 80 | def _after_inference(self): 81 | # TODO add process_image to monitors 82 | # batch_size = len(self._result_im[0]) 83 | batch_size = len(self._result_list[0][0]) 84 | grid_size = self._get_grid_size(batch_size) 85 | # grid_size = [8, 8] if batch_size == 64 else [6, 6] 86 | local_step = 0 87 | for result_im in self._result_list: 88 | for im, save_name in zip(result_im, self._prefix): 89 | save_merge_images(im, [grid_size, grid_size], 90 | self._save_dir + save_name + '_step_' + str(self.global_step) +\ 91 | '_b_' + str(local_step) + '.png', 92 | color = self._color, 93 | tanh = self._tanh) 94 | local_step += 1 95 | return None 96 | 97 | def _get_grid_size(self, batch_size): 98 | try: 99 | return self._grid_size 100 | except AttributeError: 101 | self._grid_size = np.ceil(batch_size**0.5).astype(int) 102 | return self._grid_size 103 | 104 | class InferOverlay(InferImages): 105 | def __init__(self, im_name, prefix=None, color=False, tanh=False): 106 | if not isinstance(im_name, list): 107 | im_name = [im_name] 108 | assert len(im_name) == 2,\ 109 | '[InferOverlay] requires two image tensors but the input len = {}.'.\ 110 | format(len(im_name)) 111 | super(InferOverlay, self).__init__(im_name=im_name, 112 | prefix=prefix, 113 | color=color, 114 | tanh=tanh) 115 | self._overlay_prefix = '{}_{}'.format(self._prefix[0], self._prefix[1]) 116 | 117 | def _after_inference(self): 118 | # TODO add process_image to monitors 119 | # batch_size = len(self._result_im[0]) 120 | batch_size = len(self._result_list[0][0]) 121 | grid_size = self._get_grid_size(batch_size) 122 | # grid_size = [8, 8] if batch_size == 64 else [6, 6] 123 | local_step = 0 124 | for result_im in self._result_list: 125 | overlay_im_list = [] 126 | for im_1, im_2 in zip(result_im[0], result_im[1]): 127 | overlay_im = image_overlay(im_1, im_2, color = self._color) 128 | overlay_im_list.append(overlay_im) 129 | save_merge_images(np.squeeze(overlay_im_list), [grid_size, grid_size], 130 | self._save_dir + self._overlay_prefix + '_step_' +str(self.global_step) +\ 131 | '_b_' + str(local_step) + '.png', 132 | color = False, tanh = self._tanh) 133 | local_step += 1 134 | return None 135 | 136 | class InferMat(InferImages): 137 | def __init__(self, infer_save_name, mat_name, prefix=None): 138 | self._infer_save_name = str(infer_save_name) 139 | super(InferMat, self).__init__(im_name = mat_name, prefix=prefix, 140 | color=False, tanh=False) 141 | def _after_inference(self): 142 | for idx, batch_result in enumerate(self._result_list): 143 | save_path = os.path.join(self._save_dir, 144 | '{}_step_{}_b_{}.mat'.format(self._infer_save_name, self.global_step, idx)) 145 | # self._infer_save_name + '_b_' + str(idx) + str(self.global_step) + '.mat') 146 | scipy.io.savemat(save_path, {name: np.squeeze(val) for name, val 147 | in zip(self._prefix, batch_result)}) 148 | return None 149 | 150 | class InferScalars(InferencerBase): 151 | def __init__(self, scaler_names, summary_names=None): 152 | if not isinstance(scaler_names, list): 153 | scaler_names = [scaler_names] 154 | self._names = scaler_names 155 | if summary_names is None: 156 | self._summary_names = scaler_names 157 | else: 158 | if not isinstance(summary_names, list): 159 | summary_names = [summary_names] 160 | assert len(self._names) == len(summary_names), \ 161 | "length of scaler_names and summary_names has to be the same!" 162 | self._summary_names = summary_names 163 | 164 | def _before_inference(self): 165 | self.result_list = [[] for i in range(0, len(self._names))] 166 | 167 | def _get_fetch(self, val): 168 | for i,v in enumerate(val.results): 169 | self.result_list[i] += v, 170 | 171 | def _after_inference(self): 172 | """ process after get_fetch """ 173 | summary_dict = {name: np.mean(val) for name, val 174 | in zip(self._summary_names, self.result_list)} 175 | if summary_dict is not None: 176 | for key, val in summary_dict.items(): 177 | s = tf.Summary() 178 | s.value.add(tag=key, simple_value=val) 179 | self.trainer.monitors.process_summary(s) 180 | print('[infer] '+ key + ': ' + str(val)) 181 | # return {name: np.mean(val) for name, val 182 | # in zip(self._summary_names, self.result_list)} 183 | 184 | # TODO to be modified 185 | # class BinaryClassificationStats(InferencerBase): 186 | # def __init__(self, accuracy): 187 | # self._names = accuracy 188 | 189 | # def _before_inference(self): 190 | # self.result_list = [] 191 | 192 | # def _put_fetch(self): 193 | # # fetch_list = self.names 194 | # return self._names 195 | 196 | # def _get_fetch(self, val): 197 | # self.result_list += val.results, 198 | 199 | # def _after_inference(self): 200 | # """ process after get_fetch """ 201 | # return {"test_accuracy": np.mean(self.result_list)} 202 | 203 | if __name__ == '__main__': 204 | t = InferGANGenerator('gen_name', 205 | save_dir = 'D:\\Qian\\GitHub\\workspace\\test\\result\\', prefix = 1) 206 | print(t._prefix) 207 | 208 | 209 | 210 | -------------------------------------------------------------------------------- /tensorcv/callbacks/inputs.py: -------------------------------------------------------------------------------- 1 | # File: input.py 2 | # Author: Qian Ge 3 | 4 | import scipy.misc 5 | import os 6 | 7 | import tensorflow as tf 8 | 9 | from .base import Callback 10 | from ..dataflow.base import DataFlow 11 | 12 | __all__ = ['FeedInput'] 13 | 14 | def assert_type(v, tp): 15 | assert isinstance(v, tp), \ 16 | "Expect " + str(tp) + ", but " + str(v.__class__) + " is given!" 17 | 18 | class FeedInput(Callback): 19 | """ input using feed """ 20 | def __init__(self, dataflow, placeholders): 21 | assert_type(dataflow, DataFlow) 22 | self.dataflow = dataflow 23 | 24 | if not isinstance(placeholders, list): 25 | print(type(placeholders)) 26 | placeholders = [placeholders] 27 | self.placeholders = placeholders 28 | 29 | # def _setup_graph(self): 30 | # pass 31 | def _setup_graph(self): 32 | pass 33 | # self.dataflow._setup(num_epoch=self.trainer.config.max_epoch) 34 | 35 | def _before_train(self): 36 | self.dataflow.before_read_setup() 37 | 38 | def _before_inference(self): 39 | self._before_train() 40 | 41 | def _before_run(self, _): 42 | cur_batch = self.dataflow.next_batch() 43 | 44 | # assert len(cur_batch) == len(self.placeholders), \ 45 | # "[FeedInput] lenght of input {} is not equal to length of placeholders {}"\ 46 | # .format(len(cur_batch), len(self.placeholders)) 47 | 48 | feed = dict(zip(self.placeholders, cur_batch)) 49 | return tf.train.SessionRunArgs(fetches=[], feed_dict=feed) 50 | 51 | def _after_train(self): 52 | self.dataflow.after_reading() 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | -------------------------------------------------------------------------------- /tensorcv/callbacks/monitors.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import tensorflow as tf 4 | 5 | from .base import Callback 6 | from ..utils.common import check_dir 7 | 8 | __all__ = ['TrainingMonitor','Monitors','TFSummaryWriter'] 9 | 10 | def assert_type(v, tp): 11 | assert isinstance(v, tp), \ 12 | "Expect " + str(tp) + ", but " + str(v.__class__) + " is given!" 13 | 14 | class TrainingMonitor(Callback): 15 | def _setup_graph(self): 16 | pass 17 | 18 | def process_summary(self, summary): 19 | self._process_summary(summary) 20 | 21 | def _process_summary(self, summary): 22 | pass 23 | 24 | class Monitors(TrainingMonitor): 25 | """ group monitors """ 26 | def __init__(self, mons): 27 | for mon in mons: 28 | assert_type(mon, TrainingMonitor) 29 | self.mons = mons 30 | 31 | def _process_summary(self, summary): 32 | for mon in self.mons: 33 | mon.process_summary(summary) 34 | 35 | class TFSummaryWriter(TrainingMonitor): 36 | 37 | def _setup_graph(self): 38 | try: 39 | summary_dir = os.path.join(self.trainer.default_dirs.summary_dir) 40 | check_dir(summary_dir) 41 | except AttributeError: 42 | raise AttributeError('summary_dir is not set in config.py!') 43 | self._writer = tf.summary.FileWriter(summary_dir) 44 | 45 | def _before_train(self): 46 | # default to write graph 47 | self._writer.add_graph(self.trainer.sess.graph) 48 | 49 | def _after_train(self): 50 | self._writer.close() 51 | 52 | def process_summary(self, summary): 53 | self._writer.add_summary(summary, self.global_step) 54 | -------------------------------------------------------------------------------- /tensorcv/callbacks/saver.py: -------------------------------------------------------------------------------- 1 | import scipy.misc 2 | import os 3 | import os 4 | 5 | import numpy as np 6 | import tensorflow as tf 7 | 8 | from .base import Callback 9 | from ..utils.common import check_dir 10 | 11 | __all__ = ['ModelSaver'] 12 | 13 | class ModelSaver(Callback): 14 | def __init__(self, max_to_keep=5, 15 | keep_checkpoint_every_n_hours=0.5, 16 | periodic=1, 17 | checkpoint_dir=None, 18 | var_collections=tf.GraphKeys.GLOBAL_VARIABLES): 19 | 20 | self._periodic = periodic 21 | 22 | self._max_to_keep = max_to_keep 23 | self._keep_checkpoint_every_n_hours = keep_checkpoint_every_n_hours 24 | 25 | if not isinstance(var_collections, list): 26 | var_collections = [var_collections] 27 | self.var_collections = var_collections 28 | 29 | def _setup_graph(self): 30 | try: 31 | checkpoint_dir = os.path.join(self.trainer.default_dirs.checkpoint_dir) 32 | check_dir(checkpoint_dir) 33 | except AttributeError: 34 | raise AttributeError('checkpoint_dir is not set in config_path!') 35 | 36 | self._save_path = os.path.join(checkpoint_dir, 'model') 37 | self._saver = tf.train.Saver() 38 | 39 | def _trigger_step(self): 40 | if self.global_step % self._periodic == 0: 41 | self._saver.save(tf.get_default_session(), self._save_path, 42 | global_step = self.global_step) 43 | 44 | 45 | 46 | -------------------------------------------------------------------------------- /tensorcv/callbacks/summary.py: -------------------------------------------------------------------------------- 1 | import scipy.misc 2 | import os 3 | 4 | import numpy as np 5 | import tensorflow as tf 6 | 7 | from .base import Callback 8 | 9 | __all__ = ['TrainSummary'] 10 | 11 | class TrainSummary(Callback): 12 | def __init__(self, 13 | key=None, 14 | periodic=1): 15 | 16 | self.periodic = periodic 17 | if not key is None and not isinstance(key, list): 18 | key = [key] 19 | self._key = key 20 | 21 | def _setup_graph(self): 22 | self.summary_list = tf.summary.merge( 23 | [tf.summary.merge_all(key) for key in self._key]) 24 | # self.all_summary = tf.summary.merge_all(self._key) 25 | 26 | def _before_run(self, _): 27 | if self.global_step % self.periodic == 0: 28 | return tf.train.SessionRunArgs(fetches = self.summary_list) 29 | else: 30 | None 31 | 32 | def _after_run(self, _, val): 33 | if val.results is not None: 34 | self.trainer.monitors.process_summary(val.results) 35 | 36 | 37 | -------------------------------------------------------------------------------- /tensorcv/callbacks/trigger.py: -------------------------------------------------------------------------------- 1 | import scipy.misc 2 | import os 3 | from abc import ABCMeta 4 | import os 5 | 6 | import numpy as np 7 | import tensorflow as tf 8 | 9 | from .base import ProxyCallback, Callback 10 | 11 | __all__ = ['PeriodicTrigger'] 12 | 13 | def assert_type(v, tp): 14 | assert isinstance(v, tp), \ 15 | "Expect " + str(tp) + ", but " + str(v.__class__) + " is given!" 16 | 17 | class PeriodicTrigger(ProxyCallback): 18 | """ may not need """ 19 | def __init__(self, trigger_cb, every_k_steps=None, every_k_epochs=None): 20 | 21 | assert_type(trigger_cb, Callback) 22 | super(PeriodicTrigger, self).__init__(trigger_cb) 23 | 24 | assert (every_k_steps is not None) or (every_k_epochs is not None), \ 25 | "every_k_steps and every_k_epochs cannot be both None!" 26 | self._k_step = every_k_steps 27 | self._k_epoch = every_k_epochs 28 | 29 | def __str__(self): 30 | return 'PeriodicTrigger' + str(self.cb) 31 | 32 | def _trigger_step(self): 33 | if self._k_step is None: 34 | return 35 | if self.global_step % self._k_step == 0: 36 | self.cb.trigger() 37 | 38 | def _trigger_epoch(self): 39 | if self._k_epoch is None: 40 | return 41 | if self.epochs_completed % self._k_epoch == 0: 42 | self.cb.trigger() 43 | -------------------------------------------------------------------------------- /tensorcv/dataflow/__init__.py: -------------------------------------------------------------------------------- 1 | # File: __init__.py 2 | # Author: Qian Ge 3 | 4 | from .base import * 5 | from .image import * 6 | from .matlab import * 7 | from .randoms import * 8 | # from .dataset import * 9 | from .normalization import * 10 | -------------------------------------------------------------------------------- /tensorcv/dataflow/argument.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # File: argument.py 4 | # Author: Qian Ge 5 | 6 | from .base import DataFlow 7 | from ..utils.utils import assert_type 8 | 9 | 10 | class ArgumentDataflow(DataFlow): 11 | def __init__(self, dataflow, argument_order, argument_fnc): 12 | assert_type(dataflow, DataFlow) 13 | if not isinstance(argument_order, list): 14 | argument_order = [argument_order] 15 | if not isinstance(argument_fnc, list): 16 | argument_fnc = [argument_fnc] 17 | 18 | self._dataflow = dataflow 19 | 20 | self._order = argument_order 21 | self._fnc = argument_fnc 22 | 23 | def setup(self, epoch_val, batch_size, **kwargs): 24 | self._dataflow .setup(epoch_val, batch_size, **kwargs) 25 | 26 | @property 27 | def epochs_completed(self): 28 | return self._dataflow.epochs_completed 29 | 30 | def reset_epochs_completed(self, val): 31 | self._dataflow.reset_epochs_completed(val) 32 | 33 | def set_batch_size(self, batch_size): 34 | self._dataflow.set_batch_size(batch_size) 35 | 36 | def size(self): 37 | return self._dataflow.size() 38 | 39 | def reset_state(self): 40 | self._dataflow.reset_state() 41 | 42 | def after_reading(self): 43 | self._dataflow.after_reading() 44 | 45 | def next_batch_dict(self): 46 | batch_data_dict = self._dataflow.next_batch_dict() 47 | arg_batch_data_dict = {} 48 | for key, arg_fnc in zip(self._order, self._fnc): 49 | arg_batch_data_dict[key] = arg_fnc(batch_data_dict[key]) 50 | 51 | return arg_batch_data_dict 52 | 53 | def next_batch(self): 54 | print("***** [ArgumentDataflow.next_batch()] not implmented *****") 55 | -------------------------------------------------------------------------------- /tensorcv/dataflow/base.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod, ABCMeta 2 | import numpy as np 3 | 4 | from ..utils.utils import get_rng 5 | 6 | __all__ = ['DataFlow', 'RNGDataFlow'] 7 | 8 | # @six.add_metaclass(ABCMeta) 9 | class DataFlow(object): 10 | """ base class for dataflow """ 11 | # self._epochs_completed = 0 12 | 13 | def before_read_setup(self, **kwargs): 14 | pass 15 | 16 | def setup(self, epoch_val, batch_size, **kwargs): 17 | self.reset_epochs_completed(epoch_val) 18 | self.set_batch_size(batch_size) 19 | self.reset_state() 20 | self._setup() 21 | 22 | def _setup(self, **kwargs): 23 | pass 24 | 25 | # @property 26 | # def channels(self): 27 | # try: 28 | # return self._num_channels 29 | # except AttributeError: 30 | # self._num_channels = self._get_channels() 31 | # return self._num_channels 32 | 33 | # def _get_channels(self): 34 | # return 0 35 | 36 | # @property 37 | # def im_size(self): 38 | # try: 39 | # return self._im_size 40 | # except AttributeError: 41 | # self._im_size = self._get_im_size() 42 | # return self._im_size 43 | 44 | def _get_im_size(self): 45 | return 0 46 | 47 | @property 48 | def epochs_completed(self): 49 | return self._epochs_completed 50 | 51 | def reset_epochs_completed(self, val): 52 | self._epochs_completed = val 53 | 54 | @abstractmethod 55 | def next_batch(self): 56 | return 57 | 58 | def next_batch_dict(self): 59 | print('Need to be implemented!') 60 | 61 | def set_batch_size(self, batch_size): 62 | self._batch_size = batch_size 63 | 64 | def size(self): 65 | raise NotImplementedError() 66 | 67 | def reset_state(self): 68 | self._reset_state() 69 | 70 | def _reset_state(self): 71 | pass 72 | 73 | def after_reading(self): 74 | pass 75 | 76 | class RNGDataFlow(DataFlow): 77 | def _reset_state(self): 78 | self.rng = get_rng(self) 79 | 80 | def _suffle_file_list(self): 81 | idxs = np.arange(self.size()) 82 | self.rng.shuffle(idxs) 83 | self.file_list = self.file_list[idxs] 84 | 85 | def suffle_data(self): 86 | self._suffle_file_list() 87 | -------------------------------------------------------------------------------- /tensorcv/dataflow/bk/image.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import collections 3 | import scipy.io 4 | import scipy.misc 5 | import os 6 | import pickle 7 | import random 8 | 9 | import tensorflow as tf 10 | 11 | class CIFAT10(object): 12 | def __init__(self, file_path): 13 | self._file_list = [os.path.join(file_path, 'data_batch_' + str(batch_id)) for batch_id in range(1,6)] 14 | self._num_channels = 3 15 | self._batch_id = 0 16 | self._batch_file_id = -1 17 | self._image = [] 18 | self._epochs_completed = 0 19 | 20 | def next_batch_file(self): 21 | if self._batch_file_id >= len(self._file_list)-1: 22 | self._batch_file_id = 0 23 | self._epochs_completed += 1 24 | else: 25 | self._batch_file_id += 1 26 | self._image = unpickle(self._file_list[self._batch_file_id]) 27 | random.shuffle(self._image) 28 | # scipy.misc.imsave('test.png', self.image[100]) 29 | 30 | def next_batch(self, batch_size): 31 | batch_id_end = self._batch_id + batch_size 32 | if batch_id_end >= len(self._image): 33 | self.next_batch_file() 34 | self._batch_id = 0 35 | batch_id_end = batch_size 36 | batch_image = self._image[self._batch_id:batch_id_end] 37 | self._batch_id = batch_id_end 38 | return batch_image 39 | 40 | @property 41 | def epochs_completed(self): 42 | return self._epochs_completed 43 | 44 | 45 | class TestImage(object): 46 | def __init__(self, test_file_path, patch_size, sample_mean = 0, num_channels = 1): 47 | self.patch_size = patch_size 48 | self.num_channels = num_channels 49 | self._sample_mean = sample_mean 50 | 51 | self._file_list = ([os.path.join(test_file_path, file) 52 | for file in os.listdir(test_file_path) 53 | if file.endswith(".mat")]) 54 | 55 | self._image_id = -1 56 | self._image = None 57 | self._cnt_row = -1 58 | 59 | self.cur_patch = np.empty(shape=[1, self.patch_size, self.patch_size, self.num_channels]) 60 | 61 | def next_image(self): 62 | if self._image_id >= len(self._file_list) - 1: 63 | return None 64 | self._image_id += 1 65 | self._image = load_image(self._file_list[self._image_id], self.num_channels) 66 | 67 | self.im_rows, self.im_cols = self._image.shape[0], self._image.shape[1] 68 | self._cnt_row = 1 69 | 70 | half_patch_size = int(self.patch_size/2) 71 | self.row_id, self.col_id = half_patch_size - 1, half_patch_size - 2 72 | self.patch_row_start, self.patch_row_end = self.row_id - half_patch_size + 1, self.row_id + half_patch_size 73 | self.patch_col_start, self.patch_col_end = self.col_id - half_patch_size + 1, self.col_id + half_patch_size 74 | 75 | for cur_channel in range(0, self.num_channels): 76 | self._image[:,:,:,cur_channel] -= self._sample_mean[cur_channel] 77 | 78 | return self._image 79 | 80 | def next_patch(self): 81 | if self._image is None: 82 | return None 83 | half_patch_size = int(self.patch_size/2) 84 | if self.col_id >= self.im_cols - half_patch_size - 1: 85 | self.row_id += 1 86 | self.patch_row_start += 1 87 | self.patch_row_end += 1 88 | print('Row: ' + str(self.row_id)) 89 | self._cnt_row += 1 90 | self.col_id = half_patch_size - 1 91 | self.patch_col_start, self.patch_col_end = self.col_id - half_patch_size + 1, self.col_id + half_patch_size 92 | else: 93 | self.col_id += 1 94 | self.patch_col_start += 1 95 | self.patch_col_end += 1 96 | 97 | if self.row_id >= self.im_rows - half_patch_size: 98 | return None 99 | # print(self.patch_row_start,self.patch_row_end,self.patch_col_start,self.patch_col_end ) 100 | if self.num_channels > 1: 101 | cnt_depth = 0 102 | for channel_id in range(0, self.num_channels): 103 | self.cur_patch[:,:,:,cnt_depth] = self._image[self.patch_row_start:self.patch_row_end + 1, 104 | self.patch_col_start:self.patch_col_end + 1, channel_id].transpose() 105 | cnt_depth += 1 106 | else: 107 | self.cur_patch[:,:,:,:] = np.reshape(self._image[self.patch_row_start:self.patch_row_end + 1, 108 | self.patch_col_start:self.patch_col_end + 1].transpose(), 109 | [1, self.patch_size, self.patch_size, self.num_channels]) 110 | return self.cur_patch 111 | 112 | def reset_cnt_row(self): 113 | self._cnt_row = 1 114 | def get_cnt_row(self): 115 | return self._cnt_row 116 | 117 | class TrainData(object): 118 | def __init__(self, file_list, sample_mean = 0, num_channels = 1): 119 | self._num_channels = num_channels 120 | 121 | self._file_list = file_list 122 | self._num_image = len(self._file_list) 123 | 124 | self._image_id = 0 125 | self._image = [] 126 | self._label = [] 127 | self._mask = [] 128 | self._epochs_completed = 0 129 | 130 | self._sample_mean = sample_mean 131 | 132 | @property 133 | def epochs_completed(self): 134 | return self._epochs_completed 135 | 136 | def set_epochs_completed(self, value): 137 | self._epochs_completed = value 138 | 139 | @property 140 | def sample_mean(self): 141 | return self._sample_mean 142 | 143 | def next_image(self): 144 | if self._image_id >= self._num_image - 1: 145 | self._epochs_completed += 1 146 | self._image_id = 1 147 | perm = np.arange(self._num_image) 148 | np.random.shuffle(perm) 149 | self._file_list = self._file_list[perm] 150 | else: 151 | self._image_id += 1 152 | 153 | self._image, self._label, self._mask = load_training_image(self._file_list[self._image_id], num_channels = self._num_channels) 154 | for cur_channel in range(0, self._num_channels): 155 | self._image[:,:,:,cur_channel] -= self._sample_mean[cur_channel] 156 | return self._image, self._label, self._mask 157 | 158 | def next_batch(self, batch_size): 159 | if batch_size > self._num_image: 160 | return None 161 | start = self._image_id 162 | self._image_id += batch_size 163 | if self._image_id >= self._num_image: 164 | self._epochs_completed += 1 165 | perm = np.arange(self._num_image) 166 | np.random.shuffle(perm) 167 | self._file_list = self._file_list[perm] 168 | start = 0 169 | self._image_id = batch_size 170 | end = self._image_id 171 | batch_file_path = self._file_list[start:end] 172 | return load_batch_image(batch_file_path, num_channels = self._num_channels) 173 | 174 | 175 | 176 | def prepare_data_set(file_path, valid_percentage, num_channels = 1, isSubstractMean = True): 177 | file_list = np.array([os.path.join(file_path, file) 178 | for file in os.listdir(file_path) 179 | if file.endswith(".mat")]) 180 | if isSubstractMean: 181 | sample_mean_value = average_train_data(file_list, num_channels) 182 | else: 183 | sample_mean_value = np.empty(shape=[num_channels], dtype = float) 184 | sample_mean = tf.Variable(sample_mean_value, name = 'sample_mean') 185 | num_image = len(file_list) 186 | num_valid = int(num_image*valid_percentage) 187 | num_train = num_image - num_valid 188 | train = TrainData(file_list[:num_train], sample_mean = sample_mean_value, num_channels = num_channels) 189 | validate = TrainData(file_list[num_train:], sample_mean = sample_mean_value, num_channels = num_channels) 190 | print('Number of training image: {}. Number of validation image: {}.'.format(num_train, num_valid)) 191 | # print('Mean of training samples: {}'.format(sampel_mean)) 192 | ds = collections.namedtuple('TrainData', ['train', 'validate']) 193 | return ds(train = train, validate = validate) 194 | 195 | def load_batch_image(batch_file_path, num_channels = 1): 196 | image_list = [] 197 | # image_list = np.empty(shape = [0, im_size, im_size, num_channels]) 198 | for file_path in batch_file_path: 199 | image ,_ ,_ = load_training_image(file_path, num_channels = num_channels) 200 | image_list.extend(image) 201 | # image_list = np.append(image_list, image, axis=0) 202 | return np.array(image_list) 203 | 204 | 205 | def load_training_image(file_path, num_channels = 1): 206 | # print('Loading training file ' + file_path) 207 | mat = scipy.io.loadmat(file_path) 208 | image = mat['level1Edge'].astype('float') 209 | label = mat['GT'] 210 | mask = mat['Mask'] 211 | image = np.reshape(image, [1, image.shape[0], image.shape[1], num_channels]) 212 | label = np.reshape(label, [1, label.shape[0], label.shape[1]]) 213 | mask = np.reshape(mask, [1, mask.shape[0], mask.shape[1]]) 214 | # print('Load successfully.') 215 | return image, label, mask 216 | 217 | def load_image(test_file_path, num_channels): 218 | print('Loading test file ' + test_file_path + '...') 219 | mat = scipy.io.loadmat(test_file_path) 220 | image = mat['level1Edge'].astype('float') 221 | print('Load successfully.') 222 | return np.array(np.reshape(image, [1, image.shape[0], image.shape[1], num_channels])) 223 | 224 | def average_train_data(file_list, num_channels): 225 | mean_list = np.empty(shape=[num_channels], dtype = float) 226 | for cur_file_path in file_list: 227 | image, label, mask = load_training_image(cur_file_path, num_channels = num_channels) 228 | for cur_channel in range(0, num_channels): 229 | mean_list[cur_channel] += np.ma.masked_array(image[:,:,:, cur_channel], mask = mask).mean() 230 | return mean_list/len(file_list) 231 | 232 | 233 | def unpickle(file): 234 | with open(file, 'rb') as fo: 235 | dict = pickle.load(fo, encoding='bytes') 236 | image = dict[b'data'] 237 | 238 | r = image[:,:32*32].reshape(-1,32,32) 239 | g = image[:,32*32: 2*32*32].reshape(-1,32,32) 240 | b = image[:,2*32*32:].reshape(-1,32,32) 241 | 242 | image = np.stack((r,g,b),axis=-1) 243 | return image 244 | 245 | -------------------------------------------------------------------------------- /tensorcv/dataflow/common.py: -------------------------------------------------------------------------------- 1 | # File: common.py 2 | # Author: Qian Ge 3 | 4 | import os 5 | 6 | from scipy import misc 7 | import numpy as np 8 | 9 | from .preprocess import resize_image_with_smallest_side, random_crop_to_size 10 | from .normalization import identity 11 | 12 | 13 | def get_file_list(file_dir, file_ext, sub_name=None): 14 | # assert file_ext in ['.mat', '.png', '.jpg', '.jpeg'] 15 | re_list = [] 16 | 17 | if sub_name is None: 18 | return np.array([os.path.join(root, name) 19 | for root, dirs, files in os.walk(file_dir) 20 | for name in sorted(files) if name.lower().endswith(file_ext)]) 21 | else: 22 | return np.array([os.path.join(root, name) 23 | for root, dirs, files in os.walk(file_dir) 24 | for name in sorted(files) if name.lower().endswith(file_ext) and sub_name.lower() in name.lower()]) 25 | # for root, dirs, files in os.walk(file_dir): 26 | # for name in files: 27 | # if name.lower().endswith(file_ext): 28 | # re_list.append(os.path.join(root, name)) 29 | # return np.array(re_list) 30 | 31 | def get_folder_list(folder_dir): 32 | return np.array([os.path.join(folder_dir, folder) 33 | for folder in os.listdir(folder_dir) 34 | if os.path.join(folder_dir, folder)]) 35 | 36 | def get_folder_names(folder_dir): 37 | return np.array([name for name in os.listdir(folder_dir) 38 | if os.path.join(folder_dir, name)]) 39 | 40 | def input_val_range(in_mat): 41 | # TODO to be modified 42 | max_val = np.amax(in_mat) 43 | min_val = np.amin(in_mat) 44 | if max_val > 1: 45 | max_in_val = 255.0 46 | half_in_val = 128.0 47 | elif min_val >= 0: 48 | max_in_val = 1.0 49 | half_in_val = 0.5 50 | else: 51 | max_in_val = 1.0 52 | half_in_val = 0 53 | return max_in_val, half_in_val 54 | 55 | def tanh_normalization(data, half_in_val): 56 | return (data*1.0 - half_in_val)/half_in_val 57 | 58 | 59 | def dense_to_one_hot(labels_dense, num_classes): 60 | """Convert class labels from scalars to one-hot vectors.""" 61 | num_labels = labels_dense.shape[0] 62 | index_offset = np.arange(num_labels) * num_classes 63 | labels_one_hot = np.zeros((num_labels, num_classes)) 64 | labels_one_hot.flat[[index_offset + labels_dense.ravel()]] = 1 65 | return labels_one_hot 66 | 67 | def reverse_label_dict(label_dict): 68 | label_dict_reverse = {} 69 | for key, value in label_dict.items(): 70 | label_dict_reverse[value] = key 71 | return label_dict_reverse 72 | 73 | def load_image(im_path, read_channel=None, pf=identity, resize=None, resize_crop=None): 74 | if resize is not None: 75 | print_warning('[load_image] resize will be unused in the future!\ 76 | Use pf (preprocess_fnc) instead.') 77 | if resize_crop is not None: 78 | print_warning('[load_image] resize_crop will be unused in the future!\ 79 | Use pf (preprocess_fnc) instead.') 80 | 81 | # im = cv2.imread(im_path, self._cv_read) 82 | if read_channel is None: 83 | im = misc.imread(im_path) 84 | elif read_channel == 3: 85 | im = misc.imread(im_path, mode='RGB') 86 | else: 87 | im = misc.imread(im_path, flatten=True) 88 | 89 | if len(im.shape) < 3: 90 | try: 91 | im = misc.imresize(im, (resize[0], resize[1], 1)) 92 | except TypeError: 93 | pass 94 | if resize_crop is not None: 95 | im = resize_image_with_smallest_side(im, resize_crop) 96 | im = random_crop_to_size(im, resize_crop) 97 | im = pf(im) 98 | im = np.reshape(im, [1, im.shape[0], im.shape[1], 1]) 99 | else: 100 | try: 101 | im = misc.imresize(im, (resize[0], resize[1], im.shape[2])) 102 | except TypeError: 103 | pass 104 | if resize_crop is not None: 105 | im = resize_image_with_smallest_side(im, resize_crop) 106 | im = random_crop_to_size(im, resize_crop) 107 | im = pf(im) 108 | im = np.reshape(im, [1, im.shape[0], im.shape[1], im.shape[2]]) 109 | return im 110 | 111 | 112 | def print_warning(warning_str): 113 | print('[**** warning ****] {}'.format(warning_str)) 114 | 115 | -------------------------------------------------------------------------------- /tensorcv/dataflow/dataset/BSDS500.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | from scipy.io import loadmat 5 | 6 | from ..common import * 7 | from ..normalization import * 8 | from ..image import ImageFromFile 9 | 10 | __all__ = ['BSDS500', 'BSDS500HED'] 11 | 12 | class BSDS500(ImageFromFile): 13 | def __init__(self, name, 14 | data_dir='', 15 | shuffle=True, 16 | normalize=None, 17 | is_mask=False, 18 | normalize_fnc=identity, 19 | resize=None): 20 | 21 | assert name in ['train', 'test', 'val', 'infer'] 22 | self._load_name = name 23 | self._is_mask = is_mask 24 | 25 | super(BSDS500, self).__init__('.jpg', 26 | data_dir=data_dir, 27 | num_channel=3, 28 | shuffle=shuffle, 29 | normalize=normalize, 30 | normalize_fnc=normalize_fnc, 31 | resize=resize) 32 | 33 | def _load_file_list(self, _): 34 | im_dir = os.path.join(self.data_dir, 'images', self._load_name) 35 | self._im_list = get_file_list(im_dir, '.jpg') 36 | 37 | gt_dir = os.path.join(self.data_dir, 'groundTruth', self._load_name) 38 | self._gt_list = get_file_list(gt_dir, '.mat') 39 | 40 | # TODO may remove later 41 | if self._is_mask: 42 | mask_dir = os.path.join(self.data_dir, 'mask', self._load_name) 43 | self._mask_list = get_file_list(mask_dir, '.mat') 44 | 45 | if self._shuffle: 46 | self._suffle_file_list() 47 | 48 | def _load_data(self, start, end): 49 | input_im_list = [] 50 | input_label_list = [] 51 | if self._is_mask: 52 | input_mask_list = [] 53 | 54 | for k in range(start, end): 55 | cur_path = self._im_list[k] 56 | im = load_image(cur_path, read_channel=self._read_channel, 57 | resize=self._resize) 58 | input_im_list.extend(im) 59 | 60 | gt = loadmat(self._gt_list[k])['groundTruth'][0] 61 | num_gt = gt.shape[0] 62 | gt = sum(gt[k]['Boundaries'][0][0] for k in range(num_gt)) 63 | gt = gt.astype('float32') 64 | gt = gt * 1.0 / np.amax(gt) 65 | # gt = 1.0*gt/num_gt 66 | try: 67 | gt = misc.imresize(gt, (self._resize[0], self._resize[1])) 68 | except TypeError: 69 | pass 70 | gt = np.reshape(gt, [1, gt.shape[0], gt.shape[1]]) 71 | input_label_list.extend(gt) 72 | 73 | if self._is_mask: 74 | mask = np.reshape(gt, [1, mask.shape[0], mask.shape[1]]) 75 | input_mask_list.extend(mask) 76 | 77 | input_im_list = self._normalize_fnc(np.array(input_im_list), 78 | self._get_max_in_val(), 79 | self._get_half_in_val()) 80 | 81 | input_label_list = np.array(input_label_list) 82 | if self._is_mask: 83 | input_mask_list = np.array(input_mask_list) 84 | return [input_im_list, input_label_list, input_mask_list] 85 | else: 86 | return [input_im_list, input_label_list] 87 | 88 | def _suffle_file_list(self): 89 | idxs = np.arange(self.size()) 90 | self.rng.shuffle(idxs) 91 | self._im_list = self._im_list[idxs] 92 | self._gt_list = self._gt_list[idxs] 93 | try: 94 | self._mask_list = self._mask_list[idxs] 95 | except AttributeError: 96 | pass 97 | 98 | class BSDS500HED(BSDS500): 99 | def _load_file_list(self, _): 100 | im_dir = os.path.join(self.data_dir, 'images', self._load_name) 101 | self._im_list = get_file_list(im_dir, '.jpg') 102 | 103 | gt_dir = os.path.join(self.data_dir, 'groundTruth', self._load_name) 104 | self._gt_list = get_file_list(gt_dir, '.png') 105 | 106 | if self._shuffle: 107 | self._suffle_file_list() 108 | 109 | def _load_data(self, start, end): 110 | input_im_list = [] 111 | input_label_list = [] 112 | 113 | for k in range(start, end): 114 | im = load_image(self._im_list[k], read_channel=self._read_channel, 115 | resize=self._resize) 116 | input_im_list.extend(im) 117 | 118 | gt = load_image(self._gt_list[k], read_channel=1, 119 | resize=self._resize) 120 | gt = gt * 1.0 / np.amax(gt) 121 | 122 | # gt = loadmat(self._gt_list[k])['groundTruth'][0] 123 | # num_gt = gt.shape[0] 124 | # gt = sum(gt[k]['Boundaries'][0][0] for k in range(num_gt)) 125 | # gt = gt.astype('float32') 126 | # gt = 1.0*gt/num_gt 127 | # try: 128 | # gt = misc.imresize(gt, (self._resize[0], self._resize[1])) 129 | # except TypeError: 130 | # pass 131 | gt = np.squeeze(gt, axis = -1) 132 | input_label_list.extend(gt) 133 | 134 | 135 | input_im_list = self._normalize_fnc(np.array(input_im_list), 136 | self._get_max_in_val(), 137 | self._get_half_in_val()) 138 | 139 | input_label_list = np.array(input_label_list) 140 | 141 | return [input_im_list, input_label_list] 142 | 143 | 144 | if __name__ == '__main__': 145 | a = BSDS500('val','E:\\GITHUB\\workspace\\CNN\\dataset\\BSR_bsds500\\BSR\\BSDS500\\data\\') 146 | print(a.next_batch()) -------------------------------------------------------------------------------- /tensorcv/dataflow/dataset/CIFAR.py: -------------------------------------------------------------------------------- 1 | # File: CIFAR.py 2 | # Author: Qian Ge 3 | 4 | import os 5 | import pickle 6 | 7 | import numpy as np 8 | 9 | from ..base import RNGDataFlow 10 | 11 | __all__ = ['CIFAR'] 12 | 13 | ## TODO Add batch size 14 | class CIFAR(RNGDataFlow): 15 | def __init__(self, data_dir='', shuffle=True, normalize=None): 16 | self.num_channels = 3 17 | self.im_size = [32, 32] 18 | 19 | assert os.path.isdir(data_dir) 20 | self.data_dir = data_dir 21 | 22 | self.shuffle = shuffle 23 | self._normalize = normalize 24 | 25 | self.setup(epoch_val=0, batch_size=1) 26 | self._file_list = [os.path.join(data_dir, 'data_batch_' + str(batch_id)) for batch_id in range(1,6)] 27 | 28 | # self._load_files() 29 | self._num_image = self.size() 30 | 31 | self._image_id = 0 32 | self._batch_file_id = -1 33 | self._image = [] 34 | self._next_batch_file() 35 | 36 | def _next_batch_file(self): 37 | if self._batch_file_id >= len(self._file_list) - 1: 38 | self._batch_file_id = 0 39 | self._epochs_completed += 1 40 | else: 41 | self._batch_file_id += 1 42 | self._image = np.array(unpickle(self._file_list[self._batch_file_id])) 43 | # TODO to be modified 44 | if self._normalize == 'tanh': 45 | self._image = (self._image*1. - 128)/128.0 46 | 47 | if self.shuffle: 48 | self._suffle_files() 49 | 50 | def _suffle_files(self): 51 | idxs = np.arange(len(self._image)) 52 | 53 | self.rng.shuffle(idxs) 54 | self._image = self._image[idxs] 55 | 56 | def size(self): 57 | try: 58 | return self.data_size 59 | except AttributeError: 60 | data_size = 0 61 | for k in range(len(self._file_list)): 62 | tmp_image = unpickle(self._file_list[k]) 63 | data_size += len(tmp_image) 64 | self.data_size = data_size 65 | return self.data_size 66 | 67 | def next_batch(self): 68 | # TODO assume batch_size smaller than images in one file 69 | assert self._batch_size <= self.size(), \ 70 | "batch_size {} cannot be larger than data size {}".\ 71 | format(self._batch_size, self.size()) 72 | 73 | start = self._image_id 74 | self._image_id += self._batch_size 75 | end = self._image_id 76 | batch_files = np.array(self._image[start:end]) 77 | 78 | if self._image_id + self._batch_size > len(self._image): 79 | self._next_batch_file() 80 | self._image_id = 0 81 | if self.shuffle: 82 | self._suffle_files() 83 | 84 | return [batch_files] 85 | 86 | 87 | def unpickle(file): 88 | with open(file, 'rb') as fo: 89 | dict = pickle.load(fo, encoding='bytes') 90 | image = dict[b'data'] 91 | 92 | r = image[:,:32*32].reshape(-1,32,32) 93 | g = image[:,32*32: 2*32*32].reshape(-1,32,32) 94 | b = image[:,2*32*32:].reshape(-1,32,32) 95 | 96 | image = np.stack((r,g,b),axis=-1) 97 | return image 98 | 99 | if __name__ == '__main__': 100 | a = CIFAR('D:\\Qian\\GitHub\\workspace\\tensorflow-DCGAN\\cifar-10-python.tar\\') 101 | t = a.next_batch()[0] 102 | print(t) 103 | print(t.shape) 104 | print(a.size()) 105 | # print(a.next_batch()[0]) 106 | # print(a.next_batch()[0]) -------------------------------------------------------------------------------- /tensorcv/dataflow/dataset/MNIST.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: UTF-8 -*- 3 | # File: MNIST.py 4 | # Author: Qian Ge 5 | 6 | import os 7 | 8 | import numpy as np 9 | from tensorflow.examples.tutorials.mnist import input_data 10 | 11 | from ..base import RNGDataFlow 12 | 13 | __all__ = ['MNIST', 'MNISTLabel'] 14 | 15 | def get_mnist_im_label(name, mnist_data): 16 | if name == 'train': 17 | return mnist_data.train.images, mnist_data.train.labels 18 | elif name == 'val': 19 | return mnist_data.validation.images, mnist_data.validation.labels 20 | else: 21 | return mnist_data.test.images, mnist_data.test.labels 22 | 23 | # TODO read data without tensorflow 24 | class MNIST(RNGDataFlow): 25 | """ 26 | 27 | """ 28 | def __init__(self, name, data_dir='', shuffle=True, normalize=None): 29 | 30 | self.num_channels = 1 31 | self.im_size = [28, 28] 32 | 33 | assert os.path.isdir(data_dir) 34 | self.data_dir = data_dir 35 | 36 | self.shuffle = shuffle 37 | self._normalize = normalize 38 | 39 | assert name in ['train', 'test', 'val'] 40 | self.setup(epoch_val=0, batch_size=1) 41 | 42 | self._load_files(name) 43 | self._num_image = self.size() 44 | self._image_id = 0 45 | 46 | def _load_files(self, name): 47 | mnist_data = input_data.read_data_sets(self.data_dir, one_hot=False) 48 | self.im_list = [] 49 | self.label_list = [] 50 | 51 | mnist_images, mnist_labels = get_mnist_im_label(name, mnist_data) 52 | for image, label in zip(mnist_images, mnist_labels): 53 | # TODO to be modified 54 | if self._normalize == 'tanh': 55 | image = image*2.-1. 56 | image = np.reshape(image, [28, 28, 1]) 57 | self.im_list.append(image) 58 | self.label_list.append(label) 59 | self.im_list = np.array(self.im_list) 60 | self.label_list = np.array(self.label_list) 61 | 62 | if self.shuffle: 63 | self._suffle_files() 64 | 65 | def _suffle_files(self): 66 | idxs = np.arange(self.size()) 67 | 68 | self.rng.shuffle(idxs) 69 | self.im_list = self.im_list[idxs] 70 | self.label_list = self.label_list[idxs] 71 | 72 | def size(self): 73 | return self.im_list.shape[0] 74 | 75 | def next_batch(self): 76 | assert self._batch_size <= self.size(), \ 77 | "batch_size {} cannot be larger than data size {}".\ 78 | format(self._batch_size, self.size()) 79 | start = self._image_id 80 | self._image_id += self._batch_size 81 | end = self._image_id 82 | batch_files = self.im_list[start:end] 83 | 84 | if self._image_id + self._batch_size > self._num_image: 85 | self._epochs_completed += 1 86 | self._image_id = 0 87 | if self.shuffle: 88 | self._suffle_files() 89 | return [batch_files] 90 | 91 | class MNISTLabel(MNIST): 92 | 93 | def next_batch(self): 94 | assert self._batch_size <= self.size(), \ 95 | "batch_size {} cannot be larger than data size {}".\ 96 | format(self._batch_size, self.size()) 97 | start = self._image_id 98 | self._image_id += self._batch_size 99 | end = self._image_id 100 | batch_im = self.im_list[start:end] 101 | batch_label = self.label_list[start:end] 102 | 103 | if self._image_id + self._batch_size > self._num_image: 104 | self._epochs_completed += 1 105 | self._image_id = 0 106 | if self.shuffle: 107 | self._suffle_files() 108 | return [batch_im, batch_label] 109 | 110 | 111 | if __name__ == '__main__': 112 | a = MNISTLabel('val','D:\\Qian\\GitHub\\workspace\\tensorflow-DCGAN\\MNIST_data\\') 113 | t = a.next_batch() 114 | print(t) 115 | -------------------------------------------------------------------------------- /tensorcv/dataflow/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .BSDS500 import * 2 | from .CIFAR import * 3 | from .MNIST import * 4 | -------------------------------------------------------------------------------- /tensorcv/dataflow/matlab.py: -------------------------------------------------------------------------------- 1 | # File: matlab.py 2 | # Author: Qian Ge 3 | 4 | import os 5 | from scipy.io import loadmat 6 | 7 | import numpy as np 8 | 9 | from .base import RNGDataFlow 10 | from .common import * 11 | 12 | __all__ = ['MatlabData'] 13 | 14 | class MatlabData(RNGDataFlow): 15 | """ dataflow from .mat file with mask """ 16 | def __init__(self, 17 | data_dir='', 18 | mat_name_list=None, 19 | mat_type_list=None, 20 | shuffle=True, 21 | normalize=None): 22 | 23 | self.setup(epoch_val=0, batch_size=1) 24 | 25 | self.shuffle = shuffle 26 | self._normalize = normalize 27 | 28 | assert os.path.isdir(data_dir) 29 | self.data_dir = data_dir 30 | 31 | assert mat_name_list is not None, 'mat_name_list cannot be None' 32 | if not isinstance(mat_name_list, list): 33 | mat_name_list = [mat_name_list] 34 | self._mat_name_list = mat_name_list 35 | if mat_type_list is None: 36 | mat_type_list = ['float']*len(self._mat_name_list) 37 | assert len(self._mat_name_list) == len(mat_type_list),\ 38 | 'Length of mat_name_list and mat_type_list has to be the same!' 39 | self._mat_type_list = mat_type_list 40 | 41 | self._load_file_list() 42 | self._get_im_size() 43 | self._num_image = self.size() 44 | self._image_id = 0 45 | 46 | def _get_im_size(self): 47 | # Run after _load_file_list 48 | # Assume all the image have the same size 49 | mat = loadmat(self.file_list[0]) 50 | cur_mat = load_image_from_mat(mat, self._mat_name_list[0], 51 | self._mat_type_list[0]) 52 | if len(cur_mat.shape) < 3: 53 | self.num_channels = 1 54 | else: 55 | self.num_channels = cur_mat.shape[2] 56 | self.im_size = [cur_mat.shape[0], cur_mat.shape[1]] 57 | 58 | def _load_file_list(self): 59 | # data_dir = os.path.join(self.data_dir) 60 | self.file_list = np.array([os.path.join(self.data_dir, file) 61 | for file in os.listdir(self.data_dir) if file.endswith(".mat")]) 62 | 63 | if self.shuffle: 64 | self._suffle_file_list() 65 | 66 | def next_batch(self): 67 | assert self._batch_size <= self.size(), \ 68 | "batch_size cannot be larger than data size" 69 | 70 | start = self._image_id 71 | self._image_id += self._batch_size 72 | end = self._image_id 73 | batch_file_path = self.file_list[start:end] 74 | 75 | if self._image_id + self._batch_size > self._num_image: 76 | self._epochs_completed += 1 77 | self._image_id = 0 78 | if self.shuffle: 79 | self._suffle_file_list() 80 | return self._load_data(batch_file_path) 81 | 82 | def _load_data(self, batch_file_path): 83 | # TODO deal with num_channels 84 | input_data = [[] for i in range(len(self._mat_name_list))] 85 | 86 | for file_path in batch_file_path: 87 | mat = loadmat(file_path) 88 | cur_data = load_image_from_mat(mat, self._mat_name_list[0], 89 | self._mat_type_list[0]) 90 | cur_data = np.reshape(cur_data, 91 | [1, cur_data.shape[0], cur_data.shape[1], self.num_channels]) 92 | input_data[0].extend(cur_data) 93 | 94 | for k in range(1, len(self._mat_name_list)): 95 | cur_data = load_image_from_mat(mat, 96 | self._mat_name_list[k], self._mat_type_list[k]) 97 | cur_data = np.reshape(cur_data, 98 | [1, cur_data.shape[0], cur_data.shape[1]]) 99 | input_data[k].extend(cur_data) 100 | input_data = [np.array(data) for data in input_data] 101 | 102 | if self._normalize == 'tanh': 103 | try: 104 | input_data[0] = tanh_normalization(input_data[0], self._half_in_val) 105 | except AttributeError: 106 | self._input_val_range(input_data[0][0]) 107 | input_data[0] = tanh_normalization(input_data[0], self._half_in_val) 108 | 109 | return input_data 110 | 111 | def _input_val_range(self, in_mat): 112 | # TODO to be modified 113 | self._max_in_val, self._half_in_val = input_val_range(in_mat) 114 | 115 | def size(self): 116 | return len(self.file_list) 117 | 118 | def load_image_from_mat(matfile, name, datatype): 119 | mat = matfile[name].astype(datatype) 120 | return mat 121 | 122 | if __name__ == '__main__': 123 | a = MatlabData(data_dir='D:\\GoogleDrive_Qian\\Foram\\Training\\CNN_GAN_ORIGINAL_64\\', 124 | mat_name_list=['level1Edge'], 125 | normalize='tanh') 126 | print(a.next_batch()[0].shape) 127 | print(a.next_batch()[0][:,30:40,30:40,:]) 128 | print(np.amax(a.next_batch()[0])) -------------------------------------------------------------------------------- /tensorcv/dataflow/normalization.py: -------------------------------------------------------------------------------- 1 | # File: normalization.py 2 | # Author: Qian Ge 3 | 4 | import numpy as np 5 | 6 | 7 | def identity(input_val, *args): 8 | return input_val 9 | 10 | def normalize_tanh(input_val, max_in, half_in): 11 | return (input_val*1.0 - half_in)/half_in 12 | 13 | def normalize_one(input_val, max_in, half_in): 14 | return input_val*1.0/max_in 15 | -------------------------------------------------------------------------------- /tensorcv/dataflow/operation.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # File: operation.py 4 | # Author: Qian Ge 5 | 6 | import numpy as np 7 | import copy 8 | 9 | from .base import DataFlow 10 | from ..utils.utils import assert_type 11 | 12 | 13 | def display_dataflow(dataflow, data_name='data', simple=False): 14 | assert_type(dataflow, DataFlow) 15 | 16 | n_sample = dataflow.size() 17 | try: 18 | label_list = dataflow.get_label_list() 19 | n_class = len(set(label_list)) 20 | print('[{}] num of samples {}, num of classes {}'. 21 | format(data_name, n_sample, n_class)) 22 | if not simple: 23 | nelem_dict = {} 24 | for c_label in label_list: 25 | try: 26 | nelem_dict[c_label] += 1 27 | except KeyError: 28 | nelem_dict[c_label] = 1 29 | for c_label in nelem_dict: 30 | print('class {}: {}'.format( 31 | dataflow.label_dict_reverse[c_label], 32 | nelem_dict[c_label])) 33 | except AttributeError: 34 | print('[{}] num of samples {}'. 35 | format(data_name, n_sample)) 36 | 37 | 38 | def k_fold_based_class(dataflow, k, shuffle=True): 39 | """Partition dataflow into k equal sized subsamples based on class labels 40 | 41 | Args: 42 | dataflows (DataFlow): DataFlow to be partitioned. Must contain labels. 43 | k (int): number of subsamples 44 | shuffle (bool): data will be shuffled before and after partition 45 | if is true 46 | 47 | Return: 48 | DataFlow: list of k subsample Dataflow 49 | """ 50 | assert_type(dataflow, DataFlow) 51 | k = int(k) 52 | assert k > 0, 'k must be an integer grater than 0!' 53 | dataflow_data_list = dataflow.get_data_list() 54 | if not isinstance(dataflow_data_list, list): 55 | dataflow_data_list = [dataflow_data_list] 56 | 57 | label_list = dataflow.get_label_list() 58 | # im_list = dataflow.get_data_list() 59 | 60 | if shuffle: 61 | dataflow.suffle_data() 62 | 63 | class_dict = {} 64 | for idx, cur_label in enumerate(label_list): 65 | try: 66 | for data_idx, data in enumerate(dataflow_data_list): 67 | class_dict[cur_label][data_idx] += [data[idx], ] 68 | except KeyError: 69 | class_dict[cur_label] = [[] for i in range(0, len(dataflow_data_list))] 70 | for data_idx, data in enumerate(dataflow_data_list): 71 | class_dict[cur_label][data_idx] = [data[idx], ] 72 | # class_dict[cur_label] = [cur_im, ] 73 | 74 | # fold_im_list = [[] for i in range(0, k)] 75 | fold_data_list = [[[] for j in range(0, len(dataflow_data_list))] for i in range(0, k)] 76 | # fold_label_list = [[] for i in range(0, k)] 77 | for label_key in class_dict: 78 | cur_data_list = class_dict[label_key] 79 | nelem = int(len(cur_data_list[0]) / k) 80 | start_id = 0 81 | for fold_id in range(0, k-1): 82 | for data_idx, data_list in enumerate(cur_data_list): 83 | fold_data_list[fold_id][data_idx].extend(data_list[start_id : start_id + nelem]) 84 | start_id += nelem 85 | for data_idx, data_list in enumerate(cur_data_list): 86 | fold_data_list[k - 1][data_idx].extend(data_list[start_id :]) 87 | 88 | data_folds = [copy.deepcopy(dataflow) for i in range(0, k)] 89 | 90 | for cur_fold, cur_data_list in zip(data_folds, fold_data_list): 91 | cur_fold.set_data_list(cur_data_list) 92 | 93 | if shuffle: 94 | cur_fold.suffle_data() 95 | 96 | return data_folds 97 | 98 | 99 | def combine_dataflow(dataflows, shuffle=True): 100 | """Combine several dataflow into one 101 | 102 | Args: 103 | dataflows (DataFlow list): list of DataFlow to be combined 104 | shuffle (bool): data will be shuffled after combined if is true 105 | 106 | Return: 107 | DataFlow: Combined DataFlow 108 | """ 109 | if not isinstance(dataflows, list): 110 | dataflows = [dataflows] 111 | 112 | data_list = [] 113 | for cur_dataflow in dataflows: 114 | assert_type(cur_dataflow, DataFlow) 115 | cur_data_list = cur_dataflow.get_data_list() 116 | if not isinstance(cur_data_list, list): 117 | cur_data_list = [cur_data_list] 118 | data_list.append(cur_data_list) 119 | 120 | num_data_type = len(data_list[0]) 121 | combined_data_list = [[] for i in range(0, num_data_type)] 122 | for cur_data_list in data_list: 123 | for i in range(0, num_data_type): 124 | combined_data_list[i].extend(cur_data_list[i]) 125 | 126 | dataflows[0].set_data_list(combined_data_list) 127 | if shuffle: 128 | dataflows[0].suffle_data() 129 | 130 | return dataflows[0] 131 | -------------------------------------------------------------------------------- /tensorcv/dataflow/preprocess.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # File: preprocess.py 4 | # Author: Qian Ge 5 | 6 | import numpy as np 7 | from scipy import misc 8 | 9 | 10 | def image_fliplr(image): 11 | """ Generate mirror image 12 | 13 | Args: 14 | image (np.array): a 2-D image of shape 15 | [height, width] or a 3-D image of shape 16 | [height, width, channels]. 17 | 18 | Returns: 19 | mirror version of original image. 20 | """ 21 | return np.fliplr(image) 22 | 23 | 24 | def resize_image_with_smallest_side(image, small_size): 25 | """ 26 | Resize single image array with smallest side = small_size and 27 | keep the original aspect ratio. 28 | 29 | Args: 30 | image (np.array): a 2-D image of shape 31 | [height, width] or a 3-D image of shape 32 | [height, width, channels]. 33 | small_size (int): A 1-D int. The smallest side of resize image. 34 | 35 | Returns: 36 | rescaled image 37 | """ 38 | im_shape = image.shape 39 | shape_dim = len(im_shape) 40 | assert shape_dim <= 3 and shape_dim >= 2,\ 41 | 'Wrong format of image!Shape is {}'.format(im_shape) 42 | 43 | height = float(im_shape[0]) 44 | width = float(im_shape[1]) 45 | 46 | if height <= width: 47 | new_height = int(small_size) 48 | new_width = int(width * new_height/height) 49 | else: 50 | new_width = int(small_size) 51 | new_height = int(height * new_width/width) 52 | 53 | if shape_dim == 2: 54 | im = misc.imresize(image, (new_height, new_width)) 55 | elif shape_dim == 3: 56 | im = misc.imresize(image, (new_height, new_width, image.shape[2])) 57 | return im 58 | 59 | 60 | def random_crop_to_size(image, crop_size): 61 | """ Rondomly crop an image into crop_size 62 | 63 | Args: 64 | image (np.array): a 2-D image of shape 65 | [height, width] or a 3-D image of shape 66 | [height, width, channels]. 67 | The size has to be larger than cropped image. 68 | crop_size (int or length 2 list): The image size after cropped. 69 | 70 | Returns: 71 | cropped image 72 | """ 73 | crop_size = get_shape2D(crop_size) 74 | im_shape = image.shape 75 | shape_dim = len(im_shape) 76 | assert shape_dim <= 3 and shape_dim >= 2, 'Wrong format of image!' 77 | 78 | height = im_shape[0] 79 | width = im_shape[1] 80 | assert height >= crop_size[0] and width >= crop_size[1],\ 81 | 'Image must be larger than crop size! {}'.format(im_shape) 82 | 83 | s_h = int(np.floor((height - crop_size[0] + 1) * np.random.rand())) 84 | s_w = int(np.floor((width - crop_size[1] + 1) * np.random.rand())) 85 | 86 | return image[s_h:s_h + crop_size[0], s_w:s_w + crop_size[1]] 87 | 88 | 89 | def four_connor_crop(image, crop_size): 90 | """ Crop an image into crop_size with four corner crops 91 | 92 | Args: 93 | image (np.array): a 2-D image of shape 94 | [height, width] or a 3-D image of shape 95 | [height, width, channels]. 96 | The size has to be larger than cropped image. 97 | crop_size (int or length 2 list): The image size after cropped. 98 | 99 | Returns: 100 | four cropped images 101 | """ 102 | crop_size = get_shape2D(crop_size) 103 | im_shape = image.shape 104 | shape_dim = len(im_shape) 105 | assert shape_dim <= 3 and shape_dim >= 2, 'Wrong format of image!' 106 | height = im_shape[0] 107 | width = im_shape[1] 108 | assert height >= crop_size[0] and width >= crop_size[1],\ 109 | 'Image must be larger than crop size! {}'.format(im_shape) 110 | 111 | crop_im = [] 112 | crop_im.append(image[: crop_size[0], : crop_size[1]]) 113 | crop_im.append(image[: crop_size[0], width - crop_size[1]:]) 114 | crop_im.append(image[height - crop_size[0]:, : crop_size[1]]) 115 | crop_im.append(image[height - crop_size[0]:, width - crop_size[1]:]) 116 | 117 | return crop_im 118 | 119 | 120 | def center_crop(image, crop_size): 121 | """ Center crop an image into crop_size 122 | 123 | Args: 124 | image (np.array): a 2-D image of shape 125 | [height, width] or a 3-D image of shape 126 | [height, width, channels]. 127 | The size has to be larger than cropped image. 128 | crop_size (int or length 2 list): The image size after cropped. 129 | 130 | Returns: 131 | cropped images 132 | """ 133 | crop_size = get_shape2D(crop_size) 134 | im_shape = image.shape 135 | shape_dim = len(im_shape) 136 | assert shape_dim <= 3 and shape_dim >= 2, 'Wrong format of image!' 137 | height = im_shape[0] 138 | width = im_shape[1] 139 | assert height >= crop_size[0] and width >= crop_size[1],\ 140 | 'Image must be larger than crop size! {}'.format(im_shape) 141 | 142 | return image[(height - crop_size[0])//2:(height + crop_size[0])//2, 143 | (width - crop_size[1])//2:(width + crop_size[1])//2] 144 | 145 | 146 | def random_mirror_resize_crop(image, crop_size, scale_range, mirror_rate=0.5): 147 | """ Ramdomly rescale, crop and image. 148 | 149 | Args: 150 | image (np.array): a 2-D image of shape 151 | [height, width] or a 3-D image of shape 152 | [height, width, channels]. 153 | crop_size (int or length 2 list): The image size after cropped. 154 | scale_range (list of int with length 2): The range of scale. 155 | mirror_rate (float): The probability of mirror image. 156 | Must within the range [0, 1] 157 | 158 | Returns: 159 | cropped and rescaled images 160 | """ 161 | im_shape = image.shape 162 | shape_dim = len(im_shape) 163 | 164 | assert mirror_rate >= 0 and mirror_rate <= 1,\ 165 | 'mirror rate must be in range of [0, 1]!' 166 | assert shape_dim <= 3 and shape_dim >= 2,\ 167 | 'Wrong format of image!Shape is {}'.format(im_shape) 168 | 169 | small_size = int(np.random.rand() * (max(scale_range) - min(scale_range)) 170 | + min(scale_range)) 171 | image = resize_image_with_smallest_side(image, small_size) 172 | 173 | image = random_crop_to_size(image, crop_size) 174 | 175 | if np.random.rand() >= mirror_rate: 176 | image = image_fliplr(image) 177 | 178 | return image 179 | 180 | 181 | def get_shape2D(in_val): 182 | """ 183 | Return a 2D shape 184 | 185 | Args: 186 | in_val (int or list with length 2) 187 | 188 | Returns: 189 | list with length 2 190 | """ 191 | if in_val is None: 192 | return None 193 | if isinstance(in_val, int): 194 | return [in_val, in_val] 195 | if isinstance(in_val, list): 196 | assert len(in_val) == 2 197 | return in_val 198 | raise RuntimeError('Illegal shape: {}'.format(in_val)) 199 | -------------------------------------------------------------------------------- /tensorcv/dataflow/randoms.py: -------------------------------------------------------------------------------- 1 | # File: randoms.py 2 | # Author: Qian Ge 3 | 4 | import numpy as np 5 | 6 | from .base import DataFlow 7 | from ..utils.utils import get_rng 8 | 9 | __all__ = ['RandomVec'] 10 | 11 | class RandomVec(DataFlow): 12 | """ random vector input """ 13 | def __init__(self, 14 | len_vec=100): 15 | 16 | self.setup(epoch_val=0, batch_size=1) 17 | self._len_vec = len_vec 18 | 19 | def next_batch(self): 20 | self._epochs_completed += 1 21 | return [np.random.normal(size=(self._batch_size, self._len_vec))] 22 | 23 | def size(self): 24 | return self._batch_size 25 | 26 | def reset_state(self): 27 | self._reset_state() 28 | 29 | def _reset_state(self): 30 | self.rng = get_rng(self) 31 | 32 | if __name__ == '__main__': 33 | vec = RandomVec() 34 | print(vec.next_batch()) 35 | print(vec.next_batch()) -------------------------------------------------------------------------------- /tensorcv/dataflow/sequence.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # File: sequence.py 4 | # Author: Qian Ge 5 | 6 | import numpy as np 7 | import collections 8 | 9 | from .base import DataFlow 10 | from .normalization import identity 11 | from ..utils.utils import assert_type 12 | 13 | class SeqDataflow(DataFlow): 14 | """ base class for sequence data 15 | 16 | """ 17 | def __init__(self, data_dir='', 18 | load_ratio=1, 19 | predict_step=0, 20 | batch_dict_name=None, 21 | normalize_fnc=identity): 22 | self._pred_step = predict_step 23 | self._data_dir = data_dir 24 | self._normalize_fnc = normalize_fnc 25 | self._load_ratio = load_ratio 26 | 27 | if not isinstance(batch_dict_name, list): 28 | batch_dict_name = [batch_dict_name] 29 | self._batch_dict_name = batch_dict_name 30 | 31 | self.load_entire_seq() 32 | 33 | self._data_id = 0 34 | 35 | self.setup(epoch_val=0, batch_size=1) 36 | self.setup_seq_para(num_step=10, stride=1) 37 | self._updata_batch_partition_len() 38 | 39 | def _updata_batch_partition_len(self): 40 | try: 41 | self._batch_partition_len = self.size() // self._batch_size 42 | except AttributeError: 43 | pass 44 | 45 | def set_batch_size(self, batch_size): 46 | self._batch_size = batch_size 47 | self._updata_batch_partition_len() 48 | 49 | def size(self): 50 | return len(self.get_entire_seq()) 51 | 52 | def setup_seq_para(self, num_step, stride): 53 | self._num_step = num_step 54 | self._stride = stride 55 | 56 | def next_batch(self): 57 | b_size = self._batch_size 58 | bp_len = self._batch_partition_len 59 | assert b_size * self._num_step <= self.size() 60 | if self._data_id + bp_len * (b_size - 1) + self._num_step + self._pred_step > self.size(): 61 | self._epochs_completed += 1 62 | # self._data_id = 0 63 | # self._data_id = self._epochs_completed 64 | self._data_id = self._epochs_completed % self._num_step 65 | # self._data_id = self._epochs_completed % (bp_len - self._num_step - self._pred_step) 66 | start_id = self._data_id 67 | 68 | batch_data = [] 69 | for i in range(b_size): 70 | start_id = self._data_id + bp_len * i 71 | end_id = start_id + self._num_step 72 | cur_data = self.load_data(start_id, end_id) 73 | batch_data.append(cur_data) 74 | 75 | self._data_id += self._num_step 76 | # self._data_id += 77 | # print(np.array(batch_data).shape) 78 | return self._batch_transform(batch_data) 79 | 80 | def _batch_transform(self, batch_data): 81 | return batch_data 82 | 83 | # if len(np.array(batch_data).shape) == 3: 84 | # return np.array(batch_data).transpose(1, 0, 2) 85 | # else: 86 | # return np.array(batch_data).transpose(1, 0, 2, 3) 87 | 88 | # def next_batch(self): 89 | # assert self.size() > self._batch_size * self._stride + self._num_step - self._stride 90 | # batch_data = [] 91 | # batch_id = 0 92 | # start_id = self._data_id 93 | # while batch_id < self._batch_size: 94 | # end_id = start_id + self._num_step 95 | # if end_id + 1 > self.size(): 96 | # start_id = 0 97 | # end_id = start_id + self._num_step 98 | # self._epochs_completed += 1 99 | # cur_data = self.load_data(start_id, end_id) 100 | # batch_data.append(cur_data) 101 | # start_id = start_id + self._stride 102 | # batch_id += 1 103 | # return np.array(batch_data).transpose(1, 0, 2) 104 | 105 | def load_data(self, start_id, end_id): 106 | pass 107 | # return self.get_entire_seq()[start_id: end_id] 108 | 109 | def load_entire_seq(self): 110 | pass 111 | 112 | def get_entire_seq(self): 113 | pass 114 | 115 | 116 | class SepWord(SeqDataflow): 117 | def __init__(self, data_dir='', 118 | predict_step=1, 119 | word_dict=None, 120 | batch_dict_name=None, 121 | normalize_fnc=identity): 122 | self.word_dict = word_dict 123 | super(SepWord, self).__init__(data_dir=data_dir, 124 | predict_step=predict_step, 125 | batch_dict_name=batch_dict_name, 126 | normalize_fnc=normalize_fnc) 127 | 128 | def gen_word_dict(self, word_data): 129 | counter = collections.Counter(word_data) 130 | count_pairs = sorted(counter.items(), key=lambda x: (-x[1], x[0])) 131 | words, _ = list(zip(*count_pairs)) 132 | self.word_dict = dict(zip(words, range(len(words)))) 133 | 134 | 135 | class SeqNumber(SeqDataflow): 136 | def _scale(self, data): 137 | normal_dict = self._normalize_fnc(data) 138 | try: 139 | self.scale_dict = normal_dict['scale_dict'] 140 | except KeyError: 141 | pass 142 | return normal_dict['data'] 143 | 144 | # max_data = np.amax(data) 145 | # min_data = np.amin(data) 146 | # return (data - min_data) / (max_data - min_data) 147 | 148 | def load_data(self, start_id, end_id): 149 | feature_seq = self.get_entire_seq()[start_id: end_id] 150 | label = self.get_label_seq()[start_id + self._pred_step: end_id + self._pred_step] 151 | return [feature_seq, label] 152 | 153 | def load_entire_seq(self): 154 | pass 155 | 156 | def get_entire_seq(self): 157 | pass 158 | 159 | def get_label_seq(self): 160 | pass 161 | -------------------------------------------------------------------------------- /tensorcv/dataflow/viz.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # File: viz.py 4 | # Author: Qian Ge 5 | 6 | import matplotlib.pyplot as plt 7 | 8 | from .sequence import SeqDataflow 9 | from ..utils.utils import assert_type 10 | 11 | def plot_seq(dataflow, data_range=None): 12 | assert_type(dataflow, SeqDataflow) 13 | data = dataflow.get_entire_seq() 14 | if data_range is None: 15 | plt.plot(data) 16 | else: 17 | plt.plot(data[data_range]) 18 | plt.show() 19 | -------------------------------------------------------------------------------- /tensorcv/models/__init__.py: -------------------------------------------------------------------------------- 1 | # File: __init__.py 2 | # Author: Qian Ge 3 | 4 | -------------------------------------------------------------------------------- /tensorcv/models/bk/layers.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | def conv(x, filter_height, filter_width, num_filters, 4 | name, stride_x = 1, stride_y = 1, 5 | padding = 'SAME', relu = True): 6 | input_channel = int(x.shape[-1]) 7 | convolve = lambda i, k: tf.nn.conv2d(i, k, 8 | strides=[1, stride_y, stride_x, 1], 9 | padding = padding) 10 | with tf.variable_scope(name) as scope: 11 | # weights = tf.get_variable('weights', shape = [filter_height, filter_width, input_channel, num_filters]) 12 | # biases = tf.get_variable('biases', shape = [num_filters]) 13 | weights = new_normal_variable('weights', 14 | shape = [filter_height, filter_width, input_channel, num_filters]) 15 | biases = new_normal_variable('biases', shape = [num_filters]) 16 | 17 | conv = convolve(x, weights) 18 | bias = tf.nn.bias_add(conv, biases) 19 | # bias = tf.reshape(tf.nn.bias_add(conv, biases), conv.get_shape().as_list()) 20 | 21 | if relu: 22 | relu = tf.nn.relu(bias, name = scope.name) 23 | return relu 24 | else: 25 | return bias 26 | 27 | def dconv(x, filter_height, filter_width, 28 | name, fuse_x = None, 29 | output_shape = [], output_channels = None, 30 | stride_x = 2, stride_y = 2, padding = 'SAME'): 31 | input_channels = int(x.shape[-1]) 32 | 33 | if fuse_x is not None: 34 | output_shape = tf.shape(fuse_x) 35 | output_channels = int(fuse_x.shape[-1]) 36 | elif output_channels is None: 37 | output_channels = output_shape[-1] 38 | 39 | with tf.variable_scope(name) as scope: 40 | # weights = tf.get_variable('weights', shape = [filter_height, filter_width, output_channels, input_channels]) 41 | # biases = tf.get_variable('biases', shape = [output_channels]) 42 | weights = new_normal_variable('weights', 43 | shape = [filter_height, filter_width, output_channels, input_channels]) 44 | biases = new_normal_variable('biases', shape = [output_channels]) 45 | 46 | dconv = tf.nn.conv2d_transpose(x, weights, 47 | output_shape = output_shape, 48 | strides=[1, stride_y, stride_x, 1], 49 | padding = padding, name = scope.name) 50 | bias = tf.nn.bias_add(dconv, biases) 51 | bias = tf.reshape(bias, output_shape) 52 | 53 | if fuse_x is not None: 54 | fuse = tf.add(bias, fuse_x, name = 'fuse') 55 | return fuse 56 | else: 57 | return bias 58 | 59 | def fc(x, num_in, num_out, name, relu = True): 60 | num_in = x.get_shape().as_list()[1] 61 | # num_in = x.shape[-1] 62 | with tf.variable_scope(name) as scope: 63 | # weights = tf.get_variable('weights', shape = [num_in, num_out], trainable = True) 64 | # biases = tf.get_variable('biases', shape = [num_out], trainable = True) 65 | weights = new_normal_variable('weights', shape = [num_in, num_out]) 66 | biases = new_normal_variable('biases', shape = [num_out]) 67 | act = tf.nn.xw_plus_b(x, weights, biases, name = scope.name) 68 | 69 | if relu: 70 | relu = tf.nn.relu(act) 71 | return relu 72 | else: 73 | return act 74 | 75 | def max_pool(x, name, filter_height = 2, filter_width = 2, 76 | stride_x = 2, stride_y = 2, padding = 'SAME'): 77 | return tf.nn.max_pool(x, ksize = [1, filter_height, filter_width, 1], 78 | strides = [1, stride_y, stride_x, 1], 79 | padding = padding, name = name) 80 | 81 | def dropout(x, keep_prob, is_training): 82 | # print(is_training) 83 | return tf.layers.dropout(x, rate = 1 - keep_prob, training = is_training) 84 | # return tf.nn.dropout(x, keep_prob, is_training = is_training) 85 | 86 | def batch_norm(x, name, train = True): 87 | return tf.contrib.layers.batch_norm(x, 88 | decay = 0.9, 89 | updates_collections = None, 90 | epsilon = 1e-5, 91 | scale = False, 92 | is_training = train, 93 | scope = name) 94 | 95 | def leaky_relu(x, leak = 0.2): 96 | return tf.maximum(x, leak*x) 97 | 98 | def new_normal_variable(name, shape = None, trainable = True, stddev = 0.002): 99 | return tf.get_variable(name, shape = shape, trainable = trainable, 100 | initializer = tf.random_normal_initializer(stddev = stddev)) 101 | 102 | 103 | 104 | 105 | -------------------------------------------------------------------------------- /tensorcv/models/losses.py: -------------------------------------------------------------------------------- 1 | # File: losses.py 2 | # Author: Qian Ge 3 | 4 | import tensorflow as tf 5 | import numpy as np 6 | 7 | 8 | def GAN_discriminator_loss(d_real, d_fake, name='d_loss'): 9 | print('---- d_loss -----') 10 | with tf.name_scope(name): 11 | d_loss_real = comp_loss_real(d_real) 12 | d_loss_fake = comp_loss_fake(d_fake) 13 | return tf.identity(d_loss_real + d_loss_fake, name='result') 14 | 15 | def GAN_generator_loss(d_fake, name='g_loss'): 16 | print('---- g_loss -----') 17 | with tf.name_scope(name): 18 | return tf.identity(comp_loss_real(d_fake), name='result') 19 | 20 | def comp_loss_fake(discrim_output): 21 | return tf.reduce_mean( 22 | tf.nn.sigmoid_cross_entropy_with_logits(logits=discrim_output, 23 | labels=tf.zeros_like(discrim_output))) 24 | 25 | def comp_loss_real(discrim_output): 26 | return tf.reduce_mean( 27 | tf.nn.sigmoid_cross_entropy_with_logits(logits=discrim_output, 28 | labels=tf.ones_like(discrim_output))) 29 | 30 | -------------------------------------------------------------------------------- /tensorcv/models/model_builder/base.py: -------------------------------------------------------------------------------- 1 | # File: base.py 2 | # Author: Qian Ge 3 | 4 | from abc import abstractmethod 5 | 6 | import tensorflow as tf 7 | import numpy as np 8 | 9 | 10 | __all__ = ['BaseBuilder'] 11 | 12 | class BaseBuilder(object): 13 | """ base model for model builder """ 14 | def __init__(self): 15 | self.input = [] 16 | self.output = [] 17 | def Add(self, BaseLayer): 18 | """ add one layer """ 19 | pass 20 | 21 | # def set_batch_size(self, val): 22 | # self._batch_size = val 23 | 24 | # def get_batch_size(self): 25 | # return self._batch_size 26 | 27 | # def set_is_training(self, is_training = True): 28 | # self.is_training = is_training 29 | 30 | # def get_placeholder(self): 31 | # return self._get_placeholder() 32 | 33 | # def _get_placeholder(self): 34 | # return [] 35 | 36 | # # TODO to be modified 37 | # def get_prediction_placeholder(self): 38 | # return self._get_prediction_placeholder() 39 | 40 | # def _get_prediction_placeholder(self): 41 | # return [] 42 | 43 | # def get_graph_feed(self): 44 | # return self._get_graph_feed() 45 | 46 | # def _get_graph_feed(self): 47 | # return {} 48 | 49 | # def create_graph(self): 50 | # self._create_graph() 51 | # self._setup_graph() 52 | # # self._setup_summary() 53 | 54 | # @abstractmethod 55 | # def _create_graph(self): 56 | # raise NotImplementedError() 57 | 58 | # def _setup_graph(self): 59 | # pass 60 | 61 | # # TDDO move outside of class 62 | # # summary will be created before prediction 63 | # # which is unnecessary 64 | # def setup_summary(self): 65 | # self._setup_summary() 66 | 67 | # def _setup_summary(self): 68 | # pass 69 | 70 | 71 | class BaseModel(ModelDes): 72 | """ Model with single loss and single optimizer """ 73 | 74 | def get_optimizer(self): 75 | try: 76 | return self.optimizer 77 | except AttributeError: 78 | self.optimizer = self._get_optimizer() 79 | return self.optimizer 80 | 81 | @property 82 | def default_collection(self): 83 | return 'default' 84 | 85 | def _get_optimizer(self): 86 | raise NotImplementedError() 87 | 88 | def get_loss(self): 89 | try: 90 | return self._loss 91 | except AttributeError: 92 | self._loss = self._get_loss() 93 | tf.summary.scalar('loss_summary', self.get_loss(), 94 | collections = [self.default_collection]) 95 | return self._loss 96 | 97 | def _get_loss(self): 98 | raise NotImplementedError() 99 | 100 | def get_grads(self): 101 | try: 102 | return self.grads 103 | except AttributeError: 104 | optimizer = self.get_optimizer() 105 | loss = self.get_loss() 106 | self.grads = optimizer.compute_gradients(loss) 107 | [tf.summary.histogram('gradient/' + var.name, grad, 108 | collections = [self.default_collection]) for grad, var in self.grads] 109 | return self.grads 110 | 111 | class GANBaseModel(ModelDes): 112 | """ Base model for GANs """ 113 | def __init__(self, input_vec_length, learning_rate): 114 | self.input_vec_length = input_vec_length 115 | assert len(learning_rate) == 2 116 | self.dis_learning_rate, self.gen_learning_rate = learning_rate 117 | 118 | @property 119 | def g_collection(self): 120 | return 'default_g' 121 | 122 | @property 123 | def d_collection(self): 124 | return 'default_d' 125 | 126 | def get_random_vec_placeholder(self): 127 | try: 128 | return self.Z 129 | except AttributeError: 130 | self.Z = tf.placeholder(tf.float32, [None, self.input_vec_length]) 131 | return self.Z 132 | 133 | def _get_prediction_placeholder(self): 134 | return self.get_random_vec_placeholder() 135 | 136 | def get_graph_feed(self): 137 | default_feed = self._get_graph_feed() 138 | random_input_feed = self._get_random_input_feed() 139 | default_feed.update(random_input_feed) 140 | return default_feed 141 | 142 | def _get_random_input_feed(self): 143 | feed = {self.get_random_vec_placeholder(): 144 | np.random.normal(size = (self.get_batch_size(), 145 | self.input_vec_length))} 146 | return feed 147 | 148 | def create_GAN(self, real_data, gen_data_name = 'gen_data'): 149 | with tf.variable_scope('generator') as scope: 150 | gen_data = self._generator() 151 | scope.reuse_variables() 152 | sample_gen_data = tf.identity(self._generator(train = False), 153 | name = gen_data_name) 154 | 155 | with tf.variable_scope('discriminator') as scope: 156 | d_real = self._discriminator(real_data) 157 | scope.reuse_variables() 158 | d_fake = self._discriminator(gen_data) 159 | 160 | with tf.name_scope('discriminator_out'): 161 | tf.summary.histogram('discrim_real', 162 | tf.nn.sigmoid(d_real), 163 | collections = [self.d_collection]) 164 | tf.summary.histogram('discrim_gen', 165 | tf.nn.sigmoid(d_fake), 166 | collections = [self.d_collection]) 167 | 168 | return gen_data, sample_gen_data, d_real, d_fake 169 | 170 | # def get_random_input_feed(self): 171 | # return self._get_random_input_feed() 172 | 173 | # def _get_random_input_feed(self): 174 | # return {} 175 | 176 | def get_discriminator_optimizer(self): 177 | try: 178 | return self.d_optimizer 179 | except AttributeError: 180 | self.d_optimizer = self._get_discriminator_optimizer() 181 | return self.d_optimizer 182 | 183 | def get_generator_optimizer(self): 184 | try: 185 | return self.g_optimizer 186 | except AttributeError: 187 | self.g_optimizer = self._get_generator_optimizer() 188 | return self.g_optimizer 189 | 190 | def _get_discriminator_optimizer(self): 191 | raise NotImplementedError() 192 | 193 | def _get_generator_optimizer(self): 194 | raise NotImplementedError() 195 | 196 | def get_discriminator_loss(self): 197 | try: 198 | return self.d_loss 199 | except AttributeError: 200 | self.d_loss = self._get_discriminator_loss() 201 | tf.summary.scalar('d_loss_summary', self.d_loss, 202 | collections = [self.d_collection]) 203 | return self.d_loss 204 | 205 | def get_generator_loss(self): 206 | try: 207 | return self.g_loss 208 | except AttributeError: 209 | self.g_loss = self._get_generator_loss() 210 | tf.summary.scalar('g_loss_summary', self.g_loss, 211 | collections = [self.g_collection]) 212 | return self.g_loss 213 | 214 | def _get_discriminator_loss(self): 215 | raise NotImplementedError() 216 | 217 | def _get_generator_loss(self): 218 | raise NotImplementedError() 219 | 220 | def get_discriminator_grads(self): 221 | try: 222 | return self.d_grads 223 | except AttributeError: 224 | d_training_vars = [v for v in tf.trainable_variables() 225 | if v.name.startswith('discriminator/')] 226 | optimizer = self.get_discriminator_optimizer() 227 | loss = self.get_discriminator_loss() 228 | self.d_grads = optimizer.compute_gradients(loss, 229 | var_list = d_training_vars) 230 | 231 | [tf.summary.histogram('d_gradient/' + var.name, grad, 232 | collections = [self.d_collection]) 233 | for grad, var in self.d_grads] 234 | return self.d_grads 235 | 236 | def get_generator_grads(self): 237 | try: 238 | return self.g_grads 239 | except AttributeError: 240 | g_training_vars = [v for v in tf.trainable_variables() 241 | if v.name.startswith('generator/')] 242 | optimizer = self.get_generator_optimizer() 243 | loss = self.get_generator_loss() 244 | self.g_grads = optimizer.compute_gradients(loss, 245 | var_list = g_training_vars) 246 | [tf.summary.histogram('g_gradient/' + var.name, grad, 247 | collections = [self.g_collection]) 248 | for grad, var in self.g_grads] 249 | return self.g_grads 250 | 251 | @staticmethod 252 | def comp_loss_fake(discrim_output): 253 | return tf.reduce_mean( 254 | tf.nn.sigmoid_cross_entropy_with_logits(logits = discrim_output, 255 | labels = tf.zeros_like(discrim_output))) 256 | 257 | @staticmethod 258 | def comp_loss_real(discrim_output): 259 | return tf.reduce_mean( 260 | tf.nn.sigmoid_cross_entropy_with_logits(logits = discrim_output, 261 | labels = tf.ones_like(discrim_output))) 262 | 263 | 264 | 265 | 266 | 267 | 268 | 269 | 270 | -------------------------------------------------------------------------------- /tensorcv/models/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # File: utils.py 4 | # Author: Qian Ge 5 | 6 | import math 7 | 8 | 9 | def deconv_size(input_height, input_width, stride=2): 10 | """ 11 | Compute the feature size (height and width) after filtering with 12 | a specific stride. Mostly used for setting the shape for deconvolution. 13 | 14 | Args: 15 | input_height (int): height of input feature 16 | input_width (int): width of input feature 17 | stride (int): stride of the filter 18 | 19 | Return: 20 | (int, int): Height and width of feature after filtering. 21 | """ 22 | return int(math.ceil(float(input_height) / float(stride))),\ 23 | int(math.ceil(float(input_width) / float(stride))) -------------------------------------------------------------------------------- /tensorcv/predicts/__init__.py: -------------------------------------------------------------------------------- 1 | # File: __init__.py 2 | # Author: Qian Ge 3 | 4 | from .config import * 5 | from .predictions import * 6 | -------------------------------------------------------------------------------- /tensorcv/predicts/base.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # File: base.py 4 | # Author: Qian Ge 5 | 6 | import os 7 | 8 | import tensorflow as tf 9 | 10 | from .config import PridectConfig 11 | from ..utils.sesscreate import ReuseSessionCreator 12 | from ..utils.common import assert_type 13 | from ..callbacks.hooks import Prediction2Hook 14 | 15 | __all__ = ['Predictor'] 16 | 17 | 18 | class Predictor(object): 19 | """Base class for a predictor. Used to run all predictions. 20 | 21 | Attributes: 22 | config (PridectConfig): the config used for this predictor 23 | model (ModelDes): 24 | input (DataFlow): 25 | sess (tf.Session): 26 | hooked_sess (tf.train.MonitoredSession): 27 | """ 28 | def __init__(self, config): 29 | """ Inits Predictor with config (PridectConfig). 30 | 31 | Will create session as well as monitored sessions for 32 | each predictions, and load pre-trained parameters. 33 | 34 | Args: 35 | config (PridectConfig): the config used for this predictor 36 | """ 37 | assert_type(config, PridectConfig) 38 | self._config = config 39 | self._model = config.model 40 | 41 | self._input = config.dataflow 42 | self._result_dir = config.result_dir 43 | 44 | # TODO to be modified 45 | self._model.set_is_training(False) 46 | self._model.create_graph() 47 | self._restore_vars = self._config.restore_vars 48 | 49 | # pass saving directory to predictions 50 | for pred in self._config.predictions: 51 | pred.setup(self._result_dir) 52 | 53 | hooks = [Prediction2Hook(pred) for pred in self._config.predictions] 54 | 55 | self.sess = self._config.session_creator.create_session() 56 | self.hooked_sess = tf.train.MonitoredSession( 57 | session_creator=ReuseSessionCreator(self.sess), hooks=hooks) 58 | 59 | # load pre-trained parameters 60 | load_model_path = os.path.join(self._config.model_dir, 61 | self._config.model_name) 62 | if self._restore_vars is not None: 63 | # variables = tf.contrib.framework.get_variables_to_restore() 64 | # variables_to_restore = [v for v in variables if v.name.split('/')[0] in self._restore_vars] 65 | # print(variables_to_restore) 66 | saver = tf.train.Saver(self._restore_vars) 67 | else: 68 | saver = tf.train.Saver() 69 | saver.restore(self.sess, load_model_path) 70 | 71 | def run_predict(self): 72 | """ 73 | Run predictions and the process after finishing predictions. 74 | """ 75 | with self.sess.as_default(): 76 | self._input.before_read_setup() 77 | self._predict_step() 78 | for pred in self._config.predictions: 79 | pred.after_finish_predict() 80 | 81 | self.after_prediction() 82 | 83 | def _predict_step(self): 84 | """ Run predictions. Defined in subclass. 85 | """ 86 | pass 87 | 88 | def after_prediction(self): 89 | self._after_prediction() 90 | 91 | def _after_prediction(self): 92 | pass 93 | 94 | -------------------------------------------------------------------------------- /tensorcv/predicts/config.py: -------------------------------------------------------------------------------- 1 | import scipy.misc 2 | import os 3 | import numpy as np 4 | 5 | from ..dataflow.base import DataFlow 6 | from ..models.base import ModelDes 7 | from ..utils.default import get_default_session_config 8 | from ..utils.sesscreate import NewSessionCreator 9 | from .predictions import PredictionBase 10 | from ..utils.common import check_dir 11 | 12 | __all__ = ['PridectConfig'] 13 | 14 | def assert_type(v, tp): 15 | assert isinstance(v, tp), \ 16 | "Expect " + str(tp) + ", but " + str(v.__class__) + " is given!" 17 | 18 | class PridectConfig(object): 19 | def __init__(self, 20 | dataflow=None, model=None, 21 | model_dir=None, model_name='', 22 | restore_vars=None, 23 | session_creator=None, 24 | predictions=None, 25 | batch_size=1, 26 | default_dirs=None): 27 | """ 28 | Args: 29 | """ 30 | self.model_name = model_name 31 | try: 32 | self.model_dir = os.path.join(default_dirs.model_dir) 33 | check_dir(self.model_dir) 34 | except AttributeError: 35 | raise AttributeError('model_dir is not set!') 36 | 37 | try: 38 | self.result_dir = os.path.join(default_dirs.result_dir) 39 | check_dir(self.result_dir) 40 | except AttributeError: 41 | raise AttributeError('result_dir is not set!') 42 | 43 | if restore_vars is not None: 44 | if not isinstance(restore_vars, list): 45 | restore_vars = [restore_vars] 46 | self.restore_vars = restore_vars 47 | 48 | 49 | assert dataflow is not None, "dataflow cannot be None!" 50 | assert_type(dataflow, DataFlow) 51 | self.dataflow = dataflow 52 | 53 | assert batch_size > 0 54 | self.dataflow.set_batch_size(batch_size) 55 | self.batch_size = batch_size 56 | 57 | assert model is not None, "model cannot be None!" 58 | assert_type(model, ModelDes) 59 | self.model = model 60 | 61 | assert predictions is not None, "predictions cannot be None" 62 | if not isinstance(predictions, list): 63 | predictions = [predictions] 64 | for pred in predictions: 65 | assert_type(pred, PredictionBase) 66 | self.predictions = predictions 67 | 68 | # if not isinstance(callbacks, list): 69 | # callbacks = [callbacks] 70 | # self._callbacks = callbacks 71 | 72 | if session_creator is None: 73 | self.session_creator = \ 74 | NewSessionCreator(config=get_default_session_config()) 75 | else: 76 | raise ValueError('custormer session creator is \ 77 | not allowed at this point!') 78 | 79 | @property 80 | def callbacks(self): 81 | return self._callbacks 82 | 83 | -------------------------------------------------------------------------------- /tensorcv/predicts/predictions.py: -------------------------------------------------------------------------------- 1 | # File: predictions.py 2 | # Author: Qian Ge 3 | 4 | import os 5 | import scipy.io 6 | import scipy.misc 7 | 8 | import tensorflow as tf 9 | import numpy as np 10 | 11 | from ..utils.common import get_tensors_by_names 12 | from ..utils.viz import * 13 | 14 | __all__ = ['PredictionImage', 'PredictionScalar', 'PredictionMat', 'PredictionMeanScalar', 'PredictionOverlay'] 15 | 16 | def assert_type(v, tp): 17 | assert isinstance(v, tp), \ 18 | "Expect " + str(tp) + ", but " + str(v.__class__) + " is given!" 19 | 20 | class PredictionBase(object): 21 | """ base class for prediction 22 | 23 | Attributes: 24 | _predictions 25 | _prefix_list 26 | _global_ind 27 | _save_dir 28 | """ 29 | def __init__(self, prediction_tensors, save_prefix): 30 | """ init prediction object 31 | 32 | Get tensors to be predicted and the prefix for saving 33 | each tensors 34 | 35 | Args: 36 | prediction_tensors : list[string] A tensor name or list of tensor names 37 | save_prefix: list[string] A string or list of strings 38 | Length of prediction_tensors and save_prefix have 39 | to be the same 40 | """ 41 | if not isinstance(prediction_tensors, list): 42 | prediction_tensors = [prediction_tensors] 43 | if not isinstance(save_prefix, list): 44 | save_prefix = [save_prefix] 45 | assert len(prediction_tensors) == len(save_prefix), \ 46 | 'Length of prediction_tensors {} and save_prefix {} has to be the same'.\ 47 | format(len(prediction_tensors), len(save_prefix)) 48 | 49 | self._predictions = prediction_tensors 50 | self._prefix_list = save_prefix 51 | self._global_ind = 0 52 | 53 | def setup(self, result_dir): 54 | assert os.path.isdir(result_dir) 55 | self._save_dir = result_dir 56 | 57 | self._predictions = get_tensors_by_names(self._predictions) 58 | 59 | def get_predictions(self): 60 | return self._predictions 61 | 62 | def after_prediction(self, results): 63 | """ process after predition 64 | default to save predictions 65 | """ 66 | self._save_prediction(results) 67 | 68 | def _save_prediction(self, results): 69 | pass 70 | 71 | def after_finish_predict(self): 72 | """ process after all prediction steps """ 73 | self._after_finish_predict() 74 | 75 | def _after_finish_predict(self): 76 | pass 77 | 78 | class PredictionImage(PredictionBase): 79 | """ Predict image output and save as files. 80 | 81 | Images are saved every batch. Each batch result can be 82 | save in one image or individule images. 83 | 84 | """ 85 | def __init__(self, prediction_image_tensors, 86 | save_prefix, merge_im=False, 87 | tanh=False, color=False): 88 | """ 89 | Args: 90 | prediction_image_tensors (list): a list of tensor names 91 | save_prefix (list): a list of file prefix for saving 92 | each tensor in prediction_image_tensors 93 | merge_im (bool): merge output of one batch or not 94 | """ 95 | self._merge = merge_im 96 | self._tanh = tanh 97 | self._color = color 98 | super(PredictionImage, self).__init__(prediction_tensors=prediction_image_tensors, 99 | save_prefix=save_prefix) 100 | 101 | def _save_prediction(self, results): 102 | 103 | for re, prefix in zip(results, self._prefix_list): 104 | cur_global_ind = self._global_ind 105 | if self._merge and re.shape[0] > 1: 106 | grid_size = self._get_grid_size(re.shape[0]) 107 | save_path = os.path.join(self._save_dir, 108 | str(cur_global_ind) + '_' + prefix + '.png') 109 | save_merge_images(np.squeeze(re), 110 | [grid_size, grid_size], save_path, 111 | tanh=self._tanh, color=self._color) 112 | cur_global_ind += 1 113 | else: 114 | for im in re: 115 | save_path = os.path.join(self._save_dir, 116 | str(cur_global_ind) + '_' + prefix + '.png') 117 | if self._color: 118 | im = intensity_to_rgb(np.squeeze(im), normalize=True) 119 | scipy.misc.imsave(save_path, np.squeeze(im)) 120 | cur_global_ind += 1 121 | self._global_ind = cur_global_ind 122 | 123 | def _get_grid_size(self, batch_size): 124 | try: 125 | return self._grid_size 126 | except AttributeError: 127 | self._grid_size = np.ceil(batch_size**0.5).astype(int) 128 | return self._grid_size 129 | 130 | class PredictionOverlay(PredictionImage): 131 | def __init__(self, prediction_image_tensors, 132 | save_prefix, merge_im=False, 133 | tanh=False, color=False): 134 | if not isinstance(prediction_image_tensors, list): 135 | prediction_image_tensors = [prediction_image_tensors] 136 | assert len(prediction_image_tensors) == 2,\ 137 | '[PredictionOverlay] requires two image tensors but the input len = {}.'.\ 138 | format(len(prediction_image_tensors)) 139 | 140 | super(PredictionOverlay, self).__init__(prediction_image_tensors, 141 | save_prefix, merge_im=merge_im, 142 | tanh=tanh, color=color) 143 | 144 | self._overlay_prefix = '{}_{}'.format(self._prefix_list[0], self._prefix_list[1]) 145 | 146 | def _save_prediction(self, results): 147 | cur_global_ind = self._global_ind 148 | 149 | if self._merge and results[0].shape[0] > 1: 150 | overlay_im_list = [] 151 | for im_1, im_2 in zip(results[0], results[1]): 152 | overlay_im = image_overlay(im_1, im_2, color=self._color) 153 | overlay_im_list.append(overlay_im) 154 | 155 | grid_size = self._get_grid_size(results[0].shape[0]) 156 | save_path = os.path.join(self._save_dir, 157 | str(cur_global_ind) + '_' + self._overlay_prefix + '.png') 158 | save_merge_images(np.squeeze(overlay_im_list), 159 | [grid_size, grid_size], save_path, 160 | tanh=self._tanh, color=False) 161 | cur_global_ind += 1 162 | else: 163 | for im_1, im_2 in zip(results[0], results[1]): 164 | overlay_im = image_overlay(im_1, im_2, color=self._color) 165 | save_path = os.path.join(self._save_dir, 166 | str(cur_global_ind) + '_' + self._overlay_prefix + '.png') 167 | scipy.misc.imsave(save_path, np.squeeze(overlay_im)) 168 | cur_global_ind += 1 169 | self._global_ind = cur_global_ind 170 | 171 | class PredictionScalar(PredictionBase): 172 | def __init__(self, prediction_scalar_tensors, print_prefix): 173 | """ 174 | Args: 175 | prediction_scalar_tensors (list): a list of tensor names 176 | print_prefix (list): a list of name prefix for printing 177 | each tensor in prediction_scalar_tensors 178 | """ 179 | 180 | super(PredictionScalar, self).__init__(prediction_tensors=prediction_scalar_tensors, 181 | save_prefix=print_prefix) 182 | 183 | def _save_prediction(self, results): 184 | for re, prefix in zip(results, self._prefix_list): 185 | print('{} = {}'.format(prefix, re)) 186 | 187 | class PredictionMeanScalar(PredictionScalar): 188 | def __init__(self, prediction_scalar_tensors, print_prefix): 189 | 190 | super(PredictionMeanScalar, self).__init__(prediction_scalar_tensors=prediction_scalar_tensors, 191 | print_prefix=print_prefix) 192 | 193 | self.scalar_list = [[] for i in range(0, len(self._predictions))] 194 | 195 | def _save_prediction(self, results): 196 | cnt = 0 197 | for re, prefix in zip(results, self._prefix_list): 198 | print('{} = {}'.format(prefix, re)) 199 | self.scalar_list[cnt].append(re) 200 | cnt += 1 201 | 202 | def _after_finish_predict(self): 203 | for i, prefix in enumerate(self._prefix_list): 204 | print('Overall {} = {}'.format(prefix, np.mean(self.scalar_list[i]))) 205 | 206 | 207 | class PredictionMat(PredictionBase): 208 | def _save_prediction(self, results): 209 | save_path = os.path.join(self._save_dir, 210 | str(self._global_ind) + '_' + 'batch_test' + '.mat') 211 | scipy.io.savemat(save_path, {name: np.squeeze(val) for name, val 212 | in zip(self._prefix_list, results)}) 213 | 214 | self._global_ind += 1 215 | 216 | 217 | 218 | 219 | 220 | 221 | 222 | 223 | 224 | 225 | 226 | 227 | 228 | 229 | 230 | 231 | -------------------------------------------------------------------------------- /tensorcv/predicts/simple.py: -------------------------------------------------------------------------------- 1 | # File: predictions.py 2 | # Author: Qian Ge 3 | 4 | import tensorflow as tf 5 | 6 | from .base import Predictor 7 | 8 | __all__ = ['SimpleFeedPredictor'] 9 | 10 | def assert_type(v, tp): 11 | assert isinstance(v, tp),\ 12 | "Expect " + str(tp) + ", but " + str(v.__class__) + " is given!" 13 | 14 | class SimpleFeedPredictor(Predictor): 15 | """ predictor with feed input """ 16 | # set_is_training 17 | def __init__(self, config): 18 | super(SimpleFeedPredictor, self).__init__(config) 19 | # TODO change len_input to other 20 | placeholders = self._model.get_prediction_placeholder() 21 | if not isinstance(placeholders, list): 22 | placeholders = [placeholders] 23 | self._plhs = placeholders 24 | # self.placeholder = self._model.get_random_vec_placeholder() 25 | # assert self.len_input <= len(self.placeholder) 26 | # self.placeholder = self.placeholder[0:self.len_input] 27 | 28 | def _predict_step(self): 29 | while self._input.epochs_completed < 1: 30 | try: 31 | cur_batch = self._input.next_batch() 32 | except AttributeError: 33 | cur_batch = self._input.next_batch() 34 | 35 | feed = dict(zip(self._plhs, cur_batch)) 36 | self.hooked_sess.run(fetches=[], feed_dict=feed) 37 | self._input.reset_epochs_completed(0) 38 | 39 | def _after_prediction(self): 40 | self._input.after_reading() 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | -------------------------------------------------------------------------------- /tensorcv/tfdataflow/__init__.py: -------------------------------------------------------------------------------- 1 | # File: __init__.py 2 | # Author: Qian Ge -------------------------------------------------------------------------------- /tensorcv/tfdataflow/base.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # File: base.py 4 | # Author: Qian Ge 5 | 6 | import tensorflow as tf 7 | from tensorflow.python.lib.io import file_io 8 | 9 | from ..dataflow.base import DataFlow 10 | from ..dataflow.normalization import identity 11 | from ..utils.utils import assert_type 12 | 13 | 14 | class DataFromTfrecord(DataFlow): 15 | def __init__(self, tfname, 16 | record_names, 17 | record_types, 18 | raw_types, 19 | decode_fncs, 20 | batch_dict_name, 21 | shuffle=True, 22 | data_shape=[], 23 | feature_len_list=None, 24 | pf=identity): 25 | 26 | if not isinstance(tfname, list): 27 | tfname = [tfname] 28 | if not isinstance(record_names, list): 29 | record_names = [record_names] 30 | if not isinstance(record_types, list): 31 | record_types = [record_types] 32 | for c_type in record_types: 33 | assert_type(c_type, tf.DType) 34 | if not isinstance(raw_types, list): 35 | raw_types = [raw_types] 36 | for raw_type in raw_types: 37 | assert_type(raw_type, tf.DType) 38 | if not isinstance(decode_fncs, list): 39 | decode_fncs = [decode_fncs] 40 | if not isinstance(batch_dict_name, list): 41 | batch_dict_name = [batch_dict_name] 42 | assert len(record_types) == len(record_names) 43 | assert len(record_types) == len(batch_dict_name) 44 | 45 | if feature_len_list is None: 46 | feature_len_list = [[] for i in range(0, len(record_names))] 47 | elif not isinstance(feature_len_list, list): 48 | feature_len_list = [feature_len_list] 49 | self._feat_len_list = feature_len_list 50 | if len(self._feat_len_list) < len(record_names): 51 | self._feat_len_list.extend([[] for i in range(0, len(record_names) - len(self._feat_len_list))]) 52 | 53 | self.record_names = record_names 54 | self.record_types = record_types 55 | self.raw_types = raw_types 56 | self.decode_fncs = decode_fncs 57 | self.data_shape = data_shape 58 | self._tfname = tfname 59 | self._batch_dict_name = batch_dict_name 60 | 61 | self._shuffle = shuffle 62 | 63 | # self._batch_step = 0 64 | # self.reset_epochs_completed(0) 65 | # self.set_batch_size(batch_size) 66 | self.setup_decode_data() 67 | self.setup(epoch_val=0, batch_size=1) 68 | 69 | 70 | def set_batch_size(self, batch_size): 71 | self._batch_size = batch_size 72 | self.updata_data_op(batch_size) 73 | self.updata_step_per_epoch(batch_size) 74 | 75 | def updata_data_op(self, batch_size): 76 | try: 77 | if self._shuffle is True: 78 | self._data = tf.train.shuffle_batch( 79 | self._decode_data, 80 | batch_size=batch_size, 81 | capacity=batch_size * 4, 82 | num_threads=2, 83 | min_after_dequeue=batch_size * 2) 84 | else: 85 | print('***** data is not shuffled *****') 86 | self._data = tf.train.batch( 87 | self._decode_data, 88 | batch_size=batch_size, 89 | capacity=batch_size, 90 | num_threads=1, 91 | allow_smaller_final_batch=False) 92 | # self._data = self._decode_data[0] 93 | # print(self._data) 94 | except AttributeError: 95 | pass 96 | 97 | def reset_epochs_completed(self, val): 98 | self._epochs_completed = val 99 | self._batch_step = 0 100 | 101 | # def _setup(self, **kwargs): 102 | def setup_decode_data(self): 103 | # n_epoch = kwargs['num_epoch'] 104 | 105 | feature = {} 106 | for record_name, r_type, cur_size in zip(self.record_names, self.record_types, self._feat_len_list): 107 | feature[record_name] = tf.FixedLenFeature(cur_size, r_type) 108 | # filename_queue = tf.train.string_input_producer(self._tfname, num_epochs=n_epoch) 109 | filename_queue = tf.train.string_input_producer(self._tfname) 110 | reader = tf.TFRecordReader() 111 | _, serialized_example = reader.read(filename_queue) 112 | features = tf.parse_single_example(serialized_example, features=feature) 113 | decode_data = [decode_fnc(features[record_name], raw_type) 114 | for decode_fnc, record_name, raw_type 115 | in zip(self.decode_fncs, self.record_names, self.raw_types)] 116 | 117 | for idx, c_shape in enumerate(self.data_shape): 118 | if c_shape: 119 | decode_data[idx] = tf.reshape(decode_data[idx], c_shape) 120 | 121 | self._decode_data = decode_data 122 | 123 | # self._data = self._decode_data[0] 124 | try: 125 | self.set_batch_size(batch_size=self._batch_size) 126 | except AttributeError: 127 | self.set_batch_size(batch_size=1) 128 | 129 | def updata_step_per_epoch(self, batch_size): 130 | self._step_per_epoch = int(self.size() / batch_size) 131 | 132 | def before_read_setup(self): 133 | self.coord = tf.train.Coordinator() 134 | self.threads = tf.train.start_queue_runners(coord=self.coord) 135 | 136 | def next_batch(self): 137 | sess = tf.get_default_session() 138 | batch_data = sess.run(self._data) 139 | self._batch_step += 1 140 | if self._batch_step % self._step_per_epoch == 0: 141 | self._epochs_completed += 1 142 | # print(batch_data[2]) 143 | return batch_data 144 | 145 | def next_batch_dict(self): 146 | sess = tf.get_default_session() 147 | batch_data = sess.run(self._data) 148 | self._batch_step += 1 149 | if self._batch_step % self._step_per_epoch == 0: 150 | self._epochs_completed += 1 151 | batch_dict = {name: data for name, data in zip(self._batch_dict_name, batch_data)} 152 | return batch_dict 153 | 154 | def after_reading(self): 155 | self.coord.request_stop() 156 | self.coord.join(self.threads) 157 | 158 | def size(self): 159 | try: 160 | return self._size 161 | except AttributeError: 162 | self._size = sum(1 for f in self._tfname for _ in tf.python_io.tf_record_iterator(f)) 163 | return self._size 164 | -------------------------------------------------------------------------------- /tensorcv/tfdataflow/convert.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # File: convert.py 4 | # Author: Qian Ge 5 | 6 | import tensorflow as tf 7 | 8 | from ..utils.utils import assert_type 9 | from ..dataflow.base import DataFlow 10 | 11 | 12 | def int64_feature(value): 13 | return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) 14 | 15 | 16 | def bytes_feature(value): 17 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 18 | 19 | 20 | def convert_image(image): 21 | return bytes_feature(tf.compat.as_bytes(image.tostring())) 22 | 23 | 24 | def float_feature(value): 25 | return tf.train.Feature(float_list=tf.train.FloatList(value=value)) 26 | 27 | 28 | def dataflow2tfrecord(dataflow, tfname, record_names, c_fncs): 29 | assert_type(dataflow, DataFlow) 30 | dataflow.setup(epoch_val=0, batch_size=1) 31 | 32 | if not isinstance(record_names, list): 33 | record_names = [record_names] 34 | if not isinstance(c_fncs, list): 35 | c_fncs = [c_fncs] 36 | assert len(c_fncs) == len(record_names) 37 | 38 | tfrecords_filename = tfname 39 | writer = tf.python_io.TFRecordWriter(tfrecords_filename) 40 | 41 | while dataflow.epochs_completed < 1: 42 | batch_data = dataflow.next_batch() 43 | feature = {} 44 | for record_name, convert_fnc, data in\ 45 | zip(record_names, c_fncs, batch_data): 46 | feature[record_name] = convert_fnc(data[0]) 47 | 48 | example = tf.train.Example(features=tf.train.Features(feature=feature)) 49 | writer.write(example.SerializeToString()) 50 | 51 | writer.close() 52 | -------------------------------------------------------------------------------- /tensorcv/tfdataflow/write.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # File: write.py 4 | # Author: Qian Ge 5 | 6 | import tensorflow as tf 7 | 8 | from ..dataflow.base import DataFlow 9 | from ..models.base import BaseModel 10 | from ..utils.utils import assert_type 11 | from .convert import float_feature 12 | 13 | 14 | class Bottleneck2TFrecord(object): 15 | def __init__(self, nets, record_feat_names, 16 | feat_preprocess=tf.identity): 17 | if not isinstance(nets, list): 18 | nets = [nets] 19 | for net in nets: 20 | assert_type(net, BaseModel) 21 | if not isinstance(record_feat_names, list): 22 | record_feat_names = [record_feat_names] 23 | assert len(nets) == len(record_feat_names) 24 | self._w_f_names = record_feat_names 25 | 26 | self._feat_ops = [] 27 | self._feed_plh_keys = [] 28 | self._net_input_dicts = [] 29 | for net in nets: 30 | net.set_is_training(False) 31 | net.create_graph() 32 | self._net_input_dicts.append(net.input_dict) 33 | self._feed_plh_keys.append(net.prediction_plh_dict) 34 | self._feat_ops.append(feat_preprocess(net.layer['conv_out'])) 35 | 36 | def write(self, tfname, dataflow, 37 | record_dataflow_keys=[], record_dataflow_names=[], c_fncs=[]): 38 | assert_type(dataflow, DataFlow) 39 | dataflow.setup(epoch_val=0, batch_size=1) 40 | 41 | if not isinstance(record_dataflow_names, list): 42 | record_dataflow_names = [record_dataflow_names] 43 | if not isinstance(c_fncs, list): 44 | c_fncs = [c_fncs] 45 | if not isinstance(record_dataflow_keys, list): 46 | record_dataflow_keys = [record_dataflow_keys] 47 | assert len(c_fncs) == len(record_dataflow_names) 48 | assert len(record_dataflow_keys) == len(record_dataflow_names) 49 | 50 | tfrecords_filename = tfname 51 | writer = tf.python_io.TFRecordWriter(tfrecords_filename) 52 | 53 | with tf.Session() as sess: 54 | sess.run(tf.local_variables_initializer()) 55 | sess.run(tf.global_variables_initializer()) 56 | dataflow.before_read_setup() 57 | cnt = 0 58 | while dataflow.epochs_completed < 1: 59 | print('Writing data {}...'.format(cnt)) 60 | batch_data = dataflow.next_batch_dict() 61 | 62 | feats = [] 63 | for feat_op, feed_plh_key, net_input_dict in zip(self._feat_ops, self._feed_plh_keys, self._net_input_dicts): 64 | feed_dict = {net_input_dict[key]: batch_data[key] 65 | for key in feed_plh_key} 66 | feats.append(sess.run(feat_op, feed_dict=feed_dict)) 67 | 68 | # feature = {} 69 | # for record_name, convert_fnc, key in zip(record_dataflow_names, c_fncs, record_dataflow_keys): 70 | # feature[record_name] = convert_fnc(batch_data[key][0]) 71 | 72 | 73 | # for record_name, feat in zip(self._w_f_names, feats): 74 | # feature[record_name] = float_feature(feat.reshape(-1).tolist()) 75 | 76 | # feature_list = [] 77 | batch_size = len(feats[0]) 78 | for idx in range(0, batch_size): 79 | feature = {} 80 | for record_name, convert_fnc, key in zip(record_dataflow_names, c_fncs, record_dataflow_keys): 81 | feature[record_name] = convert_fnc( 82 | batch_data[key][idx]) 83 | 84 | for record_name, feat in zip(self._w_f_names, feats): 85 | feature[record_name] =\ 86 | float_feature(feat[idx].reshape(-1).tolist()) 87 | 88 | example = tf.train.Example( 89 | features=tf.train.Features(feature=feature)) 90 | writer.write(example.SerializeToString()) 91 | 92 | cnt += 1 93 | dataflow.after_reading() 94 | 95 | writer.flush() 96 | writer.close() 97 | -------------------------------------------------------------------------------- /tensorcv/train/__init__.py: -------------------------------------------------------------------------------- 1 | # File: __init__.py 2 | # Author: Qian Ge 3 | 4 | -------------------------------------------------------------------------------- /tensorcv/train/base.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | import weakref 3 | import os 4 | 5 | import tensorflow as tf 6 | 7 | from .config import TrainConfig 8 | from ..callbacks.base import Callback 9 | from ..callbacks.group import Callbacks 10 | from ..utils.sesscreate import ReuseSessionCreator 11 | from ..callbacks.monitors import TrainingMonitor, Monitors 12 | 13 | 14 | __all__ = ['Trainer'] 15 | 16 | def assert_type(v, tp): 17 | assert isinstance(v, tp),\ 18 | "Expect " + str(tp) + ", but " + str(v.__class__) + " is given!" 19 | 20 | class Trainer(object): 21 | """ base class for trainer """ 22 | def __init__(self, config): 23 | assert_type(config, TrainConfig) 24 | self._is_load = config.is_load 25 | self.config = config 26 | self.model = config.model 27 | self.model.ex_init_model(config.dataflow, weakref.proxy(self)) 28 | self.dataflow = config.dataflow 29 | # self.monitors = self.config.monitors 30 | self._global_step = 0 31 | self._callbacks = [] 32 | self.monitors = [] 33 | 34 | self.default_dirs = config.default_dirs 35 | 36 | @property 37 | def epochs_completed(self): 38 | return self.dataflow.epochs_completed 39 | 40 | @property 41 | def get_global_step(self): 42 | return self._global_step 43 | 44 | def register_callback(self, cb): 45 | assert_type(cb, Callback) 46 | assert not isinstance(self._callbacks, Callbacks), \ 47 | "callbacks have been setup" 48 | self._callbacks.append(cb) 49 | 50 | def register_monitor(self, monitor): 51 | assert_type(monitor, TrainingMonitor) 52 | assert not isinstance(self.monitors, Monitors), \ 53 | "monitors have been setup" 54 | self.monitors.append(monitor) 55 | self.register_callback(monitor) 56 | 57 | 58 | def _create_session(self): 59 | hooks = self._callbacks.get_hooks() 60 | self.sess = self.config.session_creator.create_session() 61 | 62 | self.hooked_sess = tf.train.MonitoredSession( 63 | session_creator=ReuseSessionCreator(self.sess), hooks=hooks) 64 | 65 | if self._is_load: 66 | load_model_path = os.path.join(self.config.model_dir, 67 | self.config.model_name) 68 | saver = tf.train.Saver() 69 | saver.restore(self.sess, load_model_path) 70 | 71 | def main_loop(self): 72 | with self.sess.as_default(): 73 | self._callbacks.before_train() 74 | while self.epochs_completed <= self.config.max_epoch: 75 | self._global_step += 1 76 | print('Epoch: {}. Step: {}'.\ 77 | format(self.epochs_completed, self._global_step)) 78 | # self._callbacks.before_epoch() 79 | # TODO to be modified 80 | self.model.set_is_training(True) 81 | self._run_step() 82 | # self._callbacks.after_epoch() 83 | self._callbacks.trigger_step() 84 | self._callbacks.after_train() 85 | 86 | def train(self): 87 | self.setup() 88 | self.main_loop() 89 | 90 | @abstractmethod 91 | def _run_step(self): 92 | model_feed = self.model.get_graph_feed() 93 | self.hooked_sess.run(self.train_op, feed_dict=model_feed) 94 | 95 | def setup(self): 96 | # setup graph from model 97 | self.setup_graph() 98 | 99 | # setup callbacks 100 | for cb in self.config.callbacks: 101 | self.register_callback(cb) 102 | for monitor in self.config.monitors: 103 | self.register_monitor(monitor) 104 | self._callbacks = Callbacks(self._callbacks) 105 | self._callbacks.setup_graph(weakref.proxy(self)) 106 | self.monitors = Monitors(self.monitors) 107 | # create session 108 | self._create_session() 109 | 110 | 111 | 112 | self.sess.graph.finalize() 113 | 114 | def setup_graph(self): 115 | self.model.create_graph() 116 | self._setup() 117 | self.model.setup_summary() 118 | 119 | def _setup(self): 120 | pass 121 | 122 | 123 | 124 | 125 | 126 | 127 | -------------------------------------------------------------------------------- /tensorcv/train/config.py: -------------------------------------------------------------------------------- 1 | import scipy.misc 2 | import os 3 | import numpy as np 4 | 5 | from ..dataflow.base import DataFlow 6 | from ..models.base import ModelDes, GANBaseModel 7 | from ..utils.default import get_default_session_config 8 | from ..utils.sesscreate import NewSessionCreator 9 | from ..callbacks.monitors import TFSummaryWriter 10 | from ..callbacks.summary import TrainSummary 11 | from ..utils.common import check_dir 12 | 13 | __all__ = ['TrainConfig', 'GANTrainConfig'] 14 | 15 | def assert_type(v, tp): 16 | assert isinstance(v, tp),\ 17 | "Expect " + str(tp) + ", but " + str(v.__class__) + " is given!" 18 | 19 | class TrainConfig(object): 20 | def __init__(self, 21 | dataflow=None, model=None, 22 | callbacks=[], 23 | session_creator=None, 24 | monitors=None, 25 | batch_size=1, max_epoch=100, 26 | summary_periodic=None, 27 | is_load=False, 28 | model_name=None, 29 | default_dirs=None): 30 | self.default_dirs = default_dirs 31 | 32 | assert_type(monitors, TFSummaryWriter), \ 33 | "monitors has to be TFSummaryWriter at this point!" 34 | if not isinstance(monitors, list): 35 | monitors = [monitors] 36 | self.monitors = monitors 37 | 38 | assert dataflow is not None, "dataflow cannot be None!" 39 | assert_type(dataflow, DataFlow) 40 | self.dataflow = dataflow 41 | 42 | assert model is not None, "model cannot be None!" 43 | assert_type(model, ModelDes) 44 | self.model = model 45 | 46 | assert batch_size > 0 and max_epoch > 0 47 | self.dataflow.set_batch_size(batch_size) 48 | self.model.set_batch_size(batch_size) 49 | self.batch_size = batch_size 50 | self.max_epoch = max_epoch 51 | 52 | self.is_load = is_load 53 | if is_load: 54 | assert not model_name is None,\ 55 | '[TrainConfig]: model_name cannot be None when is_load is True!' 56 | self.model_name = model_name 57 | try: 58 | self.model_dir = os.path.join(default_dirs.model_dir) 59 | check_dir(self.model_dir) 60 | except AttributeError: 61 | raise AttributeError('model_dir is not set!') 62 | 63 | # if callbacks is None: 64 | # callbacks = [] 65 | if not isinstance(callbacks, list): 66 | callbacks = [callbacks] 67 | self._callbacks = callbacks 68 | 69 | # TODO model.default_collection only in BaseModel class 70 | if isinstance(summary_periodic, int): 71 | self._callbacks.append( 72 | TrainSummary(key=model.default_collection, 73 | periodic=summary_periodic)) 74 | 75 | if session_creator is None: 76 | self.session_creator = \ 77 | NewSessionCreator(config=get_default_session_config()) 78 | else: 79 | raise ValueError('custormer session creator is not allowed at this point!') 80 | 81 | @property 82 | def callbacks(self): 83 | return self._callbacks 84 | 85 | 86 | class GANTrainConfig(TrainConfig): 87 | def __init__(self, 88 | dataflow=None, model=None, 89 | discriminator_callbacks=[], 90 | generator_callbacks=[], 91 | session_creator=None, 92 | monitors=None, 93 | batch_size=1, max_epoch=100, 94 | summary_d_periodic=None, 95 | summary_g_periodic=None, 96 | default_dirs=None): 97 | 98 | assert_type(model, GANBaseModel) 99 | 100 | if not isinstance(discriminator_callbacks, list): 101 | discriminator_callbacks = [discriminator_callbacks] 102 | self._dis_callbacks = discriminator_callbacks 103 | 104 | if not isinstance(generator_callbacks, list): 105 | generator_callbacks = [generator_callbacks] 106 | self._gen_callbacks = generator_callbacks 107 | 108 | if isinstance(summary_d_periodic, int): 109 | self._dis_callbacks.append( 110 | TrainSummary(key=model.d_collection, 111 | periodic=summary_d_periodic)) 112 | if isinstance(summary_g_periodic, int): 113 | self._dis_callbacks.append( 114 | TrainSummary(key=model.g_collection, 115 | periodic=summary_g_periodic)) 116 | 117 | callbacks = self._dis_callbacks + self._gen_callbacks 118 | 119 | super(GANTrainConfig, self).__init__( 120 | dataflow=dataflow, model=model, 121 | callbacks=callbacks, 122 | session_creator=session_creator, 123 | monitors=monitors, 124 | batch_size=batch_size, max_epoch=ßmax_epoch, 125 | default_dirs=default_dirs) 126 | @property 127 | def dis_callbacks(self): 128 | return self._dis_callbacks 129 | @property 130 | def gen_callbacks(self): 131 | return self._gen_callbacks 132 | 133 | -------------------------------------------------------------------------------- /tensorcv/train/simple.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | 3 | import tensorflow as tf 4 | 5 | from .config import TrainConfig, GANTrainConfig 6 | from .base import Trainer 7 | from ..callbacks.inputs import FeedInput 8 | from ..callbacks.group import Callbacks 9 | from ..callbacks.hooks import Callback2Hook 10 | from ..models.base import BaseModel, GANBaseModel 11 | from ..utils.sesscreate import ReuseSessionCreator 12 | 13 | 14 | __all__ = ['SimpleFeedTrainer'] 15 | 16 | def assert_type(v, tp): 17 | assert isinstance(v, tp),\ 18 | "Expect " + str(tp) + ", but " + str(v.__class__) + " is given!" 19 | 20 | class SimpleFeedTrainer(Trainer): 21 | """ single optimizer """ 22 | def __init__(self, config): 23 | assert_type(config.model, BaseModel) 24 | super(SimpleFeedTrainer, self).__init__(config) 25 | 26 | def _setup(self): 27 | # TODO to be modified 28 | cbs = FeedInput(self.dataflow, self.model.get_train_placeholder()) 29 | 30 | self.config.callbacks.append(cbs) 31 | 32 | grads = self.model.get_grads() 33 | opt = self.model.get_optimizer() 34 | self.train_op = opt.apply_gradients(grads, name='train') 35 | 36 | class GANFeedTrainer(Trainer): 37 | def __init__(self, config): 38 | assert_type(config, GANTrainConfig) 39 | # assert_type(config.model, GANBaseModel) 40 | 41 | # config.model.set_batch_size(config.batch_size) 42 | 43 | super(GANFeedTrainer, self).__init__(config) 44 | 45 | def _setup(self): 46 | # TODO to be modified 47 | # Since FeedInput only have before_run, 48 | # it is safe to put this cb only in hooks. 49 | cbs = FeedInput(self.dataflow, self.model.get_train_placeholder()) 50 | # self.config.callbacks.append(cbs) 51 | self.feed_input_hook = [Callback2Hook(cbs)] 52 | 53 | dis_grads = self.model.get_discriminator_grads() 54 | dis_opt = self.model.get_discriminator_optimizer() 55 | self.dis_train_op = dis_opt.apply_gradients(dis_grads, 56 | name='discriminator_train') 57 | 58 | gen_grads = self.model.get_generator_grads() 59 | gen_opt = self.model.get_generator_optimizer() 60 | self.gen_train_op = gen_opt.apply_gradients(gen_grads, 61 | name='generator_train') 62 | 63 | def _create_session(self): 64 | self._dis_callbacks = Callbacks([cb 65 | for cb in self.config.dis_callbacks]) 66 | self._gen_callbacks = Callbacks([cb 67 | for cb in self.config.gen_callbacks]) 68 | dis_hooks = self._dis_callbacks.get_hooks() 69 | gen_hooks = self._gen_callbacks.get_hooks() 70 | 71 | self.sess = self.config.session_creator.create_session() 72 | self.dis_hooked_sess = tf.train.MonitoredSession( 73 | session_creator=ReuseSessionCreator(self.sess), 74 | hooks=dis_hooks + self.feed_input_hook) 75 | self.gen_hooked_sess = tf.train.MonitoredSession( 76 | session_creator=ReuseSessionCreator(self.sess), 77 | hooks=gen_hooks) 78 | 79 | def _run_step(self): 80 | model_feed = self.model.get_graph_feed() 81 | self.dis_hooked_sess.run(self.dis_train_op, feed_dict=model_feed) 82 | 83 | for k in range(0,2): 84 | model_feed = self.model.get_graph_feed() 85 | self.gen_hooked_sess.run(self.gen_train_op, feed_dict=ßmodel_feed) 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | -------------------------------------------------------------------------------- /tensorcv/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # File: __init__.py 2 | # Author: Qian Ge 3 | -------------------------------------------------------------------------------- /tensorcv/utils/common.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # File: common.py 4 | # Author: Qian Ge 5 | 6 | import math 7 | import os 8 | 9 | import tensorflow as tf 10 | 11 | __all__ = ['apply_mask', 'apply_mask_inverse', 'get_tensors_by_names', 12 | 'deconv_size', 'match_tensor_save_name'] 13 | 14 | 15 | def apply_mask(input_matrix, mask): 16 | """Get partition of input_matrix using index 1 in mask. 17 | 18 | Args: 19 | input_matrix (Tensor): A Tensor 20 | mask (int): A Tensor of type int32 with indices in {0, 1}. Shape 21 | has to be the same as input_matrix. 22 | 23 | Return: 24 | A Tensor with elements from data with entries in mask equal to 1. 25 | """ 26 | return tf.dynamic_partition(input_matrix, mask, 2)[1] 27 | 28 | 29 | def apply_mask_inverse(input_matrix, mask): 30 | """Get partition of input_matrix using index 0 in mask. 31 | 32 | Args: 33 | input_matrix (Tensor): A Tensor 34 | mask (int): A Tensor of type int32 with indices in {0, 1}. Shape 35 | has to be the same as input_matrix. 36 | 37 | Return: 38 | A Tensor with elements from data with entries in mask equal to 0. 39 | """ 40 | return tf.dynamic_partition(input_matrix, mask, 2)[0] 41 | 42 | 43 | def get_tensors_by_names(names): 44 | """Get a list of tensors by the input name list. 45 | 46 | Args: 47 | names (str): A str or a list of str 48 | 49 | Return: 50 | A list of tensors with name in input names. 51 | 52 | Warning: 53 | If more than one tensor have the same name in the graph. This function 54 | will only return the tensor with name NAME:0. 55 | """ 56 | if not isinstance(names, list): 57 | names = [names] 58 | 59 | graph = tf.get_default_graph() 60 | tensor_list = [] 61 | # TODO assume there is no repeativie names 62 | for name in names: 63 | tensor_name = name + ':0' 64 | tensor_list += graph.get_tensor_by_name(tensor_name), 65 | return tensor_list 66 | 67 | 68 | def deconv_size(input_height, input_width, stride=2): 69 | """ 70 | Compute the feature size (height and width) after filtering with 71 | a specific stride. Mostly used for setting the shape for deconvolution. 72 | 73 | Args: 74 | input_height (int): height of input feature 75 | input_width (int): width of input feature 76 | stride (int): stride of the filter 77 | 78 | Return: 79 | (int, int): Height and width of feature after filtering. 80 | """ 81 | print('***** WARNING ********: deconv_size is moved to models.utils.py') 82 | return int(math.ceil(float(input_height) / float(stride))),\ 83 | int(math.ceil(float(input_width) / float(stride))) 84 | 85 | 86 | def match_tensor_save_name(tensor_names, save_names): 87 | """ 88 | Match tensor_names and corresponding save_names for saving the results of 89 | the tenors. If the number of tensors is less or equal to the length 90 | of save names, tensors will be saved using the corresponding names in 91 | save_names. Otherwise, tensors will be saved using their own names. 92 | Used for prediction or inference. 93 | 94 | Args: 95 | tensor_names (str): List of tensor names 96 | save_names (str): List of names for saving tensors 97 | 98 | Return: 99 | (list, list): List of tensor names and list of names to save 100 | the tensors. 101 | """ 102 | if not isinstance(tensor_names, list): 103 | tensor_names = [tensor_names] 104 | if save_names is None: 105 | return tensor_names, tensor_names 106 | elif not isinstance(save_names, list): 107 | save_names = [save_names] 108 | if len(save_names) < len(tensor_names): 109 | return tensor_names, tensor_names 110 | else: 111 | return tensor_names, save_names 112 | 113 | 114 | def check_dir(input_dir): 115 | print('***** WARNING ********: check_dir is moved to utils.utils.py') 116 | assert input_dir is not None, "dir cannot be None!" 117 | assert os.path.isdir(input_dir), input_dir + ' does not exist!' 118 | 119 | 120 | def assert_type(v, tp): 121 | print('***** WARNING ********: assert_type is moved to utils.utils.py') 122 | """ 123 | Assert type of input v be type tp 124 | """ 125 | assert isinstance(v, tp),\ 126 | "Expect " + str(tp) + ", but " + str(v.__class__) + " is given!" 127 | -------------------------------------------------------------------------------- /tensorcv/utils/debug.py: -------------------------------------------------------------------------------- 1 | def hello_cv(): 2 | print('hello') 3 | -------------------------------------------------------------------------------- /tensorcv/utils/default.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # File: default.py 4 | # Author: Qian Ge 5 | 6 | import tensorflow as tf 7 | 8 | __all__ = ['get_default_session_config'] 9 | 10 | 11 | def get_default_session_config(memory_fraction=1): 12 | """Default config of a TensorFlow session 13 | 14 | Args: 15 | memory_fraction (float): Memory fraction of GPU for this session 16 | 17 | Return: 18 | tf.ConfigProto(): Config of session. 19 | """ 20 | conf = tf.ConfigProto() 21 | conf.gpu_options.per_process_gpu_memory_fraction = memory_fraction 22 | conf.gpu_options.allow_growth = True 23 | 24 | return conf 25 | -------------------------------------------------------------------------------- /tensorcv/utils/sesscreate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # File: default.py 4 | # Author: Qian Ge 5 | # Modified from https://github.com/ppwwyyxx/tensorpack/blob/master/tensorpack/tfutils/sesscreate.py 6 | 7 | import tensorflow as tf 8 | 9 | from .default import get_default_session_config 10 | 11 | __all__ = ['NewSessionCreator', 'ReuseSessionCreator'] 12 | 13 | 14 | class NewSessionCreator(tf.train.SessionCreator): 15 | """ 16 | tf.train.SessionCreator for a new session 17 | """ 18 | def __init__(self, target='', graph=None, config=None): 19 | """ Inits NewSessionCreator with targe, graph and config. 20 | 21 | Args: 22 | target: same as :meth:`tf.Session.__init__()`. 23 | graph: same as :meth:`tf.Session.__init__()`. 24 | config: same as :meth:`tf.Session.__init__()`. Default to 25 | :func:`utils.default.get_default_session_config()`. 26 | """ 27 | self.target = target 28 | if config is not None: 29 | self.config = config 30 | else: 31 | self.config = get_default_session_config() 32 | self.graph = graph 33 | 34 | def create_session(self): 35 | """Create session as well as initialize global and local variables 36 | 37 | Return: 38 | A tf.Session object containing nodes for all of the 39 | operations in the underlying TensorFlow graph. 40 | """ 41 | sess = tf.Session(target=self.target, 42 | graph=self.graph, config=self.config) 43 | sess.run(tf.global_variables_initializer()) 44 | sess.run(tf.local_variables_initializer()) 45 | return sess 46 | 47 | 48 | class ReuseSessionCreator(tf.train.SessionCreator): 49 | """ 50 | tf.train.SessionCreator for reuse an existed session 51 | """ 52 | def __init__(self, sess): 53 | """ Inits ReuseSessionCreator with an existed session. 54 | 55 | Args: 56 | sess (tf.Session): an existed tf.Session object 57 | """ 58 | self.sess = sess 59 | 60 | def create_session(self): 61 | """Create session by reusing an existing session 62 | 63 | Return: 64 | A reused tf.Session object containing nodes for all of the 65 | operations in the underlying TensorFlow graph. 66 | """ 67 | return self.sess 68 | -------------------------------------------------------------------------------- /tensorcv/utils/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # File: utils.py 4 | # Author: Qian Ge 5 | 6 | import os 7 | from datetime import datetime 8 | import numpy as np 9 | 10 | __all__ = ['get_rng'] 11 | 12 | _RNG_SEED = None 13 | 14 | 15 | def get_rng(obj=None): 16 | """ 17 | This function is copied from `tensorpack 18 | `__. 19 | Get a good RNG seeded with time, pid and the object. 20 | Args: 21 | obj: some object to use to generate random seed. 22 | Returns: 23 | np.random.RandomState: the RNG. 24 | """ 25 | seed = (id(obj) + os.getpid() + 26 | int(datetime.now().strftime("%Y%m%d%H%M%S%f"))) % 4294967295 27 | if _RNG_SEED is not None: 28 | seed = _RNG_SEED 29 | return np.random.RandomState(seed) 30 | 31 | 32 | def check_dir(input_dir): 33 | assert input_dir is not None, "dir cannot be None!" 34 | assert os.path.isdir(input_dir), input_dir + ' does not exist!' 35 | 36 | 37 | def assert_type(v, tp): 38 | """ 39 | Assert type of input v be type tp 40 | """ 41 | assert isinstance(v, tp),\ 42 | "Expect " + str(tp) + ", but " + str(v.__class__) + " is given!" 43 | -------------------------------------------------------------------------------- /tensorcv/utils/viz.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # File: viz.py 3 | # Author: Qian Ge 4 | 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | import scipy.misc 8 | 9 | 10 | def intensity_to_rgb(intensity, cmap='jet', normalize=False): 11 | """ 12 | This function is copied from `tensorpack 13 | `__. 14 | Convert a 1-channel matrix of intensities to an RGB image employing 15 | a colormap. 16 | This function requires matplotlib. See `matplotlib colormaps 17 | `_ for a 18 | list of available colormap. 19 | 20 | Args: 21 | intensity (np.ndarray): array of intensities such as saliency. 22 | cmap (str): name of the colormap to use. 23 | normalize (bool): if True, will normalize the intensity so that it has 24 | minimum 0 and maximum 1. 25 | 26 | Returns: 27 | np.ndarray: an RGB float32 image in range [0, 255], a colored heatmap. 28 | """ 29 | 30 | # assert intensity.ndim == 2, intensity.shape 31 | intensity = intensity.astype("float") 32 | 33 | if normalize: 34 | intensity -= intensity.min() 35 | intensity /= intensity.max() 36 | 37 | if intensity.ndim == 3: 38 | return intensity.astype('float32') * 255.0 39 | 40 | cmap = plt.get_cmap(cmap) 41 | intensity = cmap(intensity)[..., :3] 42 | return intensity.astype('float32') * 255.0 43 | 44 | 45 | def save_merge_images(images, merge_grid, save_path, color=False, tanh=False): 46 | """Save multiple images with same size into one larger image. 47 | 48 | The best size number is 49 | int(max(sqrt(image.shape[0]),sqrt(image.shape[1]))) + 1 50 | 51 | Args: 52 | images (np.ndarray): A batch of image array to be merged with size 53 | [BATCH_SIZE, HEIGHT, WIDTH, CHANNEL]. 54 | merge_grid (list): List of length 2. The grid size for merge images. 55 | save_path (str): Path for saving the merged image. 56 | color (bool): Whether convert intensity image to color image. 57 | tanh (bool): If True, will normalize the image in range [-1, 1] 58 | to [0, 1] (for GAN models). 59 | 60 | Example: 61 | The batch_size is 64, then the size is recommended [8, 8]. 62 | The batch_size is 32, then the size is recommended [6, 6]. 63 | """ 64 | 65 | # normalization of tanh output 66 | img = images 67 | 68 | if tanh: 69 | img = (img + 1.0) / 2.0 70 | 71 | if color: 72 | # TODO 73 | img_list = [] 74 | for im in np.squeeze(img): 75 | im = intensity_to_rgb(np.squeeze(im), normalize=True) 76 | img_list.append(im) 77 | img = np.array(img_list) 78 | # img = np.expand_dims(img, 0) 79 | 80 | if len(img.shape) == 2 or (len(img.shape) == 3 and img.shape[2] <= 4): 81 | img = np.expand_dims(img, 0) 82 | # img = images 83 | h, w = img.shape[1], img.shape[2] 84 | merge_img = np.zeros((h * merge_grid[0], w * merge_grid[1], 3)) 85 | if len(img.shape) < 4: 86 | img = np.expand_dims(img, -1) 87 | 88 | for idx, image in enumerate(img): 89 | i = idx % merge_grid[1] 90 | j = idx // merge_grid[1] 91 | merge_img[j*h:j*h+h, i*w:i*w+w, :] = image 92 | 93 | scipy.misc.imsave(save_path, merge_img) 94 | 95 | 96 | def image_overlay(im_1, im_2, color=True, normalize=True): 97 | """Overlay two images with the same size. 98 | 99 | Args: 100 | im_1 (np.ndarray): image arrary 101 | im_2 (np.ndarray): image arrary 102 | color (bool): Whether convert intensity image to color image. 103 | normalize (bool): If both color and normalize are True, will 104 | normalize the intensity so that it has minimum 0 and maximum 1. 105 | 106 | Returns: 107 | np.ndarray: an overlay image of im_1*0.5 + im_2*0.5 108 | """ 109 | if color: 110 | im_1 = intensity_to_rgb(np.squeeze(im_1), normalize=normalize) 111 | im_2 = intensity_to_rgb(np.squeeze(im_2), normalize=normalize) 112 | 113 | return im_1*0.5 + im_2*0.5 114 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorcv.utils.debug as debug 3 | 4 | 5 | debug.hello_cv() 6 | print(tf.__version__) 7 | 8 | -------------------------------------------------------------------------------- /test/VGG_pre_trained.py: -------------------------------------------------------------------------------- 1 | # File: VGG.py 2 | # Author: Qian Ge 3 | 4 | import argparse 5 | import os 6 | 7 | import numpy as np 8 | import tensorflow as tf 9 | import scipy 10 | import cv2 11 | 12 | from tensorcv.dataflow.image import * 13 | from tensorcv.callbacks import * 14 | from tensorcv.predicts import * 15 | from tensorcv.train.config import TrainConfig 16 | from tensorcv.train.simple import SimpleFeedTrainer 17 | 18 | import VGG 19 | import config 20 | 21 | VGG_MEAN = [103.939, 116.779, 123.68] 22 | 23 | # def load_VGG_model(session, model_path, skip_layer = []): 24 | # weights_dict = np.load(model_path, encoding='latin1').item() 25 | # for layer_name in weights_dict: 26 | # print(layer_name) 27 | # if layer_name not in skip_layer: 28 | # with tf.variable_scope(layer_name, reuse = True): 29 | # for data in weights_dict[layer_name]: 30 | # if len(data.shape) == 1: 31 | # var = tf.get_variable('biases', trainable = False) 32 | # session.run(var.assign(data)) 33 | # else: 34 | # var = tf.get_variable('weights', trainable = False) 35 | # session.run(var.assign(data)) 36 | 37 | def get_args(): 38 | parser = argparse.ArgumentParser() 39 | 40 | # parser.add_argument('--input_channel', default = 1, 41 | # help = 'Number of image channels') 42 | # parser.add_argument('--num_class', default = 2, 43 | # help = 'Number of classes') 44 | parser.add_argument('--batch_size', default = 128) 45 | 46 | parser.add_argument('--predict', help = 'Run prediction', action='store_true') 47 | parser.add_argument('--train', help = 'Train the model', action='store_true') 48 | 49 | return parser.parse_args() 50 | 51 | if __name__ == '__main__': 52 | Model = VGG.VGG19(num_class = 1000, 53 | num_channels = 3, 54 | im_height = 224, 55 | im_width = 224) 56 | 57 | keep_prob = tf.placeholder(tf.float32, name='keep_prob') 58 | image = tf.placeholder(tf.float32, name = 'image', 59 | shape = [None, 64, 64, 3]) 60 | input_im = tf.image.resize_images(image, [224, 224]) 61 | 62 | Model.create_model([input_im, keep_prob]) 63 | predict_op = tf.argmax(Model.output, dimension = -1) 64 | 65 | dataset_val = ImageLabelFromFile('.JPEG', data_dir = config.valid_data_dir, 66 | label_file_name = 'val_annotations.txt', 67 | num_channel = 3, 68 | label_dict = {}, 69 | shuffle = False) 70 | 71 | # dataset_val = ImageData('.JPEG', data_dir = config.valid_data_dir, 72 | # shuffle = False) 73 | 74 | dataset_val.setup(epoch_val = 0, batch_size = 32) 75 | # o_label_dict = dataset_val.label_dict_reverse 76 | 77 | 78 | 79 | word_dict = {} 80 | word_file = open(os.path.join('D:\\Qian\\GitHub\\workspace\\dataset\\tiny-imagenet-200\\tiny-imagenet-200\\', 81 | 'words.txt'), 'r') 82 | lines = word_file.read().split('\n') 83 | for line in lines: 84 | label, word = line.split('\t') 85 | word_dict[label] = word 86 | 87 | with tf.Session() as sess: 88 | sess.run(tf.global_variables_initializer()) 89 | load_VGG_model(sess, 'D:\\Qian\\GitHub\\workspace\\VGG\\vgg19.npy') 90 | batch_data = dataset_val.next_batch() 91 | 92 | result = sess.run(predict_op, feed_dict = {keep_prob: 1, image: batch_data[0]}) 93 | print(result) 94 | print([word_dict[o_label_dict[label]] for label in batch_data[1]]) 95 | 96 | 97 | 98 | -------------------------------------------------------------------------------- /test/config.py: -------------------------------------------------------------------------------- 1 | # File: config.py 2 | # Author: Qian Ge 3 | 4 | # directory of training data 5 | # data_dir = 'D:\\GoogleDrive_Qian\\Foram\\Training\\CNN_GAN_ORIGINAL_64\\' 6 | # data_dir = 'D:\\Qian\\GitHub\\workspace\\tensorflow-DCGAN\\cifar-10-python.tar\\') 7 | train_data_dir = 'D:\\Qian\\GitHub\\workspace\\dataset\\tiny-imagenet-200\\tiny-imagenet-200\\train\\' 8 | 9 | # directory of validataion data 10 | valid_data_dir = 'D:\\Qian\\GitHub\\workspace\\dataset\\tiny-imagenet-200\\tiny-imagenet-200\\val\\' 11 | 12 | # directory for saving inference data 13 | infer_dir = 'D:\\Qian\\GitHub\\workspace\\test\\result\\' 14 | 15 | # directory for saving summary 16 | summary_dir = 'D:\\Qian\\GitHub\\workspace\\test\\' 17 | 18 | # directory for saving checkpoint 19 | checkpoint_dir = 'D:\\Qian\\GitHub\\workspace\\test\\' 20 | 21 | # directory for restoring checkpoint 22 | model_dir = 'D:\\Qian\\GitHub\\workspace\\test\\' 23 | 24 | # directory for saving prediction results 25 | result_dir = 'D:\\Qian\\GitHub\\workspace\\test\\2\\' 26 | 27 | 28 | 29 | 30 | -------------------------------------------------------------------------------- /test/test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import numpy as np 4 | import tensorflow as tf 5 | 6 | from tensorcv.dataflow.matlab import MatlabData 7 | from tensorcv.models.layers import * 8 | from tensorcv.models.base import BaseModel 9 | from tensorcv.utils.common import apply_mask, get_tensors_by_names 10 | from tensorcv.train.config import TrainConfig 11 | from tensorcv.predicts.config import PridectConfig 12 | from tensorcv.train.simple import SimpleFeedTrainer 13 | from tensorcv.callbacks.saver import ModelSaver 14 | from tensorcv.callbacks.summary import TrainSummary 15 | from tensorcv.callbacks.inference import FeedInference 16 | from tensorcv.callbacks.monitors import TFSummaryWriter 17 | from tensorcv.callbacks.inferencer import InferScalars 18 | from tensorcv.predicts.simple import SimpleFeedPredictor 19 | from tensorcv.predicts.predictions import PredictionImage 20 | from tensorcv.callbacks.debug import CheckScalar 21 | 22 | class Model(BaseModel): 23 | def __init__(self, num_channels = 3, num_class = 2, 24 | learning_rate = 0.0001): 25 | self.learning_rate = learning_rate 26 | self.num_channels = num_channels 27 | self.num_class = num_class 28 | self.set_is_training(True) 29 | 30 | def _get_placeholder(self): 31 | return [self.image, self.gt, self.mask] 32 | # image, label, mask 33 | 34 | def _get_prediction_placeholder(self): 35 | return self.image 36 | 37 | def _get_graph_feed(self): 38 | if self.is_training: 39 | feed = {self.keep_prob: 0.5} 40 | else: 41 | feed = {self.keep_prob: 1} 42 | return feed 43 | 44 | def _create_graph(self): 45 | 46 | self.keep_prob = tf.placeholder(tf.float32, name='keep_prob') 47 | 48 | self.image = tf.placeholder(tf.float32, name = 'image', 49 | shape = [None, None, None, self.num_channels]) 50 | self.gt = tf.placeholder(tf.int64, [None, None, None], 'gt') 51 | self.mask = tf.placeholder(tf.int32, [None, None, None], 'mask') 52 | 53 | with tf.variable_scope('conv1') as scope: 54 | conv1 = conv(self.image, 5, 32, nl = tf.nn.relu) 55 | pool1 = max_pool(conv1, padding = 'SAME') 56 | 57 | with tf.variable_scope('conv2') as scope: 58 | conv2 = conv(pool1, 3, 48, nl = tf.nn.relu) 59 | pool2 = max_pool(conv2, padding = 'SAME') 60 | 61 | with tf.variable_scope('conv3') as scope: 62 | conv3 = conv(pool2, 3, 64, nl = tf.nn.relu) 63 | pool3 = max_pool(conv3, padding = 'SAME') 64 | 65 | with tf.variable_scope('conv4') as scope: 66 | conv4 = conv(pool3, 3, 128, nl = tf.nn.relu) 67 | pool4 = max_pool(conv4, padding = 'SAME') 68 | 69 | with tf.variable_scope('fc1') as scope: 70 | fc1 = conv(pool4, 2, 128, nl = tf.nn.relu) 71 | dropout_fc1 = dropout(fc1, self.keep_prob, self.is_training) 72 | 73 | with tf.variable_scope('fc2') as scope: 74 | fc2 = conv(dropout_fc1, 1, self.num_class) 75 | 76 | dconv1 = tf.add(dconv(fc2, 4, name = 'dconv1', 77 | out_shape_by_tensor = pool3), pool3) 78 | dconv2 = tf.add(dconv(dconv1, 4, name = 'dconv2', 79 | out_shape_by_tensor = pool2), pool2) 80 | dconv3 = dconv(dconv2, 16, self.num_class, 81 | out_shape_by_tensor = self.image, 82 | name = 'dconv3', stride = 4) 83 | 84 | with tf.name_scope('prediction'): 85 | self.prediction = tf.argmax(dconv3, name='label', dimension = -1) 86 | self.softmax_dconv3 = tf.nn.softmax(dconv3) 87 | prediction_pro = tf.identity(self.softmax_dconv3[:,:,:,1], 88 | name = 'probability') 89 | 90 | def _setup_graph(self): 91 | with tf.name_scope('accuracy'): 92 | correct_prediction = apply_mask( 93 | tf.equal(self.prediction, self.gt), 94 | self.mask) 95 | self.accuracy = tf.reduce_mean( 96 | tf.cast(correct_prediction, tf.float32), 97 | name = 'result') 98 | 99 | def _get_loss(self): 100 | with tf.name_scope('loss'): 101 | # This op expects unscaled logits, since it performs a softmax on logits internally for efficiency. Do not call this op with the output of softmax, as it will produce incorrect results. 102 | return tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits 103 | (logits = apply_mask(self.softmax_dconv3, self.mask), 104 | labels = apply_mask(self.gt, self.mask)), name = 'result') 105 | 106 | def _get_optimizer(self): 107 | return tf.train.AdamOptimizer(learning_rate = self.learning_rate) 108 | 109 | def _setup_summary(self): 110 | with tf.name_scope('train_summary'): 111 | tf.summary.image("train_Predict", 112 | tf.expand_dims(tf.cast(self.prediction, tf.float32), -1), 113 | collections = ['train']) 114 | tf.summary.image("im",tf.cast(self.image, tf.float32), 115 | collections = ['train']) 116 | tf.summary.image("gt", 117 | tf.expand_dims(tf.cast(self.gt, tf.float32), -1), 118 | collections = ['train']) 119 | tf.summary.image("mask", 120 | tf.expand_dims(tf.cast(self.mask, tf.float32), -1), 121 | collections = ['train']) 122 | tf.summary.scalar('train_accuracy', self.accuracy, 123 | collections = ['train']) 124 | with tf.name_scope('test_summary'): 125 | tf.summary.image("test_Predict", 126 | tf.expand_dims(tf.cast(self.prediction, tf.float32), -1), 127 | collections = ['test']) 128 | 129 | def get_config(FLAGS): 130 | mat_name_list = ['level1Edge', 'GT', 'Mask'] 131 | dataset_train = MatlabData('train', mat_name_list = mat_name_list, 132 | data_dir = FLAGS.data_dir) 133 | dataset_val = MatlabData('val', mat_name_list = mat_name_list, 134 | data_dir = FLAGS.data_dir) 135 | inference_list = [InferScalars('accuracy/result', 'test_accuracy')] 136 | 137 | return TrainConfig( 138 | dataflow = dataset_train, 139 | model = Model(num_channels = FLAGS.input_channel, 140 | num_class = FLAGS.num_class, 141 | learning_rate = 0.0001), 142 | monitors = TFSummaryWriter(summary_dir = FLAGS.summary_dir), 143 | callbacks = [ 144 | ModelSaver(periodic = 10, 145 | checkpoint_dir = FLAGS.summary_dir), 146 | TrainSummary(key = 'train', periodic = 10), 147 | FeedInference(dataset_val, periodic = 10, 148 | extra_cbs = TrainSummary(key = 'test'), 149 | inferencers = inference_list), 150 | # CheckScalar(['accuracy/result'], periodic = 10), 151 | ], 152 | batch_size = FLAGS.batch_size, 153 | max_epoch = 200, 154 | summary_periodic = 10) 155 | 156 | def get_predictConfig(FLAGS): 157 | mat_name_list = ['level1Edge'] 158 | dataset_test = MatlabData('Level_1', shuffle = False, 159 | mat_name_list = mat_name_list, 160 | data_dir = FLAGS.test_data_dir) 161 | prediction_list = PredictionImage(['prediction/label', 'prediction/probability'], 162 | ['test','test_pro'], 163 | merge_im = True) 164 | 165 | return PridectConfig( 166 | dataflow = dataset_test, 167 | model = Model(FLAGS.input_channel, 168 | num_class = FLAGS.num_class), 169 | model_name = 'model-14070', 170 | model_dir = FLAGS.model_dir, 171 | result_dir = FLAGS.result_dir, 172 | predictions = prediction_list, 173 | batch_size = FLAGS.batch_size) 174 | 175 | def get_args(): 176 | parser = argparse.ArgumentParser() 177 | parser.add_argument('--data_dir', 178 | help = 'Directory of input training data.', 179 | default = 'D:\\GoogleDrive_Qian\\Foram\\Training\\CNN_Image\\') 180 | parser.add_argument('--summary_dir', 181 | help = 'Directory for saving summary.', 182 | default = 'D:\\Qian\\GitHub\\workspace\\test\\') 183 | parser.add_argument('--checkpoint_dir', 184 | help = 'Directory for saving checkpoint.', 185 | default = 'D:\\Qian\\GitHub\\workspace\\test\\') 186 | 187 | parser.add_argument('--test_data_dir', 188 | help = 'Directory of input test data.', 189 | default = 'D:\\GoogleDrive_Qian\\Foram\\testing\\') 190 | parser.add_argument('--model_dir', 191 | help = 'Directory for restoring checkpoint.', 192 | default = 'D:\\Qian\\GitHub\\workspace\\test\\') 193 | parser.add_argument('--result_dir', 194 | help = 'Directory for saving prediction results.', 195 | default = 'D:\\Qian\\GitHub\\workspace\\test\\2\\') 196 | 197 | parser.add_argument('--input_channel', default = 1, 198 | help = 'Number of image channels') 199 | parser.add_argument('--num_class', default = 2, 200 | help = 'Number of classes') 201 | parser.add_argument('--batch_size', default = 1) 202 | 203 | parser.add_argument('--predict', help = 'Run prediction', action='store_true') 204 | parser.add_argument('--train', help = 'Train the model', action='store_true') 205 | 206 | return parser.parse_args() 207 | 208 | if __name__ == '__main__': 209 | 210 | FLAGS = get_args() 211 | if FLAGS.train: 212 | config = get_config(FLAGS) 213 | SimpleFeedTrainer(config).train() 214 | elif FLAGS.predict: 215 | config = get_predictConfig(FLAGS) 216 | SimpleFeedPredictor(config).run_predict() 217 | 218 | 219 | -------------------------------------------------------------------------------- /test/todo.md: -------------------------------------------------------------------------------- 1 | # todo 2 | 3 | ## utils.common 4 | - add multiple tensors with the same name in get_tensors_by_names 5 | --------------------------------------------------------------------------------