├── .github
├── FUNDING.yml
├── ISSUE_TEMPLATE
│ ├── bug_report.md
│ └── feature_request.md
├── dependabot.yml
└── workflows
│ └── ci.yml
├── .gitignore
├── LICENSE
├── README.md
├── misc
├── Adding_Task.png
├── Copy_Memory_Task.png
├── Dilated_Conv.png
├── Non_Causal.png
├── Receptive_Field_Formula.png
└── Sequential_MNIST_Task.png
├── requirements.txt
├── setup.py
├── tasks
├── adding_problem
│ ├── README.md
│ ├── main.py
│ └── utils.py
├── bidirect_tcn.py
├── copy_memory
│ ├── README.md
│ ├── main.py
│ └── utils.py
├── est_receptive_field.py
├── exchange_rate
│ ├── demo.ipynb
│ ├── exchange_rate.txt
│ ├── main.py
│ └── utils.py
├── imdb_tcn.py
├── many_to_many.py
├── mnist_pixel
│ ├── main.py
│ └── utils.py
├── monthly-milk-production-pounds-p.csv
├── multi_length_sequences.py
├── non_causal.py
├── plot_tcn_model.py
├── receptive-field
│ ├── main.py
│ ├── run.sh
│ └── utils.py
├── save_reload_sequential_model.py
├── sequential.py
├── tcn_call_test.py
├── tcn_tensorboard.py
├── time_series_forecasting.py
├── video_classification.py
├── visualise_activations.py
└── word_ptb
│ ├── README.md
│ ├── data
│ ├── README
│ ├── ptb.test.txt
│ ├── ptb.train.txt
│ └── ptb.valid.txt
│ ├── plot.py
│ ├── result.png
│ ├── run.sh
│ └── train.py
├── tcn
├── __init__.py
└── tcn.py
└── tox.ini
/.github/FUNDING.yml:
--------------------------------------------------------------------------------
1 | github: [philipperemy]
2 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/bug_report.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Bug report
3 | about: Create a report to help us improve
4 | title: ''
5 | labels: ''
6 | assignees: ''
7 |
8 | ---
9 |
10 | **Describe the bug**
11 | A description of what the bug is.
12 |
13 | **Paste a snippet**
14 | The snippet has to be concise and standalone. Everyone should be able to reproduce the bug by copying and running the snippet.
15 |
16 | **Dependencies**
17 | Specify which version of tensorflow you are running.
18 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/feature_request.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Feature request
3 | about: Suggest an idea for this project
4 | title: ''
5 | labels: ''
6 | assignees: ''
7 |
8 | ---
9 |
10 | **Is your feature request related to a problem? Please describe.**
11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
12 |
13 | **Describe the solution you'd like**
14 | A clear and concise description of what you want to happen.
15 |
16 | **Describe alternatives you've considered**
17 | A clear and concise description of any alternative solutions or features you've considered.
18 |
19 | **Additional context**
20 | Add any other context or screenshots about the feature request here.
21 |
--------------------------------------------------------------------------------
/.github/dependabot.yml:
--------------------------------------------------------------------------------
1 | version: 2
2 | updates:
3 | - package-ecosystem: pip
4 | directory: "/"
5 | schedule:
6 | interval: daily
7 | time: "20:00"
8 | open-pull-requests-limit: 10
9 | ignore:
10 | - dependency-name: gast
11 | versions:
12 | - "> 0.2.2, < 1"
13 | - dependency-name: numpy
14 | versions:
15 | - "> 1.16.2, < 2"
16 | - dependency-name: numpy
17 | versions:
18 | - "> 1.17.3, < 1.18"
19 |
--------------------------------------------------------------------------------
/.github/workflows/ci.yml:
--------------------------------------------------------------------------------
1 | name: Keras TCN CI
2 |
3 | on: [ push, pull_request ]
4 |
5 | jobs:
6 | build:
7 |
8 | runs-on: ubuntu-latest
9 | strategy:
10 | max-parallel: 4
11 | matrix:
12 | python-version: [ "3.10" ]
13 |
14 | steps:
15 | - uses: actions/checkout@v4
16 | - name: Set up Python ${{ matrix.python-version }}
17 | uses: actions/setup-python@v4
18 | with:
19 | python-version: ${{ matrix.python-version }}
20 | - name: Install dependencies
21 | run: |
22 | sudo apt-get install -y graphviz
23 | python -m pip install --upgrade pip
24 | pip install tox flake8
25 | - name: Lint with flake8
26 | run: |
27 | # stop the build if there are Python syntax errors or undefined names
28 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
29 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
30 | flake8 . --count --max-complexity 10 --max-line-length 127 --statistics
31 | - name: Test with tox
32 | run: |
33 | tox
34 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 | .idea/
6 | .DS_Store
7 |
8 | *.h5
9 | *.tsv
10 | *.tar.gz
11 | *out*
12 | credentials.json
13 |
14 | *.json
15 |
16 | nohup.out
17 | *.out
18 | *.txt
19 |
20 | # C extensions
21 | *.so
22 |
23 | # Distribution / packaging
24 | .Python
25 | env/
26 | build/
27 | develop-eggs/
28 | dist/
29 | downloads/
30 | eggs/
31 | .eggs/
32 | lib/
33 | lib64/
34 | parts/
35 | sdist/
36 | var/
37 | wheels/
38 | *.egg-info/
39 | .installed.cfg
40 | *.egg
41 |
42 | # PyInstaller
43 | # Usually these files are written by a python script from a template
44 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
45 | *.manifest
46 | *.spec
47 |
48 | # Installer logs
49 | pip-log.txt
50 | pip-delete-this-directory.txt
51 |
52 | # Unit test / coverage reports
53 | htmlcov/
54 | .tox/
55 | .coverage
56 | .coverage.*
57 | .cache
58 | nosetests.xml
59 | coverage.xml
60 | *.cover
61 | .hypothesis/
62 |
63 | # Translations
64 | *.mo
65 | *.pot
66 |
67 | # Django stuff:
68 | *.log
69 | local_settings.py
70 |
71 | # Flask stuff:
72 | instance/
73 | .webassets-cache
74 |
75 | # Scrapy stuff:
76 | .scrapy
77 |
78 | # Sphinx documentation
79 | docs/_build/
80 |
81 | # PyBuilder
82 | target/
83 |
84 | # Jupyter Notebook
85 | .ipynb_checkpoints
86 |
87 | # pyenv
88 | .python-version
89 |
90 | # celery beat schedule file
91 | celerybeat-schedule
92 |
93 | # SageMath parsed files
94 | *.sage.py
95 |
96 | # dotenv
97 | .env
98 |
99 | # virtualenv
100 | .venv
101 | venv/
102 | ENV/
103 |
104 | # Spyder project settings
105 | .spyderproject
106 | .spyproject
107 |
108 | # Rope project settings
109 | .ropeproject
110 |
111 | # mkdocs documentation
112 | /site
113 |
114 | # mypy
115 | .mypy_cache/
116 |
117 | !/tasks/exchange_rate/exchange_rate.txt
118 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Philippe Rémy
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Keras TCN
2 |
3 | *Keras Temporal Convolutional Network*. [[paper](https://arxiv.org/abs/1803.01271)]
4 |
5 | Tested with Tensorflow 2.9, 2.10, 2.11, 2.12, 2.13, 2.14, 2.15, 2.16, 2.17, 2.18, 2.19 (Mar 13, 2025).
6 |
7 | For a fully working example of Keras TCN using **R Language**, [browse here](https://github.com/philipperemy/keras-tcn/issues/246).
8 |
9 | [](https://pepy.tech/project/keras-tcn)
10 | [](https://pepy.tech/project/keras-tcn)
11 | 
12 | ```bash
13 | pip install keras-tcn
14 | ```
15 |
16 | For [MacOS users](https://developer.apple.com/metal/tensorflow-plugin/) to use the GPU: `pip install tensorflow-metal`.
17 |
18 | ## Why TCN (Temporal Convolutional Network) instead of LSTM/GRU?
19 |
20 | - TCNs exhibit longer memory than recurrent architectures with the same capacity.
21 | - Performs better than LSTM/GRU on long time series (Seq. MNIST, Adding Problem, Copy Memory, Word-level PTB...).
22 | - Parallelism (convolutional layers), flexible receptive field size (how far the model can see), stable gradients (compared to backpropagation through time, vanishing gradients)...
23 |
24 |
25 |
26 | Visualization of a stack of dilated causal convolutional layers (Wavenet, 2016)
27 |
28 |
29 | ## TCN Layer
30 |
31 | ### TCN Class
32 |
33 | ```python
34 | TCN(
35 | nb_filters=64,
36 | kernel_size=3,
37 | nb_stacks=1,
38 | dilations=(1, 2, 4, 8, 16, 32),
39 | padding='causal',
40 | use_skip_connections=True,
41 | dropout_rate=0.0,
42 | return_sequences=False,
43 | activation='relu',
44 | kernel_initializer='he_normal',
45 | use_batch_norm=False,
46 | use_layer_norm=False,
47 | go_backwards=False,
48 | return_state=False,
49 | **kwargs
50 | )
51 | ```
52 |
53 | ### Arguments
54 |
55 | - `nb_filters`: Integer. The number of filters to use in the convolutional layers. Would be similar to `units` for LSTM. Can be a list.
56 | - `kernel_size`: Integer. The size of the kernel to use in each convolutional layer.
57 | - `dilations`: List/Tuple. A dilation list. Example is: [1, 2, 4, 8, 16, 32, 64].
58 | - `nb_stacks`: Integer. The number of stacks of residual blocks to use.
59 | - `padding`: String. The padding to use in the convolutions. 'causal' for a causal network (as in the original implementation) and 'same' for a non-causal network.
60 | - `use_skip_connections`: Boolean. If we want to add skip connections from input to each residual block.
61 | - `return_sequences`: Boolean. Whether to return the last output in the output sequence, or the full sequence.
62 | - `dropout_rate`: Float between 0 and 1. Fraction of the input units to drop.
63 | - `activation`: The activation used in the residual blocks o = activation(x + F(x)).
64 | - `kernel_initializer`: Initializer for the kernel weights matrix (Conv1D).
65 | - `use_batch_norm`: Whether to use batch normalization in the residual layers or not.
66 | - `use_layer_norm`: Whether to use layer normalization in the residual layers or not.
67 | - `go_backwards`: Boolean (default False). If True, process the input sequence backwards and return the reversed sequence.
68 | - `return_state`: Boolean. Whether to return the last state in addition to the output. Default: False.
69 | - `kwargs`: Any other set of arguments for configuring the parent class Layer. For example "name=str", Name of the model. Use unique names when using multiple TCN.
70 |
71 | ### Input shape
72 |
73 | 3D tensor with shape `(batch_size, timesteps, input_dim)`.
74 |
75 | `timesteps` can be `None`. This can be useful if each sequence is of a different length: [Multiple Length Sequence Example](tasks/multi_length_sequences.py).
76 |
77 | ### Output shape
78 |
79 | - if `return_sequences=True`: 3D tensor with shape `(batch_size, timesteps, nb_filters)`.
80 | - if `return_sequences=False`: 2D tensor with shape `(batch_size, nb_filters)`.
81 |
82 |
83 | ### How do I choose the correct set of parameters to configure my TCN layer?
84 |
85 | Here are some of my notes regarding my experience using TCN:
86 |
87 | - `nb_filters`: Present in any ConvNet architecture. It is linked to the predictive power of the model and affects the size of your network. The more, the better unless you start to overfit. It's similar to the number of units in an LSTM/GRU architecture too.
88 | - `kernel_size`: Controls the spatial area/volume considered in the convolutional ops. Good values are usually between 2 and 8. If you think your sequence heavily depends on t-1 and t-2, but less on the rest, then choose a kernel size of 2/3. For NLP tasks, we prefer bigger kernel sizes. A large kernel size will make your network much bigger.
89 | - `dilations`: It controls how deep your TCN layer is. Usually, consider a list with multiple of two. You can guess how many dilations you need by matching the receptive field (of the TCN) with the length of features in your sequence. For example, if your input sequence is periodic, you might want to have multiples of that period as dilations.
90 | - `nb_stacks`: Not very useful unless your sequences are very long (like waveforms with hundreds of thousands of time steps).
91 | - `padding`: I have only used `causal` since a TCN stands for Temporal Convolutional Networks. Causal prevents information leakage.
92 | - `use_skip_connections`: Skip connections connects layers, similarly to DenseNet. It helps the gradients flow. Unless you experience a drop in performance, you should always activate it.
93 | - `return_sequences`: Same as the one present in the LSTM layer. Refer to the Keras doc for this parameter.
94 | - `dropout_rate`: Similar to `recurrent_dropout` for the LSTM layer. I usually don't use it much. Or set it to a low value like `0.05`.
95 | - `activation`: Leave it to default. I have never changed it.
96 | - `kernel_initializer`: If the training of the TCN gets stuck, it might be worth changing this parameter. For example: `glorot_uniform`.
97 |
98 | - `use_batch_norm`, `use_layer_norm`: Use normalization if your network is big enough and the task contains enough data. I usually prefer using `use_layer_norm`, but you can try and see which one works the best.
99 |
100 |
101 | ### Receptive field
102 |
103 | The receptive field is defined as: the maximum number of steps back in time from current sample at time T, that a filter from (block, layer, stack, TCN) can hit (effective history) + 1. The receptive field of the TCN can be calculated using the formula:
104 |
105 |
106 |
107 |
108 | where Nstack is the number of stacks, Nb is the number of residual blocks per stack, d is a vector containing the dilations of each residual block in each stack, and K is the kernel size. The 2 is there because there are two `Conv1d` layers in a single `ResidualBlock`.
109 |
110 | Ideally you want your receptive field to be bigger than the largest length of input sequence, if you pass a sequence longer than your receptive field into the model, any extra values (further back in the sequence) will be replaced with zeros.
111 |
112 | #### Examples
113 |
114 | *NOTE*: Unlike the TCN, example figures only include a single `Conv1d` per layer, so the formula becomes Rfield = 1 + (K-1)⋅Nstack⋅Σi di (without the factor 2).
115 |
116 | - If a dilated conv net has only one stack of residual blocks with a kernel size of `2` and dilations `[1, 2, 4, 8]`, its receptive field is `16`. The image below illustrates it:
117 |
118 |
119 |
120 | ks = 2, dilations = [1, 2, 4, 8], 1 block
121 |
122 |
123 | - If a dilated conv net has 2 stacks of residual blocks, you would have the situation below, that is, an increase in the receptive field up to 31:
124 |
125 |
126 |
127 | ks = 2, dilations = [1, 2, 4, 8], 2 blocks
128 |
129 |
130 |
131 | - If we increased the number of stacks to 3, the size of the receptive field would increase again, such as below:
132 |
133 |
134 |
135 | ks = 2, dilations = [1, 2, 4, 8], 3 blocks
136 |
137 |
138 |
139 | ### Non-causal TCN
140 |
141 | Making the TCN architecture non-causal allows it to take the future into consideration to do its prediction as shown in the figure below.
142 |
143 | However, it is not anymore suitable for real-time applications.
144 |
145 |
146 |
147 | Non-Causal TCN - ks = 3, dilations = [1, 2, 4, 8], 1 block
148 |
149 |
150 | To use a non-causal TCN, specify `padding='valid'` or `padding='same'` when initializing the TCN layers.
151 |
152 | ## Run
153 |
154 | Once `keras-tcn` is installed as a package, you can take a glimpse of what is possible to do with TCNs. Some tasks examples are available in the repository for this purpose:
155 |
156 | ```bash
157 | cd adding_problem/
158 | python main.py # run adding problem task
159 |
160 | cd copy_memory/
161 | python main.py # run copy memory task
162 |
163 | cd mnist_pixel/
164 | python main.py # run sequential mnist pixel task
165 | ```
166 |
167 | Reproducible results are possible on (NVIDIA) GPUs using the [tensorflow-determinism](https://github.com/NVIDIA/tensorflow-determinism) library. It was tested with keras-tcn by @lingdoc.
168 |
169 | ## Tasks
170 |
171 | ### Word PTB
172 |
173 | Language modeling remains one of the primary applications of recurrent networks. In this example, we show that TCN can beat LSTM on the [WordPTB](tasks/word_ptb/README.md) task, without too much tuning.
174 |
175 |
176 | 
177 | TCN vs LSTM (comparable number of weights)
178 |
179 |
180 | ### Adding Task
181 |
182 | The task consists of feeding a large array of decimal numbers to the network, along with a boolean array of the same length. The objective is to sum the two decimals where the boolean array contain the two 1s.
183 |
184 | #### Explanation
185 |
186 |
187 |
188 | Adding Problem Task
189 |
190 |
191 | #### Implementation results
192 |
193 | ```
194 | 782/782 [==============================] - 154s 197ms/step - loss: 0.8437 - val_loss: 0.1883
195 | 782/782 [==============================] - 154s 196ms/step - loss: 0.0702 - val_loss: 0.0111
196 | [...]
197 | 782/782 [==============================] - 152s 194ms/step - loss: 6.9630e-04 - val_loss: 3.7180e-04
198 | ```
199 |
200 | ### Copy Memory Task
201 |
202 | The copy memory consists of a very large array:
203 | - At the beginning, there's the vector x of length N. This is the vector to copy.
204 | - At the end, N+1 9s are present. The first 9 is seen as a delimiter.
205 | - In the middle, only 0s are there.
206 |
207 | The idea is to copy the content of the vector x to the end of the large array. The task is made sufficiently complex by increasing the number of 0s in the middle.
208 |
209 | #### Explanation
210 |
211 |
212 |
213 | Copy Memory Task
214 |
215 |
216 | #### Implementation results (first epochs)
217 |
218 | ```
219 | 118/118 [==============================] - 17s 143ms/step - loss: 1.1732 - accuracy: 0.6725 - val_loss: 0.1119 - val_accuracy: 0.9796
220 | [...]
221 | 118/118 [==============================] - 15s 125ms/step - loss: 0.0268 - accuracy: 0.9885 - val_loss: 0.0206 - val_accuracy: 0.9908
222 | 118/118 [==============================] - 15s 125ms/step - loss: 0.0228 - accuracy: 0.9900 - val_loss: 0.0169 - val_accuracy: 0.9933
223 | ```
224 |
225 | ### Sequential MNIST
226 |
227 | #### Explanation
228 |
229 | The idea here is to consider MNIST images as 1-D sequences and feed them to the network. This task is particularly hard because sequences are 28*28 = 784 elements. In order to classify correctly, the network has to remember all the sequence. Usual LSTM are unable to perform well on this task.
230 |
231 |
232 |
233 | Sequential MNIST
234 |
235 |
236 | #### Implementation results
237 |
238 | ```
239 | 1875/1875 [==============================] - 46s 25ms/step - loss: 0.0949 - accuracy: 0.9706 - val_loss: 0.0763 - val_accuracy: 0.9756
240 | 1875/1875 [==============================] - 46s 25ms/step - loss: 0.0831 - accuracy: 0.9743 - val_loss: 0.0656 - val_accuracy: 0.9807
241 | [...]
242 | 1875/1875 [==============================] - 46s 25ms/step - loss: 0.0486 - accuracy: 0.9840 - val_loss: 0.0572 - val_accuracy: 0.9832
243 | 1875/1875 [==============================] - 46s 25ms/step - loss: 0.0453 - accuracy: 0.9858 - val_loss: 0.0424 - val_accuracy: 0.9862
244 | ```
245 |
246 | ## R Language
247 |
248 | For a fully working example of Keras TCN using **R Language**, [browse here](https://github.com/philipperemy/keras-tcn/issues/246).
249 |
250 | ## References
251 | - https://github.com/locuslab/TCN/ (TCN for Pytorch)
252 | - https://arxiv.org/pdf/1803.01271 (An Empirical Evaluation of Generic Convolutional and Recurrent Networks
253 | for Sequence Modeling)
254 | - https://arxiv.org/pdf/1609.03499 (Original Wavenet paper)
255 | - - https://github.com/Baichenjia/Tensorflow-TCN (Tensorflow Eager implementation of TCNs)
256 |
257 | ## Citation
258 |
259 | ```
260 | @misc{KerasTCN,
261 | author = {Philippe Remy},
262 | title = {Temporal Convolutional Networks for Keras},
263 | year = {2020},
264 | publisher = {GitHub},
265 | journal = {GitHub repository},
266 | howpublished = {\url{https://github.com/philipperemy/keras-tcn}},
267 | }
268 | ```
269 |
270 | ## Contributors
271 |
272 |
273 |
274 |
275 |
--------------------------------------------------------------------------------
/misc/Adding_Task.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/philipperemy/keras-tcn/30a765c1daad74514874a6fb363fd428298af899/misc/Adding_Task.png
--------------------------------------------------------------------------------
/misc/Copy_Memory_Task.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/philipperemy/keras-tcn/30a765c1daad74514874a6fb363fd428298af899/misc/Copy_Memory_Task.png
--------------------------------------------------------------------------------
/misc/Dilated_Conv.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/philipperemy/keras-tcn/30a765c1daad74514874a6fb363fd428298af899/misc/Dilated_Conv.png
--------------------------------------------------------------------------------
/misc/Non_Causal.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/philipperemy/keras-tcn/30a765c1daad74514874a6fb363fd428298af899/misc/Non_Causal.png
--------------------------------------------------------------------------------
/misc/Receptive_Field_Formula.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/philipperemy/keras-tcn/30a765c1daad74514874a6fb363fd428298af899/misc/Receptive_Field_Formula.png
--------------------------------------------------------------------------------
/misc/Sequential_MNIST_Task.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/philipperemy/keras-tcn/30a765c1daad74514874a6fb363fd428298af899/misc/Sequential_MNIST_Task.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy
2 | keract
3 | matplotlib
4 | pydot
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import os
2 | import platform
3 |
4 | from setuptools import setup
5 |
6 | tensorflow = 'tensorflow'
7 | if platform.system() == 'Darwin' and platform.processor() == 'arm':
8 | tensorflow = 'tensorflow-macos'
9 | # https://github.com/grpc/grpc/issues/25082
10 | os.environ['GRPC_PYTHON_BUILD_SYSTEM_OPENSSL'] = '1'
11 | os.environ['GRPC_PYTHON_BUILD_SYSTEM_ZLIB'] = '1'
12 |
13 | install_requires = ['numpy', tensorflow]
14 |
15 | setup(
16 | name='keras-tcn',
17 | version='3.5.6',
18 | description='Keras TCN',
19 | author='Philippe Remy',
20 | license_files=['MIT'],
21 | long_description_content_type='text/markdown',
22 | long_description=open('README.md').read(),
23 | packages=['tcn'],
24 | install_requires=install_requires
25 | )
26 |
--------------------------------------------------------------------------------
/tasks/adding_problem/README.md:
--------------------------------------------------------------------------------
1 | ## The Adding Problem
2 |
3 | ### Overview
4 |
5 | In this task, each input consists of a length-T sequence of depth 2, with all values randomly
6 | chosen randomly in [0, 1] in dimension 1. The second dimension consists of all zeros except for
7 | two elements, which are marked by 1. The objective is to sum the two random values whose second
8 | dimensions are marked by 1. One can think of this as computing the dot product of two dimensions.
9 |
10 | Simply predicting the sum to be 1 should give an MSE of about 0.1767.
11 |
12 | ### Data Generation
13 |
14 | See `data_generator` in `utils.py`.
15 |
16 | ### Note
17 |
18 | Because a TCN's receptive field depends on depth of the network and the filter size, we need
19 | to make sure these the model we use can cover the sequence length T.
20 |
21 | From: https://github.com/locuslab/TCN/
--------------------------------------------------------------------------------
/tasks/adding_problem/main.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from tensorflow.keras.callbacks import Callback
3 | from utils import data_generator
4 |
5 | from tcn import compiled_tcn, tcn_full_summary
6 |
7 | x_train, y_train = data_generator(n=200000, seq_length=600)
8 | x_test, y_test = data_generator(n=40000, seq_length=600)
9 |
10 |
11 | class PrintSomeValues(Callback):
12 |
13 | def on_epoch_begin(self, epoch, logs={}):
14 | print('y_true, y_pred')
15 | print(np.hstack([y_test[:5], self.model.predict(x_test[:5])]))
16 |
17 |
18 | def run_task():
19 | model = compiled_tcn(
20 | return_sequences=False,
21 | num_feat=x_train.shape[2],
22 | num_classes=0,
23 | nb_filters=24,
24 | kernel_size=8,
25 | dilations=[2 ** i for i in range(9)],
26 | nb_stacks=1,
27 | max_len=x_train.shape[1],
28 | use_skip_connections=False,
29 | # use_weight_norm=True,
30 | regression=True,
31 | dropout_rate=0
32 | )
33 |
34 | tcn_full_summary(model)
35 | model.fit(x_train, y_train, validation_data=(x_test, y_test), epochs=15,
36 | batch_size=256, callbacks=[PrintSomeValues()])
37 |
38 |
39 | if __name__ == '__main__':
40 | run_task()
41 |
--------------------------------------------------------------------------------
/tasks/adding_problem/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def data_generator(n, seq_length):
5 | """
6 | Args:
7 | seq_length: Length of the adding problem data
8 | n: # of data in the set
9 | """
10 | x_num = np.random.uniform(0, 1, (n, 1, seq_length))
11 | x_mask = np.zeros([n, 1, seq_length])
12 | y = np.zeros([n, 1])
13 | for i in range(n):
14 | positions = np.random.choice(seq_length, size=2, replace=False)
15 | x_mask[i, 0, positions[0]] = 1
16 | x_mask[i, 0, positions[1]] = 1
17 | y[i, 0] = x_num[i, 0, positions[0]] + x_num[i, 0, positions[1]]
18 | x = np.concatenate((x_num, x_mask), axis=1)
19 | x = np.transpose(x, (0, 2, 1))
20 | return x, y
21 |
22 |
23 | if __name__ == '__main__':
24 | print(data_generator(n=20, seq_length=10))
25 |
--------------------------------------------------------------------------------
/tasks/bidirect_tcn.py:
--------------------------------------------------------------------------------
1 | """
2 | #Trains a TCN on the IMDB sentiment classification task.
3 | Output after 1 epochs on CPU: ~0.8611
4 | Time per epoch on CPU (Core i7): ~64s.
5 | Based on: https://github.com/keras-team/keras/blob/master/examples/imdb_bidirectional_lstm.py
6 | """
7 | import numpy as np
8 | from tensorflow.keras import Input, Model
9 | from tensorflow.keras.datasets import imdb
10 | from tensorflow.keras.layers import Dense, Embedding, Bidirectional
11 | from tensorflow.keras.preprocessing import sequence
12 |
13 | from tcn import TCN
14 |
15 | max_features = 20000
16 | # cut texts after this number of words
17 | # (among top max_features most common words)
18 | maxlen = 100
19 | batch_size = 32
20 |
21 | print('Loading data...')
22 | (x_train, y_train), (x_test, y_test) = imdb.load_data(num_words=max_features)
23 | print(len(x_train), 'train sequences')
24 | print(len(x_test), 'test sequences')
25 |
26 | print('Pad sequences (samples x time)')
27 | x_train = sequence.pad_sequences(x_train, maxlen=maxlen)
28 | x_test = sequence.pad_sequences(x_test, maxlen=maxlen)
29 | print('x_train shape:', x_train.shape)
30 | print('x_test shape:', x_test.shape)
31 | y_train = np.array(y_train)
32 | y_test = np.array(y_test)
33 |
34 | inputs = Input(shape=(None,), dtype="int32")
35 | x = Embedding(max_features, 128)(inputs)
36 | x = Bidirectional(TCN(64))(x)
37 | # Add a classifier
38 | outputs = Dense(1, activation="sigmoid")(x)
39 | model = Model(inputs, outputs)
40 | model.summary()
41 |
42 | print(f'Backward TCN receptive field: {model.layers[2].backward_layer.receptive_field}.')
43 | print(f'Forward TCN receptive field: {model.layers[2].forward_layer.receptive_field}.')
44 |
45 | model.summary()
46 | model.compile('adam', 'binary_crossentropy', metrics=['accuracy'])
47 |
48 | print('Train...')
49 | model.fit(
50 | x_train, y_train,
51 | batch_size=batch_size,
52 | validation_data=[x_test, y_test]
53 | )
54 |
--------------------------------------------------------------------------------
/tasks/copy_memory/README.md:
--------------------------------------------------------------------------------
1 | ## Copying Memory Task
2 |
3 | ### Overview
4 |
5 | In this task, each input sequence has length T+20. The first 10 values are chosen randomly
6 | among the digits 1-8, with the rest being all zeros, except for the last 11 entries that are
7 | filled with the digit ‘9’ (the first ‘9’ is a delimiter). The goal is to generate an output
8 | of same length that is zero everywhere, except the last 10 values after the delimiter, where
9 | the model is expected to repeat the 10 values it encountered at the start of the input.
10 |
11 | ### Data Generation
12 |
13 | See `data_generator` in `utils.py`.
14 |
15 | ### Note
16 |
17 | - Because a TCN's receptive field depends on depth of the network and the filter size, we need
18 | to make sure these the model we use can cover the sequence length T+20.
19 |
20 | - Using the `--seq_len` flag, one can change the # of values to recall (the typical setup is 10).
21 |
22 | From: https://github.com/locuslab/TCN/
23 |
--------------------------------------------------------------------------------
/tasks/copy_memory/main.py:
--------------------------------------------------------------------------------
1 | from uuid import uuid4
2 |
3 | import numpy as np
4 | from tensorflow.keras.callbacks import Callback
5 |
6 | from tcn import compiled_tcn
7 | from utils import data_generator
8 |
9 | x_train, y_train = data_generator(601, 10, 30000)
10 | x_test, y_test = data_generator(601, 10, 6000)
11 |
12 |
13 | class PrintSomeValues(Callback):
14 |
15 | def on_epoch_begin(self, epoch, logs={}):
16 | print('y_true')
17 | print(np.array(y_test[:5, -10:].squeeze(), dtype=int))
18 | print('y_pred')
19 | print(self.model.predict(x_test[:5])[:, -10:].argmax(axis=-1))
20 |
21 |
22 | def run_task():
23 | model = compiled_tcn(num_feat=1,
24 | num_classes=10,
25 | nb_filters=10,
26 | kernel_size=8,
27 | dilations=[2 ** i for i in range(9)],
28 | nb_stacks=1,
29 | max_len=x_train[0:1].shape[1],
30 | use_skip_connections=True,
31 | opt='rmsprop',
32 | lr=5e-4,
33 | # use_weight_norm=True,
34 | return_sequences=True)
35 |
36 | print(f'x_train.shape = {x_train.shape}')
37 | print(f'y_train.shape = {y_train.shape}')
38 |
39 | psv = PrintSomeValues()
40 |
41 | # Using sparse softmax.
42 | # http://chappers.github.io/web%20micro%20log/2017/01/26/quick-models-in-keras/
43 | model.summary()
44 |
45 | model.fit(x_train, y_train, validation_data=(x_test, y_test), epochs=100,
46 | callbacks=[psv], batch_size=256)
47 |
48 | test_acc = model.evaluate(x=x_test, y=y_test)[1] # accuracy.
49 | with open(f'copy_memory_{str(uuid4())[0:5]}.txt', 'w') as w:
50 | w.write(str(test_acc) + '\n')
51 |
52 |
53 | if __name__ == '__main__':
54 | run_task()
55 |
--------------------------------------------------------------------------------
/tasks/copy_memory/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def data_generator(t, mem_length, b_size):
5 | """
6 | Generate data for the copying memory task
7 | :param t: The total blank time length
8 | :param mem_length: The length of the memory to be recalled
9 | :param b_size: The batch size
10 | :return: Input and target data tensor
11 | """
12 | seq = np.array(np.random.randint(1, 9, size=(b_size, mem_length)), dtype=float)
13 | zeros = np.zeros((b_size, t))
14 | marker = 9 * np.ones((b_size, mem_length + 1))
15 | placeholders = np.zeros((b_size, mem_length))
16 |
17 | x = np.array(np.concatenate((seq, zeros[:, :-1], marker), 1), dtype=int)
18 | y = np.array(np.concatenate((placeholders, zeros, seq), 1), dtype=int)
19 | return np.expand_dims(x, axis=2).astype(np.float32), np.expand_dims(y, axis=2).astype(np.float32)
20 |
21 |
22 | if __name__ == '__main__':
23 | print(data_generator(t=601, mem_length=10, b_size=1)[0].flatten())
24 |
--------------------------------------------------------------------------------
/tasks/est_receptive_field.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from tensorflow.keras.layers import Dense
3 | from tensorflow.keras.models import Sequential
4 |
5 | from tcn import TCN
6 |
7 |
8 | # if time_steps > tcn_layer.receptive_field, then we should not
9 | # be able to solve this task.
10 |
11 |
12 | def get_x_y(time_steps, size=1000):
13 | pos_indices = np.random.choice(size, size=int(size // 2), replace=False)
14 | x_train = np.zeros(shape=(size, time_steps, 1))
15 | y_train = np.zeros(shape=(size, 1))
16 | x_train[pos_indices, 0] = 1.0 # we introduce the target in the first timestep of the sequence.
17 | y_train[pos_indices, 0] = 1.0 # the task is to see if the TCN can go back in time to find it.
18 | return x_train, y_train
19 |
20 |
21 | def new_bounds(dilations, bounds, input_dim, kernel_size, nb_stacks):
22 | # similar to the bisect algorithm.
23 | middle = int(np.mean(bounds))
24 | t1 = could_task_be_learned(dilations, bounds[0], input_dim, kernel_size, nb_stacks)
25 | t_middle = could_task_be_learned(dilations, middle, input_dim, kernel_size, nb_stacks)
26 | t2 = could_task_be_learned(dilations, bounds[1], input_dim, kernel_size, nb_stacks)
27 | go_left = t1 and not t_middle
28 | go_right = t_middle and not t2
29 | if go_left:
30 | assert not go_right
31 | if go_right:
32 | assert not go_left
33 | assert go_left or go_right
34 |
35 | if go_left:
36 | return np.array([bounds[0], middle])
37 | else:
38 | return np.array([middle, bounds[1]])
39 |
40 |
41 | def est_receptive_field(kernel_size, nb_stacks, dilations):
42 | print('K', 'S', 'D', kernel_size, nb_stacks, dilations)
43 | input_dim = 1
44 | bounds = np.array([5, 800])
45 | while True:
46 | bounds = new_bounds(dilations, bounds, input_dim, kernel_size, nb_stacks)
47 | if bounds[1] - bounds[0] <= 1:
48 | print(f'Receptive field: {bounds[0]}.')
49 | break
50 |
51 |
52 | def could_task_be_learned(dilations, guess, input_dim, kernel_size, nb_stacks):
53 | tcn_layer = TCN(
54 | kernel_size=kernel_size,
55 | dilations=dilations,
56 | nb_stacks=nb_stacks,
57 | input_shape=(guess, input_dim)
58 | )
59 |
60 | m = Sequential([
61 | tcn_layer,
62 | Dense(1, activation='sigmoid')
63 | ])
64 | m.compile(optimizer='adam', loss='mse', metrics=['accuracy'])
65 | x, y = get_x_y(guess)
66 | m.fit(x, y, validation_split=0.2, verbose=0, epochs=2)
67 | accuracy = m.evaluate(x, y, verbose=0)[1]
68 | task_is_learned = accuracy > 0.95
69 | return task_is_learned
70 |
71 |
72 | if __name__ == '__main__':
73 | est_receptive_field(kernel_size=2, nb_stacks=1, dilations=(1, 2, 4))
74 |
--------------------------------------------------------------------------------
/tasks/exchange_rate/demo.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {
7 | "ExecuteTime": {
8 | "end_time": "2019-10-14T02:42:42.022793Z",
9 | "start_time": "2019-10-14T02:42:40.338165Z"
10 | },
11 | "pycharm": {
12 | "is_executing": false
13 | }
14 | },
15 | "outputs": [
16 | {
17 | "name": "stderr",
18 | "output_type": "stream",
19 | "text": [
20 | "/opt/conda/lib/python3.6/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n",
21 | " from ._conv import register_converters as _register_converters\n",
22 | "Using TensorFlow backend.\n"
23 | ]
24 | }
25 | ],
26 | "source": [
27 | "from tcn import compiled_tcn\n",
28 | "from utils import get_xy_kfolds\n",
29 | "from sklearn.metrics import mean_squared_error\n",
30 | "import numpy as np"
31 | ]
32 | },
33 | {
34 | "cell_type": "code",
35 | "execution_count": 2,
36 | "metadata": {
37 | "ExecuteTime": {
38 | "end_time": "2019-10-14T02:42:42.451012Z",
39 | "start_time": "2019-10-14T02:42:42.025145Z"
40 | }
41 | },
42 | "outputs": [],
43 | "source": [
44 | "folds,enc = get_xy_kfolds()"
45 | ]
46 | },
47 | {
48 | "cell_type": "code",
49 | "execution_count": 3,
50 | "metadata": {
51 | "ExecuteTime": {
52 | "end_time": "2019-10-14T06:50:51.811354Z",
53 | "start_time": "2019-10-14T02:42:42.453309Z"
54 | },
55 | "pycharm": {
56 | "is_executing": true,
57 | "name": "#%%\n"
58 | }
59 | },
60 | "outputs": [
61 | {
62 | "name": "stdout",
63 | "output_type": "stream",
64 | "text": [
65 | "x.shape= (?, 24)\n",
66 | "model.x = (?, 1000, 8)\n",
67 | "model.y = (?, 8)\n",
68 | "Epoch 1/100\n",
69 | "2794/2794 [==============================] - 20s 7ms/step - loss: 68.0583\n",
70 | "Epoch 2/100\n",
71 | "2794/2794 [==============================] - 19s 7ms/step - loss: 41.9572\n",
72 | "Epoch 3/100\n",
73 | "2794/2794 [==============================] - 19s 7ms/step - loss: 24.5881\n",
74 | "Epoch 4/100\n",
75 | "2794/2794 [==============================] - 19s 7ms/step - loss: 13.4458\n",
76 | "Epoch 5/100\n",
77 | "2794/2794 [==============================] - 19s 7ms/step - loss: 6.4616\n",
78 | "Epoch 6/100\n",
79 | "2794/2794 [==============================] - 19s 7ms/step - loss: 2.4977\n",
80 | "Epoch 7/100\n",
81 | "2794/2794 [==============================] - 19s 7ms/step - loss: 1.0665\n",
82 | "Epoch 8/100\n",
83 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.8594\n",
84 | "Epoch 9/100\n",
85 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.6765\n",
86 | "Epoch 10/100\n",
87 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.5553\n",
88 | "Epoch 11/100\n",
89 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.4711\n",
90 | "Epoch 12/100\n",
91 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.4122\n",
92 | "Epoch 13/100\n",
93 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.3630\n",
94 | "Epoch 14/100\n",
95 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.3219\n",
96 | "Epoch 15/100\n",
97 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.2862\n",
98 | "Epoch 16/100\n",
99 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.2579\n",
100 | "Epoch 17/100\n",
101 | "2794/2794 [==============================] - 20s 7ms/step - loss: 0.2329\n",
102 | "Epoch 18/100\n",
103 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.2135\n",
104 | "Epoch 19/100\n",
105 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.1962\n",
106 | "Epoch 20/100\n",
107 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.1802\n",
108 | "Epoch 21/100\n",
109 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.1671\n",
110 | "Epoch 22/100\n",
111 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.1553\n",
112 | "Epoch 23/100\n",
113 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.1446\n",
114 | "Epoch 24/100\n",
115 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.1347\n",
116 | "Epoch 25/100\n",
117 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.1255\n",
118 | "Epoch 26/100\n",
119 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.1182\n",
120 | "Epoch 27/100\n",
121 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.1112\n",
122 | "Epoch 28/100\n",
123 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.1044\n",
124 | "Epoch 29/100\n",
125 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0990\n",
126 | "Epoch 30/100\n",
127 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0942\n",
128 | "Epoch 31/100\n",
129 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0882\n",
130 | "Epoch 32/100\n",
131 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0833\n",
132 | "Epoch 33/100\n",
133 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0789\n",
134 | "Epoch 34/100\n",
135 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0749\n",
136 | "Epoch 35/100\n",
137 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0715\n",
138 | "Epoch 36/100\n",
139 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0681\n",
140 | "Epoch 37/100\n",
141 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0650\n",
142 | "Epoch 38/100\n",
143 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0620\n",
144 | "Epoch 39/100\n",
145 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0591\n",
146 | "Epoch 40/100\n",
147 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0565\n",
148 | "Epoch 41/100\n",
149 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0539\n",
150 | "Epoch 42/100\n",
151 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0518\n",
152 | "Epoch 43/100\n",
153 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0498\n",
154 | "Epoch 44/100\n",
155 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0478\n",
156 | "Epoch 45/100\n",
157 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0457\n",
158 | "Epoch 46/100\n",
159 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0436\n",
160 | "Epoch 47/100\n",
161 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0418\n",
162 | "Epoch 48/100\n",
163 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0401\n",
164 | "Epoch 49/100\n",
165 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0386\n",
166 | "Epoch 50/100\n",
167 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0372\n",
168 | "Epoch 51/100\n",
169 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0359\n",
170 | "Epoch 52/100\n",
171 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0343\n",
172 | "Epoch 53/100\n",
173 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0330\n",
174 | "Epoch 54/100\n",
175 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0319\n",
176 | "Epoch 55/100\n",
177 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0309\n",
178 | "Epoch 56/100\n",
179 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0301\n",
180 | "Epoch 57/100\n",
181 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0288\n",
182 | "Epoch 58/100\n",
183 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0277\n",
184 | "Epoch 59/100\n",
185 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0270\n",
186 | "Epoch 60/100\n",
187 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0257\n",
188 | "Epoch 61/100\n",
189 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0249\n",
190 | "Epoch 62/100\n",
191 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0241\n",
192 | "Epoch 63/100\n",
193 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0232\n",
194 | "Epoch 64/100\n",
195 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0225\n",
196 | "Epoch 65/100\n",
197 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0221\n",
198 | "Epoch 66/100\n",
199 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0214\n",
200 | "Epoch 67/100\n",
201 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0206\n",
202 | "Epoch 68/100\n",
203 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0199\n",
204 | "Epoch 69/100\n",
205 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0195\n",
206 | "Epoch 70/100\n",
207 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0187\n",
208 | "Epoch 71/100\n",
209 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0181\n",
210 | "Epoch 72/100\n",
211 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0175\n",
212 | "Epoch 73/100\n",
213 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0172\n",
214 | "Epoch 74/100\n",
215 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0169\n",
216 | "Epoch 75/100\n",
217 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0161\n",
218 | "Epoch 76/100\n",
219 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0156\n",
220 | "Epoch 77/100\n",
221 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0153\n",
222 | "Epoch 78/100\n",
223 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0147\n",
224 | "Epoch 79/100\n",
225 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0143\n",
226 | "Epoch 80/100\n",
227 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0140\n",
228 | "Epoch 81/100\n",
229 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0136\n",
230 | "Epoch 82/100\n",
231 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0133\n",
232 | "Epoch 83/100\n",
233 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0129\n",
234 | "Epoch 84/100\n",
235 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0125\n",
236 | "Epoch 85/100\n",
237 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0122\n",
238 | "Epoch 86/100\n",
239 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0119\n",
240 | "Epoch 87/100\n",
241 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0116\n",
242 | "Epoch 88/100\n",
243 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0113\n",
244 | "Epoch 89/100\n",
245 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0110\n",
246 | "Epoch 90/100\n",
247 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0108\n",
248 | "Epoch 91/100\n",
249 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0105\n",
250 | "Epoch 92/100\n",
251 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0102\n",
252 | "Epoch 93/100\n",
253 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0099\n",
254 | "Epoch 94/100\n",
255 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0097\n",
256 | "Epoch 95/100\n"
257 | ]
258 | },
259 | {
260 | "name": "stdout",
261 | "output_type": "stream",
262 | "text": [
263 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0094\n",
264 | "Epoch 96/100\n",
265 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0093\n",
266 | "Epoch 97/100\n",
267 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0091\n",
268 | "Epoch 98/100\n",
269 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0089\n",
270 | "Epoch 99/100\n",
271 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0087\n",
272 | "Epoch 100/100\n",
273 | "2794/2794 [==============================] - 19s 7ms/step - loss: 0.0084\n"
274 | ]
275 | },
276 | {
277 | "data": {
278 | "text/plain": [
279 | ""
280 | ]
281 | },
282 | "execution_count": 3,
283 | "metadata": {},
284 | "output_type": "execute_result"
285 | },
286 | {
287 | "name": "stdout",
288 | "output_type": "stream",
289 | "text": [
290 | "train_set_size:2794\n",
291 | "y_true:[0.75585 1.8273 0.79321 0.810668 0.120824 0.009247 0.7094 0.601214]\n",
292 | "y_pred:[0.8300596 1.9546034 0.7468406 0.8549483 0.14127728 0.00971216\n",
293 | " 0.79883957 0.568041 ]\n",
294 | "mse:0.004417816002578773\n",
295 | "x.shape= (?, 24)\n",
296 | "model.x = (?, 1000, 8)\n",
297 | "model.y = (?, 8)\n",
298 | "Epoch 1/100\n",
299 | "3552/3552 [==============================] - 25s 7ms/step - loss: 264.8228\n",
300 | "Epoch 2/100\n",
301 | "3552/3552 [==============================] - 25s 7ms/step - loss: 150.9494\n",
302 | "Epoch 3/100\n",
303 | "3552/3552 [==============================] - 25s 7ms/step - loss: 77.5260\n",
304 | "Epoch 4/100\n",
305 | "3552/3552 [==============================] - 25s 7ms/step - loss: 34.8306\n",
306 | "Epoch 5/100\n",
307 | "3552/3552 [==============================] - 25s 7ms/step - loss: 12.4802\n",
308 | "Epoch 6/100\n",
309 | "3552/3552 [==============================] - 25s 7ms/step - loss: 3.1224\n",
310 | "Epoch 7/100\n",
311 | "3552/3552 [==============================] - 25s 7ms/step - loss: 1.8036\n",
312 | "Epoch 8/100\n",
313 | "3552/3552 [==============================] - 24s 7ms/step - loss: 1.2830\n",
314 | "Epoch 9/100\n",
315 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.9430\n",
316 | "Epoch 10/100\n",
317 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.7132\n",
318 | "Epoch 11/100\n",
319 | "3552/3552 [==============================] - 25s 7ms/step - loss: 0.5746\n",
320 | "Epoch 12/100\n",
321 | "3552/3552 [==============================] - 25s 7ms/step - loss: 0.4884\n",
322 | "Epoch 13/100\n",
323 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.4341\n",
324 | "Epoch 14/100\n",
325 | "3552/3552 [==============================] - 25s 7ms/step - loss: 0.4016\n",
326 | "Epoch 15/100\n",
327 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.3699\n",
328 | "Epoch 16/100\n",
329 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.3382\n",
330 | "Epoch 17/100\n",
331 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.3114\n",
332 | "Epoch 18/100\n",
333 | "3552/3552 [==============================] - 25s 7ms/step - loss: 0.2857\n",
334 | "Epoch 19/100\n",
335 | "3552/3552 [==============================] - 25s 7ms/step - loss: 0.2656\n",
336 | "Epoch 20/100\n",
337 | "3552/3552 [==============================] - 25s 7ms/step - loss: 0.2465\n",
338 | "Epoch 21/100\n",
339 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.2316\n",
340 | "Epoch 22/100\n",
341 | "3552/3552 [==============================] - 25s 7ms/step - loss: 0.2194\n",
342 | "Epoch 23/100\n",
343 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.2072\n",
344 | "Epoch 24/100\n",
345 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.1959\n",
346 | "Epoch 25/100\n",
347 | "3552/3552 [==============================] - 25s 7ms/step - loss: 0.1858\n",
348 | "Epoch 26/100\n",
349 | "3552/3552 [==============================] - 25s 7ms/step - loss: 0.1753\n",
350 | "Epoch 27/100\n",
351 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.1651\n",
352 | "Epoch 28/100\n",
353 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.1583\n",
354 | "Epoch 29/100\n",
355 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.1504\n",
356 | "Epoch 30/100\n",
357 | "3552/3552 [==============================] - 25s 7ms/step - loss: 0.1422\n",
358 | "Epoch 31/100\n",
359 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.1365\n",
360 | "Epoch 32/100\n",
361 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.1297\n",
362 | "Epoch 33/100\n",
363 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.1266\n",
364 | "Epoch 34/100\n",
365 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.1195\n",
366 | "Epoch 35/100\n",
367 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.1142\n",
368 | "Epoch 36/100\n",
369 | "3552/3552 [==============================] - 25s 7ms/step - loss: 0.1085\n",
370 | "Epoch 37/100\n",
371 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.1030\n",
372 | "Epoch 38/100\n",
373 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.0981\n",
374 | "Epoch 39/100\n",
375 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.0933\n",
376 | "Epoch 40/100\n",
377 | "3552/3552 [==============================] - 25s 7ms/step - loss: 0.0887\n",
378 | "Epoch 41/100\n",
379 | "3552/3552 [==============================] - 25s 7ms/step - loss: 0.0842\n",
380 | "Epoch 42/100\n",
381 | "3552/3552 [==============================] - 25s 7ms/step - loss: 0.0811\n",
382 | "Epoch 43/100\n",
383 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.0768\n",
384 | "Epoch 44/100\n",
385 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.0729\n",
386 | "Epoch 45/100\n",
387 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.0697\n",
388 | "Epoch 46/100\n",
389 | "3552/3552 [==============================] - 25s 7ms/step - loss: 0.0663\n",
390 | "Epoch 47/100\n",
391 | "3552/3552 [==============================] - 25s 7ms/step - loss: 0.0630\n",
392 | "Epoch 48/100\n",
393 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.0606\n",
394 | "Epoch 49/100\n",
395 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.0576\n",
396 | "Epoch 50/100\n",
397 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.0551\n",
398 | "Epoch 51/100\n",
399 | "3552/3552 [==============================] - 25s 7ms/step - loss: 0.0535\n",
400 | "Epoch 52/100\n",
401 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.0509\n",
402 | "Epoch 53/100\n",
403 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.0494\n",
404 | "Epoch 54/100\n",
405 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.0478\n",
406 | "Epoch 55/100\n",
407 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.0450\n",
408 | "Epoch 56/100\n",
409 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.0431\n",
410 | "Epoch 57/100\n",
411 | "3552/3552 [==============================] - 25s 7ms/step - loss: 0.0407\n",
412 | "Epoch 58/100\n",
413 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.0394\n",
414 | "Epoch 59/100\n",
415 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.0377\n",
416 | "Epoch 60/100\n",
417 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.0355\n",
418 | "Epoch 61/100\n",
419 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.0343\n",
420 | "Epoch 62/100\n",
421 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.0331\n",
422 | "Epoch 63/100\n",
423 | "3552/3552 [==============================] - 25s 7ms/step - loss: 0.0325\n",
424 | "Epoch 64/100\n",
425 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.0303\n",
426 | "Epoch 65/100\n",
427 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.0297\n",
428 | "Epoch 66/100\n",
429 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.0278\n",
430 | "Epoch 67/100\n",
431 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.0269\n",
432 | "Epoch 68/100\n",
433 | "3552/3552 [==============================] - 25s 7ms/step - loss: 0.0254\n",
434 | "Epoch 69/100\n",
435 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.0245\n",
436 | "Epoch 70/100\n",
437 | "3552/3552 [==============================] - 25s 7ms/step - loss: 0.0237\n",
438 | "Epoch 71/100\n",
439 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.0226\n",
440 | "Epoch 72/100\n",
441 | "3552/3552 [==============================] - 25s 7ms/step - loss: 0.0219\n",
442 | "Epoch 73/100\n",
443 | "3552/3552 [==============================] - 25s 7ms/step - loss: 0.0210\n",
444 | "Epoch 74/100\n",
445 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.0206\n",
446 | "Epoch 75/100\n",
447 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.0197\n",
448 | "Epoch 76/100\n",
449 | "3552/3552 [==============================] - 25s 7ms/step - loss: 0.0191\n",
450 | "Epoch 77/100\n",
451 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.0181\n",
452 | "Epoch 78/100\n",
453 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.0173\n",
454 | "Epoch 79/100\n",
455 | "3552/3552 [==============================] - 25s 7ms/step - loss: 0.0169\n",
456 | "Epoch 80/100\n",
457 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.0165\n",
458 | "Epoch 81/100\n",
459 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.0157\n",
460 | "Epoch 82/100\n",
461 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.0152\n",
462 | "Epoch 83/100\n",
463 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.0150\n",
464 | "Epoch 84/100\n",
465 | "3552/3552 [==============================] - 25s 7ms/step - loss: 0.0139\n",
466 | "Epoch 85/100\n",
467 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.0138\n",
468 | "Epoch 86/100\n",
469 | "3552/3552 [==============================] - 25s 7ms/step - loss: 0.0129\n",
470 | "Epoch 87/100\n",
471 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.0127\n",
472 | "Epoch 88/100\n",
473 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.0122\n",
474 | "Epoch 89/100\n",
475 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.0122\n",
476 | "Epoch 90/100\n",
477 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.0118\n",
478 | "Epoch 91/100\n",
479 | "3552/3552 [==============================] - 25s 7ms/step - loss: 0.0113\n",
480 | "Epoch 92/100\n",
481 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.0112\n",
482 | "Epoch 93/100\n"
483 | ]
484 | },
485 | {
486 | "name": "stdout",
487 | "output_type": "stream",
488 | "text": [
489 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.0105\n",
490 | "Epoch 94/100\n",
491 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.0104\n",
492 | "Epoch 95/100\n",
493 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.0100\n",
494 | "Epoch 96/100\n",
495 | "3552/3552 [==============================] - 25s 7ms/step - loss: 0.0096\n",
496 | "Epoch 97/100\n",
497 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.0095\n",
498 | "Epoch 98/100\n",
499 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.0092\n",
500 | "Epoch 99/100\n",
501 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.0090\n",
502 | "Epoch 100/100\n",
503 | "3552/3552 [==============================] - 24s 7ms/step - loss: 0.0085\n"
504 | ]
505 | },
506 | {
507 | "data": {
508 | "text/plain": [
509 | ""
510 | ]
511 | },
512 | "execution_count": 3,
513 | "metadata": {},
514 | "output_type": "execute_result"
515 | },
516 | {
517 | "name": "stdout",
518 | "output_type": "stream",
519 | "text": [
520 | "train_set_size:3552\n",
521 | "y_true:[0.88545 2.0517 1.018537 0.895375 0.134689 0.009079 0.75525 0.690846]\n",
522 | "y_pred:[0.83563536 2.0971928 0.92631006 0.9065408 0.14294215 0.01020975\n",
523 | " 0.7965754 0.6845535 ]\n",
524 | "mse:0.0018747940419921757\n",
525 | "x.shape= (?, 24)\n",
526 | "model.x = (?, 1000, 8)\n",
527 | "model.y = (?, 8)\n",
528 | "Epoch 1/100\n",
529 | "4311/4311 [==============================] - 30s 7ms/step - loss: 43.0135\n",
530 | "Epoch 2/100\n",
531 | "4311/4311 [==============================] - 30s 7ms/step - loss: 12.3071\n",
532 | "Epoch 3/100\n",
533 | "4311/4311 [==============================] - 30s 7ms/step - loss: 3.5359\n",
534 | "Epoch 4/100\n",
535 | "4311/4311 [==============================] - 30s 7ms/step - loss: 2.2983\n",
536 | "Epoch 5/100\n",
537 | "4311/4311 [==============================] - 30s 7ms/step - loss: 1.5991\n",
538 | "Epoch 6/100\n",
539 | "4311/4311 [==============================] - 30s 7ms/step - loss: 1.1558\n",
540 | "Epoch 7/100\n",
541 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.8747\n",
542 | "Epoch 8/100\n",
543 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.7059\n",
544 | "Epoch 9/100\n",
545 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.5858\n",
546 | "Epoch 10/100\n",
547 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.4895\n",
548 | "Epoch 11/100\n",
549 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.4085\n",
550 | "Epoch 12/100\n",
551 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.3448\n",
552 | "Epoch 13/100\n",
553 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.2970\n",
554 | "Epoch 14/100\n",
555 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.2587\n",
556 | "Epoch 15/100\n",
557 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.2278\n",
558 | "Epoch 16/100\n",
559 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.2028\n",
560 | "Epoch 17/100\n",
561 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.1811\n",
562 | "Epoch 18/100\n",
563 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.1627\n",
564 | "Epoch 19/100\n",
565 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.1499\n",
566 | "Epoch 20/100\n",
567 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.1339\n",
568 | "Epoch 21/100\n",
569 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.1235\n",
570 | "Epoch 22/100\n",
571 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.1145\n",
572 | "Epoch 23/100\n",
573 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.1038\n",
574 | "Epoch 24/100\n",
575 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0956\n",
576 | "Epoch 25/100\n",
577 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0874\n",
578 | "Epoch 26/100\n",
579 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0799\n",
580 | "Epoch 27/100\n",
581 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0730\n",
582 | "Epoch 28/100\n",
583 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0672\n",
584 | "Epoch 29/100\n",
585 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0619\n",
586 | "Epoch 30/100\n",
587 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0569\n",
588 | "Epoch 31/100\n",
589 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0529\n",
590 | "Epoch 32/100\n",
591 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0495\n",
592 | "Epoch 33/100\n",
593 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0461\n",
594 | "Epoch 34/100\n",
595 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0430\n",
596 | "Epoch 35/100\n",
597 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0399\n",
598 | "Epoch 36/100\n",
599 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0374\n",
600 | "Epoch 37/100\n",
601 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0350\n",
602 | "Epoch 38/100\n",
603 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0327\n",
604 | "Epoch 39/100\n",
605 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0306\n",
606 | "Epoch 40/100\n",
607 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0290\n",
608 | "Epoch 41/100\n",
609 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0270\n",
610 | "Epoch 42/100\n",
611 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0253\n",
612 | "Epoch 43/100\n",
613 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0243\n",
614 | "Epoch 44/100\n",
615 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0229\n",
616 | "Epoch 45/100\n",
617 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0217\n",
618 | "Epoch 46/100\n",
619 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0207\n",
620 | "Epoch 47/100\n",
621 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0193\n",
622 | "Epoch 48/100\n",
623 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0183\n",
624 | "Epoch 49/100\n",
625 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0175\n",
626 | "Epoch 50/100\n",
627 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0166\n",
628 | "Epoch 51/100\n",
629 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0156\n",
630 | "Epoch 52/100\n",
631 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0149\n",
632 | "Epoch 53/100\n",
633 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0141\n",
634 | "Epoch 54/100\n",
635 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0135\n",
636 | "Epoch 55/100\n",
637 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0130\n",
638 | "Epoch 56/100\n",
639 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0124\n",
640 | "Epoch 57/100\n",
641 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0119\n",
642 | "Epoch 58/100\n",
643 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0115\n",
644 | "Epoch 59/100\n",
645 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0110\n",
646 | "Epoch 60/100\n",
647 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0104\n",
648 | "Epoch 61/100\n",
649 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0101\n",
650 | "Epoch 62/100\n",
651 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0096\n",
652 | "Epoch 63/100\n",
653 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0092\n",
654 | "Epoch 64/100\n",
655 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0090\n",
656 | "Epoch 65/100\n",
657 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0089\n",
658 | "Epoch 66/100\n",
659 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0088\n",
660 | "Epoch 67/100\n",
661 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0083\n",
662 | "Epoch 68/100\n",
663 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0079\n",
664 | "Epoch 69/100\n",
665 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0075\n",
666 | "Epoch 70/100\n",
667 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0073\n",
668 | "Epoch 71/100\n",
669 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0071\n",
670 | "Epoch 72/100\n",
671 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0070\n",
672 | "Epoch 73/100\n",
673 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0069\n",
674 | "Epoch 74/100\n",
675 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0065\n",
676 | "Epoch 75/100\n",
677 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0064\n",
678 | "Epoch 76/100\n",
679 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0062\n",
680 | "Epoch 77/100\n",
681 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0060\n",
682 | "Epoch 78/100\n",
683 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0059\n",
684 | "Epoch 79/100\n",
685 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0058\n",
686 | "Epoch 80/100\n",
687 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0057\n",
688 | "Epoch 81/100\n",
689 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0056\n",
690 | "Epoch 82/100\n",
691 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0054\n",
692 | "Epoch 83/100\n",
693 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0054\n",
694 | "Epoch 84/100\n",
695 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0052\n",
696 | "Epoch 85/100\n",
697 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0052\n",
698 | "Epoch 86/100\n",
699 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0050\n",
700 | "Epoch 87/100\n",
701 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0050\n",
702 | "Epoch 88/100\n",
703 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0049\n",
704 | "Epoch 89/100\n",
705 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0048\n",
706 | "Epoch 90/100\n",
707 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0049\n",
708 | "Epoch 91/100\n",
709 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0048\n",
710 | "Epoch 92/100\n",
711 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0047\n",
712 | "Epoch 93/100\n"
713 | ]
714 | },
715 | {
716 | "name": "stdout",
717 | "output_type": "stream",
718 | "text": [
719 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0047\n",
720 | "Epoch 94/100\n",
721 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0047\n",
722 | "Epoch 95/100\n",
723 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0046\n",
724 | "Epoch 96/100\n",
725 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0044\n",
726 | "Epoch 97/100\n",
727 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0045\n",
728 | "Epoch 98/100\n",
729 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0045\n",
730 | "Epoch 99/100\n",
731 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0043\n",
732 | "Epoch 100/100\n",
733 | "4311/4311 [==============================] - 30s 7ms/step - loss: 0.0043\n"
734 | ]
735 | },
736 | {
737 | "data": {
738 | "text/plain": [
739 | ""
740 | ]
741 | },
742 | "execution_count": 3,
743 | "metadata": {},
744 | "output_type": "execute_result"
745 | },
746 | {
747 | "name": "stdout",
748 | "output_type": "stream",
749 | "text": [
750 | "train_set_size:4311\n",
751 | "y_true:[0.860363 1.474491 0.96965 0.87581 0.146342 0.010898 0.696379 0.71623 ]\n",
752 | "y_pred:[0.8938178 1.5261562 0.9444967 0.9710834 0.14255804 0.0103059\n",
753 | " 0.7613744 0.73095745]\n",
754 | "mse:0.0022442743423320343\n",
755 | "x.shape= (?, 24)\n",
756 | "model.x = (?, 1000, 8)\n",
757 | "model.y = (?, 8)\n",
758 | "Epoch 1/100\n",
759 | "5070/5070 [==============================] - 36s 7ms/step - loss: 45.8226\n",
760 | "Epoch 2/100\n",
761 | "5070/5070 [==============================] - 35s 7ms/step - loss: 4.6997\n",
762 | "Epoch 3/100\n",
763 | "5070/5070 [==============================] - 35s 7ms/step - loss: 2.3479\n",
764 | "Epoch 4/100\n",
765 | "5070/5070 [==============================] - 35s 7ms/step - loss: 1.6677\n",
766 | "Epoch 5/100\n",
767 | "5070/5070 [==============================] - 35s 7ms/step - loss: 1.2386\n",
768 | "Epoch 6/100\n",
769 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.9361\n",
770 | "Epoch 7/100\n",
771 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.7327\n",
772 | "Epoch 8/100\n",
773 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.5901\n",
774 | "Epoch 9/100\n",
775 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.4862\n",
776 | "Epoch 10/100\n",
777 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.4145\n",
778 | "Epoch 11/100\n",
779 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.3669\n",
780 | "Epoch 12/100\n",
781 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.3260\n",
782 | "Epoch 13/100\n",
783 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.2939\n",
784 | "Epoch 14/100\n",
785 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.2691\n",
786 | "Epoch 15/100\n",
787 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.2454\n",
788 | "Epoch 16/100\n",
789 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.2224\n",
790 | "Epoch 17/100\n",
791 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.2035\n",
792 | "Epoch 18/100\n",
793 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.1859\n",
794 | "Epoch 19/100\n",
795 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.1717\n",
796 | "Epoch 20/100\n",
797 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.1597\n",
798 | "Epoch 21/100\n",
799 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.1494\n",
800 | "Epoch 22/100\n",
801 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.1403\n",
802 | "Epoch 23/100\n",
803 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.1309\n",
804 | "Epoch 24/100\n",
805 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.1249\n",
806 | "Epoch 25/100\n",
807 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.1155\n",
808 | "Epoch 26/100\n",
809 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.1089\n",
810 | "Epoch 27/100\n",
811 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.1026\n",
812 | "Epoch 28/100\n",
813 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0973\n",
814 | "Epoch 29/100\n",
815 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0913\n",
816 | "Epoch 30/100\n",
817 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0869\n",
818 | "Epoch 31/100\n",
819 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0823\n",
820 | "Epoch 32/100\n",
821 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0800\n",
822 | "Epoch 33/100\n",
823 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0743\n",
824 | "Epoch 34/100\n",
825 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0704\n",
826 | "Epoch 35/100\n",
827 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0667\n",
828 | "Epoch 36/100\n",
829 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0647\n",
830 | "Epoch 37/100\n",
831 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0611\n",
832 | "Epoch 38/100\n",
833 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0579\n",
834 | "Epoch 39/100\n",
835 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0553\n",
836 | "Epoch 40/100\n",
837 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0538\n",
838 | "Epoch 41/100\n",
839 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0535\n",
840 | "Epoch 42/100\n",
841 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0498\n",
842 | "Epoch 43/100\n",
843 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0478\n",
844 | "Epoch 44/100\n",
845 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0455\n",
846 | "Epoch 45/100\n",
847 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0433\n",
848 | "Epoch 46/100\n",
849 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0424\n",
850 | "Epoch 47/100\n",
851 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0406\n",
852 | "Epoch 48/100\n",
853 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0384\n",
854 | "Epoch 49/100\n",
855 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0382\n",
856 | "Epoch 50/100\n",
857 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0368\n",
858 | "Epoch 51/100\n",
859 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0347\n",
860 | "Epoch 52/100\n",
861 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0341\n",
862 | "Epoch 53/100\n",
863 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0331\n",
864 | "Epoch 54/100\n",
865 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0313\n",
866 | "Epoch 55/100\n",
867 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0304\n",
868 | "Epoch 56/100\n",
869 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0307\n",
870 | "Epoch 57/100\n",
871 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0287\n",
872 | "Epoch 58/100\n",
873 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0277\n",
874 | "Epoch 59/100\n",
875 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0268\n",
876 | "Epoch 60/100\n",
877 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0260\n",
878 | "Epoch 61/100\n",
879 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0250\n",
880 | "Epoch 62/100\n",
881 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0243\n",
882 | "Epoch 63/100\n",
883 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0235\n",
884 | "Epoch 64/100\n",
885 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0226\n",
886 | "Epoch 65/100\n",
887 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0229\n",
888 | "Epoch 66/100\n",
889 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0219\n",
890 | "Epoch 67/100\n",
891 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0215\n",
892 | "Epoch 68/100\n",
893 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0213\n",
894 | "Epoch 69/100\n",
895 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0198\n",
896 | "Epoch 70/100\n",
897 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0192\n",
898 | "Epoch 71/100\n",
899 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0188\n",
900 | "Epoch 72/100\n",
901 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0182\n",
902 | "Epoch 73/100\n",
903 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0183\n",
904 | "Epoch 74/100\n",
905 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0174\n",
906 | "Epoch 75/100\n",
907 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0183\n",
908 | "Epoch 76/100\n",
909 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0172\n",
910 | "Epoch 77/100\n",
911 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0162\n",
912 | "Epoch 78/100\n",
913 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0159\n",
914 | "Epoch 79/100\n",
915 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0151\n",
916 | "Epoch 80/100\n",
917 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0149\n",
918 | "Epoch 81/100\n",
919 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0144\n",
920 | "Epoch 82/100\n",
921 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0142\n",
922 | "Epoch 83/100\n",
923 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0137\n",
924 | "Epoch 84/100\n",
925 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0137\n",
926 | "Epoch 85/100\n",
927 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0132\n",
928 | "Epoch 86/100\n",
929 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0136\n",
930 | "Epoch 87/100\n",
931 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0125\n",
932 | "Epoch 88/100\n",
933 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0125\n",
934 | "Epoch 89/100\n",
935 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0121\n",
936 | "Epoch 90/100\n",
937 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0118\n",
938 | "Epoch 91/100\n",
939 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0117\n",
940 | "Epoch 92/100\n",
941 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0118\n",
942 | "Epoch 93/100\n"
943 | ]
944 | },
945 | {
946 | "name": "stdout",
947 | "output_type": "stream",
948 | "text": [
949 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0114\n",
950 | "Epoch 94/100\n",
951 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0109\n",
952 | "Epoch 95/100\n",
953 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0108\n",
954 | "Epoch 96/100\n",
955 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0105\n",
956 | "Epoch 97/100\n",
957 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0104\n",
958 | "Epoch 98/100\n",
959 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0102\n",
960 | "Epoch 99/100\n",
961 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0101\n",
962 | "Epoch 100/100\n",
963 | "5070/5070 [==============================] - 35s 7ms/step - loss: 0.0093\n"
964 | ]
965 | },
966 | {
967 | "data": {
968 | "text/plain": [
969 | ""
970 | ]
971 | },
972 | "execution_count": 3,
973 | "metadata": {},
974 | "output_type": "execute_result"
975 | },
976 | {
977 | "name": "stdout",
978 | "output_type": "stream",
979 | "text": [
980 | "train_set_size:5070\n",
981 | "y_true:[1.026905 1.611733 1.014096 1.079214 0.159627 0.012674 0.813603 0.819672]\n",
982 | "y_pred:[1.0476167 1.586348 1.0373206 1.1811144 0.18007559 0.01273255\n",
983 | " 0.7849961 0.86136806]\n",
984 | "mse:0.0018714395955997517\n",
985 | "x.shape= (?, 24)\n",
986 | "model.x = (?, 1000, 8)\n",
987 | "model.y = (?, 8)\n",
988 | "Epoch 1/100\n",
989 | "5829/5829 [==============================] - 41s 7ms/step - loss: 139.7857\n",
990 | "Epoch 2/100\n",
991 | "5829/5829 [==============================] - 40s 7ms/step - loss: 35.0750\n",
992 | "Epoch 3/100\n",
993 | "5829/5829 [==============================] - 40s 7ms/step - loss: 4.4459\n",
994 | "Epoch 4/100\n",
995 | "5829/5829 [==============================] - 40s 7ms/step - loss: 1.5621\n",
996 | "Epoch 5/100\n",
997 | "5829/5829 [==============================] - 41s 7ms/step - loss: 1.1352\n",
998 | "Epoch 6/100\n",
999 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.8480\n",
1000 | "Epoch 7/100\n",
1001 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.6581\n",
1002 | "Epoch 8/100\n",
1003 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.5356\n",
1004 | "Epoch 9/100\n",
1005 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.4467\n",
1006 | "Epoch 10/100\n",
1007 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.3843\n",
1008 | "Epoch 11/100\n",
1009 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.3407\n",
1010 | "Epoch 12/100\n",
1011 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.3037\n",
1012 | "Epoch 13/100\n",
1013 | "5829/5829 [==============================] - 41s 7ms/step - loss: 0.2729\n",
1014 | "Epoch 14/100\n",
1015 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.2480\n",
1016 | "Epoch 15/100\n",
1017 | "5829/5829 [==============================] - 41s 7ms/step - loss: 0.2266\n",
1018 | "Epoch 16/100\n",
1019 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.2131\n",
1020 | "Epoch 17/100\n",
1021 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.1996\n",
1022 | "Epoch 18/100\n",
1023 | "5829/5829 [==============================] - 41s 7ms/step - loss: 0.1867\n",
1024 | "Epoch 19/100\n",
1025 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.1748\n",
1026 | "Epoch 20/100\n",
1027 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.1636\n",
1028 | "Epoch 21/100\n",
1029 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.1545\n",
1030 | "Epoch 22/100\n",
1031 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.1452\n",
1032 | "Epoch 23/100\n",
1033 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.1379\n",
1034 | "Epoch 24/100\n",
1035 | "5829/5829 [==============================] - 41s 7ms/step - loss: 0.1302\n",
1036 | "Epoch 25/100\n",
1037 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.1228\n",
1038 | "Epoch 26/100\n",
1039 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.1191\n",
1040 | "Epoch 27/100\n",
1041 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.1116\n",
1042 | "Epoch 28/100\n",
1043 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.1069\n",
1044 | "Epoch 29/100\n",
1045 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.1012\n",
1046 | "Epoch 30/100\n",
1047 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.0965\n",
1048 | "Epoch 31/100\n",
1049 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.0911\n",
1050 | "Epoch 32/100\n",
1051 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.0865\n",
1052 | "Epoch 33/100\n",
1053 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.0818\n",
1054 | "Epoch 34/100\n",
1055 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.0777\n",
1056 | "Epoch 35/100\n",
1057 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.0751\n",
1058 | "Epoch 36/100\n",
1059 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.0711\n",
1060 | "Epoch 37/100\n",
1061 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.0678\n",
1062 | "Epoch 38/100\n",
1063 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.0642\n",
1064 | "Epoch 39/100\n",
1065 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.0616\n",
1066 | "Epoch 40/100\n",
1067 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.0584\n",
1068 | "Epoch 41/100\n",
1069 | "5829/5829 [==============================] - 41s 7ms/step - loss: 0.0554\n",
1070 | "Epoch 42/100\n",
1071 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.0532\n",
1072 | "Epoch 43/100\n",
1073 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.0499\n",
1074 | "Epoch 44/100\n",
1075 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.0482\n",
1076 | "Epoch 45/100\n",
1077 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.0455\n",
1078 | "Epoch 46/100\n",
1079 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.0427\n",
1080 | "Epoch 47/100\n",
1081 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.0404\n",
1082 | "Epoch 48/100\n",
1083 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.0386\n",
1084 | "Epoch 49/100\n",
1085 | "5829/5829 [==============================] - 41s 7ms/step - loss: 0.0368\n",
1086 | "Epoch 50/100\n",
1087 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.0347\n",
1088 | "Epoch 51/100\n",
1089 | "5829/5829 [==============================] - 41s 7ms/step - loss: 0.0331\n",
1090 | "Epoch 52/100\n",
1091 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.0313\n",
1092 | "Epoch 53/100\n",
1093 | "5829/5829 [==============================] - 41s 7ms/step - loss: 0.0298\n",
1094 | "Epoch 54/100\n",
1095 | "5829/5829 [==============================] - 41s 7ms/step - loss: 0.0285\n",
1096 | "Epoch 55/100\n",
1097 | "5829/5829 [==============================] - 41s 7ms/step - loss: 0.0269\n",
1098 | "Epoch 56/100\n",
1099 | "5829/5829 [==============================] - 41s 7ms/step - loss: 0.0256\n",
1100 | "Epoch 57/100\n",
1101 | "5829/5829 [==============================] - 41s 7ms/step - loss: 0.0246\n",
1102 | "Epoch 58/100\n",
1103 | "5829/5829 [==============================] - 41s 7ms/step - loss: 0.0233\n",
1104 | "Epoch 59/100\n",
1105 | "5829/5829 [==============================] - 41s 7ms/step - loss: 0.0221\n",
1106 | "Epoch 60/100\n",
1107 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.0211\n",
1108 | "Epoch 61/100\n",
1109 | "5829/5829 [==============================] - 41s 7ms/step - loss: 0.0213\n",
1110 | "Epoch 62/100\n",
1111 | "5829/5829 [==============================] - 41s 7ms/step - loss: 0.0207\n",
1112 | "Epoch 63/100\n",
1113 | "5829/5829 [==============================] - 41s 7ms/step - loss: 0.0190\n",
1114 | "Epoch 64/100\n",
1115 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.0183\n",
1116 | "Epoch 65/100\n",
1117 | "5829/5829 [==============================] - 41s 7ms/step - loss: 0.0174\n",
1118 | "Epoch 66/100\n",
1119 | "5829/5829 [==============================] - 41s 7ms/step - loss: 0.0166\n",
1120 | "Epoch 67/100\n",
1121 | "5829/5829 [==============================] - 41s 7ms/step - loss: 0.0160\n",
1122 | "Epoch 68/100\n",
1123 | "5829/5829 [==============================] - 41s 7ms/step - loss: 0.0151\n",
1124 | "Epoch 69/100\n",
1125 | "5829/5829 [==============================] - 41s 7ms/step - loss: 0.0145\n",
1126 | "Epoch 70/100\n",
1127 | "5829/5829 [==============================] - 41s 7ms/step - loss: 0.0142\n",
1128 | "Epoch 71/100\n",
1129 | "5829/5829 [==============================] - 41s 7ms/step - loss: 0.0140\n",
1130 | "Epoch 72/100\n",
1131 | "5829/5829 [==============================] - 41s 7ms/step - loss: 0.0129\n",
1132 | "Epoch 73/100\n",
1133 | "5829/5829 [==============================] - 41s 7ms/step - loss: 0.0123\n",
1134 | "Epoch 74/100\n",
1135 | "5829/5829 [==============================] - 41s 7ms/step - loss: 0.0119\n",
1136 | "Epoch 75/100\n",
1137 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.0119\n",
1138 | "Epoch 76/100\n",
1139 | "5829/5829 [==============================] - 41s 7ms/step - loss: 0.0113\n",
1140 | "Epoch 77/100\n",
1141 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.0110\n",
1142 | "Epoch 78/100\n",
1143 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.0107\n",
1144 | "Epoch 79/100\n",
1145 | "5829/5829 [==============================] - 41s 7ms/step - loss: 0.0102\n",
1146 | "Epoch 80/100\n",
1147 | "5829/5829 [==============================] - 41s 7ms/step - loss: 0.0099\n",
1148 | "Epoch 81/100\n",
1149 | "5829/5829 [==============================] - 41s 7ms/step - loss: 0.0095\n",
1150 | "Epoch 82/100\n",
1151 | "5829/5829 [==============================] - 41s 7ms/step - loss: 0.0090\n",
1152 | "Epoch 83/100\n",
1153 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.0086\n",
1154 | "Epoch 84/100\n",
1155 | "5829/5829 [==============================] - 41s 7ms/step - loss: 0.0086\n",
1156 | "Epoch 85/100\n",
1157 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.0083\n",
1158 | "Epoch 86/100\n",
1159 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.0084\n",
1160 | "Epoch 87/100\n",
1161 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.0081\n",
1162 | "Epoch 88/100\n",
1163 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.0078\n",
1164 | "Epoch 89/100\n",
1165 | "5829/5829 [==============================] - 41s 7ms/step - loss: 0.0075\n",
1166 | "Epoch 90/100\n",
1167 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.0072\n",
1168 | "Epoch 91/100\n",
1169 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.0072\n",
1170 | "Epoch 92/100\n",
1171 | "5829/5829 [==============================] - 41s 7ms/step - loss: 0.0072\n",
1172 | "Epoch 93/100\n"
1173 | ]
1174 | },
1175 | {
1176 | "name": "stdout",
1177 | "output_type": "stream",
1178 | "text": [
1179 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.0071\n",
1180 | "Epoch 94/100\n",
1181 | "5829/5829 [==============================] - 41s 7ms/step - loss: 0.0066\n",
1182 | "Epoch 95/100\n",
1183 | "5829/5829 [==============================] - 41s 7ms/step - loss: 0.0064\n",
1184 | "Epoch 96/100\n",
1185 | "5829/5829 [==============================] - 41s 7ms/step - loss: 0.0064\n",
1186 | "Epoch 97/100\n",
1187 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.0061\n",
1188 | "Epoch 98/100\n",
1189 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.0062\n",
1190 | "Epoch 99/100\n",
1191 | "5829/5829 [==============================] - 41s 7ms/step - loss: 0.0059\n",
1192 | "Epoch 100/100\n",
1193 | "5829/5829 [==============================] - 40s 7ms/step - loss: 0.0057\n"
1194 | ]
1195 | },
1196 | {
1197 | "data": {
1198 | "text/plain": [
1199 | ""
1200 | ]
1201 | },
1202 | "execution_count": 3,
1203 | "metadata": {},
1204 | "output_type": "execute_result"
1205 | },
1206 | {
1207 | "name": "stdout",
1208 | "output_type": "stream",
1209 | "text": [
1210 | "train_set_size:5829\n",
1211 | "y_true:[0.845201 1.563942 0.877281 1.028278 0.162615 0.008391 0.779757 0.763601]\n",
1212 | "y_pred:[0.8108509 1.5758184 0.8468687 1.0259019 0.15817188 0.00943068\n",
1213 | " 0.74160624 0.7605794 ]\n",
1214 | "mse:0.0004671205911515846\n",
1215 | "0.002175088914730864\n"
1216 | ]
1217 | }
1218 | ],
1219 | "source": [
1220 | "mse_list = []\n",
1221 | "for train_x,train_y,test_x,test_y in folds:\n",
1222 | " model = compiled_tcn(return_sequences=False,\n",
1223 | " num_feat=test_x.shape[1],\n",
1224 | " nb_filters=24,\n",
1225 | " num_classes=0,\n",
1226 | " kernel_size=8,\n",
1227 | " dilations=[2 ** i for i in range(9)],\n",
1228 | " nb_stacks=1,\n",
1229 | " max_len=test_x.shape[0],\n",
1230 | " use_skip_connections=True,\n",
1231 | " regression=True,\n",
1232 | " dropout_rate=0,\n",
1233 | " output_len= test_y.shape[0])\n",
1234 | " model.fit(train_x,train_y,batch_size=256,epochs=100)\n",
1235 | " y_raw_pred = model.predict(np.array([test_x]))\n",
1236 | " y_pred = enc.inverse_transform(y_raw_pred).flatten()\n",
1237 | " y_true = enc.inverse_transform([test_y]).flatten()\n",
1238 | " mse_cur = mean_squared_error(y_true,y_pred)\n",
1239 | " mse_list.append(mse_cur)\n",
1240 | " print(f\"train_set_size:{train_x.shape[0]}\")\n",
1241 | " print(f\"y_true:{y_true}\")\n",
1242 | " print(f\"y_pred:{y_pred}\")\n",
1243 | " print(f\"mse:{mse_cur}\")\n",
1244 | "print(np.mean(mse_list))"
1245 | ]
1246 | },
1247 | {
1248 | "cell_type": "code",
1249 | "execution_count": 4,
1250 | "metadata": {
1251 | "ExecuteTime": {
1252 | "end_time": "2019-10-14T06:50:51.821444Z",
1253 | "start_time": "2019-10-14T06:50:51.815596Z"
1254 | },
1255 | "pycharm": {
1256 | "is_executing": false,
1257 | "name": "#%%\n"
1258 | }
1259 | },
1260 | "outputs": [
1261 | {
1262 | "name": "stdout",
1263 | "output_type": "stream",
1264 | "text": [
1265 | "total mse on test set: 0.002175088914730864\n"
1266 | ]
1267 | }
1268 | ],
1269 | "source": [
1270 | "print(f\"total mse on test set: {np.mean(mse_list)}\")"
1271 | ]
1272 | }
1273 | ],
1274 | "metadata": {
1275 | "kernelspec": {
1276 | "display_name": "Python 3",
1277 | "language": "python",
1278 | "name": "python3"
1279 | },
1280 | "language_info": {
1281 | "codemirror_mode": {
1282 | "name": "ipython",
1283 | "version": 3
1284 | },
1285 | "file_extension": ".py",
1286 | "mimetype": "text/x-python",
1287 | "name": "python",
1288 | "nbconvert_exporter": "python",
1289 | "pygments_lexer": "ipython3",
1290 | "version": "3.6.6"
1291 | },
1292 | "pycharm": {
1293 | "stem_cell": {
1294 | "cell_type": "raw",
1295 | "metadata": {
1296 | "collapsed": false
1297 | },
1298 | "source": []
1299 | }
1300 | },
1301 | "toc": {
1302 | "base_numbering": 1,
1303 | "nav_menu": {},
1304 | "number_sections": true,
1305 | "sideBar": true,
1306 | "skip_h1_title": false,
1307 | "title_cell": "Table of Contents",
1308 | "title_sidebar": "Contents",
1309 | "toc_cell": false,
1310 | "toc_position": {},
1311 | "toc_section_display": true,
1312 | "toc_window_display": false
1313 | }
1314 | },
1315 | "nbformat": 4,
1316 | "nbformat_minor": 1
1317 | }
1318 |
--------------------------------------------------------------------------------
/tasks/exchange_rate/main.py:
--------------------------------------------------------------------------------
1 | from tcn import compiled_tcn
2 | from utils import get_xy_kfolds
3 | from sklearn.metrics import mean_squared_error
4 | import numpy as np
5 |
6 | # dataset source: https://github.com/laiguokun/multivariate-time-series-data
7 | # exchange rate: the collection of the daily exchange rates of eight foreign countries
8 | # including Australia, British, Canada, Switzerland, China, Japan, New Zealand and
9 | # Singapore ranging from 1990 to 2016.
10 | # task: predict multi-column daily exchange rate from history
11 |
12 | folds, enc = get_xy_kfolds()
13 | mse_list = []
14 |
15 | if __name__ == '__main__':
16 | mse_list = []
17 | for train_x, train_y, test_x, test_y in folds:
18 | model = compiled_tcn(return_sequences=False,
19 | num_feat=test_x.shape[1],
20 | nb_filters=24,
21 | num_classes=0,
22 | kernel_size=8,
23 | dilations=[2 ** i for i in range(9)],
24 | nb_stacks=1,
25 | max_len=test_x.shape[0],
26 | use_skip_connections=True,
27 | regression=True,
28 | dropout_rate=0,
29 | output_len=test_y.shape[0])
30 | model.fit(train_x, train_y, batch_size=256, epochs=100)
31 | y_raw_pred = model.predict(np.array([test_x]))
32 | y_pred = enc.inverse_transform(y_raw_pred).flatten()
33 | y_true = enc.inverse_transform([test_y]).flatten()
34 | mse_cur = mean_squared_error(y_true, y_pred)
35 | mse_list.append(mse_cur)
36 | print(f"train_set_size:{train_x.shape[0]}")
37 | print(f"y_true:{y_true}")
38 | print(f"y_pred:{y_pred}")
39 | print(f"mse:{mse_cur}")
40 | print(f"finial loss on test set: {np.mean(mse_list)}")
41 |
--------------------------------------------------------------------------------
/tasks/exchange_rate/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from sklearn.preprocessing import MinMaxScaler
3 |
4 |
5 | def get_xy_kfolds(split_index=[0.5, 0.6, 0.7, 0.8, 0.9], timesteps=1000):
6 | """
7 | load exchange rate dataset and preprecess it, then split it into k-folds for CV
8 | :param split_index: list, the ratio of whole dataset as train set
9 | :param timesteps: length of a single train x sample
10 | :return: list, [train_x_set,train_y_set,test_x_single,test_y_single]
11 | """
12 | df = np.loadtxt('exchange_rate.txt', delimiter=',')
13 | n = len(df)
14 | folds = []
15 | enc = MinMaxScaler()
16 | df = enc.fit_transform(df)
17 | for split_point in split_index:
18 | train_end = int(split_point * n)
19 | train_x, train_y = [], []
20 | for i in range(train_end - timesteps):
21 | train_x.append(df[i:i + timesteps])
22 | train_y.append(df[i + timesteps])
23 | train_x = np.array(train_x)
24 | train_y = np.array(train_y)
25 | test_x = df[train_end - timesteps + 1:train_end + 1]
26 | test_y = df[train_end + 1]
27 | folds.append((train_x, train_y, test_x, test_y))
28 | return folds, enc
29 |
30 |
31 | if __name__ == '__main__':
32 | print(get_xy_kfolds())
33 |
--------------------------------------------------------------------------------
/tasks/imdb_tcn.py:
--------------------------------------------------------------------------------
1 | """
2 | #Trains a TCN on the IMDB sentiment classification task.
3 | Output after 1 epochs on CPU: ~0.8611
4 | Time per epoch on CPU (Core i7): ~64s.
5 | Based on: https://github.com/keras-team/keras/blob/master/examples/imdb_bidirectional_lstm.py
6 | """
7 | import numpy as np
8 | from tensorflow.keras import Sequential
9 | from tensorflow.keras.datasets import imdb
10 | from tensorflow.keras.layers import Dense, Embedding
11 | from tensorflow.keras.preprocessing import sequence
12 |
13 | from tcn import TCN
14 |
15 | max_features = 20000
16 | # cut texts after this number of words
17 | # (among top max_features most common words)
18 | maxlen = 100
19 | batch_size = 32
20 |
21 | print('Loading data...')
22 | (x_train, y_train), (x_test, y_test) = imdb.load_data(num_words=max_features)
23 | print(len(x_train), 'train sequences')
24 | print(len(x_test), 'test sequences')
25 |
26 | print('Pad sequences (samples x time)')
27 | x_train = sequence.pad_sequences(x_train, maxlen=maxlen)
28 | x_test = sequence.pad_sequences(x_test, maxlen=maxlen)
29 | print('x_train shape:', x_train.shape)
30 | print('x_test shape:', x_test.shape)
31 | y_train = np.array(y_train)
32 | y_test = np.array(y_test)
33 |
34 | model = Sequential([
35 | Embedding(max_features, 128, input_shape=(maxlen,)),
36 | TCN(kernel_size=6, dilations=[1, 2, 4, 8, 16]),
37 | Dense(1, activation='sigmoid')
38 | ])
39 |
40 | print(f'TCN receptive field: {model.layers[1].receptive_field}.')
41 |
42 | model.summary()
43 | model.compile('adam', 'binary_crossentropy', metrics=['accuracy'])
44 |
45 | print('Train...')
46 | model.fit(
47 | x_train, y_train,
48 | batch_size=batch_size,
49 | validation_data=[x_test, y_test]
50 | )
51 |
--------------------------------------------------------------------------------
/tasks/many_to_many.py:
--------------------------------------------------------------------------------
1 | """
2 | #Trains a TCN on the IMDB sentiment classification task.
3 | Output after 1 epochs on CPU: ~0.8611
4 | Time per epoch on CPU (Core i7): ~64s.
5 | Based on: https://github.com/keras-team/keras/blob/master/examples/imdb_bidirectional_lstm.py
6 | """
7 | import numpy as np
8 | from tensorflow.keras import Sequential
9 | from tensorflow.keras.layers import Dense
10 | from tensorflow.keras.layers import RepeatVector
11 |
12 | from tcn import TCN
13 |
14 | # many to many example.
15 | # the input to the TCN model has the shape (batch_size, 24, 8),
16 | # and the TCN output shape should have the shape (batch_size, 6, 2).
17 |
18 | # We apply the TCN on the input sequence of length 24 to produce a vector of size 64
19 | # (comparable to the last state of an LSTM). We repeat this vector 6 times to match the length
20 | # of the output. We obtain an output_shape = (output_timesteps, 64) where each vector of size 64
21 | # is identical (just duplicated output_timesteps times).
22 | # From there, we apply a fully connected layer to go from a dim of 64 to output_dim.
23 | # The kernel of this FC layer is (64, output_dim). That means each output_dim is parametrized by 64
24 | # weights + 1 bias (applied at the TCN output, the RepeatVector does not have weights, just a reshape).
25 |
26 | batch_size, timesteps, input_dim = 64, 24, 8
27 | output_timesteps, output_dim = 6, 2
28 |
29 | # dummy values here. There is nothing to learn. It's just to show how to do it.
30 | batch_x = np.random.uniform(size=(batch_size, timesteps, input_dim))
31 | batch_y = np.random.uniform(size=(batch_size, output_timesteps, output_dim))
32 |
33 | model = Sequential(
34 | layers=[
35 | TCN(input_shape=(timesteps, input_dim)), # output.shape = (batch, 64)
36 | RepeatVector(output_timesteps), # output.shape = (batch, output_timesteps, 64)
37 | Dense(output_dim) # output.shape = (batch, output_timesteps, output_dim)
38 | ]
39 | )
40 |
41 | model.summary()
42 | model.compile('adam', 'mse')
43 |
44 | print('Train...')
45 | model.fit(batch_x, batch_y, batch_size=batch_size)
46 |
--------------------------------------------------------------------------------
/tasks/mnist_pixel/main.py:
--------------------------------------------------------------------------------
1 | from utils import data_generator
2 |
3 | from tcn import compiled_tcn
4 |
5 |
6 | def run_task():
7 | (x_train, y_train), (x_test, y_test) = data_generator()
8 |
9 | model = compiled_tcn(return_sequences=False,
10 | num_feat=1,
11 | num_classes=10,
12 | nb_filters=20,
13 | kernel_size=6,
14 | dilations=[2 ** i for i in range(9)],
15 | nb_stacks=1,
16 | max_len=x_train[0:1].shape[1],
17 | # use_weight_norm=True,
18 | use_skip_connections=True)
19 |
20 | print(f'x_train.shape = {x_train.shape}')
21 | print(f'y_train.shape = {y_train.shape}')
22 | print(f'x_test.shape = {x_test.shape}')
23 | print(f'y_test.shape = {y_test.shape}')
24 |
25 | model.summary()
26 |
27 | model.fit(x_train, y_train.squeeze().argmax(axis=1), epochs=100,
28 | validation_data=(x_test, y_test.squeeze().argmax(axis=1)))
29 |
30 |
31 | if __name__ == '__main__':
32 | run_task()
33 |
--------------------------------------------------------------------------------
/tasks/mnist_pixel/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from tensorflow.keras.datasets import mnist
3 | from tensorflow.keras.utils import to_categorical
4 |
5 |
6 | def data_generator():
7 | # input image dimensions
8 | img_rows, img_cols = 28, 28
9 | (x_train, y_train), (x_test, y_test) = mnist.load_data()
10 | x_train = x_train.reshape(-1, img_rows * img_cols, 1)
11 | x_test = x_test.reshape(-1, img_rows * img_cols, 1)
12 |
13 | num_classes = 10
14 | y_train = to_categorical(y_train, num_classes)
15 | y_test = to_categorical(y_test, num_classes)
16 |
17 | y_train = np.expand_dims(y_train, axis=2)
18 | y_test = np.expand_dims(y_test, axis=2)
19 |
20 | x_train = x_train.astype('float32')
21 | x_test = x_test.astype('float32')
22 | x_train /= 255
23 | x_test /= 255
24 |
25 | return (x_train, y_train), (x_test, y_test)
26 |
27 |
28 | if __name__ == '__main__':
29 | print(data_generator())
30 |
--------------------------------------------------------------------------------
/tasks/monthly-milk-production-pounds-p.csv:
--------------------------------------------------------------------------------
1 | "month","milk_production_pounds"
2 | "1962-01",589
3 | "1962-02",561
4 | "1962-03",640
5 | "1962-04",656
6 | "1962-05",727
7 | "1962-06",697
8 | "1962-07",640
9 | "1962-08",599
10 | "1962-09",568
11 | "1962-10",577
12 | "1962-11",553
13 | "1962-12",582
14 | "1963-01",600
15 | "1963-02",566
16 | "1963-03",653
17 | "1963-04",673
18 | "1963-05",742
19 | "1963-06",716
20 | "1963-07",660
21 | "1963-08",617
22 | "1963-09",583
23 | "1963-10",587
24 | "1963-11",565
25 | "1963-12",598
26 | "1964-01",628
27 | "1964-02",618
28 | "1964-03",688
29 | "1964-04",705
30 | "1964-05",770
31 | "1964-06",736
32 | "1964-07",678
33 | "1964-08",639
34 | "1964-09",604
35 | "1964-10",611
36 | "1964-11",594
37 | "1964-12",634
38 | "1965-01",658
39 | "1965-02",622
40 | "1965-03",709
41 | "1965-04",722
42 | "1965-05",782
43 | "1965-06",756
44 | "1965-07",702
45 | "1965-08",653
46 | "1965-09",615
47 | "1965-10",621
48 | "1965-11",602
49 | "1965-12",635
50 | "1966-01",677
51 | "1966-02",635
52 | "1966-03",736
53 | "1966-04",755
54 | "1966-05",811
55 | "1966-06",798
56 | "1966-07",735
57 | "1966-08",697
58 | "1966-09",661
59 | "1966-10",667
60 | "1966-11",645
61 | "1966-12",688
62 | "1967-01",713
63 | "1967-02",667
64 | "1967-03",762
65 | "1967-04",784
66 | "1967-05",837
67 | "1967-06",817
68 | "1967-07",767
69 | "1967-08",722
70 | "1967-09",681
71 | "1967-10",687
72 | "1967-11",660
73 | "1967-12",698
74 | "1968-01",717
75 | "1968-02",696
76 | "1968-03",775
77 | "1968-04",796
78 | "1968-05",858
79 | "1968-06",826
80 | "1968-07",783
81 | "1968-08",740
82 | "1968-09",701
83 | "1968-10",706
84 | "1968-11",677
85 | "1968-12",711
86 | "1969-01",734
87 | "1969-02",690
88 | "1969-03",785
89 | "1969-04",805
90 | "1969-05",871
91 | "1969-06",845
92 | "1969-07",801
93 | "1969-08",764
94 | "1969-09",725
95 | "1969-10",723
96 | "1969-11",690
97 | "1969-12",734
98 | "1970-01",750
99 | "1970-02",707
100 | "1970-03",807
101 | "1970-04",824
102 | "1970-05",886
103 | "1970-06",859
104 | "1970-07",819
105 | "1970-08",783
106 | "1970-09",740
107 | "1970-10",747
108 | "1970-11",711
109 | "1970-12",751
110 | "1971-01",804
111 | "1971-02",756
112 | "1971-03",860
113 | "1971-04",878
114 | "1971-05",942
115 | "1971-06",913
116 | "1971-07",869
117 | "1971-08",834
118 | "1971-09",790
119 | "1971-10",800
120 | "1971-11",763
121 | "1971-12",800
122 | "1972-01",826
123 | "1972-02",799
124 | "1972-03",890
125 | "1972-04",900
126 | "1972-05",961
127 | "1972-06",935
128 | "1972-07",894
129 | "1972-08",855
130 | "1972-09",809
131 | "1972-10",810
132 | "1972-11",766
133 | "1972-12",805
134 | "1973-01",821
135 | "1973-02",773
136 | "1973-03",883
137 | "1973-04",898
138 | "1973-05",957
139 | "1973-06",924
140 | "1973-07",881
141 | "1973-08",837
142 | "1973-09",784
143 | "1973-10",791
144 | "1973-11",760
145 | "1973-12",802
146 | "1974-01",828
147 | "1974-02",778
148 | "1974-03",889
149 | "1974-04",902
150 | "1974-05",969
151 | "1974-06",947
152 | "1974-07",908
153 | "1974-08",867
154 | "1974-09",815
155 | "1974-10",812
156 | "1974-11",773
157 | "1974-12",813
158 | "1975-01",834
159 | "1975-02",782
160 | "1975-03",892
161 | "1975-04",903
162 | "1975-05",966
163 | "1975-06",937
164 | "1975-07",896
165 | "1975-08",858
166 | "1975-09",817
167 | "1975-10",827
168 | "1975-11",797
169 | "1975-12",843
--------------------------------------------------------------------------------
/tasks/multi_length_sequences.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from tensorflow.keras import Sequential
3 | from tensorflow.keras.layers import Dense
4 |
5 | from tcn import TCN
6 |
7 | # if you increase the sequence length make sure the receptive field of the TCN is big enough.
8 | MAX_TIME_STEP = 30
9 |
10 | """
11 | Input: sequence of length 7
12 | Input: sequence of length 25
13 | Input: sequence of length 29
14 | Input: sequence of length 21
15 | Input: sequence of length 20
16 | Input: sequence of length 13
17 | Input: sequence of length 9
18 | Input: sequence of length 7
19 | Input: sequence of length 4
20 | Input: sequence of length 14
21 | Input: sequence of length 10
22 | Input: sequence of length 11
23 | ...
24 | """
25 |
26 |
27 | def get_x_y(max_time_steps):
28 | for k in range(int(1e9)):
29 | time_steps = np.random.choice(range(1, max_time_steps), size=1)[0]
30 | if k % 2 == 0:
31 | x_train = np.expand_dims([np.insert(np.zeros(shape=(time_steps, 1)), 0, 1)], axis=-1)
32 | y_train = [1]
33 | else:
34 | x_train = np.array([np.zeros(shape=(time_steps, 1))])
35 | y_train = [0]
36 | if k % 100 == 0:
37 | print(f'({k}) Input: sequence of length {time_steps}.')
38 | yield x_train, np.expand_dims(y_train, axis=-1)
39 |
40 |
41 | m = Sequential([
42 | TCN(input_shape=(None, 1)),
43 | Dense(1, activation='sigmoid')
44 | ])
45 |
46 | m.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
47 |
48 | gen = get_x_y(max_time_steps=MAX_TIME_STEP)
49 | m.fit(gen, epochs=1, steps_per_epoch=1000, verbose=2)
50 |
--------------------------------------------------------------------------------
/tasks/non_causal.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from tensorflow.keras import Sequential
3 | from tensorflow.keras.layers import Dense
4 |
5 | from tcn import TCN
6 |
7 | # Look at the README.md to know what is a non-causal case.
8 |
9 | model = Sequential([
10 | TCN(nb_filters=30, padding='same', input_shape=(5, 300)),
11 | Dense(1)
12 | ])
13 | model.compile(optimizer='adam', loss='mse')
14 | pred = model.predict(np.random.rand(1, 5, 300))
15 | print(pred.shape)
16 |
--------------------------------------------------------------------------------
/tasks/plot_tcn_model.py:
--------------------------------------------------------------------------------
1 | from tcn import TCN
2 | import tensorflow as tf
3 |
4 | timesteps = 32
5 | input_dim = 5
6 | input_shape = (timesteps, input_dim)
7 | forecast_horizon = 3
8 | num_features = 4
9 |
10 | inputs = tf.keras.layers.Input(shape=input_shape, name='input')
11 | tcn_out = TCN(nb_filters=64, kernel_size=3, nb_stacks=1, activation='relu')(inputs)
12 | outputs = tf.keras.layers.Dense(forecast_horizon * num_features, activation='linear')(tcn_out)
13 | outputs = tf.keras.layers.Reshape((forecast_horizon, num_features), name='ouput')(outputs)
14 | model = tf.keras.Model(inputs=inputs, outputs=outputs)
15 |
16 | tf.keras.utils.plot_model(
17 | model,
18 | to_file='TCN_model.png',
19 | show_shapes=True,
20 | show_dtype=True,
21 | show_layer_names=True,
22 | rankdir='TB',
23 | dpi=200,
24 | layer_range=None,
25 | )
26 |
--------------------------------------------------------------------------------
/tasks/receptive-field/main.py:
--------------------------------------------------------------------------------
1 | from utils import data_generator
2 |
3 | from tcn import compiled_tcn
4 |
5 |
6 | def run_task(sequence_length=8):
7 | x_train, y_train = data_generator(batch_size=2048, sequence_length=sequence_length)
8 | print(x_train.shape)
9 | print(y_train.shape)
10 | model = compiled_tcn(return_sequences=False,
11 | num_feat=1,
12 | num_classes=10,
13 | nb_filters=10,
14 | kernel_size=10,
15 | dilations=[1, 2, 4, 8, 16, 32],
16 | nb_stacks=6,
17 | max_len=x_train[0:1].shape[1],
18 | use_skip_connections=False)
19 |
20 | print(f'x_train.shape = {x_train.shape}')
21 | print(f'y_train.shape = {y_train.shape}')
22 |
23 | # model.summary()
24 |
25 | model.fit(x_train, y_train, epochs=5)
26 | return model.evaluate(x_train, y_train)[1]
27 |
28 |
29 | def main():
30 | print('acc =', run_task(sequence_length=630))
31 |
32 |
33 | if __name__ == '__main__':
34 | main()
35 |
--------------------------------------------------------------------------------
/tasks/receptive-field/run.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | source ~/venv3.6/bin/activate
3 | pip uninstall -y keras-tcn
4 | cd ..
5 | pip install . --upgrade
6 | cd understands
7 | export CUDA_VISIBLE_DEVICES=; python main.py | grep acc
--------------------------------------------------------------------------------
/tasks/receptive-field/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def data_generator(batch_size=1024, sequence_length=32):
5 | # input image dimensions
6 | pos_indices = np.random.choice(batch_size, size=int(batch_size // 2), replace=False)
7 |
8 | x_train = np.zeros(shape=(batch_size, sequence_length))
9 | y_train = np.zeros(shape=(batch_size, 1))
10 | x_train[pos_indices, 0] = 1.0
11 | y_train[pos_indices, 0] = 1.0
12 |
13 | # y_train = to_categorical(y_train, num_classes=2)
14 |
15 | return np.expand_dims(x_train, axis=2), y_train
16 |
17 |
18 | if __name__ == '__main__':
19 | print(data_generator(batch_size=3, sequence_length=4))
20 |
--------------------------------------------------------------------------------
/tasks/save_reload_sequential_model.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from tensorflow.keras.layers import Dense, Embedding
3 | from tensorflow.keras.models import Sequential, model_from_json, load_model
4 |
5 | from tcn import TCN, tcn_full_summary
6 |
7 | # define input shape
8 | max_len = 100
9 | max_features = 50
10 |
11 | # make model
12 | model = Sequential(layers=[Embedding(max_features, 16, input_shape=(max_len,)),
13 | TCN(nb_filters=12,
14 | dropout_rate=0.5,
15 | kernel_size=6,
16 | use_batch_norm=True,
17 | dilations=[1, 2, 4]),
18 | Dense(units=1, activation='sigmoid')])
19 |
20 | model.compile(loss='mae')
21 | model.fit(x=np.random.random((max_features, 100)), y=np.random.random((max_features, 1)))
22 |
23 | # get model as json string and save to file
24 | model_as_json = model.to_json()
25 | with open('model.json', "w") as json_file:
26 | json_file.write(model_as_json)
27 | # save weights to file (for this format, need h5py installed)
28 | model.save_weights('model.weights.h5')
29 |
30 | # Make inference.
31 | inputs = np.ones(shape=(1, 100))
32 | out1 = model.predict(inputs)[0, 0]
33 | print('*' * 80)
34 | print('Inference after creation:', out1)
35 |
36 | # load model from file
37 | loaded_json = open('model.json', 'r').read()
38 | reloaded_model = model_from_json(loaded_json, custom_objects={'TCN': TCN})
39 |
40 | tcn_full_summary(model, expand_residual_blocks=False)
41 |
42 | # restore weights
43 | reloaded_model.load_weights('model.weights.h5')
44 |
45 | # Make inference.
46 | out2 = reloaded_model.predict(inputs)[0, 0]
47 | print('*' * 80)
48 | print('Inference after loading:', out2)
49 |
50 | assert abs(out1 - out2) < 1e-6
51 |
52 | model.save('model.keras')
53 | out11 = load_model('model.keras').predict(inputs)[0, 0]
54 | out22 = model.predict(inputs)[0, 0]
55 | assert abs(out11 - out22) < 1e-6
56 |
--------------------------------------------------------------------------------
/tasks/sequential.py:
--------------------------------------------------------------------------------
1 | """
2 | #Trains a TCN on the IMDB sentiment classification task.
3 | Output after 1 epochs on CPU: ~0.8611
4 | Time per epoch on CPU (Core i7): ~64s.
5 | Based on: https://github.com/keras-team/keras/blob/master/examples/imdb_bidirectional_lstm.py
6 | """
7 | import numpy as np
8 | from tensorflow.keras import Sequential
9 | from tensorflow.keras.callbacks import Callback
10 | from tensorflow.keras.datasets import imdb
11 | from tensorflow.keras.layers import Dense, Dropout, Embedding
12 | from tensorflow.keras.preprocessing import sequence
13 |
14 | from tcn import TCN
15 |
16 | max_features = 20000
17 | # cut texts after this number of words
18 | # (among top max_features most common words)
19 | maxlen = 100
20 | batch_size = 32
21 |
22 | print('Loading data...')
23 | (x_train, y_train), (x_test, y_test) = imdb.load_data(num_words=max_features)
24 | print(len(x_train), 'train sequences')
25 | print(len(x_test), 'test sequences')
26 |
27 | print('Pad sequences (samples x time)')
28 | x_train = sequence.pad_sequences(x_train, maxlen=maxlen)
29 | x_test = sequence.pad_sequences(x_test, maxlen=maxlen)
30 | print('x_train shape:', x_train.shape)
31 | print('x_test shape:', x_test.shape)
32 | y_train = np.array(y_train)
33 | y_test = np.array(y_test)
34 |
35 | model = Sequential()
36 | model.add(Embedding(max_features, 128, input_shape=(maxlen,)))
37 | model.add(TCN(
38 | nb_filters=64,
39 | kernel_size=6,
40 | dilations=[1, 2, 4, 8, 16, 32, 64]
41 | ))
42 | model.add(Dropout(0.5))
43 | model.add(Dense(1, activation='sigmoid'))
44 |
45 | model.summary()
46 |
47 | model.compile('adam', 'binary_crossentropy', metrics=['accuracy'])
48 |
49 |
50 | class TestCallback(Callback):
51 |
52 | def on_epoch_end(self, epoch, logs=None):
53 | print(logs)
54 | acc_key = 'val_accuracy' if 'val_accuracy' in logs else 'val_acc'
55 | assert logs[acc_key] > 0.78
56 |
57 |
58 | print('Train...')
59 | model.fit(
60 | x_train, y_train,
61 | batch_size=batch_size,
62 | validation_data=(x_test, y_test),
63 | callbacks=[TestCallback()]
64 | )
65 |
--------------------------------------------------------------------------------
/tasks/tcn_call_test.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | import numpy as np
4 | from tensorflow.keras import Input
5 | from tensorflow.keras import Model
6 |
7 | from tcn import TCN
8 |
9 | NB_FILTERS = 16
10 | TIME_STEPS = 20
11 |
12 | SEQ_LEN_1 = 5
13 | SEQ_LEN_2 = 1
14 | SEQ_LEN_3 = 10
15 |
16 |
17 | def predict_with_tcn(time_steps=None, padding='causal', return_sequences=True) -> list:
18 | input_dim = 4
19 | i = Input(batch_shape=(None, time_steps, input_dim))
20 | o = TCN(nb_filters=NB_FILTERS, return_sequences=return_sequences, padding=padding)(i)
21 | m = Model(inputs=[i], outputs=[o])
22 | m.compile(optimizer='adam', loss='mse')
23 | if time_steps is None:
24 | np.random.seed(123)
25 | return [
26 | m(np.random.rand(1, SEQ_LEN_1, input_dim)),
27 | m(np.random.rand(1, SEQ_LEN_2, input_dim)),
28 | m(np.random.rand(1, SEQ_LEN_3, input_dim))
29 | ]
30 | else:
31 | np.random.seed(123)
32 | return [m(np.random.rand(1, time_steps, input_dim))]
33 |
34 |
35 | class TCNCallTest(unittest.TestCase):
36 |
37 | def test_compute_output_for_multiple_config(self):
38 | # with time steps None.
39 | o1 = TCN(nb_filters=NB_FILTERS, return_sequences=True, padding='same').compute_output_shape((None, None, 4))
40 | self.assertListEqual(list(o1), [None, None, NB_FILTERS])
41 |
42 | o2 = TCN(nb_filters=NB_FILTERS, return_sequences=True, padding='causal').compute_output_shape((None, None, 4))
43 | self.assertListEqual(list(o2), [None, None, NB_FILTERS])
44 |
45 | o3 = TCN(nb_filters=NB_FILTERS, return_sequences=False, padding='same').compute_output_shape((None, None, 4))
46 | self.assertListEqual(list(o3), [None, NB_FILTERS])
47 |
48 | o4 = TCN(nb_filters=NB_FILTERS, return_sequences=False, padding='causal').compute_output_shape((None, None, 4))
49 | self.assertListEqual(list(o4), [None, NB_FILTERS])
50 |
51 | # with time steps known.
52 | o5 = TCN(nb_filters=NB_FILTERS, return_sequences=True, padding='same').compute_output_shape((None, 5, 4))
53 | self.assertListEqual(list(o5), [None, 5, NB_FILTERS])
54 |
55 | o6 = TCN(nb_filters=NB_FILTERS, return_sequences=True, padding='causal').compute_output_shape((None, 5, 4))
56 | self.assertListEqual(list(o6), [None, 5, NB_FILTERS])
57 |
58 | o7 = TCN(nb_filters=NB_FILTERS, return_sequences=False, padding='same').compute_output_shape((None, 5, 4))
59 | self.assertListEqual(list(o7), [None, NB_FILTERS])
60 |
61 | o8 = TCN(nb_filters=NB_FILTERS, return_sequences=False, padding='causal').compute_output_shape((None, 5, 4))
62 | self.assertListEqual(list(o8), [None, NB_FILTERS])
63 |
64 | def test_causal_time_dim_known_return_sequences(self):
65 | r = predict_with_tcn(time_steps=TIME_STEPS, padding='causal', return_sequences=True)
66 | self.assertListEqual([list(b.shape) for b in r], [[1, TIME_STEPS, NB_FILTERS]])
67 |
68 | def test_causal_time_dim_unknown_return_sequences(self):
69 | r = predict_with_tcn(time_steps=None, padding='causal', return_sequences=True)
70 | self.assertListEqual([list(b.shape) for b in r],
71 | [[1, SEQ_LEN_1, NB_FILTERS],
72 | [1, SEQ_LEN_2, NB_FILTERS],
73 | [1, SEQ_LEN_3, NB_FILTERS]])
74 |
75 | def test_non_causal_time_dim_known_return_sequences(self):
76 | r = predict_with_tcn(time_steps=TIME_STEPS, padding='same', return_sequences=True)
77 | self.assertListEqual([list(b.shape) for b in r], [[1, TIME_STEPS, NB_FILTERS]])
78 |
79 | def test_non_causal_time_dim_unknown_return_sequences(self):
80 | r = predict_with_tcn(time_steps=None, padding='same', return_sequences=True)
81 | self.assertListEqual([list(b.shape) for b in r],
82 | [[1, SEQ_LEN_1, NB_FILTERS],
83 | [1, SEQ_LEN_2, NB_FILTERS],
84 | [1, SEQ_LEN_3, NB_FILTERS]])
85 |
86 | def test_causal_time_dim_known_return_no_sequences(self):
87 | r = predict_with_tcn(time_steps=TIME_STEPS, padding='causal', return_sequences=False)
88 | self.assertListEqual([list(b.shape) for b in r], [[1, NB_FILTERS]])
89 |
90 | def test_causal_time_dim_unknown_return_no_sequences(self):
91 | r = predict_with_tcn(time_steps=None, padding='causal', return_sequences=False)
92 | self.assertListEqual([list(b.shape) for b in r], [[1, NB_FILTERS], [1, NB_FILTERS], [1, NB_FILTERS]])
93 |
94 | def test_non_causal_time_dim_known_return_no_sequences(self):
95 | r = predict_with_tcn(time_steps=TIME_STEPS, padding='same', return_sequences=False)
96 | self.assertListEqual([list(b.shape) for b in r], [[1, NB_FILTERS]])
97 |
98 | def test_non_causal_time_dim_unknown_return_no_sequences(self):
99 | r = predict_with_tcn(time_steps=None, padding='same', return_sequences=False)
100 | self.assertListEqual([list(b.shape) for b in r], [[1, NB_FILTERS], [1, NB_FILTERS], [1, NB_FILTERS]])
101 |
102 | def test_receptive_field(self):
103 | self.assertEqual(37, TCN(kernel_size=3, dilations=(1, 3, 5), nb_stacks=1).receptive_field)
104 | self.assertEqual(379, TCN(kernel_size=4, dilations=(1, 2, 4, 8, 16, 32), nb_stacks=1).receptive_field)
105 | self.assertEqual(253, TCN(kernel_size=3, dilations=(1, 2, 4, 8, 16, 32), nb_stacks=1).receptive_field)
106 | self.assertEqual(125, TCN(kernel_size=3, dilations=(1, 2, 4, 8, 16), nb_stacks=1).receptive_field)
107 | self.assertEqual(61, TCN(kernel_size=3, dilations=(1, 2, 4, 8), nb_stacks=1).receptive_field)
108 | self.assertEqual(29, TCN(kernel_size=3, dilations=(1, 2, 4), nb_stacks=1).receptive_field)
109 | self.assertEqual(57, TCN(kernel_size=3, dilations=(1, 2, 4), nb_stacks=2).receptive_field)
110 | self.assertEqual(121, TCN(kernel_size=3, dilations=(1, 2, 4, 8), nb_stacks=2).receptive_field)
111 | self.assertEqual(91, TCN(kernel_size=4, dilations=(1, 2, 4, 8), nb_stacks=1).receptive_field)
112 | self.assertEqual(25, TCN(kernel_size=5, dilations=(1, 2), nb_stacks=1).receptive_field)
113 | self.assertEqual(31, TCN(kernel_size=6, dilations=(1, 2), nb_stacks=1).receptive_field)
114 | # 1+(3-1)*1*(1+3+5)*2 = 37
115 | # 1+(4-1)*1*(1+2+4+8+16+32)*2 = 379
116 | # 1+(3-1)*1*(1+2+4+8+16+32)*2 = 253
117 | # 1+(3-1)*1*(1+2+4+8+16)*2 = 125
118 | # 1+(3-1)*1*(1+2+4+8)*2 = 61
119 | # 1+(3-1)*1*(1+2+4)*2 = 29
120 | # 1+(3-1)*2*(1+2+4)*2 = 57
121 | # 1+(3-1)*2*(1+2+4+8)*2 = 121
122 | # 1+(4-1)*1*(1+2+4+8)*2 = 91
123 | # 1+(5-1)*1*(1+2)*2 = 25
124 | # 1+(6-1)*1*(1+2)*2 = 31
125 |
126 |
127 | if __name__ == '__main__':
128 | unittest.main()
129 |
--------------------------------------------------------------------------------
/tasks/tcn_tensorboard.py:
--------------------------------------------------------------------------------
1 | """
2 | #Trains a TCN on the IMDB sentiment classification task.
3 | Output after 1 epochs on CPU: ~0.8611
4 | Time per epoch on CPU (Core i7): ~64s.
5 | Based on: https://github.com/keras-team/keras/blob/master/examples/imdb_bidirectional_lstm.py
6 | """
7 | import numpy as np
8 | from tensorflow.keras import Sequential
9 | from tensorflow.keras.callbacks import TensorBoard
10 | from tensorflow.keras.datasets import imdb
11 | from tensorflow.keras.layers import Dense, Dropout, Embedding
12 | from tensorflow.keras.preprocessing import sequence
13 |
14 | from tcn import TCN
15 |
16 | max_features = 20000
17 | maxlen = 100
18 | batch_size = 32
19 |
20 | print('Loading data...')
21 | (x_train, y_train), (x_test, y_test) = imdb.load_data(num_words=max_features)
22 | print(len(x_train), 'train sequences')
23 | print(len(x_test), 'test sequences')
24 |
25 | print('Pad sequences (samples x time)')
26 | x_train = sequence.pad_sequences(x_train, maxlen=maxlen)
27 | x_test = sequence.pad_sequences(x_test, maxlen=maxlen)
28 | print('x_train shape:', x_train.shape)
29 | print('x_test shape:', x_test.shape)
30 | y_train = np.array(y_train)
31 | y_test = np.array(y_test)
32 |
33 | model = Sequential()
34 | model.add(Embedding(max_features, 128, input_shape=(maxlen,)))
35 | model.add(TCN(
36 | kernel_size=6,
37 | dilations=[1, 2, 4, 8, 16, 32, 64]
38 | ))
39 | model.add(Dropout(0.5))
40 | model.add(Dense(1, activation='sigmoid'))
41 | model.summary()
42 |
43 | model.compile('adam', 'binary_crossentropy', metrics=['accuracy'])
44 |
45 | # tensorboard --logdir logs_tcn
46 | # Browse to http://localhost:6006/#graphs&run=train.
47 | # and double click on TCN to expand the inner layers.
48 | # It takes time to write the graph to tensorboard. Wait until the first epoch is completed.
49 | tensorboard = TensorBoard(
50 | log_dir='logs_tcn',
51 | histogram_freq=1,
52 | write_images=True
53 | )
54 |
55 | print('Train...')
56 | model.fit(
57 | x_train, y_train,
58 | batch_size=batch_size,
59 | validation_data=(x_test, y_test),
60 | callbacks=[tensorboard],
61 | epochs=10
62 | )
63 |
--------------------------------------------------------------------------------
/tasks/time_series_forecasting.py:
--------------------------------------------------------------------------------
1 | # https://datamarket.com/data/set/22ox/monthly-milk-production-pounds-per-cow-jan-62-dec-75#!ds=22ox&display=line
2 | import matplotlib.pyplot as plt
3 | import numpy as np
4 | import pandas as pd
5 | from tensorflow.keras import Sequential
6 | from tensorflow.keras.layers import Dense
7 |
8 | from tcn import TCN
9 |
10 | ##
11 | # It's a very naive (toy) example to show how to do time series forecasting.
12 | # - There are no training-testing sets here. Everything is training set for simplicity.
13 | # - There is no input/output normalization.
14 | # - The model is simple.
15 | ##
16 |
17 | milk = pd.read_csv('monthly-milk-production-pounds-p.csv', index_col=0, parse_dates=True)
18 |
19 | print(milk.head())
20 |
21 | lookback_window = 12 # months.
22 |
23 | milk = milk.values # just keep np array here for simplicity.
24 |
25 | x, y = [], []
26 | for i in range(lookback_window, len(milk)):
27 | x.append(milk[i - lookback_window:i])
28 | y.append(milk[i])
29 | x = np.array(x)
30 | y = np.array(y)
31 |
32 | print(x.shape)
33 | print(y.shape)
34 |
35 | # noinspection PyArgumentEqualDefault
36 | model = Sequential([
37 | TCN(input_shape=(lookback_window, 1),
38 | kernel_size=2,
39 | use_skip_connections=False,
40 | use_batch_norm=False,
41 | # use_weight_norm=False,
42 | use_layer_norm=False
43 | ),
44 | Dense(1, activation='linear')
45 | ])
46 |
47 | model.summary()
48 | model.compile('adam', 'mae')
49 |
50 | print('Train...')
51 | model.fit(x, y, epochs=100, verbose=2)
52 |
53 | p = model.predict(x)
54 |
55 | plt.plot(p)
56 | plt.plot(y)
57 | plt.title('Monthly Milk Production (in pounds)')
58 | plt.legend(['predicted', 'actual'])
59 | plt.show()
60 |
--------------------------------------------------------------------------------
/tasks/video_classification.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tensorflow.keras.backend as K
3 | from tensorflow.keras import Input, Model
4 | from tensorflow.keras.layers import Conv2D
5 | from tensorflow.keras.layers import Dense
6 | from tensorflow.keras.layers import Lambda
7 | from tensorflow.keras.layers import MaxPool2D
8 |
9 | from tcn import TCN
10 |
11 | num_samples = 1000 # number of videos.
12 | num_frames = 240 # 10 seconds of video at 24 ips.
13 | h, w, c = 32, 32, 3 # def not a HD video! 32x32 color.
14 |
15 |
16 | def data():
17 | # very very dummy example. The purpose is more to show how to use a RNN/TCN
18 | # in the context of video processing.
19 | inputs = np.zeros(shape=(num_samples, num_frames, h, w, c))
20 | targets = np.zeros(shape=(num_samples, 1))
21 | # class 0 => only 0.
22 |
23 | # class 1 => will contain some 1s.
24 | for i in range(num_samples):
25 | if np.random.uniform(low=0, high=1) > 0.50:
26 | for j in range(num_frames):
27 | inputs[i, j] = (np.random.uniform(low=0, high=1) > 0.90)
28 | targets[i] = 1
29 | return inputs, targets
30 |
31 |
32 | def train():
33 | # Good exercise: https://www.crcv.ucf.edu/data/UCF101.php
34 | # replace data() by this dataset.
35 | # Useful links:
36 | # - https://www.pyimagesearch.com/2019/07/15/video-classification-with-keras-and-deep-learning/
37 | # - https://github.com/sujiongming/UCF-101_video_classification
38 | x_train, y_train = data()
39 |
40 | inputs = Input(shape=(num_frames, h, w, c))
41 | # push num_frames in batch_dim to process all the frames independently of their orders (CNN features).
42 | x = Lambda(lambda y: K.reshape(y, (-1, h, w, c)))(inputs)
43 | # apply convolutions to each image of each video.
44 | x = Conv2D(16, 5)(x)
45 | x = MaxPool2D()(x)
46 | # re-creates the videos by reshaping.
47 | # 3D input shape (batch, timesteps, input_dim)
48 | num_features_cnn = np.prod(K.int_shape(x)[1:])
49 | x = Lambda(lambda y: K.reshape(y, (-1, num_frames, num_features_cnn)))(x)
50 | # apply the RNN on the time dimension (num_frames dim).
51 | x = TCN(16)(x)
52 | x = Dense(1, activation='sigmoid')(x)
53 |
54 | model = Model(inputs=[inputs], outputs=[x])
55 | model.summary()
56 | model.compile('adam', 'binary_crossentropy', metrics=['accuracy'])
57 | print('Train...')
58 | model.fit(x_train, y_train, validation_split=0.2, epochs=5)
59 |
60 |
61 | if __name__ == '__main__':
62 | train()
63 |
--------------------------------------------------------------------------------
/tasks/visualise_activations.py:
--------------------------------------------------------------------------------
1 | """
2 | #Trains a TCN on the IMDB sentiment classification task.
3 | Output after 1 epochs on CPU: ~0.8611
4 | Time per epoch on CPU (Core i7): ~64s.
5 | Based on: https://github.com/keras-team/keras/blob/master/examples/imdb_bidirectional_lstm.py
6 | """
7 | import os
8 | import shutil
9 |
10 | import keract # pip install keract
11 | import keras
12 | import matplotlib.pyplot as plt
13 | import numpy as np
14 | from tensorflow.keras import Sequential
15 | from tensorflow.keras.datasets import imdb
16 | from tensorflow.keras.layers import Dense, Dropout, Embedding
17 | from tensorflow.keras.preprocessing import sequence
18 |
19 | from tcn import TCN
20 |
21 | index_from_ = 3
22 |
23 |
24 | def visualize(model, x, max_len, tcn_num_filters, tcn_layer_outputs, prefix):
25 | for i in range(len(x)):
26 | tcn_outputs = keract.get_activations(model, x[i:i + 1], nodes_to_evaluate=tcn_layer_outputs)
27 | tcn_blocks_outputs = [v for (k, v) in tcn_outputs.items() if v.shape == (1, max_len, tcn_num_filters)]
28 | plt.figure(figsize=(10, 2)) # creates a figure 10 inches by 10 inches
29 | plt.title('TCN internal outputs (one row = one residual block output)')
30 | plt.xlabel('Timesteps')
31 | plt.ylabel('Forward pass\n (top to bottom)')
32 | plt.imshow(np.max(np.vstack(tcn_blocks_outputs), axis=-1), cmap='jet', interpolation='bilinear')
33 | plt.savefig(f'acts/{prefix}_{i}.png', dpi=1000, bbox_inches='tight', pad_inches=0)
34 | plt.show()
35 | plt.clf()
36 | plt.cla()
37 | plt.close()
38 |
39 |
40 | def get_word_mappings():
41 | word_to_id_dict = keras.datasets.imdb.get_word_index()
42 | word_to_id_dict = {k: (v + index_from_) for k, v in word_to_id_dict.items()}
43 | word_to_id_dict[''] = 0
44 | word_to_id_dict[''] = 1
45 | word_to_id_dict[''] = 2
46 | word_to_id_dict[''] = 3
47 | id_to_word_dict = {value: key for key, value in word_to_id_dict.items()}
48 | return word_to_id_dict, id_to_word_dict
49 |
50 |
51 | def encode_text(x_):
52 | word_to_id, id_to_word = get_word_mappings()
53 | return [1] + [word_to_id[a] for a in x_.lower().replace('.', '').strip().split(' ')]
54 |
55 |
56 | def print_text(x_):
57 | word_to_id, id_to_word = get_word_mappings()
58 | print(' '.join(id_to_word[ii] for ii in x_))
59 |
60 |
61 | def main():
62 | max_features = 20000
63 | # cut texts after this number of words
64 | # (among top max_features most common words)
65 | max_len = 100
66 | batch_size = 32
67 | tcn_num_filters = 10
68 |
69 | print('Loading data...')
70 | (x_train, y_train), (x_test, y_test) = imdb.load_data(num_words=max_features, index_from=index_from_)
71 | print(len(x_train), 'train sequences')
72 | print(len(x_test), 'test sequences')
73 |
74 | x_val = [
75 | encode_text('The movie was very good. I highly recommend.'), # will be at the end.
76 | encode_text(' '.join(["worst"] * 100)),
77 | encode_text("Put all speaking her delicate recurred possible. "
78 | "Set indulgence discretion insensible bed why announcing. "
79 | "Middleton fat two satisfied additions. "
80 | "So continued he or commanded household smallness delivered. "
81 | "Door poor on do walk in half. "
82 | "Roof his head the what. "
83 | "Society excited by cottage private an it seems. "
84 | "Fully begin on by wound an. "
85 | "The movie was very good. I highly recommend. "
86 | "At declared in as rejoiced of together. "
87 | "He impression collecting delightful unpleasant by prosperous as on. "
88 | "End too talent she object mrs wanted remove giving. "
89 | "Man request adapted spirits set pressed. "
90 | "Up to denoting subjects sensible feelings it indulged directly.")
91 | ]
92 |
93 | y_val = [1, 0, 1]
94 |
95 | print('Pad sequences (samples x time)')
96 | x_train = sequence.pad_sequences(x_train, maxlen=max_len)
97 | x_test = sequence.pad_sequences(x_test, maxlen=max_len)
98 | x_val = sequence.pad_sequences(x_val, maxlen=max_len)
99 | print('x_train shape:', x_train.shape)
100 | print('x_test shape:', x_test.shape)
101 | print('x_val shape:', x_val.shape)
102 | y_train = np.array(y_train)
103 | y_test = np.array(y_test)
104 | y_val = np.array(y_val)
105 |
106 | x_val[x_val > max_features] = 2 # oov.
107 |
108 | for i in range(10):
109 | print(f'x_test[{i}]=', end=' | ')
110 | print_text(x_test[i])
111 |
112 | for i in range(len(x_val)):
113 | print(f'x_val[{i}]=', end=' | ')
114 | print_text(x_val[i])
115 |
116 | temporal_conv_net = TCN(
117 | nb_filters=tcn_num_filters,
118 | kernel_size=7,
119 | dilations=[1, 2, 4, 8, 16, 32]
120 | )
121 |
122 | print(temporal_conv_net.receptive_field)
123 |
124 | model = Sequential()
125 | model.add(Embedding(max_features, 128, input_shape=(max_len,)))
126 | model.add(temporal_conv_net)
127 | model.add(Dropout(0.5))
128 | model.add(Dense(1, activation='sigmoid'))
129 |
130 | model.compile('adam', 'binary_crossentropy', metrics=['accuracy'])
131 |
132 | tcn_layer_outputs = list(temporal_conv_net.layers_outputs)
133 |
134 | model.fit(x_train, y_train,
135 | batch_size=batch_size,
136 | epochs=4,
137 | validation_data=[x_test, y_test])
138 |
139 | if os.path.exists('acts'):
140 | shutil.rmtree('acts')
141 | os.makedirs('acts')
142 |
143 | print(model.predict_on_batch(x_val))
144 | print(y_val)
145 |
146 | visualize(model, x_test[0:10], max_len, tcn_num_filters, tcn_layer_outputs, 'x_test')
147 | visualize(model, x_val, max_len, tcn_num_filters, tcn_layer_outputs, 'x_val')
148 |
149 |
150 | if __name__ == '__main__':
151 | main()
152 |
--------------------------------------------------------------------------------
/tasks/word_ptb/README.md:
--------------------------------------------------------------------------------
1 | ## Word-level Language Modeling
2 |
3 | Ref: https://arxiv.org/pdf/1803.01271.
4 |
5 | ### Overview
6 |
7 | In word-level language modeling tasks, each element of the sequence is a word, where the model is expected to predict
8 | the next incoming word in the text. We evaluate the temporal convolutional network as a word-level language model on
9 | PennTreebank
10 |
11 | ### Data
12 |
13 | **PennTreebank**: We used the PennTreebank (PTB) (Marcus et al., 1993) for both character-level and word-level
14 | language modeling. When used as a character-level language corpus, PTB contains 5,059K characters for training,
15 | 396K for validation, and 446K for testing, with an alphabet
16 | size of 50. When used as a word-level language corpus,
17 | PTB contains 888K words for training, 70K for validation,
18 | and 79K for testing, with a vocabulary size of 10K. This
19 | is a highly studied but relatively small language modeling
20 | dataset (Miyamoto & Cho, 2016; Krueger et al., 2017; Merity et al., 2017).
21 |
22 | ### Results
23 |
24 | *Note that the implementation might be a bit different than what is quoted in the paper.*
25 |
26 | **Word-level language modeling**. Language modeling remains one of the primary applications of recurrent networks
27 | and many recent works have focused on optimizing LSTMs
28 | for this task (Krueger et al., 2017; Merity et al., 2017).
29 | Our implementation follows standard practice that ties the
30 | weights of encoder and decoder layers for both TCN and
31 | RNNs (Press & Wolf, 2016), which significantly reduces
32 | the number of parameters in the model. For training, we use
33 | SGD and anneal the learning rate by a factor of 0.5 for both
34 | TCN and RNNs when validation accuracy plateaus.
35 | On the smaller PTB corpus, an optimized LSTM architecture (with recurrent and embedding dropout, etc.) outperforms the TCN, while the TCN outperforms both GRU and
36 | vanilla RNN. However, on the much larger Wikitext-103
37 | corpus and the LAMBADA dataset (Paperno et al., 2016),
38 | without any hyperparameter search, the TCN outperforms the LSTM results of Grave et al. (2017), achieving much
39 | lower perplexities.
40 |
41 | **Character-level language modeling**. On character-level
42 | language modeling (PTB and text8, accuracy measured in
43 | bits per character), the generic TCN outperforms regularized LSTMs and GRUs as well as methods such as Normstabilized LSTMs (Krueger & Memisevic, 2015). (Specialized architectures exist that outperform all of these, see the
44 | supplement.)
--------------------------------------------------------------------------------
/tasks/word_ptb/data/README:
--------------------------------------------------------------------------------
1 | Data description:
2 |
3 | Penn Treebank Corpus
4 | - should be free for research purposes
5 | - the same processing of data as used in many LM papers, including "Empirical Evaluation and Combination of Advanced Language Modeling Techniques"
6 | - ptb.train.txt: train set
7 | - ptb.valid.txt: development set (should be used just for tuning hyper-parameters, but not for training)
8 | - ptb.test.txt: test set for reporting perplexity
9 |
10 | - ptb.char.*: the same data, just rewritten as sequences of characters, with spaces rewritten as '_' - useful for training character based models, as is shown in example 9
11 |
--------------------------------------------------------------------------------
/tasks/word_ptb/plot.py:
--------------------------------------------------------------------------------
1 | import re
2 | import sys
3 | from pathlib import Path
4 |
5 | import matplotlib.pyplot as plt
6 | import pandas as pd
7 |
8 |
9 | # export CUDA_VISIBLE_DEVICES=0; nohup python -u train.py --use_lstm --batch_size 256 --task char > lstm.log 2>&1 &
10 | # export CUDA_VISIBLE_DEVICES=1; nohup python -u train.py --batch_size 256 --task char > tcn.log 2>&1 &
11 | # Usage: python plot.py tcn.log lstm.log
12 |
13 | def keras_output_to_data_frame(filename) -> pd.DataFrame:
14 | suffix = Path(filename).stem
15 | headers = None
16 | data = []
17 | with open(filename) as r:
18 | lines = r.read().strip().split('\n')
19 | for line in lines:
20 | if 'ETA' not in line and 'loss' in line:
21 | matches = re.findall('[a-z_]+: [0-9]+.[0-9]+', line)
22 | headers = [m.split(':')[0] + '_' + suffix for m in matches]
23 | data.append([float(m.split(':')[1]) for m in matches])
24 | return pd.DataFrame(data, columns=headers)
25 |
26 |
27 | def main():
28 | dfs = []
29 | colors = ['darkviolet', 'violet', 'deepskyblue', 'skyblue']
30 | for i, argument in enumerate(sys.argv):
31 | if i == 0:
32 | continue
33 | dfs.append(keras_output_to_data_frame(argument))
34 | m = pd.concat(dfs, axis=1)
35 | accuracy_columns = [c for c in list(m.columns) if 'acc' in c]
36 | loss_columns = [c for c in list(m.columns) if 'loss' in c]
37 | _, axs = plt.subplots(ncols=2, figsize=(12, 7), dpi=150)
38 | m.plot(y=accuracy_columns, title='Accuracy', legend=True, xlabel='epoch', color=colors,
39 | ylabel='accuracy', sort_columns=True, grid=True, ax=axs[0])
40 | plt.figure(1, figsize=(2, 5))
41 | m.plot(y=loss_columns, title='Loss', legend=True, color=colors,
42 | xlabel='epoch', ylabel='loss', sort_columns=True, grid=True, ax=axs[1])
43 | plt.savefig('result.png')
44 | plt.close()
45 |
46 |
47 | if __name__ == '__main__':
48 | main()
49 |
--------------------------------------------------------------------------------
/tasks/word_ptb/result.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/philipperemy/keras-tcn/30a765c1daad74514874a6fb363fd428298af899/tasks/word_ptb/result.png
--------------------------------------------------------------------------------
/tasks/word_ptb/run.sh:
--------------------------------------------------------------------------------
1 | # export CUDA_VISIBLE_DEVICES=0; nohup python -u train.py --use_lstm --batch_size 256 --task char > lstm.log 2>&1 &
2 | # export CUDA_VISIBLE_DEVICES=1; nohup python -u train.py --batch_size 256 --task char > tcn.log 2>&1 &
3 |
4 | export CUDA_VISIBLE_DEVICES=0; nohup python -u train.py --use_lstm --batch_size 256 --task char > lstm_no_recurrent_dropout.log 2>&1 &
5 | export CUDA_VISIBLE_DEVICES=1; nohup python -u train.py --batch_size 256 --task char > tcn_boost.log 2>&1 &
--------------------------------------------------------------------------------
/tasks/word_ptb/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import nltk
4 | import numpy as np
5 | from nltk.tokenize import word_tokenize
6 | from tensorflow.keras import Sequential
7 | from tensorflow.keras.layers import Embedding, Dense, LSTM
8 | from tensorflow.python.keras.layers import Dropout
9 | from tqdm import tqdm
10 |
11 | from tcn import TCN
12 |
13 | nltk.download('punkt')
14 |
15 |
16 | def split_to_sequences(ids, len_):
17 | x = np.zeros(shape=(len(ids) - len_ - 1, len_))
18 | y = np.zeros(shape=(len(ids) - len_ - 1, 1))
19 | for index in tqdm(range(0, len(ids) - len_ - 1)):
20 | x[index] = ids[index:index + len_]
21 | y[index] = ids[index + len_]
22 | return x, y
23 |
24 |
25 | def main():
26 | parser = argparse.ArgumentParser(description='Sequence Modeling - The Word PTB')
27 | parser.add_argument('--batch_size', type=int, default=16, help='batch size')
28 | parser.add_argument('--emb_size', type=int, default=200, help='embedding size')
29 | parser.add_argument('--epochs', type=int, default=100, help='number of epochs')
30 | parser.add_argument('--seq_len', type=int, default=80, help='sequence length')
31 | parser.add_argument('--use_lstm', action='store_true')
32 | parser.add_argument('--task', choices=['char', 'word'])
33 | args = parser.parse_args()
34 | print(args)
35 |
36 | # Prepare dataset...
37 | with open('data/ptb.train.txt', 'r') as f1, \
38 | open('data/ptb.valid.txt', 'r') as f2, \
39 | open('data/ptb.test.txt', 'r') as f3:
40 | seq_train = f1.read().replace('', '')
41 | seq_valid = f2.read().replace('', '')
42 | seq_test = f3.read().replace('', '')
43 |
44 | if args.task == 'word':
45 | # split into words: [I, am, a, cat].
46 | seq_train = word_tokenize(seq_train)
47 | seq_valid = word_tokenize(seq_valid)
48 | seq_test = word_tokenize(seq_test)
49 | else:
50 | # split into characters: [I, ,a,m, ,a, ,c,a,t] ...
51 | seq_train = list(seq_train)
52 | seq_valid = list(seq_valid)
53 | seq_test = list(seq_test)
54 |
55 | vocab_train = set(seq_train)
56 | vocab_valid = set(seq_valid)
57 | vocab_test = set(seq_test)
58 |
59 | assert vocab_valid.issubset(vocab_train)
60 | assert vocab_test.issubset(vocab_train)
61 | size_vocab = len(vocab_train)
62 |
63 | # must have deterministic ordering for word2id dictionary to be reproducible
64 | vocab_train = sorted(vocab_train)
65 | word2id = {w: i for i, w in enumerate(vocab_train)}
66 |
67 | ids_train = [word2id[word] for word in seq_train]
68 | ids_valid = [word2id[word] for word in seq_valid]
69 | ids_test = [word2id[word] for word in seq_test]
70 |
71 | print(len(ids_train), len(ids_valid), len(ids_test))
72 |
73 | # Prepare inputs to model...
74 | x_train, y_train = split_to_sequences(ids_train, args.seq_len)
75 | x_val, y_val = split_to_sequences(ids_valid, args.seq_len)
76 |
77 | print(x_train.shape, y_train.shape)
78 | print(x_val.shape, y_val.shape)
79 |
80 | # Define the model.
81 | if args.use_lstm:
82 | model = Sequential(layers=[
83 | Embedding(size_vocab, args.emb_size),
84 | Dropout(rate=0.2),
85 | LSTM(128),
86 | Dense(size_vocab, activation='softmax')
87 | ])
88 | else:
89 | # noinspection PyArgumentEqualDefault
90 | tcn = TCN(
91 | nb_filters=70,
92 | kernel_size=3,
93 | dilations=(1, 2, 4, 8, 16),
94 | use_skip_connections=True,
95 | use_layer_norm=True
96 | )
97 | print(f'TCN.receptive_field: {tcn.receptive_field}.')
98 | model = Sequential(layers=[
99 | Embedding(size_vocab, args.emb_size),
100 | Dropout(rate=0.2),
101 | tcn,
102 | Dense(size_vocab, activation='softmax')
103 | ])
104 |
105 | # Compile and train.
106 | model.summary()
107 | model.compile('adam', 'sparse_categorical_crossentropy', metrics=['accuracy'])
108 | model.fit(
109 | x_train, y_train,
110 | batch_size=args.batch_size,
111 | validation_data=(x_val, y_val),
112 | epochs=args.epochs
113 | )
114 |
115 |
116 | if __name__ == '__main__':
117 | main()
118 |
--------------------------------------------------------------------------------
/tcn/__init__.py:
--------------------------------------------------------------------------------
1 | from tcn.tcn import TCN, compiled_tcn, tcn_full_summary # noqa
2 |
3 | __version__ = '3.5.6'
4 |
--------------------------------------------------------------------------------
/tcn/tcn.py:
--------------------------------------------------------------------------------
1 | import inspect
2 | from typing import List # noqa
3 |
4 | import tensorflow as tf
5 | try:
6 | # pylint: disable=E0611,E0401
7 | from keras.src.saving import register_keras_serializable # For recent Keras
8 | except ImportError:
9 | # pylint: disable=E0611,E0401
10 | from tensorflow.keras.saving import register_keras_serializable # For older versions
11 |
12 | # pylint: disable=E0611,E0401
13 | from tensorflow.keras import backend as K, Model, Input, optimizers
14 | # pylint: disable=E0611,E0401
15 | from tensorflow.keras import layers
16 | # pylint: disable=E0611,E0401
17 | from tensorflow.keras.layers import Activation, SpatialDropout1D, Lambda
18 | # pylint: disable=E0611,E0401
19 | from tensorflow.keras.layers import Layer, Conv1D, Dense, BatchNormalization, LayerNormalization
20 |
21 |
22 | def is_power_of_two(num: int):
23 | return num != 0 and ((num & (num - 1)) == 0)
24 |
25 |
26 | def adjust_dilations(dilations: list):
27 | if all([is_power_of_two(i) for i in dilations]):
28 | return dilations
29 | else:
30 | new_dilations = [2 ** i for i in dilations]
31 | return new_dilations
32 |
33 |
34 | class ResidualBlock(Layer):
35 |
36 | def __init__(self,
37 | dilation_rate: int,
38 | nb_filters: int,
39 | kernel_size: int,
40 | padding: str,
41 | activation: str = 'relu',
42 | dropout_rate: float = 0,
43 | kernel_initializer: str = 'he_normal',
44 | use_batch_norm: bool = False,
45 | use_layer_norm: bool = False,
46 | **kwargs):
47 | """Defines the residual block for the WaveNet TCN
48 | Args:
49 | x: The previous layer in the model
50 | training: boolean indicating whether the layer should behave in training mode or in inference mode
51 | dilation_rate: The dilation power of 2 we are using for this residual block
52 | nb_filters: The number of convolutional filters to use in this block
53 | kernel_size: The size of the convolutional kernel
54 | padding: The padding used in the convolutional layers, 'same' or 'causal'.
55 | activation: The final activation used in o = Activation(x + F(x))
56 | dropout_rate: Float between 0 and 1. Fraction of the input units to drop.
57 | kernel_initializer: Initializer for the kernel weights matrix (Conv1D).
58 | use_batch_norm: Whether to use batch normalization in the residual layers or not.
59 | use_layer_norm: Whether to use layer normalization in the residual layers or not.
60 | kwargs: Any initializers for Layer class.
61 | """
62 |
63 | self.dilation_rate = dilation_rate
64 | self.nb_filters = nb_filters
65 | self.kernel_size = kernel_size
66 | self.padding = padding
67 | self.activation = activation
68 | self.dropout_rate = dropout_rate
69 | self.use_batch_norm = use_batch_norm
70 | self.use_layer_norm = use_layer_norm
71 | self.kernel_initializer = kernel_initializer
72 | self.layers = []
73 | self.shape_match_conv = None
74 | self.res_output_shape = None
75 | self.final_activation = None
76 | self.batch_norm_layers = []
77 | self.layer_norm_layers = []
78 |
79 | super(ResidualBlock, self).__init__(**kwargs)
80 |
81 | def _build_layer(self, layer):
82 | """Helper function for building layer
83 | Args:
84 | layer: Appends layer to internal layer list and builds it based on the current output
85 | shape of ResidualBlocK. Updates current output shape.
86 | """
87 | self.layers.append(layer)
88 | self.layers[-1].build(self.res_output_shape)
89 | self.res_output_shape = self.layers[-1].compute_output_shape(self.res_output_shape)
90 |
91 | def build(self, input_shape):
92 |
93 | with K.name_scope(self.name): # name scope used to make sure weights get unique names
94 | self.layers = []
95 | self.res_output_shape = input_shape
96 |
97 | for k in range(2): # dilated conv block.
98 | name = 'conv1D_{}'.format(k)
99 | with K.name_scope(name): # name scope used to make sure weights get unique names
100 | conv = Conv1D(
101 | filters=self.nb_filters,
102 | kernel_size=self.kernel_size,
103 | dilation_rate=self.dilation_rate,
104 | padding=self.padding,
105 | name=name,
106 | kernel_initializer=self.kernel_initializer
107 | )
108 | self._build_layer(conv)
109 |
110 | with K.name_scope('norm_{}'.format(k)):
111 | if self.use_batch_norm:
112 | bn_layer = BatchNormalization()
113 | self.batch_norm_layers.append(bn_layer)
114 | self._build_layer(bn_layer)
115 | elif self.use_layer_norm:
116 | ln_layer = LayerNormalization()
117 | self.layer_norm_layers.append(ln_layer)
118 | self._build_layer(ln_layer)
119 |
120 | with K.name_scope('act_and_dropout_{}'.format(k)):
121 | self._build_layer(Activation(self.activation, name='Act_Conv1D_{}'.format(k)))
122 | self._build_layer(SpatialDropout1D(rate=self.dropout_rate, name='SDropout_{}'.format(k)))
123 |
124 | if self.nb_filters != input_shape[-1]:
125 | # 1x1 conv to match the shapes (channel dimension).
126 | name = 'matching_conv1D'
127 | with K.name_scope(name):
128 | # make and build this layer separately because it directly uses input_shape.
129 | # 1x1 conv.
130 | self.shape_match_conv = Conv1D(
131 | filters=self.nb_filters,
132 | kernel_size=1,
133 | padding='same',
134 | name=name,
135 | kernel_initializer=self.kernel_initializer
136 | )
137 | else:
138 | name = 'matching_identity'
139 | self.shape_match_conv = Lambda(lambda x: x, name=name)
140 |
141 | with K.name_scope(name):
142 | self.shape_match_conv.build(input_shape)
143 | self.res_output_shape = self.shape_match_conv.compute_output_shape(input_shape)
144 |
145 | self._build_layer(Activation(self.activation, name='Act_Conv_Blocks'))
146 | self.final_activation = Activation(self.activation, name='Act_Res_Block')
147 | self.final_activation.build(self.res_output_shape) # probably isn't necessary
148 |
149 | # this is done to force Keras to add the layers in the list to self._layers
150 | for layer in self.layers:
151 | self.__setattr__(layer.name, layer)
152 | self.__setattr__(self.shape_match_conv.name, self.shape_match_conv)
153 | self.__setattr__(self.final_activation.name, self.final_activation)
154 |
155 | super(ResidualBlock, self).build(input_shape) # done to make sure self.built is set True
156 |
157 | def call(self, inputs, training=None, **kwargs):
158 | """
159 | Returns: A tuple where the first element is the residual model tensor, and the second
160 | is the skip connection tensor.
161 | """
162 | # https://arxiv.org/pdf/1803.01271.pdf page 4, Figure 1 (b).
163 | # x1: Dilated Conv -> Norm -> Dropout (x2).
164 | # x2: Residual (1x1 matching conv - optional).
165 | # Output: x1 + x2.
166 | # x1 -> connected to skip connections.
167 | # x1 + x2 -> connected to the next block.
168 | # input
169 | # x1 x2
170 | # conv1D 1x1 Conv1D (optional)
171 | # ...
172 | # conv1D
173 | # ...
174 | # x1 + x2
175 | x1 = inputs
176 | for layer in self.layers:
177 | training_flag = 'training' in dict(inspect.signature(layer.call).parameters)
178 | x1 = layer(x1, training=training) if training_flag else layer(x1)
179 | x2 = self.shape_match_conv(inputs)
180 | x1_x2 = self.final_activation(layers.add([x2, x1], name='Add_Res'))
181 | return [x1_x2, x1]
182 |
183 | def compute_output_shape(self, input_shape):
184 | return [self.res_output_shape, self.res_output_shape]
185 |
186 |
187 | @register_keras_serializable()
188 | class TCN(Layer):
189 | """Creates a TCN layer.
190 |
191 | Input shape:
192 | A 3D tensor with shape (batch_size, timesteps, input_dim).
193 |
194 | Args:
195 | nb_filters: The number of filters to use in the convolutional layers. Can be a list.
196 | kernel_size: The size of the kernel to use in each convolutional layer.
197 | dilations: The list of the dilations. Example is: [1, 2, 4, 8, 16, 32, 64].
198 | nb_stacks : The number of stacks of residual blocks to use.
199 | padding: The padding to use in the convolutional layers, 'causal' or 'same'.
200 | use_skip_connections: Boolean. If we want to add skip connections from input to each residual blocK.
201 | return_sequences: Boolean. Whether to return the last output in the output sequence, or the full sequence.
202 | activation: The activation used in the residual blocks o = Activation(x + F(x)).
203 | dropout_rate: Float between 0 and 1. Fraction of the input units to drop.
204 | kernel_initializer: Initializer for the kernel weights matrix (Conv1D).
205 | use_batch_norm: Whether to use batch normalization in the residual layers or not.
206 | use_layer_norm: Whether to use layer normalization in the residual layers or not.
207 | go_backwards: Boolean (default False). If True, process the input sequence backwards and
208 | return the reversed sequence.
209 | return_state: Boolean. Whether to return the last state in addition to the output. Default: False.
210 | kwargs: Any other arguments for configuring parent class Layer. For example "name=str", Name of the model.
211 | Use unique names when using multiple TCN.
212 | Returns:
213 | A TCN layer.
214 | """
215 |
216 | def __init__(self,
217 | nb_filters=64,
218 | kernel_size=3,
219 | nb_stacks=1,
220 | dilations=(1, 2, 4, 8, 16, 32),
221 | padding='causal',
222 | use_skip_connections=True,
223 | dropout_rate=0.0,
224 | return_sequences=False,
225 | activation='relu',
226 | kernel_initializer='he_normal',
227 | use_batch_norm=False,
228 | use_layer_norm=False,
229 | go_backwards=False,
230 | return_state=False,
231 | **kwargs):
232 | self.stateful = False # TCN are not stateful. Keras requires this parameter.
233 | self.return_sequences = return_sequences
234 | self.dropout_rate = dropout_rate
235 | self.use_skip_connections = use_skip_connections
236 | self.dilations = dilations
237 | self.nb_stacks = nb_stacks
238 | self.kernel_size = kernel_size
239 | self.nb_filters = nb_filters
240 | self.activation_name = activation
241 | self.padding = padding
242 | self.kernel_initializer = kernel_initializer
243 | self.use_batch_norm = use_batch_norm
244 | self.use_layer_norm = use_layer_norm
245 | self.go_backwards = go_backwards
246 | self.return_state = return_state
247 | self.skip_connections = []
248 | self.residual_blocks = []
249 | self.layers_outputs = []
250 | self.build_output_shape = None
251 | self.slicer_layer = None # in case return_sequence=False
252 | self.output_slice_index = None # in case return_sequence=False
253 | self.padding_same_and_time_dim_unknown = False # edge case if padding='same' and time_dim = None
254 |
255 | if self.use_batch_norm + self.use_layer_norm > 1:
256 | raise ValueError('Only one normalization can be specified at once.')
257 |
258 | if isinstance(self.nb_filters, list):
259 | assert len(self.nb_filters) == len(self.dilations)
260 | if len(set(self.nb_filters)) > 1 and self.use_skip_connections:
261 | raise ValueError('Skip connections are not compatible '
262 | 'with a list of filters, unless they are all equal.')
263 |
264 | if padding != 'causal' and padding != 'same':
265 | raise ValueError("Only 'causal' or 'same' padding are compatible for this layer.")
266 |
267 | # initialize parent class
268 | super(TCN, self).__init__(**kwargs)
269 |
270 | @property
271 | def receptive_field(self):
272 | return 1 + 2 * (self.kernel_size - 1) * self.nb_stacks * sum(self.dilations)
273 |
274 | def tolist(self, shape):
275 | # noinspection PyBroadException
276 | try:
277 | return shape.as_list()
278 | except Exception:
279 | return list(shape)
280 |
281 | def build(self, input_shape):
282 |
283 | # member to hold current output shape of the layer for building purposes
284 | self.build_output_shape = input_shape
285 |
286 | # list to hold all the member ResidualBlocks
287 | self.residual_blocks = []
288 | total_num_blocks = self.nb_stacks * len(self.dilations)
289 | if not self.use_skip_connections:
290 | total_num_blocks += 1 # cheap way to do a false case for below
291 |
292 | for s in range(self.nb_stacks):
293 | for i, d in enumerate(self.dilations):
294 | res_block_filters = self.nb_filters[i] if isinstance(self.nb_filters, list) else self.nb_filters
295 | self.residual_blocks.append(ResidualBlock(dilation_rate=d,
296 | nb_filters=res_block_filters,
297 | kernel_size=self.kernel_size,
298 | padding=self.padding,
299 | activation=self.activation_name,
300 | dropout_rate=self.dropout_rate,
301 | use_batch_norm=self.use_batch_norm,
302 | use_layer_norm=self.use_layer_norm,
303 | kernel_initializer=self.kernel_initializer,
304 | name='residual_block_{}'.format(len(self.residual_blocks))))
305 | # build newest residual block
306 | self.residual_blocks[-1].build(self.build_output_shape)
307 | self.build_output_shape = self.residual_blocks[-1].res_output_shape
308 |
309 | # this is done to force keras to add the layers in the list to self._layers
310 | for layer in self.residual_blocks:
311 | self.__setattr__(layer.name, layer)
312 |
313 | self.output_slice_index = None
314 | if self.padding == 'same':
315 | time = self.tolist(self.build_output_shape)[1]
316 | if time is not None: # if time dimension is defined. e.g. shape = (bs, 500, input_dim).
317 | self.output_slice_index = int(self.tolist(self.build_output_shape)[1] / 2)
318 | else:
319 | # It will known at call time. c.f. self.call.
320 | self.padding_same_and_time_dim_unknown = True
321 |
322 | else:
323 | self.output_slice_index = -1 # causal case.
324 | self.slicer_layer = Lambda(lambda tt: tt[:, self.output_slice_index, :], name='Slice_Output')
325 | self.slicer_layer.build(self.tolist(self.build_output_shape))
326 |
327 | def compute_output_shape(self, input_shape):
328 | """
329 | Overridden in case keras uses it somewhere... no idea. Just trying to avoid future errors.
330 | """
331 | if not self.built:
332 | self.build(input_shape)
333 | if not self.return_sequences:
334 | batch_size = self.build_output_shape[0]
335 | batch_size = batch_size.value if hasattr(batch_size, 'value') else batch_size
336 | nb_filters = self.build_output_shape[-1]
337 | return [batch_size, nb_filters]
338 | else:
339 | # Compatibility tensorflow 1.x
340 | return [v.value if hasattr(v, 'value') else v for v in self.build_output_shape]
341 |
342 | def call(self, inputs, training=None, **kwargs):
343 | x = inputs
344 |
345 | if self.go_backwards:
346 | # reverse x in the time axis
347 | x = tf.reverse(x, axis=[1])
348 |
349 | self.layers_outputs = [x]
350 | self.skip_connections = []
351 | for res_block in self.residual_blocks:
352 | try:
353 | x, skip_out = res_block(x, training=training)
354 | except TypeError: # compatibility with tensorflow 1.x
355 | x, skip_out = res_block(K.cast(x, 'float32'), training=training)
356 | self.skip_connections.append(skip_out)
357 | self.layers_outputs.append(x)
358 |
359 | if self.use_skip_connections:
360 | if len(self.skip_connections) > 1:
361 | # Keras: A merge layer should be called on a list of at least 2 inputs. Got 1 input.
362 | x = layers.add(self.skip_connections, name='Add_Skip_Connections')
363 | else:
364 | x = self.skip_connections[0]
365 | self.layers_outputs.append(x)
366 |
367 | if not self.return_sequences:
368 | # case: time dimension is unknown. e.g. (bs, None, input_dim).
369 | if self.padding_same_and_time_dim_unknown:
370 | self.output_slice_index = K.shape(self.layers_outputs[-1])[1] // 2
371 | x = self.slicer_layer(x)
372 | self.layers_outputs.append(x)
373 | return x
374 |
375 | def get_config(self):
376 | """
377 | Returns the config of a the layer. This is used for saving and loading from a model
378 | :return: python dictionary with specs to rebuild layer
379 | """
380 | config = super(TCN, self).get_config()
381 | config['nb_filters'] = self.nb_filters
382 | config['kernel_size'] = self.kernel_size
383 | config['nb_stacks'] = self.nb_stacks
384 | config['dilations'] = self.dilations
385 | config['padding'] = self.padding
386 | config['use_skip_connections'] = self.use_skip_connections
387 | config['dropout_rate'] = self.dropout_rate
388 | config['return_sequences'] = self.return_sequences
389 | config['activation'] = self.activation_name
390 | config['use_batch_norm'] = self.use_batch_norm
391 | config['use_layer_norm'] = self.use_layer_norm
392 | config['kernel_initializer'] = self.kernel_initializer
393 | config['go_backwards'] = self.go_backwards
394 | config['return_state'] = self.return_state
395 | return config
396 |
397 |
398 | def compiled_tcn(num_feat, # type: int
399 | num_classes, # type: int
400 | nb_filters, # type: int
401 | kernel_size, # type: int
402 | dilations, # type: List[int]
403 | nb_stacks, # type: int
404 | max_len, # type: int
405 | output_len=1, # type: int
406 | padding='causal', # type: str
407 | use_skip_connections=False, # type: bool
408 | return_sequences=True,
409 | regression=False, # type: bool
410 | dropout_rate=0.05, # type: float
411 | name='tcn', # type: str,
412 | kernel_initializer='he_normal', # type: str,
413 | activation='relu', # type:str,
414 | opt='adam',
415 | lr=0.002,
416 | use_batch_norm=False,
417 | use_layer_norm=False):
418 | # type: (...) -> Model
419 | """Creates a compiled TCN model for a given task (i.e. regression or classification).
420 | Classification uses a sparse categorical loss. Please input class ids and not one-hot encodings.
421 |
422 | Args:
423 | num_feat: The number of features of your input, i.e. the last dimension of: (batch_size, timesteps, input_dim).
424 | num_classes: The size of the final dense layer, how many classes we are predicting.
425 | nb_filters: The number of filters to use in the convolutional layers.
426 | kernel_size: The size of the kernel to use in each convolutional layer.
427 | dilations: The list of the dilations. Example is: [1, 2, 4, 8, 16, 32, 64].
428 | nb_stacks : The number of stacks of residual blocks to use.
429 | max_len: The maximum sequence length, use None if the sequence length is dynamic.
430 | padding: The padding to use in the convolutional layers.
431 | use_skip_connections: Boolean. If we want to add skip connections from input to each residual blocK.
432 | return_sequences: Boolean. Whether to return the last output in the output sequence, or the full sequence.
433 | regression: Whether the output should be continuous or discrete.
434 | dropout_rate: Float between 0 and 1. Fraction of the input units to drop.
435 | activation: The activation used in the residual blocks o = Activation(x + F(x)).
436 | name: Name of the model. Useful when having multiple TCN.
437 | kernel_initializer: Initializer for the kernel weights matrix (Conv1D).
438 | opt: Optimizer name.
439 | lr: Learning rate.
440 | use_batch_norm: Whether to use batch normalization in the residual layers or not.
441 | use_layer_norm: Whether to use layer normalization in the residual layers or not.
442 | Returns:
443 | A compiled keras TCN.
444 | """
445 |
446 | dilations = adjust_dilations(dilations)
447 |
448 | input_layer = Input(shape=(max_len, num_feat))
449 |
450 | x = TCN(nb_filters, kernel_size, nb_stacks, dilations, padding,
451 | use_skip_connections, dropout_rate, return_sequences,
452 | activation, kernel_initializer, use_batch_norm, use_layer_norm,
453 | name=name)(input_layer)
454 |
455 | print('x.shape=', x.shape)
456 |
457 | def get_opt():
458 | if opt == 'adam':
459 | return optimizers.Adam(lr=lr, clipnorm=1.)
460 | elif opt == 'rmsprop':
461 | return optimizers.RMSprop(lr=lr, clipnorm=1.)
462 | else:
463 | raise Exception('Only Adam and RMSProp are available here')
464 |
465 | if not regression:
466 | # classification
467 | x = Dense(num_classes)(x)
468 | x = Activation('softmax')(x)
469 | output_layer = x
470 | model = Model(input_layer, output_layer)
471 |
472 | # https://github.com/keras-team/keras/pull/11373
473 | # It's now in Keras@master but still not available with pip.
474 | # TODO remove later.
475 | def accuracy(y_true, y_pred):
476 | # reshape in case it's in shape (num_samples, 1) instead of (num_samples,)
477 | if K.ndim(y_true) == K.ndim(y_pred):
478 | y_true = K.squeeze(y_true, -1)
479 | # convert dense predictions to labels
480 | y_pred_labels = K.argmax(y_pred, axis=-1)
481 | y_pred_labels = K.cast(y_pred_labels, K.floatx())
482 | return K.cast(K.equal(y_true, y_pred_labels), K.floatx())
483 |
484 | model.compile(get_opt(), loss='sparse_categorical_crossentropy', metrics=[accuracy])
485 | else:
486 | # regression
487 | x = Dense(output_len)(x)
488 | x = Activation('linear')(x)
489 | output_layer = x
490 | model = Model(input_layer, output_layer)
491 | model.compile(get_opt(), loss='mean_squared_error')
492 | print('model.x = {}'.format(input_layer.shape))
493 | print('model.y = {}'.format(output_layer.shape))
494 | return model
495 |
496 |
497 | def tcn_full_summary(model: Model, expand_residual_blocks=True):
498 | import tensorflow as tf
499 | # 2.6.0-rc1, 2.5.0...
500 | versions = [int(v) for v in tf.__version__.split('-')[0].split('.')]
501 | if versions[0] <= 2 and versions[1] < 5:
502 | layers = model._layers.copy() # store existing layers
503 | model._layers.clear() # clear layers
504 |
505 | for i in range(len(layers)):
506 | if isinstance(layers[i], TCN):
507 | for layer in layers[i]._layers:
508 | if not isinstance(layer, ResidualBlock):
509 | if not hasattr(layer, '__iter__'):
510 | model._layers.append(layer)
511 | else:
512 | if expand_residual_blocks:
513 | for lyr in layer._layers:
514 | if not hasattr(lyr, '__iter__'):
515 | model._layers.append(lyr)
516 | else:
517 | model._layers.append(layer)
518 | else:
519 | model._layers.append(layers[i])
520 |
521 | model.summary() # print summary
522 |
523 | # restore original layers
524 | model._layers.clear()
525 | [model._layers.append(lyr) for lyr in layers]
526 | else:
527 | print('WARNING: tcn_full_summary: Compatible with tensorflow 2.5.0 or below.')
528 | print('Use tensorboard instead. Example in keras-tcn/tasks/tcn_tensorboard.py.')
529 |
--------------------------------------------------------------------------------
/tox.ini:
--------------------------------------------------------------------------------
1 | [tox]
2 | envlist = {py3}-tensorflow-{2.17,2.18,2.19}
3 |
4 | [testenv]
5 | setenv =
6 | PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
7 | deps = pytest
8 | pylint
9 | flake8
10 | -rrequirements.txt
11 | tensorflow-2.17: tensorflow==2.17
12 | tensorflow-2.18: tensorflow==2.18
13 | tensorflow-2.19: tensorflow==2.19
14 | changedir = tasks/
15 | commands = pylint --disable=R,C,W,E1136 ../tcn
16 | flake8 ../tcn --count --select=E9,F63,F7,F82 --show-source --statistics
17 | flake8 ../tcn --count --exclude=michel,tests --max-line-length 127 --statistics
18 | python tcn_call_test.py
19 | python save_reload_sequential_model.py
20 | python sequential.py
21 | python multi_length_sequences.py
22 | python plot_tcn_model.py
23 | passenv = *
24 | install_command = pip install {packages}
25 |
--------------------------------------------------------------------------------