└── README.md
/README.md:
--------------------------------------------------------------------------------
1 | # How to use TensorLayer
2 |
3 | While research in Deep Learning continues to improve the world, we use a bunch of tricks to implement algorithms with TensorLayer day to day.
4 |
5 | Here are a summary of the tricks to use TensorLayer.
6 | If you find a trick that is particularly useful in practice, please open a Pull Request to add it to the document. If we find it to be reasonable and verified, we will merge it in.
7 |
8 | - 🇨🇳 [《深度学习:一起玩转TensorLayer》](https://item.jd.com/12286942.html)已上架。
9 |
10 | ## 1. Installation
11 | * To keep your TL version and edit the source code easily, you can download the whole repository by excuting `git clone https://github.com/zsdonghao/tensorlayer.git` in your terminal, then copy the `tensorlayer` folder into your project
12 | * As TL is growing very fast, if you want to use `pip` install, we suggest you to install the master version
13 | * For NLP application, you will need to install [NLTK and NLTK data](http://www.nltk.org/install.html)
14 |
15 | ## 2. Interaction between TF and TL
16 | * TF to TL : use [InputLayer](https://tensorlayer.readthedocs.io/en/latest/modules/layers.html#input-layers)
17 | * TL to TF : use [network.outputs](http://tensorlayer.readthedocs.io/en/latest/modules/layers.html#understand-basic-layer)
18 | * Other methods [issues7](https://github.com/tensorlayer/tensorlayer/issues/7), multiple inputs [issues31](https://github.com/tensorlayer/tensorlayer/issues/31)
19 |
20 | ## 3. Training/Testing switching
21 | * Use [network.all_drop](http://tensorlayer.readthedocs.io/en/latest/modules/layers.html#understand-basic-layer) to control the training/testing phase (for [DropoutLayer](http://tensorlayer.readthedocs.io/en/latest/modules/layers.html#dropout-layer) only) see [this example](https://github.com/tensorlayer/tensorlayer/blob/master/examples/basic_tutorials/tutorial_mlp_dropout1.py) and [Understand Basic layer](http://tensorlayer.readthedocs.io/en/latest/modules/layers.html#understand-basic-layer)
22 | * Alternatively, set `is_fix` to `True` in [DropoutLayer](http://tensorlayer.readthedocs.io/en/latest/modules/layers.html#dropout-layer), and build different graphs for training/testing by reusing the parameters. You can also set different `batch_size` and noise probability for different graphs. This method is the best when you use [GaussianNoiseLayer](https://tensorlayer.readthedocs.io/en/stable/modules/layers.html#tensorlayer.layers.GaussianNoiseLayer), [BatchNormLayer](https://tensorlayer.readthedocs.io/en/stable/modules/layers.html#normalization-layers) and etc. Here is an example:
23 |
24 | ```python
25 | def mlp(x, is_train=True, reuse=False):
26 | with tf.variable_scope("MLP", reuse=reuse):
27 | net = InputLayer(x, name='in')
28 | net = DropoutLayer(net, 0.8, True, is_train, name='drop1')
29 | net = DenseLayer(net, n_units=800, act=tf.nn.relu, name='dense1')
30 | net = DropoutLayer(net, 0.8, True, is_train, name='drop2')
31 | net = DenseLayer(net, n_units=800, act=tf.nn.relu, name='dense2')
32 | net = DropoutLayer(net, 0.8, True, is_train, name='drop3')
33 | net = DenseLayer(net, n_units=10, act=tf.identity, name='out')
34 | logits = net.outputs
35 | net.outputs = tf.nn.sigmoid(net.outputs)
36 | return net, logits
37 | x = tf.placeholder(tf.float32, shape=[None, 784], name='x')
38 | y_ = tf.placeholder(tf.int64, shape=[None, ], name='y_')
39 | net_train, logits = mlp(x, is_train=True, reuse=False)
40 | net_test, _ = mlp(x, is_train=False, reuse=True)
41 | cost = tl.cost.cross_entropy(logits, y_, name='cost')
42 | ```
43 |
44 | More in [here](https://github.com/tensorlayer/tensorlayer/blob/master/examples/basic_tutorials/tutorial_mlp_dropout2.py).
45 |
46 | ## 4. Get variables and outputs
47 | * Use [tl.layers.get_variables_with_name](https://tensorlayer.readthedocs.io/en/stable/modules/layers.html#get-variables-with-name) instead of using [net.all_params](https://tensorlayer.readthedocs.io/en/stable/modules/layers.html#understanding-the-basic-layer)
48 | ```python
49 | train_vars = tl.layers.get_variables_with_name('MLP', True, True)
50 | train_op = tf.train.AdamOptimizer(learning_rate=0.0001).minimize(cost, var_list=train_vars)
51 | ```
52 | * This method can also be used to freeze some layers during training, just simply don't get some variables
53 | * Other methods [issues17](https://github.com/zsdonghao/tensorlayer/issues/17), [issues26](https://github.com/zsdonghao/tensorlayer/issues/26), [FQA](http://tensorlayer.readthedocs.io/en/latest/user/more.html#exclude-some-layers-from-training)
54 | * Use [tl.layers.get_layers_with_name](https://tensorlayer.readthedocs.io/en/stable/modules/layers.html#name-scope-and-sharing-parameters) to get list of activation outputs from a network.
55 | ```python
56 | layers = tl.layers.get_layers_with_name(network, "MLP", True)
57 | ```
58 | * This method usually be used for activation regularization.
59 |
60 | ## 5. Data augmentation for large dataset
61 | If your dataset is large, data loading and data augmentation will become the bottomneck and slow down the training.
62 | To speed up the data processing you can:
63 |
64 | * Use TFRecord or TF DatasetAPI, see [cifar10 examples](https://github.com/tensorlayer/tensorlayer/tree/master/examples/basic_tutorials)
65 |
66 | ## 6. Data augmentation for small dataset
67 | If your data size is small enough to feed into the memory of your machine, and data augmentation is simple. To debug easily, you can:
68 |
69 | * Use [tl.iterate.minibatches](http://tensorlayer.readthedocs.io/en/latest/modules/iterate.html#tensorlayer.iterate.minibatches) to shuffle and return the examples and labels by the given batchsize.
70 | * Use [tl.prepro.threading_data](http://tensorlayer.readthedocs.io/en/latest/modules/prepro.html#tensorlayer.prepro.threading_data) to read a batch of data at the beginning of every step, the performance is slow but good for small dataset.
71 | * For time-series data, use [tl.iterate.seq_minibatches, tl.iterate.seq_minibatches2, tl.iterate.ptb_iterator and etc](http://tensorlayer.readthedocs.io/en/latest/modules/iterate.html#time-series)
72 |
73 | ## 7. Pre-trained CNN and Resnet
74 | * Pre-trained CNN
75 | * Many applications make need pre-trained CNN model
76 | * TL provides pre-trained VGG16, VGG19, MobileNet, SqueezeNet and etc : [tl.models](https://tensorlayer.readthedocs.io/en/stable/modules/models.html#)
77 | * [tl.layers.SlimNetsLayer](https://tensorlayer.readthedocs.io/en/stable/modules/layers.html#external-libraries-layers) allows you to use all [Tf-Slim pre-trained models](https://github.com/tensorflow/models/tree/master/research/slim#pre-trained-models) and [tensorlayer/pretrained-models](https://github.com/tensorlayer/pretrained-models)
78 | * Resnet
79 | * Implemented by "for" loop [issues85](https://github.com/zsdonghao/tensorlayer/issues/85)
80 | * Other methods [by @ritchieng](https://github.com/ritchieng/wideresnet-tensorlayer)
81 |
82 | ## 8. Using `tl.models`
83 |
84 | * Use pretrained VGG16 for ImageNet classification
85 | ```python
86 | x = tf.placeholder(tf.float32, [None, 224, 224, 3])
87 | # get the whole model
88 | vgg = tl.models.VGG16(x)
89 | # restore pre-trained VGG parameters
90 | sess = tf.InteractiveSession()
91 | vgg.restore_params(sess)
92 | # use for inferencing
93 | probs = tf.nn.softmax(vgg.outputs)
94 | ```
95 |
96 | * Extract features with VGG16 and retrain a classifier with 100 classes
97 | ```python
98 | x = tf.placeholder(tf.float32, [None, 224, 224, 3])
99 | # get VGG without the last layer
100 | vgg = tl.models.VGG16(x, end_with='fc2_relu')
101 | # add one more layer
102 | net = tl.layers.DenseLayer(vgg, 100, name='out')
103 | # initialize all parameters
104 | sess = tf.InteractiveSession()
105 | tl.layers.initialize_global_variables(sess)
106 | # restore pre-trained VGG parameters
107 | vgg.restore_params(sess)
108 | # train your own classifier (only update the last layer)
109 | train_params = tl.layers.get_variables_with_name('out')
110 | ```
111 |
112 | * Reuse model
113 |
114 | ```python
115 | x1 = tf.placeholder(tf.float32, [None, 224, 224, 3])
116 | x2 = tf.placeholder(tf.float32, [None, 224, 224, 3])
117 | # get VGG without the last layer
118 | vgg1 = tl.models.VGG16(x1, end_with='fc2_relu')
119 | # reuse the parameters of vgg1 with different input
120 | vgg2 = tl.models.VGG16(x2, end_with='fc2_relu', reuse=True)
121 | # restore pre-trained VGG parameters (as they share parameters, we don’t need to restore vgg2)
122 | sess = tf.InteractiveSession()
123 | vgg1.restore_params(sess)
124 | ```
125 |
126 | ## 9. Customized layer
127 | * 1. [Write a TL layer directly](https://tensorlayer.readthedocs.io/en/stable/modules/layers.html#customizing-layers)
128 | * 2. Use [LambdaLayer](https://tensorlayer.readthedocs.io/en/stable/modules/layers.html#lambda-layers), it can also accept functions with new variables. With this layer you can connect all third party TF libraries and your customized function to TL. Here is an example of using Keras and TL together.
129 |
130 | ```python
131 | import tensorflow as tf
132 | import tensorlayer as tl
133 | from keras.layers import *
134 | from tensorlayer.layers import *
135 | def my_fn(x):
136 | x = Dropout(0.8)(x)
137 | x = Dense(800, activation='relu')(x)
138 | x = Dropout(0.5)(x)
139 | x = Dense(800, activation='relu')(x)
140 | x = Dropout(0.5)(x)
141 | logits = Dense(10, activation='linear')(x)
142 | return logits
143 |
144 | network = InputLayer(x, name='input')
145 | network = LambdaLayer(network, my_fn, name='keras')
146 | ...
147 | ```
148 |
149 | ## 10. Sentences tokenization
150 | * Use [tl.nlp.process_sentence](https://tensorlayer.readthedocs.io/en/stable/modules/nlp.html#tensorlayer.nlp.process_sentence) to tokenize the sentences, [NLTK and NLTK data](http://www.nltk.org/install.html) is required
151 |
152 | ```python
153 | >>> captions = ["one two , three", "four five five"] # 2个 句 子
154 | >>> processed_capts = []
155 | >>> for c in captions:
156 | >>> c = tl.nlp.process_sentence(c, start_word="", end_word="")
157 | >>> processed_capts.append(c)
158 | >>> print(processed_capts)
159 | ... [['', 'one', 'two', ',', 'three', ''],
160 | ... ['', 'four', 'five', 'five', '']]
161 | ```
162 |
163 | * Then use [tl.nlp.create_vocab](https://tensorlayer.readthedocs.io/en/stable/modules/nlp.html#tensorlayer.nlp.create_vocab) to create a vocabulary and save as txt file (it will return a [tl.nlp.SimpleVocabulary object](https://tensorlayer.readthedocs.io/en/stable/modules/nlp.html#tensorlayer.nlp.SimpleVocabulary) for word to id only)
164 |
165 | ```python
166 | >>> tl.nlp.create_vocab(processed_capts, word_counts_output_file='vocab.txt', min_word_count=1)
167 | ... [TL] Creating vocabulary.
168 | ... Total words: 8
169 | ... Words in vocabulary: 8
170 | ... Wrote vocabulary file: vocab.txt
171 | ```
172 |
173 | * Finally use [tl.nlp.Vocabulary](https://tensorlayer.readthedocs.io/en/stable/modules/nlp.html#vocabulary-class) to create a vocabulary object from the txt vocabulary file created by `tl.nlp.create_vocab`
174 |
175 | ```python
176 | >>> vocab = tl.nlp.Vocabulary('vocab.txt', start_word="", end_word="", unk_word="")
177 | ... INFO:tensorflow:Initializing vocabulary from file: vocab.txt
178 | ... [TL] Vocabulary from vocab.txt :
179 | ... vocabulary with 10 words (includes start_word, end_word, unk_word)
180 | ... start_id: 2
181 | ... end_id: 3
182 | ... unk_id: 9
183 | ... pad_id: 0
184 | ```
185 |
186 | Then you can map word to ID or vice verse as follow:
187 | ```python
188 | >>> vocab.id_to_word(2)
189 | ... 'one'
190 | >>> vocab.word_to_id('one')
191 | ... 2
192 | >>> vocab.id_to_word(100)
193 | ... ''
194 | >>> vocab.word_to_id('hahahaha')
195 | ... 9
196 | ```
197 |
198 | * More pre-processing functions for sentences in [tl.prepro](https://tensorlayer.readthedocs.io/en/stable/modules/prepro.html#sequence) and [tl.nlp](https://tensorlayer.readthedocs.io/en/stable/modules/nlp.html)
199 |
200 | ## 11. Dynamic RNN and sequence length
201 | * Apply zero padding on a batch of tokenized sentences as follow:
202 | ```python
203 | >>> sequences = [[1,1,1,1,1],[2,2,2],[3,3]]
204 | >>> sequences = tl.prepro.pad_sequences(sequences, maxlen=None,
205 | ... dtype='int32', padding='post', truncating='pre', value=0.)
206 | ... [[1 1 1 1 1]
207 | ... [2 2 2 0 0]
208 | ... [3 3 0 0 0]]
209 | ```
210 |
211 | * Use [tl.layers.retrieve_seq_length_op2](https://tensorlayer.readthedocs.io/en/stable/modules/layers.html#tensorlayer.layers.retrieve_seq_length_op2) to automatically compute the sequence length from placeholder, and feed it to the `sequence_length` of [DynamicRNNLayer](https://tensorlayer.readthedocs.io/en/stable/modules/layers.html#dynamic-rnn-layer)
212 |
213 | ```python
214 | >>> data = [[1,2,0,0,0], [1,2,3,0,0], [1,2,6,1,0]]
215 | >>> o = tl.layers.retrieve_seq_length_op2(data)
216 | >>> sess = tf.InteractiveSession()
217 | >>> tl.layers.initialize_global_variables(sess)
218 | >>> print(o.eval())
219 | ... [2 3 4]
220 | ```
221 |
222 | * Other methods [issues18](https://github.com/tensorlayer/tensorlayer/issues/18)
223 |
224 | ## 12. Save models
225 |
226 | - 1. [tl.files.save_npz](https://tensorlayer.readthedocs.io/en/stable/modules/files.html#save-network-into-list-npz) save all model parameters (weights) into a a list of array, restore using `tl.files.load_and_assign_npz`
227 | - 2. [tl.files.save_npz_dict](https://tensorlayer.readthedocs.io/en/stable/modules/files.html#save-network-into-dict-npz) save all model parameters (weights) into a dictionary of array, key is the parameter name, restore using `tl.files.load_and_assign_npz_dict`
228 | - 3. [tl.files.save_ckpt](https://tensorlayer.readthedocs.io/en/stable/modules/files.html#save-network-into-ckpt) save all model parameters (weights) into TensorFlow ckpt file, restore using `tl.files.load_ckpt`.
229 |
230 | ## 13. Compatibility with other TF wrappers
231 | TL can interact with other TF wrappers, which means if you find some codes or models implemented by other wrappers, you can just use it !
232 | * Other TensorFlow layer implementations can be connected into TensorLayer via [LambdaLayer](https://tensorlayer.readthedocs.io/en/stable/modules/layers.html#lambda-layers), see example [here](https://github.com/tensorlayer/tensorlayer/tree/master/examples/keras_tfslim))
233 | * TF-Slim to TL: [SlimNetsLayer](https://tensorlayer.readthedocs.io/en/stable/modules/layers.html#tensorlayer.layers.SlimNetsLayer) (you can use all Google's pre-trained convolutional models with this layer !!!)
234 |
235 | ## 14. Others
236 | * `BatchNormLayer`'s `decay` default is 0.9, set to 0.999 for large dataset.
237 | * Matplotlib issue arise when importing TensorLayer [issues](https://github.com/tensorlayer/tensorlayer/issues/79), see [FQA](https://tensorlayer.readthedocs.io/en/stable/user/faq.html#visualization)
238 |
239 | ## Useful links
240 | * [Awesome-TensorLayer](https://github.com/tensorlayer/awesome-tensorlayer) for all examples
241 | * TL official sites: [Docs](https://tensorlayer.readthedocs.io), [中文文档](https://tensorlayercn.readthedocs.io), [Github](https://github.com/tensorlayer/tensorlayer)
242 | * [Learning Deep Learning with TF and TL ](https://github.com/wagamamaz/tensorflow-tutorial)
243 | * Follow [zsdonghao](https://github.com/zsdonghao) for further examples
244 |
245 | ## Author
246 | - Zhang Rui
247 | - Hao Dong
248 |
--------------------------------------------------------------------------------