');
38 | ul.append(entry);
39 | obj[groupName] = ul;
40 | }
41 | });
42 |
43 | return obj;
44 | }
45 |
46 |
47 |
48 | var customIndex = $('.custom-index');
49 | customIndex.empty();
50 |
51 |
52 | var selected = $('div.section>dl>dt');
53 | if (selected.length === 0)
54 | return;
55 |
56 | var obj = createList(selected);
57 | var block = $('');
58 | for(var key in obj) {
59 | var a = $('
');
60 | a.html(key + ':');
61 | block.append(a);
62 | block.append(obj[key]);
63 | }
64 | customIndex.append(block);
65 | });
66 |
--------------------------------------------------------------------------------
/third_party/tensorpack/docs/_static/favicon.ico:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-research/ssl_detection/00d52272f61b56eade8d5ace18213cba6c74f6d8/third_party/tensorpack/docs/_static/favicon.ico
--------------------------------------------------------------------------------
/third_party/tensorpack/docs/_static/sanitize_desc_name.js:
--------------------------------------------------------------------------------
1 | //File: sanitize_desc_name.js
2 |
3 | $(function (){
4 | var selected = $('div.section>dl>dt>code.descclassname');
5 | selected.each(function(_, e) {
6 | var text = e.innerText;
7 | if (text.startsWith('tensorpack.')) {
8 | text = text.substr(11);
9 | e.innerText = text;
10 | }
11 | });
12 | });
13 |
--------------------------------------------------------------------------------
/third_party/tensorpack/docs/_templates/layout.html:
--------------------------------------------------------------------------------
1 | {% extends "!layout.html" %}
2 |
3 | {%- block extrahead %}
4 |
7 |
10 |
11 |
20 | {% endblock %}
21 |
--------------------------------------------------------------------------------
/third_party/tensorpack/docs/index.rst:
--------------------------------------------------------------------------------
1 | Tensorpack Documentation
2 | ==============================
3 |
4 | .. image:: ../.github/tensorpack.png
5 |
6 | Tensorpack is a **training interface** based on TensorFlow, with a focus on speed + flexibility.
7 | TensorFlow is powerful, but has its own drawbacks:
8 | Its low-level APIs are too hard and complicated for many users,
9 | and its existing high-level APIs sacrifice a lot in either speed or flexibility.
10 | The Tensorpack API brings speed and flexibility together.
11 |
12 | Tensorpack is Yet Another TF high-level API, but different in:
13 |
14 | - Focus on **training speed**.
15 |
16 | - Speed comes for free with tensorpack -- it uses TensorFlow in the
17 | **efficient way** with no extra overhead. On common CNNs, it runs
18 | `1.2~5x faster `_
19 | than the equivalent Keras code.
20 |
21 | - Data-parallel multi-GPU/distributed training strategy is off-the-shelf to use.
22 | It scales as well as Google's
23 | `official benchmark `_.
24 | You cannot beat its speed unless you're a TensorFlow expert.
25 |
26 | - See `tensorpack/benchmarks `_ for some benchmark scripts.
27 |
28 | - Focus on **large datasets**.
29 |
30 | - You don't usually need `tf.data`. Symbolic programming often makes data processing harder.
31 | Tensorpack helps you efficiently process large datasets (e.g. ImageNet) in **pure Python** with autoparallelization.
32 |
33 | - It's not a model wrapper.
34 |
35 | - There are already too many symbolic function wrappers in the world.
36 | Tensorpack includes only a few common models, but you can use any symbolic function library inside tensorpack, including tf.layers/Keras/slim/tflearn/tensorlayer/...
37 |
38 | See :doc:`tutorial/index` to know more about these features:
39 |
40 |
41 | .. toctree::
42 | :maxdepth: 3
43 |
44 | tutorial/index
45 | modules/index
46 |
--------------------------------------------------------------------------------
/third_party/tensorpack/docs/modules/callbacks.rst:
--------------------------------------------------------------------------------
1 | tensorpack.callbacks package
2 | ============================
3 |
4 | **Everything** other than the training iterations happen in the callbacks.
5 | Most of the fancy things you want to do will probably end up here.
6 | See relevant tutorials: :doc:`../tutorial/callback`.
7 |
8 | .. container:: custom-index
9 |
10 | .. raw:: html
11 |
12 |
13 |
14 |
15 | .. automodule:: tensorpack.callbacks
16 | :members:
17 | :no-undoc-members:
18 | :show-inheritance:
19 |
--------------------------------------------------------------------------------
/third_party/tensorpack/docs/modules/contrib.rst:
--------------------------------------------------------------------------------
1 |
2 | tensorpack.contrib package
3 | ==========================
4 |
5 | .. automodule:: tensorpack.contrib.keras
6 | :members:
7 | :undoc-members:
8 | :show-inheritance:
9 |
--------------------------------------------------------------------------------
/third_party/tensorpack/docs/modules/dataflow.dataset.rst:
--------------------------------------------------------------------------------
1 | tensorpack.dataflow.dataset package
2 | ===================================
3 |
4 | .. container:: custom-index
5 |
6 | .. raw:: html
7 |
8 |
9 |
10 |
11 | .. automodule:: tensorpack.dataflow.dataset
12 | :members:
13 | :undoc-members:
14 | :show-inheritance:
15 |
--------------------------------------------------------------------------------
/third_party/tensorpack/docs/modules/dataflow.imgaug.rst:
--------------------------------------------------------------------------------
1 | tensorpack.dataflow.imgaug package
2 | ==================================
3 |
4 | This package contains Tensorpack's augmentors.
5 | Read the `tutorial <../tutorial/extend/augmentor.html>`_
6 | first for its design and general usage.
7 |
8 | Note that other image augmentation libraries can be wrapped into Tensorpack's interface as well.
9 | For example, `imgaug.IAAugmentor <#tensorpack.dataflow.imgaug.IAAugmentor>`_
10 | and `imgaug.Albumentations <#tensorpack.dataflow.imgaug.Albumentations>`_
11 | wrap two popular image augmentation libraries.
12 |
13 |
14 | .. container:: custom-index
15 |
16 | .. raw:: html
17 |
18 |
19 |
20 | .. automodule:: tensorpack.dataflow.imgaug
21 | :members:
22 | :undoc-members:
23 | :show-inheritance:
24 |
--------------------------------------------------------------------------------
/third_party/tensorpack/docs/modules/dataflow.rst:
--------------------------------------------------------------------------------
1 | tensorpack.dataflow package
2 | ===========================
3 |
4 | Relevant tutorials: :doc:`../tutorial/dataflow`, :doc:`../tutorial/philosophy/dataflow`.
5 |
6 | .. container:: custom-index
7 |
8 | .. raw:: html
9 |
10 |
11 |
12 |
13 | .. automodule:: tensorpack.dataflow
14 | :members:
15 | :undoc-members:
16 | :show-inheritance:
17 |
--------------------------------------------------------------------------------
/third_party/tensorpack/docs/modules/graph_builder.rst:
--------------------------------------------------------------------------------
1 | tensorpack.graph_builder package
2 | ================================
3 |
4 | These are some useful functions if you need to write your own trainers.
5 | Otherwise you probably don't need to use them.
6 |
7 | .. automodule:: tensorpack.graph_builder
8 | :members:
9 | :undoc-members:
10 | :show-inheritance:
11 |
--------------------------------------------------------------------------------
/third_party/tensorpack/docs/modules/index.rst:
--------------------------------------------------------------------------------
1 | API Documentation
2 | --------------------
3 |
4 | .. toctree::
5 | :maxdepth: 1
6 |
7 |
8 | dataflow
9 | dataflow.dataset
10 | dataflow.imgaug
11 | input_source
12 | models
13 | callbacks
14 | graph_builder
15 | train
16 | predict
17 | tfutils
18 | utils
19 | contrib
20 |
--------------------------------------------------------------------------------
/third_party/tensorpack/docs/modules/input_source.rst:
--------------------------------------------------------------------------------
1 | tensorpack.input_source package
2 | ================================
3 |
4 | Read the relevant tutorials first for an overview of InputSource: :doc:`../tutorial/extend/input-source`.
5 |
6 | .. automodule:: tensorpack.input_source
7 | :members:
8 | :undoc-members:
9 | :show-inheritance:
10 |
--------------------------------------------------------------------------------
/third_party/tensorpack/docs/modules/models.rst:
--------------------------------------------------------------------------------
1 | tensorpack.models package
2 | =========================
3 |
4 | Relevant tutorials: :doc:`../tutorial/symbolic`.
5 |
6 | .. container:: custom-index
7 |
8 | .. raw:: html
9 |
10 |
11 |
12 |
13 | .. automodule:: tensorpack.models
14 | :members:
15 | :undoc-members:
16 | :show-inheritance:
17 |
--------------------------------------------------------------------------------
/third_party/tensorpack/docs/modules/predict.rst:
--------------------------------------------------------------------------------
1 | tensorpack.predict package
2 | ==========================
3 |
4 | .. automodule:: tensorpack.predict
5 | :members:
6 | :undoc-members:
7 | :show-inheritance:
8 |
--------------------------------------------------------------------------------
/third_party/tensorpack/docs/modules/tfutils.rst:
--------------------------------------------------------------------------------
1 | tensorpack.tfutils package
2 | ==========================
3 |
4 | .. container:: custom-index
5 |
6 | .. raw:: html
7 |
8 |
9 |
10 | tensorpack.tfutils.argscope module
11 | ------------------------------------
12 |
13 | .. automodule:: tensorpack.tfutils.argscope
14 | :members:
15 | :undoc-members:
16 | :show-inheritance:
17 |
18 | tensorpack.tfutils.collection module
19 | ------------------------------------
20 |
21 | .. automodule:: tensorpack.tfutils.collection
22 | :members:
23 | :undoc-members:
24 | :show-inheritance:
25 |
26 | tensorpack.tfutils.gradproc module
27 | ------------------------------------
28 |
29 | .. automodule:: tensorpack.tfutils.gradproc
30 | :members:
31 | :undoc-members:
32 | :show-inheritance:
33 |
34 | tensorpack.tfutils.tower module
35 | ------------------------------------
36 |
37 | .. automodule:: tensorpack.tfutils.tower
38 | :members:
39 | :undoc-members:
40 | :show-inheritance:
41 |
42 | tensorpack.tfutils.scope_utils module
43 | --------------------------------------
44 |
45 | .. automodule:: tensorpack.tfutils.scope_utils
46 | :members:
47 | :undoc-members:
48 | :show-inheritance:
49 |
50 | tensorpack.tfutils.optimizer module
51 | ------------------------------------
52 |
53 | .. automodule:: tensorpack.tfutils.optimizer
54 | :members:
55 | :undoc-members:
56 | :show-inheritance:
57 |
58 | tensorpack.tfutils.sesscreate module
59 | ------------------------------------
60 |
61 | .. automodule:: tensorpack.tfutils.sesscreate
62 | :members:
63 | :undoc-members:
64 | :show-inheritance:
65 |
66 | tensorpack.tfutils.sessinit module
67 | ------------------------------------
68 |
69 | .. automodule:: tensorpack.tfutils.sessinit
70 | :members:
71 | :undoc-members:
72 | :show-inheritance:
73 |
74 | tensorpack.tfutils.summary module
75 | ---------------------------------
76 |
77 | .. automodule:: tensorpack.tfutils.summary
78 | :members:
79 | :undoc-members:
80 | :show-inheritance:
81 |
82 | tensorpack.tfutils.varmanip module
83 | ----------------------------------
84 |
85 | .. automodule:: tensorpack.tfutils.varmanip
86 | :members:
87 | :undoc-members:
88 | :show-inheritance:
89 |
90 | tensorpack.tfutils.varreplace module
91 | ------------------------------------
92 |
93 | .. automodule:: tensorpack.tfutils.varreplace
94 | :members:
95 | :undoc-members:
96 | :show-inheritance:
97 |
98 | tensorpack.tfutils.export module
99 | ------------------------------------
100 |
101 | .. automodule:: tensorpack.tfutils.export
102 | :members:
103 | :undoc-members:
104 | :show-inheritance:
105 |
106 | tensorpack.tfutils.dependency module
107 | ------------------------------------
108 |
109 | .. automodule:: tensorpack.tfutils.dependency
110 | :members:
111 | :undoc-members:
112 | :show-inheritance:
113 |
114 | Other functions in tensorpack.tfutils module
115 | ---------------------------------------------
116 |
117 | .. automethod:: tensorpack.tfutils.get_default_sess_config
118 | .. automethod:: tensorpack.tfutils.get_global_step_var
119 | .. automethod:: tensorpack.tfutils.get_global_step_value
120 | .. automethod:: tensorpack.tfutils.get_tf_version_tuple
121 |
--------------------------------------------------------------------------------
/third_party/tensorpack/docs/modules/train.rst:
--------------------------------------------------------------------------------
1 | tensorpack.train package
2 | ========================
3 |
4 | Relevant tutorials: :doc:`../tutorial/trainer`, :doc:`../tutorial/training-interface`
5 |
6 | .. container:: custom-index
7 |
8 | .. raw:: html
9 |
10 |
11 |
12 | .. automodule:: tensorpack.train
13 | :members:
14 | :undoc-members:
15 | :show-inheritance:
16 |
--------------------------------------------------------------------------------
/third_party/tensorpack/docs/modules/utils.rst:
--------------------------------------------------------------------------------
1 | tensorpack.utils package
2 | ========================
3 |
4 | .. automodule:: tensorpack.utils
5 | :members:
6 | :undoc-members:
7 | :show-inheritance:
8 |
9 | tensorpack.utils.argtools module
10 | --------------------------------
11 |
12 | .. automodule:: tensorpack.utils.argtools
13 | :members:
14 | :undoc-members:
15 | :show-inheritance:
16 |
17 | tensorpack.utils.concurrency module
18 | -----------------------------------
19 |
20 | .. automodule:: tensorpack.utils.concurrency
21 | :members:
22 | :undoc-members:
23 | :show-inheritance:
24 |
25 |
26 | tensorpack.utils.fs module
27 | --------------------------
28 |
29 | .. automodule:: tensorpack.utils.fs
30 | :members:
31 | :undoc-members:
32 | :show-inheritance:
33 |
34 | tensorpack.utils.loadcaffe module
35 | ---------------------------------
36 |
37 | .. automodule:: tensorpack.utils.loadcaffe
38 | :members:
39 | :undoc-members:
40 | :show-inheritance:
41 |
42 | tensorpack.utils.logger module
43 | ------------------------------
44 |
45 | .. automodule:: tensorpack.utils.logger
46 | :members:
47 | :undoc-members:
48 | :show-inheritance:
49 |
50 |
51 | tensorpack.utils.serialize module
52 | ---------------------------------
53 |
54 | .. automodule:: tensorpack.utils.serialize
55 | :members:
56 | :undoc-members:
57 | :show-inheritance:
58 |
59 | tensorpack.utils.stats module
60 | -----------------------------
61 |
62 | .. automodule:: tensorpack.utils.stats
63 | :members:
64 | :undoc-members:
65 | :show-inheritance:
66 |
67 | tensorpack.utils.timer module
68 | -----------------------------
69 |
70 | .. automodule:: tensorpack.utils.timer
71 | :members:
72 | :undoc-members:
73 | :show-inheritance:
74 |
75 | tensorpack.utils.viz module
76 | ---------------------------
77 |
78 | .. automodule:: tensorpack.utils.viz
79 | :members:
80 | :undoc-members:
81 | :show-inheritance:
82 |
83 | tensorpack.utils.gpu module
84 | ---------------------------
85 |
86 | .. automodule:: tensorpack.utils.gpu
87 | :members:
88 | :undoc-members:
89 | :show-inheritance:
90 |
--------------------------------------------------------------------------------
/third_party/tensorpack/docs/requirements.txt:
--------------------------------------------------------------------------------
1 | termcolor
2 | numpy
3 | tqdm
4 | docutils>=0.14
5 | Sphinx>=1.6
6 | recommonmark==0.4.0
7 | sphinx_rtd_theme
8 | mock
9 | matplotlib
10 | tensorflow==1.4.0
11 |
--------------------------------------------------------------------------------
/third_party/tensorpack/docs/tensorpack.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 | 0.9.0.1
4 | https://github.com/tensorpack/tensorpack/releases/download/doc-v0.9.0.1/tensorpack.docset.tgz
5 |
6 |
--------------------------------------------------------------------------------
/third_party/tensorpack/docs/tutorial/dataflow.md:
--------------------------------------------------------------------------------
1 |
2 | # DataFlow
3 |
4 | DataFlow is a pure-Python library to create iterators for efficient data loading.
5 | It is originally part of tensorpack, and now also available as a [separate library](https://github.com/tensorpack/dataflow).
6 |
7 | ### What is DataFlow
8 |
9 | **Definition**: A DataFlow instance is a idiomatic Python iterator object that has a `__iter__()` method
10 | which yields `datapoints`, and optionally a `__len__()` method returning the size of the DataFlow.
11 | A datapoint is a **list or dict** of Python objects, each of which are called the `components` of a datapoint.
12 |
13 | **Example**: to train on MNIST dataset, you may need a DataFlow with a `__iter__()` method
14 | that yields datapoints (lists) of two components:
15 | a numpy array of shape (64, 28, 28), and an array of shape (64,).
16 |
17 | DataFlow is independent of the training frameworks since it produces any python objects
18 | (usually numpy arrays).
19 | You can simply use DataFlow as a data processing pipeline and plug it into your own training code.
20 |
21 | ### Load Raw Data
22 | We do not make any assumptions about your data format.
23 | You would usually want to write the source DataFlow (`MyDataFlow` in the example below) for your own data format.
24 | See [another tutorial](extend/dataflow.html) for simple instructions on writing a DataFlow.
25 |
26 | ### Assemble the Pipeline
27 | There are a lot of existing DataFlow utilities in tensorpack, which you can use to assemble
28 | the source DataFlow with complex data pipeline.
29 | A common pipeline usually would
30 | __read from disk (or other sources),
31 | apply transformations,
32 | group into batches, prefetch data__, etc, and all __run in parallel__.
33 | A simple DataFlow pipeline is like the following:
34 |
35 | ````python
36 | # a DataFlow you implement to produce [tensor1, tensor2, ..] lists from whatever sources:
37 | df = MyDataFlow(dir='/my/data', shuffle=True)
38 | # apply transformation to your data
39 | df = MapDataComponent(df, lambda t: transform(t), 0)
40 | # group data into batches of size 128
41 | df = BatchData(df, 128)
42 | # start 3 processes to run the dataflow in parallel
43 | df = MultiProcessRunnerZMQ(df, 3)
44 | ````
45 |
46 | A list of built-in DataFlow to use can be found at [API docs](../modules/dataflow.html).
47 | You can also find complicated real-life DataFlow pipelines in the [ImageNet training script](../examples/ImageNetModels/imagenet_utils.py)
48 | or other tensorpack examples.
49 |
50 | ### Parallelize the Pipeline
51 |
52 | DataFlow includes **carefully optimized** parallel runners and parallel mappers: `Multi{Thread,Process}{Runner,MapData}`.
53 | Runners execute multiple clones of a dataflow in parallel.
54 | Mappers execute a mapping function in parallel on top of an existing dataflow.
55 | You can find details in the [API docs](../modules/dataflow.html) under the
56 | "parallel" and "parallel_map" section.
57 |
58 | [Parallel DataFlow tutorial](parallel-dataflow.html) gives a deeper dive
59 | on how to use them to optimize your data pipeline.
60 |
61 | ### Run the DataFlow
62 |
63 | When training with tensorpack, typically it is the `InputSource` interface that runs the DataFlow.
64 |
65 | When using DataFlow alone without tensorpack,
66 | you need to call `reset_state()` first to initialize it,
67 | and then use the generator however you like:
68 |
69 | ```python
70 | df = SomeDataFlow()
71 |
72 | df.reset_state()
73 | for dp in df:
74 | # dp is now a list/dict. do whatever with it
75 | ```
76 |
77 | ### Why DataFlow?
78 |
79 | It's **easy and fast**.
80 | For more discussions, see [Why DataFlow?](/tutorial/philosophy/dataflow.html)
81 | Nevertheless, using DataFlow is not required in tensorpack.
82 | Tensorpack supports data loading with native TF operators / TF datasets as well.
83 |
--------------------------------------------------------------------------------
/third_party/tensorpack/docs/tutorial/extend/augmentor.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ### Write an Image Augmentor
4 |
5 | The first thing to note: __you never have to write an augmentor__.
6 | An augmentor is a part of the DataFlow, so you can always
7 | [write a DataFlow](dataflow.html)
8 | to do whatever operations to your data, rather than writing an augmentor.
9 |
10 | Augmentor makes things easier when what you want fits its design.
11 | But remember it is just an abstraction that may not always work for your use case.
12 | For example, if your data transformation depend on multiple dataflow components,
13 | or if you want to apply different transformations to different components,
14 | the abstraction is often not enough for you, and you need to write code on the
15 | DataFlow level instead.
16 |
17 | An image augmentor maps an image to an image.
18 | If you have such a mapping function `f` already, you can simply use
19 | [imgaug.MapImage(f)](../../modules/dataflow.imgaug.html#tensorpack.dataflow.imgaug.MapImage)
20 | as the augmentor, or use
21 | [MapDataComponent(dataflow, f, index)](../../modules/dataflow.html#tensorpack.dataflow.MapDataComponent)
22 | as the DataFlow.
23 | In other words, for simple mapping you do not need to write an augmentor.
24 |
25 | An augmentor may do something more than just applying a mapping.
26 | To do custom augmentation, you can implement one yourself.
27 |
28 |
29 | #### The Design of imgaug Module
30 |
31 | The [imgaug module](../../modules/dataflow.imgaug.html) is designed to allow the following usage:
32 |
33 | * Factor out randomness and determinism.
34 | An augmentor often contains randomized policy, e.g., it randomly perturbs each image differently.
35 | However, its "deterministic" part needs to be factored out, so that
36 | the same transformation can be re-applied to other data
37 | assocaited with the image. This is achieved like this:
38 |
39 | ```python
40 | tfm = augmentor.get_transform(img) # a deterministic transformation
41 | new_img = tfm.apply_image(img)
42 | new_img2 = tfm.apply_image(img2)
43 | new_coords = tfm.apply_coords(coords)
44 | ```
45 |
46 | Due to this design, it can augment images together with its annotations
47 | (e.g., segmentation masks, bounding boxes, keypoints).
48 | Our coordinate augmentation enforces floating points coordinates
49 | to avoid quantization error.
50 |
51 | When you don't need to re-apply the same transformation, you can also just call
52 |
53 | ```python
54 | new_img = augmentor.augment(img)
55 | ```
56 |
57 | * Reset random seed. Random seed can be reset by
58 | [reset_state](../../modules/dataflow.imgaug.html#tensorpack.dataflow.imgaug.ImageAugmentor.reset_state).
59 | This is important for multi-process data loading, to make sure different
60 | processes get different seeds.
61 | The reset method is called automatically if you use tensorpack's
62 | [image augmentation dataflow](../../modules/dataflow.html#tensorpack.dataflow.AugmentImageComponent)
63 | or if you use Python 3.7+.
64 | Otherwise, **you are responsible** for calling it by yourself in subprocesses.
65 | See the
66 | [API documentation](../../modules/dataflow.imgaug.html#tensorpack.dataflow.imgaug.ImageAugmentor.reset_state)
67 | of this method for more details.
68 |
69 |
70 | ### Write an Augmentor
71 |
72 | The interface you will need to implement is:
73 |
74 | ```python
75 | class MyAug(imgaug.ImageAugmentor):
76 | def get_transform(self, img):
77 | # Randomly generate a deterministic transformation, to be applied on img
78 | x = random_parameters()
79 | return MyTransform(x)
80 |
81 | class MyTransform(imgaug.Transform):
82 | def apply_image(self, img):
83 | return new_img
84 |
85 | def apply_coords(self, coords):
86 | return new_coords
87 | ```
88 |
89 | Check out the zoo of builtin augmentors to have a better sense.
90 |
--------------------------------------------------------------------------------
/third_party/tensorpack/docs/tutorial/extend/dataflow.md:
--------------------------------------------------------------------------------
1 |
2 | ### Write a DataFlow
3 |
4 | First, make sure you know about Python's generators and `yield` keyword.
5 | If you don't, learn it on Google.
6 |
7 | #### Write a Source DataFlow
8 |
9 | There are several existing DataFlow, e.g. [ImageFromFile](../../modules/dataflow.html#tensorpack.dataflow.ImageFromFile),
10 | [DataFromList](../../modules/dataflow.html#tensorpack.dataflow.DataFromList),
11 | which you can use if your data format is simple.
12 | In general, you probably need to write a source DataFlow to produce data for your task,
13 | and then compose it with other DataFlow (e.g. mapping, batching, prefetching, ...).
14 |
15 | The easiest way to create a DataFlow to load custom data, is to wrap a custom generator, e.g.:
16 | ```python
17 | def my_data_loader():
18 | # load data from somewhere with Python, and yield them
19 | for k in range(100):
20 | yield [my_array, my_label]
21 |
22 | df = DataFromGenerator(my_data_loader)
23 | ```
24 |
25 | To write more complicated DataFlow, you need to inherit the base `DataFlow` class.
26 | Usually, you just need to implement the `__iter__()` method which yields a datapoint every time.
27 | ```python
28 | class MyDataFlow(DataFlow):
29 | def __iter__(self):
30 | # load data from somewhere with Python, and yield them
31 | for k in range(100):
32 | digit = np.random.rand(28, 28)
33 | label = np.random.randint(10)
34 | yield [digit, label]
35 |
36 | df = MyDataFlow()
37 | df.reset_state()
38 | for datapoint in df:
39 | print(datapoint[0], datapoint[1])
40 | ```
41 |
42 | Optionally, you can implement the `__len__` and `reset_state` method.
43 | The detailed semantics of these three methods are explained
44 | in the [API documentation](../../modules/dataflow.html#tensorpack.dataflow.DataFlow).
45 | If you're writing a complicated DataFlow, make sure to read the API documentation
46 | for the semantics.
47 |
48 | DataFlow implementations for several well-known datasets are provided in the
49 | [dataflow.dataset](../../modules/dataflow.dataset.html)
50 | module. You can take them as examples.
51 |
52 | #### More Data Processing
53 |
54 | You can put any data processing you need in the source DataFlow you write, or you can write a new DataFlow for data
55 | processing on top of the source DataFlow, e.g.:
56 |
57 | ```python
58 | class ProcessingDataFlow(DataFlow):
59 | def __init__(self, ds):
60 | self.ds = ds
61 |
62 | def reset_state(self):
63 | self.ds.reset_state()
64 |
65 | def __iter__(self):
66 | for datapoint in self.ds:
67 | # do something
68 | yield new_datapoint
69 | ```
70 |
71 | Some built-in dataflows, e.g.
72 | [MapData](../../modules/dataflow.html#tensorpack.dataflow.MapData) and
73 | [MapDataComponent](../../modules/dataflow.html#tensorpack.dataflow.MapDataComponent)
74 | can do common types of data processing for you.
75 |
--------------------------------------------------------------------------------
/third_party/tensorpack/docs/tutorial/extend/input-source.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-research/ssl_detection/00d52272f61b56eade8d5ace18213cba6c74f6d8/third_party/tensorpack/docs/tutorial/extend/input-source.png
--------------------------------------------------------------------------------
/third_party/tensorpack/docs/tutorial/extend/model.md:
--------------------------------------------------------------------------------
1 |
2 | ## Write a Layer
3 |
4 | The first thing to note: __you never have to write a layer__.
5 | Tensorpack layers are nothing but wrappers of symbolic functions.
6 | In tensorpack, you can use __any__ symbolic functions you have written or seen elsewhere with or without tensorpack layers.
7 |
8 | If you would like, you can make a symbolic function become a "layer" by following some simple rules, and then gain benefits from tensorpack.
9 |
10 | Take a look at the [ShuffleNet example](../../examples/ImageNetModels/shufflenet.py#L22)
11 | to see an example of how to define a custom layer:
12 |
13 | ```python
14 | @layer_register(log_shape=True)
15 | def DepthConv(x, out_channel, kernel_shape, padding='SAME', stride=1,
16 | W_init=None, activation=tf.identity):
17 | ```
18 |
19 | Basically, a tensorpack layer is just a symbolic function, but with the following rules:
20 |
21 | + It is decorated by `@layer_register`.
22 | + The first argument is its "input". It must be a **tensor or a list of tensors**.
23 | + It returns either a tensor or a list of tensors as its "output".
24 |
25 |
26 | By making a symbolic function a "layer", the following things will happen:
27 | + You will need to call the function with a scope name as the first argument, e.g. `Conv2D('conv0', x, 32, 3)`.
28 | Everything happening in this function will be under the variable scope `conv0`.
29 | You can register the layer with `use_scope=False` to disable this feature.
30 | + Static shapes of input/output will be printed to screen (if you register with `log_shape=True`).
31 | + `argscope` will work for all its arguments except the input tensor(s).
32 | + It will work with `LinearWrap`: you can use it if the output of one layer matches the input of the next layer.
33 |
34 | There is no rule about what kind of symbolic functions should be made a layer -- they are quite
35 | similar anyway. However, in general, I define the following symbolic functions as layers:
36 | + Functions which contain variables. A variable scope is almost always helpful for such functions.
37 | + Functions which are commonly referred to as "layers", such as pooling. This makes a model
38 | definition more straightforward.
39 |
40 |
--------------------------------------------------------------------------------
/third_party/tensorpack/docs/tutorial/faq.md:
--------------------------------------------------------------------------------
1 |
2 | # FAQs
3 |
4 | ## Does it support data format X / augmentation Y / layer Z?
5 |
6 | The library tries to __support__ everything, but it could not really __include__ everything.
7 |
8 | The interface attempts to be flexible enough so you can put any XYZ on it.
9 | You can either implement them under the interface or simply wrap some existing Python code.
10 | See [Extend Tensorpack](index.html#extend-tensorpack)
11 | for more details.
12 |
13 | If you think:
14 | 1. The framework has limitation in its interface so your XYZ cannot be supported, OR
15 | 2. Your XYZ is super common / very well-defined / very useful, so it would be nice to include it.
16 |
17 | Then it is a good time to open an issue.
18 |
19 | ## How to print/dump intermediate results during training
20 |
21 | 1. Learn `tf.Print`. Most of the times, adding one line in between:
22 |
23 | ```python
24 | tensor = obtain_a_tensor()
25 | tensor = tf.Print(tensor, [tf.shape(tensor), tensor], tensor.name, summarize=100)
26 | use_the_tensor(tensor)
27 | ```
28 | is sufficient.
29 |
30 | 2. Know [DumpTensors](../modules/callbacks.html#tensorpack.callbacks.DumpTensors),
31 | [ProcessTensors](../modules/callbacks.html#tensorpack.callbacks.ProcessTensors) callbacks.
32 | And it's also easy to write your own version of them.
33 |
34 | 3. The [ProgressBar](../modules/callbacks.html#tensorpack.callbacks.ProgressBar)
35 | callback can print some scalar statistics, though not enabled by default.
36 |
37 | 4. Read [Summary and Logging](summary.html) for more options on logging.
38 |
39 | ## How to freeze some variables in training
40 |
41 | 1. Learn `tf.stop_gradient`. You can simply use `tf.stop_gradient` in your model code in many situations (e.g. to freeze first several layers).
42 | Note that it stops the gradient flow in the current Tensor but your variables may still contribute to the
43 | final loss through other tensors (e.g., weight decay).
44 |
45 | 2. [varreplace.freeze_variables](../modules/tfutils.html#tensorpack.tfutils.varreplace.freeze_variables) returns a context where variables are freezed.
46 | It is implemented by `custom_getter` argument of `tf.variable_scope` -- learn it to gain more control over what & how variables are freezed.
47 |
48 | 3. [ScaleGradient](../modules/tfutils.html#tensorpack.tfutils.gradproc.ScaleGradient) can be used to set the gradients of some variables to 0.
49 | But it may be slow, since variables still have gradients.
50 |
51 | Note that the above methods only prevent variables being updated by SGD.
52 | Some variables may be updated by other means,
53 | e.g., BatchNorm statistics are updated through the `UPDATE_OPS` collection and the [RunUpdateOps](../modules/callbacks.html#tensorpack.callbacks.RunUpdateOps) callback.
54 |
55 | ## The model does not run on CPUs?
56 |
57 | Some TensorFlow ops are not implemented on CPUs.
58 | For example, it does not support many ops in NCHW format on CPUs.
59 | Note that if you use MKL-enabled version of TensorFlow, it supports more NCHW ops.
60 |
61 | In general, you need to implement the model in a way your version of TensorFlow supports.
62 |
63 | ## My training seems slow. Why?
64 |
65 | Checkout the [Performance Tuning tutorial](performance-tuning.html)
66 |
--------------------------------------------------------------------------------
/third_party/tensorpack/docs/tutorial/index.rst:
--------------------------------------------------------------------------------
1 |
2 | Tutorials
3 | ---------------------
4 |
5 | Introduction
6 | =============
7 |
8 | .. include:: intro.rst
9 |
10 | Basic Tutorials
11 | ========================
12 |
13 | .. toctree::
14 | :maxdepth: 1
15 |
16 | trainer
17 | training-interface
18 | callback
19 | symbolic
20 | save-load
21 | summary
22 | inference
23 | faq
24 |
25 | DataFlow Tutorials
26 | ========================
27 |
28 | .. toctree::
29 | :maxdepth: 1
30 |
31 | dataflow
32 | philosophy/dataflow
33 | extend/dataflow
34 | parallel-dataflow
35 | efficient-dataflow
36 |
37 | Advanced Tutorials
38 | ==================
39 |
40 | .. toctree::
41 | :maxdepth: 1
42 |
43 | extend/input-source
44 | extend/callback
45 | extend/augmentor
46 | extend/model
47 | extend/trainer
48 | performance-tuning
49 |
--------------------------------------------------------------------------------
/third_party/tensorpack/docs/tutorial/intro.rst:
--------------------------------------------------------------------------------
1 |
2 | What is tensorpack?
3 | ~~~~~~~~~~~~~~~~~~~
4 |
5 | Tensorpack is a **training interface** based on TensorFlow, which means:
6 | you'll use mostly tensorpack high-level APIs to do training, rather than TensorFlow low-level APIs.
7 |
8 | Why tensorpack?
9 | ~~~~~~~~~~~~~~~~~~~
10 |
11 | TensorFlow is powerful, but has its own drawbacks:
12 | Its low-level APIs are too hard and complicated for many users,
13 | and its existing high-level APIs sacrifice a lot in either speed or flexibility.
14 | The Tensorpack API brings speed and flexibility together.
15 |
16 |
17 | Is TensorFlow Slow?
18 | *******************
19 |
20 |
21 | There is a common misconception,
22 | but no, it's not slow. But it's not easy to write it in an efficient way.
23 |
24 | When **speed** is a concern, users will have to worry a lot about things unrelated to the model.
25 | Code written with low-level APIs or other existing high-level wrappers is often suboptimal in speed.
26 | Even most of the official TensorFlow examples are written for simplicity rather than efficiency,
27 | which as a result makes people think TensorFlow is *slow*.
28 |
29 | The `official TensorFlow benchmark `_ said this in their README:
30 |
31 | These models are designed for performance. For models that have clean and easy-to-read implementations, see the TensorFlow Official Models.
32 |
33 | which seems to suggest that you cannot have **performance and ease-of-use together**.
34 | However you can have them both in tensorpack.
35 | Tensorpack
36 | `uses TensorFlow efficiently `_,
37 | and hides performance details under its APIs.
38 | You no longer need to write
39 | data prefetch, multi-GPU replication, device placement, variables synchronization -- anything that's unrelated to the model itself.
40 | You still need to understand graph and learn to write models with TF, but performance is all taken care of by tensorpack.
41 |
42 | A High Level Glance
43 | ~~~~~~~~~~~~~~~~~~~
44 |
45 | .. image:: https://user-images.githubusercontent.com/1381301/29187907-2caaa740-7dc6-11e7-8220-e20ca52c3ca6.png
46 |
47 |
48 | * ``DataFlow`` is a library to load data efficiently in Python.
49 | Apart from DataFlow, native TF operators can be used for data loading as well.
50 | They will eventually be wrapped under the same ``InputSource`` interface and go through prefetching.
51 |
52 | * You can use any TF-based symbolic function library to define a model, including
53 | a small set of functions within tensorpack. ``ModelDesc`` is an interface to connect
54 | the model with the trainers, but you can also use trainers without ``ModelDesc``.
55 |
56 | * Tensorpack trainers manage the training loops for you.
57 | They also include data parallel logic for multi-GPU and distributed training.
58 | At the same time, you have the power of customization through callbacks.
59 |
60 | * Callbacks are like ``tf.train.SessionRunHook``, or plugins. During training,
61 | everything you want to do other than the main iterations can be defined through callbacks and easily reused.
62 |
63 | * All the components, though work perfectly together, are highly decorrelated: you can:
64 |
65 | * Use DataFlow alone as a data loading library, without tensorfow at all.
66 | * Use tensorpack to build the graph with multi-GPU or distributed support,
67 | then train it with your own loops.
68 | * Build the graph on your own, and train it with tensorpack callbacks.
69 |
--------------------------------------------------------------------------------
/third_party/tensorpack/readthedocs.yml:
--------------------------------------------------------------------------------
1 | formats:
2 | - none
3 | requirements_file: docs/requirements.txt
4 |
5 | python:
6 | version: 3.5
7 |
--------------------------------------------------------------------------------
/third_party/tensorpack/scripts/README.md:
--------------------------------------------------------------------------------
1 |
2 | These scripts are some helpful utilities about the library.
3 |
4 | They are meant to be __examples__ on how to do some basic model manipulation
5 | with tensorpack. The scripts themselves are not part of the library and
6 | therefore are not subject to any compatibility guarantee.
7 |
--------------------------------------------------------------------------------
/third_party/tensorpack/scripts/checkpoint-manipulate.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # File: checkpoint-manipulate.py
4 |
5 |
6 | import argparse
7 | import numpy as np
8 |
9 | from tensorpack.tfutils.varmanip import load_chkpt_vars
10 | from tensorpack.utils import logger
11 |
12 | if __name__ == '__main__':
13 | parser = argparse.ArgumentParser()
14 | parser.add_argument('model')
15 | parser.add_argument('--dump', help='dump to an npz file')
16 | parser.add_argument('--shell', action='store_true', help='start a shell with the params')
17 | args = parser.parse_args()
18 |
19 | if args.model.endswith('.npy'):
20 | params = np.load(args.model, encoding='latin1').item()
21 | elif args.model.endswith('.npz'):
22 | params = dict(np.load(args.model))
23 | else:
24 | params = load_chkpt_vars(args.model)
25 | logger.info("Variables in the model:")
26 | logger.info(str(params.keys()))
27 |
28 | if args.dump:
29 | assert args.dump.endswith('.npz'), args.dump
30 | np.savez(args.dump, **params)
31 |
32 | if args.shell:
33 | # params is a dict. play with it
34 | import IPython as IP
35 | IP.embed(config=IP.terminal.ipapp.load_default_config())
36 |
--------------------------------------------------------------------------------
/third_party/tensorpack/scripts/checkpoint-prof.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # File: checkpoint-prof.py
4 |
5 | import argparse
6 | import numpy as np
7 | import tensorflow as tf
8 |
9 | from tensorpack import get_default_sess_config, get_op_tensor_name
10 | from tensorpack.tfutils.sessinit import SmartInit
11 | from tensorpack.utils import logger
12 |
13 | if __name__ == '__main__':
14 | parser = argparse.ArgumentParser()
15 | parser.add_argument('--model', help='model file')
16 | parser.add_argument('--meta', help='metagraph proto file. Will be used to load the graph', required=True)
17 | parser.add_argument('-i', '--input', nargs='+', help='list of input tensors with their shapes.')
18 | parser.add_argument('-o', '--output', nargs='+', help='list of output tensors')
19 | parser.add_argument('--warmup', help='warmup iterations', type=int, default=5)
20 | parser.add_argument('--print-flops', action='store_true')
21 | parser.add_argument('--print-params', action='store_true')
22 | parser.add_argument('--print-timing', action='store_true')
23 | args = parser.parse_args()
24 |
25 | tf.train.import_meta_graph(args.meta, clear_devices=True)
26 | G = tf.get_default_graph()
27 | with tf.Session(config=get_default_sess_config()) as sess:
28 | init = SmartInit(args.model)
29 | init.init(sess)
30 |
31 | feed = {}
32 | for inp in args.input:
33 | inp = inp.split('=')
34 | name = get_op_tensor_name(inp[0].strip())[1]
35 | shape = list(map(int, inp[1].strip().split(',')))
36 | tensor = G.get_tensor_by_name(name)
37 | logger.info("Feeding shape ({}) to tensor {}".format(','.join(map(str, shape)), name))
38 | feed[tensor] = np.random.rand(*shape)
39 |
40 | fetches = []
41 | for name in args.output:
42 | name = get_op_tensor_name(name)[1]
43 | fetches.append(G.get_tensor_by_name(name))
44 | logger.info("Fetching tensors: {}".format(', '.join([k.name for k in fetches])))
45 |
46 | for _ in range(args.warmup):
47 | sess.run(fetches, feed_dict=feed)
48 |
49 | opt = tf.RunOptions()
50 | opt.trace_level = tf.RunOptions.FULL_TRACE
51 | meta = tf.RunMetadata()
52 | sess.run(fetches, feed_dict=feed, options=opt, run_metadata=meta)
53 |
54 | if args.print_flops:
55 | tf.profiler.profile(
56 | G,
57 | run_meta=meta,
58 | cmd='op',
59 | options=tf.profiler.ProfileOptionBuilder.float_operation())
60 |
61 | if args.print_params:
62 | tf.profiler.profile(
63 | G,
64 | run_meta=meta,
65 | options=tf.profiler.ProfileOptionBuilder.trainable_variables_parameter())
66 |
67 | if args.print_timing:
68 | tf.profiler.profile(
69 | G,
70 | run_meta=meta,
71 | options=tf.profiler.ProfileOptionBuilder.time_and_memory())
72 |
--------------------------------------------------------------------------------
/third_party/tensorpack/scripts/ls-checkpoint.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # File: ls-checkpoint.py
4 |
5 | import numpy as np
6 | import pprint
7 | import sys
8 | import six
9 | import tensorflow as tf
10 |
11 | from tensorpack.tfutils.varmanip import get_checkpoint_path
12 |
13 | if __name__ == '__main__':
14 | fpath = sys.argv[1]
15 |
16 | if fpath.endswith('.npy'):
17 | params = np.load(fpath, encoding='latin1').item()
18 | dic = {k: v.shape for k, v in six.iteritems(params)}
19 | elif fpath.endswith('.npz'):
20 | params = dict(np.load(fpath))
21 | dic = {k: v.shape for k, v in six.iteritems(params)}
22 | else:
23 | path = get_checkpoint_path(fpath)
24 | reader = tf.train.NewCheckpointReader(path)
25 | dic = reader.get_variable_to_shape_map()
26 | pprint.pprint(dic)
27 |
--------------------------------------------------------------------------------
/third_party/tensorpack/setup.py:
--------------------------------------------------------------------------------
1 | from os import path
2 | import setuptools
3 | from setuptools import setup, find_packages
4 |
5 | version = int(setuptools.__version__.split('.')[0])
6 | assert version > 30, "Tensorpack installation requires setuptools > 30"
7 |
8 | this_directory = path.abspath(path.dirname(__file__))
9 |
10 | # setup metainfo
11 | libinfo_py = path.join(this_directory, 'tensorpack', 'libinfo.py')
12 | libinfo_content = open(libinfo_py, "r").readlines()
13 | version_line = [l.strip() for l in libinfo_content if l.startswith('__version__')][0]
14 | exec(version_line) # produce __version__
15 |
16 | with open(path.join(this_directory, 'README.md'), 'rb') as f:
17 | long_description = f.read().decode('utf-8')
18 |
19 |
20 | def add_git_version():
21 |
22 | def get_git_version():
23 | from subprocess import check_output
24 | try:
25 | return check_output("git describe --tags --long --dirty".split()).decode('utf-8').strip()
26 | except Exception:
27 | return __version__ # noqa
28 |
29 | newlibinfo_content = [l for l in libinfo_content if not l.startswith('__git_version__')]
30 | newlibinfo_content.append('__git_version__ = "{}"'.format(get_git_version()))
31 | with open(libinfo_py, "w") as f:
32 | f.write("".join(newlibinfo_content))
33 |
34 |
35 | add_git_version()
36 |
37 |
38 | setup(
39 | name='tensorpack',
40 | author="TensorPack contributors",
41 | author_email="ppwwyyxxc@gmail.com",
42 | url="https://github.com/tensorpack/tensorpack",
43 | keywords="tensorflow, deep learning, neural network",
44 | license="Apache",
45 |
46 | version=__version__, # noqa
47 | description='A Neural Network Training Interface on TensorFlow',
48 | long_description=long_description,
49 | long_description_content_type='text/markdown',
50 |
51 | packages=find_packages(exclude=["examples", "tests"]),
52 | zip_safe=False, # dataset and __init__ use file
53 |
54 | install_requires=[
55 | "numpy>=1.14",
56 | "six",
57 | "termcolor>=1.1",
58 | "tabulate>=0.7.7",
59 | "tqdm>4.29.0",
60 | "msgpack>=0.5.2",
61 | "msgpack-numpy>=0.4.4.2",
62 | "pyzmq>=16",
63 | "psutil>=5",
64 | ],
65 | tests_require=['flake8', 'scikit-image'],
66 | extras_require={
67 | 'all': ['scipy', 'h5py', 'lmdb>=0.92', 'matplotlib', 'scikit-learn'],
68 | 'all: "linux" in sys_platform': ['python-prctl'],
69 | },
70 |
71 | # https://packaging.python.org/guides/distributing-packages-using-setuptools/#universal-wheels
72 | options={'bdist_wheel': {'universal': '1'}},
73 | )
74 |
--------------------------------------------------------------------------------
/third_party/tensorpack/sotabench/sotabench.yml:
--------------------------------------------------------------------------------
1 | image: cuda10.0-cudnn7-ubuntu18.04
2 |
--------------------------------------------------------------------------------
/third_party/tensorpack/sotabench/sotabench_setup.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | set -v
4 |
5 | . /workspace/venv/bin/activate
6 |
7 | pip install -e .
8 | pip install tensorflow-gpu==1.14.0
9 | pip install opencv-python scipy
10 |
11 | echo "Extracting ..."
12 | cd ./.data/vision/coco
13 | python -c 'import zipfile; zipfile.ZipFile("annotations_trainval2017.zip").extractall()'
14 | python -c 'import zipfile; zipfile.ZipFile("val2017.zip").extractall()'
15 | cd -
16 |
--------------------------------------------------------------------------------
/third_party/tensorpack/tensorpack/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File: __init__.py
3 |
4 |
5 | from tensorpack.libinfo import __version__, __git_version__, _HAS_TF
6 |
7 | from tensorpack.utils import *
8 | from tensorpack.dataflow import *
9 |
10 | # dataflow can be used alone without installing tensorflow
11 |
12 | # https://github.com/celery/kombu/blob/7d13f9b95d0b50c94393b962e6def928511bfda6/kombu/__init__.py#L34-L36
13 | STATICA_HACK = True
14 | globals()['kcah_acitats'[::-1].upper()] = _HAS_TF
15 | if STATICA_HACK:
16 | from tensorpack.models import *
17 |
18 | from tensorpack.callbacks import *
19 | from tensorpack.tfutils import *
20 |
21 | from tensorpack.train import *
22 | from tensorpack.graph_builder import InputDesc # kept for BC
23 | from tensorpack.input_source import *
24 | from tensorpack.predict import *
25 |
--------------------------------------------------------------------------------
/third_party/tensorpack/tensorpack/callbacks/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File: __init__.py
3 |
4 |
5 | # https://github.com/celery/kombu/blob/7d13f9b95d0b50c94393b962e6def928511bfda6/kombu/__init__.py#L34-L36
6 | STATICA_HACK = True
7 | globals()['kcah_acitats'[::-1].upper()] = False
8 | if STATICA_HACK:
9 | from .base import *
10 | from .concurrency import *
11 | from .graph import *
12 | from .group import *
13 | from .hooks import *
14 | from .inference import *
15 | from .inference_runner import *
16 | from .monitor import *
17 | from .param import *
18 | from .prof import *
19 | from .saver import *
20 | from .misc import *
21 | from .steps import *
22 | from .summary import *
23 | from .trigger import *
24 |
25 |
26 | from pkgutil import iter_modules
27 | import os
28 |
29 |
30 | __all__ = []
31 |
32 |
33 | def _global_import(name):
34 | p = __import__(name, globals(), locals(), level=1)
35 | lst = p.__all__ if '__all__' in dir(p) else dir(p)
36 | if lst:
37 | del globals()[name]
38 | for k in lst:
39 | if not k.startswith('__'):
40 | globals()[k] = p.__dict__[k]
41 | __all__.append(k)
42 |
43 |
44 | _CURR_DIR = os.path.dirname(__file__)
45 | for _, module_name, _ in iter_modules(
46 | [_CURR_DIR]):
47 | srcpath = os.path.join(_CURR_DIR, module_name + '.py')
48 | if not os.path.isfile(srcpath):
49 | continue
50 | if module_name.endswith('_test'):
51 | continue
52 | if not module_name.startswith('_'):
53 | _global_import(module_name)
54 |
--------------------------------------------------------------------------------
/third_party/tensorpack/tensorpack/callbacks/concurrency.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File: concurrency.py
3 |
4 | import multiprocessing as mp
5 |
6 | from ..utils import logger
7 | from ..utils.concurrency import StoppableThread, start_proc_mask_signal
8 | from .base import Callback
9 |
10 | __all__ = ['StartProcOrThread']
11 |
12 |
13 | class StartProcOrThread(Callback):
14 | """
15 | Start some threads or processes before training.
16 | """
17 |
18 | _chief_only = False
19 |
20 | def __init__(self, startable, stop_at_last=True):
21 | """
22 | Args:
23 | startable (list): list of processes or threads which have ``start()`` method.
24 | Can also be a single instance of process of thread.
25 | stop_at_last (bool): whether to stop the processes or threads
26 | after training. It will use :meth:`Process.terminate()` or
27 | :meth:`StoppableThread.stop()`, but will do nothing on normal
28 | ``threading.Thread`` or other startable objects.
29 | """
30 | if not isinstance(startable, list):
31 | startable = [startable]
32 | self._procs_threads = startable
33 | self._stop_at_last = stop_at_last
34 |
35 | def _before_train(self):
36 | logger.info("Starting " +
37 | ', '.join([k.name for k in self._procs_threads]) + ' ...')
38 | # avoid sigint get handled by other processes
39 | start_proc_mask_signal(self._procs_threads)
40 |
41 | def _after_train(self):
42 | if not self._stop_at_last:
43 | return
44 | for k in self._procs_threads:
45 | if not k.is_alive():
46 | continue
47 | if isinstance(k, mp.Process):
48 | logger.info("Stopping {} ...".format(k.name))
49 | k.terminate()
50 | k.join(5.0)
51 | if k.is_alive():
52 | logger.error("Cannot join process {}.".format(k.name))
53 | elif isinstance(k, StoppableThread):
54 | logger.info("Stopping {} ...".format(k.name))
55 | k.stop()
56 | k.join(5.0)
57 | if k.is_alive():
58 | logger.error("Cannot join thread {}.".format(k.name))
59 |
--------------------------------------------------------------------------------
/third_party/tensorpack/tensorpack/callbacks/group.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File: group.py
3 |
4 |
5 | import traceback
6 | from contextlib import contextmanager
7 | from time import perf_counter as timer # noqa
8 | from ..compat import tfv1 as tf
9 |
10 | from ..utils import logger
11 | from ..utils.utils import humanize_time_delta
12 | from .base import Callback
13 | from .hooks import CallbackToHook
14 |
15 | __all__ = ['Callbacks']
16 |
17 |
18 | class CallbackTimeLogger(object):
19 | def __init__(self):
20 | self.times = []
21 | self.tot = 0
22 |
23 | def add(self, name, time):
24 | self.tot += time
25 | self.times.append((name, time))
26 |
27 | @contextmanager
28 | def timed_callback(self, name):
29 | s = timer()
30 | yield
31 | self.add(name, timer() - s)
32 |
33 | def log(self):
34 |
35 | """ log the time of some heavy callbacks """
36 | if self.tot < 3:
37 | return
38 | msgs = []
39 | for name, t in self.times:
40 | if t / self.tot > 0.3 and t > 1:
41 | msgs.append(name + ": " + humanize_time_delta(t))
42 | logger.info(
43 | "Callbacks took {:.3f} sec in total. {}".format(
44 | self.tot, '; '.join(msgs)))
45 |
46 |
47 | class Callbacks(Callback):
48 | """
49 | A container to hold all callbacks, and trigger them iteratively.
50 |
51 | This is only used by the base trainer to run all the callbacks.
52 | Users do not need to use this class.
53 | """
54 |
55 | def __init__(self, cbs):
56 | """
57 | Args:
58 | cbs(list): a list of :class:`Callback` instances.
59 | """
60 | # check type
61 | for cb in cbs:
62 | assert isinstance(cb, Callback), cb.__class__
63 | self.cbs = cbs
64 |
65 | def _setup_graph(self):
66 | with tf.name_scope(None): # clear the name scope
67 | for cb in self.cbs:
68 | cb.setup_graph(self.trainer)
69 |
70 | def _before_train(self):
71 | for cb in self.cbs:
72 | cb.before_train()
73 |
74 | def _after_train(self):
75 | for cb in self.cbs:
76 | # make sure callbacks are properly finalized
77 | try:
78 | cb.after_train()
79 | except Exception:
80 | traceback.print_exc()
81 |
82 | def get_hooks(self):
83 | return [CallbackToHook(cb) for cb in self.cbs]
84 |
85 | def trigger_step(self):
86 | for cb in self.cbs:
87 | cb.trigger_step()
88 |
89 | def _trigger_epoch(self):
90 | tm = CallbackTimeLogger()
91 |
92 | for cb in self.cbs:
93 | display_name = str(cb)
94 | with tm.timed_callback(display_name):
95 | cb.trigger_epoch()
96 | tm.log()
97 |
98 | def _before_epoch(self):
99 | for cb in self.cbs:
100 | cb.before_epoch()
101 |
102 | def _after_epoch(self):
103 | for cb in self.cbs:
104 | cb.after_epoch()
105 |
--------------------------------------------------------------------------------
/third_party/tensorpack/tensorpack/callbacks/hooks.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File: hooks.py
3 |
4 |
5 | """ Compatible layers between tf.train.SessionRunHook and Callback"""
6 |
7 | import tensorflow as tf
8 |
9 | from ..compat import tfv1
10 | from ..utils.develop import HIDE_DOC
11 |
12 | from .base import Callback
13 |
14 | __all__ = ['CallbackToHook', 'HookToCallback', 'TFLocalCLIDebugHook']
15 |
16 |
17 | class CallbackToHook(tfv1.train.SessionRunHook):
18 | """
19 | Hooks are less powerful than callbacks so the conversion is incomplete.
20 | It only converts the ``before_run/after_run`` calls.
21 |
22 | This is only for internal implementation of
23 | ``before_run/after_run`` callbacks.
24 | You shouldn't need to use this.
25 | """
26 |
27 | def __init__(self, cb):
28 | self._cb = cb
29 |
30 | @HIDE_DOC
31 | def before_run(self, ctx):
32 | return self._cb.before_run(ctx)
33 |
34 | @HIDE_DOC
35 | def after_run(self, ctx, vals):
36 | self._cb.after_run(ctx, vals)
37 |
38 |
39 | class HookToCallback(Callback):
40 | """
41 | Make a ``tf.train.SessionRunHook`` into a callback.
42 | Note that when ``SessionRunHook.after_create_session`` is called, the ``coord`` argument will be None.
43 | """
44 |
45 | _chief_only = False
46 |
47 | def __init__(self, hook):
48 | """
49 | Args:
50 | hook (tf.train.SessionRunHook):
51 | """
52 | self._hook = hook
53 |
54 | def _setup_graph(self):
55 | with tf.name_scope(None): # jump out of the name scope
56 | self._hook.begin()
57 |
58 | def _before_train(self):
59 | sess = tf.get_default_session()
60 | # coord is set to None when converting
61 | self._hook.after_create_session(sess, None)
62 |
63 | def _before_run(self, ctx):
64 | return self._hook.before_run(ctx)
65 |
66 | def _after_run(self, ctx, run_values):
67 | self._hook.after_run(ctx, run_values)
68 |
69 | def _after_train(self):
70 | self._hook.end(self.trainer.sess)
71 |
72 |
73 | class TFLocalCLIDebugHook(HookToCallback):
74 | """
75 | Use the hook `tfdbg.LocalCLIDebugHook` in tensorpack.
76 | """
77 |
78 | _chief_only = True
79 |
80 | def __init__(self, *args, **kwargs):
81 | """
82 | Args:
83 | args, kwargs: arguments to create `tfdbg.LocalCLIDebugHook`.
84 | Refer to tensorflow documentation for details.
85 | """
86 | from tensorflow.python import debug as tfdbg
87 | super(TFLocalCLIDebugHook, self).__init__(tfdbg.LocalCLIDebugHook(*args, **kwargs))
88 |
89 | def add_tensor_filter(self, *args, **kwargs):
90 | """
91 | Wrapper of `tfdbg.LocalCLIDebugHook.add_tensor_filter`.
92 | Refer to tensorflow documentation for details.
93 | """
94 | self._hook.add_tensor_filter(*args, **kwargs)
95 |
--------------------------------------------------------------------------------
/third_party/tensorpack/tensorpack/callbacks/misc.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File: misc.py
3 |
4 |
5 | import numpy as np
6 | import os
7 | import time
8 | from collections import deque
9 |
10 | from ..utils import logger
11 | from ..utils.utils import humanize_time_delta
12 | from .base import Callback
13 |
14 | __all__ = ['SendStat', 'InjectShell', 'EstimatedTimeLeft']
15 |
16 |
17 | class SendStat(Callback):
18 | """ An equivalent of :class:`SendMonitorData`, but as a normal callback. """
19 | def __init__(self, command, names):
20 | self.command = command
21 | if not isinstance(names, list):
22 | names = [names]
23 | self.names = names
24 |
25 | def _trigger(self):
26 | M = self.trainer.monitors
27 | v = {k: M.get_latest(k) for k in self.names}
28 | cmd = self.command.format(**v)
29 | ret = os.system(cmd)
30 | if ret != 0:
31 | logger.error("Command {} failed with ret={}!".format(cmd, ret))
32 |
33 |
34 | class InjectShell(Callback):
35 | """
36 | Allow users to create a specific file as a signal to pause
37 | and iteratively debug the training.
38 | Once the :meth:`trigger` method is called, it detects whether the file exists, and opens an
39 | IPython/pdb shell if yes.
40 | In the shell, ``self`` is this callback, ``self.trainer`` is the trainer, and
41 | from that you can access everything else.
42 |
43 | Example:
44 |
45 | .. code-block:: none
46 |
47 | callbacks=[InjectShell('/path/to/pause-training.tmp'), ...]
48 |
49 | # the following command will pause the training and start a shell when the epoch finishes:
50 | $ touch /path/to/pause-training.tmp
51 |
52 | """
53 |
54 | def __init__(self, file='INJECT_SHELL.tmp', shell='ipython'):
55 | """
56 | Args:
57 | file (str): if this file exists, will open a shell.
58 | shell (str): one of 'ipython', 'pdb'
59 | """
60 | self._file = file
61 | assert shell in ['ipython', 'pdb']
62 | self._shell = shell
63 | logger.info("Create a file '{}' to open {} shell.".format(file, shell))
64 |
65 | def _trigger(self):
66 | if os.path.isfile(self._file):
67 | logger.info("File {} exists, entering shell.".format(self._file))
68 | self._inject()
69 |
70 | def _inject(self):
71 | trainer = self.trainer # noqa
72 | if self._shell == 'ipython':
73 | import IPython as IP # noqa
74 | IP.embed()
75 | elif self._shell == 'pdb':
76 | import pdb # noqa
77 | pdb.set_trace()
78 |
79 | def _after_train(self):
80 | if os.path.isfile(self._file):
81 | os.unlink(self._file)
82 |
83 |
84 | class EstimatedTimeLeft(Callback):
85 | """
86 | Estimate the time left until completion of training.
87 | """
88 | def __init__(self, last_k_epochs=5, median=True):
89 | """
90 | Args:
91 | last_k_epochs (int): Use the time spent on last k epochs to estimate total time left.
92 | median (bool): Use the mean or median time spent on last k epochs.
93 | """
94 | self._times = deque(maxlen=last_k_epochs)
95 | self._median = median
96 |
97 | def _before_train(self):
98 | self._max_epoch = self.trainer.max_epoch
99 | self._last_time = time.time()
100 |
101 | def _trigger_epoch(self):
102 | duration = time.time() - self._last_time
103 | self._last_time = time.time()
104 | self._times.append(duration)
105 |
106 | epoch_time = np.median(self._times) if self._median else np.mean(self._times)
107 | time_left = (self._max_epoch - self.epoch_num) * epoch_time
108 | if time_left > 0:
109 | logger.info("Estimated Time Left: " + humanize_time_delta(time_left))
110 |
--------------------------------------------------------------------------------
/third_party/tensorpack/tensorpack/callbacks/param_test.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import unittest
3 | import tensorflow as tf
4 |
5 | from ..utils import logger
6 | from ..train.trainers import NoOpTrainer
7 | from .param import ScheduledHyperParamSetter, ObjAttrParam
8 |
9 |
10 | class ParamObject(object):
11 | """
12 | An object that holds the param to be set, for testing purposes.
13 | """
14 | PARAM_NAME = 'param'
15 |
16 | def __init__(self):
17 | self.param_history = {}
18 | self.__dict__[self.PARAM_NAME] = 1.0
19 |
20 | def __setattr__(self, name, value):
21 | if name == self.PARAM_NAME:
22 | self._set_param(value)
23 | super(ParamObject, self).__setattr__(name, value)
24 |
25 | def _set_param(self, value):
26 | self.param_history[self.trainer.global_step] = value
27 |
28 |
29 | class ScheduledHyperParamSetterTest(unittest.TestCase):
30 | def setUp(self):
31 | self._param_obj = ParamObject()
32 |
33 | def tearDown(self):
34 | tf.reset_default_graph()
35 |
36 | def _create_trainer_with_scheduler(self, scheduler,
37 | steps_per_epoch, max_epoch, starting_epoch=1):
38 | trainer = NoOpTrainer()
39 | tf.get_variable(name='test_var', shape=[])
40 | self._param_obj.trainer = trainer
41 | trainer.train_with_defaults(
42 | callbacks=[scheduler],
43 | extra_callbacks=[],
44 | monitors=[],
45 | steps_per_epoch=steps_per_epoch,
46 | max_epoch=max_epoch,
47 | starting_epoch=starting_epoch
48 | )
49 | return self._param_obj.param_history
50 |
51 | def testInterpolation(self):
52 | scheduler = ScheduledHyperParamSetter(
53 | ObjAttrParam(self._param_obj, ParamObject.PARAM_NAME),
54 | [(30, 0.3), (40, 0.4), (50, 0.5)], interp='linear', step_based=True)
55 | history = self._create_trainer_with_scheduler(scheduler, 10, 50, starting_epoch=20)
56 | self.assertEqual(min(history.keys()), 30)
57 | self.assertEqual(history[30], 0.3)
58 | self.assertEqual(history[40], 0.4)
59 | self.assertEqual(history[45], 0.45)
60 |
61 | def testSchedule(self):
62 | scheduler = ScheduledHyperParamSetter(
63 | ObjAttrParam(self._param_obj, ParamObject.PARAM_NAME),
64 | [(10, 0.3), (20, 0.4), (30, 0.5)])
65 | history = self._create_trainer_with_scheduler(scheduler, 1, 50)
66 | self.assertEqual(min(history.keys()), 10)
67 | self.assertEqual(len(history), 3)
68 |
69 | def testStartAfterSchedule(self):
70 | scheduler = ScheduledHyperParamSetter(
71 | ObjAttrParam(self._param_obj, ParamObject.PARAM_NAME),
72 | [(10, 0.3), (20, 0.4), (30, 0.5)])
73 | history = self._create_trainer_with_scheduler(scheduler, 1, 92, starting_epoch=90)
74 | self.assertEqual(len(history), 0)
75 |
76 | def testWarningStartInTheMiddle(self):
77 | scheduler = ScheduledHyperParamSetter(
78 | ObjAttrParam(self._param_obj, ParamObject.PARAM_NAME),
79 | [(10, 0.3), (20, 0.4), (30, 0.5)])
80 | with self.assertLogs(logger=logger._logger, level='WARNING'):
81 | self._create_trainer_with_scheduler(scheduler, 1, 21, starting_epoch=20)
82 |
83 | def testNoWarningStartInTheMiddle(self):
84 | scheduler = ScheduledHyperParamSetter(
85 | ObjAttrParam(self._param_obj, ParamObject.PARAM_NAME),
86 | [(10, 0.3), (20, 1.0), (30, 1.5)])
87 | with unittest.mock.patch('tensorpack.utils.logger.warning') as warning:
88 | self._create_trainer_with_scheduler(scheduler, 1, 22, starting_epoch=21)
89 | self.assertFalse(warning.called)
90 |
91 |
92 | if __name__ == '__main__':
93 | unittest.main()
94 |
--------------------------------------------------------------------------------
/third_party/tensorpack/tensorpack/callbacks/stats.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File: stats.py
3 |
4 | from .graph import DumpParamAsImage # noqa
5 | # for compatibility only
6 | from .misc import InjectShell, SendStat # noqa
7 |
8 | __all__ = []
9 |
--------------------------------------------------------------------------------
/third_party/tensorpack/tensorpack/compat/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | import tensorflow as tf
4 |
5 |
6 | def backport_tensor_spec():
7 | if hasattr(tf, 'TensorSpec'):
8 | return tf.TensorSpec
9 | try:
10 | # available since 1.7
11 | from tensorflow.python.framework.tensor_spec import TensorSpec
12 | except ImportError:
13 | pass
14 | else:
15 | tf.TensorSpec = TensorSpec
16 | return TensorSpec
17 |
18 | from .tensor_spec import TensorSpec
19 | tf.TensorSpec = TensorSpec
20 | return TensorSpec
21 |
22 |
23 | def is_tfv2():
24 | try:
25 | from tensorflow.python import tf2
26 | return tf2.enabled()
27 | except Exception:
28 | return False
29 |
30 |
31 | if is_tfv2():
32 | tfv1 = tf.compat.v1
33 | if not hasattr(tf, 'layers'):
34 | # promised at https://github.com/tensorflow/community/pull/24#issuecomment-440453886
35 | tf.layers = tf.keras.layers
36 | else:
37 | try:
38 | tfv1 = tf.compat.v1 # this will silent some warnings
39 | except AttributeError:
40 | tfv1 = tf
41 |
--------------------------------------------------------------------------------
/third_party/tensorpack/tensorpack/compat/tensor_spec.py:
--------------------------------------------------------------------------------
1 |
2 | """
3 | Copied from tensorflow/python/framework/tensor_spec.py
4 | """
5 |
6 | from __future__ import absolute_import
7 | from __future__ import division
8 | from __future__ import print_function
9 |
10 | import numpy as np
11 |
12 | from tensorflow.python.framework import common_shapes
13 | from tensorflow.python.framework import dtypes
14 | from tensorflow.python.framework import ops
15 | from tensorflow.python.framework import tensor_shape
16 |
17 |
18 | class TensorSpec(object):
19 | """Describes a tf.Tensor.
20 |
21 | Metadata for describing the `tf.Tensor` objects accepted or returned
22 | by some TensorFlow APIs.
23 | """
24 |
25 | __slots__ = ["_shape", "_shape_tuple", "_dtype", "_name"]
26 |
27 | def __init__(self, shape, dtype=dtypes.float32, name=None):
28 | """Creates a TensorSpec.
29 |
30 | Args:
31 | shape: Value convertible to `tf.TensorShape`. The shape of the tensor.
32 | dtype: Value convertible to `tf.DType`. The type of the tensor values.
33 | name: Optional name for the Tensor.
34 |
35 | Raises:
36 | TypeError: If shape is not convertible to a `tf.TensorShape`, or dtype is
37 | not convertible to a `tf.DType`.
38 | """
39 | self._shape = tensor_shape.TensorShape(shape)
40 | try:
41 | self._shape_tuple = tuple(self.shape.as_list())
42 | except ValueError:
43 | self._shape_tuple = None
44 | self._dtype = dtypes.as_dtype(dtype)
45 | self._name = name
46 |
47 | @classmethod
48 | def from_spec(cls, spec, name=None):
49 | return cls(spec.shape, spec.dtype, name or spec.name)
50 |
51 | @classmethod
52 | def from_tensor(cls, tensor, name=None):
53 | if isinstance(tensor, ops.EagerTensor):
54 | return TensorSpec(tensor.shape, tensor.dtype, name)
55 | elif isinstance(tensor, ops.Tensor):
56 | return TensorSpec(tensor.shape, tensor.dtype, name or tensor.op.name)
57 | else:
58 | raise ValueError("`tensor` should be a tf.Tensor")
59 |
60 | @property
61 | def shape(self):
62 | """Returns the `TensorShape` that represents the shape of the tensor."""
63 | return self._shape
64 |
65 | @property
66 | def dtype(self):
67 | """Returns the `dtype` of elements in the tensor."""
68 | return self._dtype
69 |
70 | @property
71 | def name(self):
72 | """Returns the (optionally provided) name of the described tensor."""
73 | return self._name
74 |
75 | def is_compatible_with(self, spec_or_tensor):
76 | """Returns True if spec_or_tensor is compatible with this TensorSpec.
77 |
78 | Two tensors are considered compatible if they have the same dtype
79 | and their shapes are compatible (see `tf.TensorShape.is_compatible_with`).
80 |
81 | Args:
82 | spec_or_tensor: A tf.TensorSpec or a tf.Tensor
83 |
84 | Returns:
85 | True if spec_or_tensor is compatible with self.
86 | """
87 | return (self._dtype.is_compatible_with(spec_or_tensor.dtype) and
88 | self._shape.is_compatible_with(spec_or_tensor.shape))
89 |
90 | def __repr__(self):
91 | return "TensorSpec(shape={}, dtype={}, name={})".format(
92 | self.shape, repr(self.dtype), repr(self.name))
93 |
94 | def __hash__(self):
95 | return hash((self._shape_tuple, self.dtype))
96 |
97 | def __eq__(self, other):
98 | return (self._shape_tuple == other._shape_tuple # pylint: disable=protected-access
99 | and self.dtype == other.dtype
100 | and self._name == other._name) # pylint: disable=protected-access
101 |
102 | def __ne__(self, other):
103 | return not self == other
104 |
105 | def __reduce__(self):
106 | return TensorSpec, (self._shape, self._dtype, self._name)
107 |
--------------------------------------------------------------------------------
/third_party/tensorpack/tensorpack/contrib/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-research/ssl_detection/00d52272f61b56eade8d5ace18213cba6c74f6d8/third_party/tensorpack/tensorpack/contrib/__init__.py
--------------------------------------------------------------------------------
/third_party/tensorpack/tensorpack/dataflow/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File: __init__.py
3 |
4 | # https://github.com/celery/kombu/blob/7d13f9b95d0b50c94393b962e6def928511bfda6/kombu/__init__.py#L34-L36
5 | STATICA_HACK = True
6 | globals()['kcah_acitats'[::-1].upper()] = False
7 | if STATICA_HACK:
8 | from .base import *
9 | from .common import *
10 | from .format import *
11 | from .image import *
12 | from .parallel_map import *
13 | from .parallel import *
14 | from .raw import *
15 | from .remote import *
16 | from . import imgaug
17 | from . import dataset
18 |
19 |
20 | from pkgutil import iter_modules
21 | import os
22 | import os.path
23 | from ..utils.develop import LazyLoader
24 |
25 | __all__ = []
26 |
27 |
28 | def _global_import(name):
29 | p = __import__(name, globals(), locals(), level=1)
30 | lst = p.__all__ if '__all__' in dir(p) else dir(p)
31 | if lst:
32 | del globals()[name]
33 | for k in lst:
34 | if not k.startswith('__'):
35 | globals()[k] = p.__dict__[k]
36 | __all__.append(k)
37 |
38 |
39 | __SKIP = set(['dataset', 'imgaug'])
40 | _CURR_DIR = os.path.dirname(__file__)
41 | for _, module_name, __ in iter_modules(
42 | [os.path.dirname(__file__)]):
43 | srcpath = os.path.join(_CURR_DIR, module_name + '.py')
44 | if not os.path.isfile(srcpath):
45 | continue
46 | if "_test" not in module_name and \
47 | not module_name.startswith('_') and \
48 | module_name not in __SKIP:
49 | _global_import(module_name)
50 |
51 |
52 | globals()['dataset'] = LazyLoader('dataset', globals(), __name__ + '.dataset')
53 | globals()['imgaug'] = LazyLoader('imgaug', globals(), __name__ + '.imgaug')
54 |
55 | del LazyLoader
56 |
57 | __all__.extend(['imgaug', 'dataset'])
58 |
--------------------------------------------------------------------------------
/third_party/tensorpack/tensorpack/dataflow/dataset/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File: __init__.py
3 |
4 | # https://github.com/celery/kombu/blob/7d13f9b95d0b50c94393b962e6def928511bfda6/kombu/__init__.py#L34-L36
5 | STATICA_HACK = True
6 | globals()['kcah_acitats'[::-1].upper()] = False
7 | if STATICA_HACK:
8 | from .bsds500 import *
9 | from .cifar import *
10 | from .ilsvrc import *
11 | from .mnist import *
12 | from .svhn import *
13 | from .caltech101 import *
14 |
15 | from pkgutil import iter_modules
16 | import os
17 | import os.path
18 |
19 | __all__ = []
20 |
21 |
22 | def global_import(name):
23 | p = __import__(name, globals(), locals(), level=1)
24 | lst = p.__all__ if '__all__' in dir(p) else dir(p)
25 | if lst:
26 | del globals()[name]
27 | for k in lst:
28 | if not k.startswith('__'):
29 | globals()[k] = p.__dict__[k]
30 | __all__.append(k)
31 |
32 |
33 | _CURR_DIR = os.path.dirname(__file__)
34 | for _, module_name, _ in iter_modules(
35 | [_CURR_DIR]):
36 | srcpath = os.path.join(_CURR_DIR, module_name + '.py')
37 | if not os.path.isfile(srcpath):
38 | continue
39 | if not module_name.startswith('_'):
40 | global_import(module_name)
41 |
--------------------------------------------------------------------------------
/third_party/tensorpack/tensorpack/dataflow/dataset/bsds500.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File: bsds500.py
3 |
4 |
5 | import glob
6 | import numpy as np
7 | import os
8 |
9 | from ...utils.fs import download, get_dataset_path
10 | from ..base import RNGDataFlow
11 |
12 | __all__ = ['BSDS500']
13 |
14 |
15 | DATA_URL = "http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/BSR/BSR_bsds500.tgz"
16 | DATA_SIZE = 70763455
17 | IMG_W, IMG_H = 481, 321
18 |
19 |
20 | class BSDS500(RNGDataFlow):
21 | """
22 | `Berkeley Segmentation Data Set and Benchmarks 500 dataset
23 | `_.
24 |
25 | Produce ``(image, label)`` pair, where ``image`` has shape (321, 481, 3(BGR)) and
26 | ranges in [0,255].
27 | ``Label`` is a floating point image of shape (321, 481) in range [0, 1].
28 | The value of each pixel is ``number of times it is annotated as edge / total number of annotators for this image``.
29 | """
30 |
31 | def __init__(self, name, data_dir=None, shuffle=True):
32 | """
33 | Args:
34 | name (str): 'train', 'test', 'val'
35 | data_dir (str): a directory containing the original 'BSR' directory.
36 | """
37 | # check and download data
38 | if data_dir is None:
39 | data_dir = get_dataset_path('bsds500_data')
40 | if not os.path.isdir(os.path.join(data_dir, 'BSR')):
41 | download(DATA_URL, data_dir, expect_size=DATA_SIZE)
42 | filename = DATA_URL.split('/')[-1]
43 | filepath = os.path.join(data_dir, filename)
44 | import tarfile
45 | tarfile.open(filepath, 'r:gz').extractall(data_dir)
46 | self.data_root = os.path.join(data_dir, 'BSR', 'BSDS500', 'data')
47 | assert os.path.isdir(self.data_root)
48 |
49 | self.shuffle = shuffle
50 | assert name in ['train', 'test', 'val']
51 | self._load(name)
52 |
53 | def _load(self, name):
54 | image_glob = os.path.join(self.data_root, 'images', name, '*.jpg')
55 | image_files = glob.glob(image_glob)
56 | gt_dir = os.path.join(self.data_root, 'groundTruth', name)
57 | self.data = np.zeros((len(image_files), IMG_H, IMG_W, 3), dtype='uint8')
58 | self.label = np.zeros((len(image_files), IMG_H, IMG_W), dtype='float32')
59 |
60 | for idx, f in enumerate(image_files):
61 | im = cv2.imread(f, cv2.IMREAD_COLOR)
62 | assert im is not None
63 | if im.shape[0] > im.shape[1]:
64 | im = np.transpose(im, (1, 0, 2))
65 | assert im.shape[:2] == (IMG_H, IMG_W), "{} != {}".format(im.shape[:2], (IMG_H, IMG_W))
66 |
67 | imgid = os.path.basename(f).split('.')[0]
68 | gt_file = os.path.join(gt_dir, imgid)
69 | gt = loadmat(gt_file)['groundTruth'][0]
70 | n_annot = gt.shape[0]
71 | gt = sum(gt[k]['Boundaries'][0][0] for k in range(n_annot))
72 | gt = gt.astype('float32')
73 | gt *= 1.0 / n_annot
74 | if gt.shape[0] > gt.shape[1]:
75 | gt = gt.transpose()
76 | assert gt.shape == (IMG_H, IMG_W)
77 |
78 | self.data[idx] = im
79 | self.label[idx] = gt
80 |
81 | def __len__(self):
82 | return self.data.shape[0]
83 |
84 | def __iter__(self):
85 | idxs = np.arange(self.data.shape[0])
86 | if self.shuffle:
87 | self.rng.shuffle(idxs)
88 | for k in idxs:
89 | yield [self.data[k], self.label[k]]
90 |
91 |
92 | try:
93 | from scipy.io import loadmat
94 | import cv2
95 | except ImportError:
96 | from ...utils.develop import create_dummy_class
97 | BSDS500 = create_dummy_class('BSDS500', ['scipy.io', 'cv2']) # noqa
98 |
99 | if __name__ == '__main__':
100 | a = BSDS500('val')
101 | a.reset_state()
102 | for k in a:
103 | cv2.imshow("haha", k[1].astype('uint8') * 255)
104 | cv2.waitKey(1000)
105 |
--------------------------------------------------------------------------------
/third_party/tensorpack/tensorpack/dataflow/dataset/caltech101.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File: caltech101.py
3 |
4 |
5 | import os
6 |
7 | from ...utils import logger
8 | from ...utils.fs import download, get_dataset_path
9 | from ..base import RNGDataFlow
10 |
11 | __all__ = ["Caltech101Silhouettes"]
12 |
13 |
14 | def maybe_download(url, work_directory):
15 | """Download the data from Marlin's website, unless it's already here."""
16 | filename = url.split("/")[-1]
17 | filepath = os.path.join(work_directory, filename)
18 | if not os.path.exists(filepath):
19 | logger.info("Downloading to {}...".format(filepath))
20 | download(url, work_directory)
21 | return filepath
22 |
23 |
24 | class Caltech101Silhouettes(RNGDataFlow):
25 | """
26 | Produces [image, label] in Caltech101 Silhouettes dataset,
27 | image is 28x28 in the range [0,1], label is an int in the range [0,100].
28 | """
29 |
30 | _DIR_NAME = "caltech101_data"
31 | _SOURCE_URL = "https://people.cs.umass.edu/~marlin/data/"
32 |
33 | def __init__(self, name, shuffle=True, dir=None):
34 | """
35 | Args:
36 | name (str): 'train', 'test', 'val'
37 | shuffle (bool): shuffle the dataset
38 | """
39 | if dir is None:
40 | dir = get_dataset_path(self._DIR_NAME)
41 | assert name in ['train', 'test', 'val']
42 | self.name = name
43 | self.shuffle = shuffle
44 |
45 | def get_images_and_labels(data_file):
46 | f = maybe_download(self._SOURCE_URL + data_file, dir)
47 | data = scipy.io.loadmat(f)
48 | return data
49 |
50 | self.data = get_images_and_labels("caltech101_silhouettes_28_split1.mat")
51 |
52 | if self.name == "train":
53 | self.images = self.data["train_data"].reshape((4100, 28, 28))
54 | self.labels = self.data["train_labels"].ravel() - 1
55 | elif self.name == "test":
56 | self.images = self.data["test_data"].reshape((2307, 28, 28))
57 | self.labels = self.data["test_labels"].ravel() - 1
58 | else:
59 | self.images = self.data["val_data"].reshape((2264, 28, 28))
60 | self.labels = self.data["val_labels"].ravel() - 1
61 |
62 | def __len__(self):
63 | return self.images.shape[0]
64 |
65 | def __iter__(self):
66 | idxs = list(range(self.__len__()))
67 | if self.shuffle:
68 | self.rng.shuffle(idxs)
69 | for k in idxs:
70 | img = self.images[k]
71 | label = self.labels[k]
72 | yield [img, label]
73 |
74 |
75 | try:
76 | import scipy.io
77 | except ImportError:
78 | from ...utils.develop import create_dummy_class
79 | Caltech101Silhouettes = create_dummy_class('Caltech101Silhouettes', 'scipy.io') # noqa
80 |
81 |
82 | if __name__ == "__main__":
83 | ds = Caltech101Silhouettes("train")
84 | ds.reset_state()
85 | for _ in ds:
86 | from IPython import embed
87 |
88 | embed()
89 | break
90 |
--------------------------------------------------------------------------------
/third_party/tensorpack/tensorpack/dataflow/dataset/svhn.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File: svhn.py
3 |
4 |
5 | import numpy as np
6 | import os
7 |
8 | from ...utils import logger
9 | from ...utils.fs import download, get_dataset_path
10 | from ..base import RNGDataFlow
11 |
12 | __all__ = ['SVHNDigit']
13 |
14 | SVHN_URL = "http://ufldl.stanford.edu/housenumbers/"
15 |
16 |
17 | class SVHNDigit(RNGDataFlow):
18 | """
19 | `SVHN `_ Cropped Digit Dataset.
20 | Produces [img, label], img of 32x32x3 in range [0,255], label of 0-9
21 | """
22 | _Cache = {}
23 |
24 | def __init__(self, name, data_dir=None, shuffle=True):
25 | """
26 | Args:
27 | name (str): 'train', 'test', or 'extra'.
28 | data_dir (str): a directory containing the original {train,test,extra}_32x32.mat.
29 | shuffle (bool): shuffle the dataset.
30 | """
31 | self.shuffle = shuffle
32 |
33 | if name in SVHNDigit._Cache:
34 | self.X, self.Y = SVHNDigit._Cache[name]
35 | return
36 | if data_dir is None:
37 | data_dir = get_dataset_path('svhn_data')
38 | assert name in ['train', 'test', 'extra'], name
39 | filename = os.path.join(data_dir, name + '_32x32.mat')
40 | if not os.path.isfile(filename):
41 | url = SVHN_URL + os.path.basename(filename)
42 | logger.info("File {} not found!".format(filename))
43 | logger.info("Downloading from {} ...".format(url))
44 | download(url, os.path.dirname(filename))
45 | logger.info("Loading {} ...".format(filename))
46 | data = scipy.io.loadmat(filename)
47 | self.X = data['X'].transpose(3, 0, 1, 2)
48 | self.Y = data['y'].reshape((-1))
49 | self.Y[self.Y == 10] = 0
50 | SVHNDigit._Cache[name] = (self.X, self.Y)
51 |
52 | def __len__(self):
53 | return self.X.shape[0]
54 |
55 | def __iter__(self):
56 | n = self.X.shape[0]
57 | idxs = np.arange(n)
58 | if self.shuffle:
59 | self.rng.shuffle(idxs)
60 | for k in idxs:
61 | # since svhn is quite small, just do it for safety
62 | yield [self.X[k], self.Y[k]]
63 |
64 | @staticmethod
65 | def get_per_pixel_mean(names=('train', 'test', 'extra')):
66 | """
67 | Args:
68 | names (tuple[str]): names of the dataset split
69 |
70 | Returns:
71 | a 32x32x3 image, the mean of all images in the given datasets
72 | """
73 | for name in names:
74 | assert name in ['train', 'test', 'extra'], name
75 | images = [SVHNDigit(x).X for x in names]
76 | return np.concatenate(tuple(images)).mean(axis=0)
77 |
78 |
79 | try:
80 | import scipy.io
81 | except ImportError:
82 | from ...utils.develop import create_dummy_class
83 | SVHNDigit = create_dummy_class('SVHNDigit', 'scipy.io') # noqa
84 |
85 | if __name__ == '__main__':
86 | a = SVHNDigit('train')
87 | b = SVHNDigit.get_per_pixel_mean()
88 |
--------------------------------------------------------------------------------
/third_party/tensorpack/tensorpack/dataflow/imgaug/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File: __init__.py
3 |
4 | # https://github.com/celery/kombu/blob/7d13f9b95d0b50c94393b962e6def928511bfda6/kombu/__init__.py#L34-L36
5 | STATICA_HACK = True
6 | globals()['kcah_acitats'[::-1].upper()] = False
7 | if STATICA_HACK:
8 | from .base import *
9 | from .convert import *
10 | from .crop import *
11 | from .deform import *
12 | from .geometry import *
13 | from .imgproc import *
14 | from .meta import *
15 | from .misc import *
16 | from .noise import *
17 | from .paste import *
18 | from .transform import *
19 | from .external import *
20 |
21 |
22 | import os
23 | from pkgutil import iter_modules
24 |
25 | __all__ = []
26 |
27 |
28 | def global_import(name):
29 | p = __import__(name, globals(), locals(), level=1)
30 | lst = p.__all__ if '__all__' in dir(p) else dir(p)
31 | if lst:
32 | del globals()[name]
33 | for k in lst:
34 | if not k.startswith('__'):
35 | globals()[k] = p.__dict__[k]
36 | __all__.append(k)
37 |
38 |
39 | try:
40 | import cv2 # noqa
41 | except ImportError:
42 | from ...utils import logger
43 | logger.warn("Cannot import 'cv2', therefore image augmentation is not available.")
44 | else:
45 | _CURR_DIR = os.path.dirname(__file__)
46 | for _, module_name, _ in iter_modules(
47 | [os.path.dirname(__file__)]):
48 | srcpath = os.path.join(_CURR_DIR, module_name + '.py')
49 | if not os.path.isfile(srcpath):
50 | continue
51 | if not module_name.startswith('_') and "_test" not in module_name:
52 | global_import(module_name)
53 |
--------------------------------------------------------------------------------
/third_party/tensorpack/tensorpack/dataflow/imgaug/convert.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File: convert.py
3 |
4 | import numpy as np
5 | import cv2
6 |
7 | from .base import PhotometricAugmentor
8 |
9 | __all__ = ['ColorSpace', 'Grayscale', 'ToUint8', 'ToFloat32']
10 |
11 |
12 | class ColorSpace(PhotometricAugmentor):
13 | """ Convert into another color space. """
14 |
15 | def __init__(self, mode, keepdims=True):
16 | """
17 | Args:
18 | mode: OpenCV color space conversion code (e.g., ``cv2.COLOR_BGR2HSV``)
19 | keepdims (bool): keep the dimension of image unchanged if OpenCV
20 | changes it.
21 | """
22 | super(ColorSpace, self).__init__()
23 | self._init(locals())
24 |
25 | def _augment(self, img, _):
26 | transf = cv2.cvtColor(img, self.mode)
27 | if self.keepdims:
28 | if len(transf.shape) is not len(img.shape):
29 | transf = transf[..., None]
30 | return transf
31 |
32 |
33 | class Grayscale(ColorSpace):
34 | """ Convert image to grayscale. """
35 |
36 | def __init__(self, keepdims=True, rgb=False):
37 | """
38 | Args:
39 | keepdims (bool): return image of shape [H, W, 1] instead of [H, W]
40 | rgb (bool): interpret input as RGB instead of the default BGR
41 | """
42 | mode = cv2.COLOR_RGB2GRAY if rgb else cv2.COLOR_BGR2GRAY
43 | super(Grayscale, self).__init__(mode, keepdims)
44 |
45 |
46 | class ToUint8(PhotometricAugmentor):
47 | """ Convert image to uint8. Useful to reduce communication overhead. """
48 | def _augment(self, img, _):
49 | return np.clip(img, 0, 255).astype(np.uint8)
50 |
51 |
52 | class ToFloat32(PhotometricAugmentor):
53 | """ Convert image to float32, may increase quality of the augmentor. """
54 | def _augment(self, img, _):
55 | return img.astype(np.float32)
56 |
--------------------------------------------------------------------------------
/third_party/tensorpack/tensorpack/dataflow/imgaug/external.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | import numpy as np
4 |
5 | from .base import ImageAugmentor
6 | from .transform import Transform
7 |
8 | __all__ = ['IAAugmentor', 'Albumentations']
9 |
10 |
11 | class IAATransform(Transform):
12 | def __init__(self, aug, img_shape):
13 | self._init(locals())
14 |
15 | def apply_image(self, img):
16 | return self.aug.augment_image(img)
17 |
18 | def apply_coords(self, coords):
19 | import imgaug as IA
20 | points = [IA.Keypoint(x=x, y=y) for x, y in coords]
21 | points = IA.KeypointsOnImage(points, shape=self.img_shape)
22 | augmented = self.aug.augment_keypoints([points])[0].keypoints
23 | return np.asarray([[p.x, p.y] for p in augmented])
24 |
25 |
26 | class IAAugmentor(ImageAugmentor):
27 | """
28 | Wrap an augmentor form the IAA library: https://github.com/aleju/imgaug.
29 | Both images and coordinates are supported.
30 |
31 | Note:
32 | 1. It's NOT RECOMMENDED
33 | to use coordinates because the IAA library does not handle coordinates accurately.
34 |
35 | 2. Only uint8 images are supported by the IAA library.
36 |
37 | 3. The IAA library can only produces images of the same shape.
38 |
39 | Example:
40 |
41 | .. code-block:: python
42 |
43 | from imgaug import augmenters as iaa # this is the aleju/imgaug library
44 | from tensorpack import imgaug # this is not the aleju/imgaug library
45 | # or from dataflow import imgaug # if you're using the standalone version of dataflow
46 | myaug = imgaug.IAAugmentor(
47 | iaa.Sequential([
48 | iaa.Sharpen(alpha=(0, 1), lightness=(0.75, 1.5)),
49 | iaa.Fliplr(0.5),
50 | iaa.Crop(px=(0, 100)),
51 | ])
52 | """
53 |
54 | def __init__(self, augmentor):
55 | """
56 | Args:
57 | augmentor (iaa.Augmenter):
58 | """
59 | super(IAAugmentor, self).__init__()
60 | self._aug = augmentor
61 |
62 | def get_transform(self, img):
63 | return IAATransform(self._aug.to_deterministic(), img.shape)
64 |
65 |
66 | class AlbumentationsTransform(Transform):
67 | def __init__(self, aug, param):
68 | self._init(locals())
69 |
70 | def apply_image(self, img):
71 | return self.aug.apply(img, **self.param)
72 |
73 |
74 | class Albumentations(ImageAugmentor):
75 | """
76 | Wrap an augmentor form the albumentations library: https://github.com/albu/albumentations.
77 | Coordinate augmentation is not supported by the library.
78 |
79 | Example:
80 |
81 | .. code-block:: python
82 |
83 | from tensorpack import imgaug
84 | # or from dataflow import imgaug # if you're using the standalone version of dataflow
85 | import albumentations as AB
86 | myaug = imgaug.Albumentations(AB.RandomRotate90(p=1))
87 | """
88 | def __init__(self, augmentor):
89 | """
90 | Args:
91 | augmentor (albumentations.BasicTransform):
92 | """
93 | super(Albumentations, self).__init__()
94 | self._aug = augmentor
95 |
96 | def get_transform(self, img):
97 | return AlbumentationsTransform(self._aug, self._aug.get_params())
98 |
--------------------------------------------------------------------------------
/third_party/tensorpack/tensorpack/dataflow/imgaug/noise.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File: noise.py
3 |
4 |
5 | import numpy as np
6 | import cv2
7 |
8 | from .base import PhotometricAugmentor
9 |
10 | __all__ = ['JpegNoise', 'GaussianNoise', 'SaltPepperNoise']
11 |
12 |
13 | class JpegNoise(PhotometricAugmentor):
14 | """ Random JPEG noise. """
15 |
16 | def __init__(self, quality_range=(40, 100)):
17 | """
18 | Args:
19 | quality_range (tuple): range to sample JPEG quality
20 | """
21 | super(JpegNoise, self).__init__()
22 | self._init(locals())
23 |
24 | def _get_augment_params(self, img):
25 | return self.rng.randint(*self.quality_range)
26 |
27 | def _augment(self, img, q):
28 | enc = cv2.imencode('.jpg', img, [cv2.IMWRITE_JPEG_QUALITY, q])[1]
29 | return cv2.imdecode(enc, 1).astype(img.dtype)
30 |
31 |
32 | class GaussianNoise(PhotometricAugmentor):
33 | """
34 | Add random Gaussian noise N(0, sigma^2) of the same shape to img.
35 | """
36 | def __init__(self, sigma=1, clip=True):
37 | """
38 | Args:
39 | sigma (float): stddev of the Gaussian distribution.
40 | clip (bool): clip the result to [0,255] in the end.
41 | """
42 | super(GaussianNoise, self).__init__()
43 | self._init(locals())
44 |
45 | def _get_augment_params(self, img):
46 | return self.rng.randn(*img.shape)
47 |
48 | def _augment(self, img, noise):
49 | old_dtype = img.dtype
50 | ret = img + noise * self.sigma
51 | if self.clip or old_dtype == np.uint8:
52 | ret = np.clip(ret, 0, 255)
53 | return ret.astype(old_dtype)
54 |
55 |
56 | class SaltPepperNoise(PhotometricAugmentor):
57 | """ Salt and pepper noise.
58 | Randomly set some elements in image to 0 or 255, regardless of its channels.
59 | """
60 |
61 | def __init__(self, white_prob=0.05, black_prob=0.05):
62 | """
63 | Args:
64 | white_prob (float), black_prob (float): probabilities setting an element to 255 or 0.
65 | """
66 | assert white_prob + black_prob <= 1, "Sum of probabilities cannot be greater than 1"
67 | super(SaltPepperNoise, self).__init__()
68 | self._init(locals())
69 |
70 | def _get_augment_params(self, img):
71 | return self.rng.uniform(low=0, high=1, size=img.shape)
72 |
73 | def _augment(self, img, param):
74 | img[param > (1 - self.white_prob)] = 255
75 | img[param < self.black_prob] = 0
76 | return img
77 |
--------------------------------------------------------------------------------
/third_party/tensorpack/tensorpack/dataflow/imgaug/paste.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File: paste.py
3 |
4 |
5 | import numpy as np
6 | from abc import abstractmethod
7 |
8 | from .base import ImageAugmentor
9 | from .transform import TransformFactory
10 |
11 | __all__ = ['CenterPaste', 'BackgroundFiller', 'ConstantBackgroundFiller',
12 | 'RandomPaste']
13 |
14 |
15 | class BackgroundFiller(object):
16 | """ Base class for all BackgroundFiller"""
17 |
18 | def fill(self, background_shape, img):
19 | """
20 | Return a proper background image of background_shape, given img.
21 |
22 | Args:
23 | background_shape (tuple): a shape (h, w)
24 | img: an image
25 | Returns:
26 | a background image
27 | """
28 | background_shape = tuple(background_shape)
29 | return self._fill(background_shape, img)
30 |
31 | @abstractmethod
32 | def _fill(self, background_shape, img):
33 | pass
34 |
35 |
36 | class ConstantBackgroundFiller(BackgroundFiller):
37 | """ Fill the background by a constant """
38 |
39 | def __init__(self, value):
40 | """
41 | Args:
42 | value (float): the value to fill the background.
43 | """
44 | self.value = value
45 |
46 | def _fill(self, background_shape, img):
47 | assert img.ndim in [3, 2]
48 | if img.ndim == 3:
49 | return_shape = background_shape + (img.shape[2],)
50 | else:
51 | return_shape = background_shape
52 | return np.zeros(return_shape, dtype=img.dtype) + self.value
53 |
54 |
55 | # NOTE:
56 | # apply_coords should be implemeted in paste transform, but not yet done
57 |
58 |
59 | class CenterPaste(ImageAugmentor):
60 | """
61 | Paste the image onto the center of a background canvas.
62 | """
63 |
64 | def __init__(self, background_shape, background_filler=None):
65 | """
66 | Args:
67 | background_shape (tuple): shape of the background canvas.
68 | background_filler (BackgroundFiller): How to fill the background. Defaults to zero-filler.
69 | """
70 | if background_filler is None:
71 | background_filler = ConstantBackgroundFiller(0)
72 |
73 | self._init(locals())
74 |
75 | def get_transform(self, _):
76 | return TransformFactory(name=str(self), apply_image=lambda img: self._impl(img))
77 |
78 | def _impl(self, img):
79 | img_shape = img.shape[:2]
80 | assert self.background_shape[0] >= img_shape[0] and self.background_shape[1] >= img_shape[1]
81 |
82 | background = self.background_filler.fill(
83 | self.background_shape, img)
84 | y0 = int((self.background_shape[0] - img_shape[0]) * 0.5)
85 | x0 = int((self.background_shape[1] - img_shape[1]) * 0.5)
86 | background[y0:y0 + img_shape[0], x0:x0 + img_shape[1]] = img
87 | return background
88 |
89 |
90 | class RandomPaste(CenterPaste):
91 | """
92 | Randomly paste the image onto a background canvas.
93 | """
94 |
95 | def get_transform(self, img):
96 | img_shape = img.shape[:2]
97 | assert self.background_shape[0] > img_shape[0] and self.background_shape[1] > img_shape[1]
98 |
99 | y0 = self._rand_range(self.background_shape[0] - img_shape[0])
100 | x0 = self._rand_range(self.background_shape[1] - img_shape[1])
101 | l = int(x0), int(y0)
102 | return TransformFactory(name=str(self), apply_image=lambda img: self._impl(img, l))
103 |
104 | def _impl(self, img, loc):
105 | x0, y0 = loc
106 | img_shape = img.shape[:2]
107 | background = self.background_filler.fill(
108 | self.background_shape, img)
109 | background[y0:y0 + img_shape[0], x0:x0 + img_shape[1]] = img
110 | return background
111 |
--------------------------------------------------------------------------------
/third_party/tensorpack/tensorpack/dataflow/serialize_test.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | import numpy as np
5 | import os
6 | import unittest
7 |
8 | from tensorpack.dataflow import HDF5Serializer, LMDBSerializer, NumpySerializer, TFRecordSerializer
9 | from tensorpack.dataflow.base import DataFlow
10 |
11 |
12 | def delete_file_if_exists(fn):
13 | try:
14 | os.remove(fn)
15 | except OSError:
16 | pass
17 |
18 |
19 | class SeededFakeDataFlow(DataFlow):
20 | """docstring for SeededFakeDataFlow"""
21 |
22 | def __init__(self, seed=42, size=32):
23 | super(SeededFakeDataFlow, self).__init__()
24 | self.seed = seed
25 | self._size = size
26 | self.cache = []
27 |
28 | def reset_state(self):
29 | np.random.seed(self.seed)
30 | for _ in range(self._size):
31 | label = np.random.randint(low=0, high=10)
32 | img = np.random.randn(28, 28, 3)
33 | self.cache.append([label, img])
34 |
35 | def __len__(self):
36 | return self._size
37 |
38 | def __iter__(self):
39 | for dp in self.cache:
40 | yield dp
41 |
42 |
43 | class SerializerTest(unittest.TestCase):
44 |
45 | def run_write_read_test(self, file, serializer, w_args, w_kwargs, r_args, r_kwargs, error_msg):
46 | try:
47 | delete_file_if_exists(file)
48 |
49 | ds_expected = SeededFakeDataFlow()
50 | serializer.save(ds_expected, file, *w_args, **w_kwargs)
51 | ds_actual = serializer.load(file, *r_args, **r_kwargs)
52 |
53 | ds_actual.reset_state()
54 | ds_expected.reset_state()
55 |
56 | for dp_expected, dp_actual in zip(ds_expected.__iter__(), ds_actual.__iter__()):
57 | self.assertEqual(dp_expected[0], dp_actual[0])
58 | self.assertTrue(np.allclose(dp_expected[1], dp_actual[1]))
59 | except ImportError:
60 | print(error_msg)
61 |
62 | def test_lmdb(self):
63 | self.run_write_read_test('test.lmdb', LMDBSerializer,
64 | {}, {},
65 | {}, {'shuffle': False},
66 | 'Skip test_lmdb, no lmdb available')
67 |
68 | def test_tfrecord(self):
69 | self.run_write_read_test('test.tfrecord', TFRecordSerializer,
70 | {}, {},
71 | {}, {'size': 32},
72 | 'Skip test_tfrecord, no tensorflow available')
73 |
74 | def test_numpy(self):
75 | self.run_write_read_test('test.npz', NumpySerializer,
76 | {}, {},
77 | {}, {'shuffle': False},
78 | 'Skip test_numpy, no numpy available')
79 |
80 | def test_hdf5(self):
81 | args = [['label', 'image']]
82 | self.run_write_read_test('test.h5', HDF5Serializer,
83 | args, {},
84 | args, {'shuffle': False},
85 | 'Skip test_hdf5, no h5py available')
86 |
87 |
88 | if __name__ == '__main__':
89 | unittest.main()
90 |
--------------------------------------------------------------------------------
/third_party/tensorpack/tensorpack/graph_builder/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File: __init__.py
3 |
4 | # https://github.com/celery/kombu/blob/7d13f9b95d0b50c94393b962e6def928511bfda6/kombu/__init__.py#L34-L36
5 | STATICA_HACK = True
6 | globals()['kcah_acitats'[::-1].upper()] = False
7 | if STATICA_HACK:
8 | from .model_desc import *
9 | from .training import *
10 | from .distributed import *
11 | from .utils import *
12 |
13 | from .model_desc import InputDesc, ModelDesc, ModelDescBase # kept for BC # noqa
14 |
15 | from pkgutil import iter_modules
16 | import os
17 | import os.path
18 |
19 | __all__ = []
20 |
21 | def global_import(name):
22 | p = __import__(name, globals(), locals(), level=1)
23 | lst = p.__all__ if '__all__' in dir(p) else []
24 | del globals()[name]
25 | for k in lst:
26 | if not k.startswith('__'):
27 | globals()[k] = p.__dict__[k]
28 | __all__.append(k)
29 |
30 |
31 | _CURR_DIR = os.path.dirname(__file__)
32 | _SKIP = ['distributed']
33 | for _, module_name, _ in iter_modules(
34 | [_CURR_DIR]):
35 | srcpath = os.path.join(_CURR_DIR, module_name + '.py')
36 | if not os.path.isfile(srcpath):
37 | continue
38 | if module_name.startswith('_'):
39 | continue
40 | if module_name not in _SKIP:
41 | global_import(module_name)
42 |
--------------------------------------------------------------------------------
/third_party/tensorpack/tensorpack/graph_builder/model_desc.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File: model_desc.py
3 |
4 |
5 | from collections import namedtuple
6 | import tensorflow as tf
7 |
8 | from ..utils.develop import log_deprecated
9 | from ..train.model_desc import ModelDesc, ModelDescBase # kept for BC # noqa
10 |
11 |
12 | __all__ = ['InputDesc']
13 |
14 |
15 | class InputDesc(
16 | namedtuple('InputDescTuple', ['type', 'shape', 'name'])):
17 | """
18 | An equivalent of `tf.TensorSpec`.
19 |
20 | History: this concept is used to represent metadata about the inputs,
21 | which can be later used to build placeholders or other types of input source.
22 | It is introduced much much earlier than the equivalent concept `tf.TensorSpec`
23 | was introduced in TensorFlow.
24 | Therefore, we now switched to use `tf.TensorSpec`, but keep this here for compatibility reasons.
25 | """
26 |
27 | def __new__(cls, type, shape, name):
28 | """
29 | Args:
30 | type (tf.DType):
31 | shape (tuple):
32 | name (str):
33 | """
34 | log_deprecated("InputDesc", "Use tf.TensorSpec instead!", "2020-03-01")
35 | assert isinstance(type, tf.DType), type
36 | return tf.TensorSpec(shape=shape, dtype=type, name=name)
37 |
--------------------------------------------------------------------------------
/third_party/tensorpack/tensorpack/input_source/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File: __init__.py
3 |
4 | # https://github.com/celery/kombu/blob/7d13f9b95d0b50c94393b962e6def928511bfda6/kombu/__init__.py#L34-L36
5 | STATICA_HACK = True
6 | globals()['kcah_acitats'[::-1].upper()] = False
7 | if STATICA_HACK:
8 | from .input_source_base import *
9 | from .input_source import *
10 |
11 | from pkgutil import iter_modules
12 | import os
13 | import os.path
14 |
15 | __all__ = []
16 |
17 |
18 | def global_import(name):
19 | p = __import__(name, globals(), locals(), level=1)
20 | lst = p.__all__ if '__all__' in dir(p) else []
21 | del globals()[name]
22 | for k in lst:
23 | if not k.startswith('__'):
24 | globals()[k] = p.__dict__[k]
25 | __all__.append(k)
26 |
27 |
28 | _CURR_DIR = os.path.dirname(__file__)
29 | _SKIP = []
30 | for _, module_name, _ in iter_modules(
31 | [_CURR_DIR]):
32 | srcpath = os.path.join(_CURR_DIR, module_name + '.py')
33 | if not os.path.isfile(srcpath):
34 | continue
35 | if module_name.startswith('_'):
36 | continue
37 | if module_name not in _SKIP:
38 | global_import(module_name)
39 |
--------------------------------------------------------------------------------
/third_party/tensorpack/tensorpack/libinfo.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 |
4 | # issue#7378 may happen with custom opencv. It doesn't hurt to disable opencl
5 | os.environ['OPENCV_OPENCL_RUNTIME'] = 'disabled' # https://github.com/opencv/opencv/pull/10155
6 | try:
7 | # issue#1924 may happen on old systems
8 | import cv2 # noqa
9 | # cv2.setNumThreads(0)
10 | if int(cv2.__version__.split('.')[0]) >= 3:
11 | cv2.ocl.setUseOpenCL(False)
12 | # check if cv is built with cuda or openmp
13 | info = cv2.getBuildInformation().split('\n')
14 | for line in info:
15 | splits = line.split()
16 | if not len(splits):
17 | continue
18 | answer = splits[-1].lower()
19 | if answer in ['yes', 'no']:
20 | if 'cuda' in line.lower() and answer == 'yes':
21 | # issue#1197
22 | print("OpenCV is built with CUDA support. "
23 | "This may cause slow initialization or sometimes segfault with TensorFlow.")
24 | if answer == 'openmp':
25 | print("OpenCV is built with OpenMP support. This usually results in poor performance. For details, see "
26 | "https://github.com/tensorpack/benchmarks/blob/master/ImageNet/benchmark-opencv-resize.py")
27 | except (ImportError, TypeError):
28 | pass
29 |
30 | os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1' # issue#9339
31 | os.environ['TF_AUTOTUNE_THRESHOLD'] = '2' # use more warm-up
32 |
33 | # Since 1.3, this is not needed
34 | os.environ['TF_AVGPOOL_USE_CUDNN'] = '1' # issue#8566
35 |
36 | # TF1.5 features
37 | os.environ['TF_SYNC_ON_FINISH'] = '0' # will become default
38 | os.environ['TF_GPU_THREAD_MODE'] = 'gpu_private'
39 | os.environ['TF_GPU_THREAD_COUNT'] = '2'
40 |
41 | # Available in TF1.6+ & cudnn7. Haven't seen different performance on R50.
42 | # NOTE we disable it because:
43 | # this mode may use scaled atomic integer reduction that may cause a numerical
44 | # overflow for certain input data range.
45 | os.environ['TF_USE_CUDNN_BATCHNORM_SPATIAL_PERSISTENT'] = '0'
46 |
47 | # Available since 1.12. issue#15874
48 | # But they're sometimes buggy. We leave this decision to users.
49 | # os.environ['TF_ENABLE_WHILE_V2'] = '1'
50 | # os.environ['TF_ENABLE_COND_V2'] = '1'
51 |
52 | try:
53 | import tensorflow as tf # noqa
54 | _version = tf.__version__.split('.')
55 | assert (int(_version[0]), int(_version[1])) >= (1, 3), "TF>=1.3 is required!"
56 | _HAS_TF = True
57 | except ImportError:
58 | print("Failed to import tensorflow.")
59 | _HAS_TF = False
60 | else:
61 | # Install stacktrace handler
62 | try:
63 | from tensorflow.python.framework import test_util
64 | test_util.InstallStackTraceHandler()
65 | except Exception:
66 | pass
67 |
68 | # silence the massive deprecation warnings in TF 1.13+
69 | if (int(_version[0]), int(_version[1])) >= (1, 13):
70 | try:
71 | from tensorflow.python.util.deprecation import silence
72 | except Exception:
73 | pass
74 | else:
75 | silence().__enter__()
76 | try:
77 | from tensorflow.python.util import deprecation_wrapper
78 | deprecation_wrapper._PER_MODULE_WARNING_LIMIT = 0
79 | except Exception:
80 | pass
81 |
82 | # Monkey-patch tf.test.is_gpu_available to avoid side effects:
83 | # https://github.com/tensorflow/tensorflow/issues/26460
84 | try:
85 | list_dev = tf.config.experimental.list_physical_devices
86 | except AttributeError:
87 | pass
88 | else:
89 | old_is_gpu_available = tf.test.is_gpu_available
90 |
91 | def is_gpu_available(*args, **kwargs):
92 | if len(args) == 0 and len(kwargs) == 0:
93 | return len(list_dev('GPU')) > 0
94 | return old_is_gpu_available(*args, **kwargs)
95 |
96 | tf.test.is_gpu_available = is_gpu_available
97 |
98 |
99 | # These lines will be programatically read/write by setup.py
100 | # Don't touch them.
101 | __version__ = '0.9.8'
102 | __git_version__ = "v0.9.8-61-g4ac2e22b-dirty"
--------------------------------------------------------------------------------
/third_party/tensorpack/tensorpack/models/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File: __init__.py
3 |
4 | # https://github.com/celery/kombu/blob/7d13f9b95d0b50c94393b962e6def928511bfda6/kombu/__init__.py#L34-L36
5 | STATICA_HACK = True
6 | globals()['kcah_acitats'[::-1].upper()] = False
7 | if STATICA_HACK:
8 | from .batch_norm import *
9 | from .common import *
10 | from .conv2d import *
11 | from .fc import *
12 | from .layer_norm import *
13 | from .linearwrap import *
14 | from .nonlin import *
15 | from .pool import *
16 | from .regularize import *
17 |
18 |
19 | from pkgutil import iter_modules
20 | import os
21 | import os.path
22 | # this line is necessary for _TFModuleFunc to work
23 | import tensorflow as tf # noqa: F401
24 |
25 | __all__ = []
26 |
27 |
28 | def _global_import(name):
29 | p = __import__(name, globals(), locals(), level=1)
30 | lst = p.__all__ if '__all__' in dir(p) else dir(p)
31 | del globals()[name]
32 | for k in lst:
33 | if not k.startswith('__'):
34 | globals()[k] = p.__dict__[k]
35 | __all__.append(k)
36 |
37 |
38 | _CURR_DIR = os.path.dirname(__file__)
39 | _SKIP = ['utils', 'registry', 'tflayer']
40 | for _, module_name, _ in iter_modules(
41 | [_CURR_DIR]):
42 | srcpath = os.path.join(_CURR_DIR, module_name + '.py')
43 | if not os.path.isfile(srcpath):
44 | continue
45 | if module_name.startswith('_'):
46 | continue
47 | if "_test" in module_name:
48 | continue
49 | if module_name not in _SKIP:
50 | _global_import(module_name)
51 |
--------------------------------------------------------------------------------
/third_party/tensorpack/tensorpack/models/common.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File: common.py
3 |
4 | from .registry import layer_register, disable_layer_logging # noqa
5 | from .tflayer import rename_tflayer_get_variable
6 | from .utils import VariableHolder # noqa
7 |
8 | __all__ = ['layer_register', 'VariableHolder', 'rename_tflayer_get_variable',
9 | 'disable_layer_logging']
10 |
--------------------------------------------------------------------------------
/third_party/tensorpack/tensorpack/models/fc.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File: fc.py
3 |
4 |
5 | import numpy as np
6 | from ..compat import tfv1 as tf # this should be avoided first in model code
7 |
8 | from ..tfutils.common import get_tf_version_tuple
9 | from .common import VariableHolder, layer_register
10 | from .tflayer import convert_to_tflayer_args, rename_get_variable
11 |
12 | __all__ = ['FullyConnected']
13 |
14 |
15 | def batch_flatten(x):
16 | """
17 | Flatten the tensor except the first dimension.
18 | """
19 | shape = x.get_shape().as_list()[1:]
20 | if None not in shape:
21 | return tf.reshape(x, [-1, int(np.prod(shape))])
22 | return tf.reshape(x, tf.stack([tf.shape(x)[0], -1]))
23 |
24 |
25 | @layer_register(log_shape=True)
26 | @convert_to_tflayer_args(
27 | args_names=['units'],
28 | name_mapping={'out_dim': 'units'})
29 | def FullyConnected(
30 | inputs,
31 | units,
32 | activation=None,
33 | use_bias=True,
34 | kernel_initializer=None,
35 | bias_initializer=tf.zeros_initializer(),
36 | kernel_regularizer=None,
37 | bias_regularizer=None,
38 | activity_regularizer=None):
39 | """
40 | A wrapper around `tf.layers.Dense`.
41 | One difference to maintain backward-compatibility:
42 | Default weight initializer is variance_scaling_initializer(2.0).
43 |
44 | Variable Names:
45 |
46 | * ``W``: weights of shape [in_dim, out_dim]
47 | * ``b``: bias
48 | """
49 | if kernel_initializer is None:
50 | if get_tf_version_tuple() <= (1, 12):
51 | kernel_initializer = tf.contrib.layers.variance_scaling_initializer(2.0) # deprecated
52 | else:
53 | kernel_initializer = tf.keras.initializers.VarianceScaling(2.0, distribution='untruncated_normal')
54 |
55 | inputs = batch_flatten(inputs)
56 | with rename_get_variable({'kernel': 'W', 'bias': 'b'}):
57 | layer = tf.layers.Dense(
58 | units=units,
59 | activation=activation,
60 | use_bias=use_bias,
61 | kernel_initializer=kernel_initializer,
62 | bias_initializer=bias_initializer,
63 | kernel_regularizer=kernel_regularizer,
64 | bias_regularizer=bias_regularizer,
65 | activity_regularizer=activity_regularizer,
66 | _reuse=tf.get_variable_scope().reuse)
67 | ret = layer.apply(inputs, scope=tf.get_variable_scope())
68 | ret = tf.identity(ret, name='output')
69 |
70 | ret.variables = VariableHolder(W=layer.kernel)
71 | if use_bias:
72 | ret.variables.b = layer.bias
73 | return ret
74 |
--------------------------------------------------------------------------------
/third_party/tensorpack/tensorpack/models/layer_norm.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File: layer_norm.py
3 |
4 |
5 | from ..compat import tfv1 as tf # this should be avoided first in model code
6 |
7 | from ..utils.argtools import get_data_format
8 | from .common import VariableHolder, layer_register
9 |
10 | __all__ = ['LayerNorm', 'InstanceNorm']
11 |
12 |
13 | @layer_register()
14 | def LayerNorm(
15 | x, epsilon=1e-5,
16 | use_bias=True, use_scale=True,
17 | gamma_init=None, data_format='channels_last'):
18 | """
19 | Layer Normalization layer, as described in the paper:
20 | `Layer Normalization `_.
21 |
22 | Args:
23 | x (tf.Tensor): a 4D or 2D tensor. When 4D, the layout should match data_format.
24 | epsilon (float): epsilon to avoid divide-by-zero.
25 | use_scale, use_bias (bool): whether to use the extra affine transformation or not.
26 | """
27 | data_format = get_data_format(data_format, keras_mode=False)
28 | shape = x.get_shape().as_list()
29 | ndims = len(shape)
30 | assert ndims in [2, 4]
31 |
32 | mean, var = tf.nn.moments(x, list(range(1, len(shape))), keep_dims=True)
33 |
34 | if data_format == 'NCHW':
35 | chan = shape[1]
36 | new_shape = [1, chan, 1, 1]
37 | else:
38 | chan = shape[-1]
39 | new_shape = [1, 1, 1, chan]
40 | if ndims == 2:
41 | new_shape = [1, chan]
42 |
43 | if use_bias:
44 | beta = tf.get_variable('beta', [chan], initializer=tf.constant_initializer())
45 | beta = tf.reshape(beta, new_shape)
46 | else:
47 | beta = tf.zeros([1] * ndims, name='beta')
48 | if use_scale:
49 | if gamma_init is None:
50 | gamma_init = tf.constant_initializer(1.0)
51 | gamma = tf.get_variable('gamma', [chan], initializer=gamma_init)
52 | gamma = tf.reshape(gamma, new_shape)
53 | else:
54 | gamma = tf.ones([1] * ndims, name='gamma')
55 |
56 | ret = tf.nn.batch_normalization(x, mean, var, beta, gamma, epsilon, name='output')
57 |
58 | vh = ret.variables = VariableHolder()
59 | if use_scale:
60 | vh.gamma = gamma
61 | if use_bias:
62 | vh.beta = beta
63 | return ret
64 |
65 |
66 | @layer_register()
67 | def InstanceNorm(x, epsilon=1e-5, use_affine=True, gamma_init=None, data_format='channels_last'):
68 | """
69 | Instance Normalization, as in the paper:
70 | `Instance Normalization: The Missing Ingredient for Fast Stylization
71 | `_.
72 |
73 | Args:
74 | x (tf.Tensor): a 4D tensor.
75 | epsilon (float): avoid divide-by-zero
76 | use_affine (bool): whether to apply learnable affine transformation
77 | """
78 | data_format = get_data_format(data_format, keras_mode=False)
79 | shape = x.get_shape().as_list()
80 | assert len(shape) == 4, "Input of InstanceNorm has to be 4D!"
81 |
82 | if data_format == 'NHWC':
83 | axis = [1, 2]
84 | ch = shape[3]
85 | new_shape = [1, 1, 1, ch]
86 | else:
87 | axis = [2, 3]
88 | ch = shape[1]
89 | new_shape = [1, ch, 1, 1]
90 | assert ch is not None, "Input of InstanceNorm require known channel!"
91 |
92 | mean, var = tf.nn.moments(x, axis, keep_dims=True)
93 |
94 | if not use_affine:
95 | return tf.divide(x - mean, tf.sqrt(var + epsilon), name='output')
96 |
97 | beta = tf.get_variable('beta', [ch], initializer=tf.constant_initializer())
98 | beta = tf.reshape(beta, new_shape)
99 | if gamma_init is None:
100 | gamma_init = tf.constant_initializer(1.0)
101 | gamma = tf.get_variable('gamma', [ch], initializer=gamma_init)
102 | gamma = tf.reshape(gamma, new_shape)
103 | ret = tf.nn.batch_normalization(x, mean, var, beta, gamma, epsilon, name='output')
104 |
105 | vh = ret.variables = VariableHolder()
106 | if use_affine:
107 | vh.gamma = gamma
108 | vh.beta = beta
109 | return ret
110 |
--------------------------------------------------------------------------------
/third_party/tensorpack/tensorpack/models/nonlin.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File: nonlin.py
3 |
4 |
5 | import tensorflow as tf
6 |
7 | from ..utils.develop import log_deprecated
8 | from ..compat import tfv1
9 | from .batch_norm import BatchNorm
10 | from .common import VariableHolder, layer_register
11 | from .utils import disable_autograph
12 |
13 | __all__ = ['Maxout', 'PReLU', 'BNReLU']
14 |
15 |
16 | @layer_register(use_scope=None)
17 | def Maxout(x, num_unit):
18 | """
19 | Maxout as in the paper `Maxout Networks `_.
20 |
21 | Args:
22 | x (tf.Tensor): a NHWC or NC tensor. Channel has to be known.
23 | num_unit (int): a int. Must be divisible by C.
24 |
25 | Returns:
26 | tf.Tensor: of shape NHW(C/num_unit) named ``output``.
27 | """
28 | input_shape = x.get_shape().as_list()
29 | ndim = len(input_shape)
30 | assert ndim == 4 or ndim == 2
31 | ch = input_shape[-1]
32 | assert ch is not None and ch % num_unit == 0
33 | if ndim == 4:
34 | x = tf.reshape(x, [-1, input_shape[1], input_shape[2], ch / num_unit, num_unit])
35 | else:
36 | x = tf.reshape(x, [-1, ch / num_unit, num_unit])
37 | return tf.reduce_max(x, ndim, name='output')
38 |
39 |
40 | @layer_register()
41 | @disable_autograph()
42 | def PReLU(x, init=0.001, name=None):
43 | """
44 | Parameterized ReLU as in the paper `Delving Deep into Rectifiers: Surpassing
45 | Human-Level Performance on ImageNet Classification
46 | `_.
47 |
48 | Args:
49 | x (tf.Tensor): input
50 | init (float): initial value for the learnable slope.
51 | name (str): deprecated argument. Don't use
52 |
53 | Variable Names:
54 |
55 | * ``alpha``: learnable slope.
56 | """
57 | if name is not None:
58 | log_deprecated("PReLU(name=...)", "The output tensor will be named `output`.")
59 | init = tfv1.constant_initializer(init)
60 | alpha = tfv1.get_variable('alpha', [], initializer=init)
61 | x = ((1 + alpha) * x + (1 - alpha) * tf.abs(x))
62 | ret = tf.multiply(x, 0.5, name=name or None)
63 |
64 | ret.variables = VariableHolder(alpha=alpha)
65 | return ret
66 |
67 |
68 | @layer_register(use_scope=None)
69 | def BNReLU(x, name=None):
70 | """
71 | A shorthand of BatchNormalization + ReLU.
72 |
73 | Args:
74 | x (tf.Tensor): the input
75 | name: deprecated, don't use.
76 | """
77 | if name is not None:
78 | log_deprecated("BNReLU(name=...)", "The output tensor will be named `output`.")
79 |
80 | x = BatchNorm('bn', x)
81 | x = tf.nn.relu(x, name=name)
82 | return x
83 |
--------------------------------------------------------------------------------
/third_party/tensorpack/tensorpack/models/shape_utils.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File: shape_utils.py
3 |
4 | import tensorflow as tf
5 |
6 | __all__ = []
7 |
8 |
9 | class StaticDynamicAxis(object):
10 | def __init__(self, static, dynamic):
11 | self.static = static
12 | self.dynamic = dynamic
13 |
14 | def apply(self, f):
15 | try:
16 | st = f(self.static)
17 | return StaticDynamicAxis(st, st)
18 | except TypeError:
19 | return StaticDynamicAxis(None, f(self.dynamic))
20 |
21 | def __str__(self):
22 | return "S={}, D={}".format(str(self.static), str(self.dynamic))
23 |
24 |
25 | def DynamicLazyAxis(shape, idx):
26 | return lambda: shape[idx]
27 |
28 |
29 | def StaticLazyAxis(dim):
30 | return lambda: dim
31 |
32 |
33 | class StaticDynamicShape(object):
34 | def __init__(self, tensor):
35 | assert isinstance(tensor, tf.Tensor), tensor
36 | ndims = tensor.shape.ndims
37 | self.static = tensor.shape.as_list()
38 | if tensor.shape.is_fully_defined():
39 | self.dynamic = self.static[:]
40 | else:
41 | dynamic = tf.shape(tensor)
42 | self.dynamic = [DynamicLazyAxis(dynamic, k) for k in range(ndims)]
43 |
44 | for k in range(ndims):
45 | if self.static[k] is not None:
46 | self.dynamic[k] = StaticLazyAxis(self.static[k])
47 |
48 | def apply(self, axis, f):
49 | if self.static[axis] is not None:
50 | try:
51 | st = f(self.static[axis])
52 | self.static[axis] = st
53 | self.dynamic[axis] = StaticLazyAxis(st)
54 | return
55 | except TypeError:
56 | pass
57 | self.static[axis] = None
58 | dyn = self.dynamic[axis]
59 | self.dynamic[axis] = lambda: f(dyn())
60 |
61 | def get_static(self):
62 | return self.static
63 |
64 | @property
65 | def ndims(self):
66 | return len(self.static)
67 |
68 | def get_dynamic(self, axis=None):
69 | if axis is None:
70 | return [self.dynamic[k]() for k in range(self.ndims)]
71 | return self.dynamic[axis]()
72 |
73 |
74 | if __name__ == '__main__':
75 | x = tf.placeholder(tf.float32, shape=[None, 3, None, 10])
76 | shape = StaticDynamicShape(x)
77 | shape.apply(1, lambda x: x * 3)
78 | shape.apply(2, lambda x: x + 5)
79 | print(shape.get_static())
80 | print(shape.get_dynamic())
81 |
--------------------------------------------------------------------------------
/third_party/tensorpack/tensorpack/models/shapes.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File: shapes.py
3 |
4 |
5 | import tensorflow as tf
6 |
7 | from .common import layer_register
8 |
9 | __all__ = ['ConcatWith']
10 |
11 |
12 | @layer_register(use_scope=None)
13 | def ConcatWith(x, tensor, dim):
14 | """
15 | A wrapper around ``tf.concat`` to cooperate with :class:`LinearWrap`.
16 |
17 | Args:
18 | x (tf.Tensor): input
19 | tensor (list[tf.Tensor]): a tensor or list of tensors to concatenate with x.
20 | x will be at the beginning
21 | dim (int): the dimension along which to concatenate
22 |
23 | Returns:
24 | tf.Tensor: ``tf.concat([x] + tensor, dim)``
25 | """
26 | if type(tensor) != list:
27 | tensor = [tensor]
28 | return tf.concat([x] + tensor, dim)
29 |
--------------------------------------------------------------------------------
/third_party/tensorpack/tensorpack/models/utils.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File: utils.py
3 |
4 | import six
5 |
6 |
7 | class VariableHolder(object):
8 | """ A proxy to access variables defined in a layer. """
9 | def __init__(self, **kwargs):
10 | """
11 | Args:
12 | kwargs: {name:variable}
13 | """
14 | self._vars = {}
15 | for k, v in six.iteritems(kwargs):
16 | self._add_variable(k, v)
17 |
18 | def _add_variable(self, name, var):
19 | assert name not in self._vars
20 | self._vars[name] = var
21 |
22 | def __setattr__(self, name, var):
23 | if not name.startswith('_'):
24 | self._add_variable(name, var)
25 | else:
26 | # private attributes
27 | super(VariableHolder, self).__setattr__(name, var)
28 |
29 | def __getattr__(self, name):
30 | return self._vars[name]
31 |
32 | def all(self):
33 | """
34 | Returns:
35 | list of all variables
36 | """
37 | return list(six.itervalues(self._vars))
38 |
39 |
40 | try:
41 | # When BN is used as an activation, keras layers try to autograph.convert it
42 | # This leads to massive warnings so we disable it.
43 | from tensorflow.python.autograph.impl.api import do_not_convert as disable_autograph
44 | except ImportError:
45 | def disable_autograph():
46 | return lambda x: x
47 |
--------------------------------------------------------------------------------
/third_party/tensorpack/tensorpack/predict/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File: __init__.py
3 |
4 | # https://github.com/celery/kombu/blob/7d13f9b95d0b50c94393b962e6def928511bfda6/kombu/__init__.py#L34-L36
5 | STATICA_HACK = True
6 | globals()['kcah_acitats'[::-1].upper()] = False
7 | if STATICA_HACK:
8 | from .base import *
9 | from .concurrency import *
10 | from .config import *
11 | from .dataset import *
12 | from .multigpu import *
13 |
14 |
15 | from pkgutil import iter_modules
16 | import os
17 | import os.path
18 |
19 | __all__ = []
20 |
21 |
22 | def global_import(name):
23 | p = __import__(name, globals(), locals(), level=1)
24 | lst = p.__all__ if '__all__' in dir(p) else dir(p)
25 | if lst:
26 | del globals()[name]
27 | for k in lst:
28 | globals()[k] = p.__dict__[k]
29 | __all__.append(k)
30 |
31 |
32 | _CURR_DIR = os.path.dirname(__file__)
33 | for _, module_name, _ in iter_modules(
34 | [_CURR_DIR]):
35 | srcpath = os.path.join(_CURR_DIR, module_name + '.py')
36 | if not os.path.isfile(srcpath):
37 | continue
38 | if module_name.startswith('_'):
39 | continue
40 | global_import(module_name)
41 |
--------------------------------------------------------------------------------
/third_party/tensorpack/tensorpack/predict/feedfree.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | from tensorflow.python.training.monitored_session import _HookedSession as HookedSession
4 |
5 | from ..callbacks import Callbacks
6 | from ..tfutils.tower import PredictTowerContext
7 | from .base import PredictorBase
8 |
9 | __all__ = ['FeedfreePredictor']
10 |
11 |
12 | class FeedfreePredictor(PredictorBase):
13 | """
14 | Create a predictor that takes inputs from an :class:`InputSource`, instead of from feeds.
15 | An instance `pred` of :class:`FeedfreePredictor` can be called only by `pred()`, which returns
16 | a list of output values as defined in config.output_names.
17 | """
18 |
19 | def __init__(self, config, input_source):
20 | """
21 | Args:
22 | config (PredictConfig): the config to use.
23 | input_source (InputSource): the feedfree InputSource to use.
24 | Must match the signature of the tower function in config.
25 | """
26 | self._config = config
27 | self._input_source = input_source
28 | assert config.return_input is False, \
29 | "return_input is not supported in FeedfreePredictor! " \
30 | "If you need to fetch inputs, add the names to the output_names!"
31 |
32 | self._hooks = []
33 | self.graph = config._maybe_create_graph()
34 | with self.graph.as_default():
35 | self._input_callbacks = Callbacks(
36 | self._input_source.setup(config.input_signature))
37 | with PredictTowerContext(''):
38 | self._input_tensors = self._input_source.get_input_tensors()
39 | config.tower_func(*self._input_tensors)
40 | self._tower_handle = config.tower_func.towers[-1]
41 |
42 | self._output_tensors = self._tower_handle.get_tensors(config.output_names)
43 |
44 | self._input_callbacks.setup_graph(None)
45 |
46 | for h in self._input_callbacks.get_hooks():
47 | self._register_hook(h)
48 | self._initialize_session()
49 |
50 | def _register_hook(self, hook):
51 | """
52 | Args:
53 | hook (tf.train.SessionRunHook):
54 | """
55 | self._hooks.append(hook)
56 |
57 | def _initialize_session(self):
58 | # init the session
59 | self._config.session_init._setup_graph()
60 | self._sess = self._config.session_creator.create_session()
61 | self._config.session_init._run_init(self._sess)
62 |
63 | with self._sess.as_default():
64 | self._input_callbacks.before_train()
65 | self._hooked_sess = HookedSession(self._sess, self._hooks)
66 |
67 | def __call__(self):
68 | return self._hooked_sess.run(self._output_tensors)
69 |
70 | def _do_call(self):
71 | raise NotImplementedError("You're calling the wrong function!")
72 |
--------------------------------------------------------------------------------
/third_party/tensorpack/tensorpack/tfutils/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File: __init__.py
3 |
4 |
5 | from .tower import get_current_tower_context, TowerContext
6 |
7 | # https://github.com/celery/kombu/blob/7d13f9b95d0b50c94393b962e6def928511bfda6/kombu/__init__.py#L34-L36
8 | STATICA_HACK = True
9 | globals()['kcah_acitats'[::-1].upper()] = False
10 | if STATICA_HACK:
11 | from .common import *
12 | from .sessinit import *
13 | from .argscope import *
14 |
15 |
16 | # don't want to include everything from .tower
17 | __all__ = ['get_current_tower_context', 'TowerContext']
18 |
19 |
20 | def _global_import(name):
21 | p = __import__(name, globals(), None, level=1)
22 | lst = p.__all__ if '__all__' in dir(p) else dir(p)
23 | for k in lst:
24 | if not k.startswith('__'):
25 | globals()[k] = p.__dict__[k]
26 | __all__.append(k)
27 |
28 |
29 | _TO_IMPORT = frozenset([
30 | 'common',
31 | 'sessinit',
32 | 'argscope',
33 | ])
34 |
35 | for module_name in _TO_IMPORT:
36 | _global_import(module_name)
37 |
38 | """
39 | TODO remove this line in the future.
40 | Better to keep submodule names (sesscreate, varmanip, etc) out of __all__,
41 | so that these names will be invisible under `tensorpack.` namespace.
42 |
43 | To use these utilities, users are expected to import them explicitly, e.g.:
44 |
45 | import tensorpack.tfutils.sessinit as sessinit
46 | """
47 | __all__.extend(['sessinit', 'summary', 'optimizer',
48 | 'sesscreate', 'gradproc', 'varreplace',
49 | 'tower'])
50 |
--------------------------------------------------------------------------------
/third_party/tensorpack/tensorpack/tfutils/dependency.py:
--------------------------------------------------------------------------------
1 |
2 | import tensorflow as tf
3 |
4 | from ..utils.argtools import graph_memoized
5 |
6 | """
7 | Utils about parsing dependencies in the graph.
8 | """
9 |
10 | __all__ = [
11 | 'dependency_of_targets', 'dependency_of_fetches'
12 | ]
13 |
14 |
15 | @graph_memoized
16 | def dependency_of_targets(targets, op):
17 | """
18 | Check that op is in the subgraph induced by the dependencies of targets.
19 | The result is memoized.
20 |
21 | This is useful if some SessionRunHooks should be run only together with certain ops.
22 |
23 | Args:
24 | targets: a tuple of ops or tensors. The targets to find dependencies of.
25 | op (tf.Operation or tf.Tensor):
26 |
27 | Returns:
28 | bool: True if any one of `targets` depend on `op`.
29 | """
30 | # TODO tensorarray? sparsetensor?
31 | if isinstance(op, tf.Tensor):
32 | op = op.op
33 | assert isinstance(op, tf.Operation), op
34 |
35 | try:
36 | from tensorflow.contrib.graph_editor import get_backward_walk_ops # deprecated
37 | except ImportError:
38 | from tensorflow.python.ops.op_selector import get_backward_walk_ops
39 | # alternative implementation can use graph_util.extract_sub_graph
40 | dependent_ops = get_backward_walk_ops(targets, control_inputs=True)
41 | return op in dependent_ops
42 |
43 |
44 | def dependency_of_fetches(fetches, op):
45 | """
46 | Check that op is in the subgraph induced by the dependencies of fetches.
47 | fetches may have more general structure.
48 |
49 | Args:
50 | fetches: An argument to `sess.run`. Nested structure will affect performance.
51 | op (tf.Operation or tf.Tensor):
52 |
53 | Returns:
54 | bool: True if any of `fetches` depend on `op`.
55 | """
56 | try:
57 | from tensorflow.python.client.session import _FetchHandler as FetchHandler
58 | # use the graph of the op, so that this function can be called without being under a default graph
59 | handler = FetchHandler(op.graph, fetches, {})
60 | targets = tuple(handler.fetches() + handler.targets())
61 | except ImportError:
62 | if isinstance(fetches, list):
63 | targets = tuple(fetches)
64 | elif isinstance(fetches, dict):
65 | raise ValueError("Don't know how to parse dictionary to fetch list! "
66 | "This is a bug of tensorpack.")
67 | else:
68 | targets = (fetches, )
69 | return dependency_of_targets(targets, op)
70 |
71 |
72 | if __name__ == '__main__':
73 | a = tf.random_normal(shape=[3, 3])
74 | b = tf.random_normal(shape=[3, 3])
75 | print(dependency_of_fetches(a, a))
76 | print(dependency_of_fetches([a, b], a))
77 |
--------------------------------------------------------------------------------
/third_party/tensorpack/tensorpack/tfutils/distributed.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File: distributed.py
3 |
4 |
5 | import tensorflow as tf
6 |
7 |
8 | def get_distributed_session_creator(server):
9 | """
10 | Args:
11 | server (tf.train.Server):
12 |
13 | Returns:
14 | tf.train.SessionCreator
15 | """
16 |
17 | server_def = server.server_def
18 | is_chief = (server_def.job_name == 'worker') and (server_def.task_index == 0)
19 |
20 | init_op = tf.global_variables_initializer()
21 | local_init_op = tf.local_variables_initializer()
22 | ready_op = tf.report_uninitialized_variables()
23 | ready_for_local_init_op = tf.report_uninitialized_variables(tf.global_variables())
24 | sm = tf.train.SessionManager(
25 | local_init_op=local_init_op,
26 | ready_op=ready_op,
27 | ready_for_local_init_op=ready_for_local_init_op,
28 | graph=tf.get_default_graph())
29 |
30 | # to debug wrong variable collection
31 | # from pprint import pprint
32 | # print("GLOBAL:")
33 | # pprint([(k.name, k.device) for k in tf.global_variables()])
34 | # print("LOCAL:")
35 | # pprint([(k.name, k.device) for k in tf.local_variables()])
36 |
37 | class _Creator(tf.train.SessionCreator):
38 | def create_session(self):
39 | if is_chief:
40 | return sm.prepare_session(master=server.target, init_op=init_op)
41 | else:
42 | tf.logging.set_verbosity(tf.logging.INFO) # print message about uninitialized vars
43 | ret = sm.wait_for_session(master=server.target)
44 | tf.logging.set_verbosity(tf.logging.WARN)
45 | return ret
46 |
47 | return _Creator()
48 |
--------------------------------------------------------------------------------
/third_party/tensorpack/tensorpack/tfutils/model_utils.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File: model_utils.py
3 | # Author: tensorpack contributors
4 |
5 | from ..compat import tfv1 as tf
6 | from tabulate import tabulate
7 | from termcolor import colored
8 |
9 | from .common import get_op_tensor_name
10 | from ..utils import logger
11 |
12 | __all__ = []
13 |
14 |
15 | def describe_trainable_vars():
16 | """
17 | Print a description of the current model parameters.
18 | Skip variables starting with "tower", as they are just duplicates built by data-parallel logic.
19 | """
20 | train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
21 | if len(train_vars) == 0:
22 | logger.warn("No trainable variables in the graph!")
23 | return
24 | total = 0
25 | total_bytes = 0
26 | data = []
27 | for v in train_vars:
28 | if v.name.startswith('tower'):
29 | continue
30 | shape = v.get_shape()
31 | ele = shape.num_elements()
32 | if ele is None:
33 | logger.warn("Shape of variable {} is not fully defined but {}.".format(v.name, shape))
34 | ele = 0
35 | try:
36 | shape = shape.as_list()
37 | except ValueError:
38 | shape = ''
39 |
40 | total += ele
41 | total_bytes += ele * v.dtype.size
42 | data.append([get_op_tensor_name(v.name)[0], shape, ele, v.device, v.dtype.base_dtype.name])
43 | headers = ['name', 'shape', '#elements', 'device', 'dtype']
44 |
45 | dtypes = list({x[4] for x in data})
46 | if len(dtypes) == 1 and dtypes[0] == "float32":
47 | # don't log the dtype if all vars are float32 (default dtype)
48 | for x in data:
49 | del x[4]
50 | del headers[4]
51 |
52 | devices = {x[3] for x in data}
53 | if len(devices) == 1:
54 | # don't log the device if all vars on the same device
55 | for x in data:
56 | del x[3]
57 | del headers[3]
58 |
59 | table = tabulate(data, headers=headers)
60 |
61 | size_mb = total_bytes / 1024.0**2
62 | summary_msg = colored(
63 | "\nNumber of trainable variables: {}".format(len(data)) +
64 | "\nNumber of parameters (elements): {}".format(total) +
65 | "\nStorage space needed for all trainable variables: {:.02f}MB".format(size_mb),
66 | 'cyan')
67 | logger.info(colored("List of Trainable Variables: \n", 'cyan') + table + summary_msg)
68 |
69 |
70 | def get_shape_str(tensors):
71 | """
72 | Internally used by layer registry, to print shapes of inputs/outputs of layers.
73 |
74 | Args:
75 | tensors (list or tf.Tensor): a tensor or a list of tensors
76 | Returns:
77 | str: a string to describe the shape
78 | """
79 | if isinstance(tensors, (list, tuple)):
80 | for v in tensors:
81 | assert isinstance(v, (tf.Tensor, tf.Variable)), "Not a tensor: {}".format(type(v))
82 | shape_str = ", ".join(map(get_shape_str, tensors))
83 | else:
84 | assert isinstance(tensors, (tf.Tensor, tf.Variable)), "Not a tensor: {}".format(type(tensors))
85 | shape_str = str(tensors.get_shape().as_list()).replace("None", "?")
86 | return shape_str
87 |
--------------------------------------------------------------------------------
/third_party/tensorpack/tensorpack/tfutils/symbolic_functions.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File: symbolic_functions.py
3 |
4 |
5 | import tensorflow as tf
6 |
7 | from ..compat import tfv1
8 |
9 | __all__ = ['print_stat', 'rms']
10 |
11 |
12 | def print_stat(x, message=None):
13 | """ A simple print Op that might be easier to use than :meth:`tf.Print`.
14 | Use it like: ``x = print_stat(x, message='This is x')``.
15 | """
16 | if message is None:
17 | message = x.op.name
18 | lst = [tf.shape(x), tf.reduce_mean(x)]
19 | if x.dtype.is_floating:
20 | lst.append(rms(x))
21 | return tf.Print(x, lst + [x], summarize=20,
22 | message=message, name='print_' + x.op.name)
23 |
24 |
25 | # for internal use only
26 | def rms(x, name=None):
27 | """
28 | Returns:
29 | root mean square of tensor x.
30 | """
31 | if name is None:
32 | name = x.op.name + '/rms'
33 | with tfv1.name_scope(None): # name already contains the scope
34 | return tf.sqrt(tf.reduce_mean(tf.square(x)), name=name)
35 | return tf.sqrt(tf.reduce_mean(tf.square(x)), name=name)
36 |
37 |
38 | # don't hurt to leave it here
39 | def psnr(prediction, ground_truth, maxp=None, name='psnr'):
40 | """`Peak Signal to Noise Ratio `_.
41 |
42 | .. math::
43 |
44 | PSNR = 20 \cdot \log_{10}(MAX_p) - 10 \cdot \log_{10}(MSE)
45 |
46 | Args:
47 | prediction: a :class:`tf.Tensor` representing the prediction signal.
48 | ground_truth: another :class:`tf.Tensor` with the same shape.
49 | maxp: maximum possible pixel value of the image (255 in in 8bit images)
50 |
51 | Returns:
52 | A scalar tensor representing the PSNR
53 | """
54 |
55 | maxp = float(maxp)
56 |
57 | def log10(x):
58 | with tf.name_scope("log10"):
59 | numerator = tf.log(x)
60 | denominator = tf.log(tf.constant(10, dtype=numerator.dtype))
61 | return numerator / denominator
62 |
63 | mse = tf.reduce_mean(tf.square(prediction - ground_truth))
64 | if maxp is None:
65 | psnr = tf.multiply(log10(mse), -10., name=name)
66 | else:
67 | psnr = tf.multiply(log10(mse), -10.)
68 | psnr = tf.add(tf.multiply(20., log10(maxp)), psnr, name=name)
69 |
70 | return psnr
71 |
--------------------------------------------------------------------------------
/third_party/tensorpack/tensorpack/tfutils/unit_tests.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import unittest
4 | import tensorflow as tf
5 |
6 | from ..utils import logger
7 | from .scope_utils import under_name_scope
8 |
9 |
10 | class ScopeUtilsTest(unittest.TestCase):
11 |
12 | @under_name_scope(name_scope='s')
13 | def _f(self, check=True):
14 | if check:
15 | assert tf.get_default_graph().get_name_scope().endswith('s')
16 | return True
17 |
18 | def test_under_name_scope(self):
19 | self.assertTrue(self._f())
20 | with self.assertRaises(AssertionError):
21 | self._f() # name conflict
22 |
23 | def test_under_name_scope_warning(self):
24 | x = tf.placeholder(tf.float32, [3])
25 | tf.nn.relu(x, name='s')
26 | with self.assertLogs(logger=logger._logger, level='WARNING'):
27 | self._f(check=False, name_scope='s')
28 |
29 |
30 | if __name__ == '__main__':
31 | unittest.main()
32 |
--------------------------------------------------------------------------------
/third_party/tensorpack/tensorpack/train/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File: __init__.py
3 | # flake8: noqa
4 |
5 | # https://github.com/celery/kombu/blob/7d13f9b95d0b50c94393b962e6def928511bfda6/kombu/__init__.py#L34-L36
6 | STATICA_HACK = True
7 | globals()['kcah_acitats'[::-1].upper()] = False
8 | if STATICA_HACK:
9 | from .base import *
10 | from .config import *
11 | from .interface import *
12 | from .tower import *
13 | from .trainers import *
14 |
15 |
16 | from pkgutil import iter_modules
17 | import os
18 | import os.path
19 |
20 | __all__ = []
21 |
22 |
23 | def global_import(name):
24 | p = __import__(name, globals(), locals(), level=1)
25 | lst = p.__all__ if '__all__' in dir(p) else []
26 | if lst:
27 | del globals()[name]
28 | for k in lst:
29 | globals()[k] = p.__dict__[k]
30 | __all__.append(k)
31 |
32 |
33 | _CURR_DIR = os.path.dirname(__file__)
34 | _SKIP = ['utility']
35 | for _, module_name, _ in iter_modules(
36 | [_CURR_DIR]):
37 | srcpath = os.path.join(_CURR_DIR, module_name + '.py')
38 | if not os.path.isfile(srcpath):
39 | continue
40 | if module_name.startswith('_'):
41 | continue
42 | if module_name not in _SKIP:
43 | global_import(module_name)
44 |
--------------------------------------------------------------------------------
/third_party/tensorpack/tensorpack/train/utility.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File: utility.py
3 |
4 | # for backwards-compatibility
5 | from ..graph_builder.utils import LeastLoadedDeviceSetter, OverrideToLocalVariable, override_to_local_variable # noqa
6 |
--------------------------------------------------------------------------------
/third_party/tensorpack/tensorpack/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File: __init__.py
3 |
4 | """
5 | Common utils.
6 | These utils should be irrelevant to tensorflow.
7 | """
8 |
9 | # https://github.com/celery/kombu/blob/7d13f9b95d0b50c94393b962e6def928511bfda6/kombu/__init__.py#L34-L36
10 | STATICA_HACK = True
11 | globals()['kcah_acitats'[::-1].upper()] = False
12 | if STATICA_HACK:
13 | from .utils import *
14 |
15 |
16 | __all__ = []
17 |
18 |
19 | def _global_import(name):
20 | p = __import__(name, globals(), None, level=1)
21 | lst = p.__all__ if '__all__' in dir(p) else dir(p)
22 | for k in lst:
23 | if not k.startswith('__'):
24 | globals()[k] = p.__dict__[k]
25 | __all__.append(k)
26 |
27 |
28 | _global_import('utils')
29 |
30 | # Import no other submodules. they are supposed to be explicitly imported by users.
31 | __all__.extend(['logger'])
32 |
--------------------------------------------------------------------------------
/third_party/tensorpack/tensorpack/utils/compatible_serialize.py:
--------------------------------------------------------------------------------
1 | from .serialize import loads, dumps # noqa
2 |
3 | # keep this file for BC
4 |
--------------------------------------------------------------------------------
/third_party/tensorpack/tensorpack/utils/debug.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File: debug.py
3 |
4 |
5 | import sys
6 |
7 |
8 | def enable_call_trace():
9 | """ Enable trace for calls to any function. """
10 | def tracer(frame, event, arg):
11 | if event == 'call':
12 | co = frame.f_code
13 | func_name = co.co_name
14 | if func_name == 'write' or func_name == 'print':
15 | # ignore write() calls from print statements
16 | return
17 | func_line_no = frame.f_lineno
18 | func_filename = co.co_filename
19 | caller = frame.f_back
20 | if caller:
21 | caller_line_no = caller.f_lineno
22 | caller_filename = caller.f_code.co_filename
23 | print('Call to `%s` on line %s:%s from %s:%s' %
24 | (func_name, func_filename, func_line_no,
25 | caller_filename, caller_line_no))
26 | return
27 | sys.settrace(tracer)
28 |
29 |
30 | if __name__ == '__main__':
31 | enable_call_trace()
32 |
33 | def b(a):
34 | print(2)
35 |
36 | def a():
37 | print(1)
38 | b(1)
39 |
40 | a()
41 |
--------------------------------------------------------------------------------
/third_party/tensorpack/tensorpack/utils/gpu.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File: gpu.py
3 |
4 |
5 | import os
6 |
7 | from . import logger
8 | from .concurrency import subproc_call
9 | from .nvml import NVMLContext
10 | from .utils import change_env
11 |
12 | __all__ = ['change_gpu', 'get_nr_gpu', 'get_num_gpu']
13 |
14 |
15 | def change_gpu(val):
16 | """
17 | Args:
18 | val: an integer, the index of the GPU or -1 to disable GPU.
19 |
20 | Returns:
21 | a context where ``CUDA_VISIBLE_DEVICES=val``.
22 | """
23 | val = str(val)
24 | if val == '-1':
25 | val = ''
26 | return change_env('CUDA_VISIBLE_DEVICES', val)
27 |
28 |
29 | def get_num_gpu():
30 | """
31 | Returns:
32 | int: #available GPUs in CUDA_VISIBLE_DEVICES, or in the system.
33 | """
34 |
35 | def warn_return(ret, message):
36 | try:
37 | import tensorflow as tf
38 | except ImportError:
39 | return ret
40 |
41 | built_with_cuda = tf.test.is_built_with_cuda()
42 | if not built_with_cuda and ret > 0:
43 | logger.warn(message + "But TensorFlow was not built with CUDA support and could not use GPUs!")
44 | return ret
45 |
46 | env = os.environ.get('CUDA_VISIBLE_DEVICES', None)
47 | if env:
48 | return warn_return(len(env.split(',')), "Found non-empty CUDA_VISIBLE_DEVICES. ")
49 | output, code = subproc_call("nvidia-smi -L", timeout=5)
50 | if code == 0:
51 | output = output.decode('utf-8')
52 | return warn_return(len(output.strip().split('\n')), "Found nvidia-smi. ")
53 | try:
54 | # Use NVML to query device properties
55 | with NVMLContext() as ctx:
56 | return warn_return(ctx.num_devices(), "NVML found nvidia devices. ")
57 | except Exception:
58 | # Fallback
59 | logger.info("Loading local devices by TensorFlow ...")
60 |
61 | try:
62 | import tensorflow as tf
63 | # available since TF 1.14
64 | gpu_devices = tf.config.experimental.list_physical_devices('GPU')
65 | except AttributeError:
66 | from tensorflow.python.client import device_lib
67 | local_device_protos = device_lib.list_local_devices()
68 | # Note this will initialize all GPUs and therefore has side effect
69 | # https://github.com/tensorflow/tensorflow/issues/8136
70 | gpu_devices = [x.name for x in local_device_protos if x.device_type == 'GPU']
71 | return len(gpu_devices)
72 |
73 |
74 | get_nr_gpu = get_num_gpu
75 |
--------------------------------------------------------------------------------
/third_party/tensorpack/tensorpack/utils/naming.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File: naming.py
3 |
4 |
5 | GLOBAL_STEP_INCR_OP_NAME = 'global_step_incr'
6 |
7 | # extra variables to summarize during training in a moving-average way
8 | MOVING_SUMMARY_OPS_KEY = 'MOVING_SUMMARY_OPS'
9 |
--------------------------------------------------------------------------------
/third_party/tensorpack/tensorpack/utils/serialize.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File: serialize.py
3 |
4 | import os
5 |
6 | import pickle
7 | from multiprocessing.reduction import ForkingPickler
8 | import msgpack
9 | import msgpack_numpy
10 |
11 | msgpack_numpy.patch()
12 | assert msgpack.version >= (0, 5, 2)
13 |
14 | __all__ = ['loads', 'dumps']
15 |
16 |
17 | MAX_MSGPACK_LEN = 1000000000
18 |
19 |
20 | class MsgpackSerializer(object):
21 |
22 | @staticmethod
23 | def dumps(obj):
24 | """
25 | Serialize an object.
26 |
27 | Returns:
28 | Implementation-dependent bytes-like object.
29 | """
30 | return msgpack.dumps(obj, use_bin_type=True)
31 |
32 | @staticmethod
33 | def loads(buf):
34 | """
35 | Args:
36 | buf: the output of `dumps`.
37 | """
38 | # Since 0.6, the default max size was set to 1MB.
39 | # We change it to approximately 1G.
40 | return msgpack.loads(buf, raw=False,
41 | max_bin_len=MAX_MSGPACK_LEN,
42 | max_array_len=MAX_MSGPACK_LEN,
43 | max_map_len=MAX_MSGPACK_LEN,
44 | max_str_len=MAX_MSGPACK_LEN)
45 |
46 |
47 | class PyarrowSerializer(object):
48 | @staticmethod
49 | def dumps(obj):
50 | """
51 | Serialize an object.
52 |
53 | Returns:
54 | Implementation-dependent bytes-like object.
55 | May not be compatible across different versions of pyarrow.
56 | """
57 | import pyarrow as pa
58 | return pa.serialize(obj).to_buffer()
59 |
60 | @staticmethod
61 | def dumps_bytes(obj):
62 | """
63 | Returns:
64 | bytes
65 | """
66 | return PyarrowSerializer.dumps(obj).to_pybytes()
67 |
68 | @staticmethod
69 | def loads(buf):
70 | """
71 | Args:
72 | buf: the output of `dumps` or `dumps_bytes`.
73 | """
74 | import pyarrow as pa
75 | return pa.deserialize(buf)
76 |
77 |
78 | class PickleSerializer(object):
79 | @staticmethod
80 | def dumps(obj):
81 | """
82 | Returns:
83 | bytes
84 | """
85 | return pickle.dumps(obj, protocol=-1)
86 |
87 | @staticmethod
88 | def loads(buf):
89 | """
90 | Args:
91 | bytes
92 | """
93 | return pickle.loads(buf)
94 |
95 |
96 | # Define the default serializer to be used that dumps data to bytes
97 | _DEFAULT_S = os.environ.get('TENSORPACK_SERIALIZE', 'pickle')
98 |
99 | if _DEFAULT_S == "pyarrow":
100 | dumps = PyarrowSerializer.dumps_bytes
101 | loads = PyarrowSerializer.loads
102 | elif _DEFAULT_S == "pickle":
103 | dumps = PickleSerializer.dumps
104 | loads = PickleSerializer.loads
105 | else:
106 | dumps = MsgpackSerializer.dumps
107 | loads = MsgpackSerializer.loads
108 |
109 | # Define the default serializer to be used for passing data
110 | # among a pair of peers. In this case the deserialization is
111 | # known to happen only once
112 | _DEFAULT_S = os.environ.get('TENSORPACK_ONCE_SERIALIZE', 'pickle')
113 |
114 | if _DEFAULT_S == "pyarrow":
115 | dumps_once = PyarrowSerializer.dumps
116 | loads_once = PyarrowSerializer.loads
117 | elif _DEFAULT_S == "pickle":
118 | dumps_once = ForkingPickler.dumps
119 | loads_once = ForkingPickler.loads
120 | else:
121 | dumps_once = MsgpackSerializer.dumps
122 | loads_once = MsgpackSerializer.loads
123 |
--------------------------------------------------------------------------------
/third_party/tensorpack/tests/benchmark-serializer.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | import numpy as np
4 | import argparse
5 | import pyarrow as pa
6 | from tabulate import tabulate
7 | import operator
8 | from tensorpack.utils import logger
9 | from tensorpack.utils.serialize import (
10 | MsgpackSerializer,
11 | PyarrowSerializer,
12 | PickleSerializer,
13 | ForkingPickler,
14 | )
15 | from tensorpack.utils.timer import Timer
16 |
17 |
18 | def benchmark_serializer(dumps, loads, data, num):
19 | buf = dumps(data)
20 |
21 | enc_timer = Timer()
22 | dec_timer = Timer()
23 | enc_timer.pause()
24 | dec_timer.pause()
25 |
26 | for k in range(num):
27 | enc_timer.resume()
28 | buf = dumps(data)
29 | enc_timer.pause()
30 |
31 | dec_timer.resume()
32 | loads(buf)
33 | dec_timer.pause()
34 |
35 | dumps_time = enc_timer.seconds() / num
36 | loads_time = dec_timer.seconds() / num
37 | return dumps_time, loads_time
38 |
39 |
40 | def display_results(name, results):
41 | logger.info("Encoding benchmark for {}:".format(name))
42 | data = sorted(((x, y[0]) for x, y in results), key=operator.itemgetter(1))
43 | print(tabulate(data, floatfmt='.5f'))
44 |
45 | logger.info("Decoding benchmark for {}:".format(name))
46 | data = sorted(((x, y[1]) for x, y in results), key=operator.itemgetter(1))
47 | print(tabulate(data, floatfmt='.5f'))
48 |
49 |
50 | def benchmark_all(name, serializers, data, num=30):
51 | logger.info("Benchmarking {} ...".format(name))
52 | results = []
53 | for serializer_name, dumps, loads in serializers:
54 | results.append((serializer_name, benchmark_serializer(dumps, loads, data, num=num)))
55 | display_results(name, results)
56 |
57 |
58 | def fake_json_data():
59 | return {
60 | 'words': """
61 | Lorem ipsum dolor sit amet, consectetur adipiscing
62 | elit. Mauris adipiscing adipiscing placerat.
63 | Vestibulum augue augue,
64 | pellentesque quis sollicitudin id, adipiscing.
65 | """ * 100,
66 | 'list': list(range(100)) * 500,
67 | 'dict': {str(i): 'a' for i in range(50000)},
68 | 'dict2': {i: 'a' for i in range(50000)},
69 | 'int': 3000,
70 | 'float': 100.123456
71 | }
72 |
73 |
74 | if __name__ == '__main__':
75 | parser = argparse.ArgumentParser()
76 | parser.add_argument("task")
77 | args = parser.parse_args()
78 |
79 | serializers = [
80 | ("msgpack", MsgpackSerializer.dumps, MsgpackSerializer.loads),
81 | ("pyarrow-buf", PyarrowSerializer.dumps, PyarrowSerializer.loads),
82 | ("pyarrow-bytes", PyarrowSerializer.dumps_bytes, PyarrowSerializer.loads),
83 | ("pickle", PickleSerializer.dumps, PickleSerializer.loads),
84 | ("forking-pickle", ForkingPickler.dumps, ForkingPickler.loads),
85 | ]
86 |
87 | if args.task == "numpy":
88 | numpy_data = [np.random.rand(64, 224, 224, 3).astype("float32"), np.random.rand(64).astype('int32')]
89 | benchmark_all("numpy data", serializers, numpy_data)
90 | elif args.task == "json":
91 | benchmark_all("json data", serializers, fake_json_data(), num=50)
92 | elif args.task == "torch":
93 | import torch
94 | from pyarrow.lib import _default_serialization_context
95 |
96 | pa.register_torch_serialization_handlers(_default_serialization_context)
97 | torch_data = [torch.rand(64, 224, 224, 3), torch.rand(64).to(dtype=torch.int32)]
98 | benchmark_all("torch data", serializers[1:], torch_data)
99 |
--------------------------------------------------------------------------------
/third_party/tensorpack/tests/case_script.py:
--------------------------------------------------------------------------------
1 | from abc import abstractproperty
2 | import unittest
3 | import subprocess
4 | import shlex
5 | import sys
6 | import threading
7 | import os
8 | import shutil
9 |
10 |
11 | class PythonScript(threading.Thread):
12 | """A wrapper to start a python script with timeout.
13 |
14 | To test the actual models even without GPUs we simply start them and
15 | test whether they survive a certain amount of time "timeout". This allows to
16 | test if all imports are correct and the computation graph can be built without
17 | run the entire model on the CPU.
18 |
19 | Attributes:
20 | cmd (str): command to execute the example with all flags (including python)
21 | p: process handle
22 | timeout (int): timeout in seconds
23 | """
24 | def __init__(self, cmd, timeout):
25 | """Prepare a python script
26 |
27 | Args:
28 | cmd (str): command to execute the example with all flags (including python)
29 | timeout (int): time in seconds the script has to survive
30 | """
31 | threading.Thread.__init__(self)
32 | self.cmd = cmd
33 | self.timeout = timeout
34 |
35 | def run(self):
36 | self.p = subprocess.Popen(shlex.split(self.cmd), stderr=subprocess.PIPE, stdout=subprocess.PIPE)
37 | self.out, self.err = self.p.communicate()
38 |
39 | def execute(self):
40 | """Execute python script in other process.
41 |
42 | Raises:
43 | SurviveException: contains the error message of the script if it terminated before timeout
44 | """
45 | self.start()
46 | self.join(self.timeout)
47 |
48 | if self.is_alive():
49 | self.p.terminate()
50 | self.p.kill() # kill -9
51 | self.join()
52 | else:
53 | # something unexpected happend here, this script was supposed to survive at least the timeout
54 | if len(self.err) > 0:
55 | output = u"STDOUT: \n\n\n" + self.out.decode('utf-8')
56 | output += u"\n\n\n STDERR: \n\n\n" + self.err.decode('utf-8')
57 | raise AssertionError(output)
58 |
59 |
60 | class TestPythonScript(unittest.TestCase):
61 |
62 | @abstractproperty
63 | def script(self):
64 | pass
65 |
66 | @staticmethod
67 | def clear_trainlog(script):
68 | script = os.path.basename(script)
69 | script = script[:-3]
70 | if os.path.isdir(os.path.join("train_log", script)):
71 | shutil.rmtree(os.path.join("train_log", script))
72 |
73 | def assertSurvive(self, script, args=None, timeout=20): # noqa
74 | cmd = "python{} {}".format(sys.version_info.major, script)
75 | if args:
76 | cmd += " " + " ".join(args)
77 | PythonScript(cmd, timeout=timeout).execute()
78 |
79 | def setUp(self):
80 | TestPythonScript.clear_trainlog(self.script)
81 |
82 | def tearDown(self):
83 | TestPythonScript.clear_trainlog(self.script)
84 |
--------------------------------------------------------------------------------
/third_party/tensorpack/tests/dev/git-hooks/pre-commit:
--------------------------------------------------------------------------------
1 | #!/bin/bash -e
2 | flake8 .
3 |
4 | cd examples
5 | GIT_ARG="--git-dir ../.git --work-tree .."
6 |
7 | # find out modified python files, so that we ignored unstaged files
8 | # exclude ../docs
9 | MOD=$(git $GIT_ARG status -s \
10 | | grep -E '\.py$' | grep -v '../docs' | grep -v '__init__' \
11 | | grep -E '^ *M|^ *A' | cut -c 4- )
12 | if [[ -n $MOD ]]; then
13 | flake8 $MOD
14 | fi
15 |
--------------------------------------------------------------------------------
/third_party/tensorpack/tests/install-tensorflow.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash -e
2 |
3 | if [ $TF_VERSION == "nightly" ]; then
4 | TF_BINARY_URL="tf-nightly"
5 | else
6 | if [[ $TRAVIS_PYTHON_VERSION == 2* ]]; then
7 | TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-${TF_VERSION}-cp27-none-linux_x86_64.whl
8 | fi
9 | if [[ $TRAVIS_PYTHON_VERSION == 3.4* ]]; then
10 | TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-${TF_VERSION}-cp34-cp34m-linux_x86_64.whl
11 | fi
12 | if [[ $TRAVIS_PYTHON_VERSION == 3.5* ]]; then
13 | TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-${TF_VERSION}-cp35-cp35m-linux_x86_64.whl
14 | fi
15 | if [[ $TRAVIS_PYTHON_VERSION == 3.6* ]]; then
16 | TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-${TF_VERSION}-cp36-cp36m-linux_x86_64.whl
17 | fi
18 | fi
19 |
20 |
21 | python -m pip install --upgrade ${TF_BINARY_URL}
22 |
--------------------------------------------------------------------------------
/third_party/tensorpack/tests/run-tests.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash -ev
2 | # File: run-tests.sh
3 |
4 | mkdir -p "$TENSORPACK_DATASET"
5 | DIR=$(dirname $0)
6 | cd $DIR
7 |
8 | export TF_CPP_MIN_LOG_LEVEL=2
9 | export TF_CPP_MIN_VLOG_LEVEL=2
10 | # test import (#471)
11 | python -c 'from tensorpack.dataflow import imgaug'
12 | # Check that these private names can be imported because tensorpack is using them
13 | python -c "from tensorflow.python.client.session import _FetchHandler"
14 | python -c "from tensorflow.python.training.monitored_session import _HookedSession"
15 | python -c "import tensorflow as tf; tf.Operation._add_control_input"
16 |
17 | # run tests
18 | python -m unittest tensorpack.callbacks.param_test
19 | python -m unittest tensorpack.tfutils.unit_tests
20 | python -m unittest tensorpack.dataflow.imgaug.imgaug_test
21 | python -m unittest tensorpack.models.models_test
22 |
23 | # use pyarrow after we organize the serializers.
24 | # TENSORPACK_SERIALIZE=pyarrow python ...
25 | python -m unittest tensorpack.dataflow.serialize_test
26 |
27 | # e2e tests
28 | python -m unittest discover -v
29 |
--------------------------------------------------------------------------------
/third_party/tensorpack/tests/test_char_rnn.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from case_script import TestPythonScript
4 |
5 |
6 | def random_content():
7 | return ('Lorem ipsum dolor sit amet\n'
8 | 'consetetur sadipscing elitr\n'
9 | 'sed diam nonumy eirmod tempor invidunt ut labore\n')
10 |
11 |
12 | class CharRNNTest(TestPythonScript):
13 |
14 | @property
15 | def script(self):
16 | return '../examples/Char-RNN/char-rnn.py'
17 |
18 | def setUp(self):
19 | super(CharRNNTest, self).setUp()
20 | with open('input.txt', 'w') as f:
21 | f.write(random_content())
22 |
23 | def test(self):
24 | self.assertSurvive(self.script, args=['train'])
25 |
26 | def tearDown(self):
27 | super(CharRNNTest, self).tearDown()
28 | os.remove('input.txt')
29 |
--------------------------------------------------------------------------------
/third_party/tensorpack/tests/test_infogan.py:
--------------------------------------------------------------------------------
1 | from case_script import TestPythonScript
2 |
3 | from tensorpack.tfutils.common import get_tf_version_tuple
4 |
5 |
6 | class InfoGANTest(TestPythonScript):
7 |
8 | @property
9 | def script(self):
10 | return '../examples/GAN/InfoGAN-mnist.py'
11 |
12 | def test(self):
13 | return True # https://github.com/tensorflow/tensorflow/issues/24517
14 | if get_tf_version_tuple() < (1, 4):
15 | return True # requires leaky_relu
16 | self.assertSurvive(self.script, args=None)
17 |
--------------------------------------------------------------------------------
/third_party/tensorpack/tests/test_mnist.py:
--------------------------------------------------------------------------------
1 | from case_script import TestPythonScript
2 |
3 |
4 | class MnistTest(TestPythonScript):
5 |
6 | @property
7 | def script(self):
8 | return '../examples/basics/mnist-convnet.py'
9 |
10 | def test(self):
11 | self.assertSurvive(self.script, args=None)
12 |
--------------------------------------------------------------------------------
/third_party/tensorpack/tests/test_mnist_similarity.py:
--------------------------------------------------------------------------------
1 | from case_script import TestPythonScript
2 |
3 |
4 | class SimilarityLearningTest(TestPythonScript):
5 |
6 | @property
7 | def script(self):
8 | return '../examples/SimilarityLearning/mnist-embeddings.py'
9 |
10 | def test(self):
11 | self.assertSurvive(self.script, args=['--algorithm triplet'], timeout=10)
12 |
--------------------------------------------------------------------------------
/third_party/tensorpack/tests/test_resnet.py:
--------------------------------------------------------------------------------
1 | from case_script import TestPythonScript # noqa
2 |
3 | # this tests occasionally fails (memory issue on travis?)
4 |
5 |
6 | # class ResnetTest(TestPythonScript):
7 | # @property
8 | # def script(self):
9 | # return '../examples/ResNet/imagenet-resnet.py'
10 | #
11 | # def test(self):
12 | # self.assertSurvive(
13 | # self.script,
14 | # args=['--fake', '--data_format NHWC'], timeout=20)
15 |
--------------------------------------------------------------------------------
/third_party/tensorpack/tox.ini:
--------------------------------------------------------------------------------
1 | [flake8]
2 | max-line-length = 120
3 | # See https://pep8.readthedocs.io/en/latest/intro.html#error-codes
4 | ignore = E265,E741,E742,E743,W504,W605,C408,B007,B008
5 | exclude = .git,
6 | __init__.py,
7 | setup.py,
8 | tensorpack/compat/*,
9 | docs,
10 | examples,
11 | docs/conf.py
12 | snippet,
13 | examples_v2,
14 | _test.py,
15 | show-source = true
16 |
17 | [isort]
18 | line_length=100
19 | skip=docs/conf.py
20 | multi_line_output=4
21 | known_tensorpack=tensorpack
22 | known_standard_library=numpy
23 | known_third_party=bob,gym,matplotlib
24 | no_lines_before=STDLIB,THIRDPARTY
25 | sections=FUTURE,STDLIB,THIRDPARTY,tensorpack,FIRSTPARTY,LOCALFOLDER
26 |
--------------------------------------------------------------------------------