├── .gitignore
├── .travis.yml
├── LICENSE
├── README.md
├── docs
├── Makefile
├── _templates
│ └── about.html
├── conf.py
├── index.rst
├── make.bat
├── modules.rst
├── requirements.txt
├── tensorcv.callbacks.rst
├── tensorcv.dataflow.dataset.rst
├── tensorcv.dataflow.rst
├── tensorcv.models.rst
├── tensorcv.predicts.rst
├── tensorcv.rst
├── tensorcv.train.rst
└── tensorcv.utils.rst
├── requirements.txt
├── setup.py
├── tensorcv
├── __init__.py
├── algorithms
│ ├── GAN
│ │ ├── DCGAN.py
│ │ ├── README.md
│ │ ├── config.py
│ │ └── fig
│ │ │ └── mnist_result.png
│ └── pretrained
│ │ ├── VGG.py
│ │ └── VGG_.py
├── callbacks
│ ├── __init__.py
│ ├── base.py
│ ├── debug.py
│ ├── group.py
│ ├── hooks.py
│ ├── inference.py
│ ├── inferencer.py
│ ├── inputs.py
│ ├── monitors.py
│ ├── saver.py
│ ├── summary.py
│ └── trigger.py
├── data
│ └── imageNetLabel.txt
├── dataflow
│ ├── __init__.py
│ ├── argument.py
│ ├── base.py
│ ├── bk
│ │ └── image.py
│ ├── common.py
│ ├── dataset
│ │ ├── BSDS500.py
│ │ ├── CIFAR.py
│ │ ├── MNIST.py
│ │ └── __init__.py
│ ├── image.py
│ ├── matlab.py
│ ├── normalization.py
│ ├── operation.py
│ ├── preprocess.py
│ ├── randoms.py
│ ├── sequence.py
│ └── viz.py
├── models
│ ├── __init__.py
│ ├── base.py
│ ├── bk
│ │ └── layers.py
│ ├── layers.py
│ ├── losses.py
│ ├── model_builder
│ │ └── base.py
│ └── utils.py
├── predicts
│ ├── __init__.py
│ ├── base.py
│ ├── config.py
│ ├── predictions.py
│ └── simple.py
├── tfdataflow
│ ├── __init__.py
│ ├── base.py
│ ├── convert.py
│ └── write.py
├── train
│ ├── __init__.py
│ ├── base.py
│ ├── config.py
│ └── simple.py
└── utils
│ ├── __init__.py
│ ├── common.py
│ ├── debug.py
│ ├── default.py
│ ├── sesscreate.py
│ ├── utils.py
│ └── viz.py
├── test.py
└── test
├── GAN-1D.py
├── VGG.py
├── VGG_pre_trained.py
├── VGG_train.py
├── config.py
├── model.py
├── test.py
└── todo.md
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | env/
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 |
49 | # Translations
50 | *.mo
51 | *.pot
52 |
53 | # Django stuff:
54 | *.log
55 | local_settings.py
56 |
57 | # Flask stuff:
58 | instance/
59 | .webassets-cache
60 |
61 | # Scrapy stuff:
62 | .scrapy
63 |
64 | # Sphinx documentation
65 | docs/_build/
66 |
67 | # PyBuilder
68 | target/
69 |
70 | # Jupyter Notebook
71 | .ipynb_checkpoints
72 |
73 | # pyenv
74 | .python-version
75 |
76 | # celery beat schedule file
77 | celerybeat-schedule
78 |
79 | # SageMath parsed files
80 | *.sage.py
81 |
82 | # dotenv
83 | .env
84 |
85 | # virtualenv
86 | .venv
87 | venv/
88 | ENV/
89 |
90 | # Spyder project settings
91 | .spyderproject
92 | .spyproject
93 |
94 | # Rope project settings
95 | .ropeproject
96 |
97 | # mkdocs documentation
98 | /site
99 |
100 | # mypy
101 | .mypy_cache/
102 |
--------------------------------------------------------------------------------
/.travis.yml:
--------------------------------------------------------------------------------
1 | language: python
2 | python:
3 | - "3.4"
4 | install:
5 | - pip install -r requirements.txt
6 | - pip install tensorflow
7 | # - pip install flake8
8 | - pip install coveralls
9 | script:
10 | # - flake8 . --ignore=F405,F403
11 | - python test.py
12 | - coverage run --source=. test.py
13 | after_success:
14 | - coveralls
15 |
16 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Qian Ge
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # A deep learning package for computer vision algorithms built on top of TensorFlow
2 |
3 | [](https://travis-ci.org/conan7882/DeepVision-tensorflow)
4 |
5 | This package is for practicing and reproducing learning-based computer vision algorithms in recent papers.
6 |
7 | It is a simplified version of [tensorpack](https://github.com/ppwwyyxx/tensorpack), which is the reference code of this package.
8 |
9 | # Install
10 |
11 | ```
12 | pip install -r requirements.txt
13 | pip install -U git+https://github.com/conan7882/DeepVision-tensorflow.git
14 | ```
15 |
16 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line.
5 | SPHINXOPTS =
6 | SPHINXBUILD = python -msphinx
7 | SPHINXPROJ = DeepVision-TensorFlow
8 | SOURCEDIR = .
9 | BUILDDIR = _build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
--------------------------------------------------------------------------------
/docs/_templates/about.html:
--------------------------------------------------------------------------------
1 | {% extends "!layout.html" %}
2 |
3 | {%- block extrahead %}
4 |
7 |
10 |
11 |
19 | {% endblock %}
--------------------------------------------------------------------------------
/docs/index.rst:
--------------------------------------------------------------------------------
1 | Test Documentation
2 | --------------------
3 |
4 | .. toctree::
5 | :maxdepth: 1
6 |
7 | modules
8 |
--------------------------------------------------------------------------------
/docs/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=python -msphinx
9 | )
10 | set SOURCEDIR=.
11 | set BUILDDIR=_build
12 | set SPHINXPROJ=DeepVision-TensorFlow
13 |
14 | if "%1" == "" goto help
15 |
16 | %SPHINXBUILD% >NUL 2>NUL
17 | if errorlevel 9009 (
18 | echo.
19 | echo.The Sphinx module was not found. Make sure you have Sphinx installed,
20 | echo.then set the SPHINXBUILD environment variable to point to the full
21 | echo.path of the 'sphinx-build' executable. Alternatively you may add the
22 | echo.Sphinx directory to PATH.
23 | echo.
24 | echo.If you don't have Sphinx installed, grab it from
25 | echo.http://sphinx-doc.org/
26 | exit /b 1
27 | )
28 |
29 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
30 | goto end
31 |
32 | :help
33 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
34 |
35 | :end
36 | popd
37 |
--------------------------------------------------------------------------------
/docs/modules.rst:
--------------------------------------------------------------------------------
1 | tensorcv
2 | ========
3 |
4 | .. toctree::
5 | :maxdepth: 4
6 |
7 | tensorcv
8 |
--------------------------------------------------------------------------------
/docs/requirements.txt:
--------------------------------------------------------------------------------
1 | termcolor
2 | numpy
3 | tqdm
4 | Sphinx>=1.6
5 | recommonmark==0.4.0
6 | guzzle_sphinx_theme
7 | mock
8 | tensorflow
9 | matplotlib
10 | Pillow
--------------------------------------------------------------------------------
/docs/tensorcv.callbacks.rst:
--------------------------------------------------------------------------------
1 | tensorcv\.callbacks package
2 | ===========================
3 |
4 | Submodules
5 | ----------
6 |
7 | tensorcv\.callbacks\.base module
8 | --------------------------------
9 |
10 | .. automodule:: tensorcv.callbacks.base
11 | :members:
12 | :undoc-members:
13 | :show-inheritance:
14 |
15 | tensorcv\.callbacks\.debug module
16 | ---------------------------------
17 |
18 | .. automodule:: tensorcv.callbacks.debug
19 | :members:
20 | :undoc-members:
21 | :show-inheritance:
22 |
23 | tensorcv\.callbacks\.group module
24 | ---------------------------------
25 |
26 | .. automodule:: tensorcv.callbacks.group
27 | :members:
28 | :undoc-members:
29 | :show-inheritance:
30 |
31 | tensorcv\.callbacks\.hooks module
32 | ---------------------------------
33 |
34 | .. automodule:: tensorcv.callbacks.hooks
35 | :members:
36 | :undoc-members:
37 | :show-inheritance:
38 |
39 | tensorcv\.callbacks\.inference module
40 | -------------------------------------
41 |
42 | .. automodule:: tensorcv.callbacks.inference
43 | :members:
44 | :undoc-members:
45 | :show-inheritance:
46 |
47 | tensorcv\.callbacks\.inferencer module
48 | --------------------------------------
49 |
50 | .. automodule:: tensorcv.callbacks.inferencer
51 | :members:
52 | :undoc-members:
53 | :show-inheritance:
54 |
55 | tensorcv\.callbacks\.inputs module
56 | ----------------------------------
57 |
58 | .. automodule:: tensorcv.callbacks.inputs
59 | :members:
60 | :undoc-members:
61 | :show-inheritance:
62 |
63 | tensorcv\.callbacks\.monitors module
64 | ------------------------------------
65 |
66 | .. automodule:: tensorcv.callbacks.monitors
67 | :members:
68 | :undoc-members:
69 | :show-inheritance:
70 |
71 | tensorcv\.callbacks\.saver module
72 | ---------------------------------
73 |
74 | .. automodule:: tensorcv.callbacks.saver
75 | :members:
76 | :undoc-members:
77 | :show-inheritance:
78 |
79 | tensorcv\.callbacks\.summary module
80 | -----------------------------------
81 |
82 | .. automodule:: tensorcv.callbacks.summary
83 | :members:
84 | :undoc-members:
85 | :show-inheritance:
86 |
87 | tensorcv\.callbacks\.trigger module
88 | -----------------------------------
89 |
90 | .. automodule:: tensorcv.callbacks.trigger
91 | :members:
92 | :undoc-members:
93 | :show-inheritance:
94 |
95 |
96 | Module contents
97 | ---------------
98 |
99 | .. automodule:: tensorcv.callbacks
100 | :members:
101 | :undoc-members:
102 | :show-inheritance:
103 |
--------------------------------------------------------------------------------
/docs/tensorcv.dataflow.dataset.rst:
--------------------------------------------------------------------------------
1 | tensorcv\.dataflow\.dataset package
2 | ===================================
3 |
4 | Submodules
5 | ----------
6 |
7 | tensorcv\.dataflow\.dataset\.BSDS500 module
8 | -------------------------------------------
9 |
10 | .. automodule:: tensorcv.dataflow.dataset.BSDS500
11 | :members:
12 | :undoc-members:
13 | :show-inheritance:
14 |
15 | tensorcv\.dataflow\.dataset\.CIFAR module
16 | -----------------------------------------
17 |
18 | .. automodule:: tensorcv.dataflow.dataset.CIFAR
19 | :members:
20 | :undoc-members:
21 | :show-inheritance:
22 |
23 | tensorcv\.dataflow\.dataset\.MNIST module
24 | -----------------------------------------
25 |
26 | .. automodule:: tensorcv.dataflow.dataset.MNIST
27 | :members:
28 | :undoc-members:
29 | :show-inheritance:
30 |
31 |
32 | Module contents
33 | ---------------
34 |
35 | .. automodule:: tensorcv.dataflow.dataset
36 | :members:
37 | :undoc-members:
38 | :show-inheritance:
39 |
--------------------------------------------------------------------------------
/docs/tensorcv.dataflow.rst:
--------------------------------------------------------------------------------
1 | tensorcv\.dataflow package
2 | ==========================
3 |
4 | Subpackages
5 | -----------
6 |
7 | .. toctree::
8 |
9 | tensorcv.dataflow.dataset
10 |
11 | Submodules
12 | ----------
13 |
14 | tensorcv\.dataflow\.base module
15 | -------------------------------
16 |
17 | .. automodule:: tensorcv.dataflow.base
18 | :members:
19 | :undoc-members:
20 | :show-inheritance:
21 |
22 | tensorcv\.dataflow\.common module
23 | ---------------------------------
24 |
25 | .. automodule:: tensorcv.dataflow.common
26 | :members:
27 | :undoc-members:
28 | :show-inheritance:
29 |
30 | tensorcv\.dataflow\.image module
31 | --------------------------------
32 |
33 | .. automodule:: tensorcv.dataflow.image
34 | :members:
35 | :undoc-members:
36 | :show-inheritance:
37 |
38 | tensorcv\.dataflow\.matlab module
39 | ---------------------------------
40 |
41 | .. automodule:: tensorcv.dataflow.matlab
42 | :members:
43 | :undoc-members:
44 | :show-inheritance:
45 |
46 | tensorcv\.dataflow\.randoms module
47 | ----------------------------------
48 |
49 | .. automodule:: tensorcv.dataflow.randoms
50 | :members:
51 | :undoc-members:
52 | :show-inheritance:
53 |
54 |
55 | Module contents
56 | ---------------
57 |
58 | .. automodule:: tensorcv.dataflow
59 | :members:
60 | :undoc-members:
61 | :show-inheritance:
62 |
--------------------------------------------------------------------------------
/docs/tensorcv.models.rst:
--------------------------------------------------------------------------------
1 | tensorcv\.models package
2 | ========================
3 |
4 | Submodules
5 | ----------
6 |
7 | tensorcv\.models\.base module
8 | -----------------------------
9 |
10 | .. automodule:: tensorcv.models.base
11 | :members:
12 | :undoc-members:
13 | :show-inheritance:
14 |
15 | tensorcv\.models\.layers module
16 | -------------------------------
17 |
18 | .. automodule:: tensorcv.models.layers
19 | :members:
20 | :undoc-members:
21 | :show-inheritance:
22 |
23 | tensorcv\.models\.losses module
24 | -------------------------------
25 |
26 | .. automodule:: tensorcv.models.losses
27 | :members:
28 | :undoc-members:
29 | :show-inheritance:
30 |
31 |
32 | Module contents
33 | ---------------
34 |
35 | .. automodule:: tensorcv.models
36 | :members:
37 | :undoc-members:
38 | :show-inheritance:
39 |
--------------------------------------------------------------------------------
/docs/tensorcv.predicts.rst:
--------------------------------------------------------------------------------
1 | tensorcv\.predicts package
2 | ==========================
3 |
4 | Submodules
5 | ----------
6 |
7 | tensorcv\.predicts\.base module
8 | -------------------------------
9 |
10 | .. automodule:: tensorcv.predicts.base
11 | :members:
12 | :undoc-members:
13 | :show-inheritance:
14 |
15 | tensorcv\.predicts\.config module
16 | ---------------------------------
17 |
18 | .. automodule:: tensorcv.predicts.config
19 | :members:
20 | :undoc-members:
21 | :show-inheritance:
22 |
23 | tensorcv\.predicts\.predictions module
24 | --------------------------------------
25 |
26 | .. automodule:: tensorcv.predicts.predictions
27 | :members:
28 | :undoc-members:
29 | :show-inheritance:
30 |
31 | tensorcv\.predicts\.simple module
32 | ---------------------------------
33 |
34 | .. automodule:: tensorcv.predicts.simple
35 | :members:
36 | :undoc-members:
37 | :show-inheritance:
38 |
39 |
40 | Module contents
41 | ---------------
42 |
43 | .. automodule:: tensorcv.predicts
44 | :members:
45 | :undoc-members:
46 | :show-inheritance:
47 |
--------------------------------------------------------------------------------
/docs/tensorcv.rst:
--------------------------------------------------------------------------------
1 | tensorcv package
2 | ================
3 |
4 | Subpackages
5 | -----------
6 |
7 | .. toctree::
8 |
9 | tensorcv.callbacks
10 | tensorcv.dataflow
11 | tensorcv.models
12 | tensorcv.predicts
13 | tensorcv.train
14 | tensorcv.utils
15 |
16 | Module contents
17 | ---------------
18 |
19 | .. automodule:: tensorcv
20 | :members:
21 | :undoc-members:
22 | :show-inheritance:
23 |
--------------------------------------------------------------------------------
/docs/tensorcv.train.rst:
--------------------------------------------------------------------------------
1 | tensorcv\.train package
2 | =======================
3 |
4 | Submodules
5 | ----------
6 |
7 | tensorcv\.train\.base module
8 | ----------------------------
9 |
10 | .. automodule:: tensorcv.train.base
11 | :members:
12 | :undoc-members:
13 | :show-inheritance:
14 |
15 | tensorcv\.train\.config module
16 | ------------------------------
17 |
18 | .. automodule:: tensorcv.train.config
19 | :members:
20 | :undoc-members:
21 | :show-inheritance:
22 |
23 | tensorcv\.train\.simple module
24 | ------------------------------
25 |
26 | .. automodule:: tensorcv.train.simple
27 | :members:
28 | :undoc-members:
29 | :show-inheritance:
30 |
31 |
32 | Module contents
33 | ---------------
34 |
35 | .. automodule:: tensorcv.train
36 | :members:
37 | :undoc-members:
38 | :show-inheritance:
39 |
--------------------------------------------------------------------------------
/docs/tensorcv.utils.rst:
--------------------------------------------------------------------------------
1 | tensorcv\.utils package
2 | =======================
3 |
4 | Submodules
5 | ----------
6 |
7 | tensorcv\.utils\.common module
8 | ------------------------------
9 |
10 | .. automodule:: tensorcv.utils.common
11 | :members:
12 | :undoc-members:
13 | :show-inheritance:
14 |
15 | tensorcv\.utils\.default module
16 | -------------------------------
17 |
18 | .. automodule:: tensorcv.utils.default
19 | :members:
20 | :undoc-members:
21 | :show-inheritance:
22 |
23 | tensorcv\.utils\.sesscreate module
24 | ----------------------------------
25 |
26 | .. automodule:: tensorcv.utils.sesscreate
27 | :members:
28 | :undoc-members:
29 | :show-inheritance:
30 |
31 | tensorcv\.utils\.utils module
32 | -----------------------------
33 |
34 | .. automodule:: tensorcv.utils.utils
35 | :members:
36 | :undoc-members:
37 | :show-inheritance:
38 |
39 | tensorcv\.utils\.viz module
40 | -----------------------------
41 |
42 | .. automodule:: tensorcv.utils.viz
43 | :members:
44 | :undoc-members:
45 | :show-inheritance:
46 |
47 |
48 | Module contents
49 | ---------------
50 |
51 | .. automodule:: tensorcv.utils
52 | :members:
53 | :undoc-members:
54 | :show-inheritance:
55 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy
2 | Pillow
3 | six
4 | termcolor>=1.1
5 | tabulate>=0.7.7
6 | tqdm>4.11.1
7 | msgpack-python>0.4.0
8 | msgpack-numpy>=0.3.9
9 | pyzmq>=16
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 |
3 | setup(name='tensorcv',
4 | version='0.1',
5 | description=' ',
6 | url='https://github.com/conan7882/DeepVision-tensorflow',
7 | author='Qian Ge',
8 | author_email='geqian1001@gmail.com',
9 | packages=['tensorcv', 'tensorcv.utils', 'tensorcv.algorithms', 'tensorcv.callbacks',
10 | 'tensorcv.data', 'tensorcv.dataflow', 'tensorcv.dataflow.dataset',
11 | 'tensorcv.models', 'tensorcv.predicts', 'tensorcv.train', 'tensorcv.tfdataflow'],
12 | zip_safe=False)
13 |
--------------------------------------------------------------------------------
/tensorcv/__init__.py:
--------------------------------------------------------------------------------
1 | # File: __init__.py
2 | # Author: Qian Ge
3 |
--------------------------------------------------------------------------------
/tensorcv/algorithms/GAN/DCGAN.py:
--------------------------------------------------------------------------------
1 | # File: DCGAN.py
2 | # Author: Qian Ge
3 |
4 | import argparse
5 |
6 | import tensorflow as tf
7 |
8 | from tensorcv.dataflow.randoms import RandomVec
9 | from tensorcv.dataflow.dataset.MNIST import MNIST
10 | # import tensorcv.callbacks as cb
11 | from tensorcv.callbacks import *
12 | from tensorcv.predicts import *
13 | from tensorcv.models.layers import *
14 | from tensorcv.models.losses import *
15 | from tensorcv.predicts.simple import SimpleFeedPredictor
16 | from tensorcv.predicts.config import PridectConfig
17 | from tensorcv.models.base import GANBaseModel
18 | from tensorcv.train.config import GANTrainConfig
19 | from tensorcv.train.simple import GANFeedTrainer
20 | from tensorcv.utils.common import deconv_size
21 |
22 | import config as config_path
23 |
24 |
25 | class Model(GANBaseModel):
26 | def __init__(self,
27 | input_vec_length=100,
28 | learning_rate=[0.0002, 0.0002],
29 | num_channels=None,
30 | im_size=None):
31 |
32 | super(Model, self).__init__(input_vec_length, learning_rate)
33 |
34 | if num_channels is not None:
35 | self.num_channels = num_channels
36 | if im_size is not None:
37 | self.im_height, self.im_width = im_size
38 |
39 | self.set_is_training(True)
40 |
41 | # def _get_placeholder(self):
42 | # # image
43 | # return [self.real_data]
44 |
45 | def _create_input(self):
46 | self.real_data = tf.placeholder(
47 | tf.float32,
48 | [None, self.im_height, self.im_width, self.num_channels])
49 | self.set_train_placeholder(self.real_data)
50 |
51 | def _generator(self, train=True):
52 |
53 | final_dim = 64
54 | filter_size = 5
55 |
56 | d_height_2, d_width_2 = deconv_size(self.im_height, self.im_width)
57 | d_height_4, d_width_4 = deconv_size(d_height_2, d_width_2)
58 | d_height_8, d_width_8 = deconv_size(d_height_4, d_width_4)
59 | d_height_16, d_width_16 = deconv_size(d_height_8, d_width_8)
60 |
61 | rand_vec = self.get_random_vec_placeholder()
62 | batch_size = tf.shape(rand_vec)[0]
63 |
64 | with tf.variable_scope('fc1'):
65 | fc1 = fc(rand_vec, d_height_16 * d_width_16 * final_dim * 8, 'fc')
66 | fc1 = tf.nn.relu(batch_norm(fc1, train=train))
67 | fc1_reshape = tf.reshape(
68 | fc1, [-1, d_height_16, d_width_16, final_dim * 8])
69 |
70 | with tf.variable_scope('dconv2'):
71 | output_shape = [batch_size, d_height_8, d_width_8, final_dim * 4]
72 | dconv2 = dconv(fc1_reshape, filter_size,
73 | out_shape=output_shape, name='dconv')
74 | bn_dconv2 = tf.nn.relu(batch_norm(dconv2, train=train))
75 |
76 | with tf.variable_scope('dconv3'):
77 | output_shape = [batch_size, d_height_4, d_width_4, final_dim * 2]
78 | dconv3 = dconv(bn_dconv2, filter_size,
79 | out_shape=output_shape, name='dconv')
80 | bn_dconv3 = tf.nn.relu(batch_norm(dconv3, train=train))
81 |
82 | with tf.variable_scope('dconv4'):
83 | output_shape = [batch_size, d_height_2, d_width_2, final_dim]
84 | dconv4 = dconv(bn_dconv3, filter_size,
85 | out_shape=output_shape, name='dconv')
86 | bn_dconv4 = tf.nn.relu(batch_norm(dconv4, train=train))
87 |
88 | with tf.variable_scope('dconv5'):
89 | # Do not use batch norm for the last layer
90 | output_shape = [batch_size, self.im_height,
91 | self.im_width, self.num_channels]
92 | dconv5 = dconv(bn_dconv4, filter_size,
93 | out_shape=output_shape, name='dconv')
94 |
95 | generation = tf.nn.tanh(dconv5, 'gen_out')
96 | return generation
97 |
98 | def _discriminator(self, input_im):
99 |
100 | filter_size = 5
101 | start_depth = 64
102 |
103 | with tf.variable_scope('conv1'):
104 | conv1 = conv(input_im, filter_size, start_depth, stride=2)
105 | bn_conv1 = leaky_relu((batch_norm(conv1)))
106 |
107 | with tf.variable_scope('conv2'):
108 | conv2 = conv(bn_conv1, filter_size, start_depth * 2, stride=2)
109 | bn_conv2 = leaky_relu((batch_norm(conv2)))
110 |
111 | with tf.variable_scope('conv3'):
112 | conv3 = conv(bn_conv2, filter_size, start_depth * 4, stride=2)
113 | bn_conv3 = leaky_relu((batch_norm(conv3)))
114 |
115 | with tf.variable_scope('conv4'):
116 | conv4 = conv(bn_conv3, filter_size, start_depth * 8, stride=2)
117 | bn_conv4 = leaky_relu((batch_norm(conv4)))
118 |
119 | with tf.variable_scope('fc5'):
120 | fc5 = fc(bn_conv4, 1, name='fc')
121 |
122 | return fc5
123 |
124 | def _ex_setup_graph(self):
125 | tf.identity(self.get_sample_gen_data(), 'generate_image')
126 | tf.identity(self.get_generator_loss(), 'g_loss_check')
127 | tf.identity(self.get_discriminator_loss(), 'd_loss_check')
128 |
129 | def _setup_summary(self):
130 | with tf.name_scope('generator_summary'):
131 | tf.summary.image('generate_sample',
132 | tf.cast(self.get_sample_gen_data(), tf.float32),
133 | collections=[self.g_collection])
134 | tf.summary.image('generate_train',
135 | tf.cast(self.get_gen_data(), tf.float32),
136 | collections=[self.d_collection])
137 | with tf.name_scope('real_data'):
138 | tf.summary.image('real_data',
139 | tf.cast(self.real_data, tf.float32),
140 | collections=[self.d_collection])
141 |
142 |
143 | def get_config(FLAGS):
144 | dataset_train = MNIST('train', data_dir=config_path.data_dir,
145 | normalize='tanh')
146 |
147 | inference_list = InferImages('generate_image', prefix='gen')
148 | random_feed = RandomVec(len_vec=FLAGS.len_vec)
149 |
150 | return GANTrainConfig(
151 | dataflow=dataset_train,
152 | model=Model(input_vec_length=FLAGS.len_vec,
153 | learning_rate=[0.0002, 0.0002]),
154 | monitors=TFSummaryWriter(),
155 | discriminator_callbacks=[
156 | # ModelSaver(periodic = 100),
157 | CheckScalar(['d_loss_check', 'g_loss_check'],
158 | periodic=10)],
159 | generator_callbacks=[GANInference(inputs=random_feed,
160 | periodic=100,
161 | inferencers=inference_list)],
162 | batch_size=FLAGS.batch_size,
163 | max_epoch=100,
164 | summary_d_periodic=10,
165 | summary_g_periodic=10,
166 | default_dirs=config_path)
167 |
168 |
169 | def get_predictConfig(FLAGS):
170 | random_feed = RandomVec(len_vec=FLAGS.len_vec)
171 | prediction_list = PredictionImage('generate_image',
172 | 'test', merge_im=True, tanh=True)
173 | im_size = [FLAGS.h, FLAGS.w]
174 | return PridectConfig(dataflow=random_feed,
175 | model=Model(input_vec_length=FLAGS.len_vec,
176 | num_channels=FLAGS.input_channel,
177 | im_size=im_size),
178 | model_name='model-100',
179 | predictions=prediction_list,
180 | batch_size=FLAGS.batch_size,
181 | default_dirs=config_path)
182 |
183 |
184 | def get_args():
185 | parser = argparse.ArgumentParser()
186 |
187 | parser.add_argument('--len_vec', default=100, type=int,
188 | help='Length of input random vector')
189 | parser.add_argument('--input_channel', default=1, type=int,
190 | help='Number of image channels')
191 | parser.add_argument('--h', default=28, type=int,
192 | help='Heigh of input images')
193 | parser.add_argument('--w', default=28, type=int,
194 | help='Width of input images')
195 | parser.add_argument('--batch_size', default=64, type=int)
196 |
197 | parser.add_argument('--predict', action='store_true',
198 | help='Run prediction')
199 | parser.add_argument('--train', action='store_true',
200 | help='Train the model')
201 |
202 | return parser.parse_args()
203 |
204 |
205 | if __name__ == '__main__':
206 |
207 | FLAGS = get_args()
208 |
209 | if FLAGS.train:
210 | config = get_config(FLAGS)
211 | GANFeedTrainer(config).train()
212 | elif FLAGS.predict:
213 | config = get_predictConfig(FLAGS)
214 | SimpleFeedPredictor(config).run_predict()
215 |
--------------------------------------------------------------------------------
/tensorcv/algorithms/GAN/README.md:
--------------------------------------------------------------------------------
1 | ## Deep Convolutional Generative Adversarial Networks (DCGAN)
2 |
3 | - Model for [DCGAN](https://arxiv.org/abs/1511.06434)
4 |
5 | - An example implementation of DCGAN using this model can be found [here](https://github.com/conan7882/tensorflow-DCGAN).
6 |
7 | *Details of how to write your own GAN model and callbacks configuration can be found in docs (coming soon).*
8 |
9 | ## Implementation Details
10 | #### Generator
11 |
12 | #### Discriminator
13 |
14 | #### Loss function
15 |
16 | #### Optimizer
17 |
18 | #### Variable initialization
19 |
20 | #### Batch normal and LeakyReLu
21 |
22 | #### Training settings
23 | - training rate
24 | - training step
25 |
26 | ## Default Summary
27 | ### Scalar:
28 | - loss of generator and discriminator
29 |
30 | ### Histogram:
31 | - gradients of generator and discriminator
32 | - discriminator output for real image and generated image
33 |
34 | ### Image
35 | - real image and generated image
36 |
37 | ## Callbacks
38 |
39 | ### Available callbacks:
40 |
41 | - TrainSummary()
42 | - CheckScalar()
43 | - GANInference()
44 |
45 | ### Available inferencer:
46 | - InferImages()
47 |
48 | ### Available predictor
49 | - PredictionImage()
50 |
51 | ## Test
52 | To test this model on MNIST dataset, first put all the directories in *config.py*.
53 |
54 | For training:
55 |
56 | $ python DCGAN.py --train --batch_size 64
57 |
58 | For testing, the batch size has to be the same as training:
59 |
60 | $ python DCGAN.py --predict --batch_size 64
61 |
62 | Using this model run on other dataset can be found [here](https://github.com/conan7882/tensorflow-DCGAN).
63 |
64 |
65 | ## Results
66 | *More results can be found [here](https://github.com/conan7882/tensorflow-DCGAN#results).*
67 | ### MNIST
68 |
69 | 
70 |
71 |
72 | ## Reference
73 | - [Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks.](https://arxiv.org/abs/1511.06434)
74 |
75 |
76 |
77 |
--------------------------------------------------------------------------------
/tensorcv/algorithms/GAN/config.py:
--------------------------------------------------------------------------------
1 | # File: config.py
2 | # Author: Qian Ge
3 |
4 | # directory of input data
5 | data_dir = 'D:\\Qian\\GitHub\\workspace\\tensorflow-DCGAN\\MNIST_data\\'
6 |
7 | # directory for saving inference data
8 | infer_dir = 'D:\\Qian\\GitHub\\workspace\\test\\result\\'
9 |
10 | # directory for saving summary
11 | summary_dir = 'D:\\Qian\\GitHub\\workspace\\test\\'
12 |
13 | # directory for saving checkpoint
14 | checkpoint_dir = 'D:\\Qian\\GitHub\\workspace\\test\\'
15 |
16 | # directory for restoring checkpoint
17 | model_dir = 'D:\\Qian\\GitHub\\workspace\\test\\'
18 |
19 | # directory for saving prediction results
20 | result_dir = 'D:\\Qian\\GitHub\\workspace\\test\\2\\'
21 |
--------------------------------------------------------------------------------
/tensorcv/algorithms/GAN/fig/mnist_result.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/conan7882/DeepVision-tensorflow/0ffc81a62eccf021077019fb59b0e9e7615e8222/tensorcv/algorithms/GAN/fig/mnist_result.png
--------------------------------------------------------------------------------
/tensorcv/algorithms/pretrained/VGG_.py:
--------------------------------------------------------------------------------
1 | # File: VGG.py
2 | # Author: Qian Ge
3 |
4 | import numpy as np
5 | import tensorflow as tf
6 |
7 | from tensorcv.models.layers import *
8 | from tensorcv.models.base import BaseModel
9 |
10 | import config
11 |
12 | VGG_MEAN = [103.939, 116.779, 123.68]
13 |
14 | class BaseVGG(BaseModel):
15 | """ base of VGG class """
16 | def __init__(self, num_class=1000,
17 | num_channels=3,
18 | im_height=224, im_width=224,
19 | learning_rate=0.0001,
20 | is_load=False,
21 | pre_train_path=None):
22 | """
23 | Args:
24 | num_class (int): number of image classes
25 | num_channels (int): number of input channels
26 | im_height, im_width (int): size of input image
27 | Can be unknown when testing.
28 | learning_rate (float): learning rate of training
29 | """
30 |
31 | self.learning_rate = learning_rate
32 | self.num_channels = num_channels
33 | self.im_height = im_height
34 | self.im_width = im_width
35 | self.num_class = num_class
36 |
37 | self._is_load = is_load
38 | if self._is_load and pre_train_path is None:
39 | raise ValueError('pre_train_path can not be None!')
40 | self._pre_train_path = pre_train_path
41 |
42 | self.set_is_training(True)
43 |
44 | def _create_input(self):
45 | self.keep_prob = tf.placeholder(tf.float32, name='keep_prob')
46 | self.image = tf.placeholder(tf.float32, name='image',
47 | shape=[None, self.im_height, self.im_width, self.num_channels])
48 | self.label = tf.placeholder(tf.int64, [None], 'label')
49 | # self.label = tf.placeholder(tf.int64, [None, self.num_class], 'label')
50 |
51 | self.set_model_input([self.image, self.keep_prob])
52 | self.set_dropout(self.keep_prob, keep_prob=0.5)
53 | self.set_train_placeholder([self.image, self.label])
54 | self.set_prediction_placeholder(self.image)
55 |
56 | @staticmethod
57 | def load_pre_trained(session, model_path, skip_layer=[]):
58 | weights_dict = np.load(model_path, encoding='latin1').item()
59 | for layer_name in weights_dict:
60 | print('Loading ' + layer_name)
61 | if layer_name not in skip_layer:
62 | with tf.variable_scope(layer_name, reuse=True):
63 | for data in weights_dict[layer_name]:
64 | if len(data.shape) == 1:
65 | var = tf.get_variable('biases', trainable=False)
66 | session.run(var.assign(data))
67 | else:
68 | var = tf.get_variable('weights', trainable=False)
69 | session.run(var.assign(data))
70 |
71 | class VGG19(BaseVGG):
72 | # def __init__(self, num_class = 1000,
73 | # num_channels = 3,
74 | # im_height = 224, im_width = 224,
75 | # learning_rate = 0.0001):
76 |
77 | # super(VGG19, self).__init__(num_class = num_class,
78 | # num_channels = num_channels,
79 | # im_height = im_height,
80 | # im_width = im_width,
81 | # learning_rate = learning_rate)
82 |
83 | def _create_conv(self, input_im, data_dict):
84 |
85 | arg_scope = tf.contrib.framework.arg_scope
86 | with arg_scope([conv], nl=tf.nn.relu, trainable=True, data_dict=data_dict):
87 | conv1_1 = conv(input_im, 3, 64, 'conv1_1')
88 | conv1_2 = conv(conv1_1, 3, 64, 'conv1_2')
89 | pool1 = max_pool(conv1_2, 'pool1', padding='SAME')
90 |
91 | conv2_1 = conv(pool1, 3, 128, 'conv2_1')
92 | conv2_2 = conv(conv2_1, 3, 128, 'conv2_2')
93 | pool2 = max_pool(conv2_2, 'pool2', padding='SAME')
94 |
95 | conv3_1 = conv(pool2, 3, 256, 'conv3_1')
96 | conv3_2 = conv(conv3_1, 3, 256, 'conv3_2')
97 | conv3_3 = conv(conv3_2, 3, 256, 'conv3_3')
98 | conv3_4 = conv(conv3_3, 3, 256, 'conv3_4')
99 | pool3 = max_pool(conv3_4, 'pool3', padding='SAME')
100 |
101 | conv4_1 = conv(pool3, 3, 512, 'conv4_1')
102 | conv4_2 = conv(conv4_1, 3, 512, 'conv4_2')
103 | conv4_3 = conv(conv4_2, 3, 512, 'conv4_3')
104 | conv4_4 = conv(conv4_3, 3, 512, 'conv4_4')
105 | pool4 = max_pool(conv4_4, 'pool4', padding='SAME')
106 |
107 | conv5_1 = conv(pool4, 3, 512, 'conv5_1')
108 | conv5_2 = conv(conv5_1, 3, 512, 'conv5_2')
109 | conv5_3 = conv(conv5_2, 3, 512, 'conv5_3')
110 | conv5_4 = conv(conv5_3, 3, 512, 'conv5_4')
111 | pool5 = max_pool(conv5_4, 'pool5', padding='SAME')
112 |
113 | self.conv_out = tf.identity(conv5_4)
114 |
115 | return pool5
116 |
117 | def _create_model(self):
118 |
119 | input_im = self.model_input[0]
120 | keep_prob = self.model_input[1]
121 |
122 | # Convert RGB image to BGR image
123 | red, green, blue = tf.split(axis=3, num_or_size_splits=3,
124 | value=input_im)
125 |
126 | input_bgr = tf.concat(axis=3, values=[
127 | blue - VGG_MEAN[0],
128 | green - VGG_MEAN[1],
129 | red - VGG_MEAN[2],
130 | ])
131 |
132 | data_dict = {}
133 | if self._is_load:
134 | data_dict = np.load(self._pre_train_path, encoding='latin1').item()
135 |
136 | conv_output = self._create_conv(input_bgr, data_dict)
137 |
138 | arg_scope = tf.contrib.framework.arg_scope
139 | with arg_scope([fc], trainable=True, data_dict=data_dict):
140 | fc6 = fc(conv_output, 4096, 'fc6', nl=tf.nn.relu)
141 | dropout_fc6 = dropout(fc6, keep_prob, self.is_training)
142 |
143 | fc7 = fc(dropout_fc6, 4096, 'fc7', nl=tf.nn.relu)
144 | dropout_fc7 = dropout(fc7, keep_prob, self.is_training)
145 |
146 | fc8 = fc(dropout_fc7, self.num_class, 'fc8')
147 |
148 | self.output = tf.identity(fc8, 'model_output')
149 |
150 | class VGG19_FCN(VGG19):
151 | # def __init__(self, num_class = 1000,
152 | # num_channels = 3,
153 | # im_height = 224, im_width = 224,
154 | # learning_rate = 0.0001):
155 |
156 | # super(VGG19, self).__init__(num_class = num_class,
157 | # num_channels = num_channels,
158 | # im_height = im_height,
159 | # im_width = im_width,
160 | # learning_rate = learning_rate)
161 |
162 | def _create_model(self):
163 |
164 | input_im = self.model_input[0]
165 | keep_prob = self.model_input[1]
166 |
167 | # Convert rgb image to bgr image
168 | red, green, blue = tf.split(axis=3, num_or_size_splits=3,
169 | value=input_im)
170 |
171 | input_bgr = tf.concat(axis=3, values=[
172 | blue - VGG_MEAN[0],
173 | green - VGG_MEAN[1],
174 | red - VGG_MEAN[2],
175 | ])
176 |
177 | data_dict = {}
178 | if self._is_load:
179 | data_dict = np.load(self._pre_train_path, encoding='latin1').item()
180 |
181 | conv_outptu = self._create_conv(input_bgr, data_dict)
182 |
183 | arg_scope = tf.contrib.framework.arg_scope
184 | with arg_scope([conv], trainable=True, data_dict=data_dict):
185 |
186 | fc6 = conv(conv_outptu, 7, 4096, 'fc6', nl=tf.nn.relu, padding='VALID')
187 | dropout_fc6 = dropout(fc6, keep_prob, self.is_training)
188 |
189 | fc7 = conv(dropout_fc6, 1, 4096, 'fc7', nl=tf.nn.relu, padding='VALID')
190 | dropout_fc7 = dropout(fc7, keep_prob, self.is_training)
191 |
192 | fc8 = conv(dropout_fc7, 1, self.num_class, 'fc8', padding='VALID')
193 |
194 | # self.conv_output = tf.identity(conv5_4, 'conv_output')
195 | self.output = tf.identity(fc8, 'model_output')
196 | filter_size = [tf.shape(fc8)[1], tf.shape(fc8)[2]]
197 |
198 | self.avg_output = global_avg_pool(fc8)
199 |
200 | @staticmethod
201 | def load_pre_trained(session, model_path, skip_layer=[]):
202 | fc_layers = ['fc6', 'fc7', 'fc8']
203 | weights_dict = np.load(model_path, encoding='latin1').item()
204 | for layer_name in weights_dict:
205 | print('Loading ' + layer_name)
206 | if layer_name not in skip_layer:
207 | with tf.variable_scope(layer_name, reuse=True):
208 | for data in weights_dict[layer_name]:
209 | if len(data.shape) == 1:
210 | var = tf.get_variable('biases', trainable=False)
211 | session.run(var.assign(data))
212 | else:
213 | var = tf.get_variable('weights', trainable=False)
214 | if layer_name == 'fc6':
215 | data = tf.reshape(data, [7,7,512,4096])
216 | elif layer_name == 'fc7':
217 | data = tf.reshape(data, [1,1,4096,4096])
218 | elif layer_name == 'fc8':
219 | data = tf.reshape(data, [1,1,4096,1000])
220 | session.run(var.assign(data))
221 |
222 | if __name__ == '__main__':
223 | VGG = VGG19(num_class=1000,
224 | num_channels=3,
225 | im_height=224,
226 | im_width=224)
227 |
228 | VGG.create_graph()
229 |
230 | writer = tf.summary.FileWriter(config.summary_dir)
231 | with tf.Session() as sess:
232 | sess.run(tf.global_variables_initializer())
233 | writer.add_graph(sess.graph)
234 |
235 | writer.close()
236 |
237 |
238 |
239 |
--------------------------------------------------------------------------------
/tensorcv/callbacks/__init__.py:
--------------------------------------------------------------------------------
1 | # File: __init__.py
2 | # Author: Qian Ge
3 |
4 | from .saver import *
5 | from .summary import *
6 | from .inference import *
7 | from .inferencer import *
8 | from .monitors import *
9 | from .debug import *
--------------------------------------------------------------------------------
/tensorcv/callbacks/base.py:
--------------------------------------------------------------------------------
1 | import scipy.misc
2 | import os
3 | from abc import ABCMeta
4 |
5 | import numpy as np
6 | import tensorflow as tf
7 |
8 | __all__ = ['Callback', 'ProxyCallback']
9 |
10 | def assert_type(v, tp):
11 | assert isinstance(v, tp),\
12 | "Expect " + str(tp) + ", but " + str(v.__class__) + " is given!"
13 |
14 | class Callback(object):
15 | """ base class for callbacks """
16 |
17 | def setup_graph(self, trainer):
18 | self.trainer = trainer
19 | self._setup_graph()
20 |
21 | @property
22 | def global_step(self):
23 | return self.trainer.get_global_step
24 |
25 | @property
26 | def epochs_completed(self):
27 | return self.trainer.epochs_completed
28 |
29 | def _setup_graph(self):
30 | pass
31 |
32 | def before_run(self, rct):
33 | fetch = self._before_run(rct)
34 | if fetch is None:
35 | return None
36 | assert_type(fetch, tf.train.SessionRunArgs)
37 | return fetch
38 |
39 | def _before_run(self, rct):
40 | return None
41 |
42 | def after_run(self, rct, val):
43 | self._after_run(rct, val)
44 |
45 | def _after_run(self, rct, val):
46 | pass
47 |
48 | def before_train(self):
49 | self._before_train()
50 |
51 | def _before_train(self):
52 | pass
53 |
54 | def before_inference(self):
55 | self._before_inference()
56 |
57 | def _before_inference(self):
58 | pass
59 |
60 | def after_train(self):
61 | self._after_train()
62 |
63 | def _after_train(self):
64 | pass
65 |
66 | def before_epoch(self):
67 | self._before_epoch()
68 |
69 | def _before_epoch(self):
70 | pass
71 |
72 | def after_epoch(self):
73 | self._after_epoch()
74 |
75 | def _after_epoch(self):
76 | pass
77 |
78 | def trigger_epoch(self):
79 | self._trigger_epoch()
80 |
81 | def _trigger_epoch(self):
82 | self.trigger()
83 |
84 | def trigger_step(self):
85 | self._trigger_step()
86 |
87 | def _trigger_step(self):
88 | pass
89 |
90 | def trigger(self):
91 | self._trigger()
92 |
93 | def _trigger(self):
94 | pass
95 |
96 | # def before_run(self):
97 |
98 | class ProxyCallback(Callback):
99 | def __init__(self, cb):
100 | assert_type(cb, Callback)
101 | self.cb = cb
102 |
103 | def __str__(self):
104 | return "Proxy-" + str(self.cb)
105 |
106 | def _before_train(self):
107 | self.cb.before_train()
108 |
109 | def _before_inference(self):
110 | self.cb.before_inference()
111 |
112 | def _setup_graph(self):
113 | with tf.name_scope(None):
114 | self.cb.setup_graph(self.trainer)
115 |
116 | def _trigger_epoch(self):
117 | self.cb.trigger_epoch()
118 |
119 | def _trigger(self):
120 | self.cb.trigger()
121 |
122 | def _trigger_step(self):
123 | self.cb.trigger_step()
124 |
125 | def _after_train(self):
126 | self.cb.after_train()
127 |
128 | def _before_epoch(self):
129 | self.cb.before_epoch()
130 |
131 | def _after_epoch(self):
132 | self.cb.after_epoch()
133 |
134 | def _before_run(self, crt):
135 | self.cb.before_run(crt)
136 |
137 | def _after_run(self, crt, val):
138 | self.cb.after_run(crt, val)
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 |
--------------------------------------------------------------------------------
/tensorcv/callbacks/debug.py:
--------------------------------------------------------------------------------
1 | # File: inference.py
2 | # Author: Qian Ge
3 |
4 | import tensorflow as tf
5 |
6 | from .base import Callback
7 | from ..utils.common import get_tensors_by_names
8 |
9 | __all__ = ['CheckScalar']
10 |
11 | def assert_type(v, tp):
12 | assert isinstance(v, tp), \
13 | "Expect " + str(tp) + ", but " + str(v.__class__) + " is given!"
14 |
15 | class CheckScalar(Callback):
16 | """ print scalar tensor values during training
17 | Attributes:
18 | _tensors
19 | _names
20 | """
21 | def __init__(self, tensors, periodic=1):
22 | """ init CheckScalar object
23 | Args:
24 | tensors : list[string] A tensor name or list of tensor names
25 | """
26 | if not isinstance(tensors, list):
27 | tensors = [tensors]
28 | self._tensors = tensors
29 | self._names = tensors
30 |
31 | self._periodic = periodic
32 |
33 | def _setup_graph(self):
34 | self._tensors = get_tensors_by_names(self._tensors)
35 |
36 | def _before_run(self, _):
37 | if self.global_step % self._periodic == 0:
38 | return tf.train.SessionRunArgs(fetches = self._tensors)
39 | else:
40 | return None
41 |
42 | def _after_run(self, _, val):
43 | if val.results is not None:
44 | print([name + ': ' + str(v)
45 | for name, v in zip(self._names, val.results)])
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
--------------------------------------------------------------------------------
/tensorcv/callbacks/group.py:
--------------------------------------------------------------------------------
1 | import scipy.misc
2 | import os
3 |
4 | import numpy as np
5 | import tensorflow as tf
6 |
7 | from .base import Callback
8 | from .hooks import Callback2Hook
9 |
10 | __all__ = ['Callbacks']
11 |
12 | def assert_type(v, tp):
13 | assert isinstance(v, tp),\
14 | "Expect " + str(tp) + ", but " + str(v.__class__) + " is given!"
15 |
16 | class Callbacks(Callback):
17 | """ group all the callback """
18 | def __init__(self, cbs):
19 | for cb in cbs:
20 | assert_type(cb, Callback)
21 | self.cbs = cbs
22 |
23 |
24 | def _setup_graph(self):
25 | with tf.name_scope(None):
26 | for cb in self.cbs:
27 | cb.setup_graph(self.trainer)
28 |
29 | def get_hooks(self):
30 | return [Callback2Hook(cb) for cb in self.cbs]
31 |
32 | def _before_train(self):
33 | for cb in self.cbs:
34 | cb.before_train()
35 |
36 | def _before_inference(self):
37 | for cb in self.cbs:
38 | cb.before_inference()
39 |
40 | def _after_train(self):
41 | for cb in self.cbs:
42 | cb.after_train()
43 |
44 |
45 | def _before_epoch(self):
46 | for cb in self.cbs:
47 | cb.before_epoch()
48 |
49 |
50 | def _after_epoch(self):
51 | for cb in self.cbs:
52 | cb.after_epoch()
53 |
54 | def _trigger_epoch(self):
55 | for cb in self.cbs:
56 | cb.trigger_epoch()
57 |
58 | def _trigger_step(self):
59 | for cb in self.cbs:
60 | cb.trigger_step()
61 |
62 | # def trigger(self):
63 | # self._trigger()
64 |
65 | # def _trigger(self):
66 | # pass
67 |
68 | # def before_run(self):
--------------------------------------------------------------------------------
/tensorcv/callbacks/hooks.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 | from .base import Callback
4 | from .inferencer import InferencerBase
5 | from ..predicts.predictions import PredictionBase
6 |
7 | __all__ = ['Callback2Hook', 'Infer2Hook', 'Prediction2Hook']
8 |
9 | def assert_type(v, tp):
10 | assert isinstance(v, tp), \
11 | "Expect " + str(tp) + ", but " + str(v.__class__) + " is given!"
12 |
13 | class Callback2Hook(tf.train.SessionRunHook):
14 | """ """
15 | def __init__(self, cb):
16 | self.cb = cb
17 |
18 | def before_run(self, rct):
19 | return self.cb.before_run(rct)
20 |
21 | def after_run(self, rct, val):
22 | self.cb.after_run(rct, val)
23 |
24 | class Infer2Hook(tf.train.SessionRunHook):
25 |
26 | def __init__(self, inferencer):
27 | # to be modified
28 | assert_type(inferencer, InferencerBase)
29 | self.inferencer = inferencer
30 |
31 | def before_run(self, rct):
32 | return tf.train.SessionRunArgs(fetches=self.inferencer.put_fetch())
33 |
34 | def after_run(self, rct, val):
35 | self.inferencer.get_fetch(val)
36 |
37 | class Prediction2Hook(tf.train.SessionRunHook):
38 | def __init__(self, prediction):
39 | assert_type(prediction, PredictionBase)
40 | self.prediction = prediction
41 |
42 | def before_run(self, rct):
43 |
44 | return tf.train.SessionRunArgs(fetches=self.prediction.get_predictions())
45 |
46 | def after_run(self, rct, val):
47 | self.prediction.after_prediction(val.results)
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
--------------------------------------------------------------------------------
/tensorcv/callbacks/inference.py:
--------------------------------------------------------------------------------
1 | # File: inference.py
2 | # Author: Qian Ge
3 |
4 | import scipy.misc
5 | import os
6 | from abc import ABCMeta
7 |
8 | import numpy as np
9 | import tensorflow as tf
10 |
11 | from .base import Callback
12 | from .group import Callbacks
13 | from .inputs import FeedInput
14 | from ..dataflow.base import DataFlow
15 | from ..dataflow.randoms import RandomVec
16 | from .hooks import Callback2Hook, Infer2Hook
17 | from ..utils.sesscreate import ReuseSessionCreator
18 | from .inferencer import InferencerBase
19 |
20 | __all__ = ['FeedInference', 'GANInference', 'FeedInferenceBatch']
21 |
22 | def assert_type(v, tp):
23 | assert isinstance(v, tp), \
24 | "Expect " + str(tp) + ", but " + str(v.__class__) + " is given!"
25 |
26 | class InferenceBase(Callback):
27 | """ base class for Inference """
28 | def __init__(self, inputs=None, periodic=1,
29 | inferencers=None, extra_cbs=None,
30 | infer_batch_size=None):
31 | """
32 | Args:
33 | extra_cbs (list[Callback])
34 | """
35 | self._inputs = inputs
36 | self._periodic = periodic
37 | self._infer_batch_size = infer_batch_size
38 |
39 | assert inferencers is not None or extra_cbs is not None,\
40 | "Inferencers and extra_cbs cannot be both None!"
41 |
42 | if not isinstance(inferencers, list):
43 | inferencers = [inferencers]
44 | for infer in inferencers:
45 | assert_type(infer, InferencerBase)
46 | self._inference_list = inferencers
47 |
48 | if extra_cbs is None:
49 | self._extra_cbs = []
50 | elif not isinstance(extra_cbs, list):
51 | self._extra_cbs = [extra_cbs]
52 | else:
53 | self._extra_cbs = extra_cbs
54 |
55 | self._cbs = []
56 |
57 | def _setup_graph(self):
58 | self.model = self.trainer.model
59 | self.setup_inference()
60 | self.register_cbs()
61 | self._cbs = Callbacks(self._cbs)
62 | self._cbs.setup_graph(self.trainer)
63 |
64 | for infer in self._inference_list:
65 | infer.setup_inferencer()
66 |
67 | def setup_inference(self):
68 | self._setup_inference()
69 |
70 | for infer in self._inference_list:
71 | assert_type(infer, InferencerBase)
72 | infer.setup_graph(self.trainer)
73 |
74 | if self._infer_batch_size is None:
75 | self._inputs.set_batch_size(self.trainer.config.batch_size)
76 | else:
77 | self._inputs.set_batch_size(self._infer_batch_size)
78 |
79 | def _setup_inference(self):
80 | """ setup extra default callbacks for inference """
81 | pass
82 |
83 | def register_cbs(self):
84 | for cb in self._extra_cbs:
85 | assert_type(cb, Callback)
86 | self._cbs.append(cb)
87 |
88 | def get_infer_hooks(self):
89 | return (self._cbs.get_hooks()
90 | + [Infer2Hook(infer) for infer in self._inference_list])
91 |
92 | def _create_infer_sess(self):
93 | self.sess = self.trainer.sess
94 | infer_hooks = self.get_infer_hooks()
95 | self.hooked_sess = tf.train.MonitoredSession(
96 | session_creator = ReuseSessionCreator(self.sess),
97 | hooks = infer_hooks)
98 |
99 | def _trigger_step(self):
100 | if self.global_step % self._periodic == 0:
101 | for infer in self._inference_list:
102 | infer.before_inference()
103 |
104 | self._create_infer_sess()
105 | self.inference_step()
106 |
107 | for infer in self._inference_list:
108 | infer.after_inference()
109 |
110 | def inference_step(self):
111 | # TODO to be modified
112 | self.model.set_is_training(False)
113 | self._cbs.before_inference()
114 | self._inference_step()
115 |
116 | def _inference_step(self, extra_feed):
117 | self.hooked_sess.run(fetches = [], feed_dict = extra_feed)
118 |
119 | def _after_train(self):
120 | self._cbs.after_train()
121 |
122 |
123 | class FeedInference(InferenceBase):
124 | """
125 | default inferencer:
126 | inference_list = InferImages('generator/gen_image', prefix = 'gen')
127 | """
128 | def __init__(self, inputs, periodic=1,
129 | inferencers=[], extra_cbs=None,
130 | infer_batch_size=None):
131 | assert_type(inputs, DataFlow)
132 |
133 | # inferencers.append(InferImages('default', prefix = 'gen'))
134 | super(FeedInference, self).__init__(inputs=inputs,
135 | periodic=periodic,
136 | inferencers=inferencers,
137 | extra_cbs=extra_cbs,
138 | infer_batch_size=infer_batch_size)
139 |
140 | def _setup_inference(self):
141 | placeholders = self.model.get_train_placeholder()
142 | self._extra_cbs.append(FeedInput(self._inputs, placeholders))
143 |
144 | def _inference_step(self):
145 | model_feed = self.model.get_graph_feed()
146 | while self._inputs.epochs_completed <= 0:
147 | self.hooked_sess.run(fetches = [], feed_dict = model_feed)
148 | self._inputs.reset_epochs_completed(0)
149 |
150 | class FeedInferenceBatch(FeedInference):
151 | """ do not use all validation data """
152 | def __init__(self, inputs, periodic=1,
153 | batch_count=10,
154 | inferencers=[], extra_cbs=None,
155 | infer_batch_size=None):
156 | self._batch_count = batch_count
157 | super(FeedInferenceBatch, self).__init__(inputs=inputs,
158 | periodic=periodic,
159 | inferencers=inferencers,
160 | extra_cbs=extra_cbs,
161 | infer_batch_size=infer_batch_size)
162 | def _inference_step(self):
163 | model_feed = self.model.get_graph_feed()
164 | for i in range(self._batch_count):
165 | self.hooked_sess.run(fetches=[], feed_dict=model_feed)
166 |
167 |
168 | class GANInference(InferenceBase):
169 | def __init__(self, inputs=None, periodic=1,
170 | inferencers=None, extra_cbs=None):
171 | if inputs is not None:
172 | assert_type(inputs, RandomVec)
173 | super(GANInference, self).__init__(inputs=inputs,
174 | periodic=periodic,
175 | inferencers=inferencers,
176 | extra_cbs=extra_cbs)
177 |
178 | def _setup_inference(self):
179 | if self._inputs is not None:
180 | self._inputs.set_batch_size(self.trainer.config.batch_size)
181 | rand_vec_phs = self.model.get_random_vec_placeholder()
182 | self._extra_cbs.append(FeedInput(self._inputs, rand_vec_phs))
183 |
184 | def _inference_step(self):
185 | if self._inputs is None:
186 | model_feed = self.model.get_graph_feed()
187 | else:
188 | model_feed = {}
189 | # while self._inputs.epochs_completed <= 0:
190 | self.hooked_sess.run(fetches=[], feed_dict=model_feed)
191 | # self._inputs.reset_epochs_completed(0)
192 |
193 |
--------------------------------------------------------------------------------
/tensorcv/callbacks/inferencer.py:
--------------------------------------------------------------------------------
1 | # File: inference.py
2 | # Author: Qian Ge
3 |
4 | import os
5 |
6 | import numpy as np
7 | import tensorflow as tf
8 |
9 | from .base import Callback
10 | from ..utils.common import get_tensors_by_names, check_dir, match_tensor_save_name
11 | from ..utils.viz import *
12 |
13 | __all__ = ['InferencerBase', 'InferImages', 'InferScalars', 'InferOverlay', 'InferMat']
14 |
15 | class InferencerBase(Callback):
16 |
17 | def setup_inferencer(self):
18 | if not isinstance(self._names, list):
19 | self._names = [self._names]
20 | self._names = get_tensors_by_names(self._names)
21 | self._setup_inference(self.trainer.default_dirs)
22 |
23 | def _setup_inference(self, default_dirs=None):
24 | pass
25 |
26 | def put_fetch(self):
27 | return self._put_fetch()
28 |
29 | def _put_fetch(self):
30 | return self._names
31 | # pass
32 |
33 | def get_fetch(self, val):
34 | self._get_fetch(val)
35 |
36 | def _get_fetch(self, val):
37 | pass
38 |
39 | def before_inference(self):
40 | """ process before every inference """
41 | self._before_inference()
42 |
43 | def _before_inference(self):
44 | pass
45 |
46 | def after_inference(self):
47 | self._after_inference()
48 |
49 | # if re is not None:
50 | # for key, val in re.items():
51 | # s = tf.Summary()
52 | # s.value.add(tag = key, simple_value = val)
53 | # self.trainer.monitors.process_summary(s)
54 |
55 | def _after_inference(self):
56 | return None
57 |
58 |
59 |
60 | class InferImages(InferencerBase):
61 | def __init__(self, im_name, prefix=None, color=False, tanh=False):
62 | self._names, self._prefix = match_tensor_save_name(im_name, prefix)
63 | self._color = color
64 | self._tanh = tanh
65 |
66 | def _setup_inference(self, default_dirs=None):
67 | try:
68 | self._save_dir = os.path.join(self.trainer.default_dirs.infer_dir)
69 | check_dir(self._save_dir)
70 | except AttributeError:
71 | raise AttributeError('summary_dir is not set in infer_dir.py!')
72 |
73 | def _before_inference(self):
74 | self._result_list = []
75 |
76 | def _get_fetch(self, val):
77 | self._result_list.append(val.results)
78 | # self._result_im = val.results
79 |
80 | def _after_inference(self):
81 | # TODO add process_image to monitors
82 | # batch_size = len(self._result_im[0])
83 | batch_size = len(self._result_list[0][0])
84 | grid_size = self._get_grid_size(batch_size)
85 | # grid_size = [8, 8] if batch_size == 64 else [6, 6]
86 | local_step = 0
87 | for result_im in self._result_list:
88 | for im, save_name in zip(result_im, self._prefix):
89 | save_merge_images(im, [grid_size, grid_size],
90 | self._save_dir + save_name + '_step_' + str(self.global_step) +\
91 | '_b_' + str(local_step) + '.png',
92 | color = self._color,
93 | tanh = self._tanh)
94 | local_step += 1
95 | return None
96 |
97 | def _get_grid_size(self, batch_size):
98 | try:
99 | return self._grid_size
100 | except AttributeError:
101 | self._grid_size = np.ceil(batch_size**0.5).astype(int)
102 | return self._grid_size
103 |
104 | class InferOverlay(InferImages):
105 | def __init__(self, im_name, prefix=None, color=False, tanh=False):
106 | if not isinstance(im_name, list):
107 | im_name = [im_name]
108 | assert len(im_name) == 2,\
109 | '[InferOverlay] requires two image tensors but the input len = {}.'.\
110 | format(len(im_name))
111 | super(InferOverlay, self).__init__(im_name=im_name,
112 | prefix=prefix,
113 | color=color,
114 | tanh=tanh)
115 | self._overlay_prefix = '{}_{}'.format(self._prefix[0], self._prefix[1])
116 |
117 | def _after_inference(self):
118 | # TODO add process_image to monitors
119 | # batch_size = len(self._result_im[0])
120 | batch_size = len(self._result_list[0][0])
121 | grid_size = self._get_grid_size(batch_size)
122 | # grid_size = [8, 8] if batch_size == 64 else [6, 6]
123 | local_step = 0
124 | for result_im in self._result_list:
125 | overlay_im_list = []
126 | for im_1, im_2 in zip(result_im[0], result_im[1]):
127 | overlay_im = image_overlay(im_1, im_2, color = self._color)
128 | overlay_im_list.append(overlay_im)
129 | save_merge_images(np.squeeze(overlay_im_list), [grid_size, grid_size],
130 | self._save_dir + self._overlay_prefix + '_step_' +str(self.global_step) +\
131 | '_b_' + str(local_step) + '.png',
132 | color = False, tanh = self._tanh)
133 | local_step += 1
134 | return None
135 |
136 | class InferMat(InferImages):
137 | def __init__(self, infer_save_name, mat_name, prefix=None):
138 | self._infer_save_name = str(infer_save_name)
139 | super(InferMat, self).__init__(im_name = mat_name, prefix=prefix,
140 | color=False, tanh=False)
141 | def _after_inference(self):
142 | for idx, batch_result in enumerate(self._result_list):
143 | save_path = os.path.join(self._save_dir,
144 | '{}_step_{}_b_{}.mat'.format(self._infer_save_name, self.global_step, idx))
145 | # self._infer_save_name + '_b_' + str(idx) + str(self.global_step) + '.mat')
146 | scipy.io.savemat(save_path, {name: np.squeeze(val) for name, val
147 | in zip(self._prefix, batch_result)})
148 | return None
149 |
150 | class InferScalars(InferencerBase):
151 | def __init__(self, scaler_names, summary_names=None):
152 | if not isinstance(scaler_names, list):
153 | scaler_names = [scaler_names]
154 | self._names = scaler_names
155 | if summary_names is None:
156 | self._summary_names = scaler_names
157 | else:
158 | if not isinstance(summary_names, list):
159 | summary_names = [summary_names]
160 | assert len(self._names) == len(summary_names), \
161 | "length of scaler_names and summary_names has to be the same!"
162 | self._summary_names = summary_names
163 |
164 | def _before_inference(self):
165 | self.result_list = [[] for i in range(0, len(self._names))]
166 |
167 | def _get_fetch(self, val):
168 | for i,v in enumerate(val.results):
169 | self.result_list[i] += v,
170 |
171 | def _after_inference(self):
172 | """ process after get_fetch """
173 | summary_dict = {name: np.mean(val) for name, val
174 | in zip(self._summary_names, self.result_list)}
175 | if summary_dict is not None:
176 | for key, val in summary_dict.items():
177 | s = tf.Summary()
178 | s.value.add(tag=key, simple_value=val)
179 | self.trainer.monitors.process_summary(s)
180 | print('[infer] '+ key + ': ' + str(val))
181 | # return {name: np.mean(val) for name, val
182 | # in zip(self._summary_names, self.result_list)}
183 |
184 | # TODO to be modified
185 | # class BinaryClassificationStats(InferencerBase):
186 | # def __init__(self, accuracy):
187 | # self._names = accuracy
188 |
189 | # def _before_inference(self):
190 | # self.result_list = []
191 |
192 | # def _put_fetch(self):
193 | # # fetch_list = self.names
194 | # return self._names
195 |
196 | # def _get_fetch(self, val):
197 | # self.result_list += val.results,
198 |
199 | # def _after_inference(self):
200 | # """ process after get_fetch """
201 | # return {"test_accuracy": np.mean(self.result_list)}
202 |
203 | if __name__ == '__main__':
204 | t = InferGANGenerator('gen_name',
205 | save_dir = 'D:\\Qian\\GitHub\\workspace\\test\\result\\', prefix = 1)
206 | print(t._prefix)
207 |
208 |
209 |
210 |
--------------------------------------------------------------------------------
/tensorcv/callbacks/inputs.py:
--------------------------------------------------------------------------------
1 | # File: input.py
2 | # Author: Qian Ge
3 |
4 | import scipy.misc
5 | import os
6 |
7 | import tensorflow as tf
8 |
9 | from .base import Callback
10 | from ..dataflow.base import DataFlow
11 |
12 | __all__ = ['FeedInput']
13 |
14 | def assert_type(v, tp):
15 | assert isinstance(v, tp), \
16 | "Expect " + str(tp) + ", but " + str(v.__class__) + " is given!"
17 |
18 | class FeedInput(Callback):
19 | """ input using feed """
20 | def __init__(self, dataflow, placeholders):
21 | assert_type(dataflow, DataFlow)
22 | self.dataflow = dataflow
23 |
24 | if not isinstance(placeholders, list):
25 | print(type(placeholders))
26 | placeholders = [placeholders]
27 | self.placeholders = placeholders
28 |
29 | # def _setup_graph(self):
30 | # pass
31 | def _setup_graph(self):
32 | pass
33 | # self.dataflow._setup(num_epoch=self.trainer.config.max_epoch)
34 |
35 | def _before_train(self):
36 | self.dataflow.before_read_setup()
37 |
38 | def _before_inference(self):
39 | self._before_train()
40 |
41 | def _before_run(self, _):
42 | cur_batch = self.dataflow.next_batch()
43 |
44 | # assert len(cur_batch) == len(self.placeholders), \
45 | # "[FeedInput] lenght of input {} is not equal to length of placeholders {}"\
46 | # .format(len(cur_batch), len(self.placeholders))
47 |
48 | feed = dict(zip(self.placeholders, cur_batch))
49 | return tf.train.SessionRunArgs(fetches=[], feed_dict=feed)
50 |
51 | def _after_train(self):
52 | self.dataflow.after_reading()
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
--------------------------------------------------------------------------------
/tensorcv/callbacks/monitors.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import tensorflow as tf
4 |
5 | from .base import Callback
6 | from ..utils.common import check_dir
7 |
8 | __all__ = ['TrainingMonitor','Monitors','TFSummaryWriter']
9 |
10 | def assert_type(v, tp):
11 | assert isinstance(v, tp), \
12 | "Expect " + str(tp) + ", but " + str(v.__class__) + " is given!"
13 |
14 | class TrainingMonitor(Callback):
15 | def _setup_graph(self):
16 | pass
17 |
18 | def process_summary(self, summary):
19 | self._process_summary(summary)
20 |
21 | def _process_summary(self, summary):
22 | pass
23 |
24 | class Monitors(TrainingMonitor):
25 | """ group monitors """
26 | def __init__(self, mons):
27 | for mon in mons:
28 | assert_type(mon, TrainingMonitor)
29 | self.mons = mons
30 |
31 | def _process_summary(self, summary):
32 | for mon in self.mons:
33 | mon.process_summary(summary)
34 |
35 | class TFSummaryWriter(TrainingMonitor):
36 |
37 | def _setup_graph(self):
38 | try:
39 | summary_dir = os.path.join(self.trainer.default_dirs.summary_dir)
40 | check_dir(summary_dir)
41 | except AttributeError:
42 | raise AttributeError('summary_dir is not set in config.py!')
43 | self._writer = tf.summary.FileWriter(summary_dir)
44 |
45 | def _before_train(self):
46 | # default to write graph
47 | self._writer.add_graph(self.trainer.sess.graph)
48 |
49 | def _after_train(self):
50 | self._writer.close()
51 |
52 | def process_summary(self, summary):
53 | self._writer.add_summary(summary, self.global_step)
54 |
--------------------------------------------------------------------------------
/tensorcv/callbacks/saver.py:
--------------------------------------------------------------------------------
1 | import scipy.misc
2 | import os
3 | import os
4 |
5 | import numpy as np
6 | import tensorflow as tf
7 |
8 | from .base import Callback
9 | from ..utils.common import check_dir
10 |
11 | __all__ = ['ModelSaver']
12 |
13 | class ModelSaver(Callback):
14 | def __init__(self, max_to_keep=5,
15 | keep_checkpoint_every_n_hours=0.5,
16 | periodic=1,
17 | checkpoint_dir=None,
18 | var_collections=tf.GraphKeys.GLOBAL_VARIABLES):
19 |
20 | self._periodic = periodic
21 |
22 | self._max_to_keep = max_to_keep
23 | self._keep_checkpoint_every_n_hours = keep_checkpoint_every_n_hours
24 |
25 | if not isinstance(var_collections, list):
26 | var_collections = [var_collections]
27 | self.var_collections = var_collections
28 |
29 | def _setup_graph(self):
30 | try:
31 | checkpoint_dir = os.path.join(self.trainer.default_dirs.checkpoint_dir)
32 | check_dir(checkpoint_dir)
33 | except AttributeError:
34 | raise AttributeError('checkpoint_dir is not set in config_path!')
35 |
36 | self._save_path = os.path.join(checkpoint_dir, 'model')
37 | self._saver = tf.train.Saver()
38 |
39 | def _trigger_step(self):
40 | if self.global_step % self._periodic == 0:
41 | self._saver.save(tf.get_default_session(), self._save_path,
42 | global_step = self.global_step)
43 |
44 |
45 |
46 |
--------------------------------------------------------------------------------
/tensorcv/callbacks/summary.py:
--------------------------------------------------------------------------------
1 | import scipy.misc
2 | import os
3 |
4 | import numpy as np
5 | import tensorflow as tf
6 |
7 | from .base import Callback
8 |
9 | __all__ = ['TrainSummary']
10 |
11 | class TrainSummary(Callback):
12 | def __init__(self,
13 | key=None,
14 | periodic=1):
15 |
16 | self.periodic = periodic
17 | if not key is None and not isinstance(key, list):
18 | key = [key]
19 | self._key = key
20 |
21 | def _setup_graph(self):
22 | self.summary_list = tf.summary.merge(
23 | [tf.summary.merge_all(key) for key in self._key])
24 | # self.all_summary = tf.summary.merge_all(self._key)
25 |
26 | def _before_run(self, _):
27 | if self.global_step % self.periodic == 0:
28 | return tf.train.SessionRunArgs(fetches = self.summary_list)
29 | else:
30 | None
31 |
32 | def _after_run(self, _, val):
33 | if val.results is not None:
34 | self.trainer.monitors.process_summary(val.results)
35 |
36 |
37 |
--------------------------------------------------------------------------------
/tensorcv/callbacks/trigger.py:
--------------------------------------------------------------------------------
1 | import scipy.misc
2 | import os
3 | from abc import ABCMeta
4 | import os
5 |
6 | import numpy as np
7 | import tensorflow as tf
8 |
9 | from .base import ProxyCallback, Callback
10 |
11 | __all__ = ['PeriodicTrigger']
12 |
13 | def assert_type(v, tp):
14 | assert isinstance(v, tp), \
15 | "Expect " + str(tp) + ", but " + str(v.__class__) + " is given!"
16 |
17 | class PeriodicTrigger(ProxyCallback):
18 | """ may not need """
19 | def __init__(self, trigger_cb, every_k_steps=None, every_k_epochs=None):
20 |
21 | assert_type(trigger_cb, Callback)
22 | super(PeriodicTrigger, self).__init__(trigger_cb)
23 |
24 | assert (every_k_steps is not None) or (every_k_epochs is not None), \
25 | "every_k_steps and every_k_epochs cannot be both None!"
26 | self._k_step = every_k_steps
27 | self._k_epoch = every_k_epochs
28 |
29 | def __str__(self):
30 | return 'PeriodicTrigger' + str(self.cb)
31 |
32 | def _trigger_step(self):
33 | if self._k_step is None:
34 | return
35 | if self.global_step % self._k_step == 0:
36 | self.cb.trigger()
37 |
38 | def _trigger_epoch(self):
39 | if self._k_epoch is None:
40 | return
41 | if self.epochs_completed % self._k_epoch == 0:
42 | self.cb.trigger()
43 |
--------------------------------------------------------------------------------
/tensorcv/dataflow/__init__.py:
--------------------------------------------------------------------------------
1 | # File: __init__.py
2 | # Author: Qian Ge
3 |
4 | from .base import *
5 | from .image import *
6 | from .matlab import *
7 | from .randoms import *
8 | # from .dataset import *
9 | from .normalization import *
10 |
--------------------------------------------------------------------------------
/tensorcv/dataflow/argument.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # File: argument.py
4 | # Author: Qian Ge
5 |
6 | from .base import DataFlow
7 | from ..utils.utils import assert_type
8 |
9 |
10 | class ArgumentDataflow(DataFlow):
11 | def __init__(self, dataflow, argument_order, argument_fnc):
12 | assert_type(dataflow, DataFlow)
13 | if not isinstance(argument_order, list):
14 | argument_order = [argument_order]
15 | if not isinstance(argument_fnc, list):
16 | argument_fnc = [argument_fnc]
17 |
18 | self._dataflow = dataflow
19 |
20 | self._order = argument_order
21 | self._fnc = argument_fnc
22 |
23 | def setup(self, epoch_val, batch_size, **kwargs):
24 | self._dataflow .setup(epoch_val, batch_size, **kwargs)
25 |
26 | @property
27 | def epochs_completed(self):
28 | return self._dataflow.epochs_completed
29 |
30 | def reset_epochs_completed(self, val):
31 | self._dataflow.reset_epochs_completed(val)
32 |
33 | def set_batch_size(self, batch_size):
34 | self._dataflow.set_batch_size(batch_size)
35 |
36 | def size(self):
37 | return self._dataflow.size()
38 |
39 | def reset_state(self):
40 | self._dataflow.reset_state()
41 |
42 | def after_reading(self):
43 | self._dataflow.after_reading()
44 |
45 | def next_batch_dict(self):
46 | batch_data_dict = self._dataflow.next_batch_dict()
47 | arg_batch_data_dict = {}
48 | for key, arg_fnc in zip(self._order, self._fnc):
49 | arg_batch_data_dict[key] = arg_fnc(batch_data_dict[key])
50 |
51 | return arg_batch_data_dict
52 |
53 | def next_batch(self):
54 | print("***** [ArgumentDataflow.next_batch()] not implmented *****")
55 |
--------------------------------------------------------------------------------
/tensorcv/dataflow/base.py:
--------------------------------------------------------------------------------
1 | from abc import abstractmethod, ABCMeta
2 | import numpy as np
3 |
4 | from ..utils.utils import get_rng
5 |
6 | __all__ = ['DataFlow', 'RNGDataFlow']
7 |
8 | # @six.add_metaclass(ABCMeta)
9 | class DataFlow(object):
10 | """ base class for dataflow """
11 | # self._epochs_completed = 0
12 |
13 | def before_read_setup(self, **kwargs):
14 | pass
15 |
16 | def setup(self, epoch_val, batch_size, **kwargs):
17 | self.reset_epochs_completed(epoch_val)
18 | self.set_batch_size(batch_size)
19 | self.reset_state()
20 | self._setup()
21 |
22 | def _setup(self, **kwargs):
23 | pass
24 |
25 | # @property
26 | # def channels(self):
27 | # try:
28 | # return self._num_channels
29 | # except AttributeError:
30 | # self._num_channels = self._get_channels()
31 | # return self._num_channels
32 |
33 | # def _get_channels(self):
34 | # return 0
35 |
36 | # @property
37 | # def im_size(self):
38 | # try:
39 | # return self._im_size
40 | # except AttributeError:
41 | # self._im_size = self._get_im_size()
42 | # return self._im_size
43 |
44 | def _get_im_size(self):
45 | return 0
46 |
47 | @property
48 | def epochs_completed(self):
49 | return self._epochs_completed
50 |
51 | def reset_epochs_completed(self, val):
52 | self._epochs_completed = val
53 |
54 | @abstractmethod
55 | def next_batch(self):
56 | return
57 |
58 | def next_batch_dict(self):
59 | print('Need to be implemented!')
60 |
61 | def set_batch_size(self, batch_size):
62 | self._batch_size = batch_size
63 |
64 | def size(self):
65 | raise NotImplementedError()
66 |
67 | def reset_state(self):
68 | self._reset_state()
69 |
70 | def _reset_state(self):
71 | pass
72 |
73 | def after_reading(self):
74 | pass
75 |
76 | class RNGDataFlow(DataFlow):
77 | def _reset_state(self):
78 | self.rng = get_rng(self)
79 |
80 | def _suffle_file_list(self):
81 | idxs = np.arange(self.size())
82 | self.rng.shuffle(idxs)
83 | self.file_list = self.file_list[idxs]
84 |
85 | def suffle_data(self):
86 | self._suffle_file_list()
87 |
--------------------------------------------------------------------------------
/tensorcv/dataflow/bk/image.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import collections
3 | import scipy.io
4 | import scipy.misc
5 | import os
6 | import pickle
7 | import random
8 |
9 | import tensorflow as tf
10 |
11 | class CIFAT10(object):
12 | def __init__(self, file_path):
13 | self._file_list = [os.path.join(file_path, 'data_batch_' + str(batch_id)) for batch_id in range(1,6)]
14 | self._num_channels = 3
15 | self._batch_id = 0
16 | self._batch_file_id = -1
17 | self._image = []
18 | self._epochs_completed = 0
19 |
20 | def next_batch_file(self):
21 | if self._batch_file_id >= len(self._file_list)-1:
22 | self._batch_file_id = 0
23 | self._epochs_completed += 1
24 | else:
25 | self._batch_file_id += 1
26 | self._image = unpickle(self._file_list[self._batch_file_id])
27 | random.shuffle(self._image)
28 | # scipy.misc.imsave('test.png', self.image[100])
29 |
30 | def next_batch(self, batch_size):
31 | batch_id_end = self._batch_id + batch_size
32 | if batch_id_end >= len(self._image):
33 | self.next_batch_file()
34 | self._batch_id = 0
35 | batch_id_end = batch_size
36 | batch_image = self._image[self._batch_id:batch_id_end]
37 | self._batch_id = batch_id_end
38 | return batch_image
39 |
40 | @property
41 | def epochs_completed(self):
42 | return self._epochs_completed
43 |
44 |
45 | class TestImage(object):
46 | def __init__(self, test_file_path, patch_size, sample_mean = 0, num_channels = 1):
47 | self.patch_size = patch_size
48 | self.num_channels = num_channels
49 | self._sample_mean = sample_mean
50 |
51 | self._file_list = ([os.path.join(test_file_path, file)
52 | for file in os.listdir(test_file_path)
53 | if file.endswith(".mat")])
54 |
55 | self._image_id = -1
56 | self._image = None
57 | self._cnt_row = -1
58 |
59 | self.cur_patch = np.empty(shape=[1, self.patch_size, self.patch_size, self.num_channels])
60 |
61 | def next_image(self):
62 | if self._image_id >= len(self._file_list) - 1:
63 | return None
64 | self._image_id += 1
65 | self._image = load_image(self._file_list[self._image_id], self.num_channels)
66 |
67 | self.im_rows, self.im_cols = self._image.shape[0], self._image.shape[1]
68 | self._cnt_row = 1
69 |
70 | half_patch_size = int(self.patch_size/2)
71 | self.row_id, self.col_id = half_patch_size - 1, half_patch_size - 2
72 | self.patch_row_start, self.patch_row_end = self.row_id - half_patch_size + 1, self.row_id + half_patch_size
73 | self.patch_col_start, self.patch_col_end = self.col_id - half_patch_size + 1, self.col_id + half_patch_size
74 |
75 | for cur_channel in range(0, self.num_channels):
76 | self._image[:,:,:,cur_channel] -= self._sample_mean[cur_channel]
77 |
78 | return self._image
79 |
80 | def next_patch(self):
81 | if self._image is None:
82 | return None
83 | half_patch_size = int(self.patch_size/2)
84 | if self.col_id >= self.im_cols - half_patch_size - 1:
85 | self.row_id += 1
86 | self.patch_row_start += 1
87 | self.patch_row_end += 1
88 | print('Row: ' + str(self.row_id))
89 | self._cnt_row += 1
90 | self.col_id = half_patch_size - 1
91 | self.patch_col_start, self.patch_col_end = self.col_id - half_patch_size + 1, self.col_id + half_patch_size
92 | else:
93 | self.col_id += 1
94 | self.patch_col_start += 1
95 | self.patch_col_end += 1
96 |
97 | if self.row_id >= self.im_rows - half_patch_size:
98 | return None
99 | # print(self.patch_row_start,self.patch_row_end,self.patch_col_start,self.patch_col_end )
100 | if self.num_channels > 1:
101 | cnt_depth = 0
102 | for channel_id in range(0, self.num_channels):
103 | self.cur_patch[:,:,:,cnt_depth] = self._image[self.patch_row_start:self.patch_row_end + 1,
104 | self.patch_col_start:self.patch_col_end + 1, channel_id].transpose()
105 | cnt_depth += 1
106 | else:
107 | self.cur_patch[:,:,:,:] = np.reshape(self._image[self.patch_row_start:self.patch_row_end + 1,
108 | self.patch_col_start:self.patch_col_end + 1].transpose(),
109 | [1, self.patch_size, self.patch_size, self.num_channels])
110 | return self.cur_patch
111 |
112 | def reset_cnt_row(self):
113 | self._cnt_row = 1
114 | def get_cnt_row(self):
115 | return self._cnt_row
116 |
117 | class TrainData(object):
118 | def __init__(self, file_list, sample_mean = 0, num_channels = 1):
119 | self._num_channels = num_channels
120 |
121 | self._file_list = file_list
122 | self._num_image = len(self._file_list)
123 |
124 | self._image_id = 0
125 | self._image = []
126 | self._label = []
127 | self._mask = []
128 | self._epochs_completed = 0
129 |
130 | self._sample_mean = sample_mean
131 |
132 | @property
133 | def epochs_completed(self):
134 | return self._epochs_completed
135 |
136 | def set_epochs_completed(self, value):
137 | self._epochs_completed = value
138 |
139 | @property
140 | def sample_mean(self):
141 | return self._sample_mean
142 |
143 | def next_image(self):
144 | if self._image_id >= self._num_image - 1:
145 | self._epochs_completed += 1
146 | self._image_id = 1
147 | perm = np.arange(self._num_image)
148 | np.random.shuffle(perm)
149 | self._file_list = self._file_list[perm]
150 | else:
151 | self._image_id += 1
152 |
153 | self._image, self._label, self._mask = load_training_image(self._file_list[self._image_id], num_channels = self._num_channels)
154 | for cur_channel in range(0, self._num_channels):
155 | self._image[:,:,:,cur_channel] -= self._sample_mean[cur_channel]
156 | return self._image, self._label, self._mask
157 |
158 | def next_batch(self, batch_size):
159 | if batch_size > self._num_image:
160 | return None
161 | start = self._image_id
162 | self._image_id += batch_size
163 | if self._image_id >= self._num_image:
164 | self._epochs_completed += 1
165 | perm = np.arange(self._num_image)
166 | np.random.shuffle(perm)
167 | self._file_list = self._file_list[perm]
168 | start = 0
169 | self._image_id = batch_size
170 | end = self._image_id
171 | batch_file_path = self._file_list[start:end]
172 | return load_batch_image(batch_file_path, num_channels = self._num_channels)
173 |
174 |
175 |
176 | def prepare_data_set(file_path, valid_percentage, num_channels = 1, isSubstractMean = True):
177 | file_list = np.array([os.path.join(file_path, file)
178 | for file in os.listdir(file_path)
179 | if file.endswith(".mat")])
180 | if isSubstractMean:
181 | sample_mean_value = average_train_data(file_list, num_channels)
182 | else:
183 | sample_mean_value = np.empty(shape=[num_channels], dtype = float)
184 | sample_mean = tf.Variable(sample_mean_value, name = 'sample_mean')
185 | num_image = len(file_list)
186 | num_valid = int(num_image*valid_percentage)
187 | num_train = num_image - num_valid
188 | train = TrainData(file_list[:num_train], sample_mean = sample_mean_value, num_channels = num_channels)
189 | validate = TrainData(file_list[num_train:], sample_mean = sample_mean_value, num_channels = num_channels)
190 | print('Number of training image: {}. Number of validation image: {}.'.format(num_train, num_valid))
191 | # print('Mean of training samples: {}'.format(sampel_mean))
192 | ds = collections.namedtuple('TrainData', ['train', 'validate'])
193 | return ds(train = train, validate = validate)
194 |
195 | def load_batch_image(batch_file_path, num_channels = 1):
196 | image_list = []
197 | # image_list = np.empty(shape = [0, im_size, im_size, num_channels])
198 | for file_path in batch_file_path:
199 | image ,_ ,_ = load_training_image(file_path, num_channels = num_channels)
200 | image_list.extend(image)
201 | # image_list = np.append(image_list, image, axis=0)
202 | return np.array(image_list)
203 |
204 |
205 | def load_training_image(file_path, num_channels = 1):
206 | # print('Loading training file ' + file_path)
207 | mat = scipy.io.loadmat(file_path)
208 | image = mat['level1Edge'].astype('float')
209 | label = mat['GT']
210 | mask = mat['Mask']
211 | image = np.reshape(image, [1, image.shape[0], image.shape[1], num_channels])
212 | label = np.reshape(label, [1, label.shape[0], label.shape[1]])
213 | mask = np.reshape(mask, [1, mask.shape[0], mask.shape[1]])
214 | # print('Load successfully.')
215 | return image, label, mask
216 |
217 | def load_image(test_file_path, num_channels):
218 | print('Loading test file ' + test_file_path + '...')
219 | mat = scipy.io.loadmat(test_file_path)
220 | image = mat['level1Edge'].astype('float')
221 | print('Load successfully.')
222 | return np.array(np.reshape(image, [1, image.shape[0], image.shape[1], num_channels]))
223 |
224 | def average_train_data(file_list, num_channels):
225 | mean_list = np.empty(shape=[num_channels], dtype = float)
226 | for cur_file_path in file_list:
227 | image, label, mask = load_training_image(cur_file_path, num_channels = num_channels)
228 | for cur_channel in range(0, num_channels):
229 | mean_list[cur_channel] += np.ma.masked_array(image[:,:,:, cur_channel], mask = mask).mean()
230 | return mean_list/len(file_list)
231 |
232 |
233 | def unpickle(file):
234 | with open(file, 'rb') as fo:
235 | dict = pickle.load(fo, encoding='bytes')
236 | image = dict[b'data']
237 |
238 | r = image[:,:32*32].reshape(-1,32,32)
239 | g = image[:,32*32: 2*32*32].reshape(-1,32,32)
240 | b = image[:,2*32*32:].reshape(-1,32,32)
241 |
242 | image = np.stack((r,g,b),axis=-1)
243 | return image
244 |
245 |
--------------------------------------------------------------------------------
/tensorcv/dataflow/common.py:
--------------------------------------------------------------------------------
1 | # File: common.py
2 | # Author: Qian Ge
3 |
4 | import os
5 |
6 | from scipy import misc
7 | import numpy as np
8 |
9 | from .preprocess import resize_image_with_smallest_side, random_crop_to_size
10 | from .normalization import identity
11 |
12 |
13 | def get_file_list(file_dir, file_ext, sub_name=None):
14 | # assert file_ext in ['.mat', '.png', '.jpg', '.jpeg']
15 | re_list = []
16 |
17 | if sub_name is None:
18 | return np.array([os.path.join(root, name)
19 | for root, dirs, files in os.walk(file_dir)
20 | for name in sorted(files) if name.lower().endswith(file_ext)])
21 | else:
22 | return np.array([os.path.join(root, name)
23 | for root, dirs, files in os.walk(file_dir)
24 | for name in sorted(files) if name.lower().endswith(file_ext) and sub_name.lower() in name.lower()])
25 | # for root, dirs, files in os.walk(file_dir):
26 | # for name in files:
27 | # if name.lower().endswith(file_ext):
28 | # re_list.append(os.path.join(root, name))
29 | # return np.array(re_list)
30 |
31 | def get_folder_list(folder_dir):
32 | return np.array([os.path.join(folder_dir, folder)
33 | for folder in os.listdir(folder_dir)
34 | if os.path.join(folder_dir, folder)])
35 |
36 | def get_folder_names(folder_dir):
37 | return np.array([name for name in os.listdir(folder_dir)
38 | if os.path.join(folder_dir, name)])
39 |
40 | def input_val_range(in_mat):
41 | # TODO to be modified
42 | max_val = np.amax(in_mat)
43 | min_val = np.amin(in_mat)
44 | if max_val > 1:
45 | max_in_val = 255.0
46 | half_in_val = 128.0
47 | elif min_val >= 0:
48 | max_in_val = 1.0
49 | half_in_val = 0.5
50 | else:
51 | max_in_val = 1.0
52 | half_in_val = 0
53 | return max_in_val, half_in_val
54 |
55 | def tanh_normalization(data, half_in_val):
56 | return (data*1.0 - half_in_val)/half_in_val
57 |
58 |
59 | def dense_to_one_hot(labels_dense, num_classes):
60 | """Convert class labels from scalars to one-hot vectors."""
61 | num_labels = labels_dense.shape[0]
62 | index_offset = np.arange(num_labels) * num_classes
63 | labels_one_hot = np.zeros((num_labels, num_classes))
64 | labels_one_hot.flat[[index_offset + labels_dense.ravel()]] = 1
65 | return labels_one_hot
66 |
67 | def reverse_label_dict(label_dict):
68 | label_dict_reverse = {}
69 | for key, value in label_dict.items():
70 | label_dict_reverse[value] = key
71 | return label_dict_reverse
72 |
73 | def load_image(im_path, read_channel=None, pf=identity, resize=None, resize_crop=None):
74 | if resize is not None:
75 | print_warning('[load_image] resize will be unused in the future!\
76 | Use pf (preprocess_fnc) instead.')
77 | if resize_crop is not None:
78 | print_warning('[load_image] resize_crop will be unused in the future!\
79 | Use pf (preprocess_fnc) instead.')
80 |
81 | # im = cv2.imread(im_path, self._cv_read)
82 | if read_channel is None:
83 | im = misc.imread(im_path)
84 | elif read_channel == 3:
85 | im = misc.imread(im_path, mode='RGB')
86 | else:
87 | im = misc.imread(im_path, flatten=True)
88 |
89 | if len(im.shape) < 3:
90 | try:
91 | im = misc.imresize(im, (resize[0], resize[1], 1))
92 | except TypeError:
93 | pass
94 | if resize_crop is not None:
95 | im = resize_image_with_smallest_side(im, resize_crop)
96 | im = random_crop_to_size(im, resize_crop)
97 | im = pf(im)
98 | im = np.reshape(im, [1, im.shape[0], im.shape[1], 1])
99 | else:
100 | try:
101 | im = misc.imresize(im, (resize[0], resize[1], im.shape[2]))
102 | except TypeError:
103 | pass
104 | if resize_crop is not None:
105 | im = resize_image_with_smallest_side(im, resize_crop)
106 | im = random_crop_to_size(im, resize_crop)
107 | im = pf(im)
108 | im = np.reshape(im, [1, im.shape[0], im.shape[1], im.shape[2]])
109 | return im
110 |
111 |
112 | def print_warning(warning_str):
113 | print('[**** warning ****] {}'.format(warning_str))
114 |
115 |
--------------------------------------------------------------------------------
/tensorcv/dataflow/dataset/BSDS500.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import numpy as np
4 | from scipy.io import loadmat
5 |
6 | from ..common import *
7 | from ..normalization import *
8 | from ..image import ImageFromFile
9 |
10 | __all__ = ['BSDS500', 'BSDS500HED']
11 |
12 | class BSDS500(ImageFromFile):
13 | def __init__(self, name,
14 | data_dir='',
15 | shuffle=True,
16 | normalize=None,
17 | is_mask=False,
18 | normalize_fnc=identity,
19 | resize=None):
20 |
21 | assert name in ['train', 'test', 'val', 'infer']
22 | self._load_name = name
23 | self._is_mask = is_mask
24 |
25 | super(BSDS500, self).__init__('.jpg',
26 | data_dir=data_dir,
27 | num_channel=3,
28 | shuffle=shuffle,
29 | normalize=normalize,
30 | normalize_fnc=normalize_fnc,
31 | resize=resize)
32 |
33 | def _load_file_list(self, _):
34 | im_dir = os.path.join(self.data_dir, 'images', self._load_name)
35 | self._im_list = get_file_list(im_dir, '.jpg')
36 |
37 | gt_dir = os.path.join(self.data_dir, 'groundTruth', self._load_name)
38 | self._gt_list = get_file_list(gt_dir, '.mat')
39 |
40 | # TODO may remove later
41 | if self._is_mask:
42 | mask_dir = os.path.join(self.data_dir, 'mask', self._load_name)
43 | self._mask_list = get_file_list(mask_dir, '.mat')
44 |
45 | if self._shuffle:
46 | self._suffle_file_list()
47 |
48 | def _load_data(self, start, end):
49 | input_im_list = []
50 | input_label_list = []
51 | if self._is_mask:
52 | input_mask_list = []
53 |
54 | for k in range(start, end):
55 | cur_path = self._im_list[k]
56 | im = load_image(cur_path, read_channel=self._read_channel,
57 | resize=self._resize)
58 | input_im_list.extend(im)
59 |
60 | gt = loadmat(self._gt_list[k])['groundTruth'][0]
61 | num_gt = gt.shape[0]
62 | gt = sum(gt[k]['Boundaries'][0][0] for k in range(num_gt))
63 | gt = gt.astype('float32')
64 | gt = gt * 1.0 / np.amax(gt)
65 | # gt = 1.0*gt/num_gt
66 | try:
67 | gt = misc.imresize(gt, (self._resize[0], self._resize[1]))
68 | except TypeError:
69 | pass
70 | gt = np.reshape(gt, [1, gt.shape[0], gt.shape[1]])
71 | input_label_list.extend(gt)
72 |
73 | if self._is_mask:
74 | mask = np.reshape(gt, [1, mask.shape[0], mask.shape[1]])
75 | input_mask_list.extend(mask)
76 |
77 | input_im_list = self._normalize_fnc(np.array(input_im_list),
78 | self._get_max_in_val(),
79 | self._get_half_in_val())
80 |
81 | input_label_list = np.array(input_label_list)
82 | if self._is_mask:
83 | input_mask_list = np.array(input_mask_list)
84 | return [input_im_list, input_label_list, input_mask_list]
85 | else:
86 | return [input_im_list, input_label_list]
87 |
88 | def _suffle_file_list(self):
89 | idxs = np.arange(self.size())
90 | self.rng.shuffle(idxs)
91 | self._im_list = self._im_list[idxs]
92 | self._gt_list = self._gt_list[idxs]
93 | try:
94 | self._mask_list = self._mask_list[idxs]
95 | except AttributeError:
96 | pass
97 |
98 | class BSDS500HED(BSDS500):
99 | def _load_file_list(self, _):
100 | im_dir = os.path.join(self.data_dir, 'images', self._load_name)
101 | self._im_list = get_file_list(im_dir, '.jpg')
102 |
103 | gt_dir = os.path.join(self.data_dir, 'groundTruth', self._load_name)
104 | self._gt_list = get_file_list(gt_dir, '.png')
105 |
106 | if self._shuffle:
107 | self._suffle_file_list()
108 |
109 | def _load_data(self, start, end):
110 | input_im_list = []
111 | input_label_list = []
112 |
113 | for k in range(start, end):
114 | im = load_image(self._im_list[k], read_channel=self._read_channel,
115 | resize=self._resize)
116 | input_im_list.extend(im)
117 |
118 | gt = load_image(self._gt_list[k], read_channel=1,
119 | resize=self._resize)
120 | gt = gt * 1.0 / np.amax(gt)
121 |
122 | # gt = loadmat(self._gt_list[k])['groundTruth'][0]
123 | # num_gt = gt.shape[0]
124 | # gt = sum(gt[k]['Boundaries'][0][0] for k in range(num_gt))
125 | # gt = gt.astype('float32')
126 | # gt = 1.0*gt/num_gt
127 | # try:
128 | # gt = misc.imresize(gt, (self._resize[0], self._resize[1]))
129 | # except TypeError:
130 | # pass
131 | gt = np.squeeze(gt, axis = -1)
132 | input_label_list.extend(gt)
133 |
134 |
135 | input_im_list = self._normalize_fnc(np.array(input_im_list),
136 | self._get_max_in_val(),
137 | self._get_half_in_val())
138 |
139 | input_label_list = np.array(input_label_list)
140 |
141 | return [input_im_list, input_label_list]
142 |
143 |
144 | if __name__ == '__main__':
145 | a = BSDS500('val','E:\\GITHUB\\workspace\\CNN\\dataset\\BSR_bsds500\\BSR\\BSDS500\\data\\')
146 | print(a.next_batch())
--------------------------------------------------------------------------------
/tensorcv/dataflow/dataset/CIFAR.py:
--------------------------------------------------------------------------------
1 | # File: CIFAR.py
2 | # Author: Qian Ge
3 |
4 | import os
5 | import pickle
6 |
7 | import numpy as np
8 |
9 | from ..base import RNGDataFlow
10 |
11 | __all__ = ['CIFAR']
12 |
13 | ## TODO Add batch size
14 | class CIFAR(RNGDataFlow):
15 | def __init__(self, data_dir='', shuffle=True, normalize=None):
16 | self.num_channels = 3
17 | self.im_size = [32, 32]
18 |
19 | assert os.path.isdir(data_dir)
20 | self.data_dir = data_dir
21 |
22 | self.shuffle = shuffle
23 | self._normalize = normalize
24 |
25 | self.setup(epoch_val=0, batch_size=1)
26 | self._file_list = [os.path.join(data_dir, 'data_batch_' + str(batch_id)) for batch_id in range(1,6)]
27 |
28 | # self._load_files()
29 | self._num_image = self.size()
30 |
31 | self._image_id = 0
32 | self._batch_file_id = -1
33 | self._image = []
34 | self._next_batch_file()
35 |
36 | def _next_batch_file(self):
37 | if self._batch_file_id >= len(self._file_list) - 1:
38 | self._batch_file_id = 0
39 | self._epochs_completed += 1
40 | else:
41 | self._batch_file_id += 1
42 | self._image = np.array(unpickle(self._file_list[self._batch_file_id]))
43 | # TODO to be modified
44 | if self._normalize == 'tanh':
45 | self._image = (self._image*1. - 128)/128.0
46 |
47 | if self.shuffle:
48 | self._suffle_files()
49 |
50 | def _suffle_files(self):
51 | idxs = np.arange(len(self._image))
52 |
53 | self.rng.shuffle(idxs)
54 | self._image = self._image[idxs]
55 |
56 | def size(self):
57 | try:
58 | return self.data_size
59 | except AttributeError:
60 | data_size = 0
61 | for k in range(len(self._file_list)):
62 | tmp_image = unpickle(self._file_list[k])
63 | data_size += len(tmp_image)
64 | self.data_size = data_size
65 | return self.data_size
66 |
67 | def next_batch(self):
68 | # TODO assume batch_size smaller than images in one file
69 | assert self._batch_size <= self.size(), \
70 | "batch_size {} cannot be larger than data size {}".\
71 | format(self._batch_size, self.size())
72 |
73 | start = self._image_id
74 | self._image_id += self._batch_size
75 | end = self._image_id
76 | batch_files = np.array(self._image[start:end])
77 |
78 | if self._image_id + self._batch_size > len(self._image):
79 | self._next_batch_file()
80 | self._image_id = 0
81 | if self.shuffle:
82 | self._suffle_files()
83 |
84 | return [batch_files]
85 |
86 |
87 | def unpickle(file):
88 | with open(file, 'rb') as fo:
89 | dict = pickle.load(fo, encoding='bytes')
90 | image = dict[b'data']
91 |
92 | r = image[:,:32*32].reshape(-1,32,32)
93 | g = image[:,32*32: 2*32*32].reshape(-1,32,32)
94 | b = image[:,2*32*32:].reshape(-1,32,32)
95 |
96 | image = np.stack((r,g,b),axis=-1)
97 | return image
98 |
99 | if __name__ == '__main__':
100 | a = CIFAR('D:\\Qian\\GitHub\\workspace\\tensorflow-DCGAN\\cifar-10-python.tar\\')
101 | t = a.next_batch()[0]
102 | print(t)
103 | print(t.shape)
104 | print(a.size())
105 | # print(a.next_batch()[0])
106 | # print(a.next_batch()[0])
--------------------------------------------------------------------------------
/tensorcv/dataflow/dataset/MNIST.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: UTF-8 -*-
3 | # File: MNIST.py
4 | # Author: Qian Ge
5 |
6 | import os
7 |
8 | import numpy as np
9 | from tensorflow.examples.tutorials.mnist import input_data
10 |
11 | from ..base import RNGDataFlow
12 |
13 | __all__ = ['MNIST', 'MNISTLabel']
14 |
15 | def get_mnist_im_label(name, mnist_data):
16 | if name == 'train':
17 | return mnist_data.train.images, mnist_data.train.labels
18 | elif name == 'val':
19 | return mnist_data.validation.images, mnist_data.validation.labels
20 | else:
21 | return mnist_data.test.images, mnist_data.test.labels
22 |
23 | # TODO read data without tensorflow
24 | class MNIST(RNGDataFlow):
25 | """
26 |
27 | """
28 | def __init__(self, name, data_dir='', shuffle=True, normalize=None):
29 |
30 | self.num_channels = 1
31 | self.im_size = [28, 28]
32 |
33 | assert os.path.isdir(data_dir)
34 | self.data_dir = data_dir
35 |
36 | self.shuffle = shuffle
37 | self._normalize = normalize
38 |
39 | assert name in ['train', 'test', 'val']
40 | self.setup(epoch_val=0, batch_size=1)
41 |
42 | self._load_files(name)
43 | self._num_image = self.size()
44 | self._image_id = 0
45 |
46 | def _load_files(self, name):
47 | mnist_data = input_data.read_data_sets(self.data_dir, one_hot=False)
48 | self.im_list = []
49 | self.label_list = []
50 |
51 | mnist_images, mnist_labels = get_mnist_im_label(name, mnist_data)
52 | for image, label in zip(mnist_images, mnist_labels):
53 | # TODO to be modified
54 | if self._normalize == 'tanh':
55 | image = image*2.-1.
56 | image = np.reshape(image, [28, 28, 1])
57 | self.im_list.append(image)
58 | self.label_list.append(label)
59 | self.im_list = np.array(self.im_list)
60 | self.label_list = np.array(self.label_list)
61 |
62 | if self.shuffle:
63 | self._suffle_files()
64 |
65 | def _suffle_files(self):
66 | idxs = np.arange(self.size())
67 |
68 | self.rng.shuffle(idxs)
69 | self.im_list = self.im_list[idxs]
70 | self.label_list = self.label_list[idxs]
71 |
72 | def size(self):
73 | return self.im_list.shape[0]
74 |
75 | def next_batch(self):
76 | assert self._batch_size <= self.size(), \
77 | "batch_size {} cannot be larger than data size {}".\
78 | format(self._batch_size, self.size())
79 | start = self._image_id
80 | self._image_id += self._batch_size
81 | end = self._image_id
82 | batch_files = self.im_list[start:end]
83 |
84 | if self._image_id + self._batch_size > self._num_image:
85 | self._epochs_completed += 1
86 | self._image_id = 0
87 | if self.shuffle:
88 | self._suffle_files()
89 | return [batch_files]
90 |
91 | class MNISTLabel(MNIST):
92 |
93 | def next_batch(self):
94 | assert self._batch_size <= self.size(), \
95 | "batch_size {} cannot be larger than data size {}".\
96 | format(self._batch_size, self.size())
97 | start = self._image_id
98 | self._image_id += self._batch_size
99 | end = self._image_id
100 | batch_im = self.im_list[start:end]
101 | batch_label = self.label_list[start:end]
102 |
103 | if self._image_id + self._batch_size > self._num_image:
104 | self._epochs_completed += 1
105 | self._image_id = 0
106 | if self.shuffle:
107 | self._suffle_files()
108 | return [batch_im, batch_label]
109 |
110 |
111 | if __name__ == '__main__':
112 | a = MNISTLabel('val','D:\\Qian\\GitHub\\workspace\\tensorflow-DCGAN\\MNIST_data\\')
113 | t = a.next_batch()
114 | print(t)
115 |
--------------------------------------------------------------------------------
/tensorcv/dataflow/dataset/__init__.py:
--------------------------------------------------------------------------------
1 | from .BSDS500 import *
2 | from .CIFAR import *
3 | from .MNIST import *
4 |
--------------------------------------------------------------------------------
/tensorcv/dataflow/matlab.py:
--------------------------------------------------------------------------------
1 | # File: matlab.py
2 | # Author: Qian Ge
3 |
4 | import os
5 | from scipy.io import loadmat
6 |
7 | import numpy as np
8 |
9 | from .base import RNGDataFlow
10 | from .common import *
11 |
12 | __all__ = ['MatlabData']
13 |
14 | class MatlabData(RNGDataFlow):
15 | """ dataflow from .mat file with mask """
16 | def __init__(self,
17 | data_dir='',
18 | mat_name_list=None,
19 | mat_type_list=None,
20 | shuffle=True,
21 | normalize=None):
22 |
23 | self.setup(epoch_val=0, batch_size=1)
24 |
25 | self.shuffle = shuffle
26 | self._normalize = normalize
27 |
28 | assert os.path.isdir(data_dir)
29 | self.data_dir = data_dir
30 |
31 | assert mat_name_list is not None, 'mat_name_list cannot be None'
32 | if not isinstance(mat_name_list, list):
33 | mat_name_list = [mat_name_list]
34 | self._mat_name_list = mat_name_list
35 | if mat_type_list is None:
36 | mat_type_list = ['float']*len(self._mat_name_list)
37 | assert len(self._mat_name_list) == len(mat_type_list),\
38 | 'Length of mat_name_list and mat_type_list has to be the same!'
39 | self._mat_type_list = mat_type_list
40 |
41 | self._load_file_list()
42 | self._get_im_size()
43 | self._num_image = self.size()
44 | self._image_id = 0
45 |
46 | def _get_im_size(self):
47 | # Run after _load_file_list
48 | # Assume all the image have the same size
49 | mat = loadmat(self.file_list[0])
50 | cur_mat = load_image_from_mat(mat, self._mat_name_list[0],
51 | self._mat_type_list[0])
52 | if len(cur_mat.shape) < 3:
53 | self.num_channels = 1
54 | else:
55 | self.num_channels = cur_mat.shape[2]
56 | self.im_size = [cur_mat.shape[0], cur_mat.shape[1]]
57 |
58 | def _load_file_list(self):
59 | # data_dir = os.path.join(self.data_dir)
60 | self.file_list = np.array([os.path.join(self.data_dir, file)
61 | for file in os.listdir(self.data_dir) if file.endswith(".mat")])
62 |
63 | if self.shuffle:
64 | self._suffle_file_list()
65 |
66 | def next_batch(self):
67 | assert self._batch_size <= self.size(), \
68 | "batch_size cannot be larger than data size"
69 |
70 | start = self._image_id
71 | self._image_id += self._batch_size
72 | end = self._image_id
73 | batch_file_path = self.file_list[start:end]
74 |
75 | if self._image_id + self._batch_size > self._num_image:
76 | self._epochs_completed += 1
77 | self._image_id = 0
78 | if self.shuffle:
79 | self._suffle_file_list()
80 | return self._load_data(batch_file_path)
81 |
82 | def _load_data(self, batch_file_path):
83 | # TODO deal with num_channels
84 | input_data = [[] for i in range(len(self._mat_name_list))]
85 |
86 | for file_path in batch_file_path:
87 | mat = loadmat(file_path)
88 | cur_data = load_image_from_mat(mat, self._mat_name_list[0],
89 | self._mat_type_list[0])
90 | cur_data = np.reshape(cur_data,
91 | [1, cur_data.shape[0], cur_data.shape[1], self.num_channels])
92 | input_data[0].extend(cur_data)
93 |
94 | for k in range(1, len(self._mat_name_list)):
95 | cur_data = load_image_from_mat(mat,
96 | self._mat_name_list[k], self._mat_type_list[k])
97 | cur_data = np.reshape(cur_data,
98 | [1, cur_data.shape[0], cur_data.shape[1]])
99 | input_data[k].extend(cur_data)
100 | input_data = [np.array(data) for data in input_data]
101 |
102 | if self._normalize == 'tanh':
103 | try:
104 | input_data[0] = tanh_normalization(input_data[0], self._half_in_val)
105 | except AttributeError:
106 | self._input_val_range(input_data[0][0])
107 | input_data[0] = tanh_normalization(input_data[0], self._half_in_val)
108 |
109 | return input_data
110 |
111 | def _input_val_range(self, in_mat):
112 | # TODO to be modified
113 | self._max_in_val, self._half_in_val = input_val_range(in_mat)
114 |
115 | def size(self):
116 | return len(self.file_list)
117 |
118 | def load_image_from_mat(matfile, name, datatype):
119 | mat = matfile[name].astype(datatype)
120 | return mat
121 |
122 | if __name__ == '__main__':
123 | a = MatlabData(data_dir='D:\\GoogleDrive_Qian\\Foram\\Training\\CNN_GAN_ORIGINAL_64\\',
124 | mat_name_list=['level1Edge'],
125 | normalize='tanh')
126 | print(a.next_batch()[0].shape)
127 | print(a.next_batch()[0][:,30:40,30:40,:])
128 | print(np.amax(a.next_batch()[0]))
--------------------------------------------------------------------------------
/tensorcv/dataflow/normalization.py:
--------------------------------------------------------------------------------
1 | # File: normalization.py
2 | # Author: Qian Ge
3 |
4 | import numpy as np
5 |
6 |
7 | def identity(input_val, *args):
8 | return input_val
9 |
10 | def normalize_tanh(input_val, max_in, half_in):
11 | return (input_val*1.0 - half_in)/half_in
12 |
13 | def normalize_one(input_val, max_in, half_in):
14 | return input_val*1.0/max_in
15 |
--------------------------------------------------------------------------------
/tensorcv/dataflow/operation.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # File: operation.py
4 | # Author: Qian Ge
5 |
6 | import numpy as np
7 | import copy
8 |
9 | from .base import DataFlow
10 | from ..utils.utils import assert_type
11 |
12 |
13 | def display_dataflow(dataflow, data_name='data', simple=False):
14 | assert_type(dataflow, DataFlow)
15 |
16 | n_sample = dataflow.size()
17 | try:
18 | label_list = dataflow.get_label_list()
19 | n_class = len(set(label_list))
20 | print('[{}] num of samples {}, num of classes {}'.
21 | format(data_name, n_sample, n_class))
22 | if not simple:
23 | nelem_dict = {}
24 | for c_label in label_list:
25 | try:
26 | nelem_dict[c_label] += 1
27 | except KeyError:
28 | nelem_dict[c_label] = 1
29 | for c_label in nelem_dict:
30 | print('class {}: {}'.format(
31 | dataflow.label_dict_reverse[c_label],
32 | nelem_dict[c_label]))
33 | except AttributeError:
34 | print('[{}] num of samples {}'.
35 | format(data_name, n_sample))
36 |
37 |
38 | def k_fold_based_class(dataflow, k, shuffle=True):
39 | """Partition dataflow into k equal sized subsamples based on class labels
40 |
41 | Args:
42 | dataflows (DataFlow): DataFlow to be partitioned. Must contain labels.
43 | k (int): number of subsamples
44 | shuffle (bool): data will be shuffled before and after partition
45 | if is true
46 |
47 | Return:
48 | DataFlow: list of k subsample Dataflow
49 | """
50 | assert_type(dataflow, DataFlow)
51 | k = int(k)
52 | assert k > 0, 'k must be an integer grater than 0!'
53 | dataflow_data_list = dataflow.get_data_list()
54 | if not isinstance(dataflow_data_list, list):
55 | dataflow_data_list = [dataflow_data_list]
56 |
57 | label_list = dataflow.get_label_list()
58 | # im_list = dataflow.get_data_list()
59 |
60 | if shuffle:
61 | dataflow.suffle_data()
62 |
63 | class_dict = {}
64 | for idx, cur_label in enumerate(label_list):
65 | try:
66 | for data_idx, data in enumerate(dataflow_data_list):
67 | class_dict[cur_label][data_idx] += [data[idx], ]
68 | except KeyError:
69 | class_dict[cur_label] = [[] for i in range(0, len(dataflow_data_list))]
70 | for data_idx, data in enumerate(dataflow_data_list):
71 | class_dict[cur_label][data_idx] = [data[idx], ]
72 | # class_dict[cur_label] = [cur_im, ]
73 |
74 | # fold_im_list = [[] for i in range(0, k)]
75 | fold_data_list = [[[] for j in range(0, len(dataflow_data_list))] for i in range(0, k)]
76 | # fold_label_list = [[] for i in range(0, k)]
77 | for label_key in class_dict:
78 | cur_data_list = class_dict[label_key]
79 | nelem = int(len(cur_data_list[0]) / k)
80 | start_id = 0
81 | for fold_id in range(0, k-1):
82 | for data_idx, data_list in enumerate(cur_data_list):
83 | fold_data_list[fold_id][data_idx].extend(data_list[start_id : start_id + nelem])
84 | start_id += nelem
85 | for data_idx, data_list in enumerate(cur_data_list):
86 | fold_data_list[k - 1][data_idx].extend(data_list[start_id :])
87 |
88 | data_folds = [copy.deepcopy(dataflow) for i in range(0, k)]
89 |
90 | for cur_fold, cur_data_list in zip(data_folds, fold_data_list):
91 | cur_fold.set_data_list(cur_data_list)
92 |
93 | if shuffle:
94 | cur_fold.suffle_data()
95 |
96 | return data_folds
97 |
98 |
99 | def combine_dataflow(dataflows, shuffle=True):
100 | """Combine several dataflow into one
101 |
102 | Args:
103 | dataflows (DataFlow list): list of DataFlow to be combined
104 | shuffle (bool): data will be shuffled after combined if is true
105 |
106 | Return:
107 | DataFlow: Combined DataFlow
108 | """
109 | if not isinstance(dataflows, list):
110 | dataflows = [dataflows]
111 |
112 | data_list = []
113 | for cur_dataflow in dataflows:
114 | assert_type(cur_dataflow, DataFlow)
115 | cur_data_list = cur_dataflow.get_data_list()
116 | if not isinstance(cur_data_list, list):
117 | cur_data_list = [cur_data_list]
118 | data_list.append(cur_data_list)
119 |
120 | num_data_type = len(data_list[0])
121 | combined_data_list = [[] for i in range(0, num_data_type)]
122 | for cur_data_list in data_list:
123 | for i in range(0, num_data_type):
124 | combined_data_list[i].extend(cur_data_list[i])
125 |
126 | dataflows[0].set_data_list(combined_data_list)
127 | if shuffle:
128 | dataflows[0].suffle_data()
129 |
130 | return dataflows[0]
131 |
--------------------------------------------------------------------------------
/tensorcv/dataflow/preprocess.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # File: preprocess.py
4 | # Author: Qian Ge
5 |
6 | import numpy as np
7 | from scipy import misc
8 |
9 |
10 | def image_fliplr(image):
11 | """ Generate mirror image
12 |
13 | Args:
14 | image (np.array): a 2-D image of shape
15 | [height, width] or a 3-D image of shape
16 | [height, width, channels].
17 |
18 | Returns:
19 | mirror version of original image.
20 | """
21 | return np.fliplr(image)
22 |
23 |
24 | def resize_image_with_smallest_side(image, small_size):
25 | """
26 | Resize single image array with smallest side = small_size and
27 | keep the original aspect ratio.
28 |
29 | Args:
30 | image (np.array): a 2-D image of shape
31 | [height, width] or a 3-D image of shape
32 | [height, width, channels].
33 | small_size (int): A 1-D int. The smallest side of resize image.
34 |
35 | Returns:
36 | rescaled image
37 | """
38 | im_shape = image.shape
39 | shape_dim = len(im_shape)
40 | assert shape_dim <= 3 and shape_dim >= 2,\
41 | 'Wrong format of image!Shape is {}'.format(im_shape)
42 |
43 | height = float(im_shape[0])
44 | width = float(im_shape[1])
45 |
46 | if height <= width:
47 | new_height = int(small_size)
48 | new_width = int(width * new_height/height)
49 | else:
50 | new_width = int(small_size)
51 | new_height = int(height * new_width/width)
52 |
53 | if shape_dim == 2:
54 | im = misc.imresize(image, (new_height, new_width))
55 | elif shape_dim == 3:
56 | im = misc.imresize(image, (new_height, new_width, image.shape[2]))
57 | return im
58 |
59 |
60 | def random_crop_to_size(image, crop_size):
61 | """ Rondomly crop an image into crop_size
62 |
63 | Args:
64 | image (np.array): a 2-D image of shape
65 | [height, width] or a 3-D image of shape
66 | [height, width, channels].
67 | The size has to be larger than cropped image.
68 | crop_size (int or length 2 list): The image size after cropped.
69 |
70 | Returns:
71 | cropped image
72 | """
73 | crop_size = get_shape2D(crop_size)
74 | im_shape = image.shape
75 | shape_dim = len(im_shape)
76 | assert shape_dim <= 3 and shape_dim >= 2, 'Wrong format of image!'
77 |
78 | height = im_shape[0]
79 | width = im_shape[1]
80 | assert height >= crop_size[0] and width >= crop_size[1],\
81 | 'Image must be larger than crop size! {}'.format(im_shape)
82 |
83 | s_h = int(np.floor((height - crop_size[0] + 1) * np.random.rand()))
84 | s_w = int(np.floor((width - crop_size[1] + 1) * np.random.rand()))
85 |
86 | return image[s_h:s_h + crop_size[0], s_w:s_w + crop_size[1]]
87 |
88 |
89 | def four_connor_crop(image, crop_size):
90 | """ Crop an image into crop_size with four corner crops
91 |
92 | Args:
93 | image (np.array): a 2-D image of shape
94 | [height, width] or a 3-D image of shape
95 | [height, width, channels].
96 | The size has to be larger than cropped image.
97 | crop_size (int or length 2 list): The image size after cropped.
98 |
99 | Returns:
100 | four cropped images
101 | """
102 | crop_size = get_shape2D(crop_size)
103 | im_shape = image.shape
104 | shape_dim = len(im_shape)
105 | assert shape_dim <= 3 and shape_dim >= 2, 'Wrong format of image!'
106 | height = im_shape[0]
107 | width = im_shape[1]
108 | assert height >= crop_size[0] and width >= crop_size[1],\
109 | 'Image must be larger than crop size! {}'.format(im_shape)
110 |
111 | crop_im = []
112 | crop_im.append(image[: crop_size[0], : crop_size[1]])
113 | crop_im.append(image[: crop_size[0], width - crop_size[1]:])
114 | crop_im.append(image[height - crop_size[0]:, : crop_size[1]])
115 | crop_im.append(image[height - crop_size[0]:, width - crop_size[1]:])
116 |
117 | return crop_im
118 |
119 |
120 | def center_crop(image, crop_size):
121 | """ Center crop an image into crop_size
122 |
123 | Args:
124 | image (np.array): a 2-D image of shape
125 | [height, width] or a 3-D image of shape
126 | [height, width, channels].
127 | The size has to be larger than cropped image.
128 | crop_size (int or length 2 list): The image size after cropped.
129 |
130 | Returns:
131 | cropped images
132 | """
133 | crop_size = get_shape2D(crop_size)
134 | im_shape = image.shape
135 | shape_dim = len(im_shape)
136 | assert shape_dim <= 3 and shape_dim >= 2, 'Wrong format of image!'
137 | height = im_shape[0]
138 | width = im_shape[1]
139 | assert height >= crop_size[0] and width >= crop_size[1],\
140 | 'Image must be larger than crop size! {}'.format(im_shape)
141 |
142 | return image[(height - crop_size[0])//2:(height + crop_size[0])//2,
143 | (width - crop_size[1])//2:(width + crop_size[1])//2]
144 |
145 |
146 | def random_mirror_resize_crop(image, crop_size, scale_range, mirror_rate=0.5):
147 | """ Ramdomly rescale, crop and image.
148 |
149 | Args:
150 | image (np.array): a 2-D image of shape
151 | [height, width] or a 3-D image of shape
152 | [height, width, channels].
153 | crop_size (int or length 2 list): The image size after cropped.
154 | scale_range (list of int with length 2): The range of scale.
155 | mirror_rate (float): The probability of mirror image.
156 | Must within the range [0, 1]
157 |
158 | Returns:
159 | cropped and rescaled images
160 | """
161 | im_shape = image.shape
162 | shape_dim = len(im_shape)
163 |
164 | assert mirror_rate >= 0 and mirror_rate <= 1,\
165 | 'mirror rate must be in range of [0, 1]!'
166 | assert shape_dim <= 3 and shape_dim >= 2,\
167 | 'Wrong format of image!Shape is {}'.format(im_shape)
168 |
169 | small_size = int(np.random.rand() * (max(scale_range) - min(scale_range))
170 | + min(scale_range))
171 | image = resize_image_with_smallest_side(image, small_size)
172 |
173 | image = random_crop_to_size(image, crop_size)
174 |
175 | if np.random.rand() >= mirror_rate:
176 | image = image_fliplr(image)
177 |
178 | return image
179 |
180 |
181 | def get_shape2D(in_val):
182 | """
183 | Return a 2D shape
184 |
185 | Args:
186 | in_val (int or list with length 2)
187 |
188 | Returns:
189 | list with length 2
190 | """
191 | if in_val is None:
192 | return None
193 | if isinstance(in_val, int):
194 | return [in_val, in_val]
195 | if isinstance(in_val, list):
196 | assert len(in_val) == 2
197 | return in_val
198 | raise RuntimeError('Illegal shape: {}'.format(in_val))
199 |
--------------------------------------------------------------------------------
/tensorcv/dataflow/randoms.py:
--------------------------------------------------------------------------------
1 | # File: randoms.py
2 | # Author: Qian Ge
3 |
4 | import numpy as np
5 |
6 | from .base import DataFlow
7 | from ..utils.utils import get_rng
8 |
9 | __all__ = ['RandomVec']
10 |
11 | class RandomVec(DataFlow):
12 | """ random vector input """
13 | def __init__(self,
14 | len_vec=100):
15 |
16 | self.setup(epoch_val=0, batch_size=1)
17 | self._len_vec = len_vec
18 |
19 | def next_batch(self):
20 | self._epochs_completed += 1
21 | return [np.random.normal(size=(self._batch_size, self._len_vec))]
22 |
23 | def size(self):
24 | return self._batch_size
25 |
26 | def reset_state(self):
27 | self._reset_state()
28 |
29 | def _reset_state(self):
30 | self.rng = get_rng(self)
31 |
32 | if __name__ == '__main__':
33 | vec = RandomVec()
34 | print(vec.next_batch())
35 | print(vec.next_batch())
--------------------------------------------------------------------------------
/tensorcv/dataflow/sequence.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # File: sequence.py
4 | # Author: Qian Ge
5 |
6 | import numpy as np
7 | import collections
8 |
9 | from .base import DataFlow
10 | from .normalization import identity
11 | from ..utils.utils import assert_type
12 |
13 | class SeqDataflow(DataFlow):
14 | """ base class for sequence data
15 |
16 | """
17 | def __init__(self, data_dir='',
18 | load_ratio=1,
19 | predict_step=0,
20 | batch_dict_name=None,
21 | normalize_fnc=identity):
22 | self._pred_step = predict_step
23 | self._data_dir = data_dir
24 | self._normalize_fnc = normalize_fnc
25 | self._load_ratio = load_ratio
26 |
27 | if not isinstance(batch_dict_name, list):
28 | batch_dict_name = [batch_dict_name]
29 | self._batch_dict_name = batch_dict_name
30 |
31 | self.load_entire_seq()
32 |
33 | self._data_id = 0
34 |
35 | self.setup(epoch_val=0, batch_size=1)
36 | self.setup_seq_para(num_step=10, stride=1)
37 | self._updata_batch_partition_len()
38 |
39 | def _updata_batch_partition_len(self):
40 | try:
41 | self._batch_partition_len = self.size() // self._batch_size
42 | except AttributeError:
43 | pass
44 |
45 | def set_batch_size(self, batch_size):
46 | self._batch_size = batch_size
47 | self._updata_batch_partition_len()
48 |
49 | def size(self):
50 | return len(self.get_entire_seq())
51 |
52 | def setup_seq_para(self, num_step, stride):
53 | self._num_step = num_step
54 | self._stride = stride
55 |
56 | def next_batch(self):
57 | b_size = self._batch_size
58 | bp_len = self._batch_partition_len
59 | assert b_size * self._num_step <= self.size()
60 | if self._data_id + bp_len * (b_size - 1) + self._num_step + self._pred_step > self.size():
61 | self._epochs_completed += 1
62 | # self._data_id = 0
63 | # self._data_id = self._epochs_completed
64 | self._data_id = self._epochs_completed % self._num_step
65 | # self._data_id = self._epochs_completed % (bp_len - self._num_step - self._pred_step)
66 | start_id = self._data_id
67 |
68 | batch_data = []
69 | for i in range(b_size):
70 | start_id = self._data_id + bp_len * i
71 | end_id = start_id + self._num_step
72 | cur_data = self.load_data(start_id, end_id)
73 | batch_data.append(cur_data)
74 |
75 | self._data_id += self._num_step
76 | # self._data_id +=
77 | # print(np.array(batch_data).shape)
78 | return self._batch_transform(batch_data)
79 |
80 | def _batch_transform(self, batch_data):
81 | return batch_data
82 |
83 | # if len(np.array(batch_data).shape) == 3:
84 | # return np.array(batch_data).transpose(1, 0, 2)
85 | # else:
86 | # return np.array(batch_data).transpose(1, 0, 2, 3)
87 |
88 | # def next_batch(self):
89 | # assert self.size() > self._batch_size * self._stride + self._num_step - self._stride
90 | # batch_data = []
91 | # batch_id = 0
92 | # start_id = self._data_id
93 | # while batch_id < self._batch_size:
94 | # end_id = start_id + self._num_step
95 | # if end_id + 1 > self.size():
96 | # start_id = 0
97 | # end_id = start_id + self._num_step
98 | # self._epochs_completed += 1
99 | # cur_data = self.load_data(start_id, end_id)
100 | # batch_data.append(cur_data)
101 | # start_id = start_id + self._stride
102 | # batch_id += 1
103 | # return np.array(batch_data).transpose(1, 0, 2)
104 |
105 | def load_data(self, start_id, end_id):
106 | pass
107 | # return self.get_entire_seq()[start_id: end_id]
108 |
109 | def load_entire_seq(self):
110 | pass
111 |
112 | def get_entire_seq(self):
113 | pass
114 |
115 |
116 | class SepWord(SeqDataflow):
117 | def __init__(self, data_dir='',
118 | predict_step=1,
119 | word_dict=None,
120 | batch_dict_name=None,
121 | normalize_fnc=identity):
122 | self.word_dict = word_dict
123 | super(SepWord, self).__init__(data_dir=data_dir,
124 | predict_step=predict_step,
125 | batch_dict_name=batch_dict_name,
126 | normalize_fnc=normalize_fnc)
127 |
128 | def gen_word_dict(self, word_data):
129 | counter = collections.Counter(word_data)
130 | count_pairs = sorted(counter.items(), key=lambda x: (-x[1], x[0]))
131 | words, _ = list(zip(*count_pairs))
132 | self.word_dict = dict(zip(words, range(len(words))))
133 |
134 |
135 | class SeqNumber(SeqDataflow):
136 | def _scale(self, data):
137 | normal_dict = self._normalize_fnc(data)
138 | try:
139 | self.scale_dict = normal_dict['scale_dict']
140 | except KeyError:
141 | pass
142 | return normal_dict['data']
143 |
144 | # max_data = np.amax(data)
145 | # min_data = np.amin(data)
146 | # return (data - min_data) / (max_data - min_data)
147 |
148 | def load_data(self, start_id, end_id):
149 | feature_seq = self.get_entire_seq()[start_id: end_id]
150 | label = self.get_label_seq()[start_id + self._pred_step: end_id + self._pred_step]
151 | return [feature_seq, label]
152 |
153 | def load_entire_seq(self):
154 | pass
155 |
156 | def get_entire_seq(self):
157 | pass
158 |
159 | def get_label_seq(self):
160 | pass
161 |
--------------------------------------------------------------------------------
/tensorcv/dataflow/viz.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # File: viz.py
4 | # Author: Qian Ge
5 |
6 | import matplotlib.pyplot as plt
7 |
8 | from .sequence import SeqDataflow
9 | from ..utils.utils import assert_type
10 |
11 | def plot_seq(dataflow, data_range=None):
12 | assert_type(dataflow, SeqDataflow)
13 | data = dataflow.get_entire_seq()
14 | if data_range is None:
15 | plt.plot(data)
16 | else:
17 | plt.plot(data[data_range])
18 | plt.show()
19 |
--------------------------------------------------------------------------------
/tensorcv/models/__init__.py:
--------------------------------------------------------------------------------
1 | # File: __init__.py
2 | # Author: Qian Ge
3 |
4 |
--------------------------------------------------------------------------------
/tensorcv/models/bk/layers.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 | def conv(x, filter_height, filter_width, num_filters,
4 | name, stride_x = 1, stride_y = 1,
5 | padding = 'SAME', relu = True):
6 | input_channel = int(x.shape[-1])
7 | convolve = lambda i, k: tf.nn.conv2d(i, k,
8 | strides=[1, stride_y, stride_x, 1],
9 | padding = padding)
10 | with tf.variable_scope(name) as scope:
11 | # weights = tf.get_variable('weights', shape = [filter_height, filter_width, input_channel, num_filters])
12 | # biases = tf.get_variable('biases', shape = [num_filters])
13 | weights = new_normal_variable('weights',
14 | shape = [filter_height, filter_width, input_channel, num_filters])
15 | biases = new_normal_variable('biases', shape = [num_filters])
16 |
17 | conv = convolve(x, weights)
18 | bias = tf.nn.bias_add(conv, biases)
19 | # bias = tf.reshape(tf.nn.bias_add(conv, biases), conv.get_shape().as_list())
20 |
21 | if relu:
22 | relu = tf.nn.relu(bias, name = scope.name)
23 | return relu
24 | else:
25 | return bias
26 |
27 | def dconv(x, filter_height, filter_width,
28 | name, fuse_x = None,
29 | output_shape = [], output_channels = None,
30 | stride_x = 2, stride_y = 2, padding = 'SAME'):
31 | input_channels = int(x.shape[-1])
32 |
33 | if fuse_x is not None:
34 | output_shape = tf.shape(fuse_x)
35 | output_channels = int(fuse_x.shape[-1])
36 | elif output_channels is None:
37 | output_channels = output_shape[-1]
38 |
39 | with tf.variable_scope(name) as scope:
40 | # weights = tf.get_variable('weights', shape = [filter_height, filter_width, output_channels, input_channels])
41 | # biases = tf.get_variable('biases', shape = [output_channels])
42 | weights = new_normal_variable('weights',
43 | shape = [filter_height, filter_width, output_channels, input_channels])
44 | biases = new_normal_variable('biases', shape = [output_channels])
45 |
46 | dconv = tf.nn.conv2d_transpose(x, weights,
47 | output_shape = output_shape,
48 | strides=[1, stride_y, stride_x, 1],
49 | padding = padding, name = scope.name)
50 | bias = tf.nn.bias_add(dconv, biases)
51 | bias = tf.reshape(bias, output_shape)
52 |
53 | if fuse_x is not None:
54 | fuse = tf.add(bias, fuse_x, name = 'fuse')
55 | return fuse
56 | else:
57 | return bias
58 |
59 | def fc(x, num_in, num_out, name, relu = True):
60 | num_in = x.get_shape().as_list()[1]
61 | # num_in = x.shape[-1]
62 | with tf.variable_scope(name) as scope:
63 | # weights = tf.get_variable('weights', shape = [num_in, num_out], trainable = True)
64 | # biases = tf.get_variable('biases', shape = [num_out], trainable = True)
65 | weights = new_normal_variable('weights', shape = [num_in, num_out])
66 | biases = new_normal_variable('biases', shape = [num_out])
67 | act = tf.nn.xw_plus_b(x, weights, biases, name = scope.name)
68 |
69 | if relu:
70 | relu = tf.nn.relu(act)
71 | return relu
72 | else:
73 | return act
74 |
75 | def max_pool(x, name, filter_height = 2, filter_width = 2,
76 | stride_x = 2, stride_y = 2, padding = 'SAME'):
77 | return tf.nn.max_pool(x, ksize = [1, filter_height, filter_width, 1],
78 | strides = [1, stride_y, stride_x, 1],
79 | padding = padding, name = name)
80 |
81 | def dropout(x, keep_prob, is_training):
82 | # print(is_training)
83 | return tf.layers.dropout(x, rate = 1 - keep_prob, training = is_training)
84 | # return tf.nn.dropout(x, keep_prob, is_training = is_training)
85 |
86 | def batch_norm(x, name, train = True):
87 | return tf.contrib.layers.batch_norm(x,
88 | decay = 0.9,
89 | updates_collections = None,
90 | epsilon = 1e-5,
91 | scale = False,
92 | is_training = train,
93 | scope = name)
94 |
95 | def leaky_relu(x, leak = 0.2):
96 | return tf.maximum(x, leak*x)
97 |
98 | def new_normal_variable(name, shape = None, trainable = True, stddev = 0.002):
99 | return tf.get_variable(name, shape = shape, trainable = trainable,
100 | initializer = tf.random_normal_initializer(stddev = stddev))
101 |
102 |
103 |
104 |
105 |
--------------------------------------------------------------------------------
/tensorcv/models/losses.py:
--------------------------------------------------------------------------------
1 | # File: losses.py
2 | # Author: Qian Ge
3 |
4 | import tensorflow as tf
5 | import numpy as np
6 |
7 |
8 | def GAN_discriminator_loss(d_real, d_fake, name='d_loss'):
9 | print('---- d_loss -----')
10 | with tf.name_scope(name):
11 | d_loss_real = comp_loss_real(d_real)
12 | d_loss_fake = comp_loss_fake(d_fake)
13 | return tf.identity(d_loss_real + d_loss_fake, name='result')
14 |
15 | def GAN_generator_loss(d_fake, name='g_loss'):
16 | print('---- g_loss -----')
17 | with tf.name_scope(name):
18 | return tf.identity(comp_loss_real(d_fake), name='result')
19 |
20 | def comp_loss_fake(discrim_output):
21 | return tf.reduce_mean(
22 | tf.nn.sigmoid_cross_entropy_with_logits(logits=discrim_output,
23 | labels=tf.zeros_like(discrim_output)))
24 |
25 | def comp_loss_real(discrim_output):
26 | return tf.reduce_mean(
27 | tf.nn.sigmoid_cross_entropy_with_logits(logits=discrim_output,
28 | labels=tf.ones_like(discrim_output)))
29 |
30 |
--------------------------------------------------------------------------------
/tensorcv/models/model_builder/base.py:
--------------------------------------------------------------------------------
1 | # File: base.py
2 | # Author: Qian Ge
3 |
4 | from abc import abstractmethod
5 |
6 | import tensorflow as tf
7 | import numpy as np
8 |
9 |
10 | __all__ = ['BaseBuilder']
11 |
12 | class BaseBuilder(object):
13 | """ base model for model builder """
14 | def __init__(self):
15 | self.input = []
16 | self.output = []
17 | def Add(self, BaseLayer):
18 | """ add one layer """
19 | pass
20 |
21 | # def set_batch_size(self, val):
22 | # self._batch_size = val
23 |
24 | # def get_batch_size(self):
25 | # return self._batch_size
26 |
27 | # def set_is_training(self, is_training = True):
28 | # self.is_training = is_training
29 |
30 | # def get_placeholder(self):
31 | # return self._get_placeholder()
32 |
33 | # def _get_placeholder(self):
34 | # return []
35 |
36 | # # TODO to be modified
37 | # def get_prediction_placeholder(self):
38 | # return self._get_prediction_placeholder()
39 |
40 | # def _get_prediction_placeholder(self):
41 | # return []
42 |
43 | # def get_graph_feed(self):
44 | # return self._get_graph_feed()
45 |
46 | # def _get_graph_feed(self):
47 | # return {}
48 |
49 | # def create_graph(self):
50 | # self._create_graph()
51 | # self._setup_graph()
52 | # # self._setup_summary()
53 |
54 | # @abstractmethod
55 | # def _create_graph(self):
56 | # raise NotImplementedError()
57 |
58 | # def _setup_graph(self):
59 | # pass
60 |
61 | # # TDDO move outside of class
62 | # # summary will be created before prediction
63 | # # which is unnecessary
64 | # def setup_summary(self):
65 | # self._setup_summary()
66 |
67 | # def _setup_summary(self):
68 | # pass
69 |
70 |
71 | class BaseModel(ModelDes):
72 | """ Model with single loss and single optimizer """
73 |
74 | def get_optimizer(self):
75 | try:
76 | return self.optimizer
77 | except AttributeError:
78 | self.optimizer = self._get_optimizer()
79 | return self.optimizer
80 |
81 | @property
82 | def default_collection(self):
83 | return 'default'
84 |
85 | def _get_optimizer(self):
86 | raise NotImplementedError()
87 |
88 | def get_loss(self):
89 | try:
90 | return self._loss
91 | except AttributeError:
92 | self._loss = self._get_loss()
93 | tf.summary.scalar('loss_summary', self.get_loss(),
94 | collections = [self.default_collection])
95 | return self._loss
96 |
97 | def _get_loss(self):
98 | raise NotImplementedError()
99 |
100 | def get_grads(self):
101 | try:
102 | return self.grads
103 | except AttributeError:
104 | optimizer = self.get_optimizer()
105 | loss = self.get_loss()
106 | self.grads = optimizer.compute_gradients(loss)
107 | [tf.summary.histogram('gradient/' + var.name, grad,
108 | collections = [self.default_collection]) for grad, var in self.grads]
109 | return self.grads
110 |
111 | class GANBaseModel(ModelDes):
112 | """ Base model for GANs """
113 | def __init__(self, input_vec_length, learning_rate):
114 | self.input_vec_length = input_vec_length
115 | assert len(learning_rate) == 2
116 | self.dis_learning_rate, self.gen_learning_rate = learning_rate
117 |
118 | @property
119 | def g_collection(self):
120 | return 'default_g'
121 |
122 | @property
123 | def d_collection(self):
124 | return 'default_d'
125 |
126 | def get_random_vec_placeholder(self):
127 | try:
128 | return self.Z
129 | except AttributeError:
130 | self.Z = tf.placeholder(tf.float32, [None, self.input_vec_length])
131 | return self.Z
132 |
133 | def _get_prediction_placeholder(self):
134 | return self.get_random_vec_placeholder()
135 |
136 | def get_graph_feed(self):
137 | default_feed = self._get_graph_feed()
138 | random_input_feed = self._get_random_input_feed()
139 | default_feed.update(random_input_feed)
140 | return default_feed
141 |
142 | def _get_random_input_feed(self):
143 | feed = {self.get_random_vec_placeholder():
144 | np.random.normal(size = (self.get_batch_size(),
145 | self.input_vec_length))}
146 | return feed
147 |
148 | def create_GAN(self, real_data, gen_data_name = 'gen_data'):
149 | with tf.variable_scope('generator') as scope:
150 | gen_data = self._generator()
151 | scope.reuse_variables()
152 | sample_gen_data = tf.identity(self._generator(train = False),
153 | name = gen_data_name)
154 |
155 | with tf.variable_scope('discriminator') as scope:
156 | d_real = self._discriminator(real_data)
157 | scope.reuse_variables()
158 | d_fake = self._discriminator(gen_data)
159 |
160 | with tf.name_scope('discriminator_out'):
161 | tf.summary.histogram('discrim_real',
162 | tf.nn.sigmoid(d_real),
163 | collections = [self.d_collection])
164 | tf.summary.histogram('discrim_gen',
165 | tf.nn.sigmoid(d_fake),
166 | collections = [self.d_collection])
167 |
168 | return gen_data, sample_gen_data, d_real, d_fake
169 |
170 | # def get_random_input_feed(self):
171 | # return self._get_random_input_feed()
172 |
173 | # def _get_random_input_feed(self):
174 | # return {}
175 |
176 | def get_discriminator_optimizer(self):
177 | try:
178 | return self.d_optimizer
179 | except AttributeError:
180 | self.d_optimizer = self._get_discriminator_optimizer()
181 | return self.d_optimizer
182 |
183 | def get_generator_optimizer(self):
184 | try:
185 | return self.g_optimizer
186 | except AttributeError:
187 | self.g_optimizer = self._get_generator_optimizer()
188 | return self.g_optimizer
189 |
190 | def _get_discriminator_optimizer(self):
191 | raise NotImplementedError()
192 |
193 | def _get_generator_optimizer(self):
194 | raise NotImplementedError()
195 |
196 | def get_discriminator_loss(self):
197 | try:
198 | return self.d_loss
199 | except AttributeError:
200 | self.d_loss = self._get_discriminator_loss()
201 | tf.summary.scalar('d_loss_summary', self.d_loss,
202 | collections = [self.d_collection])
203 | return self.d_loss
204 |
205 | def get_generator_loss(self):
206 | try:
207 | return self.g_loss
208 | except AttributeError:
209 | self.g_loss = self._get_generator_loss()
210 | tf.summary.scalar('g_loss_summary', self.g_loss,
211 | collections = [self.g_collection])
212 | return self.g_loss
213 |
214 | def _get_discriminator_loss(self):
215 | raise NotImplementedError()
216 |
217 | def _get_generator_loss(self):
218 | raise NotImplementedError()
219 |
220 | def get_discriminator_grads(self):
221 | try:
222 | return self.d_grads
223 | except AttributeError:
224 | d_training_vars = [v for v in tf.trainable_variables()
225 | if v.name.startswith('discriminator/')]
226 | optimizer = self.get_discriminator_optimizer()
227 | loss = self.get_discriminator_loss()
228 | self.d_grads = optimizer.compute_gradients(loss,
229 | var_list = d_training_vars)
230 |
231 | [tf.summary.histogram('d_gradient/' + var.name, grad,
232 | collections = [self.d_collection])
233 | for grad, var in self.d_grads]
234 | return self.d_grads
235 |
236 | def get_generator_grads(self):
237 | try:
238 | return self.g_grads
239 | except AttributeError:
240 | g_training_vars = [v for v in tf.trainable_variables()
241 | if v.name.startswith('generator/')]
242 | optimizer = self.get_generator_optimizer()
243 | loss = self.get_generator_loss()
244 | self.g_grads = optimizer.compute_gradients(loss,
245 | var_list = g_training_vars)
246 | [tf.summary.histogram('g_gradient/' + var.name, grad,
247 | collections = [self.g_collection])
248 | for grad, var in self.g_grads]
249 | return self.g_grads
250 |
251 | @staticmethod
252 | def comp_loss_fake(discrim_output):
253 | return tf.reduce_mean(
254 | tf.nn.sigmoid_cross_entropy_with_logits(logits = discrim_output,
255 | labels = tf.zeros_like(discrim_output)))
256 |
257 | @staticmethod
258 | def comp_loss_real(discrim_output):
259 | return tf.reduce_mean(
260 | tf.nn.sigmoid_cross_entropy_with_logits(logits = discrim_output,
261 | labels = tf.ones_like(discrim_output)))
262 |
263 |
264 |
265 |
266 |
267 |
268 |
269 |
270 |
--------------------------------------------------------------------------------
/tensorcv/models/utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # File: utils.py
4 | # Author: Qian Ge
5 |
6 | import math
7 |
8 |
9 | def deconv_size(input_height, input_width, stride=2):
10 | """
11 | Compute the feature size (height and width) after filtering with
12 | a specific stride. Mostly used for setting the shape for deconvolution.
13 |
14 | Args:
15 | input_height (int): height of input feature
16 | input_width (int): width of input feature
17 | stride (int): stride of the filter
18 |
19 | Return:
20 | (int, int): Height and width of feature after filtering.
21 | """
22 | return int(math.ceil(float(input_height) / float(stride))),\
23 | int(math.ceil(float(input_width) / float(stride)))
--------------------------------------------------------------------------------
/tensorcv/predicts/__init__.py:
--------------------------------------------------------------------------------
1 | # File: __init__.py
2 | # Author: Qian Ge
3 |
4 | from .config import *
5 | from .predictions import *
6 |
--------------------------------------------------------------------------------
/tensorcv/predicts/base.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # File: base.py
4 | # Author: Qian Ge
5 |
6 | import os
7 |
8 | import tensorflow as tf
9 |
10 | from .config import PridectConfig
11 | from ..utils.sesscreate import ReuseSessionCreator
12 | from ..utils.common import assert_type
13 | from ..callbacks.hooks import Prediction2Hook
14 |
15 | __all__ = ['Predictor']
16 |
17 |
18 | class Predictor(object):
19 | """Base class for a predictor. Used to run all predictions.
20 |
21 | Attributes:
22 | config (PridectConfig): the config used for this predictor
23 | model (ModelDes):
24 | input (DataFlow):
25 | sess (tf.Session):
26 | hooked_sess (tf.train.MonitoredSession):
27 | """
28 | def __init__(self, config):
29 | """ Inits Predictor with config (PridectConfig).
30 |
31 | Will create session as well as monitored sessions for
32 | each predictions, and load pre-trained parameters.
33 |
34 | Args:
35 | config (PridectConfig): the config used for this predictor
36 | """
37 | assert_type(config, PridectConfig)
38 | self._config = config
39 | self._model = config.model
40 |
41 | self._input = config.dataflow
42 | self._result_dir = config.result_dir
43 |
44 | # TODO to be modified
45 | self._model.set_is_training(False)
46 | self._model.create_graph()
47 | self._restore_vars = self._config.restore_vars
48 |
49 | # pass saving directory to predictions
50 | for pred in self._config.predictions:
51 | pred.setup(self._result_dir)
52 |
53 | hooks = [Prediction2Hook(pred) for pred in self._config.predictions]
54 |
55 | self.sess = self._config.session_creator.create_session()
56 | self.hooked_sess = tf.train.MonitoredSession(
57 | session_creator=ReuseSessionCreator(self.sess), hooks=hooks)
58 |
59 | # load pre-trained parameters
60 | load_model_path = os.path.join(self._config.model_dir,
61 | self._config.model_name)
62 | if self._restore_vars is not None:
63 | # variables = tf.contrib.framework.get_variables_to_restore()
64 | # variables_to_restore = [v for v in variables if v.name.split('/')[0] in self._restore_vars]
65 | # print(variables_to_restore)
66 | saver = tf.train.Saver(self._restore_vars)
67 | else:
68 | saver = tf.train.Saver()
69 | saver.restore(self.sess, load_model_path)
70 |
71 | def run_predict(self):
72 | """
73 | Run predictions and the process after finishing predictions.
74 | """
75 | with self.sess.as_default():
76 | self._input.before_read_setup()
77 | self._predict_step()
78 | for pred in self._config.predictions:
79 | pred.after_finish_predict()
80 |
81 | self.after_prediction()
82 |
83 | def _predict_step(self):
84 | """ Run predictions. Defined in subclass.
85 | """
86 | pass
87 |
88 | def after_prediction(self):
89 | self._after_prediction()
90 |
91 | def _after_prediction(self):
92 | pass
93 |
94 |
--------------------------------------------------------------------------------
/tensorcv/predicts/config.py:
--------------------------------------------------------------------------------
1 | import scipy.misc
2 | import os
3 | import numpy as np
4 |
5 | from ..dataflow.base import DataFlow
6 | from ..models.base import ModelDes
7 | from ..utils.default import get_default_session_config
8 | from ..utils.sesscreate import NewSessionCreator
9 | from .predictions import PredictionBase
10 | from ..utils.common import check_dir
11 |
12 | __all__ = ['PridectConfig']
13 |
14 | def assert_type(v, tp):
15 | assert isinstance(v, tp), \
16 | "Expect " + str(tp) + ", but " + str(v.__class__) + " is given!"
17 |
18 | class PridectConfig(object):
19 | def __init__(self,
20 | dataflow=None, model=None,
21 | model_dir=None, model_name='',
22 | restore_vars=None,
23 | session_creator=None,
24 | predictions=None,
25 | batch_size=1,
26 | default_dirs=None):
27 | """
28 | Args:
29 | """
30 | self.model_name = model_name
31 | try:
32 | self.model_dir = os.path.join(default_dirs.model_dir)
33 | check_dir(self.model_dir)
34 | except AttributeError:
35 | raise AttributeError('model_dir is not set!')
36 |
37 | try:
38 | self.result_dir = os.path.join(default_dirs.result_dir)
39 | check_dir(self.result_dir)
40 | except AttributeError:
41 | raise AttributeError('result_dir is not set!')
42 |
43 | if restore_vars is not None:
44 | if not isinstance(restore_vars, list):
45 | restore_vars = [restore_vars]
46 | self.restore_vars = restore_vars
47 |
48 |
49 | assert dataflow is not None, "dataflow cannot be None!"
50 | assert_type(dataflow, DataFlow)
51 | self.dataflow = dataflow
52 |
53 | assert batch_size > 0
54 | self.dataflow.set_batch_size(batch_size)
55 | self.batch_size = batch_size
56 |
57 | assert model is not None, "model cannot be None!"
58 | assert_type(model, ModelDes)
59 | self.model = model
60 |
61 | assert predictions is not None, "predictions cannot be None"
62 | if not isinstance(predictions, list):
63 | predictions = [predictions]
64 | for pred in predictions:
65 | assert_type(pred, PredictionBase)
66 | self.predictions = predictions
67 |
68 | # if not isinstance(callbacks, list):
69 | # callbacks = [callbacks]
70 | # self._callbacks = callbacks
71 |
72 | if session_creator is None:
73 | self.session_creator = \
74 | NewSessionCreator(config=get_default_session_config())
75 | else:
76 | raise ValueError('custormer session creator is \
77 | not allowed at this point!')
78 |
79 | @property
80 | def callbacks(self):
81 | return self._callbacks
82 |
83 |
--------------------------------------------------------------------------------
/tensorcv/predicts/predictions.py:
--------------------------------------------------------------------------------
1 | # File: predictions.py
2 | # Author: Qian Ge
3 |
4 | import os
5 | import scipy.io
6 | import scipy.misc
7 |
8 | import tensorflow as tf
9 | import numpy as np
10 |
11 | from ..utils.common import get_tensors_by_names
12 | from ..utils.viz import *
13 |
14 | __all__ = ['PredictionImage', 'PredictionScalar', 'PredictionMat', 'PredictionMeanScalar', 'PredictionOverlay']
15 |
16 | def assert_type(v, tp):
17 | assert isinstance(v, tp), \
18 | "Expect " + str(tp) + ", but " + str(v.__class__) + " is given!"
19 |
20 | class PredictionBase(object):
21 | """ base class for prediction
22 |
23 | Attributes:
24 | _predictions
25 | _prefix_list
26 | _global_ind
27 | _save_dir
28 | """
29 | def __init__(self, prediction_tensors, save_prefix):
30 | """ init prediction object
31 |
32 | Get tensors to be predicted and the prefix for saving
33 | each tensors
34 |
35 | Args:
36 | prediction_tensors : list[string] A tensor name or list of tensor names
37 | save_prefix: list[string] A string or list of strings
38 | Length of prediction_tensors and save_prefix have
39 | to be the same
40 | """
41 | if not isinstance(prediction_tensors, list):
42 | prediction_tensors = [prediction_tensors]
43 | if not isinstance(save_prefix, list):
44 | save_prefix = [save_prefix]
45 | assert len(prediction_tensors) == len(save_prefix), \
46 | 'Length of prediction_tensors {} and save_prefix {} has to be the same'.\
47 | format(len(prediction_tensors), len(save_prefix))
48 |
49 | self._predictions = prediction_tensors
50 | self._prefix_list = save_prefix
51 | self._global_ind = 0
52 |
53 | def setup(self, result_dir):
54 | assert os.path.isdir(result_dir)
55 | self._save_dir = result_dir
56 |
57 | self._predictions = get_tensors_by_names(self._predictions)
58 |
59 | def get_predictions(self):
60 | return self._predictions
61 |
62 | def after_prediction(self, results):
63 | """ process after predition
64 | default to save predictions
65 | """
66 | self._save_prediction(results)
67 |
68 | def _save_prediction(self, results):
69 | pass
70 |
71 | def after_finish_predict(self):
72 | """ process after all prediction steps """
73 | self._after_finish_predict()
74 |
75 | def _after_finish_predict(self):
76 | pass
77 |
78 | class PredictionImage(PredictionBase):
79 | """ Predict image output and save as files.
80 |
81 | Images are saved every batch. Each batch result can be
82 | save in one image or individule images.
83 |
84 | """
85 | def __init__(self, prediction_image_tensors,
86 | save_prefix, merge_im=False,
87 | tanh=False, color=False):
88 | """
89 | Args:
90 | prediction_image_tensors (list): a list of tensor names
91 | save_prefix (list): a list of file prefix for saving
92 | each tensor in prediction_image_tensors
93 | merge_im (bool): merge output of one batch or not
94 | """
95 | self._merge = merge_im
96 | self._tanh = tanh
97 | self._color = color
98 | super(PredictionImage, self).__init__(prediction_tensors=prediction_image_tensors,
99 | save_prefix=save_prefix)
100 |
101 | def _save_prediction(self, results):
102 |
103 | for re, prefix in zip(results, self._prefix_list):
104 | cur_global_ind = self._global_ind
105 | if self._merge and re.shape[0] > 1:
106 | grid_size = self._get_grid_size(re.shape[0])
107 | save_path = os.path.join(self._save_dir,
108 | str(cur_global_ind) + '_' + prefix + '.png')
109 | save_merge_images(np.squeeze(re),
110 | [grid_size, grid_size], save_path,
111 | tanh=self._tanh, color=self._color)
112 | cur_global_ind += 1
113 | else:
114 | for im in re:
115 | save_path = os.path.join(self._save_dir,
116 | str(cur_global_ind) + '_' + prefix + '.png')
117 | if self._color:
118 | im = intensity_to_rgb(np.squeeze(im), normalize=True)
119 | scipy.misc.imsave(save_path, np.squeeze(im))
120 | cur_global_ind += 1
121 | self._global_ind = cur_global_ind
122 |
123 | def _get_grid_size(self, batch_size):
124 | try:
125 | return self._grid_size
126 | except AttributeError:
127 | self._grid_size = np.ceil(batch_size**0.5).astype(int)
128 | return self._grid_size
129 |
130 | class PredictionOverlay(PredictionImage):
131 | def __init__(self, prediction_image_tensors,
132 | save_prefix, merge_im=False,
133 | tanh=False, color=False):
134 | if not isinstance(prediction_image_tensors, list):
135 | prediction_image_tensors = [prediction_image_tensors]
136 | assert len(prediction_image_tensors) == 2,\
137 | '[PredictionOverlay] requires two image tensors but the input len = {}.'.\
138 | format(len(prediction_image_tensors))
139 |
140 | super(PredictionOverlay, self).__init__(prediction_image_tensors,
141 | save_prefix, merge_im=merge_im,
142 | tanh=tanh, color=color)
143 |
144 | self._overlay_prefix = '{}_{}'.format(self._prefix_list[0], self._prefix_list[1])
145 |
146 | def _save_prediction(self, results):
147 | cur_global_ind = self._global_ind
148 |
149 | if self._merge and results[0].shape[0] > 1:
150 | overlay_im_list = []
151 | for im_1, im_2 in zip(results[0], results[1]):
152 | overlay_im = image_overlay(im_1, im_2, color=self._color)
153 | overlay_im_list.append(overlay_im)
154 |
155 | grid_size = self._get_grid_size(results[0].shape[0])
156 | save_path = os.path.join(self._save_dir,
157 | str(cur_global_ind) + '_' + self._overlay_prefix + '.png')
158 | save_merge_images(np.squeeze(overlay_im_list),
159 | [grid_size, grid_size], save_path,
160 | tanh=self._tanh, color=False)
161 | cur_global_ind += 1
162 | else:
163 | for im_1, im_2 in zip(results[0], results[1]):
164 | overlay_im = image_overlay(im_1, im_2, color=self._color)
165 | save_path = os.path.join(self._save_dir,
166 | str(cur_global_ind) + '_' + self._overlay_prefix + '.png')
167 | scipy.misc.imsave(save_path, np.squeeze(overlay_im))
168 | cur_global_ind += 1
169 | self._global_ind = cur_global_ind
170 |
171 | class PredictionScalar(PredictionBase):
172 | def __init__(self, prediction_scalar_tensors, print_prefix):
173 | """
174 | Args:
175 | prediction_scalar_tensors (list): a list of tensor names
176 | print_prefix (list): a list of name prefix for printing
177 | each tensor in prediction_scalar_tensors
178 | """
179 |
180 | super(PredictionScalar, self).__init__(prediction_tensors=prediction_scalar_tensors,
181 | save_prefix=print_prefix)
182 |
183 | def _save_prediction(self, results):
184 | for re, prefix in zip(results, self._prefix_list):
185 | print('{} = {}'.format(prefix, re))
186 |
187 | class PredictionMeanScalar(PredictionScalar):
188 | def __init__(self, prediction_scalar_tensors, print_prefix):
189 |
190 | super(PredictionMeanScalar, self).__init__(prediction_scalar_tensors=prediction_scalar_tensors,
191 | print_prefix=print_prefix)
192 |
193 | self.scalar_list = [[] for i in range(0, len(self._predictions))]
194 |
195 | def _save_prediction(self, results):
196 | cnt = 0
197 | for re, prefix in zip(results, self._prefix_list):
198 | print('{} = {}'.format(prefix, re))
199 | self.scalar_list[cnt].append(re)
200 | cnt += 1
201 |
202 | def _after_finish_predict(self):
203 | for i, prefix in enumerate(self._prefix_list):
204 | print('Overall {} = {}'.format(prefix, np.mean(self.scalar_list[i])))
205 |
206 |
207 | class PredictionMat(PredictionBase):
208 | def _save_prediction(self, results):
209 | save_path = os.path.join(self._save_dir,
210 | str(self._global_ind) + '_' + 'batch_test' + '.mat')
211 | scipy.io.savemat(save_path, {name: np.squeeze(val) for name, val
212 | in zip(self._prefix_list, results)})
213 |
214 | self._global_ind += 1
215 |
216 |
217 |
218 |
219 |
220 |
221 |
222 |
223 |
224 |
225 |
226 |
227 |
228 |
229 |
230 |
231 |
--------------------------------------------------------------------------------
/tensorcv/predicts/simple.py:
--------------------------------------------------------------------------------
1 | # File: predictions.py
2 | # Author: Qian Ge
3 |
4 | import tensorflow as tf
5 |
6 | from .base import Predictor
7 |
8 | __all__ = ['SimpleFeedPredictor']
9 |
10 | def assert_type(v, tp):
11 | assert isinstance(v, tp),\
12 | "Expect " + str(tp) + ", but " + str(v.__class__) + " is given!"
13 |
14 | class SimpleFeedPredictor(Predictor):
15 | """ predictor with feed input """
16 | # set_is_training
17 | def __init__(self, config):
18 | super(SimpleFeedPredictor, self).__init__(config)
19 | # TODO change len_input to other
20 | placeholders = self._model.get_prediction_placeholder()
21 | if not isinstance(placeholders, list):
22 | placeholders = [placeholders]
23 | self._plhs = placeholders
24 | # self.placeholder = self._model.get_random_vec_placeholder()
25 | # assert self.len_input <= len(self.placeholder)
26 | # self.placeholder = self.placeholder[0:self.len_input]
27 |
28 | def _predict_step(self):
29 | while self._input.epochs_completed < 1:
30 | try:
31 | cur_batch = self._input.next_batch()
32 | except AttributeError:
33 | cur_batch = self._input.next_batch()
34 |
35 | feed = dict(zip(self._plhs, cur_batch))
36 | self.hooked_sess.run(fetches=[], feed_dict=feed)
37 | self._input.reset_epochs_completed(0)
38 |
39 | def _after_prediction(self):
40 | self._input.after_reading()
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
--------------------------------------------------------------------------------
/tensorcv/tfdataflow/__init__.py:
--------------------------------------------------------------------------------
1 | # File: __init__.py
2 | # Author: Qian Ge
--------------------------------------------------------------------------------
/tensorcv/tfdataflow/base.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # File: base.py
4 | # Author: Qian Ge
5 |
6 | import tensorflow as tf
7 | from tensorflow.python.lib.io import file_io
8 |
9 | from ..dataflow.base import DataFlow
10 | from ..dataflow.normalization import identity
11 | from ..utils.utils import assert_type
12 |
13 |
14 | class DataFromTfrecord(DataFlow):
15 | def __init__(self, tfname,
16 | record_names,
17 | record_types,
18 | raw_types,
19 | decode_fncs,
20 | batch_dict_name,
21 | shuffle=True,
22 | data_shape=[],
23 | feature_len_list=None,
24 | pf=identity):
25 |
26 | if not isinstance(tfname, list):
27 | tfname = [tfname]
28 | if not isinstance(record_names, list):
29 | record_names = [record_names]
30 | if not isinstance(record_types, list):
31 | record_types = [record_types]
32 | for c_type in record_types:
33 | assert_type(c_type, tf.DType)
34 | if not isinstance(raw_types, list):
35 | raw_types = [raw_types]
36 | for raw_type in raw_types:
37 | assert_type(raw_type, tf.DType)
38 | if not isinstance(decode_fncs, list):
39 | decode_fncs = [decode_fncs]
40 | if not isinstance(batch_dict_name, list):
41 | batch_dict_name = [batch_dict_name]
42 | assert len(record_types) == len(record_names)
43 | assert len(record_types) == len(batch_dict_name)
44 |
45 | if feature_len_list is None:
46 | feature_len_list = [[] for i in range(0, len(record_names))]
47 | elif not isinstance(feature_len_list, list):
48 | feature_len_list = [feature_len_list]
49 | self._feat_len_list = feature_len_list
50 | if len(self._feat_len_list) < len(record_names):
51 | self._feat_len_list.extend([[] for i in range(0, len(record_names) - len(self._feat_len_list))])
52 |
53 | self.record_names = record_names
54 | self.record_types = record_types
55 | self.raw_types = raw_types
56 | self.decode_fncs = decode_fncs
57 | self.data_shape = data_shape
58 | self._tfname = tfname
59 | self._batch_dict_name = batch_dict_name
60 |
61 | self._shuffle = shuffle
62 |
63 | # self._batch_step = 0
64 | # self.reset_epochs_completed(0)
65 | # self.set_batch_size(batch_size)
66 | self.setup_decode_data()
67 | self.setup(epoch_val=0, batch_size=1)
68 |
69 |
70 | def set_batch_size(self, batch_size):
71 | self._batch_size = batch_size
72 | self.updata_data_op(batch_size)
73 | self.updata_step_per_epoch(batch_size)
74 |
75 | def updata_data_op(self, batch_size):
76 | try:
77 | if self._shuffle is True:
78 | self._data = tf.train.shuffle_batch(
79 | self._decode_data,
80 | batch_size=batch_size,
81 | capacity=batch_size * 4,
82 | num_threads=2,
83 | min_after_dequeue=batch_size * 2)
84 | else:
85 | print('***** data is not shuffled *****')
86 | self._data = tf.train.batch(
87 | self._decode_data,
88 | batch_size=batch_size,
89 | capacity=batch_size,
90 | num_threads=1,
91 | allow_smaller_final_batch=False)
92 | # self._data = self._decode_data[0]
93 | # print(self._data)
94 | except AttributeError:
95 | pass
96 |
97 | def reset_epochs_completed(self, val):
98 | self._epochs_completed = val
99 | self._batch_step = 0
100 |
101 | # def _setup(self, **kwargs):
102 | def setup_decode_data(self):
103 | # n_epoch = kwargs['num_epoch']
104 |
105 | feature = {}
106 | for record_name, r_type, cur_size in zip(self.record_names, self.record_types, self._feat_len_list):
107 | feature[record_name] = tf.FixedLenFeature(cur_size, r_type)
108 | # filename_queue = tf.train.string_input_producer(self._tfname, num_epochs=n_epoch)
109 | filename_queue = tf.train.string_input_producer(self._tfname)
110 | reader = tf.TFRecordReader()
111 | _, serialized_example = reader.read(filename_queue)
112 | features = tf.parse_single_example(serialized_example, features=feature)
113 | decode_data = [decode_fnc(features[record_name], raw_type)
114 | for decode_fnc, record_name, raw_type
115 | in zip(self.decode_fncs, self.record_names, self.raw_types)]
116 |
117 | for idx, c_shape in enumerate(self.data_shape):
118 | if c_shape:
119 | decode_data[idx] = tf.reshape(decode_data[idx], c_shape)
120 |
121 | self._decode_data = decode_data
122 |
123 | # self._data = self._decode_data[0]
124 | try:
125 | self.set_batch_size(batch_size=self._batch_size)
126 | except AttributeError:
127 | self.set_batch_size(batch_size=1)
128 |
129 | def updata_step_per_epoch(self, batch_size):
130 | self._step_per_epoch = int(self.size() / batch_size)
131 |
132 | def before_read_setup(self):
133 | self.coord = tf.train.Coordinator()
134 | self.threads = tf.train.start_queue_runners(coord=self.coord)
135 |
136 | def next_batch(self):
137 | sess = tf.get_default_session()
138 | batch_data = sess.run(self._data)
139 | self._batch_step += 1
140 | if self._batch_step % self._step_per_epoch == 0:
141 | self._epochs_completed += 1
142 | # print(batch_data[2])
143 | return batch_data
144 |
145 | def next_batch_dict(self):
146 | sess = tf.get_default_session()
147 | batch_data = sess.run(self._data)
148 | self._batch_step += 1
149 | if self._batch_step % self._step_per_epoch == 0:
150 | self._epochs_completed += 1
151 | batch_dict = {name: data for name, data in zip(self._batch_dict_name, batch_data)}
152 | return batch_dict
153 |
154 | def after_reading(self):
155 | self.coord.request_stop()
156 | self.coord.join(self.threads)
157 |
158 | def size(self):
159 | try:
160 | return self._size
161 | except AttributeError:
162 | self._size = sum(1 for f in self._tfname for _ in tf.python_io.tf_record_iterator(f))
163 | return self._size
164 |
--------------------------------------------------------------------------------
/tensorcv/tfdataflow/convert.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # File: convert.py
4 | # Author: Qian Ge
5 |
6 | import tensorflow as tf
7 |
8 | from ..utils.utils import assert_type
9 | from ..dataflow.base import DataFlow
10 |
11 |
12 | def int64_feature(value):
13 | return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
14 |
15 |
16 | def bytes_feature(value):
17 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
18 |
19 |
20 | def convert_image(image):
21 | return bytes_feature(tf.compat.as_bytes(image.tostring()))
22 |
23 |
24 | def float_feature(value):
25 | return tf.train.Feature(float_list=tf.train.FloatList(value=value))
26 |
27 |
28 | def dataflow2tfrecord(dataflow, tfname, record_names, c_fncs):
29 | assert_type(dataflow, DataFlow)
30 | dataflow.setup(epoch_val=0, batch_size=1)
31 |
32 | if not isinstance(record_names, list):
33 | record_names = [record_names]
34 | if not isinstance(c_fncs, list):
35 | c_fncs = [c_fncs]
36 | assert len(c_fncs) == len(record_names)
37 |
38 | tfrecords_filename = tfname
39 | writer = tf.python_io.TFRecordWriter(tfrecords_filename)
40 |
41 | while dataflow.epochs_completed < 1:
42 | batch_data = dataflow.next_batch()
43 | feature = {}
44 | for record_name, convert_fnc, data in\
45 | zip(record_names, c_fncs, batch_data):
46 | feature[record_name] = convert_fnc(data[0])
47 |
48 | example = tf.train.Example(features=tf.train.Features(feature=feature))
49 | writer.write(example.SerializeToString())
50 |
51 | writer.close()
52 |
--------------------------------------------------------------------------------
/tensorcv/tfdataflow/write.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # File: write.py
4 | # Author: Qian Ge
5 |
6 | import tensorflow as tf
7 |
8 | from ..dataflow.base import DataFlow
9 | from ..models.base import BaseModel
10 | from ..utils.utils import assert_type
11 | from .convert import float_feature
12 |
13 |
14 | class Bottleneck2TFrecord(object):
15 | def __init__(self, nets, record_feat_names,
16 | feat_preprocess=tf.identity):
17 | if not isinstance(nets, list):
18 | nets = [nets]
19 | for net in nets:
20 | assert_type(net, BaseModel)
21 | if not isinstance(record_feat_names, list):
22 | record_feat_names = [record_feat_names]
23 | assert len(nets) == len(record_feat_names)
24 | self._w_f_names = record_feat_names
25 |
26 | self._feat_ops = []
27 | self._feed_plh_keys = []
28 | self._net_input_dicts = []
29 | for net in nets:
30 | net.set_is_training(False)
31 | net.create_graph()
32 | self._net_input_dicts.append(net.input_dict)
33 | self._feed_plh_keys.append(net.prediction_plh_dict)
34 | self._feat_ops.append(feat_preprocess(net.layer['conv_out']))
35 |
36 | def write(self, tfname, dataflow,
37 | record_dataflow_keys=[], record_dataflow_names=[], c_fncs=[]):
38 | assert_type(dataflow, DataFlow)
39 | dataflow.setup(epoch_val=0, batch_size=1)
40 |
41 | if not isinstance(record_dataflow_names, list):
42 | record_dataflow_names = [record_dataflow_names]
43 | if not isinstance(c_fncs, list):
44 | c_fncs = [c_fncs]
45 | if not isinstance(record_dataflow_keys, list):
46 | record_dataflow_keys = [record_dataflow_keys]
47 | assert len(c_fncs) == len(record_dataflow_names)
48 | assert len(record_dataflow_keys) == len(record_dataflow_names)
49 |
50 | tfrecords_filename = tfname
51 | writer = tf.python_io.TFRecordWriter(tfrecords_filename)
52 |
53 | with tf.Session() as sess:
54 | sess.run(tf.local_variables_initializer())
55 | sess.run(tf.global_variables_initializer())
56 | dataflow.before_read_setup()
57 | cnt = 0
58 | while dataflow.epochs_completed < 1:
59 | print('Writing data {}...'.format(cnt))
60 | batch_data = dataflow.next_batch_dict()
61 |
62 | feats = []
63 | for feat_op, feed_plh_key, net_input_dict in zip(self._feat_ops, self._feed_plh_keys, self._net_input_dicts):
64 | feed_dict = {net_input_dict[key]: batch_data[key]
65 | for key in feed_plh_key}
66 | feats.append(sess.run(feat_op, feed_dict=feed_dict))
67 |
68 | # feature = {}
69 | # for record_name, convert_fnc, key in zip(record_dataflow_names, c_fncs, record_dataflow_keys):
70 | # feature[record_name] = convert_fnc(batch_data[key][0])
71 |
72 |
73 | # for record_name, feat in zip(self._w_f_names, feats):
74 | # feature[record_name] = float_feature(feat.reshape(-1).tolist())
75 |
76 | # feature_list = []
77 | batch_size = len(feats[0])
78 | for idx in range(0, batch_size):
79 | feature = {}
80 | for record_name, convert_fnc, key in zip(record_dataflow_names, c_fncs, record_dataflow_keys):
81 | feature[record_name] = convert_fnc(
82 | batch_data[key][idx])
83 |
84 | for record_name, feat in zip(self._w_f_names, feats):
85 | feature[record_name] =\
86 | float_feature(feat[idx].reshape(-1).tolist())
87 |
88 | example = tf.train.Example(
89 | features=tf.train.Features(feature=feature))
90 | writer.write(example.SerializeToString())
91 |
92 | cnt += 1
93 | dataflow.after_reading()
94 |
95 | writer.flush()
96 | writer.close()
97 |
--------------------------------------------------------------------------------
/tensorcv/train/__init__.py:
--------------------------------------------------------------------------------
1 | # File: __init__.py
2 | # Author: Qian Ge
3 |
4 |
--------------------------------------------------------------------------------
/tensorcv/train/base.py:
--------------------------------------------------------------------------------
1 | from abc import abstractmethod
2 | import weakref
3 | import os
4 |
5 | import tensorflow as tf
6 |
7 | from .config import TrainConfig
8 | from ..callbacks.base import Callback
9 | from ..callbacks.group import Callbacks
10 | from ..utils.sesscreate import ReuseSessionCreator
11 | from ..callbacks.monitors import TrainingMonitor, Monitors
12 |
13 |
14 | __all__ = ['Trainer']
15 |
16 | def assert_type(v, tp):
17 | assert isinstance(v, tp),\
18 | "Expect " + str(tp) + ", but " + str(v.__class__) + " is given!"
19 |
20 | class Trainer(object):
21 | """ base class for trainer """
22 | def __init__(self, config):
23 | assert_type(config, TrainConfig)
24 | self._is_load = config.is_load
25 | self.config = config
26 | self.model = config.model
27 | self.model.ex_init_model(config.dataflow, weakref.proxy(self))
28 | self.dataflow = config.dataflow
29 | # self.monitors = self.config.monitors
30 | self._global_step = 0
31 | self._callbacks = []
32 | self.monitors = []
33 |
34 | self.default_dirs = config.default_dirs
35 |
36 | @property
37 | def epochs_completed(self):
38 | return self.dataflow.epochs_completed
39 |
40 | @property
41 | def get_global_step(self):
42 | return self._global_step
43 |
44 | def register_callback(self, cb):
45 | assert_type(cb, Callback)
46 | assert not isinstance(self._callbacks, Callbacks), \
47 | "callbacks have been setup"
48 | self._callbacks.append(cb)
49 |
50 | def register_monitor(self, monitor):
51 | assert_type(monitor, TrainingMonitor)
52 | assert not isinstance(self.monitors, Monitors), \
53 | "monitors have been setup"
54 | self.monitors.append(monitor)
55 | self.register_callback(monitor)
56 |
57 |
58 | def _create_session(self):
59 | hooks = self._callbacks.get_hooks()
60 | self.sess = self.config.session_creator.create_session()
61 |
62 | self.hooked_sess = tf.train.MonitoredSession(
63 | session_creator=ReuseSessionCreator(self.sess), hooks=hooks)
64 |
65 | if self._is_load:
66 | load_model_path = os.path.join(self.config.model_dir,
67 | self.config.model_name)
68 | saver = tf.train.Saver()
69 | saver.restore(self.sess, load_model_path)
70 |
71 | def main_loop(self):
72 | with self.sess.as_default():
73 | self._callbacks.before_train()
74 | while self.epochs_completed <= self.config.max_epoch:
75 | self._global_step += 1
76 | print('Epoch: {}. Step: {}'.\
77 | format(self.epochs_completed, self._global_step))
78 | # self._callbacks.before_epoch()
79 | # TODO to be modified
80 | self.model.set_is_training(True)
81 | self._run_step()
82 | # self._callbacks.after_epoch()
83 | self._callbacks.trigger_step()
84 | self._callbacks.after_train()
85 |
86 | def train(self):
87 | self.setup()
88 | self.main_loop()
89 |
90 | @abstractmethod
91 | def _run_step(self):
92 | model_feed = self.model.get_graph_feed()
93 | self.hooked_sess.run(self.train_op, feed_dict=model_feed)
94 |
95 | def setup(self):
96 | # setup graph from model
97 | self.setup_graph()
98 |
99 | # setup callbacks
100 | for cb in self.config.callbacks:
101 | self.register_callback(cb)
102 | for monitor in self.config.monitors:
103 | self.register_monitor(monitor)
104 | self._callbacks = Callbacks(self._callbacks)
105 | self._callbacks.setup_graph(weakref.proxy(self))
106 | self.monitors = Monitors(self.monitors)
107 | # create session
108 | self._create_session()
109 |
110 |
111 |
112 | self.sess.graph.finalize()
113 |
114 | def setup_graph(self):
115 | self.model.create_graph()
116 | self._setup()
117 | self.model.setup_summary()
118 |
119 | def _setup(self):
120 | pass
121 |
122 |
123 |
124 |
125 |
126 |
127 |
--------------------------------------------------------------------------------
/tensorcv/train/config.py:
--------------------------------------------------------------------------------
1 | import scipy.misc
2 | import os
3 | import numpy as np
4 |
5 | from ..dataflow.base import DataFlow
6 | from ..models.base import ModelDes, GANBaseModel
7 | from ..utils.default import get_default_session_config
8 | from ..utils.sesscreate import NewSessionCreator
9 | from ..callbacks.monitors import TFSummaryWriter
10 | from ..callbacks.summary import TrainSummary
11 | from ..utils.common import check_dir
12 |
13 | __all__ = ['TrainConfig', 'GANTrainConfig']
14 |
15 | def assert_type(v, tp):
16 | assert isinstance(v, tp),\
17 | "Expect " + str(tp) + ", but " + str(v.__class__) + " is given!"
18 |
19 | class TrainConfig(object):
20 | def __init__(self,
21 | dataflow=None, model=None,
22 | callbacks=[],
23 | session_creator=None,
24 | monitors=None,
25 | batch_size=1, max_epoch=100,
26 | summary_periodic=None,
27 | is_load=False,
28 | model_name=None,
29 | default_dirs=None):
30 | self.default_dirs = default_dirs
31 |
32 | assert_type(monitors, TFSummaryWriter), \
33 | "monitors has to be TFSummaryWriter at this point!"
34 | if not isinstance(monitors, list):
35 | monitors = [monitors]
36 | self.monitors = monitors
37 |
38 | assert dataflow is not None, "dataflow cannot be None!"
39 | assert_type(dataflow, DataFlow)
40 | self.dataflow = dataflow
41 |
42 | assert model is not None, "model cannot be None!"
43 | assert_type(model, ModelDes)
44 | self.model = model
45 |
46 | assert batch_size > 0 and max_epoch > 0
47 | self.dataflow.set_batch_size(batch_size)
48 | self.model.set_batch_size(batch_size)
49 | self.batch_size = batch_size
50 | self.max_epoch = max_epoch
51 |
52 | self.is_load = is_load
53 | if is_load:
54 | assert not model_name is None,\
55 | '[TrainConfig]: model_name cannot be None when is_load is True!'
56 | self.model_name = model_name
57 | try:
58 | self.model_dir = os.path.join(default_dirs.model_dir)
59 | check_dir(self.model_dir)
60 | except AttributeError:
61 | raise AttributeError('model_dir is not set!')
62 |
63 | # if callbacks is None:
64 | # callbacks = []
65 | if not isinstance(callbacks, list):
66 | callbacks = [callbacks]
67 | self._callbacks = callbacks
68 |
69 | # TODO model.default_collection only in BaseModel class
70 | if isinstance(summary_periodic, int):
71 | self._callbacks.append(
72 | TrainSummary(key=model.default_collection,
73 | periodic=summary_periodic))
74 |
75 | if session_creator is None:
76 | self.session_creator = \
77 | NewSessionCreator(config=get_default_session_config())
78 | else:
79 | raise ValueError('custormer session creator is not allowed at this point!')
80 |
81 | @property
82 | def callbacks(self):
83 | return self._callbacks
84 |
85 |
86 | class GANTrainConfig(TrainConfig):
87 | def __init__(self,
88 | dataflow=None, model=None,
89 | discriminator_callbacks=[],
90 | generator_callbacks=[],
91 | session_creator=None,
92 | monitors=None,
93 | batch_size=1, max_epoch=100,
94 | summary_d_periodic=None,
95 | summary_g_periodic=None,
96 | default_dirs=None):
97 |
98 | assert_type(model, GANBaseModel)
99 |
100 | if not isinstance(discriminator_callbacks, list):
101 | discriminator_callbacks = [discriminator_callbacks]
102 | self._dis_callbacks = discriminator_callbacks
103 |
104 | if not isinstance(generator_callbacks, list):
105 | generator_callbacks = [generator_callbacks]
106 | self._gen_callbacks = generator_callbacks
107 |
108 | if isinstance(summary_d_periodic, int):
109 | self._dis_callbacks.append(
110 | TrainSummary(key=model.d_collection,
111 | periodic=summary_d_periodic))
112 | if isinstance(summary_g_periodic, int):
113 | self._dis_callbacks.append(
114 | TrainSummary(key=model.g_collection,
115 | periodic=summary_g_periodic))
116 |
117 | callbacks = self._dis_callbacks + self._gen_callbacks
118 |
119 | super(GANTrainConfig, self).__init__(
120 | dataflow=dataflow, model=model,
121 | callbacks=callbacks,
122 | session_creator=session_creator,
123 | monitors=monitors,
124 | batch_size=batch_size, max_epoch=ßmax_epoch,
125 | default_dirs=default_dirs)
126 | @property
127 | def dis_callbacks(self):
128 | return self._dis_callbacks
129 | @property
130 | def gen_callbacks(self):
131 | return self._gen_callbacks
132 |
133 |
--------------------------------------------------------------------------------
/tensorcv/train/simple.py:
--------------------------------------------------------------------------------
1 | from abc import abstractmethod
2 |
3 | import tensorflow as tf
4 |
5 | from .config import TrainConfig, GANTrainConfig
6 | from .base import Trainer
7 | from ..callbacks.inputs import FeedInput
8 | from ..callbacks.group import Callbacks
9 | from ..callbacks.hooks import Callback2Hook
10 | from ..models.base import BaseModel, GANBaseModel
11 | from ..utils.sesscreate import ReuseSessionCreator
12 |
13 |
14 | __all__ = ['SimpleFeedTrainer']
15 |
16 | def assert_type(v, tp):
17 | assert isinstance(v, tp),\
18 | "Expect " + str(tp) + ", but " + str(v.__class__) + " is given!"
19 |
20 | class SimpleFeedTrainer(Trainer):
21 | """ single optimizer """
22 | def __init__(self, config):
23 | assert_type(config.model, BaseModel)
24 | super(SimpleFeedTrainer, self).__init__(config)
25 |
26 | def _setup(self):
27 | # TODO to be modified
28 | cbs = FeedInput(self.dataflow, self.model.get_train_placeholder())
29 |
30 | self.config.callbacks.append(cbs)
31 |
32 | grads = self.model.get_grads()
33 | opt = self.model.get_optimizer()
34 | self.train_op = opt.apply_gradients(grads, name='train')
35 |
36 | class GANFeedTrainer(Trainer):
37 | def __init__(self, config):
38 | assert_type(config, GANTrainConfig)
39 | # assert_type(config.model, GANBaseModel)
40 |
41 | # config.model.set_batch_size(config.batch_size)
42 |
43 | super(GANFeedTrainer, self).__init__(config)
44 |
45 | def _setup(self):
46 | # TODO to be modified
47 | # Since FeedInput only have before_run,
48 | # it is safe to put this cb only in hooks.
49 | cbs = FeedInput(self.dataflow, self.model.get_train_placeholder())
50 | # self.config.callbacks.append(cbs)
51 | self.feed_input_hook = [Callback2Hook(cbs)]
52 |
53 | dis_grads = self.model.get_discriminator_grads()
54 | dis_opt = self.model.get_discriminator_optimizer()
55 | self.dis_train_op = dis_opt.apply_gradients(dis_grads,
56 | name='discriminator_train')
57 |
58 | gen_grads = self.model.get_generator_grads()
59 | gen_opt = self.model.get_generator_optimizer()
60 | self.gen_train_op = gen_opt.apply_gradients(gen_grads,
61 | name='generator_train')
62 |
63 | def _create_session(self):
64 | self._dis_callbacks = Callbacks([cb
65 | for cb in self.config.dis_callbacks])
66 | self._gen_callbacks = Callbacks([cb
67 | for cb in self.config.gen_callbacks])
68 | dis_hooks = self._dis_callbacks.get_hooks()
69 | gen_hooks = self._gen_callbacks.get_hooks()
70 |
71 | self.sess = self.config.session_creator.create_session()
72 | self.dis_hooked_sess = tf.train.MonitoredSession(
73 | session_creator=ReuseSessionCreator(self.sess),
74 | hooks=dis_hooks + self.feed_input_hook)
75 | self.gen_hooked_sess = tf.train.MonitoredSession(
76 | session_creator=ReuseSessionCreator(self.sess),
77 | hooks=gen_hooks)
78 |
79 | def _run_step(self):
80 | model_feed = self.model.get_graph_feed()
81 | self.dis_hooked_sess.run(self.dis_train_op, feed_dict=model_feed)
82 |
83 | for k in range(0,2):
84 | model_feed = self.model.get_graph_feed()
85 | self.gen_hooked_sess.run(self.gen_train_op, feed_dict=ßmodel_feed)
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
--------------------------------------------------------------------------------
/tensorcv/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # File: __init__.py
2 | # Author: Qian Ge
3 |
--------------------------------------------------------------------------------
/tensorcv/utils/common.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # File: common.py
4 | # Author: Qian Ge
5 |
6 | import math
7 | import os
8 |
9 | import tensorflow as tf
10 |
11 | __all__ = ['apply_mask', 'apply_mask_inverse', 'get_tensors_by_names',
12 | 'deconv_size', 'match_tensor_save_name']
13 |
14 |
15 | def apply_mask(input_matrix, mask):
16 | """Get partition of input_matrix using index 1 in mask.
17 |
18 | Args:
19 | input_matrix (Tensor): A Tensor
20 | mask (int): A Tensor of type int32 with indices in {0, 1}. Shape
21 | has to be the same as input_matrix.
22 |
23 | Return:
24 | A Tensor with elements from data with entries in mask equal to 1.
25 | """
26 | return tf.dynamic_partition(input_matrix, mask, 2)[1]
27 |
28 |
29 | def apply_mask_inverse(input_matrix, mask):
30 | """Get partition of input_matrix using index 0 in mask.
31 |
32 | Args:
33 | input_matrix (Tensor): A Tensor
34 | mask (int): A Tensor of type int32 with indices in {0, 1}. Shape
35 | has to be the same as input_matrix.
36 |
37 | Return:
38 | A Tensor with elements from data with entries in mask equal to 0.
39 | """
40 | return tf.dynamic_partition(input_matrix, mask, 2)[0]
41 |
42 |
43 | def get_tensors_by_names(names):
44 | """Get a list of tensors by the input name list.
45 |
46 | Args:
47 | names (str): A str or a list of str
48 |
49 | Return:
50 | A list of tensors with name in input names.
51 |
52 | Warning:
53 | If more than one tensor have the same name in the graph. This function
54 | will only return the tensor with name NAME:0.
55 | """
56 | if not isinstance(names, list):
57 | names = [names]
58 |
59 | graph = tf.get_default_graph()
60 | tensor_list = []
61 | # TODO assume there is no repeativie names
62 | for name in names:
63 | tensor_name = name + ':0'
64 | tensor_list += graph.get_tensor_by_name(tensor_name),
65 | return tensor_list
66 |
67 |
68 | def deconv_size(input_height, input_width, stride=2):
69 | """
70 | Compute the feature size (height and width) after filtering with
71 | a specific stride. Mostly used for setting the shape for deconvolution.
72 |
73 | Args:
74 | input_height (int): height of input feature
75 | input_width (int): width of input feature
76 | stride (int): stride of the filter
77 |
78 | Return:
79 | (int, int): Height and width of feature after filtering.
80 | """
81 | print('***** WARNING ********: deconv_size is moved to models.utils.py')
82 | return int(math.ceil(float(input_height) / float(stride))),\
83 | int(math.ceil(float(input_width) / float(stride)))
84 |
85 |
86 | def match_tensor_save_name(tensor_names, save_names):
87 | """
88 | Match tensor_names and corresponding save_names for saving the results of
89 | the tenors. If the number of tensors is less or equal to the length
90 | of save names, tensors will be saved using the corresponding names in
91 | save_names. Otherwise, tensors will be saved using their own names.
92 | Used for prediction or inference.
93 |
94 | Args:
95 | tensor_names (str): List of tensor names
96 | save_names (str): List of names for saving tensors
97 |
98 | Return:
99 | (list, list): List of tensor names and list of names to save
100 | the tensors.
101 | """
102 | if not isinstance(tensor_names, list):
103 | tensor_names = [tensor_names]
104 | if save_names is None:
105 | return tensor_names, tensor_names
106 | elif not isinstance(save_names, list):
107 | save_names = [save_names]
108 | if len(save_names) < len(tensor_names):
109 | return tensor_names, tensor_names
110 | else:
111 | return tensor_names, save_names
112 |
113 |
114 | def check_dir(input_dir):
115 | print('***** WARNING ********: check_dir is moved to utils.utils.py')
116 | assert input_dir is not None, "dir cannot be None!"
117 | assert os.path.isdir(input_dir), input_dir + ' does not exist!'
118 |
119 |
120 | def assert_type(v, tp):
121 | print('***** WARNING ********: assert_type is moved to utils.utils.py')
122 | """
123 | Assert type of input v be type tp
124 | """
125 | assert isinstance(v, tp),\
126 | "Expect " + str(tp) + ", but " + str(v.__class__) + " is given!"
127 |
--------------------------------------------------------------------------------
/tensorcv/utils/debug.py:
--------------------------------------------------------------------------------
1 | def hello_cv():
2 | print('hello')
3 |
--------------------------------------------------------------------------------
/tensorcv/utils/default.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # File: default.py
4 | # Author: Qian Ge
5 |
6 | import tensorflow as tf
7 |
8 | __all__ = ['get_default_session_config']
9 |
10 |
11 | def get_default_session_config(memory_fraction=1):
12 | """Default config of a TensorFlow session
13 |
14 | Args:
15 | memory_fraction (float): Memory fraction of GPU for this session
16 |
17 | Return:
18 | tf.ConfigProto(): Config of session.
19 | """
20 | conf = tf.ConfigProto()
21 | conf.gpu_options.per_process_gpu_memory_fraction = memory_fraction
22 | conf.gpu_options.allow_growth = True
23 |
24 | return conf
25 |
--------------------------------------------------------------------------------
/tensorcv/utils/sesscreate.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # File: default.py
4 | # Author: Qian Ge
5 | # Modified from https://github.com/ppwwyyxx/tensorpack/blob/master/tensorpack/tfutils/sesscreate.py
6 |
7 | import tensorflow as tf
8 |
9 | from .default import get_default_session_config
10 |
11 | __all__ = ['NewSessionCreator', 'ReuseSessionCreator']
12 |
13 |
14 | class NewSessionCreator(tf.train.SessionCreator):
15 | """
16 | tf.train.SessionCreator for a new session
17 | """
18 | def __init__(self, target='', graph=None, config=None):
19 | """ Inits NewSessionCreator with targe, graph and config.
20 |
21 | Args:
22 | target: same as :meth:`tf.Session.__init__()`.
23 | graph: same as :meth:`tf.Session.__init__()`.
24 | config: same as :meth:`tf.Session.__init__()`. Default to
25 | :func:`utils.default.get_default_session_config()`.
26 | """
27 | self.target = target
28 | if config is not None:
29 | self.config = config
30 | else:
31 | self.config = get_default_session_config()
32 | self.graph = graph
33 |
34 | def create_session(self):
35 | """Create session as well as initialize global and local variables
36 |
37 | Return:
38 | A tf.Session object containing nodes for all of the
39 | operations in the underlying TensorFlow graph.
40 | """
41 | sess = tf.Session(target=self.target,
42 | graph=self.graph, config=self.config)
43 | sess.run(tf.global_variables_initializer())
44 | sess.run(tf.local_variables_initializer())
45 | return sess
46 |
47 |
48 | class ReuseSessionCreator(tf.train.SessionCreator):
49 | """
50 | tf.train.SessionCreator for reuse an existed session
51 | """
52 | def __init__(self, sess):
53 | """ Inits ReuseSessionCreator with an existed session.
54 |
55 | Args:
56 | sess (tf.Session): an existed tf.Session object
57 | """
58 | self.sess = sess
59 |
60 | def create_session(self):
61 | """Create session by reusing an existing session
62 |
63 | Return:
64 | A reused tf.Session object containing nodes for all of the
65 | operations in the underlying TensorFlow graph.
66 | """
67 | return self.sess
68 |
--------------------------------------------------------------------------------
/tensorcv/utils/utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # File: utils.py
4 | # Author: Qian Ge
5 |
6 | import os
7 | from datetime import datetime
8 | import numpy as np
9 |
10 | __all__ = ['get_rng']
11 |
12 | _RNG_SEED = None
13 |
14 |
15 | def get_rng(obj=None):
16 | """
17 | This function is copied from `tensorpack
18 | `__.
19 | Get a good RNG seeded with time, pid and the object.
20 | Args:
21 | obj: some object to use to generate random seed.
22 | Returns:
23 | np.random.RandomState: the RNG.
24 | """
25 | seed = (id(obj) + os.getpid() +
26 | int(datetime.now().strftime("%Y%m%d%H%M%S%f"))) % 4294967295
27 | if _RNG_SEED is not None:
28 | seed = _RNG_SEED
29 | return np.random.RandomState(seed)
30 |
31 |
32 | def check_dir(input_dir):
33 | assert input_dir is not None, "dir cannot be None!"
34 | assert os.path.isdir(input_dir), input_dir + ' does not exist!'
35 |
36 |
37 | def assert_type(v, tp):
38 | """
39 | Assert type of input v be type tp
40 | """
41 | assert isinstance(v, tp),\
42 | "Expect " + str(tp) + ", but " + str(v.__class__) + " is given!"
43 |
--------------------------------------------------------------------------------
/tensorcv/utils/viz.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # File: viz.py
3 | # Author: Qian Ge
4 |
5 | import matplotlib.pyplot as plt
6 | import numpy as np
7 | import scipy.misc
8 |
9 |
10 | def intensity_to_rgb(intensity, cmap='jet', normalize=False):
11 | """
12 | This function is copied from `tensorpack
13 | `__.
14 | Convert a 1-channel matrix of intensities to an RGB image employing
15 | a colormap.
16 | This function requires matplotlib. See `matplotlib colormaps
17 | `_ for a
18 | list of available colormap.
19 |
20 | Args:
21 | intensity (np.ndarray): array of intensities such as saliency.
22 | cmap (str): name of the colormap to use.
23 | normalize (bool): if True, will normalize the intensity so that it has
24 | minimum 0 and maximum 1.
25 |
26 | Returns:
27 | np.ndarray: an RGB float32 image in range [0, 255], a colored heatmap.
28 | """
29 |
30 | # assert intensity.ndim == 2, intensity.shape
31 | intensity = intensity.astype("float")
32 |
33 | if normalize:
34 | intensity -= intensity.min()
35 | intensity /= intensity.max()
36 |
37 | if intensity.ndim == 3:
38 | return intensity.astype('float32') * 255.0
39 |
40 | cmap = plt.get_cmap(cmap)
41 | intensity = cmap(intensity)[..., :3]
42 | return intensity.astype('float32') * 255.0
43 |
44 |
45 | def save_merge_images(images, merge_grid, save_path, color=False, tanh=False):
46 | """Save multiple images with same size into one larger image.
47 |
48 | The best size number is
49 | int(max(sqrt(image.shape[0]),sqrt(image.shape[1]))) + 1
50 |
51 | Args:
52 | images (np.ndarray): A batch of image array to be merged with size
53 | [BATCH_SIZE, HEIGHT, WIDTH, CHANNEL].
54 | merge_grid (list): List of length 2. The grid size for merge images.
55 | save_path (str): Path for saving the merged image.
56 | color (bool): Whether convert intensity image to color image.
57 | tanh (bool): If True, will normalize the image in range [-1, 1]
58 | to [0, 1] (for GAN models).
59 |
60 | Example:
61 | The batch_size is 64, then the size is recommended [8, 8].
62 | The batch_size is 32, then the size is recommended [6, 6].
63 | """
64 |
65 | # normalization of tanh output
66 | img = images
67 |
68 | if tanh:
69 | img = (img + 1.0) / 2.0
70 |
71 | if color:
72 | # TODO
73 | img_list = []
74 | for im in np.squeeze(img):
75 | im = intensity_to_rgb(np.squeeze(im), normalize=True)
76 | img_list.append(im)
77 | img = np.array(img_list)
78 | # img = np.expand_dims(img, 0)
79 |
80 | if len(img.shape) == 2 or (len(img.shape) == 3 and img.shape[2] <= 4):
81 | img = np.expand_dims(img, 0)
82 | # img = images
83 | h, w = img.shape[1], img.shape[2]
84 | merge_img = np.zeros((h * merge_grid[0], w * merge_grid[1], 3))
85 | if len(img.shape) < 4:
86 | img = np.expand_dims(img, -1)
87 |
88 | for idx, image in enumerate(img):
89 | i = idx % merge_grid[1]
90 | j = idx // merge_grid[1]
91 | merge_img[j*h:j*h+h, i*w:i*w+w, :] = image
92 |
93 | scipy.misc.imsave(save_path, merge_img)
94 |
95 |
96 | def image_overlay(im_1, im_2, color=True, normalize=True):
97 | """Overlay two images with the same size.
98 |
99 | Args:
100 | im_1 (np.ndarray): image arrary
101 | im_2 (np.ndarray): image arrary
102 | color (bool): Whether convert intensity image to color image.
103 | normalize (bool): If both color and normalize are True, will
104 | normalize the intensity so that it has minimum 0 and maximum 1.
105 |
106 | Returns:
107 | np.ndarray: an overlay image of im_1*0.5 + im_2*0.5
108 | """
109 | if color:
110 | im_1 = intensity_to_rgb(np.squeeze(im_1), normalize=normalize)
111 | im_2 = intensity_to_rgb(np.squeeze(im_2), normalize=normalize)
112 |
113 | return im_1*0.5 + im_2*0.5
114 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import tensorcv.utils.debug as debug
3 |
4 |
5 | debug.hello_cv()
6 | print(tf.__version__)
7 |
8 |
--------------------------------------------------------------------------------
/test/VGG_pre_trained.py:
--------------------------------------------------------------------------------
1 | # File: VGG.py
2 | # Author: Qian Ge
3 |
4 | import argparse
5 | import os
6 |
7 | import numpy as np
8 | import tensorflow as tf
9 | import scipy
10 | import cv2
11 |
12 | from tensorcv.dataflow.image import *
13 | from tensorcv.callbacks import *
14 | from tensorcv.predicts import *
15 | from tensorcv.train.config import TrainConfig
16 | from tensorcv.train.simple import SimpleFeedTrainer
17 |
18 | import VGG
19 | import config
20 |
21 | VGG_MEAN = [103.939, 116.779, 123.68]
22 |
23 | # def load_VGG_model(session, model_path, skip_layer = []):
24 | # weights_dict = np.load(model_path, encoding='latin1').item()
25 | # for layer_name in weights_dict:
26 | # print(layer_name)
27 | # if layer_name not in skip_layer:
28 | # with tf.variable_scope(layer_name, reuse = True):
29 | # for data in weights_dict[layer_name]:
30 | # if len(data.shape) == 1:
31 | # var = tf.get_variable('biases', trainable = False)
32 | # session.run(var.assign(data))
33 | # else:
34 | # var = tf.get_variable('weights', trainable = False)
35 | # session.run(var.assign(data))
36 |
37 | def get_args():
38 | parser = argparse.ArgumentParser()
39 |
40 | # parser.add_argument('--input_channel', default = 1,
41 | # help = 'Number of image channels')
42 | # parser.add_argument('--num_class', default = 2,
43 | # help = 'Number of classes')
44 | parser.add_argument('--batch_size', default = 128)
45 |
46 | parser.add_argument('--predict', help = 'Run prediction', action='store_true')
47 | parser.add_argument('--train', help = 'Train the model', action='store_true')
48 |
49 | return parser.parse_args()
50 |
51 | if __name__ == '__main__':
52 | Model = VGG.VGG19(num_class = 1000,
53 | num_channels = 3,
54 | im_height = 224,
55 | im_width = 224)
56 |
57 | keep_prob = tf.placeholder(tf.float32, name='keep_prob')
58 | image = tf.placeholder(tf.float32, name = 'image',
59 | shape = [None, 64, 64, 3])
60 | input_im = tf.image.resize_images(image, [224, 224])
61 |
62 | Model.create_model([input_im, keep_prob])
63 | predict_op = tf.argmax(Model.output, dimension = -1)
64 |
65 | dataset_val = ImageLabelFromFile('.JPEG', data_dir = config.valid_data_dir,
66 | label_file_name = 'val_annotations.txt',
67 | num_channel = 3,
68 | label_dict = {},
69 | shuffle = False)
70 |
71 | # dataset_val = ImageData('.JPEG', data_dir = config.valid_data_dir,
72 | # shuffle = False)
73 |
74 | dataset_val.setup(epoch_val = 0, batch_size = 32)
75 | # o_label_dict = dataset_val.label_dict_reverse
76 |
77 |
78 |
79 | word_dict = {}
80 | word_file = open(os.path.join('D:\\Qian\\GitHub\\workspace\\dataset\\tiny-imagenet-200\\tiny-imagenet-200\\',
81 | 'words.txt'), 'r')
82 | lines = word_file.read().split('\n')
83 | for line in lines:
84 | label, word = line.split('\t')
85 | word_dict[label] = word
86 |
87 | with tf.Session() as sess:
88 | sess.run(tf.global_variables_initializer())
89 | load_VGG_model(sess, 'D:\\Qian\\GitHub\\workspace\\VGG\\vgg19.npy')
90 | batch_data = dataset_val.next_batch()
91 |
92 | result = sess.run(predict_op, feed_dict = {keep_prob: 1, image: batch_data[0]})
93 | print(result)
94 | print([word_dict[o_label_dict[label]] for label in batch_data[1]])
95 |
96 |
97 |
98 |
--------------------------------------------------------------------------------
/test/config.py:
--------------------------------------------------------------------------------
1 | # File: config.py
2 | # Author: Qian Ge
3 |
4 | # directory of training data
5 | # data_dir = 'D:\\GoogleDrive_Qian\\Foram\\Training\\CNN_GAN_ORIGINAL_64\\'
6 | # data_dir = 'D:\\Qian\\GitHub\\workspace\\tensorflow-DCGAN\\cifar-10-python.tar\\')
7 | train_data_dir = 'D:\\Qian\\GitHub\\workspace\\dataset\\tiny-imagenet-200\\tiny-imagenet-200\\train\\'
8 |
9 | # directory of validataion data
10 | valid_data_dir = 'D:\\Qian\\GitHub\\workspace\\dataset\\tiny-imagenet-200\\tiny-imagenet-200\\val\\'
11 |
12 | # directory for saving inference data
13 | infer_dir = 'D:\\Qian\\GitHub\\workspace\\test\\result\\'
14 |
15 | # directory for saving summary
16 | summary_dir = 'D:\\Qian\\GitHub\\workspace\\test\\'
17 |
18 | # directory for saving checkpoint
19 | checkpoint_dir = 'D:\\Qian\\GitHub\\workspace\\test\\'
20 |
21 | # directory for restoring checkpoint
22 | model_dir = 'D:\\Qian\\GitHub\\workspace\\test\\'
23 |
24 | # directory for saving prediction results
25 | result_dir = 'D:\\Qian\\GitHub\\workspace\\test\\2\\'
26 |
27 |
28 |
29 |
30 |
--------------------------------------------------------------------------------
/test/test.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import numpy as np
4 | import tensorflow as tf
5 |
6 | from tensorcv.dataflow.matlab import MatlabData
7 | from tensorcv.models.layers import *
8 | from tensorcv.models.base import BaseModel
9 | from tensorcv.utils.common import apply_mask, get_tensors_by_names
10 | from tensorcv.train.config import TrainConfig
11 | from tensorcv.predicts.config import PridectConfig
12 | from tensorcv.train.simple import SimpleFeedTrainer
13 | from tensorcv.callbacks.saver import ModelSaver
14 | from tensorcv.callbacks.summary import TrainSummary
15 | from tensorcv.callbacks.inference import FeedInference
16 | from tensorcv.callbacks.monitors import TFSummaryWriter
17 | from tensorcv.callbacks.inferencer import InferScalars
18 | from tensorcv.predicts.simple import SimpleFeedPredictor
19 | from tensorcv.predicts.predictions import PredictionImage
20 | from tensorcv.callbacks.debug import CheckScalar
21 |
22 | class Model(BaseModel):
23 | def __init__(self, num_channels = 3, num_class = 2,
24 | learning_rate = 0.0001):
25 | self.learning_rate = learning_rate
26 | self.num_channels = num_channels
27 | self.num_class = num_class
28 | self.set_is_training(True)
29 |
30 | def _get_placeholder(self):
31 | return [self.image, self.gt, self.mask]
32 | # image, label, mask
33 |
34 | def _get_prediction_placeholder(self):
35 | return self.image
36 |
37 | def _get_graph_feed(self):
38 | if self.is_training:
39 | feed = {self.keep_prob: 0.5}
40 | else:
41 | feed = {self.keep_prob: 1}
42 | return feed
43 |
44 | def _create_graph(self):
45 |
46 | self.keep_prob = tf.placeholder(tf.float32, name='keep_prob')
47 |
48 | self.image = tf.placeholder(tf.float32, name = 'image',
49 | shape = [None, None, None, self.num_channels])
50 | self.gt = tf.placeholder(tf.int64, [None, None, None], 'gt')
51 | self.mask = tf.placeholder(tf.int32, [None, None, None], 'mask')
52 |
53 | with tf.variable_scope('conv1') as scope:
54 | conv1 = conv(self.image, 5, 32, nl = tf.nn.relu)
55 | pool1 = max_pool(conv1, padding = 'SAME')
56 |
57 | with tf.variable_scope('conv2') as scope:
58 | conv2 = conv(pool1, 3, 48, nl = tf.nn.relu)
59 | pool2 = max_pool(conv2, padding = 'SAME')
60 |
61 | with tf.variable_scope('conv3') as scope:
62 | conv3 = conv(pool2, 3, 64, nl = tf.nn.relu)
63 | pool3 = max_pool(conv3, padding = 'SAME')
64 |
65 | with tf.variable_scope('conv4') as scope:
66 | conv4 = conv(pool3, 3, 128, nl = tf.nn.relu)
67 | pool4 = max_pool(conv4, padding = 'SAME')
68 |
69 | with tf.variable_scope('fc1') as scope:
70 | fc1 = conv(pool4, 2, 128, nl = tf.nn.relu)
71 | dropout_fc1 = dropout(fc1, self.keep_prob, self.is_training)
72 |
73 | with tf.variable_scope('fc2') as scope:
74 | fc2 = conv(dropout_fc1, 1, self.num_class)
75 |
76 | dconv1 = tf.add(dconv(fc2, 4, name = 'dconv1',
77 | out_shape_by_tensor = pool3), pool3)
78 | dconv2 = tf.add(dconv(dconv1, 4, name = 'dconv2',
79 | out_shape_by_tensor = pool2), pool2)
80 | dconv3 = dconv(dconv2, 16, self.num_class,
81 | out_shape_by_tensor = self.image,
82 | name = 'dconv3', stride = 4)
83 |
84 | with tf.name_scope('prediction'):
85 | self.prediction = tf.argmax(dconv3, name='label', dimension = -1)
86 | self.softmax_dconv3 = tf.nn.softmax(dconv3)
87 | prediction_pro = tf.identity(self.softmax_dconv3[:,:,:,1],
88 | name = 'probability')
89 |
90 | def _setup_graph(self):
91 | with tf.name_scope('accuracy'):
92 | correct_prediction = apply_mask(
93 | tf.equal(self.prediction, self.gt),
94 | self.mask)
95 | self.accuracy = tf.reduce_mean(
96 | tf.cast(correct_prediction, tf.float32),
97 | name = 'result')
98 |
99 | def _get_loss(self):
100 | with tf.name_scope('loss'):
101 | # This op expects unscaled logits, since it performs a softmax on logits internally for efficiency. Do not call this op with the output of softmax, as it will produce incorrect results.
102 | return tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits
103 | (logits = apply_mask(self.softmax_dconv3, self.mask),
104 | labels = apply_mask(self.gt, self.mask)), name = 'result')
105 |
106 | def _get_optimizer(self):
107 | return tf.train.AdamOptimizer(learning_rate = self.learning_rate)
108 |
109 | def _setup_summary(self):
110 | with tf.name_scope('train_summary'):
111 | tf.summary.image("train_Predict",
112 | tf.expand_dims(tf.cast(self.prediction, tf.float32), -1),
113 | collections = ['train'])
114 | tf.summary.image("im",tf.cast(self.image, tf.float32),
115 | collections = ['train'])
116 | tf.summary.image("gt",
117 | tf.expand_dims(tf.cast(self.gt, tf.float32), -1),
118 | collections = ['train'])
119 | tf.summary.image("mask",
120 | tf.expand_dims(tf.cast(self.mask, tf.float32), -1),
121 | collections = ['train'])
122 | tf.summary.scalar('train_accuracy', self.accuracy,
123 | collections = ['train'])
124 | with tf.name_scope('test_summary'):
125 | tf.summary.image("test_Predict",
126 | tf.expand_dims(tf.cast(self.prediction, tf.float32), -1),
127 | collections = ['test'])
128 |
129 | def get_config(FLAGS):
130 | mat_name_list = ['level1Edge', 'GT', 'Mask']
131 | dataset_train = MatlabData('train', mat_name_list = mat_name_list,
132 | data_dir = FLAGS.data_dir)
133 | dataset_val = MatlabData('val', mat_name_list = mat_name_list,
134 | data_dir = FLAGS.data_dir)
135 | inference_list = [InferScalars('accuracy/result', 'test_accuracy')]
136 |
137 | return TrainConfig(
138 | dataflow = dataset_train,
139 | model = Model(num_channels = FLAGS.input_channel,
140 | num_class = FLAGS.num_class,
141 | learning_rate = 0.0001),
142 | monitors = TFSummaryWriter(summary_dir = FLAGS.summary_dir),
143 | callbacks = [
144 | ModelSaver(periodic = 10,
145 | checkpoint_dir = FLAGS.summary_dir),
146 | TrainSummary(key = 'train', periodic = 10),
147 | FeedInference(dataset_val, periodic = 10,
148 | extra_cbs = TrainSummary(key = 'test'),
149 | inferencers = inference_list),
150 | # CheckScalar(['accuracy/result'], periodic = 10),
151 | ],
152 | batch_size = FLAGS.batch_size,
153 | max_epoch = 200,
154 | summary_periodic = 10)
155 |
156 | def get_predictConfig(FLAGS):
157 | mat_name_list = ['level1Edge']
158 | dataset_test = MatlabData('Level_1', shuffle = False,
159 | mat_name_list = mat_name_list,
160 | data_dir = FLAGS.test_data_dir)
161 | prediction_list = PredictionImage(['prediction/label', 'prediction/probability'],
162 | ['test','test_pro'],
163 | merge_im = True)
164 |
165 | return PridectConfig(
166 | dataflow = dataset_test,
167 | model = Model(FLAGS.input_channel,
168 | num_class = FLAGS.num_class),
169 | model_name = 'model-14070',
170 | model_dir = FLAGS.model_dir,
171 | result_dir = FLAGS.result_dir,
172 | predictions = prediction_list,
173 | batch_size = FLAGS.batch_size)
174 |
175 | def get_args():
176 | parser = argparse.ArgumentParser()
177 | parser.add_argument('--data_dir',
178 | help = 'Directory of input training data.',
179 | default = 'D:\\GoogleDrive_Qian\\Foram\\Training\\CNN_Image\\')
180 | parser.add_argument('--summary_dir',
181 | help = 'Directory for saving summary.',
182 | default = 'D:\\Qian\\GitHub\\workspace\\test\\')
183 | parser.add_argument('--checkpoint_dir',
184 | help = 'Directory for saving checkpoint.',
185 | default = 'D:\\Qian\\GitHub\\workspace\\test\\')
186 |
187 | parser.add_argument('--test_data_dir',
188 | help = 'Directory of input test data.',
189 | default = 'D:\\GoogleDrive_Qian\\Foram\\testing\\')
190 | parser.add_argument('--model_dir',
191 | help = 'Directory for restoring checkpoint.',
192 | default = 'D:\\Qian\\GitHub\\workspace\\test\\')
193 | parser.add_argument('--result_dir',
194 | help = 'Directory for saving prediction results.',
195 | default = 'D:\\Qian\\GitHub\\workspace\\test\\2\\')
196 |
197 | parser.add_argument('--input_channel', default = 1,
198 | help = 'Number of image channels')
199 | parser.add_argument('--num_class', default = 2,
200 | help = 'Number of classes')
201 | parser.add_argument('--batch_size', default = 1)
202 |
203 | parser.add_argument('--predict', help = 'Run prediction', action='store_true')
204 | parser.add_argument('--train', help = 'Train the model', action='store_true')
205 |
206 | return parser.parse_args()
207 |
208 | if __name__ == '__main__':
209 |
210 | FLAGS = get_args()
211 | if FLAGS.train:
212 | config = get_config(FLAGS)
213 | SimpleFeedTrainer(config).train()
214 | elif FLAGS.predict:
215 | config = get_predictConfig(FLAGS)
216 | SimpleFeedPredictor(config).run_predict()
217 |
218 |
219 |
--------------------------------------------------------------------------------
/test/todo.md:
--------------------------------------------------------------------------------
1 | # todo
2 |
3 | ## utils.common
4 | - add multiple tensors with the same name in get_tensors_by_names
5 |
--------------------------------------------------------------------------------